diff --git a/.gitattributes b/.gitattributes index f35660ca423c63db9376cac53ad08dc8afbddfee..67b3a75fef9ccbc3c54a6b447a6a0894169c1b70 100644 --- a/.gitattributes +++ b/.gitattributes @@ -48,3 +48,9 @@ phivenv/Lib/site-packages/numpy/_core/__pycache__/fromnumeric.cpython-39.pyc fil phivenv/Lib/site-packages/numpy/_core/tests/__pycache__/test_umath.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy/_core/__pycache__/_add_newdocs.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text phivenv/Lib/site-packages/numpy.libs/msvcp140-23ebcc0b37c8e3d074511f362feac48b.dll filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text +phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll b/phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll new file mode 100644 index 0000000000000000000000000000000000000000..9b65eee367ba6a562d43062fcef9fbeaccd74ae5 --- /dev/null +++ b/phivenv/Lib/site-packages/numpy.libs/libscipy_openblas64_-caad452230ae4ddb57899b8b3a33c55c.dll @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:44629a7d27806ea076daeae8e829b0cfbdec9e25099561a19af8e5910bd635c5 +size 32816640 diff --git a/phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe b/phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe new file mode 100644 index 0000000000000000000000000000000000000000..884ca00fc6441a59625a4b882c9f8792397c9e5a --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/distlib/t64-arm.exe @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f1618387a688f162408e7811350a72269076d52bf6d0f09860548d5b57d677ac +size 180736 diff --git a/phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe b/phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe new file mode 100644 index 0000000000000000000000000000000000000000..73256d682a45367868d4551f0097e3f53e349f02 --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/distlib/t64.exe @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a00a877acefcad45953343ad56a22152f7aaba5fcf2a10215d84169d47fbcd1d +size 105984 diff --git a/phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe b/phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe new file mode 100644 index 0000000000000000000000000000000000000000..10b515951f762a0486e33623cf864dbeddcef908 --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/distlib/w64-arm.exe @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:43f1ddcd5bbdcf161d6816b79b4889e7f75d2ce12ab4f7bcc77d16003a17cdaf +size 166400 diff --git a/phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc b/phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79604ba514825210f20290c0522aa2c22f044ad0 --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/idna/__pycache__/uts46data.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a8acb4dd7cd594effc85e8c2b9ac052d6f4fe88744cd4749a8e8b8b93ba88246 +size 151716 diff --git a/phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4f7110d3377306af4097bfd4fcf967b32ba3c69 --- /dev/null +++ b/phivenv/Lib/site-packages/pip/_vendor/pyparsing/__pycache__/core.cpython-39.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:7a6e2c125af98ae3013115aad3c6156dd30340dd0c77863105db036c061ddc8e +size 176641 diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08f282c4804a0e51cf71144dcec3393e993feae3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..adb38e06459d064a777b244ecbb2b8d4883edda4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_checkpointable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36f82d6cda57df268d93ef61b6374a9465d99149 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_composable_state.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..823a199bd6488185c31835a940626592ec9ac2e7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daac39bd3735f80ab281639cc7830fb8bdec0edd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_functional_collectives_impl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60f7583107860f1e3865a71162ecc7fc9ebef787 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_serialization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c50adf57cc1555bdc1e8e0006ab64834c1b089e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/_state_dict_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..169da4e1747e2eb0a4179358d968c9c47503fa2f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/argparse_util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91f028bf6a7c810f81b6ebe35390a90e73a8075c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/c10d_logger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14bd5f5a401e11e2eb089268036e0201de5f201d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/collective_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b2352b82bb74755b4c216448014779f8bf1cb4b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d149d0fff857eb6a3fe8d3ad9b3ebec46218565 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/device_mesh.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1838b528225c2293895ec30407147329b029292f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/launch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9f48130da22b252f60020697cf44b23ed8ae55 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/logging_handlers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bdcea32769be5c573e4ca916771900e450d35aa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/remote_device.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1c00e6c2e354f86455762eac79070a52f4f9f01 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/rendezvous.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39439a9895f19e0d2a27ba80192d7e30c9428158 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/run.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..938d3cf71283a7382d1969f605ec8c16b26dc4b8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..daac58f42ae1bc3df5c1df321c0b436ccf9cb080 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable/__init__.py @@ -0,0 +1,3 @@ +from .checkpoint_activation import checkpoint +from .contract import _get_registry, contract +from .replicate import replicate diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61b671d17d8d445571a34aacec66b587e7faf19f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7913079bdd3e6a1b72df0487d2fb0cc38fc3443f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/checkpoint_activation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a2e9534b44a2f65be99045fd9d7fb46fe5a69e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/contract.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f71ef67091d3cc53ded8c5af3bdd6875a97d40e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_composable/__pycache__/replicate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py b/phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py new file mode 100644 index 0000000000000000000000000000000000000000..85daddc9c630bc7f9eaf40e71052e6d1ff143815 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable/checkpoint_activation.py @@ -0,0 +1,132 @@ +# mypy: allow-untyped-defs +from collections.abc import Generator +from contextlib import AbstractContextManager, contextmanager, nullcontext +from typing import Any, Optional + +import torch +import torch.nn as nn +from torch.utils.checkpoint import ( + _checkpoint_without_reentrant_generator, + _DEFAULT_DETERMINISM_MODE, +) + +from .contract import _State, contract + + +@contextmanager +def _no_hook(module: nn.Module, user_ctx: Optional[AbstractContextManager] = None): + r""" + Disable hooks installed by checkpoint to avoid unintentional recursion + during backward recomputation. + """ + + with user_ctx if user_ctx else nullcontext(): + orig_enable_hook = checkpoint.state(module).enable_hook + checkpoint.state(module).enable_hook = False + try: + yield + finally: + checkpoint.state(module).enable_hook = orig_enable_hook + + +class _CheckpointState(_State): + enable_hook: bool = False + _ac_generator: Optional[Generator[None, None, None]] + + +@contract(_CheckpointState) +def checkpoint(module: nn.Module, **kwargs) -> nn.Module: + r""" + This is a composable activation checkpointing API. Unlike functional + activation checkpointing APIs, this one does not require changing model + source code. Unlike ``nn.Module`` wrapper activation checkpointing APIs, + this one does not modify model structure or fully-qualified names either. + Under the hood, it registers activation checkpointing logic as pre- and + post-forward hooks. Hence, this API can be easily applied to any model or + sub-modules in the model. + + Args: + module (nn.Module): the target model or sub-module to apply activation + checkpointing. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> model = MyModel() + >>> checkpoint(model.l1) # apply activation checkpointing only to l1 + >>> model(torch.zeros(2, 10)).sum().backward() + + """ + torch._C._log_api_usage_once("torch.distributed.checkpoint") + + use_reentrant = kwargs.pop("use_reentrant", False) + if use_reentrant: + raise NotImplementedError( + "use_reentrant=True is not supported in composable checkpoint. " + "Please use torch.utils.checkpoint.checkpoint instead." + ) + preserve_rng_state = kwargs.pop("preserve_rng_state", True) + user_context_fns = kwargs.pop("context_fn", None) + determinism_check = kwargs.pop("determinism_check", _DEFAULT_DETERMINISM_MODE) + debug = kwargs.pop("debug", False) + + if kwargs: + raise ValueError( + "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs) + ) + + def forward_pre_hook( + module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> None: + if checkpoint.state(module).enable_hook: + + def context_fns(): + if user_context_fns is not None: + ctx1, ctx2 = user_context_fns() + return ctx1, _no_hook(module, ctx2) + else: + return nullcontext(), _no_hook(module) + + gen = _checkpoint_without_reentrant_generator( + module, + preserve_rng_state, + context_fns, + determinism_check, + debug, + *args, + **kwargs, + ) + checkpoint.state(module)._ac_generator = gen + next(gen) + + def forward_hook(module: nn.Module, inputs: tuple[Any, ...], output: Any) -> Any: + if checkpoint.state(module).enable_hook: + try: + gen = checkpoint.state(module)._ac_generator + assert gen is not None + next(gen) + except StopIteration: + pass + else: + raise RuntimeError( + "Expected non-reentrant activation checkpoint generator to be exhausted, but it was not!" + ) + + # Ensure that we no longer hold on to the generator. always_call=True helps ensure we + # clear this even in the case of exception in fwd pass. + checkpoint.state(module)._ac_generator = None + + checkpoint.state(module).enable_hook = True + module.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + module.register_forward_hook(forward_hook, prepend=True, always_call=True) + return module diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/contract.py b/phivenv/Lib/site-packages/torch/distributed/_composable/contract.py new file mode 100644 index 0000000000000000000000000000000000000000..00b3b45592f77c0b522274e441e53647acc29c0f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable/contract.py @@ -0,0 +1,248 @@ +# mypy: allow-untyped-defs +import uuid +from collections import OrderedDict +from functools import wraps +from typing import Callable, Generic, Optional, Protocol +from typing_extensions import Concatenate, ParamSpec, TypeVar + +import torch +import torch.nn as nn +from torch.distributed._composable_state import _State +from torch.distributed.utils import _get_root_modules + + +_T = TypeVar("_T", covariant=True) +_P = ParamSpec("_P") + + +def generate_state_key(string="__composable_api_state_key"): + return f"{string}_{str(uuid.uuid4())}" + + +STATE_KEY = generate_state_key() +REGISTRY_KEY = generate_state_key() + + +# TODO: we can add additional info to RegistryItem to share across APIs. E.g., +# we can add args and kwargs here, and then we can detect whether fully_shard +# is combined with reentrant activation checkpointing and error out with a clear +# message. +class RegistryItem: + pass + + +_TState = TypeVar("_TState", bound="_State", covariant=True) +_M = TypeVar("_M", nn.Module, list[nn.Module]) + + +class _ContractFn(Protocol, Generic[_P, _T, _TState]): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _T: ... + + def state(self, module: nn.Module) -> _TState: ... + + +def contract( + state_cls: type[_TState] = _State, # type: ignore[assignment] +) -> Callable[ + [Callable[Concatenate[_M, _P], _M]], + _ContractFn[Concatenate[_M, _P], _M, _TState], +]: + r""" + Decorate a function as a composable distributed API, where the first + argument of the function must be an :class:`nn.Module` instance or sequence + of :class:`nn.Module` instances. + + The decorator verifies that the decorated function does not modify + fully-qualified names (FQNs) for parameters, buffers, or modules. The + decorated function can return different module instances than the input + modules; the FQN invariant will be enforced following the input order. + + When a function ``func`` is decorated by ``@contract()``, a + ``.state(module: nn.Module)`` method will be installed to the decorated + function. Then you can retrieve and modify the state on a module by calling + ``func.state(module)``. + + Example:: + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> + >>> class MyModel(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.l1 = nn.Linear(10, 10) + >>> self.l2 = nn.Linear(10, 10) + >>> + >>> def forward(self, x): + >>> return self.l2(self.l1(x)) + >>> + >>> @contract() + >>> def my_feature(module: nn.Module) -> nn.Module: + >>> my_feature.state(module).some_state = "any value" + >>> return module + >>> + >>> model = MyModel() + >>> my_feature(model.l1) + >>> assert my_feature.state(model.l1).some_state == "any value" + >>> my_feature(model.l2) + >>> model(torch.randn(2, 10)).sum().backward() + """ + + # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package + @wraps(state_cls) # type: ignore[arg-type] + def inner( + func: Callable[Concatenate[_M, _P], _M], + ) -> _ContractFn[Concatenate[_M, _P], _M, _TState]: + @wraps(func) + def wrapper( + module: _M, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> _M: + inp_module = module + modules: list[nn.Module] + if isinstance(module, nn.Module): + modules = [module] + else: + # If the user passes a sequence of modules, then we assume that + # we only need to insert the state object on the root modules + # (i.e. those without a parent) among the passed-in modules. + modules = _get_root_modules(list(module)) + state = state_cls() # shared across all modules + registry_item = RegistryItem() # shared across all modules + + # `func` is allowed to return different module instances than the + # input modules as long as FQNs are preserved following the input + # module order + all_orig_named_params: list[dict[str, nn.Parameter]] = [] + all_orig_named_buffers: list[dict[str, torch.Tensor]] = [] + all_orig_named_modules: list[dict[str, nn.Module]] = [] + + for module in modules: + default_all_state: dict[Callable, _State] = OrderedDict() + default_registry: dict[str, RegistryItem] = OrderedDict() + all_state: dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, default_all_state + ) + if not isinstance(all_state, dict): + raise AssertionError( + f"Distributed composable API states corrupted: {all_state}" + ) + registry: dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] + REGISTRY_KEY, default_registry + ) + if not isinstance(registry, dict): + raise AssertionError( + f"Distributed composable API registry corrupted: {registry}" + ) + if func in all_state or func.__name__ in registry: + raise AssertionError( + "Each distinct composable distributed API can only be applied to a " + f"module once. {func.__name__} has already been applied to the " + f"following module:\n{module}" + ) + all_state.setdefault(func, state) + registry.setdefault(func.__name__, registry_item) + + all_orig_named_params.append(OrderedDict(module.named_parameters())) + all_orig_named_buffers.append(OrderedDict(module.named_buffers())) + all_orig_named_modules.append(OrderedDict(module.named_modules())) + + updated = func(inp_module, *args, **kwargs) + if updated is None: + updated = inp_module # type: ignore[assignment] + updated_modules: list[nn.Module] + if isinstance(updated, nn.Module): + updated_modules = [updated] + else: + updated_modules = _get_root_modules(list(inp_module)) # type: ignore[arg-type, call-overload] + + all_new_named_params: list[dict[str, nn.Parameter]] = [] + all_new_named_buffers: list[dict[str, torch.Tensor]] = [] + all_new_named_modules: list[dict[str, nn.Module]] = [] + for module in updated_modules: + all_new_named_params.append(OrderedDict(module.named_parameters())) + all_new_named_buffers.append(OrderedDict(module.named_buffers())) + all_new_named_modules.append(OrderedDict(module.named_modules())) + + num_orig_modules = len(all_orig_named_modules) + num_new_modules = len(all_new_named_modules) + if num_orig_modules != num_new_modules: + raise AssertionError( + f"{func.__name__} should return the same number of modules as input modules" + f"Inputs: {num_orig_modules} modules\n" + f"Outputs: {num_new_modules} modules" + ) + + def check_fqn(orig_fqns: list[str], new_fqns: list[str], check_key: str): + if orig_fqns == new_fqns: + return + + orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) + orig_only = orig_fqn_set - new_fqn_set + new_only = new_fqn_set - orig_fqn_set + if len(orig_only) or len(new_only): + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify FQNs.\n" + f"FQNs only in original: {orig_only}\n" + f"FQNs only in new: {new_only}" + ) + else: + raise RuntimeError( + f"{check_key}" + "Composable distributed API implementations cannot modify " + "the order of FQNs.\n" + f"Original FQNs: {orig_only}\n" + f"New FQNs: {new_only}" + ) + + for orig_named_params, new_named_params in zip( + all_orig_named_params, all_new_named_params + ): + check_fqn( + list(orig_named_params.keys()), + list(new_named_params.keys()), + "Checking parameters: ", + ) + for orig_named_buffers, new_named_buffers in zip( + all_orig_named_buffers, all_new_named_buffers + ): + check_fqn( + list(orig_named_buffers.keys()), + list(new_named_buffers.keys()), + "Checking buffers: ", + ) + for orig_named_modules, new_named_modules in zip( + all_orig_named_modules, all_new_named_modules + ): + check_fqn( + list(orig_named_modules.keys()), + list(new_named_modules.keys()), + "Checking modules: ", + ) + + # TODO: verify that installed distributed paradigms are compatible with + # each other. + + return updated + + def get_state(module: nn.Module) -> _State: + return module.__dict__.setdefault( # type: ignore[call-overload] + STATE_KEY, + {}, # TODO(@yhcharles): this is a temporary fix, need a better way + ).get(func) # type: ignore[call-overload] + + wrapper.state = get_state # type: ignore[attr-defined] + + return wrapper # type: ignore[return-value] + + return inner # type: ignore[return-value] + + +def _get_registry(module: nn.Module) -> Optional[dict[str, RegistryItem]]: + r""" + Get an ``OrderedDict`` of composable APIs that have been applied to the + ``module``, indexed by the API name. If no API has been applied, then this + returns ``None``. + """ + return getattr(module, REGISTRY_KEY, None) diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a227c5cfc5646a961a4fd74892f3ee0541b57503 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__init__.py @@ -0,0 +1,3 @@ +from torch.distributed.fsdp import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy + +from .fully_shard import FSDPModule, fully_shard, register_fsdp_forward_method diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..789a8c6cdbaa673b15d8e3bf7bf11c93e5963f72 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db3eda8cb48170551a2023de46971f1e02788993 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/__pycache__/fully_shard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6518227edb3a26366805ba4bc210cd6178144c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable/fsdp/fully_shard.py @@ -0,0 +1,8 @@ +# TODO: For backward compatibility, we are importing the public objects +# originally from this file. +from torch.distributed.fsdp import ( # noqa: F401 + FSDPModule, + fully_shard, + register_fsdp_forward_method, + UnshardHandle, +) diff --git a/phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py b/phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py new file mode 100644 index 0000000000000000000000000000000000000000..b881276bf2537048b6a1ede7a187155fc3600c91 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_composable/replicate.py @@ -0,0 +1,256 @@ +# mypy: allow-untyped-defs +import weakref +from collections.abc import Iterable +from typing import Any, NoReturn, Optional + +import torch +import torch.nn as nn +from torch.distributed._composable_state import _State +from torch.nn.parallel import DistributedDataParallel + +from .contract import _get_registry, contract + + +_ROOT_MODULE_PREFIX = "" + + +class _ReplicateState(_State): + _ddp_weakref: weakref.ref + + def __init__(self) -> None: + super().__init__() + self.module: nn.Module = nn.ParameterList() + self.has_initialized: bool = False + self._param_list: nn.ParameterList = nn.ParameterList() + # TODO(@fegin): this variable is originally create for testing, we + # should remove this if possible. + self._orig_module = self.module + self._param_names: list[str] = [] + self._no_sync: bool = False + self._init_args: Optional[tuple[Any, ...]] = None + self._init_kwargs: dict[str, Any] = {} + self._comm_hook_args: list[Any] = [] + + def _collect_params( + self, + module: nn.Module, + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + prefix: str = _ROOT_MODULE_PREFIX, + ) -> None: + # skip if managed by fully_sharded API + if _is_fully_sharded(module): + return + + # if a module is ignored, all descendants of the module are ignored. + if module in ignored_modules: + return + + recurse_prefix = ( + f"{prefix}." if prefix != _ROOT_MODULE_PREFIX else _ROOT_MODULE_PREFIX + ) + + for n, p in module.named_parameters(recurse=False): + if p not in ignored_params: + self._param_list.append(p) + self._param_names.append(f"{recurse_prefix}{n}") + + for name, child_module in module.named_children(): + self._collect_params( + child_module, + ignored_modules, + ignored_params, + prefix=f"{recurse_prefix}{name}", + ) + + def lazy_init(self) -> None: + @torch._disable_dynamo(recursive=True) + def _lazy_init(): + assert self._init_args is not None + self.init(*self._init_args, **self._init_kwargs) + self.register_comm_hook() + self._init_args = () + self._init_kwargs = {} + + _lazy_init() + + def init( + self, + module: nn.Module, + ignored_modules: set[nn.Module], + **kwargs, + ) -> None: + if self.has_initialized: + return + + self.has_initialized = True + self.module = module + ignored_params = {p for m in ignored_modules for p in m.parameters()} + for submodule in module.modules(): + if _is_fully_sharded(submodule): + ignored_params.update(submodule.parameters()) + from torch.distributed.tensor.parallel.ddp import _localize_dtensor + + _localize_dtensor(module, ignored_params=ignored_params) + self._collect_params(module, ignored_modules, ignored_params) + + if "device_id" in kwargs: + # replicate() supports a small usability enhancement where + # user can pass in device_id as a Union[int, torch.device] even for + # CPU devices so users don't have to change code for CPU/GPU runs. + # We derive the right device_ids to feed into DDP to support this. + if kwargs["device_id"] is not None: + device_id = kwargs["device_id"] + # Convert to device_ids that DDP expects. + if isinstance(device_id, torch.device) and device_id.type == "cpu": + # CPU modules receive device_ids None + kwargs["device_ids"] = None + else: + # GPU modules expect device_ids=[cuda_device] + kwargs["device_ids"] = [device_id] + else: + kwargs["device_ids"] = None + kwargs.pop("device_id") + + self._ddp = DistributedDataParallel(self._param_list, **kwargs) + # Weakref to the DDP instance is currently only used for testing. + replicate.state(self.module)._ddp_weakref = weakref.ref(self._ddp) + + def register_comm_hook(self) -> None: + for comm_args, comm_kwargs in self._comm_hook_args: + self._ddp.register_comm_hook(*comm_args, **comm_kwargs) + self._comm_hook_args.clear() + + def record_init_args(self, *args, **kwargs) -> None: + self._init_args = args + self._init_kwargs = kwargs + + def forward_pre_hook( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> Any: + if self._init_args or self._init_kwargs: + self.lazy_init() + self._ddp.require_backward_grad_sync = not self._no_sync + return self._ddp._pre_forward(*args, **kwargs) + + def forward_post_hook( + self, + module: nn.Module, + input: tuple[torch.Tensor], + output: torch.Tensor, + ) -> torch.Tensor: + return self._ddp._post_forward(output) + + +def unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "DDP does not support deepcopy. Please use state dict for serialization." + ) + + +# Follow the same pattern as FSDP/fully_shard +class DDP: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the DDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `DDP<...>` class + # and index 1 is the `DDP` class itself + orig_cls = cls.__mro__[2] + return orig_cls.__new__(orig_cls, *args, **kwargs) + + def set_requires_gradient_sync(self, requires_gradient_sync: bool) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation without communication. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + """ + replicate.state(self)._no_sync = not requires_gradient_sync # type: ignore[arg-type] + + def register_comm_hook(self, *args, **kwargs) -> None: + replicate.state(self)._comm_hook_args.append((args, kwargs)) # type: ignore[arg-type] + + +@contract(state_cls=_ReplicateState) +def replicate( + module: nn.Module, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + **kwargs, +) -> nn.Module: + r"""Replicates a module + + Args: + module (torch.nn.Module): module to replicate + + Example:: + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> module = nn.Linear(3, 3) + >>> replicate(module) + """ + torch._C._log_api_usage_once("torch.distributed.replicate") + + # TODO(fegin): using kwargs is not a good idea if we would like to make + # replicate a formal API to replace DDP. + if "device_id" in kwargs: + if not isinstance(kwargs["device_id"], (int, torch.device)): + raise RuntimeError( + "Expected device_id to be int or torch.device, " + f"but got {type(kwargs['device_id'])}" + ) + + if _is_fully_sharded(module): + raise RuntimeError( + "Cannot apply `replicate()` on a Module already managed by `fully_shard`" + ) + + if ignored_modules is None: + ignored_modules = {} + else: + ignored_modules = set(ignored_modules) + + state = replicate.state(module) + module.register_forward_pre_hook(state.forward_pre_hook, with_kwargs=True) + device_mesh = kwargs.get("device_mesh", None) + if device_mesh is not None: + from torch.distributed.device_mesh import _mesh_resources + + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh != device_mesh: + # TODO: This is a temporary work around to enable DDP + TP. + # We should do the logic in DDP so that the 2D implementation is + # sound and the state_dict works out of the box. + # + # This won't conflict with what is done in DDP class as the module + # replicate is going to pass is NOT the original module. + from torch.distributed.tensor.parallel.ddp import ( + _localize_dtensor, + _reconstruct_dtensor, + ) + + module.register_forward_pre_hook(_reconstruct_dtensor) + module.register_forward_hook(_localize_dtensor) + + module.register_forward_hook(state.forward_post_hook) # type: ignore[arg-type] + + state.record_init_args(module, ignored_modules, **kwargs) + + # Place DDP leftmost for highest priority in the method resolution order + cls = module.__class__ + dct = {"__deepcopy__": unimplemented_deepcopy} + new_cls = type(f"DDP{cls.__name__}", (DDP, cls), dct) + module.__class__ = new_cls + return module + + +def _is_fully_sharded(module: nn.Module) -> bool: + r"""Check if module is marked with fully_shard.""" + registry = _get_registry(module) + if registry is None: + return False + return "fully_shard" in registry diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23542f81cc832309d7dee068cc91289df5e58a25 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/__init__.py @@ -0,0 +1 @@ +from .api import _shard_tensor, load_with_process_group, shard_module, shard_parameter diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c130f241288a34a72a6b94c16f26b417046941f2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..538f5b504cb68063f20b2a6cf63c9f3c446270dc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca0bdbce7a903e61d73a6564c78b389c6dafe4a2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8917cbc6f84e668add55ce9b2c65046e1a87e6d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/common_op_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a105ed129ac4b1d7ab7820381bb747c414a7f659 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/metadata.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..91d6ea60bd5dbbc0b79fe7f80dcec4e895c54f1e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/op_registry_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e14afcfd46fce7dbe4779082d86d2cfdb5368969 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/__pycache__/sharder.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py b/phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b9a8e5719828523763206aa7ba6ad483e857eca8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/_utils.py @@ -0,0 +1,32 @@ +from collections.abc import Sequence + +import torch +from torch.distributed._shard.metadata import ShardMetadata + + +DEPRECATE_MSG = "Please use DTensor instead and we are deprecating ShardedTensor." + + +def narrow_tensor_by_index( + tensor: torch.Tensor, + offsets: Sequence[int], + sizes: Sequence[int], +) -> torch.Tensor: + """ + Narrow the tensor according to ``offsets`` and ``sizes``. + """ + narrowed_tensor = tensor + for idx, (offset, size) in enumerate(zip(offsets, sizes)): + if size < tensor.size(idx): + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrowed_tensor.narrow(idx, offset, size) + return narrowed_tensor + + +def narrow_tensor(tensor: torch.Tensor, metadata: ShardMetadata) -> torch.Tensor: + """ + Narrow the tensor according to the metadata + """ + return narrow_tensor_by_index(tensor, metadata.shard_offsets, metadata.shard_sizes) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/api.py b/phivenv/Lib/site-packages/torch/distributed/_shard/api.py new file mode 100644 index 0000000000000000000000000000000000000000..96c91391fc18b372d828701b7db373711169f275 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/api.py @@ -0,0 +1,306 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager +from typing import Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .sharder import Sharder +from .sharding_plan import ShardingPlan +from .sharding_spec import ChunkShardingSpec, ShardingSpec + + +def _shard_tensor( + tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None +) -> ShardedTensor: + """ + Given a :class:`torch.Tensor`, it shards that tensor according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + Args: + tensor (:class:`torch.Tensor`): Tensor needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + if not tensor.is_contiguous(): + raise ValueError("input tensor is not a contiguous Tensor") + + pg = ( + process_group + if process_group is not None + else distributed_c10d._get_default_group() + ) + world_size = dist.get_world_size(pg) + current_rank = dist.get_rank(pg) + + # Validate src_rank and sharding_spec are same across all ranks. + gathered_list = [None] * world_size + dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) + + for idx, entry in enumerate(gathered_list): + if src_rank != entry[0]: # type: ignore[index] + raise ValueError( + f"src_rank={src_rank} on rank: {current_rank} does not " # type: ignore[index] + f"match with src_rank={entry[0]} on rank: {idx}" # type: ignore[index] + ) + if sharding_spec != entry[1]: # type: ignore[index] + raise ValueError( + f"sharding_spec={sharding_spec} on rank: {current_rank} does not " # type: ignore[index] + f"match with sharding_spec={entry[1]} on rank: {idx}" # type: ignore[index] + ) + + st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=pg) + + return st + + +def shard_parameter( + module: torch.nn.Module, + param_name: str, + sharding_spec: ShardingSpec, + src_rank=0, + process_group=None, +): + """ + Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that + module, it shards that parameter according to the provided + ``sharding_spec``. ``src_rank`` denotes the source rank which would be + used as the ground truth of the data which would be scattered as shards + across the rest of the ranks. + + This method replaces ``module.param_name`` with a + :class:`torch.distributed._sharded_tensor.ShardedTensor` + + Args: + module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. + param_name (str): Name of the parameter of ``module`` that needs to be sharded. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + .. warning:: + Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is + currently supported as the ``sharding_spec``. + """ + # Perform some validation first. + if not hasattr(module, param_name): + raise AttributeError(f"{module._get_name()} has no attribute `{param_name}`") + + tensor = getattr(module, param_name) + if not isinstance(tensor, torch.Tensor): + raise ValueError( + f"Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}" + ) + + if not tensor.is_contiguous(): + raise ValueError(f"param: {param_name} is not a contiguous Tensor") + + st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) + + # Replace param with ShardedTensor. + module.register_parameter(param_name, nn.Parameter(st)) + + +# Tracks the current process group in the load context manager. +_CURRENT_PROCESS_GROUP: Optional[dist.ProcessGroup] = None + + +@contextmanager +def load_with_process_group(process_group): + """ + Context manager to set the process group with which to load a ShardedTensor. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is not None: + raise RuntimeError( + 'ProcessGroup already set by previous "load_with_process_group" ' + "context manager" + ) + _CURRENT_PROCESS_GROUP = process_group + try: + yield process_group + finally: + _CURRENT_PROCESS_GROUP = None + + +def _get_current_process_group(): + """ + Retrieves the current process group set by ``load_with_process_group``. + If not set, it just returns the default group. + """ + global _CURRENT_PROCESS_GROUP + if _CURRENT_PROCESS_GROUP is None: + return distributed_c10d._get_default_group() + else: + return _CURRENT_PROCESS_GROUP + + +def _reshard_output( + module: torch.nn.Module, resharding_spec: ShardingSpec +) -> torch.nn.Module: + """ + Hook a module with output resharding in the forward pass according + to the given ``resharding_spec``. + + Args: + module (:class:`torch.nn.Module`): Module whose output needs to be resharded. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how the output of the module will be resharded. + + Returns: + A :class:`torch.nn.Module` object with reshard API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + return output.reshard(resharding_spec) + return output + + module.register_forward_hook(hook_func) + return module + + +def _collect_local_shard(module: torch.nn.Module) -> torch.nn.Module: + """ + Hook a module with local shards collection in the forward pass. + + This API is typically used to convert a sharded representation back to data parallel + representation. In particular, it returns the local tensor for this Shard. If the + size along the sharding dimension for the local tensor is 1, this dimension is removed + from the final result. For example a [4, 16] ShardedTensor across 4 ranks is typically + a local Tensor of size [16] across each rank and not [1, 16] across each rank. + + Args: + module (:class:`torch.nn.Module`): Module whose output is ShardedTensor and the + local tensor value needs to be returned. + + Returns: + A :class:`torch.nn.Module` object with collection API hooked. + """ + + def hook_func(_module, _input, output): + if isinstance(output, ShardedTensor): + local_tensor = output.local_tensor() + # Squeeze the # of dimensions manually, only applicable to ChunkShardingSpec + sharding_spec = output._sharding_spec + if ( + isinstance(sharding_spec, ChunkShardingSpec) + and local_tensor.size(sharding_spec.dim) == 1 # type: ignore[attr-defined, arg-type] + ): + local_tensor = local_tensor.squeeze( + output._sharding_spec.dim # type: ignore[attr-defined] + ) + return local_tensor + + module.register_forward_hook(hook_func) + return module + + +def shard_module(module: nn.Module, plan: ShardingPlan, src_rank=0, process_group=None): + """ + Shards a given module according to the provided sharding `plan`. This method + first shards all the parameters according to the given sharding `plan`. Then if + `output_plan` and `return_local_tensor` are specified in the sharding `plan`, it + will tag the output of modules according `output_plan`, convert the module's + output back to data parallel according to `return_local_tensor`. + + Needs to be called on all ranks in an SPMD fashion. + + Args: + module (:class:`torch.nn.Module`): The module to apply sharding to + plan (:class:`torch.distributed._shard.sharding_plan.ShardingPlan`): + The ShardingPlan which specified param name to ShardingSpec to apply to + each parameter. + + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the module that would be sharded and scattered across the rest + of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + """ + # record Sharder paths for sanity check on the plan to ensure items in the plan + # does not conflict with the submodule tree that the Sharder is working with + sharder_paths = [] + for name, spec in plan.plan.items(): + if isinstance(spec, Sharder): + sharder_paths.append(name) + + # shard the parameter according to the ShardingPlan + for name, spec in plan.plan.items(): + if isinstance(spec, ShardingSpec): + # if found a sharding spec, try to shard the parameter + module_path, _, param_name = name.rpartition(".") + + for sharder_path in sharder_paths: + if module_path.startswith(sharder_path): + raise RuntimeError( + f"ShardingPlan is in-valid, trying to shard a parameter: {name}," + f" but there's already a Sharder entry for module {sharder_path}," + f" parameter sharding should not conflict with the submodule tree" + f" that a Sharder is working with!" + ) + + mod = module.get_submodule(module_path) + shard_parameter( + mod, param_name, spec, src_rank=src_rank, process_group=process_group + ) + elif isinstance(spec, Sharder): + parent_mod_path, _, _mod_name = name.rpartition(".") + if name == "": + raise KeyError("Module path must not be empty for custom sharder!") + mod = module.get_submodule(name) + parent_mod = module.get_submodule(parent_mod_path) + sharded_mod = spec.shard(mod) + # swap this submodule with the sharded module + parent_mod.mod_name = sharded_mod + else: + raise TypeError( + f"Only `ShardingSpec` and `Sharder` are supported to shard '{name}'" + ) + + # reshard output if there's an entry in `reshard_output` for this module + if plan.output_plan is not None: + for module_path, output_spec in plan.output_plan.items(): + if isinstance(output_spec, ShardingSpec): + mod = module.get_submodule(module_path) + _reshard_output(mod, output_spec) + else: + raise TypeError( + f"Only `ShardingSpec` is supported as output_plan for '{module_path}'" + ) + # convert the output back to data parallel for the modules appears in + # `return_local_tensor` of the plan, we will call `_collect_local_shard` + # to collect the local tensor for output of modules + if plan.return_local_tensor is not None: + for module_path in plan.return_local_tensor: + mod = module.get_submodule(module_path) + _collect_local_shard(mod) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9d1df78c9bd63547b73461e98ca8db2e25d7b197 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__init__.py @@ -0,0 +1,19 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed.checkpoint` package. +import sys +import warnings + +import torch +from torch.distributed.checkpoint import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._shard.checkpoint` will be deprecated, " + "use `torch.distributed.checkpoint` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["torch.distributed._shard.checkpoint"] = torch.distributed.checkpoint diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc58f6b7dd249545964b23625c9a7be8b98e07c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/checkpoint/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py b/phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cfbc822129b4778076ccfcf20df4fec398ed341c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/common_op_utils.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch.utils import _pytree as pytree + + +def _basic_validation(op, args=(), kwargs=None): + """ + Common validation across all ops go in here. + """ + from torch.distributed._shard.sharded_tensor import ShardedTensor + + if len(args) == 0 and (kwargs is None or len(kwargs) == 0): + raise ValueError(f" No input for '{op.__name__}'!") + + # Validate types + has_distributed_tensor = False + + def is_distributed_tensor(e): + nonlocal has_distributed_tensor + if isinstance(e, ShardedTensor): + has_distributed_tensor = True + + pytree.tree_map_(is_distributed_tensor, args) + pytree.tree_map_(is_distributed_tensor, kwargs) + + if not has_distributed_tensor: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} are called without any distributed tensor!" + ) + + # Validate all distributed tensors use the same PG. + cur_pg: Optional[torch.distributed.ProcessGroup] = None + + def validate_pg(e): + nonlocal cur_pg + if isinstance(e, ShardedTensor): + if cur_pg is not None and e._process_group is not cur_pg: + raise RuntimeError( + "All distributed tensors should use the " + "same ProcessGroup if used together in an op." + ) + cur_pg = e._process_group + + pytree.tree_map_(validate_pg, args) + pytree.tree_map_(validate_pg, kwargs) + + +def _register_default_op(op, decorator): + @decorator(op) + def tensor_default_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for the default tensor ops that + behave the same as ``torch.Tensor`` such as ``torch.Tensor.shape`` or + ``torch.Tensor.dtype``. We simply lower to the real op call with + DisableTorchFunctionSubclass context like ``torch.Tensor.__torch_function__`` + to avoid recursions. + """ + if kwargs is None: + kwargs = {} + + with torch._C.DisableTorchFunctionSubclass(): + return op(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/metadata.py b/phivenv/Lib/site-packages/torch/distributed/_shard/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..650f649beb7b3061b82d88d4c6e6aac0de403a39 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/metadata.py @@ -0,0 +1,64 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from functools import reduce +from typing import Optional, Union + +from torch.distributed.remote_device import _remote_device + + +@dataclass +class ShardMetadata: + """ + Represents a shard of the overall Tensor including its + offsets, lengths and device placement. + + Args: + shard_offsets(List[int]): Offsets in the original tensor indicating + the start offsets for this shard. Should have the same rank as + the original tensor. + shard_sizes(List[int]): Integers indicating the size of each + dimension for this shard. Should have the same rank as the + original tensor. + placement(:class:`torch.distributed._remote_device`): + Specifies the placement of this shard. + """ + + __slots__ = ["shard_offsets", "shard_sizes", "placement"] + + shard_offsets: list[int] + shard_sizes: list[int] + placement: Optional[_remote_device] + + def __init__( + self, + shard_offsets: list[int], + shard_sizes: list[int], + placement: Optional[Union[str, _remote_device]] = None, + ): + self.shard_offsets = shard_offsets + self.shard_sizes = shard_sizes + if isinstance(placement, str): + self.placement = _remote_device(placement) + else: + self.placement = placement + if len(self.shard_offsets) != len(self.shard_sizes): + raise ValueError( + f"shard_offsets and shard_sizes should have " + f"the same number of elements, found {len(self.shard_offsets)} " + f"and {self.shard_sizes} respectively" + ) + + for i in range(len(self.shard_offsets)): + if self.shard_offsets[i] < 0: + raise ValueError("shard_offsets should be >=0") + if self.shard_sizes[i] < 0: + raise ValueError("shard_sizes should be >= 0") + + def __hash__(self): + def _hash_reduce(a, b): + return (a << 8) + hash(b) + + res = reduce(_hash_reduce, self.shard_offsets, 37) + res = reduce(_hash_reduce, self.shard_sizes, res) + res = _hash_reduce(res, self.placement) + return res diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/op_registry_utils.py b/phivenv/Lib/site-packages/torch/distributed/_shard/op_registry_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..669af1cbbda88715825274c28afa4aba13585b3d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/op_registry_utils.py @@ -0,0 +1,41 @@ +# mypy: allow-untyped-defs +import functools +from inspect import signature + +from .common_op_utils import _basic_validation + + +""" +Common utilities to register ops on ShardedTensor +and PartialTensor. +""" + + +def _register_op(op, func, op_table): + """ + Performs basic validation and registers the provided op in the given + op_table. + """ + if len(signature(func).parameters) != 4: + raise TypeError( + f"Custom sharded op function expects signature: " + f"(types, args, kwargs, process_group), but received " + f"signature: {signature(func)}" + ) + + op_table[op] = func + + +def _decorator_func(wrapped_func, op, op_table): + """ + Decorator function to register the given ``op`` in the provided + ``op_table`` + """ + + @functools.wraps(wrapped_func) + def wrapper(types, args, kwargs, process_group): + _basic_validation(op, args, kwargs) + return wrapped_func(types, args, kwargs, process_group) + + _register_op(op, wrapper, op_table) + return wrapper diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..010b2e8177476a62103675570f3930bcebd2252f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__init__.py @@ -0,0 +1,53 @@ +from collections.abc import Iterator +from typing import Union + +import torch.nn as nn +from torch.distributed._shard.sharded_tensor import ShardedTensor + +from .api import ShardedOptimizer + + +def named_params_with_sharded_tensor( + module: nn.Module, + prefix: str = "", + recurse: bool = True, +) -> Iterator[tuple[str, Union[nn.Parameter, ShardedTensor]]]: + r"""Returns an iterator over module parameters (together with the + ShardedTensor parameters), yielding both the name of the parameter + as well as the parameter itself. This is typically passed to a + :class:torch.distributed._shard.sharded_optim.ShardedOptimizer + + Args: + prefix (str): prefix to prepend to all parameter names. + recurse (bool): if True, then yields parameters of this module + and all submodules. Otherwise, yields only parameters that + are direct members of this module. + + Yields: + (str, Union[Tensor, ShardedTensor]): Tuple containing + the name and parameter (or ShardedTensor parameter) + + Example:: + + >>> # xdoctest: +SKIP + >>> model = torch.nn.Linear(*linear_size) + >>> shard_parameter(model, "weight", spec) + >>> for name, param in named_params_with_sharded_tensor(model): + >>> if name in ['weight']: + >>> print(param.size()) + + """ + modules = module.named_modules(prefix=prefix) if recurse else [(prefix, module)] + + memo = set() + for mod_prefix, mod in modules: + # find all sharded tensor params + for name, val in vars(mod).items(): + if isinstance(val, ShardedTensor) and val not in memo: + memo.add(val) + name = mod_prefix + ("." if mod_prefix else "") + name + yield name, val + + # find all nn.Parameters + for name, val in module.named_parameters(): + yield name, val diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7fe58e2cf331c910b2f41c340eeeaa1edf3e288 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f1d3f189eb68ac64d074faca1054bf68e1450f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/api.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/api.py new file mode 100644 index 0000000000000000000000000000000000000000..25b9e4af7941fd955afdd0c6effdfae4a6e152ed --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_optim/api.py @@ -0,0 +1,102 @@ +# mypy: allow-untyped-defs +from collections.abc import Mapping +from typing import Any, Union + +import torch.optim as optim +from torch import Tensor +from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class ShardedOptimizer(optim.Optimizer): + def __init__( + self, + named_params: Mapping[str, Union[Tensor, ShardedTensor]], + optimizer_class, + *optimizer_args, + **optimizer_kwargs, + ): + """ + ShardedOptimizer collects all tensors and local shard tensors of + ShardedTensor, then use these tensors as ``params`` for optimizers + + Args: + named_params (Dict[str, Union[Tensor, ShardedTensor]]) : a Dict + of parameters, where key is the parameter key, value is either + Tensor or ShardedTensor parameter. + optimizer_class (torch.optim.Optimizer): the Optimizer to use + locally, i.e. torch.optim.SGD, torch.optim.Adagrad, etc. + *optimizer_args: the arguments to initialize the optimizer. + **optimizer_kwargs: the key-word arguments to initialize the optimizer. + + """ + tensors: list[Tensor] = [] + for value in named_params.values(): + if isinstance(value, ShardedTensor): + tensors.extend( + local_shard.tensor for local_shard in value.local_shards() + ) + else: + tensors.append(value) + + self.named_params = named_params + self._optim = optimizer_class(tensors, *optimizer_args, **optimizer_kwargs) + self.param_groups = self._optim.param_groups + self.state = self._optim.state + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + r"""Resets the gradients of all optimized :class:`torch.Tensor` s. + + Args: + set_to_none (bool): instead of setting to zero, set the grads to None. + This will in general have lower memory footprint, and can modestly improve performance. + However, it changes certain behaviors. For example: + 1. When the user tries to access a gradient and perform manual ops on it, + a None attribute or a Tensor full of 0s will behave differently. + 2. If the user requests ``zero_grad(set_to_none=True)`` followed by a backward pass, ``.grad``\ s + are guaranteed to be None for params that did not receive a gradient. + 3. ``torch.optim`` optimizers have a different behavior if the gradient is 0 or None + (in one case it does the step with a gradient of 0 and in the other it skips + the step altogether). + """ + self._optim.zero_grad(set_to_none) + + def step(self, closure=None): + r"""Performs a single optimization step (parameter update). + + Args: + closure (Callable): A closure that reevaluates the model and + returns the loss. Optional for most optimizers. + + .. note:: + Unless otherwise specified, this function should not modify the + ``.grad`` field of the parameters. + """ + self._optim.step(closure) + + def state_dict(self) -> dict[str, Any]: + """ + Returned state and param_groups will contain parameter keys + instead of parameter indices like torch.optim.Optimizer. + This allows for advanced functionality like optimizer re-sharding to be implemented. + """ + # TODO: implement state_dict + raise NotImplementedError("ShardedOptimizer state_dict not implemented yet!") + + def load_state_dict(self, state_dict: Mapping[str, Any]): + r"""Loads the ShardedOptimizer state. + + Args: + state_dict (dict): ShardedOptimizer state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # TODO: implement load_state_dict + raise NotImplementedError( + "ShardedOptimizer load_state_dict not implemented yet!" + ) + + def add_param_group(self, param_group: Any): + r"""Add a new param group""" + # TODO: implement add_param_group + raise NotImplementedError( + "ShardedOptimizer add_param_group not implemented yet!" + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..84bf2c5dd320c4f679b7ad4c7af4e203a17e7094 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__init__.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +import functools +from typing import TYPE_CHECKING + +import torch +from torch.distributed._shard.op_registry_utils import _decorator_func + +from .api import ( + _CUSTOM_SHARDED_OPS, + _SHARDED_OPS, + Shard, + ShardedTensor, + ShardedTensorBase, + ShardedTensorMetadata, + TensorProperties, +) +from .metadata import ShardMetadata # noqa: F401 + + +if TYPE_CHECKING: + from torch.distributed._shard.sharding_spec import ShardingSpec +else: + ShardingSpec = "ShardingSpec" + + +def empty( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with uninitialized data. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def ones( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` with the scalar value 1. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=1, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def zeros( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Returns a :class:`ShardedTensor` filled with the scalar value 0. + Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + return full( + sharding_spec, + size, + fill_value=0, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + + +def full( + sharding_spec: ShardingSpec, + size, + fill_value, + *, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with fill_value. The tensor's dtype + is inferred from fill_value. If dtype is specified, it will override the + inferred type from fill_value. Needs to be called on all ranks in an SPMD fashion. + Args: + sharding_spec (:class:`torch.distributed._sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + fill_value (Scalar) - the value to fill the output tensor with. + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.constant_(sharded_tensor, fill_value) # type: ignore[arg-type] + return sharded_tensor + + +def rand( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + on the interval :math:`[0, 1)`. The shape of the tensor is defined by the + variable argument `size`. Needs to be called on all ranks in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.uniform_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def randn( + sharding_spec: ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, +) -> ShardedTensor: + """ + Creates a :class:`ShardedTensor` filled with random numbers from a uniform distribution + with mean `0` and variance `1` (also called standard normal distribution). The shape + of the tensor is defined by the variable argument `size`. Needs to be called on all ranks + in an SPMD fashion. + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a list, tuple, or `torch.Size` of integers defining the shape of the + output tensor. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object on each rank + """ + sharded_tensor = ShardedTensor( + sharding_spec, + *size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + pin_memory=pin_memory, + memory_format=memory_format, + process_group=process_group, + init_rrefs=init_rrefs, + ) + torch.nn.init.normal_(sharded_tensor, 0, 1) # type: ignore[arg-type] + return sharded_tensor + + +def init_from_local_shards( + local_shards: list[Shard], *global_size, process_group=None, init_rrefs=False +) -> ShardedTensor: + """ + Creates an :class:`ShardedTensor` from local shards and the global metadata. + Needs to be called on all ranks in an SPMD fashion. + + Args: + local_shards (List[:class `torch.distributed._shard.sharded_tensor.Shard`]): A list + of shards that represent the local shards on this rank. + global_size (int...): a list, tuple, or `torch.Size` of integers defining the + shape of the overall sharded tensor. + + Keyword args: + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` object handle on this rank + + + Examples: + Suppose we want construct a sharded tensor on two ranks, global size = (10, 5), + each shard have a (5, 5) local tensor, we can do it like below: + + on rank 0: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[0, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:0/cuda:0" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + + on rank 1: + >>> # xdoctest: +SKIP("not distributed") + >>> local_shard_metadata = ShardMetadata( + >>> shard_offsets=[5, 0], + >>> shard_lengths=[5, 5], + >>> placement="rank:1/cuda:1" + >>> ) + >>> local_shards = [Shard(torch.randn(5, 5), local_shard_metadata)] + >>> sharded_tensor = init_from_local_shards(local_shards, [10, 5]) + """ + return ShardedTensor._init_from_local_shards( + local_shards, *global_size, process_group=process_group, init_rrefs=init_rrefs + ) + + +def state_dict_hook(module, destination, prefix, local_metadata): + """ + Hook to add ShardedTensor to Module's ``state_dict``. Needs to be + registered to the Module using + :meth:`torch.nn.Module._register_state_dict_hook`. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name, attr in submodule.__dict__.items(): + if isinstance(attr, ShardedTensor): + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + destination[key] = attr + + +def pre_load_state_dict_hook( + module, + state_dict, + prefix, + local_metadata, + strict, + missing_keys, + unexpected_keys, + error_msgs, +): + """ + Pre-load state dict hook to add ShardedTensor to the module. + """ + for submodule_name, submodule in module.named_modules(): + for attr_name in submodule.__dict__.keys(): + mod_prefix = prefix + submodule_name + key = mod_prefix + ("." if mod_prefix else "") + attr_name + if key in state_dict: + if isinstance(state_dict[key], ShardedTensor): + setattr(submodule, attr_name, state_dict[key]) + + +def custom_sharded_op_impl(func): + """ + Provides a way for users to write their own custom sharded operator. This + can be used to override existing ShardedTensor operators or write a new + one not supported by ShardedTensor. If the operator in question is covered + by ``__torch_function__`` dispatch and has a ShardedTensor as any of its + parameters, the function provided will be invoked for that operator. + + Example:: + >>> # xdoctest: +SKIP + >>> @custom_sharded_op_impl(torch.nn.functional.linear) + >>> def my_custom_sharded_linear(types, args, kwargs, process_group): + >>> ... + >>> # xdoctest: +SKIP("Undefined variables") + >>> input = torch.rand(10, 32) + >>> weight = sharded_tensor.rand(32, 16) + >>> bias = torch.rand(16) + >>> # This will call 'my_custom_sharded_linear' + >>> torch.nn.functional.linear(input, weight, bias) + + The types, args and kwargs parameters are the same parameters that are + passed to ``__torch_function__`` dispatch API + (https://pytorch.org/docs/stable/notes/extending.html#extending-torch). + There is an additional ``process_group`` parameter which is the + process_group used for the ShardedTensor and can be used by + implementations for communications within a sharded implementation. + + Args: + func(Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.nn.functional.linear) + """ + return functools.partial(_decorator_func, op=func, op_table=_CUSTOM_SHARDED_OPS) + + +def _sharded_op_impl(func): + """ + Decorator to register a default sharded op. + """ + return functools.partial(_decorator_func, op=func, op_table=_SHARDED_OPS) + + +# Import all builtin sharded ops +from ._ops import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b79b0760f24812711cd0636e3cc9d1a149e50b13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..874ce6e161bd0b8f6da57b204fb2be89d81b5a59 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef4e7baabd091070bd5a056be00385cbda306764 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logger.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9dd9e2df342c08fe7b4f6c30e965764d12f2b79 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/logging_handlers.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acc105bf03ac865a639bde23aeb41e15fe88b5bc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/metadata.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5193d91e6ae54a106a0ba1bd3f3cb083243f1861 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/reshard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fc954217718abafbb2707ad0e270383d261165a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/shard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71fa20c6b5e23fa529076cf912d536f57ef7435f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d759f879ec717b0fdc023737f94f06aa970e942 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__init__.py @@ -0,0 +1,13 @@ +import torch.distributed._shard.sharded_tensor._ops.misc_ops +import torch.distributed._shard.sharded_tensor._ops.tensor_ops + +# Import all ChunkShardingSpec ops +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import ( + sharded_embedding, +) +from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import ( + sharded_embedding_bag, +) + +from .binary_cmp import allclose, equal +from .init import constant_, kaiming_uniform_, normal_, uniform_ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c6b0f59de482c30635f2f539c6fa00173b2ad77 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..14b99dff3fa4c50bc661d5497a8c8d13f7e6724e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..115a611415f2afd72e52d4434da42aba7b7ae654 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/binary_cmp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..970dacaa1740fdcde99ddd78bab0cda28a99a018 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/init.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2528fe9c3f39133ecdecab85f7e789dead8b14d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/misc_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87076f4ba464fe165e3e0f4f4ebd7a2a4f209436 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/__pycache__/tensor_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..d3ef41c8ad5e101b1d3bf1ad879853838af7c26f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +import functools + +from torch.distributed._shard.common_op_utils import _basic_validation +from torch.distributed._shard.sharded_tensor import ( + _sharded_op_impl, + Shard, + ShardedTensor, +) + + +def _sharded_op_common(op, early_stop_func, extra_check): + """ + Inject sharded tensor op registration with common logics executed before + different behaviors are done on either local shards or a local tensor. + + Example:: + >>> # xdoctest: +SKIP("Undefined variables") + >>> op = torch.transpose + >>> @_sharded_op_impl(op) + >>> @_sharded_op_common(op, early_stop_func, extra_check) + >>> def sharded_tensor_op(types, args, kwargs, process_group): + >>> ... + >>> + >>> st = sharded_tensor.rand(32, 16) + >>> st.transpose(1, 2) + >>> # This will call '_sharded_op_common' + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + + Return: + func (Callable): Torch function for which we want to provide a sharded + implementation (ex: torch.transpose) + """ + + def decorator_sharded_func(wrapped_func): + @functools.wraps(wrapped_func) + def wrapper(types, args=(), kwargs=None, pg=None): + _basic_validation(op, args, kwargs) + + st = args[0] + if kwargs is None: + kwargs = {} + if extra_check: + extra_check(*args, **kwargs) + if early_stop_func: + early_stop = early_stop_func(*args, **kwargs) + if early_stop: + return st + return wrapped_func(types, args, kwargs, pg) + + return wrapper + + return decorator_sharded_func + + +def _register_sharded_op_on_local_shards( + op, early_stop_func=None, extra_check=None, customized_func=None +): + """ + Handles ``__torch_function__`` dispatch for ops which are performed on + each shard of the sharded tensor such as elementwise op like + ``torch.nn.functional.gelu`` or ``torch.nn.functional.relu``. + + For more complicated ops, a customized func can be used to generate + the new shards and sharded tensor size. + + This function expects that the original ShardingSpec for the ShardedTensor + is preserved irrespective of whether or not a customized function is used. + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + customized_func (Callable, optional): the func for customized logic + to generate new shards and sharded tensor size. + Default: if ``None``, we simply lower to the real op call with + all local shards of the st. + + Return: + func (Callable): registered implementation for sharded op for + ``__torch_function__`` dispatch. + """ + + @_sharded_op_impl(op) + @_sharded_op_common(op, early_stop_func, extra_check) + def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): + st = args[0] + st_metadata = st.metadata() + local_shards = st.local_shards() + local_shards_new = [] + if customized_func: + local_shards_new, st_metadata = customized_func(args, kwargs, pg) + else: + for local_shard in local_shards: + args = (local_shard.tensor, *args[1:]) + local_shards_new.append( + Shard(op(*args, **kwargs), local_shard.metadata) + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards_new, + st_metadata, + process_group=pg, + init_rrefs=st._init_rrefs, + sharding_spec=st.sharding_spec(), + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py new file mode 100644 index 0000000000000000000000000000000000000000..6d72c9822d15888cf50a8c38eacd2bc148970e80 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/binary_cmp.py @@ -0,0 +1,78 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard.sharded_tensor import _sharded_op_impl, ShardedTensor + + +def _communicate_result(result, pg): + # Gather results from all ranks. + if result: + result_tensor = torch.ones(1, device=torch.device(torch.cuda.current_device())) + else: + result_tensor = torch.zeros(1, device=torch.device(torch.cuda.current_device())) + + dist.all_reduce(result_tensor, group=pg) + + expected_result = torch.ones( + 1, device=torch.device(torch.cuda.current_device()) + ) * dist.get_world_size(pg) + + return torch.equal(result_tensor, expected_result) + + +def binary_cmp(cmp_fun, types, args, kwargs=None, process_group=None): + if len(args) != 2: + raise ValueError(f"Expected two arguments for torch.{cmp_fun.__name__}") + + st1 = args[0] + st2 = args[1] + if not (isinstance(st1, ShardedTensor) and isinstance(st2, ShardedTensor)): + raise TypeError( + f"Both arguments to torch.{cmp_fun.__name__} need to be of type ShardedTensor" + ) + + # Verify same PG + if st1._process_group != st2._process_group: + return False + + if distributed_c10d._rank_not_in_group( + st1._process_group + ) or distributed_c10d._rank_not_in_group(st2._process_group): + return distributed_c10d._rank_not_in_group( + st1._process_group + ) == distributed_c10d._rank_not_in_group(st2._process_group) + + # Verify metadata + if st1.metadata() != st2.metadata(): + return _communicate_result(False, st1._process_group) + + # Verify number of local shards + st1_local_shards = st1.local_shards() + st2_local_shards = st2.local_shards() + if len(st1_local_shards) != len(st2_local_shards): + return _communicate_result(False, st1._process_group) + + # kwargs must be dict-like + if kwargs is None: + kwargs = {} + # Verify each local shard + for idx in range(len(st1_local_shards)): + if st1_local_shards[idx].metadata != st2_local_shards[idx].metadata: + return _communicate_result(False, st1._process_group) + if not cmp_fun( + st1_local_shards[idx].tensor, st2_local_shards[idx].tensor, **kwargs + ): + return _communicate_result(False, st1._process_group) + + return _communicate_result(True, st1._process_group) + + +@_sharded_op_impl(torch.equal) +def equal(types, args, kwargs, process_group): + return binary_cmp(torch.equal, types, args, kwargs, process_group) + + +@_sharded_op_impl(torch.allclose) +def allclose(types, args, kwargs, process_group): + return binary_cmp(torch.allclose, types, args, kwargs, process_group) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py new file mode 100644 index 0000000000000000000000000000000000000000..f70afd2d371f586815cd0b7708315eef34d8f3a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -0,0 +1,151 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed._shard.sharded_tensor as sharded_tensor +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + + +def validate_param(param, param_name): + if param is None: + raise ValueError(f"param: {param_name} shouldn't be None!") + + +@_sharded_op_impl(torch.nn.init.uniform_) +def uniform_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensor in tensor.local_shards with values drawn from the uniform + distribution :math:`\mathcal{U}(a, b)`. + Args: + tensor: tensor sharded across devices + a: the lower bound of the uniform distribution + b: the upper bound of the uniform distribution + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + a = kwargs["a"] + validate_param(a, "a") + b = kwargs["b"] + validate_param(b, "b") + + for shard in sharded_tensor.local_shards(): + torch.nn.init.uniform_(shard.tensor, a=a, b=b) + return sharded_tensor + + +@_sharded_op_impl(torch.nn.init.normal_) +def normal_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensors in tensor.local_shards with values drawn from the normal + distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`. + Args: + tensor: tensor sharded across devices + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + mean = kwargs["mean"] + validate_param(mean, "mean") + std = kwargs["std"] + validate_param(std, "std") + + for shard in sharded_tensor.local_shards(): + torch.nn.init.normal_(shard.tensor, mean=mean, std=std) + return sharded_tensor + + +@_sharded_op_impl(torch.nn.init.kaiming_uniform_) +def kaiming_uniform_(types, args=(), kwargs=None, pg=None): + r""" + Fills the Tensors in tensor.local_shards with values according to the method + described in `Delving deep into rectifiers: Surpassing human-level + performance on ImageNet classification` - He, K. et al. (2015), using a + uniform distribution. The resulting tensor will have values sampled from + :math:`\mathcal{U}(-\text{bound}, \text{bound})` where + .. math:: + \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} + Also known as He initialization. + Args: + tensor: tensor sharded across devices + a: the negative slope of the rectifier used after this layer (only + used with ``'leaky_relu'``) + mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` + preserves the magnitude of the variance of the weights in the + forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the + backwards pass. + nonlinearity: the non-linear function (`nn.functional` name), + recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + a = kwargs["a"] + validate_param(a, "a") + mode = kwargs["mode"] + validate_param(mode, "mode") + nonlinearity = kwargs["nonlinearity"] + validate_param(nonlinearity, "nonlinearity") + + for shard in sharded_tensor.local_shards(): + torch.nn.init.kaiming_uniform_( + shard.tensor, a=a, mode=mode, nonlinearity=nonlinearity + ) + return sharded_tensor + + +@_sharded_op_impl(torch.nn.init.constant_) +def constant_(types, args=(), kwargs=None, pg=None): + r""" + Fills the input ShardedTensor with the value \text{val}val. + Args: + tensor: tensor sharded across devices + val: the value to fill the tensor with + """ + validate_param(kwargs, "kwargs") + sharded_tensor = kwargs["tensor"] + validate_param(sharded_tensor, "tensor") + val = kwargs["val"] + validate_param(val, "val") + for shard in sharded_tensor.local_shards(): + torch.nn.init.constant_(shard.tensor, val=val) + return sharded_tensor + + +tensor_like_creation_op_map = { + torch.full_like: sharded_tensor.full, + torch.empty_like: sharded_tensor.empty, + torch.zeros_like: sharded_tensor.zeros, + torch.ones_like: sharded_tensor.ones, + torch.rand_like: sharded_tensor.rand, + torch.randn_like: sharded_tensor.randn, +} + + +# tensor ops that behave the same as the default tensor +def register_tensor_creation_op(op): + @_sharded_op_impl(op) + def tensor_creation_op(types, args=(), kwargs=None, pg=None): + """ + Handles ``__torch_function__`` dispatch for tensor creation ops that + takes a ShardedTensor as argument, such as ``torch.zeros_like`` or + ``torch.full_like``. + """ + creation_op = tensor_like_creation_op_map.get(op, None) + if creation_op is None: + raise RuntimeError(f"Tensor creation {op} not supported!") + if kwargs is None: + kwargs = {} + + st = args[0] + + new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator] + return new_st + + +register_tensor_creation_op(torch.full_like) +register_tensor_creation_op(torch.empty_like) +register_tensor_creation_op(torch.zeros_like) +register_tensor_creation_op(torch.ones_like) +register_tensor_creation_op(torch.rand_like) +register_tensor_creation_op(torch.randn_like) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..b84a78c026bf1662b9fb6cc93e7aa0142eade364 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/misc_ops.py @@ -0,0 +1,12 @@ +# mypy: allow-untyped-defs +import torch +from torch.distributed._shard.sharded_tensor import _sharded_op_impl + + +# This is used by `_apply()` within module.py to set new +# parameters after apply a certain method, we should follow +# the future behavior of overwriting the existing tensor +# instead of doing in-place change using `.data = `. +@_sharded_op_impl(torch._has_compatible_shallow_copy_type) +def tensor_has_compatible_shallow_copy_type(types, args=(), kwargs=None, pg=None): + return False diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..8f2af42532216eecdd8b2a10201a9b0c26b5179e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -0,0 +1,219 @@ +# mypy: allow-untyped-defs +import copy + +import torch +from torch.distributed._shard.common_op_utils import _register_default_op +from torch.distributed._shard.sharded_tensor import ( + _sharded_op_impl, + Shard, + ShardedTensor, +) + +from ._common import _register_sharded_op_on_local_shards + + +# Tensor properties access +_register_default_op(torch.Tensor.shape.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.dtype.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.layout.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.size, _sharded_op_impl) +_register_default_op(torch.Tensor.dim, _sharded_op_impl) +_register_default_op(torch.Tensor.ndim.__get__, _sharded_op_impl) # type: ignore[attr-defined] +_register_default_op(torch.Tensor.is_contiguous, _sharded_op_impl) +_register_default_op(torch.Tensor.contiguous, _sharded_op_impl) +_register_default_op(torch.Tensor.is_floating_point, _sharded_op_impl) + +# __reduce_ex__ to dispatch to get_state/set_state +_register_default_op(torch.Tensor.__reduce_ex__, _sharded_op_impl) + +# autograd related properties +_register_default_op(torch.Tensor.requires_grad.__get__, _sharded_op_impl) # type: ignore[attr-defined] +# TODO: set grad with a ShardedTensor that consists of all local grads +_register_default_op(torch.Tensor.grad.__get__, _sharded_op_impl) # type: ignore[union-attr] +_register_default_op(torch.Tensor.grad_fn.__get__, _sharded_op_impl) # type: ignore[union-attr] +_register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ignore[attr-defined] + + +# device property is ambiguous as from a global prospective, +# ShardedTensor.device consists of multiple devices (might even across hosts) +# We choose to return the current device of the local tensor to represent +# the device property on each rank +@_sharded_op_impl(torch.Tensor.device.__get__) +def tensor_device(types, args=(), kwargs=None, pg=None): + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + dev: torch.device + if self_st._local_shards: + dev = self_st._local_shards[0].tensor.device + elif pg and pg._get_backend_name() == "gloo": + dev = torch.device("cpu") + else: + dev = torch.device(torch.cuda.current_device()) + return dev + + +@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined] +def st_is_meta(types, args=(), kwargs=None, pg=None): + return args[0].local_tensor().is_meta + + +def sharded_type_as_check(*args, **kwargs): + """ + Perform extra checks for the sharded_type_as op such as the input needs to + be either a Tensor or ShardedTensor. + + Args: same as ``torch.Tensor.type_as``. + + Return: None + """ + if len(args) < 2: + raise ValueError("Needs to give a tensor to cast type as!") + if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor): + raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!") + + +def same_dtype(*args, **kwargs): + """ + When the dtype is the same, return the original ShardedTensor. + + Args: same as ``torch.Tensor.type_as``. + + Return (bool): Whether to return early or not. + """ + return args[0].dtype == args[1].dtype + + +def sharded_type_as(args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op. + + Args: same as ``torch.Tensor.type_as``. + + Return: + new_local_shards (List[Shard]): Local shards for the new sharded tensor. + st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor. + """ + st = args[0] + tensor = args[1] + if isinstance(tensor, ShardedTensor): + tensor = tensor.local_tensor() + new_local_shards = [ + Shard(shard.tensor.type_as(tensor), shard.metadata) + for shard in st.local_shards() + ] + st_meta = copy.deepcopy(st._metadata) + st_meta.tensor_properties.dtype = tensor.dtype + return new_local_shards, st_meta + + +_register_sharded_op_on_local_shards( + torch.Tensor.type_as, + early_stop_func=same_dtype, + extra_check=sharded_type_as_check, + customized_func=sharded_type_as, +) + + +def sharded_deepcopy(args, kwargs, pg): + # NOTE: we directly implement deepcopy magic method + # instead of using the default tensor.__deepcopy__ + # and implement clone(). This is because the default + # tensor deepcopy copies every attribute, but the + # process_group in ShardedTensor cannot be deep copied. + self_st = args[0] + new_local_shards = copy.deepcopy(self_st.local_shards()) + new_metadata = copy.deepcopy(self_st.metadata()) + return new_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + torch.Tensor.__deepcopy__, + customized_func=sharded_deepcopy, +) + + +@_sharded_op_impl(torch.Tensor.copy_) +def sharded_inplace_copy(types, args, kwargs, pg): + # NOTE: inplace op don't need to rewrap + kwargs = {} if kwargs is None else kwargs + self_st = args[0] + new_st = args[1] + nonblocking = kwargs.get("non_blocking", False) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + if local_shard.metadata != new_shard.metadata: + raise RuntimeError( + "inplace copy can only happen between two ShardedTensor with same metadata!" + ) + for local_shard, new_shard in zip(self_st.local_shards(), new_st.local_shards()): + local_shard.tensor.copy_(new_shard.tensor, nonblocking) + + return self_st + + +def sharded_clone(args, kwargs, pg): + self_st = args[0] + desire_memory_format = kwargs.get("memory_format", None) + if desire_memory_format and desire_memory_format != torch.preserve_format: + raise RuntimeError("Only support torch.preserve_format for ShardedTensor!") + cloned_local_shards = [ + Shard( + local_shard.tensor.clone(memory_format=desire_memory_format), + metadata=copy.deepcopy(local_shard.metadata), + ) + for local_shard in self_st.local_shards() + ] + new_metadata = copy.deepcopy(self_st.metadata()) + return cloned_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + torch.Tensor.clone, + customized_func=sharded_clone, +) + + +def sharded_detach(args, kwargs, pg): + self_st = args[0] + detached_local_shards = [ + Shard( + local_shard.tensor.detach(), + metadata=copy.deepcopy(local_shard.metadata), + ) + for local_shard in self_st.local_shards() + ] + new_metadata = copy.deepcopy(self_st.metadata()) + new_metadata.tensor_properties.requires_grad = False + return detached_local_shards, new_metadata + + +_register_sharded_op_on_local_shards( + torch.Tensor.detach, + customized_func=sharded_detach, +) + + +@_sharded_op_impl(torch.Tensor.requires_grad_) +def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): + self_st = args[0] + # Validate types + if not isinstance(self_st, ShardedTensor): + raise TypeError("input needs to be a ShardedTensor") + + if kwargs is None: + kwargs = {} + + requires_grad = args[1] if len(args) > 1 else kwargs.get("requires_grad", True) + if requires_grad == self_st.requires_grad: + return self_st + + for local_shard in self_st.local_shards(): + local_shard.tensor.requires_grad_(requires_grad) + + # update the wrapper class property + with torch._C.DisableTorchFunctionSubclass(): + self_st.requires_grad_(requires_grad) + # update the metadata in the meanwhile + self_st._metadata.tensor_properties.requires_grad = requires_grad + return self_st diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/api.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5a45b76010f7f9c9720fd4b6e4fa22abf89aba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/api.py @@ -0,0 +1,1357 @@ +# mypy: allow-untyped-defs +from __future__ import annotations # type: ignore[attr-defined] + +import copy +import operator +import threading +import warnings +import weakref +from dataclasses import dataclass +from functools import reduce +from typing import Callable, cast, Optional, TYPE_CHECKING +from typing_extensions import deprecated + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +from torch._utils import _get_device_module +from torch.distributed import distributed_c10d, rpc +from torch.distributed._shard._utils import DEPRECATE_MSG +from torch.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) +from torch.distributed._shard.sharding_spec.api import ( + _dispatch_custom_op, + _has_custom_op, +) +from torch.distributed.remote_device import _remote_device +from torch.utils import _pytree as pytree + +from .metadata import ShardedTensorMetadata, TensorProperties +from .reshard import reshard_local_shard, reshuffle_local_shard +from .shard import Shard +from .utils import ( + _flatten_tensor_size, + _parse_and_validate_remote_device, + _validate_output_tensor_for_gather, + build_global_metadata, + build_metadata_from_local_shards, +) + + +if TYPE_CHECKING: + from collections.abc import Sequence + + from torch.distributed._shard.metadata import ShardMetadata + + +# Tracking for sharded tensor objects. +_sharded_tensor_lock = threading.Lock() +_sharded_tensor_current_id = 0 +_sharded_tensor_map: dict[int, weakref.ReferenceType[ShardedTensor]] = {} + +# Default sharded ops +_SHARDED_OPS: dict[Callable, Callable] = {} + +# Customized user ops +_CUSTOM_SHARDED_OPS: dict[Callable, Callable] = {} + + +def _register_remote_shards( + sharded_tensor_id: int, rrefs: list[rpc.RRef[Shard]], rpc_rank: int +): + with _sharded_tensor_lock: + if sharded_tensor_id not in _sharded_tensor_map: + raise RuntimeError( + f"Could not find sharded_tensor_id: {sharded_tensor_id} in map: {_sharded_tensor_map.keys()}" + ) + + sharded_tensor = _sharded_tensor_map[sharded_tensor_id]() + if sharded_tensor is None: + raise RuntimeError("ShardedTensor weakref has been deallocated") + else: + sharded_tensor._register_remote_shards(rrefs, rpc_rank) + + +class ShardedTensorBase(torch.Tensor): + _sharding_spec: shard_spec.ShardingSpec + _metadata: ShardedTensorMetadata + _local_shards: list[Shard] + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + # Use __new__ to construct a wrapper tensor, for recording tensor + # properties and logging purposes. + torch._C._log_api_usage_once("torch.distributed._shard.sharded_tensor") + + # check sharding spec and build sharded tensor metadata + if not isinstance(sharding_spec, shard_spec.ShardingSpec): + raise ValueError(f"Expecting ShardingSpec but got: {type(sharding_spec)}") + + sizes = _flatten_tensor_size(size) + dtype = kwargs["dtype"] + layout = kwargs["layout"] + pin_memory = kwargs["pin_memory"] + requires_grad = kwargs["requires_grad"] + + if dtype is None: + dtype = torch.get_default_dtype() + + tensor_properties = TensorProperties( + dtype, layout, requires_grad, pin_memory=pin_memory + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + sizes, tensor_properties=tensor_properties + ) + + r = torch.Tensor._make_wrapper_subclass( + cls, + sizes, + dtype=dtype, + layout=layout, + pin_memory=pin_memory, + requires_grad=requires_grad, + ) + # set sharding spec + r._sharding_spec = sharding_spec + # set metadata + r._metadata = sharded_tensor_metadata + # set local shards + r._local_shards = [] + return r + + def metadata(self) -> ShardedTensorMetadata: + """ + Returns a :class:`ShardedTensorMetadata` object corresponding to the + metadata for the entire tensor. + """ + return self._metadata + + def local_shards(self) -> list[Shard]: + """ + Returns a list of :class:`Shard' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + @classmethod + def _init_from_local_shards_and_global_metadata( + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + sharding_spec=None, + ) -> ShardedTensorBase: + """ + Initialize a ShardedTensorBase with local shards and a global + ShardedTensorMetadata built on each rank. + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor_base = ShardedTensorBase.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor_base._local_shards = local_shards + return sharded_tensor_base + + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + raise RuntimeError( + f"A {cls.__name__} object is being used from c++ while calling {func.__module__}.{func.__name__} " + "but the there is no custom __torch_dispatch__ implementation for it." + ) + + +class ShardedTensor(ShardedTensorBase): + """ + ShardedTensor is an torch.Tensor subclass to represent Tensors that are sharded + across multiple devices and multiple processes. + + ShardedTensor is initialized in an SPMD like fashion where each rank + initializes the ShardedTensor. The ShardedTensor object on each rank + then only stores the local shard for the Tensor and provides global + metadata for all the shards. + + ShardedTensor doesn't provide any Tensor like operations but is a wrapper + providing the Tensor representing the local shard and the global metadata. + Using these, users can build their custom distributed._sharded computations + on top of this primitive. The local shards are all initialized using the + create_op specified by tensor_init_params.create_op, e.g., torch.ones, or + torch.empty + + Args: + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification + describing how to shard the Tensor. + size (int...): a sequence of integers defining the shape of the output + tensor. Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned tensor. Default: ``False``. + pin_memory (bool, optional): If set, returned tensor would be allocated in + the pinned memory. Works only for CPU tensors. Default: ``False``. + memory_format (:class:`torch.memory_format`, optional): the desired memory format of + returned Tensor. Default: ``torch.contiguous_format``. + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + .. note:: ShardedTensor uses collectives to do various operations, i.e. it + uses all_gather to do cross rank validations. For NCCL-based process + groups, internal tensor representations of objects must be moved to the + GPU device before communication takes place. In this case, the device + used is given by ``torch.cuda.current_device()`` and it is the user's + responsibility to ensure that this is set so that each rank has an + individual GPU, via ``torch.cuda.set_device()`` + + """ + + def __new__(cls, sharding_spec: shard_spec.ShardingSpec, *size, **kwargs): + self = super().__new__(cls, sharding_spec, *size, **kwargs) + return self + + def __init__( + self, + sharding_spec: shard_spec.ShardingSpec, + *size, + dtype=None, + layout=torch.strided, + requires_grad=False, + pin_memory=False, + memory_format=torch.contiguous_format, + process_group=None, + init_rrefs=False, + ): + # prepare initialization, initialize fields like + # _process_group, _local_shards, etc. + self._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + if layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if memory_format != torch.contiguous_format: + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + self._metadata.tensor_properties.memory_format = memory_format + + current_rank = dist.get_rank() # global rank + + for shard_metadata in self._metadata.shards_metadata: + rank, device = _parse_and_validate_remote_device( + self._process_group, shard_metadata.placement + ) + if rank == current_rank: + local_tensor = _create_tensor_from_params( + shard_metadata.shard_sizes, + local_device=device, + tensor_properties=self._metadata.tensor_properties, + ) + self._local_shards.append(Shard(local_tensor, shard_metadata)) + + # do post initialization (i.e. register sharded_tensor_id, initialize_rpc) + self._post_init() + + def _prepare_init(self, process_group=None, init_rrefs=False): + self._init_rrefs = init_rrefs + self._sharded_tensor_id = None + + self._process_group = self._normalize_pg(process_group) + self._remote_shards: dict[int, list[rpc.RRef[Shard]]] = {} + + def _post_init(self): + # Initialize RPC if available. + if self._init_rrefs: + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + self._sharded_tensor_id = _sharded_tensor_current_id + _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self) + _sharded_tensor_current_id += 1 + + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + "RPC Framework needs to be initialized using" + " torch.distributed.rpc.init_rpc if init_rrefs is set to True" + ) + self._init_rpc() + + def __del__(self): + # Clean up the global map. + with _sharded_tensor_lock: + global _sharded_tensor_current_id, _sharded_tensor_map + if ( + hasattr(self, "_sharded_tensor_id") + and self._sharded_tensor_id in _sharded_tensor_map + ): + _sharded_tensor_map.pop(self._sharded_tensor_id) # type: ignore[call-overload] + + def _init_rpc(self): + # Validate PG and RPC ranks match. + pg_rank = dist.get_rank() + rpc_rank = rpc.get_worker_info().id + if pg_rank != rpc_rank: + raise ValueError( + f"Default ProcessGroup and RPC ranks must be " + f"the same for ShardedTensor, found process group rank: " + f"{pg_rank} and RPC rank: {rpc_rank}" + ) + + self._remote_shards = {} + + # Gather all the sharded tensor ids. + worker_infos = rpc._get_current_rpc_agent().get_worker_infos() + rank_to_name = {} + name_to_rank = {} + + for worker_info in worker_infos: + rank_to_name[worker_info.id] = worker_info.name + name_to_rank[worker_info.name] = worker_info.id + + all_tensor_ids = rpc.api._all_gather(self._sharded_tensor_id) + + # Share the local shards to the entire world. + futs = [] + rpc_rank = rpc.get_worker_info().id + for rank in range(dist.get_world_size()): + # Skip self. + if rank == dist.get_rank(): + continue + + if len(self.local_shards()) != 0: + rrefs: list[rpc.RRef[Shard]] = [ + rpc.RRef(shard) for shard in self.local_shards() + ] + fut = rpc.rpc_async( + rank, + _register_remote_shards, + args=(all_tensor_ids[rank_to_name[rank]], rrefs, rpc_rank), + ) + futs.append(fut) + + torch.futures.wait_all(futs) + + # Barrier for all RPCs to finish on all ranks. + rpc.api._all_gather(None) + + def _get_preferred_device(self) -> torch.device: + """ + Return the preferred device to be used when creating tensors for collectives. + This method takes into account the associated process group + """ + backend = dist.get_backend(self._process_group) + if backend == dist.Backend.NCCL: + return torch.device(torch.cuda.current_device()) + elif backend == dist.Backend.GLOO: + return torch.device("cpu") + else: + backend_config = dist.BackendConfig(backend) + for device, backend_str in backend_config.get_device_backend_map().items(): + if backend_str == backend and device != "cpu": + return torch.device( + device, _get_device_module(device).current_device() + ) + return torch.device("cpu") + + def gather( # type: ignore[override] + self, + dst: int = 0, + out: Optional[torch.Tensor] = None, + enforce_dtype: bool = False, + dtype: Optional[torch.dtype] = None, + ) -> None: + """ + Creates a full :class:`Tensor` on rank ``dst`` by gathering all shards of the + sharded tensor. + + The API needs to be called on all ranks in SPMD fashion. All ranks should have + the same ``dst``. ``out`` should be a tensor of the same size as the overall + size of the sharded tensor on ``dst`` and ``None`` on all other ranks. + + Args: + dst(int): The rank where full tensor is constructed. + Default: 0 + out (:class `torch.Tensor`, optional): The output full tensor. + Must to be provided ONLY on ``dst`` rank. + Default: ``None`` + enforce_dtype (bool): Deprecated, please use dtype instead. Force the + gathered tensors to be the same type as input and output. + dtype (torch.dtype): Force the gathered tensors to be this dtype. + Default: ``None`` + """ + + def shard_size(shard_md): + return reduce(operator.mul, shard_md.shard_sizes) # type: ignore[attr-defined] + + if enforce_dtype: + warnings.warn( + "`enforce_dtype` is deprecated. Please use `dtype` instead.", + FutureWarning, + stacklevel=2, + ) + + rank = dist.get_rank(self._process_group) + full_size = self.metadata().size + _validate_output_tensor_for_gather(rank, dst, full_size, out) + + local_shards = self.local_shards() + world_size = dist.get_world_size(self._process_group) + rank_sizes = [0 for _ in range(world_size)] + max_rank_size = 0 + shard_placement: dict[ShardMetadata, tuple[int, int]] = {} + # collect sizes + for shard_md in self.metadata().shards_metadata: + shard_rank = cast(_remote_device, shard_md.placement).rank() + assert shard_rank is not None + + shard_placement[shard_md] = (shard_rank, rank_sizes[shard_rank]) + rank_sizes[shard_rank] += shard_size(shard_md) + max_rank_size = max(max_rank_size, rank_sizes[shard_rank]) + + gather_list: Optional[list[torch.Tensor]] + if rank == dst: + assert out is not None + if enforce_dtype: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = out.dtype + # TODO make it as a view of out tensor + gather_list = [ + torch.empty((max_rank_size,), device=out.device, dtype=dtype) + for _ in range(world_size) + ] + else: + gather_list = None + + with torch.no_grad(): + if enforce_dtype and len(local_shards) > 0: + # enforce_dtype is deprecated. Do it for backward compatibility. + dtype = local_shards[0].tensor.dtype + data = torch.empty( + max_rank_size, device=self._get_preferred_device(), dtype=dtype + ) + + for shard in local_shards: + src = shard.tensor.flatten() + if src.nelement() == 0: + warnings.warn( + "Gathering a tensor with zero elements on rank " + str(rank) + ) + continue + shard_offset = shard_placement[shard.metadata][1] + data[shard_offset : shard_offset + src.numel()].copy_(src) + + dist.gather( + tensor=data, + gather_list=gather_list, + dst=dst, + group=self._process_group, + ) + if rank != dst: + return + # In _validate_output_tensor_for_gather, we raise if out == None and rank == dst + out = cast(torch.Tensor, out) + assert gather_list is not None + + full_size = self.metadata().size + dims = len(full_size) + for shard_md in self.metadata().shards_metadata: + rank, rank_offset = shard_placement[shard_md] + tensor = gather_list[rank] + tensor = tensor[rank_offset : rank_offset + shard_size(shard_md)] + tensor = tensor.view(shard_md.shard_sizes) + + out_narrow_view = out + for dim in range(dims): + out_narrow_view = out_narrow_view.narrow( + dim, + shard_md.shard_offsets[dim], + shard_md.shard_sizes[dim], + ) + + out_narrow_view.copy_(tensor) + + def cpu( + self, memory_format=torch.preserve_format, process_group=None + ) -> ShardedTensor: + """ + Returns a copy of this object in CPU memory. + + If this ShardedTensor is already on CPU memory, then no copy is + performed and original object is returned. + + .. note:: When moving a ShardedTensor from GPU to CPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupGloo), + it is the user's responsibility to explicitly pass in a new process_group that + is compatible with CPU. + """ + # TODO: make this a __torch_function__ op once ShardedTensor becomes a + # torch.Tensor subclass, see https://github.com/pytorch/pytorch/issues/75402 + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) + all_on_cpu = True + for meta in self.metadata().shards_metadata: + all_on_cpu &= meta.placement.device().type == "cpu" # type: ignore[union-attr] + + # if every shard is already on CPU, return the original object + if all_on_cpu: + return self + + # if not, returns a copy of this object on CPU + list_shards: list[Shard] = [] + # move all local shards to cpu, and change metadata + for shard in self._local_shards: + cpu_tensor = shard.tensor.cpu(memory_format=memory_format) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = torch.device("cpu") # type: ignore[union-attr] + list_shards.append(Shard(cpu_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cpu": # type: ignore[union-attr] + meta.placement._device = torch.device("cpu") # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + st_cpu = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cpu + + def cuda( + self, + device=None, + non_blocking=False, + memory_format=torch.preserve_format, + process_group=None, + ) -> ShardedTensor: + """ + Returns a copy of this object in CUDA memory, if the original ShardedTensor + is on CPU, we will move the local shard to the current GPU device of each + process in a SPMD fashion. + If this ShardedTensor is already on CUDA memory and local shards on each rank are + already on current device, we still returns a new ShardedTensor object with new + metadata, but no underlying data movements are performed. + .. note:: When moving a ShardedTensor from CPU to GPU, the ShardedTensor might + need to be managed by a different type of ProcessGroup(i.e. ProcessGroupNCCL), + it is the user's responsibility to explicitly pass in a new process_group that + is compatible with GPU. + """ + if ( + memory_format != torch.preserve_format + and memory_format != torch.contiguous_format + ): + raise RuntimeError( + "Only `torch.contiguous_format` or " + "`torch.preserve_format` is supported!" + ) + + if device is not None: + device = torch.device(device) if isinstance(device, str) else device + assert ( + isinstance(device, torch.device) + and device.index == torch.cuda.current_device() + ), ( + """Only device without device id (e.g. "cpu" or "cuda") is expected for ShardedTensor!""" + ) + + current_device = torch.device(torch.cuda.current_device()) + # returns a copy of ShardedTensor on CUDA current device + list_shards: list[Shard] = [] + # move all local shards to current device, and change metadata + # if local shards already on the current device, there's no + # real data movement, only the metadata are copied. + for shard in self._local_shards: + cuda_tensor = shard.tensor.cuda( + device=current_device, + non_blocking=non_blocking, + memory_format=memory_format, + ) # type: ignore[call-arg] + metadata = copy.deepcopy(shard.metadata) + metadata.placement._device = current_device # type: ignore[union-attr] + + list_shards.append(Shard(cuda_tensor, metadata)) + + st_meta = copy.deepcopy(self.metadata()) + for meta in st_meta.shards_metadata: + if meta.placement.device().type != "cuda": # type: ignore[union-attr] + meta.placement._device = current_device # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_cuda = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_cuda + + def to(self, *args, **kwargs) -> ShardedTensor: + current_device: torch.device + if self._local_shards: + current_device = self._local_shards[0].tensor.device + elif self._process_group._get_backend_name() == "gloo": + current_device = torch.device("cpu") + else: + current_device = torch.device(torch.cuda.current_device()) + current_dtype = self.dtype + device_to = current_device + dtype_to = current_dtype + if len(args) == 1: + if isinstance(args[0], torch.dtype): + dtype_to = args[0] + elif isinstance(args[0], torch.device): + device_to = args[0] + elif isinstance(args[0], (str, int)): + device_to = torch.device(args[0]) + elif isinstance(args[0], torch.Tensor): + dtype_to = args[0].dtype + device_to = args[0].device + else: + raise RuntimeError(f"ShardedTensor.to() have wrong arguments: {args}") + elif len(args) == 2: + device_to, dtype_to = args + else: + dtype_to = kwargs.get("dtype", current_dtype) + device_to = kwargs.get("device", current_device) + + device_to = ( + torch.device(device_to) if isinstance(device_to, (str, int)) else device_to + ) + + if device_to.type == "cuda": + # if device_to set to cuda, set to current device even + # if user specify the device index. + current_idx = torch.cuda.current_device() + if device_to.index != current_idx: + warnings.warn( + "ShardedTensor.to only move tensor to its current device" + "If you want to put to different device, use `reshard` instead." + ) + device_to = torch.device(current_idx) + + copy_tensor = kwargs.get("copy", False) + non_blocking = kwargs.get("non_blocking", False) + memory_format = kwargs.get("memory_format", torch.preserve_format) + process_group = kwargs.get("process_group", None) + + if ( + not copy_tensor + and dtype_to == current_dtype + and device_to == current_device + ): + # already have correct dtype and device, return itself + return self + + # returns a copy of ShardedTensor on CUDA current device + list_shards: list[Shard] = [] + + for shard in self._local_shards: + new_tensor = shard.tensor.to( # type: ignore[call-overload] + device=device_to, + dtype=dtype_to, + non_blocking=non_blocking, + copy=copy_tensor, + memory_format=memory_format, + ) + metadata = copy.deepcopy(shard.metadata) + if metadata.placement is not None: + metadata.placement._device = device_to + list_shards.append(Shard(new_tensor, metadata)) + + # update metadata + st_meta = copy.deepcopy(self.metadata()) + st_meta.tensor_properties.dtype = dtype_to + for meta in st_meta.shards_metadata: + meta.placement._device = device_to # type: ignore[union-attr] + + pg = self._process_group if process_group is None else process_group + # we need to use `init_from_local_shards` to communicate between ranks + # and update the sharding spec/shards metadata. + st_to = ShardedTensor._init_from_local_shards_and_global_metadata( + list_shards, + sharded_tensor_metadata=st_meta, + process_group=pg, + init_rrefs=self._init_rrefs, + ) + return st_to + + @classmethod + def _normalize_pg( + cls, process_group: Optional[dist.ProcessGroup] + ) -> dist.ProcessGroup: + if process_group is not None: + return process_group + return distributed_c10d._get_default_group() + + @classmethod + def _init_from_local_shards( + cls, + local_shards: list[Shard], + *global_size, + process_group=None, + init_rrefs=False, + ): + # recalc metadata handles special ST creation cases like each rank only has tensor available + # caller need to provide None on the unknown dimension of the global size + # We will change None into zeros and go through the same amount of checks as before to create ST + # and use all_gather to calculate the offsets and global size for metadata + # It is compatible with the current use case since, conventionally we don't pass None as global size + # Therefore the old path won't trigger the new feature + recalc_metadata = False + for dim in global_size: + if dim is None: + recalc_metadata = True + if recalc_metadata: + global_size = tuple( + 0 if dim_size is None else dim_size for dim_size in global_size + ) + # STEP 1: Validate the Shardmetadatas locally + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + world_size = dist.get_world_size(process_group) + + local_sharded_tensor_metadata: Optional[ShardedTensorMetadata] = None + global_tensor_size = _flatten_tensor_size(global_size) + + if len(local_shards) > 0: + local_sharded_tensor_metadata = build_metadata_from_local_shards( + local_shards, global_tensor_size, current_rank, process_group + ) + + # STEP 2. Validate metadata across ranks, and build a global sharded tensor + # metadata by gathering local ShardedTensorMetadata + gathered_metadatas: list[Optional[ShardedTensorMetadata]] = [] + if world_size > 1: + gathered_metadatas = [None for _ in range(world_size)] + + dist.all_gather_object( + gathered_metadatas, local_sharded_tensor_metadata, group=process_group + ) + else: + gathered_metadatas = [local_sharded_tensor_metadata] + + global_sharded_tensor_metadata = build_global_metadata( + gathered_metadatas, recalc_metadata=recalc_metadata + ) + if recalc_metadata: + # for recalc use cases, we only support rw for now, limit the blast radius + # will modify here once we support more sharding type + assert ( + len(local_shards) > 0 + and len(global_sharded_tensor_metadata.shards_metadata) > current_rank + ), ( + f"# for metadata recalculation, local_shards must be larger than 0 " + f"actual:{len(local_shards)}, # glb metadata must be greater than any rank id, " + f"# metadata:{len(global_sharded_tensor_metadata.shards_metadata)}, rank id:{current_rank}" + ) + local_md = [ + shard_md + for shard_md in global_sharded_tensor_metadata.shards_metadata + if shard_md.placement.rank() == current_rank + ] + assert len(local_md) == 1, ( + f"should has and only has one metadata for local rank, actual:{local_md}" + ) + local_shards[0].metadata = local_md[0] + tensor_properties = global_sharded_tensor_metadata.tensor_properties + + # STEP 3: Validation done, create the actual ShardedTensor and populate fields + # prepare initialization + spec = shard_spec._infer_sharding_spec_from_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + sharded_tensor = cls.__new__( + cls, + spec, + global_sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # attach local_shards to the ShardedTensor created + sharded_tensor._local_shards = local_shards + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def _init_from_local_tensor( + cls, + local_tensor: torch.Tensor, + sharding_spec: shard_spec.ShardingSpec, + *global_size: Sequence[int], + process_group: Optional[dist.ProcessGroup] = None, + init_rrefs=False, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor given only one local tensor, global sharded tensor + size and sharding spec on each rank. + + Args: + local_tensor (Tensor): Single tensor of local shard stored in each rank. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): + The specification describing how to shard the Tensor. + global_size (Sequence[int]): Size of the sharded tensor. + process_group (ProcessGroup, optional): The process group to aggregate on. + Default: None + init_rrefs (bool, optional): Whether or not to initialize + :class:`torch.distributed.rpc.RRef`s pointing to remote shards. + Need to initialize the RPC Framework if specified as ``True``. + Default: ``False``. + + Returns: + A :class:`ShardedTensor` sharded based on the given sharding_spec with local + tensor stored in the current rank. + + Examples: + >>> # xdoctest: +SKIP + >>> # All tensors below are of torch.int64 type. + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(2, dtype=torch.int64) + 1 + 2 * rank + >>> local_tensor = torch.unsqueeze(torch.cat([tensor, tensor + 2])) + >>> local_tensor + tensor([[1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6]]) # Rank 1 + >>> sharding_dim = 0 + >>> sharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + ], + ) + >>> st = ShardedTensor._init_from_local_tensor( + ... local_tensor, sharding_spec, [2, 4] + ... ) + >>> st + ShardedTensor( + ShardedTensorMetadata( + shards_metadata=[ + ShardMetadata(shard_offsets=[0, 0], shard_sizes=[1, 4], placement=rank:0/cuda:0), + ShardMetadata(shard_offsets=[1, 0], shard_sizes=[1, 4], placement=rank:1/cuda:1), + ], + size=torch.Size([2, 4]) + ) + >>> st.local_tensor() + tensor([1, 2, 3, 4]) # Rank 0 + tensor([3, 4, 5, 6]) # Rank 1 + + Warning: This API is experimental and subject to change. It lacks of a fully across + rank validations, and we only validate the local shard on the current rank. + We fully rely on the user to ensure local tensor is sharded based on the + sharding spec. + """ + if not local_tensor.is_contiguous(): + raise ValueError("local_tensor is not a contiguous Tensor.") + + global_tensor_size = _flatten_tensor_size(global_size) + tensor_properties = TensorProperties( + dtype=local_tensor.dtype, + layout=local_tensor.layout, + requires_grad=local_tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=local_tensor.is_pinned(), + ) + sharded_tensor_metadata = sharding_spec.build_metadata( + global_tensor_size, tensor_properties + ) + + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + local_shards: list[Shard] = [] + for shard_metadata in sharded_tensor_metadata.shards_metadata: + rank, _device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + if rank == current_rank: + local_shards.append(Shard(local_tensor, shard_metadata)) + + # TODO: figure out what the API should behave when some rank have no shard + # see https://github.com/pytorch/pytorch/issues/7313 + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, + sharded_tensor_metadata, + process_group=process_group, + init_rrefs=init_rrefs, + sharding_spec=sharding_spec, + ) + + @classmethod + def _init_from_local_shards_and_global_metadata( # type: ignore[override] + cls, + local_shards: list[Shard], + sharded_tensor_metadata: ShardedTensorMetadata, + process_group=None, + init_rrefs=False, + sharding_spec=None, + ) -> ShardedTensor: + """ + Initialize a ShardedTensor with local shards and a global + ShardedTensorMetadata built on each rank. + + Warning: This API is experimental and subject to change. It does + not do cross rank validations, and fully rely on the user + for the correctness of sharded_tensor_metadata on each rank + """ + process_group = cls._normalize_pg(process_group) + current_rank = dist.get_rank() # intentional to get global rank + + shards_metadata = sharded_tensor_metadata.shards_metadata + + local_shard_metadatas = [] + + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + rank, local_device = _parse_and_validate_remote_device( + process_group, shard_metadata.placement + ) + + if current_rank == rank: + local_shard_metadatas.append(shard_metadata) + + if len(local_shards) != len(local_shard_metadatas): + raise RuntimeError( + f"Number of local shards ({len(local_shards)}) does not match number of local " + f"shards metadata in sharded_tensor_metadata ({len(local_shard_metadatas)}) " + f"on rank ({current_rank}) " + ) + + shards_metadata = sharded_tensor_metadata.shards_metadata + tensor_properties = sharded_tensor_metadata.tensor_properties + + if len(shards_metadata) == 0: + raise ValueError("shards_metadata must not be empty!") + + if tensor_properties.layout != torch.strided: + raise ValueError("Only torch.strided layout is currently supported") + + if sharding_spec is None: + spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata) + else: + spec = sharding_spec + + sharded_tensor = ShardedTensor.__new__( + ShardedTensor, + spec, + sharded_tensor_metadata.size, + dtype=tensor_properties.dtype, + layout=tensor_properties.layout, + pin_memory=tensor_properties.pin_memory, + requires_grad=tensor_properties.requires_grad, + ) + + def _raise_if_mismatch(expected, actual, prop_name, rank, is_property=False): + tensor_property_or_metadata = ( + "tensor property" if is_property else "local ShardMetadata" + ) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property is incompatible with " + f"{tensor_property_or_metadata} on rank {rank}: " + f"{tensor_property_or_metadata} {prop_name}={expected}, " + f"local shard tensor {prop_name}={actual}." + ) + + for shard in local_shards: + shard_meta = shard.metadata + local_shard_tensor = shard.tensor + placement = shard_meta.placement + assert placement is not None, "Must specify placement for `Shard`!" + rank = placement.rank() + local_device = placement.device() + + _raise_if_mismatch( + tensor_properties.layout, + local_shard_tensor.layout, + "layout", + rank, + True, + ) + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported" + ) + + _raise_if_mismatch( + shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + rank, + ) + _raise_if_mismatch( + tensor_properties.pin_memory, + local_shard_tensor.is_pinned(), + "pin_memory", + rank, + True, + ) + _raise_if_mismatch(local_device, local_shard_tensor.device, "device", rank) + _raise_if_mismatch( + tensor_properties.dtype, + local_shard_tensor.dtype, + "dtype", + rank, + True, + ) + _raise_if_mismatch( + tensor_properties.requires_grad, + local_shard_tensor.requires_grad, + "requires_grad", + rank, + True, + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata(shards_metadata) + + # check if the shards_metadata is compatible with overall size of the sharded tensor. + check_tensor(shards_metadata, list(sharded_tensor_metadata.size)) + + # done validation, add local_shards + sharded_tensor._local_shards = local_shards + sharded_tensor._prepare_init(process_group=process_group, init_rrefs=init_rrefs) + + # run post initialization, i.e. map registration, rpc initialization + sharded_tensor._post_init() + return sharded_tensor + + def sharding_spec(self) -> shard_spec.ShardingSpec: + """ + Returns the ShardingSpec for the tensor. + """ + return self._sharding_spec + + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def reshard(self, resharding_spec: shard_spec.ShardingSpec) -> ShardedTensor: + """ + Reshard a sharded tensor given the ``resharding_spec``. For now, we only support + single local shard. + + If ``resharding_spec`` is same as the original one, this becomes a no-op. + If only ``resharding_spec`` shares the same sharding dim with the original one, + we swap local shards directly. + For more generic cases, we merge different shards across different ranks and split + the local shards based on the ``resharding_spec`` via `all_to_all` collective API. + + Args: + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + + Returns: + A :class:`ShardedTensor` object whose local shards are resharded. + + Examples: + >>> # xdoctest: +SKIP + >>> # We have 2 process groups, 2 ranks. + >>> tensor = torch.arange(4, dtype=torch.int64) + 1 + 2 * rank + >>> tensor = torch.stack([tensor, tensor]) + >>> tensor + tensor([[1, 2, 3, 4], [1, 2, 3, 4]]) # Rank 0 + tensor([[3, 4, 5, 6], [3, 4, 5, 6]]) # Rank 1 + tensor([[5, 6, 7, 8], [5, 6, 7, 8]]) # Rank 2 + tensor([[7, 8, 9, 10], [7, 8, 9, 10]]) # Rank 3 + >>> sharding_dim = 0 + >>> spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> current_offsets = [0] * 2 + >>> current_offsets[0] = rank * 2 + >>> shard_metadata = ShardMetadata( + shard_offsets=copy.deepcopy(current_offsets), + shard_sizes=tensor.size(), + placement=spec.placements[rank], + ) + >>> local_shards = [ + Shard( + tensor=tensor, + metadata=shard_metadata, + ) + ] + >>> st = ShardedTensor._init_from_local_shards(local_shards, tensor.size()) + >>> sharding_dim = 1 + >>> resharding_spec = ChunkShardingSpec( + dim=sharding_dim, + placements=[ + "rank:0/cuda:0", + "rank:1/cuda:1", + "rank:2/cuda:2", + "rank:3/cuda:3", + ], + ) + >>> st.reshard(resharding_spec) + >>> tensor = st.local_shards()[0].tensor + >>> tensor + tensor([[1], [1], [3], [3], [5], [5], [7], [7]]) # Rank 0 + tensor([[2], [2], [4], [4], [6], [6], [8], [8]]) # Rank 1 + tensor([[3], [3], [5], [5], [7], [7], [9], [9]]) # Rank 2 + tensor([[4], [4], [6], [6], [8], [8], [10], [10]]) # Rank 3 + """ + if not isinstance( + resharding_spec, shard_spec.ChunkShardingSpec + ) or not isinstance(self._sharding_spec, shard_spec.ChunkShardingSpec): + raise NotImplementedError("Only ChunkShardingSpec supported for reshard.") + if len(self.local_shards()) != 1: + raise NotImplementedError("Only single local shard supported for reshard.") + + if self._sharding_spec.dim == resharding_spec.dim: # type: ignore[attr-defined] + if self._sharding_spec.placements == resharding_spec.placements: # type: ignore[attr-defined] + return self + else: + local_shards, shards_metadata = reshuffle_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + else: + local_shards, shards_metadata = reshard_local_shard( + self.local_tensor(), + self.size(), # type: ignore[arg-type] + self._sharding_spec, + resharding_spec, + self._process_group, + ) + self._local_shards = local_shards + self._metadata.shards_metadata = shards_metadata + self._sharding_spec = resharding_spec + return self + + def local_tensor(self) -> torch.Tensor: + """ + Return local tensor for a sharded_tensor. For now we only support single local shard. + + Returns: + A :class:`torch.Tensor` of the local shard. + """ + if len(self.local_shards()) != 1: + raise NotImplementedError("Only single local shard is supported.") + return self.local_shards()[0].tensor + + @classmethod + @deprecated(DEPRECATE_MSG, category=FutureWarning) + def __torch_function__(cls, func, types, args=(), kwargs=None): + def dispatch(st: ShardedTensor, func: Callable): + # Dispatch to custom user provided op first if it exists. + if func in _CUSTOM_SHARDED_OPS: + return _CUSTOM_SHARDED_OPS[func](types, args, kwargs, st._process_group) + + # Dispatch to custom sharding spec op if it has one. + if _has_custom_op(st._sharding_spec, func): + return _dispatch_custom_op( + st._sharding_spec, func, types, args, kwargs, st._process_group + ) + + if func in _SHARDED_OPS: + return _SHARDED_OPS[func](types, args, kwargs, st._process_group) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + # Find ShardedTensor instance to get process_group and sharding_spec. + st_instance = None + + def find_sharded_tensor(e): + nonlocal st_instance + if st_instance is None and isinstance(e, ShardedTensor): + st_instance = e + + pytree.tree_map_(find_sharded_tensor, args) + pytree.tree_map_(find_sharded_tensor, kwargs) + + if st_instance is not None: + return dispatch(st_instance, func) + + raise RuntimeError( + f"torch function '{func.__name__}', with args: {args} and " + f"kwargs: {kwargs} not supported for ShardedTensor!" + ) + + def is_pinned(self) -> bool: # type: ignore[override] + """ + Returns True if the sharded tensor (each local shard) resides in pinned memory. + """ + return self._metadata.tensor_properties.pin_memory + + def _register_remote_shards( + self, remote_shards: list[rpc.RRef[Shard]], rpc_rank: int + ): + self._remote_shards[rpc_rank] = remote_shards + + def remote_shards(self) -> dict[int, list[rpc.RRef[Shard]]]: + """ + Returns a Dict[int, RRef] with keys being the RPC rank and values + being RRefs to shards on that rank. Need to initialize the + RPC framework for this functionality. + + Raises an exception if ShardedTensor was created with ``init_rrefs=False`` + """ + if not self._init_rrefs: + raise RuntimeError( + "ShardedTensor created with init_rrefs=False, no RRefs to remote shards available" + ) + return self._remote_shards + + def __hash__(self): + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"ShardedTensor({self._metadata})" + + @dataclass + class ProcessGroupState: + """ + State for ser-de of process group + """ + + local_rank: int + global_rank: int + local_world_size: int + global_world_size: int + + def __getstate__(self): + pg_state = ShardedTensor.ProcessGroupState( + distributed_c10d.get_rank(self._process_group), + distributed_c10d.get_rank(), + distributed_c10d.get_world_size(self._process_group), + distributed_c10d.get_world_size(), + ) + + return ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) + + def __setstate__(self, state): + self._sharded_tensor_id = None + if not distributed_c10d.is_initialized(): + raise RuntimeError( + "Need to initialize default process group using " + '"init_process_group" before loading ShardedTensor' + ) + + ( + self._local_shards, + self._metadata, + pg_state, + self._sharding_spec, + self._init_rrefs, + ) = state + + # Setup process group + from torch.distributed._shard.api import _get_current_process_group + + self._process_group = _get_current_process_group() + + # Validate process group. + local_rank = distributed_c10d.get_rank(self._process_group) + if pg_state.local_rank != local_rank: + raise RuntimeError( + f"Local rank at save time was {pg_state.local_rank}, but at " + f"load time was {local_rank}" + ) + + global_rank = distributed_c10d.get_rank() + if pg_state.global_rank != global_rank: + raise RuntimeError( + f"Global rank at save time was {pg_state.global_rank}, but at " + f"load time was {global_rank}" + ) + + local_world_size = distributed_c10d.get_world_size(self._process_group) + if pg_state.local_world_size != local_world_size: + raise RuntimeError( + f"Local world size at save time was {pg_state.local_world_size}, " + f"but at load time was {local_world_size}" + ) + + global_world_size = distributed_c10d.get_world_size() + if pg_state.global_world_size != global_world_size: + raise RuntimeError( + f"Global world size at save time was {pg_state.global_world_size}, " + f"but at load time was {global_world_size}" + ) + + self._post_init() + + +def _create_tensor_from_params( + *size, local_device, tensor_properties: TensorProperties +): + """Helper to construct tensor from size, device and common params.""" + dtype = tensor_properties.dtype + layout = tensor_properties.layout + requires_grad = tensor_properties.requires_grad + memory_format = tensor_properties.memory_format + pin_memory = tensor_properties.pin_memory + + return torch.empty( + *size, + dtype=dtype, + layout=layout, + device=local_device, + requires_grad=requires_grad, + memory_format=memory_format, + pin_memory=pin_memory, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logger.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..a646b2dc35d1085e892b12388f46bf8607aacdab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logger.py @@ -0,0 +1,35 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +from torch.distributed._shard.sharded_tensor.logging_handlers import _log_handlers + + +__all__: list[str] = [] + + +def _get_or_create_logger() -> logging.Logger: + logging_handler, log_handler_name = _get_logging_handler() + logger = logging.getLogger(f"sharding-spec-{log_handler_name}") + logger.setLevel(logging.DEBUG) + formatter = logging.Formatter( + "%(asctime)s %(filename)s:%(lineno)s %(levelname)s p:%(processName)s t:%(threadName)s: %(message)s" + ) + logging_handler.setFormatter(formatter) + logger.propagate = False + logger.addHandler(logging_handler) + return logger + + +def _get_logging_handler( + destination: str = "default", +) -> tuple[logging.Handler, str]: + log_handler = _log_handlers[destination] + log_handler_name = type(log_handler).__name__ + return (log_handler, log_handler_name) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py new file mode 100644 index 0000000000000000000000000000000000000000..8d070b48bcf0c49a95d7e7d83ae54d693b744b5a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/logging_handlers.py @@ -0,0 +1,16 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + + +__all__: list[str] = [] + +_log_handlers: dict[str, logging.Handler] = { + "default": logging.NullHandler(), +} diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py new file mode 100644 index 0000000000000000000000000000000000000000..1ab06f44fc77904e70aecb114bfab3f81e61db4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/metadata.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass, field +from enum import Enum + +import torch +from torch.distributed._shard.metadata import ShardMetadata + + +class MEM_FORMAT_ENCODING(Enum): + TORCH_CONTIGUOUS_FORMAT = 0 + TORCH_CHANNELS_LAST = 1 + TORCH_PRESERVE_FORMAT = 2 + + +@dataclass +class TensorProperties: + """Properties used to create :class:`Tensor`""" + + # Regular tensor fields + dtype: torch.dtype = field(default=torch.get_default_dtype()) + layout: torch.layout = field(default=torch.strided) + requires_grad: bool = False + memory_format: torch.memory_format = field(default=torch.contiguous_format) + pin_memory: bool = False + + def __getstate__(self): + # Since torch.memory_format cannot be pickled! + memory_format = self.memory_format + if memory_format == torch.contiguous_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT + elif memory_format == torch.channels_last: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST + elif memory_format == torch.preserve_format: + mem_format_encoding = MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT + else: + raise RuntimeError(f"Invalid torch.memory_format: {memory_format}") + + return ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) + + def __setstate__( + self, + state, + ): + ( + self.dtype, + self.layout, + self.requires_grad, + mem_format_encoding, + self.pin_memory, + ) = state + + if mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CONTIGUOUS_FORMAT: + memory_format = torch.contiguous_format + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_CHANNELS_LAST: + memory_format = torch.channels_last + elif mem_format_encoding == MEM_FORMAT_ENCODING.TORCH_PRESERVE_FORMAT: + memory_format = torch.preserve_format + else: + raise RuntimeError( + f"Invalid torch.memory_format encoding: {mem_format_encoding}" + ) + + self.memory_format = memory_format + + @staticmethod + def create_from_tensor(tensor: torch.Tensor) -> "TensorProperties": + return TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + + +@dataclass +class ShardedTensorMetadata: + """ + Represents metadata for :class:`ShardedTensor` + """ + + # Metadata about each shard of the Tensor + shards_metadata: list[ShardMetadata] = field(default_factory=list) + + # Size of each dim of the overall Tensor. + size: torch.Size = field(default=torch.Size([])) + + tensor_properties: TensorProperties = field(default_factory=TensorProperties) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py new file mode 100644 index 0000000000000000000000000000000000000000..7e101f2546be90cef6c447d37b077816965c3e36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/reshard.py @@ -0,0 +1,243 @@ +# mypy: allow-untyped-defs +import copy + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +from torch._C._distributed_c10d import ProcessGroup +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharding_spec._internals import ( + get_chunked_dim_size, + get_split_size, +) +from torch.distributed.nn.functional import all_to_all, all_to_all_single + +from .shard import Shard + + +def get_idx_from_placements(placements, current_rank) -> int: + """ + Return the position of the current rank in the given placements. + + Args: + placements(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`torch.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`torch.distributed._remote_device` + current_rank (int): number of current device. + + Returns: + A int which contains the position of current device in the placement list. + """ + for idx, placement in enumerate(placements): # type: ignore[attr-defined] + if current_rank == placement.rank(): # type: ignore[union-attr] + return idx + raise RuntimeError("current_rank not in the placement.") + + +def build_reshard_metadata( + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + world_size: int, +) -> tuple[list[ShardMetadata], list[int]]: + """ + Based the given sharding spec, we calculate the offset and local shard size. + We then build a ShardMetadata on top of the calculation result. + + Args: + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded. + world_size (int): number of ranks. + + Returns: + A Tuple of the followings: + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + A List[int] which contains the ranks in the order of placement. + """ + shard_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + shards_metadata = [None] * world_size + ranks = [] + offsets = [0] * len(st_size) + split_size = get_split_size(st_size[shard_dim], world_size) + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + ranks.append(placement.rank()) + sharded_dim_size = get_chunked_dim_size(st_size[shard_dim], split_size, idx) + local_tensor_size = list(st_size) + local_tensor_size[shard_dim] = sharded_dim_size + shards_metadata[placement.rank()] = ShardMetadata( # type: ignore[call-overload] + shard_offsets=copy.deepcopy(offsets), + shard_sizes=local_tensor_size, + placement=placement, + ) + offsets[shard_dim] += sharded_dim_size + return shards_metadata, ranks # type: ignore[return-value] + + +def reshuffle_local_shard( + local_shard: torch.Tensor, + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> tuple[list[Shard], list[ShardMetadata]]: + """ + Reshuffle the local shard directly when the reshard dim is same as the original + sharding dim. Logically we do this in two step: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending the local tensor to + the new shard directly based on the resharding spec. + + Args: + local_shard (Tensor): Local tensor stored in the current rank. + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + # Get input split size for all2all. + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + split_size = get_split_size(st_size[reshard_dim], world_size) + input_split_sizes = [0] * world_size + idx = get_idx_from_placements(sharding_spec.placements, current_rank) # type: ignore[attr-defined] + new_rank = resharding_spec.placements[idx].rank() # type: ignore[union-attr, attr-defined] + input_split_sizes[new_rank] = local_shard.size(reshard_dim) + # Get output split size for all2all. + output_split_sizes = [0] * world_size + new_idx = ranks.index(current_rank) + sharded_dim_size = get_chunked_dim_size(st_size[reshard_dim], split_size, new_idx) + output_split_sizes[new_rank] = sharded_dim_size + # Get gathered_input for all2all. + local_shard = local_shard.transpose(0, reshard_dim).contiguous() + gathered_input_size = list(local_shard.size()) + gathered_input_size[0] = sharded_dim_size + gathered_input = torch.empty( + gathered_input_size, device=local_shard.device, dtype=local_shard.dtype + ) + # all2all. + local_shard = all_to_all_single( + gathered_input, + local_shard, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + group=pg, + ) + local_tensor = local_shard.transpose(0, reshard_dim).contiguous() + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata + + +def reshard_local_shard( + local_tensor: torch.Tensor, + st_size: torch.Size, + sharding_spec: shard_spec.ShardingSpec, + resharding_spec: shard_spec.ShardingSpec, + pg: ProcessGroup, +) -> tuple[list[Shard], list[ShardMetadata]]: + """ + Reshard a sharded tensor given the ``resharding_spec``. When the reshard dim is + different from the original sharding dim, we need to do two steps logically: + 1. To collect all shards based on original sharding spec. + 2. Reshard the tensor based on the given resharding spec. + + In reality, we consolidate the two steps into one by sending each rank the new + shard based on the resharding spec. + + Args: + local_tensor (Tensor): Local tensor stored in the current rank. + st_size (torch.Size): The size of the sharded tensor. + sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor is sharded originally. + resharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The + specification describing how the tensor will be resharded. + pg (ProcessGroup): The process group to aggregate on. + + Returns: + A Tuple of the followings: + A List[`Shard`] which contains the local tensor and its metadata. + A List[`ShardMetadata`] which contains the metadata for the shard, including + offsets, lengths and device placement. + """ + current_rank = dist.get_rank(pg) + world_size = dist.get_world_size(pg) + current_sharding_dim = int(sharding_spec.dim) # type: ignore[attr-defined] + reshard_dim = int(resharding_spec.dim) # type: ignore[attr-defined] + + # Build shards_metadata first. + shards_metadata, ranks = build_reshard_metadata( + st_size, resharding_spec, world_size + ) + + # Compute expected size + input_split_sizes = [ + metadata.shard_sizes[reshard_dim] for metadata in shards_metadata + ] + rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1)) + + if rearrange_input: + # Need to re-arrange reshard_dim of local_tensor before all2all. + indices: list[int] = [] + for metadata in shards_metadata: + offset_start_idx = metadata.shard_offsets[reshard_dim] + split_size = metadata.shard_sizes[reshard_dim] + indices += range(offset_start_idx, offset_start_idx + split_size) + local_tensor = local_tensor.index_select( + reshard_dim, torch.tensor(indices, device=local_tensor.device) + ) + + # Because reshard_dim != original shard_dim. We need to compute the + # size of tensor from each rank. + output_tensor_list = [torch.tensor(1)] * world_size + split_size = get_split_size(st_size[current_sharding_dim], world_size) + rearrange_output_list = False + indices = [] + for idx, placement in enumerate(sharding_spec.placements): # type: ignore[attr-defined] + sharded_dim_size = get_chunked_dim_size( + st_size[current_sharding_dim], split_size, idx + ) + output_tensor_size = list(st_size) + output_tensor_size[current_sharding_dim] = sharded_dim_size + output_tensor_size[reshard_dim] = input_split_sizes[current_rank] + output_tensor_list[placement.rank()] = torch.empty( # type: ignore[union-attr, index] + output_tensor_size, device=local_tensor.device, dtype=local_tensor.dtype + ) + indices.append(placement.rank()) # type: ignore[union-attr, index, arg-type] + if idx != placement.rank(): # type: ignore[union-attr] + rearrange_output_list = True + + # Perform autograd enabled all2all. + input_tensor_tuple = torch.split(local_tensor, input_split_sizes, dim=reshard_dim) + input_tensor_list = [tensor.contiguous() for tensor in input_tensor_tuple] + output_tensor_list = all_to_all( + output_tensor_list, + input_tensor_list, + group=pg, + ) + + if rearrange_output_list: + # Need to re-arrange original shard_dim of output_tensor_list. + output_tensor_list = [output_tensor_list[idx] for idx in indices] # type: ignore[call-overload] + local_tensor = torch.cat(output_tensor_list, dim=current_sharding_dim) + local_shards = [Shard(local_tensor, shards_metadata[current_rank])] + return local_shards, shards_metadata diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/shard.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/shard.py new file mode 100644 index 0000000000000000000000000000000000000000..e4cb62659c38060133f31483fde7be355aa5cdb4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/shard.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass + +import torch +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed.remote_device import _remote_device + + +@dataclass +class Shard: + """ + Container which holds the data for a shard as a Tensor and also + the associated metadata for that shard. + + Args: + tensor(torch.Tensor): Local tensor for the shard. + metadata(:class `torch.distributed._shard.sharded_tensor.ShardMetadata`): + The metadata for the shard, including offsets, lengths and device placement. + """ + + __slots__ = ["tensor", "metadata"] + tensor: torch.Tensor + metadata: ShardMetadata + + def __post_init__(self) -> None: + # verification between local tensor and metadata + if list(self.tensor.size()) != self.metadata.shard_sizes: + raise ValueError( + "Shard tensor size does not match with metadata.shard_lengths! " + f"Found shard tensor size: {list(self.tensor.size())}, " + f"metadata.shard_lengths: {self.metadata.shard_sizes}, " + ) + placement_device = self.metadata.placement + if ( + placement_device is not None + and placement_device.device() != self.tensor.device + ): + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {self.tensor.device}, " + f"local shard metadata placement device: {placement_device.device()}" + ) + + @classmethod + def from_tensor_and_offsets( + cls, tensor: torch.Tensor, shard_offsets: list[int], rank: int + ) -> "Shard": + """ + Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. + + Args: + tensor(torch.Tensor): Local tensor for the shard. + shard_offsets(List[int]): List of integers specify the offset + of the shard on each dimension. + rank(int): Specify the rank for the shard. + """ + shard_sizes = list(tensor.size()) + placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") + shard_meta = ShardMetadata( + shard_offsets=shard_offsets, shard_sizes=shard_sizes, placement=placement + ) + return Shard(tensor, shard_meta) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/utils.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1acfdae703fbd50b2e1bca92a9128062c8f4497e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharded_tensor/utils.py @@ -0,0 +1,323 @@ +# mypy: allow-untyped-defs +import collections.abc +import copy +import itertools +from collections.abc import Sequence +from typing import Optional, TYPE_CHECKING + +import torch +from torch.distributed import distributed_c10d as c10d, rpc +from torch.distributed._shard.sharding_spec._internals import ( + check_tensor, + validate_non_overlapping_shards_metadata, +) + +from .metadata import ShardedTensorMetadata, TensorProperties +from .shard import Shard + + +if TYPE_CHECKING: + from torch.distributed._shard.metadata import ShardMetadata + + +def _parse_and_validate_remote_device(pg, remote_device): + if remote_device is None: + raise ValueError("remote device is None") + + worker_name = remote_device.worker_name() + rank = remote_device.rank() + device = remote_device.device() + + # Validate rank, skip validation if rank is not part of process group. + if rank is not None and not c10d._rank_not_in_group(pg): + pg_global_ranks = c10d.get_process_group_ranks(pg) + if rank not in pg_global_ranks: + raise ValueError( + f"Global rank {rank} does not exist in input process group: {pg_global_ranks}" + ) + + if worker_name is not None: + if not rpc._is_current_rpc_agent_set(): + raise RuntimeError( + f"RPC framework needs to be initialized for using worker names: {worker_name}" + ) + + workers = rpc._get_current_rpc_agent().get_worker_infos() + for worker in workers: + if worker.name == worker_name: + return worker.id, device + + raise ValueError(f"Invalid worker name: {worker_name}") + + return rank, device + + +def _validate_output_tensor_for_gather( + my_rank: int, + dst_rank: int, + size: torch.Size, + dst_tensor: Optional[torch.Tensor], +) -> None: + if dst_rank == my_rank: + if dst_tensor is None: + raise ValueError( + f"Argument ``dst_tensor`` must be specified on destination rank {dst_rank}" + ) + if tuple(size) != (dst_tensor.size()): + raise ValueError( + f"Argument ``dst_tensor`` have size {tuple(dst_tensor.size())}," + f"but should be {tuple(size)}" + ) + elif dst_tensor: + raise ValueError( + "Argument ``dst_tensor`` must NOT be specified on non-destination ranks." + ) + + +def _flatten_tensor_size(size) -> torch.Size: + """ + Checks if tensor size is valid, then flatten/return a torch.Size object. + """ + if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): + dims = list(*size) + else: + dims = list(size) + + for dim in dims: + if not isinstance(dim, int): + raise TypeError(f"size has to be a sequence of ints, found: {dims}") + + return torch.Size(dims) + + +def _raise_if_mismatch(expected, actual, prop_name, ranks, is_local=True): + if is_local: + assert isinstance(ranks, int) + if expected != actual: + raise ValueError( + f"Local shards' tensor {prop_name} property need to be the same on rank:{ranks}! " + f"Found one local shard tensor {prop_name}={expected}, " + f"the other local shard tensor {prop_name}={actual}." + ) + else: + # compare failure check across ranks, ranks list should have two rank + assert len(ranks) == 2 + if expected != actual: + raise ValueError( + f"ShardedTensor {prop_name} property does not match from different ranks! " + f"Found {prop_name}={expected} on rank:{ranks[0]}, " + f"and {prop_name}={actual} on rank:{ranks[1]}." + ) + + +def build_metadata_from_local_shards( + local_shards: list[Shard], + global_size: torch.Size, + current_rank: int, + pg: c10d.ProcessGroup, +) -> ShardedTensorMetadata: + assert len(local_shards) > 0, "must have local shards!" + local_shard_metadatas: list[ShardMetadata] = [] + + first_shard_dtype = local_shards[0].tensor.dtype + first_shard_layout = local_shards[0].tensor.layout + first_shard_requires_grad = local_shards[0].tensor.requires_grad + first_shard_is_pinned = local_shards[0].tensor.is_pinned() + + # 1). Validate local tensors and associated metadatas + for local_shard in local_shards: + local_shard_tensor = local_shard.tensor + local_shard_meta = local_shard.metadata + local_shard_metadatas.append(local_shard_meta) + rank, local_device = _parse_and_validate_remote_device( + pg, local_shard_meta.placement + ) + + if ( + local_shard_tensor.layout != torch.strided + or local_shard_tensor.layout != first_shard_layout + ): + raise ValueError( + f"Only torch.strided layout is currently supported, but found " + f"{local_shard_tensor.layout} on rank:{current_rank}!" + ) + + if not local_shard_tensor.is_contiguous(): + raise ValueError( + "Only torch.contiguous_format memory_format is currently supported!" + ) + + if rank != current_rank: + raise ValueError( + f"Local shard metadata's rank does not match with the rank in its process group! " + f"Found current rank in the process group: {current_rank}, " + f"local ShardMetadata placement's rank: {rank}" + ) + if local_shard_tensor.device != local_device: + raise ValueError( + f"Local shard tensor device does not match with local Shard's placement! " + f"Found local shard tensor device: {local_shard_tensor.device}, " + f"local shard metadata placement device: {local_device}" + ) + + _raise_if_mismatch( + local_shard_meta.shard_sizes, + list(local_shard_tensor.size()), + "size", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.is_pinned(), + first_shard_is_pinned, + "pin_memory", + current_rank, + ) + _raise_if_mismatch( + local_shard_tensor.dtype, first_shard_dtype, "dtype", current_rank + ) + _raise_if_mismatch( + local_shard_tensor.requires_grad, + first_shard_requires_grad, + "requires_grad", + current_rank, + ) + + # 2). Build a "local" ShardedTensorMetadata with all local shards on this rank, then + # do all_gather to collect local_sharded_tensor_metadata from all ranks + local_tensor_properties = TensorProperties( + dtype=first_shard_dtype, + layout=first_shard_layout, + requires_grad=first_shard_requires_grad, + memory_format=torch.contiguous_format, + pin_memory=first_shard_is_pinned, + ) + + local_sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=local_shard_metadatas, + size=global_size, + tensor_properties=local_tensor_properties, + ) + + return local_sharded_tensor_metadata + + +def build_global_metadata( + gathered_metadatas: Sequence[Optional[ShardedTensorMetadata]], + recalc_metadata: bool = False, +): + global_sharded_tensor_metadata = None + global_metadata_rank = 0 + + for rank, rank_metadata in enumerate(gathered_metadatas): + if rank_metadata is None: + continue + + if global_sharded_tensor_metadata is None: + global_sharded_tensor_metadata = copy.deepcopy(rank_metadata) + global_metadata_rank = rank + else: + _raise_if_mismatch( + global_sharded_tensor_metadata.size, + rank_metadata.size, + "global_size", + [global_metadata_rank, rank], + is_local=False, + ) + + # don't need to check layout and memory format as we already checked in local shards validation stage + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.dtype, + rank_metadata.tensor_properties.dtype, + "dtype", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.requires_grad, + rank_metadata.tensor_properties.requires_grad, + "requires_grad", + [global_metadata_rank, rank], + is_local=False, + ) + + _raise_if_mismatch( + global_sharded_tensor_metadata.tensor_properties.pin_memory, + rank_metadata.tensor_properties.pin_memory, + "pin_memory", + [global_metadata_rank, rank], + is_local=False, + ) + # pass all validations, extend shards metadata + global_sharded_tensor_metadata.shards_metadata.extend( + rank_metadata.shards_metadata + ) + + if global_sharded_tensor_metadata is not None: + if recalc_metadata: + recalc_global_sharded_tensor_metadata( + global_sharded_tensor_metadata, + 0, # sharded on 0th dim + ) + + # check if shards_metadata have overlap shards + validate_non_overlapping_shards_metadata( + global_sharded_tensor_metadata.shards_metadata + ) + + # check if the shards_metadata is compatible with global size of the sharded tensor. + check_tensor( + global_sharded_tensor_metadata.shards_metadata, + global_sharded_tensor_metadata.size, + ) + else: + raise ValueError("ShardedTensor have no local shards on all ranks!") + + return global_sharded_tensor_metadata + + +def recalc_global_sharded_tensor_metadata( + global_sharded_tensor_metadata: ShardedTensorMetadata, sharded_dim: int +) -> None: + # recalculate global ShardedTensorMetadata + + # reorder here in case shard metadata is not sorted on sharded_dim + placement_idx_pairs = [] + for i, shard_metadata in enumerate(global_sharded_tensor_metadata.shards_metadata): + if shard_metadata.placement: + placement_idx_pairs.append((shard_metadata.placement.rank(), i)) + else: + raise AssertionError( + "currently only support rw, it should always have valid rank info" + ) + sorted_idx = sorted(placement_idx_pairs) + shard_sizes = [ + global_sharded_tensor_metadata.shards_metadata[idx].shard_sizes[sharded_dim] + for _, idx in sorted_idx + ] + cum_sum = [0] + list(itertools.accumulate(shard_sizes)) + + for shard_id, shard_metadata in enumerate( + global_sharded_tensor_metadata.shards_metadata + ): + # update shard offset for each shard on the sharded dimension + shard_metadata.shard_offsets[sharded_dim] = cum_sum[shard_id] + for other_dim in range( + len(global_sharded_tensor_metadata.shards_metadata[0].shard_sizes) + ): + if other_dim != sharded_dim: + # shard offset for each shard on the unsharded dimension + shard_metadata.shard_offsets[other_dim] = 0 + + # update global size for ShardedTensorMetadata + global_size_list = [] + for other_dim in range( + len(global_sharded_tensor_metadata.shards_metadata[0].shard_sizes) + ): + if other_dim != sharded_dim: + global_size_list.append( + global_sharded_tensor_metadata.shards_metadata[0].shard_sizes[other_dim] + ) + else: + global_size_list.append(cum_sum[-1]) + global_sharded_tensor_metadata.size = torch.Size(global_size_list) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharder.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharder.py new file mode 100644 index 0000000000000000000000000000000000000000..236a54f843d400591583fd3c216f5b20b6c3663c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharder.py @@ -0,0 +1,29 @@ +import abc + +import torch.nn as nn + + +class Sharder(abc.ABC): + """ + This is an interface which allows user to create more advanced + sharding strategies that are not easily be composed by the + `ShardingSpec`. + + :class:`torch.distributed._shard.sharding_plan.ShardingPlan` could + take an object of the `Sharder` and call `shard` to shard the module, + then replace the original module with sharded module returned. + """ + + @abc.abstractmethod + def shard(self, module: nn.Module) -> nn.Module: + """ + Shard a module base on the implementation of this method, and + return the sharded version of the module. + + Args: + module (:class:`torch.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`torch.nn.Module` object that represents a module + that's already been sharded. + """ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3c8662fbac0e4127543b83570260b07eec63fe1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__init__.py @@ -0,0 +1 @@ +from .api import ShardingPlan, ShardingPlanner diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2f54a0e3fb7c48cb3f4ec5faf88c8eb41aec5701 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da9f6004edea824f10c00fd49a1923e20156f793 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/api.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/api.py new file mode 100644 index 0000000000000000000000000000000000000000..7c8d04e9a3b4d37edaa3cb52ee48691c7353702d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_plan/api.py @@ -0,0 +1,87 @@ +import abc +from dataclasses import dataclass +from typing import Optional, Union + +import torch.nn as nn +from torch.distributed._shard.sharder import Sharder +from torch.distributed._shard.sharding_spec import ShardingSpec + + +@dataclass +class ShardingPlan: + """ + Representation of a sharding plan, describes how to shard a module + across hosts. `plan` is used to shard module parameters according to the spec provided, + `output_plan` and `return_local_tensor` are optional, they are used to specify the output + layout of a module with a spec, and when to convert back to data parallel fashion. + + Args: + plan (Dict[str, Union[:class:`torch.distributed._shard.sharding_spec.ShardingSpec`, + :class:`torch.distributed._shard.sharder.Sharder`]): + a dict describes how to shard a module, there're currently two ways to shard a module: + 1. directly shard a module parameter by a `ShardingSpec`, keyed by the name of + a parameter to a `ShardingSpec`. + 2. shard a submodule by applying a `Sharder` on it, keyed by the name of a module + to a `Sharder` object. + output_plan (Dict[str, :class:`torch.distributed._shard.sharding_spec.ShardingSpec`), optional): + a dict specifies the layout of a module's output which produces a ShardedTensor, + keyed by the name of module to ShardingSpec("" in key means the root module). + Default: `None` + return_local_tensor (List[str], optional): a list of string, each element enables + a module's sharded output to be returned as a Tensor from its local shards to + ensure further processing in a data parallel fashion. ("" in list means the + root module). + Default: None + Example: + Suppose we want to shard a module with two linear layers and then run it with DDP, we also + want to convert the output of the second linear layer back to DDP, we can do it as follows: + + >>> # xdoctest: +REQUIRES(module:torch._C._distributed_c10d) + >>> class MyModule(nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.fc1 = nn.Linear() + >>> self.gelu = nn.GELU() + >>> self.fc2 = nn.Linear() + >>> self.relu = nn.Linear() + >>> + >>> def forward(self, input): + >>> return self.relu(self.fc2(self.gelu(self.fc1(input)))) + + + >>> # xdoctest: +SKIP("Undefined spec1, spec2) + >>> sharding_plan = ShardingPlan( + >>> plan={ + >>> "fc1.weight": spec1, + >>> "fc2.weight": spec2 + >>> }, + >>> output_plan={ + >>> "fc2": output_spec + >>> }, + >>> return_local_tensor=["fc2"] + >>> ) + """ + + plan: dict[str, Union[ShardingSpec, Sharder]] + output_plan: Optional[dict[str, ShardingSpec]] = None + return_local_tensor: Optional[list[str]] = None + + +class ShardingPlanner(abc.ABC): + """ + Default ShardingPlanner interface, can be extended and + implement advanced sharding strategies. + """ + + @abc.abstractmethod + def build_plan(self, module: nn.Module) -> ShardingPlan: + """ + Given a nn.Module, define how to shard the module across + ranks, return a ShardingPlan + Args: + module (:class:`torch.nn.Module`): + The module to apply sharding to. + Returns: + A :class:`torch.distributed._shard.sharding_plan.ShardingPlan` object that + represents how to shard the module. + """ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ad0f25d552dd40a52190c390f002014bde7d88f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__init__.py @@ -0,0 +1,10 @@ +from torch.distributed._shard.metadata import ShardMetadata + +from .api import ( + _infer_sharding_spec_from_shards_metadata, + DevicePlacementSpec, + EnumerableShardingSpec, + PlacementSpec, + ShardingSpec, +) +from .chunk_sharding_spec import ChunkShardingSpec as ChunkShardingSpec diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..54761b6820a92be579ec29d616d853c6e2476ec3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82e59a6442cf92d3953c1d44c24b7e0852ccf7cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/_internals.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc410cd3806fedb60c88df4b6c0881fdc0a305d1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62d8de47391106a54a8fce9186ef98b79bf0ebbb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/__pycache__/chunk_sharding_spec.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/_internals.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/_internals.py new file mode 100644 index 0000000000000000000000000000000000000000..4a6ee27eb926a1eb4de3f6c1cfb051748e1dfeb4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/_internals.py @@ -0,0 +1,228 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional + +from torch.distributed._shard.metadata import ShardMetadata + + +def _check_shard_metadata_pair_overlap(shard1: ShardMetadata, shard2: ShardMetadata): + """ + Checks if two shards overlap. + """ + + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + ndims = len(shard1.shard_offsets) + for i in range(ndims): + if shard1.shard_offsets[i] >= shard2.shard_offsets[i] + shard2.shard_sizes[i]: + return False + if shard2.shard_offsets[i] >= shard1.shard_offsets[i] + shard1.shard_sizes[i]: + return False + + return True + + +def _find_nd_overlapping_shards( + shards: list[ShardMetadata], sharded_dims: list[int] +) -> Optional[tuple[int, int]]: + # Each rank has len(sharded_dims) tuples. Each tuple represent the + # [begin, end] (inclusive) pair of that dimension. + shard_intervals = [ + [ + (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1) + for dim in sharded_dims + ] + for s in shards + ] + + for i in range(len(shards)): + shard_i = shard_intervals[i] + for j in range(i + 1, len(shards)): + shard_j = shard_intervals[j] + # For each dim of each shard, check if one shard resides on the other + # end of second shard with respect to that dim. As an example for a 2D + # shard, we would check if one shard is above or on the left of the + # other shard. + overlap = True + for interval_i, interval_j in zip(shard_i, shard_j): + if interval_i[0] > interval_j[1] or interval_j[0] > interval_i[1]: + overlap = False + break + if overlap: + return (i, j) + return None + + +def _find_1d_overlapping_shards( + shards: list[ShardMetadata], dim: int +) -> Optional[tuple[int, int]]: + # (begin, end, index_in_shards). Begin and end are inclusive. + intervals = [ + (s.shard_offsets[dim], s.shard_offsets[dim] + s.shard_sizes[dim] - 1, i) + for i, s in enumerate(shards) + ] + intervals.sort() + for i in range(len(shards) - 1): + if intervals[i][1] >= intervals[i + 1][0]: + return (intervals[i][2], intervals[i + 1][2]) + return None + + +def validate_non_overlapping_shards_metadata(shards: list[ShardMetadata]): + """ + Ensures none of the shards overlap with each other. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. + Raises: + ``ValueError`` if there's overlap in any two shards. + """ + if not shards or len(shards) == 1: + return + + sharded_dims: list[int] = [] + for dim in range(len(shards[0].shard_offsets)): + for i in range(1, len(shards)): + if ( + shards[i].shard_offsets[dim] != shards[0].shard_offsets[dim] + or shards[i].shard_sizes[dim] != shards[0].shard_sizes[dim] + ): + sharded_dims.append(dim) + break + + pair: Optional[tuple[int, int]] = None + if len(sharded_dims) == 0: + # if shard is all zeros, we should consider as pass + all_zeros: bool = all( + # strictly limited all offsets to be 0 to pass + # could loose it later on + shard.shard_offsets == [0] * len(shards[0].shard_offsets) + and math.prod(shard.shard_sizes) == 0 # one dimension is 0 + for shard in shards + ) + if all_zeros: + return + # All shards are the same, all dims are not partitioned. Choose any 2. + pair = (0, 1) + elif len(sharded_dims) == 1: + # Shards are partitioned over only one dimension. Overlap can be found + # using a O(nlogn) overlapping interval algorithm. + pair = _find_1d_overlapping_shards(shards, sharded_dims[0]) + else: + # Shards are partitioned over more than one dimension. Fall back to + # pair-wise check. Even though O(nlogn) algorithms (line sweep) exist + # for 2D overlap, the implementation is not trivial and may not justify + # the time saving in most cases. + pair = _find_nd_overlapping_shards(shards, sharded_dims) + + if pair: + raise ValueError(f"Shards {shards[pair[0]]} and {shards[pair[1]]} overlap") + + +def check_tensor(shards_metadata, tensor_dims) -> None: + """ + Checks if the shards_metadata is compatible with the provided tensor dims. + + Args: + shards_metadata(List[ShardMetadata]): List of :class:`ShardMetadata` + objects representing each shard of the tensor. + tensor_dims(Sequence of int): Dimensions of tensor to verify + Raises: + ``ValueError`` if not compatible. + """ + + # If the tensor's volume matches the total volume of all shards and + # all shard boundaries are within tensor dims, we have a compatible + # sharding spec for this tensor. Note that we have already verified + # we don't have overlapping shards. + tensor_rank = len(tensor_dims) + shards_rank = len(shards_metadata[0].shard_offsets) + if tensor_rank != shards_rank: + raise ValueError( + f"Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}" + ) + + total_shard_volume = 0 + for shard in shards_metadata: + shard_volume = 1 + for i, shard_length in enumerate(shard.shard_sizes): + shard_volume *= shard_length + if shard.shard_offsets[i] + shard.shard_sizes[i] > tensor_dims[i]: + raise ValueError( + f"Shard offset {shard.shard_offsets[i]} and length " + f"{shard.shard_sizes[i]} exceeds tensor dim: {tensor_dims[i]} for shard {shard}" + ) + total_shard_volume += shard_volume + + tensor_volume = 1 + for size in tensor_dims: + tensor_volume *= size + + if total_shard_volume != tensor_volume: + # TODO: Can we improve this error message to point out the gaps? + raise ValueError( + f"Total volume of shards: {total_shard_volume} " + f"does not match tensor volume: {tensor_volume}, in other words " + f"all the individual shards do not cover the entire tensor" + ) + + +def get_split_size(dim_size, chunks): + """ + Computes the split size inline with ``torch.chunk`` + + Args: + dim_size(int): Size of the dimension being chunked. + chunks(int): Number of chunks to create for ``dim_size``. + + Returns: + An int indicating the split size to use. + """ + return (dim_size + chunks - 1) // chunks + + +def get_chunked_dim_size(dim_size, split_size, idx): + """ + Computes the dim size of the chunk for provided ``idx`` given ``dim_size`` + and ``split_size``. + + Args: + dim_size(int): Size of the dimension being chunked. + split_size(int): The chunk size for each chunk of ``dim_size``. + idx(int): The index of chunk whose dim size is being requested. + + Returns: + An int indicating the dim size of the chunk. + """ + return max(min(dim_size, split_size * (idx + 1)) - split_size * idx, 0) + + +def get_chunk_sharding_params(sharding_dim_size, world_size, spec, rank): + """ + Generate the start pos and offset length for the current rank for + chunk sharding. + + Args: + sharding_dim_size(int): The dimension length which we shard on. + world_size(int): number of ranks. + spec (:class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec`): + sharding spec. + rank(int): # of cuda process. + + Returns: + start_pos(int): start position of sharded tensor on the given rank. + chunk_size(int): chunk size of sharded tensor on the given rank. + """ + split_size = get_split_size(sharding_dim_size, world_size) + current_offsets = 0 + start_pos = current_offsets + for idx, placement in enumerate(spec.placements): + chunk_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + if rank == placement.rank(): + start_pos = current_offsets + break + current_offsets += chunk_size + return start_pos, chunk_size # type: ignore[possibly-undefined] diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/api.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a2505e15ca985729dfa8e7c387123359f8755aea --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/api.py @@ -0,0 +1,263 @@ +# mypy: allow-untyped-defs +import functools +import operator +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, TYPE_CHECKING + +import torch +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.op_registry_utils import _decorator_func + +from ._internals import ( + check_tensor, + get_chunked_dim_size, + get_split_size, + validate_non_overlapping_shards_metadata, +) + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +class PlacementSpec(ABC): # noqa: B024 + """ + Base class representing the placement of an entity. Subclasses of this + class can be used to specify customized placements which might not be + covered by existing APIs. + """ + + +@dataclass +class DevicePlacementSpec(PlacementSpec): + """ + Associates placement of an entity with a single device. + + Args: + device(:class:`torch.distributed._remote_device`): The device to place the entity on. + """ + + device: torch.distributed._remote_device + + def __post_init__(self): + if not isinstance(self.device, torch.distributed._remote_device): + self.device = torch.distributed._remote_device(self.device) + + +class ShardingSpec(ABC): + """ + Base class representing sharding specifications. + """ + + @abstractmethod + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + """ + Given a global tensor size, define how to shard a tensor like this shape + across ranks, return ShardedTensorMetadata + Args: + tensor_sizes (:class:`torch.Size`): + The tensor shape to shard on, a `torch.Size` object that represents the + tensor shape to be sharded according to the ShardingSpec. + tensor_properties(:class:`torch.distributed._shard.sharded_tensor.TensorProperties): + Tensor properties used to create a ShardedTensor. + Returns: + A :class:`ShardedTensorMetadata` object that encodes the information about + the layout of the ShardedTensor and its properties. + """ + + @abstractmethod + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Given a global tensor on src_rank, shard this tensor + across ranks within the process group, return a ShardedTensor. + Args: + tensor (:class:`torch.Tensor`): Tensor needs to be sharded. + Keyword args: + src_rank (int, optional): The source rank which is used as the ground truth of + the data for the parameter that would be sharded and scattered + across the rest of the ranks. + Default: 0. + process_group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + Returns: + A :class:`ShardedTensor` sharded from the given tensor. + """ + + +# Ops customized for a particular ShardingSpec. +_CUSTOM_SHARDING_SPEC_OPS: dict[str, dict[Callable, Callable]] = {} + + +def _has_custom_op(sharding_spec, op): + """ + Returns whether or not the ShardingSpec has a custom op implementation. + """ + class_name = type(sharding_spec).__qualname__ + return ( + class_name in _CUSTOM_SHARDING_SPEC_OPS + and op in _CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +def _dispatch_custom_op( + sharding_spec, op: Callable, types, args, kwargs, process_group +): + """ + Calls the custom op for this ShardingSpec if it exists. + """ + class_name = type(sharding_spec).__qualname__ + if not _has_custom_op(sharding_spec, op): + raise RuntimeError(f"Custom op: {op} not registered for {class_name}") + func = _CUSTOM_SHARDING_SPEC_OPS[class_name][op] + return func(types, args, kwargs, process_group) + + +def custom_sharding_spec_op(sharding_spec_class, func): + """ + Decorator to allow custom registration of ops. + Args: + sharding_spec_class(type): The ShardingSpec for which we need to add this custom op. + func(Callable): The op to override (ex: torch.bmm) + """ + class_name = sharding_spec_class.__qualname__ + if class_name not in _CUSTOM_SHARDING_SPEC_OPS: + _CUSTOM_SHARDING_SPEC_OPS[class_name] = {} + return functools.partial( + _decorator_func, op=func, op_table=_CUSTOM_SHARDING_SPEC_OPS[class_name] + ) + + +@dataclass +class EnumerableShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that allows users to specify a generic + sharding scheme by enumerating exactly how each shard is laid out. + + Args: + shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing + each shard. Note that none of the shards should overlap. + """ + + shards: list[ShardMetadata] + + def __post_init__(self): + if len(self.shards) == 0: + raise ValueError(f"Empty shard list provided: {self.shards}") + + # Validate each shard has same rank. + rank = -1 + for shard in self.shards: + if rank != -1 and rank != len(shard.shard_offsets): + raise ValueError( + f"Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}" + ) + rank = len(shard.shard_offsets) + + validate_non_overlapping_shards_metadata(self.shards) + + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + # check if shards form a valid tensor + check_tensor(self.shards, tensor_sizes) + return sharded_tensor_meta.ShardedTensorMetadata( + self.shards, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + # TODO: figure out a generic and efficient way to scatter the shards for EnumerableShardingSpec + raise NotImplementedError("EnumerableShardingSpec.shard not implemented yet!") + + +def _infer_sharding_spec_from_shards_metadata(shards_metadata): + """ + Infer the sharding spec from the metadata of each shard of a ShardedTensor. + If the tensor is sharded only on one dimension, we can then verify whether it's + a ChunkShardingSpec or not. The way to verify it is to first get the total length + and perform a chunk sharding with the given placements to see if we can have the + same chunk size as the given shards_metadata. If not, we assume it's enum sharded. + + Args: + shards_metadata (List[ShardMetadata]): List of Metadata of local shards. + + Returns: + A :class:`torch.distributed._shard.sharding_spec.ShardingSpec` object of sharding + spec for one sharded tensor. + """ + placements = [] + chunk_sharding_dim = None + chunk_offset_list = [] + shard_size_list = [] + shard_offset_list = [] + # collect local shard metadatas from the global sharded_tensor_metadata + for shard_metadata in shards_metadata: # type: ignore[attr-defined] + placements.append(shard_metadata.placement) + local_offsets = shard_metadata.shard_offsets + chunk_offset_list.append(sum(local_offsets)) + shard_size_list.append(shard_metadata.shard_sizes) + shard_offset_list.append(shard_metadata.shard_offsets) + shard_dims = [idx for idx, e in enumerate(local_offsets) if e != 0] + # If the offset is [0, 0, ..., 0] (all zeros), + # we cannot decide whether how the tensor is sharded. + if len(shard_dims) == 0: + continue + # If the offset is [0, N, .,0, M, 0, .., 0], + # we are sure it's sharded by more than one dimension. + if len(shard_dims) != 1: + chunk_sharding_dim = None + break + # If the offset is [0, 0, .,0, M, 0, .., 0], aka, it's sharded by just + # one dimension, we need to make sure all ranks share the same dimension. + if not chunk_sharding_dim: + chunk_sharding_dim = shard_dims[0] + elif chunk_sharding_dim != shard_dims[0]: + chunk_sharding_dim = None + break + + if chunk_sharding_dim is not None: + # Ensure we infer the correct placement order from offsets + placements = [ + x + for _, x in sorted( + zip(chunk_offset_list, placements), key=operator.itemgetter(0) + ) + ] + + from .chunk_sharding_spec import ChunkShardingSpec + + chunk_spec = ChunkShardingSpec( + dim=chunk_sharding_dim, + placements=placements, + ) + + shard_sizes = sorted([x[chunk_sharding_dim] for x in shard_size_list]) + shard_total_length = sum(shard_sizes) + shard_offsets = sorted([x[chunk_sharding_dim] for x in shard_offset_list]) + + chunks = len(placements) + split_size = get_split_size(shard_total_length, chunks) + chunk_shard_sizes = sorted( + [ + get_chunked_dim_size(shard_total_length, split_size, idx) + for idx in range(chunks) + ] + ) + # Should match ChunkShardingSpec offsets calculation + chunk_shard_offsets = [split_size * idx for idx in range(chunks)] + if shard_sizes == chunk_shard_sizes and shard_offsets == chunk_shard_offsets: + return chunk_spec + return EnumerableShardingSpec(shards_metadata) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..5f14ed06962a6328dbb6e24081bf60c8b45cebcb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -0,0 +1,228 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import cast, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta +import torch.distributed.distributed_c10d as distributed_c10d +from torch.distributed._shard._utils import narrow_tensor +from torch.distributed._shard.metadata import ShardMetadata +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed._shard.sharded_tensor.utils import ( + _parse_and_validate_remote_device, +) + +from ._internals import get_chunked_dim_size, get_split_size +from .api import ShardingSpec + + +if TYPE_CHECKING: + # Only include ShardedTensor when do type checking, exclude it + # from run-time to resolve circular dependency. + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +@dataclass +class ChunkShardingSpec(ShardingSpec): + """ + This is a type of PlacementSpec that defines the placement as being sharded + across multiple devices. In particular, it represents sharding a Tensor + along a single dimension into equal chunks (similar to :meth:`torch.chunk`). + + The semantics of how a tensor is partitioned is inline with + :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the + specified ``dim`` and ``chunks`` in torch.chunk is the number of elements + in the placement specified. + + Args: + dim (int or str): + The dimension to shard on, could be an integer representing the + dimension or a string in case of named tensors where dimensions are + named. Note that named tensor support is not added yet. + placement(List[Union[_remote_device, str]]): + Specifies the placement of each shard of the Tensor. The size of + the list represents the number of shards to be created. This could + be a list of + :class:`torch.distributed._remote_device`'s. This list + could also contain a string which represents remote + device as accepted by + :class:`torch.distributed._remote_device` + """ + + ShardingDim = Union[int, str] + + dim: ShardingDim + placements: list[Union[torch.distributed._remote_device, str]] + + def __post_init__(self): + self._verify_dim(self.dim) + for i, remote_device in enumerate(self.placements): + if not isinstance(remote_device, torch.distributed._remote_device): + self.placements[i] = torch.distributed._remote_device(remote_device) + + @staticmethod + def _verify_dim(dim): + # Validate the sharding spec. + # TODO: support named dimension + if isinstance(dim, str): + raise NotImplementedError( + "ChunkShardingSpec does not support named dimension yet!" + ) + + if not isinstance(dim, int): + raise ValueError(f"Sharding dim needs to be an integer, found: {dim}") + + def build_metadata( + self, + tensor_sizes: torch.Size, + tensor_properties: sharded_tensor_meta.TensorProperties, + ) -> sharded_tensor_meta.ShardedTensorMetadata: + tensor_num_dim = len(tensor_sizes) + + self._verify_dim(self.dim) + if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator] + raise ValueError(f"Invalid sharding dim: {self.dim}") + + shards_metadata = [] + sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + for idx, placement in enumerate(self.placements): + # generate ShardMetadata for each placement device + chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx) + shard_size = list(tensor_sizes) + current_offsets = [0] * tensor_num_dim + current_offsets[self.dim] = split_size * idx # type: ignore[index] + shard_size[self.dim] = chunked_dim_size # type: ignore[index] + + shard_metadata = ShardMetadata( + shard_offsets=current_offsets, + shard_sizes=shard_size, + placement=placement, + ) + shards_metadata.append(shard_metadata) + + return sharded_tensor_meta.ShardedTensorMetadata( + shards_metadata, tensor_sizes, tensor_properties + ) + + def shard( + self, tensor: torch.Tensor, src_rank: int = 0, process_group=None + ) -> "ShardedTensor": + """ + Args: + src_rank: group rank relative to ``process_group`` + + N.B. If ``process_group`` is None, ``src_rank`` is a global rank. + """ + # relative imports to avoid circular dependency + from torch.distributed._shard.sharded_tensor import ShardedTensor + + tensor_properties = sharded_tensor_meta.TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=tensor.requires_grad, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ) + current_rank = dist.get_rank(process_group) + current_global_rank = dist.get_rank() + tensor_meta = self.build_metadata(tensor.size(), tensor_properties) + local_shards = [] + local_tensor = None + local_metadata = None + + tensors_to_scatter = cast( + list[Optional[torch.Tensor]], + [None] * dist.get_world_size(process_group), + ) + + sharding_dim_size = tensor.size()[self.dim] # type: ignore[index] + chunks = len(self.placements) + split_size = get_split_size(sharding_dim_size, chunks) + scatter_shape = list(tensor.size()) + scatter_shape[self.dim] = split_size # type: ignore[index] + + for shard_meta in tensor_meta.shards_metadata: + remote_global_rank, device = _parse_and_validate_remote_device( + process_group, shard_meta.placement + ) + if current_rank == src_rank: + # Reshape to get shard for this rank and we don't want autograd + # recording here for the narrow op and 'local_shard' should be a + # leaf variable in the autograd graph. + narrowed_tensor = narrow_tensor(tensor, shard_meta) + if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index] + # for the last shard that might be smaller to other shards + # resize the narrowed tensor to the same size and use it for + # the scatter collective as dist.scatter requires same size + # inputs on every rank + tensor_to_scatter = ( + narrowed_tensor.detach().clone().resize_(scatter_shape) + ) + else: + tensor_to_scatter = narrowed_tensor.detach().clone( + memory_format=torch.contiguous_format + ) + + tensors_to_scatter[ + dist.get_group_rank(process_group, remote_global_rank) + ] = tensor_to_scatter + + if current_global_rank == remote_global_rank: + local_tensor = torch.empty( + scatter_shape, + dtype=tensor.dtype, + layout=tensor.layout, + device=device, + ) + local_metadata = shard_meta + + # each rank should have local_tensor and local_metadata initialized if we build + # the metadata list in a correct way. + assert local_tensor is not None + assert local_metadata is not None + + # Scatter the shards to all ranks in the pg + # scatter takes the global rank as ``src`` + src_for_scatter = src_rank + if ( + process_group is not None + and process_group is not distributed_c10d._get_default_group() + ): + src_for_scatter = distributed_c10d.get_global_rank( + process_group, src_for_scatter + ) + + tensors_to_scatter_: Optional[list[torch.Tensor]] = None + if current_rank == src_rank: + tensors_to_scatter_ = [] + for t in tensors_to_scatter: + assert isinstance(t, torch.Tensor) + tensors_to_scatter_.append(t) + + dist.scatter( + local_tensor, + scatter_list=tensors_to_scatter_, + src=src_for_scatter, + group=process_group, + ) + + if list(local_tensor.size()) != local_metadata.shard_sizes: + # detach again after receiving to ensure local shards remain a leaf node + local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach() + + # Sync requires_grad to local_shard. + local_tensor.requires_grad = tensor.requires_grad + + local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata)) + + st = ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, tensor_meta, process_group=process_group + ) + + # Manually set sharding_spec + st._sharding_spec = self + + return st diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e3252ae8cd811d2502f1f52fdf69e51acd46e1a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6401ea5697cd01d3af2e9e510c8ef37d05549ac8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3c1a1da4c51a469870391d289597c748ac55630 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af216611c8d03788fd277f175202f6077fe0e96 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__pycache__/embedding_bag.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py new file mode 100644 index 0000000000000000000000000000000000000000..c0216c22a1a3c4ebac96c815921023f1e0d6f48d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -0,0 +1,348 @@ +# mypy: allow-untyped-defs + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharded_tensor._ops._common import _sharded_op_common +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec._internals import ( + get_chunk_sharding_params, + get_chunked_dim_size, + get_split_size, +) +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from torch.distributed.nn.functional import ( + _all_gather_base, + all_reduce, + all_to_all_single, +) + + +def _chunk_sharding_spec_check(spec, op): + """ + For the given op implementation check if the sharding spec is ChunkShardingSpec. + """ + if not isinstance(spec, ChunkShardingSpec): + raise NotImplementedError( + f"Only ChunkShardingSpec supported for '{op.__name__}'." + ) + + +def _register_sharded_op_on_local_tensor( + op, early_stop_func=None, extra_check=None, customized_func=None +): + """ + Handles ``__torch_function__`` dispatch for ops which are performed on + the single local tensor of the sharded tensor such as op like + ``torch.nn.functional.softmax`` or ``torch.Tensor.view``. + + For more complicated ops, a customized func can be used to generate + the new local tensor, sharding spec and sharded tensor size. + + Args: + op: The op to be registered and applied to all shards of the st. + early_stop_func (Callable, optional): the func for early stop. + Default: if ``None``, no early stop. + extra_check (Callable, optional): the func for extra condition check. + Default: if ``None``, no extra check. + customized_func (Callable, optional): the func for customized logic + to generate the new local tensor, sharding spec and sharded tensor size. + Default: if ``None``, we simply lower to the real op call with + the single local tensor of the st. + + Return: + func (Callable): registered implementation for sharded op for + ``__torch_function__`` dispatch. + """ + + @custom_sharding_spec_op(ChunkShardingSpec, op) + @_sharded_op_common(op, early_stop_func, extra_check) + def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None): + st = args[0] + sharding_spec = st.sharding_spec() + if len(st.local_shards()) != 1: + raise TypeError( + f"torch function '{op.__name__}', with args: {args} and " + f"kwargs: {kwargs} only supported for single local tensor!" + ) + st_size = st.size() + if customized_func: + local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg) + else: + args = (st.local_tensor(), *args[1:]) + local_tensor = op(*args, **kwargs) + return ShardedTensor._init_from_local_tensor( + local_tensor.contiguous(), + sharding_spec, + st_size, # type: ignore[arg-type] + process_group=pg, + init_rrefs=st._init_rrefs, + ) + + +def _handle_col_wise_sharding_base( + op_func, + col_dim, + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + mode=None, + gathered_per_sample_weights=None, + gathered_offsets=None, + padding_idx=None, +): + """ + For col-wise sharding of weight, lots of logic are common. + So we extract the common logic and put in this function: + Step 1. To get input from each rank and + Step 2. To perform the op on the concatenated tensor. + Step 3. To distribute results to each rank with col rearrangement. + Step 4. To concatenate all results from all ranks. + + Args: + op_func: operator which is applied to the input tensor. + col_dim: dim of result tensor after the operation. + input: tensor to be applied op on. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise sharded weight tensor. + pg: process group. + gathered_inputs: list of inputs from all ranks. If specified, we + don't need to communicate with each rank any more. + mode: aggregation mode of EmbeddingBag. + gathered_per_sample_weights: per_sample_weights across all ranks. + gathered_offsets: offsets across all ranks. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + + Return: final result of input being applied with the op. + """ + # run the operator's function for all the inputs. + results = [] + for i, inp in enumerate(gathered_inputs): + if op_func == torch.nn.functional.embedding_bag: + result = op_func( + inp, + local_shard, + offsets=gathered_offsets[i] if gathered_offsets is not None else None, + mode=mode, + per_sample_weights=gathered_per_sample_weights[i] + if gathered_per_sample_weights is not None + else None, + padding_idx=padding_idx, + ) + elif op_func == torch.nn.functional.embedding: + result = op_func( + inp, + local_shard, + padding_idx=padding_idx, + ) + else: + result = op_func(inp, local_shard) + results.append(torch.transpose(result, 0, col_dim)) + + # Distribute results to each rank with col rearrangement. + output = _result_distribute_with_col_rearrange( + results, input, world_size, weight, pg + ) + + # transpose the output and return result. + return torch.transpose(output, 0, col_dim) + + +def _result_distribute_with_col_rearrange(results, input, world_size, weight, pg): + """ + For col-wise sharding of weight, we need to distribute + results to each rank. We do them in this function. + Note that, if the index in the Sharding Spec is not equal to + the rank number, we need to do the rearrangement based on the + order given by the Sharding Spec (placement). + + Args: + results: results from ops applied to inputs from all ranks. + We need to distribute them back to their original ranks. + input: tensor to be applied op to. + world_size: number of ranks. + weight: sharded weight tensor. + pg: process group. + + Return: column rearranged result. + """ + # Process results and outputs for all2all. + sharding_dim = weight._sharding_spec.dim + sharding_dim_size = weight.size(sharding_dim) + dims = list(results[0].size()) + dims[0] = sharding_dim_size + combined_results = torch.cat(results) + output = torch.empty( + *dims, device=combined_results.device, dtype=combined_results.dtype + ) + + # Compute output splits + split_size = get_split_size(sharding_dim_size, world_size) + output_split_sizes = [0] * world_size + for idx, placement in enumerate(weight._sharding_spec.placements): + output_split_sizes[placement.rank()] = get_chunked_dim_size( + sharding_dim_size, split_size, idx + ) + + # distribute the outputs using all2all. + output = all_to_all_single( + output, combined_results, output_split_sizes=output_split_sizes, group=pg + ) + + # Check if we need to rearrange columns appropriately for output. + rearrange_columns = any( + idx != placement.rank() + for idx, placement in enumerate(weight._sharding_spec.placements) + ) + if not rearrange_columns: + return output + + indices = [] + for placement in weight._sharding_spec.placements: + dim_size = output_split_sizes[placement.rank()] + start = sum( + split_size if i < placement.rank() else 0 + for i, split_size in enumerate(output_split_sizes) + ) + indices += list(range(start, start + dim_size)) + + return output.index_select(0, torch.tensor(indices, device=output.device)) + + +def _handle_max_norm_col_wise( + max_norm, + norm_type, + local_shard, + input, + world_size, + gathered_inputs, + pg, +): + """ + For col-wise sharding of weight, we need to aggregate the + norm across all ranks before we can perform the proper re-norm. + Note that, the max_norm logic is only applied to the embedding + indices that are looked up and not the whole shard. + + Args: + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + local_shard: col-wise shared local weight used for lookup. + input: tensor to be applied op to. + world_size: number of ranks. + gathered_inputs: list of inputs from all ranks. + pg: process group. + + Return: + local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger + than it. + + """ + norm_type = norm_type if norm_type is not None else 2.0 + unique_inp = torch.unique(torch.cat(gathered_inputs)) + local_shard_sum = torch.sum( + torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype + ) + # For col-wise sharding, we need to first aggregate the powered sum + # from each rank first and then calculate the norm. + local_shard_sum = all_reduce(local_shard_sum, group=pg) + local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type) + max_norm_tensor = torch.full( + (local_shard.size(0),), + float("inf"), + dtype=local_shard.dtype, + device=input.device, + ) + max_norm_tensor[unique_inp] = max_norm + local_shard_t = local_shard.t().contiguous() + normalized_tensor = torch.where( + local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm + ) + # Make sure divisor is not zero. + local_shard_norm[local_shard_norm == 0.0] = 1.0 + local_shard_norm_renormed = ( + torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm) + .t() + .contiguous() + ) + return local_shard_norm_renormed + + +def _all_gather_base_input(input, pg): + """ + Use _all_gather_base to get a concatenated input from each rank. + + Args: + input: tensor to be applied op on. + pg: process group. + + Returns: + gathered_inputs: input gathered from each rank and concat by dim 0. + """ + # allgather the inputs first. + gather_inp_size = list(input.size()) + gather_inp_size[0] = input.size(0) * dist.get_world_size(pg) + gather_inp = torch.empty(gather_inp_size, device=input.device, dtype=input.dtype) + return _all_gather_base(gather_inp, input, group=pg) + + +def _handle_row_wise_mask(gather_inp, padding_idx, weight, world_size, rank): + """ + Mask the input for embedding look-up for IDs which are not stored + on the current rank. This function also adjust the ``padding_idx`` + so that it is only used on the rank where the corresponding row is + stored. + + Note that, with ``max_norm`` flag on, only weights of rows being + looked up will be re-normed. So we need an extra row for masked ID + so that it does not affect the final result and ``max_norm``. + + Args: + gather_inp: tensor to be applied op on gathered from all ranks. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + weight: weight tensor of Embedding look-up table. + world_size: number of ranks. + rank: # of cuda process. + + Returns: + lookup_input: Tensor of masked input. + padding_idx: adjusted padding_idx. + padding_row: The extra row we used during lookup so that + looking up does not affect ``max_norm``. + """ + (start_pos, chunk_size) = get_chunk_sharding_params( + weight.size(0), world_size, weight._sharding_spec, rank + ) + mask = (gather_inp < start_pos) | (gather_inp >= start_pos + chunk_size) + lookup_input = gather_inp.clone() - start_pos + lookup_input[mask] = chunk_size + if ( + padding_idx is not None + and padding_idx >= start_pos + and padding_idx < (start_pos + chunk_size) + ): + padding_idx = padding_idx - start_pos + else: + padding_idx = None + + # When max_norm is set, it will only re-norm the row being looked up. + padding_row = torch.zeros( + 1, weight.size(1), device=gather_inp.device, dtype=weight.dtype + ) + return lookup_input, padding_idx, padding_row diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..140d97c0fad8a802e61f7008da5855b1a446bcac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py @@ -0,0 +1,294 @@ +# mypy: allow-untyped-defs + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from torch.distributed.nn.functional import all_gather, reduce_scatter + +from ._common import ( + _all_gather_base_input, + _handle_col_wise_sharding_base, + _handle_max_norm_col_wise, + _handle_row_wise_mask, +) + + +@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding) +def sharded_embedding(types, args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``. + This method computes a sharded embedding lookup and has the following limitations: + + 1. Supports only sharding of ``weight``. + 2. Supports only ``ChunkShardingSpec``. + 3. Supports only a single local shard per rank. + 4. Supports all specs except for scale_grad_by_freq, sparse, etc. + + Based on the dimension that the weight is sharded on, there are two + algorithms: + + ROWWISE SHARDING + ================ + For row-wise sharding the weight is sharded on dimension 0. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (10 x 17) and W is sharded across + 4 GPUs creating 3 shard of (3 x 17) and 1 shard of (1 x 17). + The algorithm is as follows: + + 1. First the input is all gathered to all ranks, since this is SPMD and + input is actually sharded across all ranks. The inputs then become a + 4 (4 x 6) tensor on each rank. For example if the given input is + tensor([[6, 5, 2, 9, 6, 3], + [3, 1, 2, 4, 7, 6], + [4, 0, 4, 9, 8, 9], + [8, 6, 6, 4, 6, 1]]) + on rank 0. + Then on every rank, we will have this tensor. + If input itself is already replicated, no all-gather will be done. + 2. Next, we mask the ID which are not stored on that rank. + For example on rank 0, we store ID [0, 1, 2]. We only keep the ID + inside the set of numbers. The rest of them will be masked to an extra row. + The masked matrix will be used for embedding look up and is like: + tensor([[4, 4, 2, 4, 4, 4], + [4, 1, 2, 4, 4, 4], + [4, 0, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 1]]) + The reason of having an extra row (aka, number 4 in the example) is + because when max_norm is specified only weight which has looked will + be re-normed so mask IDs whose embeddings are not stored in current + rank will to an extra row will ensure max_norm still works as expected. + 3. If max_norm is specified, the extra row guarantees that the mask ID will + not affect the behavior of weigh re-norm. + + COLWISE SHARDING + ================ + For col-wise sharding the weight is sharded on dimension 1. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2). + The algorithm is as follows: + + 1. First the input is broadcasted to all ranks, since this is SPMD we + actually do an all_gather for all the inputs resulting in 4 (4 x 6) + inputs on each rank. + 2. Next we perform local embedding lookup operation by apply each + input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last). + This results in 4 (5 x 6 x 4) ((2 x 6 x 4) for the last) matrices + on each rank. We transpose dim 0 and dim 2. + 3. Next, we concat these 4 matrices and perform an all2all to share the + appropriate (5 x 6 x 4) or (2 x 6 x 4) matrices to each rank. + 4. Now, each rank receives a (17 x 6 x 4) matrix which is basically the + size of the result we need. + 5. If placements are not in order any appropriate rearrangement of columns + are done for the (17 x 6 x 4) matrix and finally we transpose the + dim 0 and dim 2 again. + 6. If max_norm is specified, we manually sum up the norm and renorm. Because + the renorm must be in place, we need to override the local_shard to mimic + this behavior. + """ + # Validate input params + _validate_embedding_param(args, kwargs) + + input = args[0] + weight = args[1] + max_norm = kwargs.get("max_norm") + norm_type = kwargs.get("norm_type") + padding_idx = kwargs.get("padding_idx") + + local_shard = weight.local_tensor().contiguous() + sharding_dim = weight._sharding_spec.dim + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + + if sharding_dim == 1: + output, local_shard = _handle_col_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg + ) + weight.local_shards()[0].tensor = local_shard + return output + elif sharding_dim == 0: + return _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + max_norm, + norm_type, + padding_idx, + rank, + pg, + ) + else: + raise RuntimeError( + f"nn.Embedding weight sharded on dim {sharding_dim} not supported!" + ) + + +def _validate_embedding_param(args, kwargs): + """ + Validate input params of sharded embedding op. + + Args: + input: list of ID used for lookup. + weight: sharded weight tensor. + kwargs: same as normal Embedding. + + Return: None. + """ + + input = args[0] + weight = args[1] + max_norm = kwargs.get("max_norm") + scale_grad_by_freq = kwargs.get("scale_grad_by_freq") + sparse = kwargs.get("sparse") + + # Validate types + if not isinstance(input, torch.Tensor): + raise TypeError("input need to be torch.Tensor") + if not isinstance(weight, ShardedTensor): + raise TypeError("weight needs to be ShardedTensor") + weight_size = weight.size() + if len(weight_size) != 2: + raise ValueError("Weight needs to have exactly 2 dims") + if int(torch.min(input).item()) < 0: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.min(input).item()), + weight_size[1], + ) + if int(torch.max(input).item()) >= weight_size[0]: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.max(input).item()), + weight_size[1], + ) + if scale_grad_by_freq: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!' + ) + if sparse: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "sparse" not supported!' + ) + if max_norm and max_norm <= 0.0: + raise ValueError('"max_norm" must be larger than zero!') + + if not isinstance(weight._sharding_spec, ChunkShardingSpec): + raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") + if len(weight.local_shards()) != 1: + raise ValueError("Only one local shard supported!") + + +def _handle_col_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, pg +): + """ + Entry-point function to handle the logic of col-wise sharding of weight + for embedding. (Detailed explanations of the logic can be found in + the comment for sharded_embedding.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise shared local weight used for lookup. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + pg: process group. + + Returns: final result of lookup. + """ + # allgather the inputs first for non Replicated Tensor. + gathered_inputs = all_gather(input, group=pg) + + if max_norm is not None: + # max_norm changes the weight in-place + local_shard = _handle_max_norm_col_wise( + max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg + ) + + output = _handle_col_wise_sharding_base( + torch.nn.functional.embedding, + len(input.size()), + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + padding_idx=padding_idx, + ) + return (output, local_shard) + + +def _handle_row_wise_sharding( + input, world_size, weight, local_shard, max_norm, norm_type, padding_idx, rank, pg +): + """ + Entry-point function to handle the logic of row-wise sharding of weight + for embedding. (Detailed explanations of the logic can be found in + the comment for sharded_embedding.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: row-wise shared local weight used for lookup. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + rank: # of cuda process. + pg: process group. + + Returns: final result of lookup. + """ + # allgather the inputs first for non Replicated Tensor. + gather_inp = _all_gather_base_input(input, pg) + + # Mask the input according to sharding spec. + lookup_input, padding_idx, padding_row = _handle_row_wise_mask( + gather_inp, padding_idx, weight, world_size, rank + ) + + # When input is a large tensor, the value of weight is changed. + # This is a walk-around for now. GH issue: #81717 + if max_norm is not None: + torch.nn.functional.embedding( + torch.unique(lookup_input)[:-1], + local_shard, + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + ) + max_norm = None + + local_input_embeddings = torch.nn.functional.embedding( + lookup_input, + torch.cat([local_shard, padding_row]), + padding_idx=padding_idx, + max_norm=max_norm, + norm_type=norm_type, + ) + + # TODO: Make the result a PartialTensor. + local_shards = local_input_embeddings.chunk(pg.size()) + return reduce_scatter( + torch.empty_like(local_shards[0]), + list(local_shards), + group=pg, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py new file mode 100644 index 0000000000000000000000000000000000000000..57081a3bd3c56e431964e692bf9744f451900903 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -0,0 +1,477 @@ +# mypy: allow-untyped-defs + +from typing import cast + +import torch +import torch.distributed as dist +from torch._C._distributed_c10d import ReduceOp +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed._shard.sharding_spec import ChunkShardingSpec +from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op +from torch.distributed.nn.functional import all_gather, reduce_scatter + +from ._common import ( + _all_gather_base_input, + _handle_col_wise_sharding_base, + _handle_max_norm_col_wise, + _handle_row_wise_mask, +) + + +@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding_bag) +def sharded_embedding_bag(types, args, kwargs, pg): + """ + Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``. + This method computes a sharded embedding bag aggregation and has the following limitations: + + 1. Supports only sharding of ``weight``. + 2. Supports only ``ChunkShardingSpec``. + 3. Supports only a single local shard per rank. + 4. Supports all specs except for scale_grad_by_freq, sparse, etc. + + Based on the dimension that the weight is sharded on, there are two + algorithms: + + ROWWISE SHARDING + ================ + For row-wise sharding the weight is sharded on dimension 0. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 4 shard of (4 x 17). + The algorithm is as follows: + + 1. First the input is all gathered to all ranks, since this is SPMD and + input is actually sharded across all ranks. The inputs then become a + 4 (4 x 6) tensor on each rank. For example if the given input is + tensor([[6, 5, 2, 9, 6, 3], + [3, 1, 2, 4, 7, 6], + [4, 0, 4, 9, 8, 9], + [8, 6, 6, 4, 6, 1]]) + on rank 0. + Then on every rank, we will have this tensor. + If input itself is already replicated, no all-gather will be done. + 2. Next, we mask the ID which are not stored on that rank. + For example on rank 0, we store ID [0, 1, 2]. We only keep the ID + inside the set of numbers. The rest of them will be masked to an extra row. + The masked matrix will be used for embedding look up and is like: + tensor([[4, 4, 2, 4, 4, 4], + [4, 1, 2, 4, 4, 4], + [4, 0, 4, 4, 4, 4], + [4, 4, 4, 4, 4, 1]]) + 3. If ``max_norm`` is specified, the extra row guarantees that the mask ID will + not affect the behavior of weigh re-norm. + 4. The example above only happens in one rank and each rank does a very similar thing. + For "Mean" mode we need to divide by either column size (2D) or the interval length + defined by the offset (excluding the row specified in ``padding_idx``). + We also need to mask the unexisting row to neg Inf so that negative value does not + gets wiped out in the "Max" mode. + + COLWISE SHARDING + ================ + For col-wise sharding the weight is sharded on dimension 1. + + The overall algorithm can be best explained with an example. Let's assume + the dims for input are (4 x 6) and W are (16 x 17) and W is sharded across + 4 GPUs creating 3 shards of (16 x 5) and 1 shard of (16 x 2). + The algorithm is as follows: + + 1. First the input is broadcasted to all ranks, since this is SPMD we + actually do an all_gather for all the inputs resulting in 4 (4 x 6) + inputs on each rank. + 2. Next we perform local embedding bag operation under the given mode by + apply each input (4 x 6) with the local shard (16 x 5) ((16 x 2) for the last). + This results in 4 (5 x 4) ((2 x 4) for the last) matrices on each rank. + We transpose the aggregation result. + 3. Next, we concatenate these 4 matrices and perform an all2all to share the + appropriate (5 x 4) or (2 x 4) matrices to each rank. + 4. Now, each rank receives a (17 x 4) matrix which is basically the + size of the result we need. + 5. If placements are not in order any appropriate rearrangement of columns + are done for the (17 x 4) matrix and finally we transpose the output again. + 6. If max_norm is specified, we manually sum up the norm and renorm. Because + the renorm must be in place, we need to override the local_shard to mimic + this behavior. + """ + # Validate input params + _validate_embedding_bag_param(args, kwargs) + + input = args[0] + weight = args[1] + offsets = kwargs.get("offsets") + per_sample_weights = kwargs.get("per_sample_weights") + mode = kwargs.get("mode") + max_norm = kwargs.get("max_norm") + norm_type = kwargs.get("norm_type") + include_last_offset = kwargs.get("include_last_offset") + padding_idx = kwargs.get("padding_idx") + + local_shard = weight.local_tensor().contiguous() + sharding_dim = weight._sharding_spec.dim + world_size = dist.get_world_size(pg) + rank = dist.get_rank(pg) + if include_last_offset: + offsets = offsets[:-1] + + if sharding_dim == 1: + output, local_shard = _handle_col_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + pg, + ) + weight.local_shards()[0].tensor = local_shard + return output + elif sharding_dim == 0: + return _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + rank, + pg, + ) + else: + raise RuntimeError( + f"nn.EmbeddingBag weight sharded on dim {sharding_dim} not supported!" + ) + + +def _validate_embedding_bag_param(args, kwargs): + """ + Validate input params of sharded embeddingBag op. + + Args: + input: list of ID used for lookup and aggregation. + weight: sharded weight tensor. + kwargs: same as normal EmbeddingBag. + + Return: None. + """ + + input = args[0] + weight = args[1] + offsets = kwargs.get("offsets") + per_sample_weights = kwargs.get("per_sample_weights") + mode = kwargs.get("mode") + max_norm = kwargs.get("max_norm") + scale_grad_by_freq = kwargs.get("scale_grad_by_freq") + sparse = kwargs.get("sparse") + include_last_offset = kwargs.get("include_last_offset") + + # Validate types + if not isinstance(input, torch.Tensor): + raise TypeError("input need to be torch.Tensor") + if offsets is not None and not isinstance(offsets, torch.Tensor): + raise TypeError("offsets need to be torch.Tensor") + if per_sample_weights is not None and not isinstance( + per_sample_weights, torch.Tensor + ): + raise TypeError("per_sample_weights need to be torch.Tensor") + if not isinstance(weight, ShardedTensor): + raise TypeError("weight needs to be ShardedTensor") + if len(input.size()) > 2: + raise ValueError("Input more than 2 dims not supported") + weight_size = weight.size() + if len(weight_size) != 2: + raise ValueError("Weight needs to have exactly 2 dims") + if int(torch.min(input).item()) < 0: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.min(input).item()), + weight_size[1], + ) + if int(torch.max(input).item()) >= weight_size[0]: + raise ValueError( + "Index out of range in Input %d %d", + int(torch.max(input).item()), + weight_size[1], + ) + if offsets is not None and len(input.size()) != 1: + raise ValueError("Input dimension needs to be exactly 1 dim") + if len(input.size()) == 1 and offsets is None: + raise ValueError("offsets is required for 1D input") + if per_sample_weights is not None and per_sample_weights.size() != input.size(): + raise ValueError( + f"per_sample_weights size {per_sample_weights.size()} not equal to input size {input.size()}" + ) + if mode is None: + mode = "mean" + if mode not in ["sum", "mean", "max"]: + raise ValueError(f"mode '{mode}' is not supported") + if scale_grad_by_freq: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "scale_grad_by_freq" not supported!' + ) + if sparse: + raise RuntimeError( + 'nn.Embedding weight sharded with flag on "sparse" not supported!' + ) + if include_last_offset and offsets is None: + raise ValueError('offsets is required for flag "include_last_offset"!') + if include_last_offset and cast(list[int], offsets)[-1] != input.size(0): + raise ValueError( + 'offsets need to have the input size in the end when the flag "include_last_offset" is on!' + ) + + if max_norm and max_norm <= 0.0: + raise ValueError('"max_norm" must be larger than zero!') + + if not isinstance(weight._sharding_spec, ChunkShardingSpec): + raise ValueError("Only ChunkShardingSpec supported for ShardedTensor ops!") + if len(weight.local_shards()) != 1: + raise ValueError("Only one local shard supported!") + + +def _handle_col_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + pg, +): + """ + Entry-point function to handle the logic of col-wise sharding of weight + for embeddingBag. (Detailed explanations of the logic can be found in + the comment for sharded_embedding_bag.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: col-wise shared local weight used for lookup. + offsets: list of start positions of each bag for 1D input. + per_sample_weights: weights for weighted sum mode. + mode: aggregation method of each bag. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + pg: process group. + + Return: + output: final result of lookup and aggregation. + local_shard: col-wise shared local weight used for lookup. + If max_norm, this will be the renormed weight. + """ + # allgather the special input of embedding bag first. + ( + gathered_inputs, + gathered_per_sample_weights, + gathered_offsets, + ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) + + if max_norm is not None: + # max_norm changes the weight in-place + local_shard = _handle_max_norm_col_wise( + max_norm, norm_type, local_shard, input, world_size, gathered_inputs, pg + ) + + output = _handle_col_wise_sharding_base( + torch.nn.functional.embedding_bag, + 1, + input, + world_size, + weight, + local_shard, + pg, + gathered_inputs, + mode=mode, + gathered_per_sample_weights=gathered_per_sample_weights, + gathered_offsets=gathered_offsets, + padding_idx=padding_idx, + ) + return (output, local_shard) + + +def _handle_row_wise_sharding( + input, + world_size, + weight, + local_shard, + offsets, + per_sample_weights, + mode, + max_norm, + norm_type, + padding_idx, + rank, + pg, +): + """ + Entry-point function to handle the logic of row-wise sharding of weight + for embeddingBag. (Detailed explanations of the logic can be found in + the comment for sharded_embedding_bag.) + + Args: + input: list of ID used for lookup and aggregation. + world_size: number of ranks. + weight: sharded weight tensor. + local_shard: row-wise shared local weight used for lookup. + offsets: list of start positions of each bag for 1D input. + per_sample_weights: weights for weighted sum mode. + mode: aggregation method of each bag. + max_norm: If given, each embedding vector with norm larger + than max_norm is renormalized to have norm max_norm. + Note: this will modify weight in-place. + norm_type: The p in the p-norm to compute for the max_norm option. + padding_idx: If specified, the entries at padding_idx do + not contribute to the gradient; therefore, the embedding + vector at padding_idx is not updated during training, + i.e. it remains as a fixed "pad". + Note that the embedding vector at padding_idx is + excluded from the reduction. + rank: # of cuda process. + pg: process group. + + Returns: + gathered_output: final result of lookup and aggregation. + """ + if input.dim() > 1 and per_sample_weights is None: + # allgather the inputs first for non Replicated Tensor. + gather_inp = _all_gather_base_input(input, pg) + else: + ( + gathered_inputs, + gathered_per_sample_weights, + gathered_offsets, + ) = _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg) + cat_dim = 0 if input.dim() != 1 else -1 + gather_inp = torch.cat(gathered_inputs, dim=cat_dim) + if per_sample_weights is not None: + per_sample_weights = torch.cat(gathered_per_sample_weights, dim=cat_dim) + offset_add = 0 if input.dim() > 1 else input.size(0) + if offsets is not None: + offsets_list = torch.cat( + [gathered_offsets[i] + (offset_add * i) for i in range(pg.size())], + dim=cat_dim, + ) + + # Mask the input according to sharding spec. + lookup_input, padding_local, padding_row = _handle_row_wise_mask( + gather_inp, padding_idx, weight, world_size, rank + ) + if mode == "max": + padding_row[:] = -float("Inf") + + # When input is a large tensor, the value of weight is changed. + # This is a walk-around for now. GH issue: #81717. + if max_norm is not None: + torch.nn.functional.embedding_bag( + torch.unique(lookup_input)[:-1], + local_shard, + offsets=torch.tensor([0], device=local_shard.device, dtype=torch.long), + mode=mode, + per_sample_weights=None, + max_norm=max_norm, + norm_type=norm_type, + padding_idx=padding_local, + ) + max_norm = None + result = torch.nn.functional.embedding_bag( + lookup_input, + torch.cat([local_shard, padding_row]), + offsets=offsets_list if offsets is not None else offsets, # type: ignore[possibly-undefined] + mode=mode if mode != "mean" else "sum", + per_sample_weights=per_sample_weights, + max_norm=max_norm, + norm_type=norm_type, + padding_idx=padding_local, + ) + + op = ReduceOp.SUM if mode != "max" else ReduceOp.MAX + # TODO: Make the result a PartialTensor and move the logic below there. + local_shards = result.chunk(pg.size()) + result = reduce_scatter( + torch.empty_like(local_shards[0]), + list(local_shards), + op=op, + group=pg, + ) + + # For Mean, we cannot do the division until very end because the sum of means + # not equal to the mean of sum. (Divisor is different) + if mode == "mean": + if input.dim() > 1: + padding_idx = padding_idx if padding_idx is not None else -1 + split_sizes = torch.sum( + torch.ne(input, padding_idx), dim=-1, dtype=local_shard.dtype + ) + else: + split_sizes = torch.cat( + ( + offsets[1 : offsets.size(0)] - offsets[0:-1], + (input.size(0) - offsets[-1]).unsqueeze(0), + ), + dim=-1, + ) + return torch.div(result, split_sizes.unsqueeze(1)) + + # Return the appropriate local result. + return result + + +def _all_gather_embedding_bag_input(input, per_sample_weights, offsets, pg): + """ + In case we need to gather input and all other parameters of embeddingBag + ops, we need to stack all input together to perform ``all_gather`` + collective communication just once. + + Note that since offsets does not share the same size as input and + is always smaller than input, we resize it during the communication. + + Args: + input: tensor to be applied op on. + per_sample_weights: weights for weighted sum mode. + offsets: when input is 1D. offsets determines the starting + index position of each bag (sequence) in input. + pg: process group. + + Returns: + gathered_inputs: list of input tensor gathered from each rank. + gathered_per_sample_weights: list of per_sample_weights from each rank. + gathered_offsets: list of offsets from each rank. + """ + input_to_gather = [input] + if per_sample_weights is not None: + input_to_gather.append(per_sample_weights) + if offsets is not None: + input_to_gather.append(offsets.clone().resize_(input.size())) + gathered_inputs = all_gather(torch.stack(input_to_gather), group=pg) + + gathered_per_sample_weights = None + if per_sample_weights is not None: + gathered_per_sample_weights = [t[1] for t in gathered_inputs] + gathered_offsets = None + if offsets is not None: + idx = 2 if per_sample_weights is not None else 1 + gathered_offsets = [ + t[idx].resize_(offsets.size()).to(offsets.dtype) for t in gathered_inputs + ] + gathered_inputs = [t[0].to(input.dtype) for t in gathered_inputs] + return gathered_inputs, gathered_per_sample_weights, gathered_offsets diff --git a/phivenv/Lib/site-packages/torch/distributed/_sharded_tensor/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_sharded_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..03ec4ad628705f52db693d9d0a2126ef42eebf6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_sharded_tensor/__init__.py @@ -0,0 +1,21 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import warnings + +import torch +from torch.distributed._shard.sharded_tensor import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharded_tensor` will be deprecated, " + "use `torch.distributed._shard.sharded_tensor` instead", + DeprecationWarning, + stacklevel=2, + ) + +sys.modules["torch.distributed._sharded_tensor"] = ( + torch.distributed._shard.sharded_tensor +) diff --git a/phivenv/Lib/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a34b259f37724ca02c10064d1d3dcb731978f725 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_sharded_tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_sharding_spec/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_sharding_spec/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0d1ba4fc0b6ed57da9598817df1015ee06879cab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_sharding_spec/__init__.py @@ -0,0 +1,22 @@ +# Keep old package for BC purposes, this file should be removed once +# everything moves to the `torch.distributed._shard` package. +import sys +import warnings + +import torch +from torch.distributed._shard.sharding_spec import * # noqa: F403 + + +with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`torch.distributed._sharding_spec` will be deprecated, " + "use `torch.distributed._shard.sharding_spec` instead", + DeprecationWarning, + stacklevel=2, + ) + +import torch.distributed._shard.sharding_spec as _sharding_spec + + +sys.modules["torch.distributed._sharding_spec"] = _sharding_spec diff --git a/phivenv/Lib/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a35eacc8d92c5030cc3a671895702fa7a65f3380 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_sharding_spec/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5334a0eb7a2beddc1cd41429ec6df2734bc045b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__init__.py @@ -0,0 +1,1730 @@ +import math +import os +import socket +import uuid +from collections.abc import Generator +from contextlib import contextmanager +from datetime import timedelta +from enum import Enum +from functools import partial +from typing import Any, Callable, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch._C._autograd import DeviceType +from torch._C._distributed_c10d import _SymmetricMemory, Work as _Work + + +_group_name_to_store: dict[str, c10d.Store] = {} + + +def enable_symm_mem_for_group(group_name: str) -> None: + """ + Enables symmetric memory for a process group. + + Args: + group_name (str): the name of the process group. + """ + if group_name in _group_name_to_store: + return + + group = c10d._resolve_process_group(group_name) + global_ranks = sorted(c10d._world.pg_group_ranks[group].keys()) + # Different subgroups with the same name should use different stores + global_ranks_str = "_".join(map(str, global_ranks)) + store = c10d.PrefixStore( + f"symmetric_memory-{global_ranks_str}", + c10d._get_process_group_store(group), + ) + _group_name_to_store[group_name] = store + _SymmetricMemory.set_group_info( + group_name, + group.rank(), + group.size(), + store, + ) + + +_is_test_mode: bool = False +_mocked_group_names: Optional[set[str]] = None + + +@contextmanager +def _test_mode(group_names: Optional[set[str]] = None) -> Generator[None, None, None]: + """ + Forces ``is_symm_mem_enabled_for_group()`` to return ``True`` and the ops + defined in the ``symm_mem`` namespace to use fallback implementations. + + The context manager is not thread safe. + """ + global _is_test_mode + global _mocked_group_names + prev = _is_test_mode + prev_group_names = _mocked_group_names + try: + _is_test_mode = True + _mocked_group_names = group_names + yield + finally: + _is_test_mode = prev + _mocked_group_names = prev_group_names + + +def is_symm_mem_enabled_for_group(group_name: str) -> bool: + """ + Check if symmetric memory is enabled for a process group. + + Args: + group_name (str): the name of the process group. + """ + if _is_test_mode: + return _mocked_group_names is None or group_name in _mocked_group_names + return group_name in _group_name_to_store + + +_group_name_to_workspace_tensor: dict[str, Optional[torch.Tensor]] = {} + + +def get_symm_mem_workspace(group_name: str, min_size: int) -> _SymmetricMemory: + """ + Get the symmetric memory workspace associated with the process group. If + ``min_size`` is greater than the workspace associated with ``group_name``, + the workspace will be re-allocated and re-rendezvous'd. + + Args: + group_name (str): the name of the process group. + min_size (int): the size requirement for the workspace in bytes. + + Returns: + _SymmetricMemory: the symmetric memory workspace associated with the + group. + """ + enable_symm_mem_for_group(group_name) + + tensor = _group_name_to_workspace_tensor.get(group_name) + size = tensor.numel() * tensor.element_size() if tensor is not None else 0 + if tensor is None or size < min_size: + if torch.cuda.is_current_stream_capturing(): + curr_size = 0 if tensor is None else tensor.numel() * tensor.element_size() + raise RuntimeError( + f"get_symm_mem_workspace(): the requested size ({min_size} bytes) " + "is greater than the size of the currently allocated workspace " + f"({curr_size} bytes). It's currently not possible to expand the " + "workspace size during graph capture. Please invoke " + f'`get_symm_mem_workspace(group_name="{group_name}", ' + f'min_size="{min_size}")` before initiating the graph capture ' + "and try again." + ) + tensor = _SymmetricMemory.empty_strided_p2p( + (max(size, min_size),), + [1], + torch.uint8, + torch.device(f"cuda:{torch.cuda.current_device()}"), + group_name, + ) + _group_name_to_workspace_tensor[group_name] = tensor + return _SymmetricMemory.rendezvous(tensor) + + +_backend_streams: dict[int, torch.cuda.Stream] = {} + + +def _get_backend_stream(priority: int = 0) -> torch.cuda.Stream: + if priority not in _backend_streams: + _backend_streams[priority] = torch.cuda.Stream(priority=priority) + return _backend_streams[priority] + + +def _pipelined_multi_all_gather_and_consume( + shard: list[torch.Tensor], + shard_consumer: Callable[[list[torch.Tensor], int], None], + ag_out: list[torch.Tensor], + group_name: str, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + gathered = [ + all_gather_tensor(x, gather_dim=0, group=group) + for x in shard + ] + + shards = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + p2p_workspace_size_req = 0 + for x in shard: + p2p_workspace_size_req += x.numel() * x.element_size() + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + for x, y in zip(shard, ag_out): + assert x.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `shard` must be contiguous" + ) + assert y.is_contiguous(), ( + "_pipelined_all_gather_and_consume: all tensors " + "in `ag_out` must be contiguous" + ) + assert x.shape[0] * group_size == y.shape[0] + assert x.shape[1:] == y.shape[1:] + + def copy_shard(dst: list[torch.Tensor], src: list[torch.Tensor]) -> None: + for d, s in zip(dst, src): + d.copy_(s) + + def get_p2p_bufs(remote_rank: int) -> list[torch.Tensor]: + offset_bytes = 0 + bufs = [] + for x in shard: + buf = symm_mem.get_buffer( + remote_rank, + x.shape, + x.dtype, + storage_offset=offset_bytes // x.element_size(), + ) + bufs.append(buf) + offset_bytes += buf.numel() * buf.element_size() + return bufs + + local_p2p_bufs = get_p2p_bufs(rank) + + # shards[i] => shard from rank i + shards: list[list[torch.Tensor]] = [[] for _ in range(group_size)] + for x in ag_out: + for i, y in enumerate(x.chunk(group_size)): + shards[i].append(y) + + # Parallelization strategy: after each rank copies its shard into its local + # p2p buffer, every rank issues independent p2p copy -> shard_consumer + # sequences to two streams. In addition to computation/communication + # overlapping, the strategy allows for computation/computation overlapping, + # greatly reducing quantization inefficiency. + # + # Notation: + # - "mv" for the copy to local buffer + # - "cp" for p2p copies + # - "b" for barriers + # + # Constraints: + # - The GPU scheduler may or may not overlap "mv" with the first shard_consumer. + # - "cp" from different streams cannot overlap. + # + # Ideal scenario 0 - "mv" overlaps with the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Ideal scenario 1 - "mv" is scheduled before the first shard_consumer: + # + # stream 0: [ shard_consumer ][ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "mv" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ][b][ cp ][ shard_consumer ] + # + # Suboptimal scenario 0 - "b" is scheduled after the first shard_consumer: + # + # stream 0: [ shard_consumer ] [ cp ][ shard_consumer ] + # stream 1: [ mv ] [b][ cp ][ shard_consumer ] + # + # We haven't yet figured out a way to ensure "mv" and "b" are either + # overlapped with or scheduled before the first shard_consumer. Thus, to + # prevent suboptimal scenarios, we are giving up the chance to overlap "mv" + # and "b" with the first shard_consumer for now. + copy_shard(dst=local_p2p_bufs, src=shard) + symm_mem.barrier(channel=1) + backend_stream.wait_stream(torch.cuda.current_stream()) + + # At this point, all ranks have copied their local shard to + # their local p2p buffer. Each rank can now copy and consume + # remote shards. + shard_consumer(shard, rank) + + for step in range(1, group_size): + if step % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + remote_rank = (step + rank) % group_size + remote_p2p_bufs = get_p2p_bufs(remote_rank) + with stream: + copy_shard(dst=shards[remote_rank], src=remote_p2p_bufs) + shard_consumer(shards[remote_rank], remote_rank) + + if ag_out_needed: + # Copy from input to the all-gather output. Opportunistically overlap + # it with the last shard_consumer. + if group_size % 2 == 0: + stream = torch.cuda.current_stream() + else: + stream = backend_stream + with stream: + copy_shard(dst=shards[rank], src=shard) + + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +def _pipelined_all_gather_and_consume( + shard: torch.Tensor, + shard_consumer: Callable[[torch.Tensor, int], None], + ag_out: torch.Tensor, + group_name: str, + ag_out_needed: bool = True, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + ag_out = all_gather_tensor(shard, gather_dim=0, group=group) + shards = ag_out.chunk(group.size()) + for src_rank, shard in enumerate(shards): + shard_consumer(shard, src_rank) + """ + + def adapter(shard: list[torch.Tensor], rank: int) -> None: + shard_consumer(shard[0], rank) + + _pipelined_multi_all_gather_and_consume( + [shard], + adapter, + [ag_out], + group_name, + ag_out_needed, + ) + + +def _pipelined_produce_and_all2all( + chunk_producer: Callable[[int, torch.Tensor], None], + output: torch.Tensor, + group_name: str, +) -> None: + """ + Perform the following logic with micro-pipelined computation and + communication: + + chunks = [ + chunk_producer(dst_rank, chunks[dst_rank]) + for dst_rank in range(group_size): + ] + dist.all_to_all_single(output=output, input=torch.cat(chunks)) + """ + out_chunks = output.chunk(c10d._get_group_size_by_name(group_name)) + p2p_workspace_size_req = out_chunks[0].numel() * out_chunks[0].element_size() * 2 + symm_mem = get_symm_mem_workspace(group_name, min_size=p2p_workspace_size_req) + group_size = symm_mem.world_size + rank = symm_mem.rank + + symm_mem.barrier(channel=0) + backend_stream = _get_backend_stream() + backend_stream.wait_stream(torch.cuda.current_stream()) + + def get_p2p_buf(rank: int, idx: int) -> torch.Tensor: + assert idx in (0, 1) + offset = 0 if idx == 0 else out_chunks[0].numel() + return symm_mem.get_buffer( + rank, out_chunks[0].shape, out_chunks[0].dtype, offset + ) + + # Prepare two local p2p buffers, so that a remote rank can pull the result + # of step [i] in one p2p buffer while the local rank can compute the + # result of step [i+1] and write it directly the other p2p buffer. + local_p2p_buf_0 = get_p2p_buf(rank, 0) + local_p2p_buf_1 = get_p2p_buf(rank, 1) + + for step in range(1, group_size): + remote_rank = (rank - step) % group_size + if step % 2 == 0: + stream = torch.cuda.current_stream() + p2p_buf = local_p2p_buf_1 + remote_p2p_buf = get_p2p_buf(remote_rank, 1) + else: + stream = backend_stream + p2p_buf = local_p2p_buf_0 + remote_p2p_buf = get_p2p_buf(remote_rank, 0) + with stream: + # Parallelization strategy: every rank issues independent compute + # -> barrier -> p2p copy sequences on two streams. In addition to + # computation/communication overlapping, the strategy allows for + # computation/computation overlapping, greatly reducing + # quantization inefficiency. + # + # Ideally, stream activities would look like this ("b" for + # barriers, "cp" for p2p copies): + # + # [rank 0] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # stream 1: [ chunk_producer ][b][ cp ][ chunk_producer ][b][ cp ] + # + # Note that the barriers synchronize streams with the same ID + # across ranks. They don't synchronize streams on the same rank. + # + # Since the work on both streams is independent, there's no + # guarantee that the chunk_producer from stream 0 or stream 1 will + # be scheduled first. If there is a scheduling mismatch across + # ranks, the barrier forces all ranks to wait for the slowest. + # + # When scheduling mismatches occur among ranks, the stream + # activities might look like this (note that p2p copies from + # different streams cannot overlap with each other): + # + # [rank 0] + # stream 0: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # stream 1: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # + # [rank 1] + # stream 0: [ chunk_producer ][b] [ cp ][ chunk_producer ][b] [ cp ] + # stream 1: [ chunk_producer ][b ][ cp ][ chunk_producer ][b ][ cp ] + # + # To prevent this, we need to ensure that the chunk_producer on + # stream 1 gets scheduled first on every rank. Without access to + # the underlying kernels, CUDA offers no API to control the + # scheduling order of two independent, overlapping kernels. Our + # solution is to issue a small sleep kernel in stream 0. The sleep + # duration is insignificant, but having an extra task in stream 0 + # will almost guarantee that the chunk_producer on stream 1 gets + # scheduled first. Once the first chunk_producer is scheduled in + # the correct order, there's very little room for the scheduling + # order of subsequent kernels to be inconsistent across ranks. + if step == 2: + torch.cuda._sleep(100) + chunk_producer((rank + step) % group_size, p2p_buf) + symm_mem.barrier(channel=step % 2) + out_chunks[remote_rank].copy_(remote_p2p_buf) + # The local P2P buffer can only be overwritten by the next + # chunk_producer after all peers have finished reading from it. + symm_mem.barrier(channel=step % 2) + + # If the sleep wasn't issued in the above loop, do it now. + if group_size == 2: + torch.cuda._sleep(100) + + chunk_producer(rank, out_chunks[rank]) + torch.cuda.current_stream().wait_stream(backend_stream) + symm_mem.barrier(channel=0) + + +lib = torch.library.Library("symm_mem", "DEF") # noqa: TOR901 +lib.define( + "fused_all_gather_matmul(" + "Tensor A, Tensor[] Bs, int gather_dim, str group_name, *, bool return_A = True) -> (Tensor?, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_all_gather_scaled_matmul(" + "Tensor A, Tensor[] Bs, Tensor A_scale, Tensor[] B_scales, " + "int gather_dim, str group_name, " + "Tensor?[] biases, " + "Tensor?[] result_scales, " + "ScalarType?[] out_dtypes, " + "bool[] use_fast_accum) -> (Tensor, Tensor[])", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_matmul_reduce_scatter(Tensor A, Tensor B, str reduce_op, int scatter_dim, str group_name) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define( + "fused_scaled_matmul_reduce_scatter(" + "Tensor A, Tensor B, Tensor A_scale, Tensor B_scale, " + "str reduce_op, int orig_scatter_dim, int scatter_dim_after_maybe_reshape, str group_name, int[]? output_shape, " + "Tensor? bias = None, " + "Tensor? result_scale = None, " + "ScalarType? out_dtype = None, " + "bool use_fast_accum = False) -> Tensor", + tags=[torch._C.Tag.needs_fixed_stride_order], +) +lib.define("_low_contention_all_gather(Tensor tensor, str group_name) -> Tensor") +lib.define( + "_low_contention_reduce_scatter(Tensor tensor, str reduce_op, str group_name) -> Tensor" +) + + +class _ScaleMode(Enum): + UNSCALED = "unscaled" + TENSOR_WISE = "tensor-wise" + ROW_WISE_SHARDED = "row-wise-sharded" + ROW_WISE_REPLICATED = "row-wise-replicated" + + +def _check_and_verify_fp8_all_gather_scale_mode( + shard: torch.Tensor, scale: Optional[torch.Tensor], gather_dim: int, group_size: int +) -> _ScaleMode: + full_shape = list(shard.shape) + full_shape[gather_dim] *= group_size + + if scale is None: + return _ScaleMode.UNSCALED + elif scale.shape[:-1] == shard.shape[:-1] and scale.shape[-1] == 1: + # Row-wise scaling + # + # NOTE: when the last dim of both A_shard and A_scale is one, we can't + # tell if A_scale is replicated tensor-wise scale or sharded row-wise + # scale. Treating it as row-wise scaling for safety. + return _ScaleMode.ROW_WISE_SHARDED + elif scale.numel() == 1: + return _ScaleMode.TENSOR_WISE + elif list(scale.shape[:-1]) == full_shape[:-1]: + return _ScaleMode.ROW_WISE_REPLICATED + else: + raise ValueError( + "Invalid scale shape for fp8 all-gather " + f"(shard shape: {shard.shape}, scale shape: {scale.shape})" + ) + + +def _fused_all_gather_matmul_impl( + mm_out_op: torch._ops.OpOverload, + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: Optional[torch.Tensor], + kwargs_list: list[dict[str, Any]], + out_dtypes: list[Optional[torch.dtype]], + gather_dim: int, + group_name: str, + return_A: bool, +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: + if A_shard.dim() < 2: + raise ValueError("A_shard must be a matrix") + for B in Bs: + if B.dim() != 2: + raise ValueError("B must be a matrix") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_types) must be the same as len(Bs)") + if len(kwargs_list) != len(Bs): + raise ValueError("len(kwargs_list) must be the same as len(Bs)") + if gather_dim < 0 or gather_dim >= A_shard.dim(): + raise ValueError("Invalid gather_dim") + + group = c10d._resolve_process_group(group_name) + + # Move the gather_dim to the front and flatten the tensor into a 2D matrix. + # The flattened tensor doesn't need to be contiguous (for computation + # efficiency), as _pipelined_all_gather_and_consume guarantees that shards + # passed to shard_consumer are contiguous. + A_shard_flat = A_shard.movedim(gather_dim, 0) + leading_dims = [group.size()] + list(A_shard_flat.shape[:-1]) + A_shard_flat = A_shard_flat.flatten(0, -2) + + # Helper function for reverting the above transformation + def unflatten(t: torch.Tensor) -> torch.Tensor: + return t.view(*leading_dims, -1).flatten(0, 1).movedim(0, gather_dim) + + A_flat = A_shard_flat.new_empty( + A_shard_flat.shape[0] * group.size(), + A_shard_flat.shape[1], + ) + + outputs = [ + A_flat.new_empty(A_flat.shape[0], B.shape[1], dtype=out_dtype or B.dtype) + for B, out_dtype in zip(Bs, out_dtypes) + ] + output_shards = [output.chunk(group.size()) for output in outputs] + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group.size() + ) + + # Computing block-wise matmul along the first dim of A + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + assert A_scale is not None + A_scale_shard = A_scale.movedim(gather_dim, 0).flatten(0, -2) + A_scale_flat = A_scale_shard.new_empty( + A_scale_shard.shape[0] * group.size(), + A_scale_shard.shape[1], + ) + + def row_wise_sharded_consumer(shard: list[torch.Tensor], rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard[0], + B, + scale_a=shard[1], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_multi_all_gather_and_consume( + [A_shard_flat, A_scale_shard], + row_wise_sharded_consumer, + [A_flat, A_scale_flat], + group_name, + return_A, + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + assert A_scale is not None + A_scale_shards = ( + A_scale.movedim(gather_dim, 0).flatten(0, -2).chunk(group.size()) + ) + + def row_wise_replicated_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op( + shard, + B, + scale_a=A_scale_shards[rank], + **kwargs, + out=output_shards[idx][rank], + ) + + _pipelined_all_gather_and_consume( + A_shard_flat, + row_wise_replicated_consumer, + A_flat, + group_name, + return_A, + ) + else: + if scale_mode == _ScaleMode.TENSOR_WISE: + assert A_scale is not None + for kwargs in kwargs_list: + kwargs["scale_a"] = A_scale + else: + assert scale_mode == _ScaleMode.UNSCALED + + def default_consumer(shard: torch.Tensor, rank: int) -> None: + for idx, (B, kwargs) in enumerate(zip(Bs, kwargs_list)): + mm_out_op(shard, B, **kwargs, out=output_shards[idx][rank]) + + _pipelined_all_gather_and_consume( + A_shard_flat, + default_consumer, + A_flat, + group_name, + return_A, + ) + + A = unflatten(A_flat) if return_A else None + return A, [unflatten(output) for output in outputs] + + +@torch.library.impl(lib, "fused_all_gather_matmul", "Meta") +def _fused_all_gather_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: str, + *, + return_A: bool = True, +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + res = [torch.matmul(A, B).movedim(0, gather_dim) for B in Bs] + if return_A: + return A.movedim(0, gather_dim), res + else: + return None, res + + +@torch.library.impl(lib, "fused_all_gather_matmul", "CUDA") +def _fused_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: str, + *, + return_A: bool = True, +) -> tuple[Optional[torch.Tensor], list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + all_gather_tensor(A_shard, gather_dim, group_name) @ B + + Optimal stride order for A_shard - if A_shard.movedim(gather_dim, 0) is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + if _is_test_mode: + return _fused_all_gather_matmul_fallback( + A_shard, Bs, gather_dim, group_name, return_A=return_A + ) + + if _should_use_fused_all_gather_matmul_native(A_shard, Bs, gather_dim, group_name): + group = c10d._resolve_process_group(group_name) + leading_dims = list(A_shard.shape[:-1]) + leading_dims[0] *= group.size() + A, out = _fused_all_gather_matmul_native( + A_shard.flatten(0, -2), Bs[0], group_name + ) + return A.view(*leading_dims, -1), [out.view(*leading_dims, -1)] + + if _should_use_multimem_all_gather_matmul( + A_shard, gather_dim, group_name, return_A + ): + return None, _multimem_all_gather_matmul(A_shard, Bs, group_name) + + with torch.profiler.record_function("fused_all_gather_matmul"): + return _fused_all_gather_matmul_impl( + torch.ops.aten.mm.out, + A_shard, + Bs, + None, + [{} for B in Bs], + [B.dtype for B in Bs], + gather_dim, + group_name, + return_A, + ) + + +def _should_use_fused_all_gather_matmul_native( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + gather_dim: int, + group_name: str, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + + return ( + "TORCH_SYMM_MEM_ENABLE_NATIVE_ASYNC_TP" in os.environ + and A_shard.is_contiguous() + and gather_dim == 0 + # _async_input_mm requires local_M to be divisible by world_size. + and local_M % group.size() == 0 + # _async_input_mm outperforms the decomposition-based approach when the + # global M is small. + and 2048 < local_M * group.size() <= 4096 + # _async_input_mm only supports a single B. + and len(Bs) == 1 + ) + + +def _fused_all_gather_matmul_native( + A_shard: torch.Tensor, + B: torch.Tensor, + group_name: str, +) -> tuple[torch.Tensor, torch.Tensor]: + symm_mem = rendezvous(A_shard, group_name) + if symm_mem is None: + symm_mem = get_symm_mem_workspace( + group_name, A_shard.numel() * A_shard.element_size() + ) + symm_mem.barrier() + buf = symm_mem.get_buffer(symm_mem.rank, A_shard.shape, A_shard.dtype) + buf.copy_(A_shard) + A_shard = buf + + rank = symm_mem.rank + world_size = symm_mem.world_size + + current_stream = torch.cuda.current_stream() + backend_stream = _get_backend_stream(priority=-1) + + symm_mem.barrier() + backend_stream.wait_stream(current_stream) + current_stream.wait_stream(backend_stream) + + A = A_shard.new_empty(A_shard.shape[0] * world_size, A_shard.shape[1]) + A_signals = torch.zeros(world_size, dtype=torch.uint32, device=A_shard.device) + A_shards = A.chunk(world_size) + + A_shards[rank].copy_(A_shard) + if not torch.cuda.is_current_stream_capturing(): + _SymmetricMemory.stream_write_value32(A_signals, rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=rank, val=1, count=1) + + out = torch.ops.symm_mem._async_input_mm(A, B, A_signals, rank) + for step in range(1, world_size): + src_rank = (rank + step) % world_size + src_buf = symm_mem.get_buffer(src_rank, A_shard.shape, A_shard.dtype) + with backend_stream: + A_shards[src_rank].copy_(src_buf) + if not torch.cuda.is_current_stream_capturing(): + # cuStreamWriteValue32 issues a system level fence before the write + _SymmetricMemory.stream_write_value32(A_signals, src_rank, 1) + else: + _SymmetricMemory.memset32(A_signals, offset=src_rank, val=1, count=1) + + current_stream.wait_stream(backend_stream) + backend_stream.wait_stream(current_stream) + + symm_mem.barrier() + return A, out + + +def _should_use_multimem_all_gather_matmul( + A_shard: torch.Tensor, + gather_dim: int, + group_name: str, + return_A: bool, +) -> bool: + group = c10d._resolve_process_group(group_name) + local_M = math.prod(A_shard.shape[:-1]) + has_multicast_support = ( + A_shard.device.type == "cuda" + and _SymmetricMemory.has_multicast_support( + DeviceType.CUDA, A_shard.device.index + ) + ) + + return ( + has_multicast_support + and not return_A + and A_shard.is_contiguous() + and gather_dim == 0 + # The heuristic is empirical. We could refine it with a more + # sophisticated perf model. + and local_M * group.size() <= 2048 + ) + + +def _multimem_all_gather_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + group_name: str, +) -> list[torch.Tensor]: + group = c10d._resolve_process_group(group_name) + A_shape = torch.Size((A_shard.shape[0] * group.size(), *A_shard.shape[1:])) + symm_mem = get_symm_mem_workspace( + group_name, A_shape.numel() * A_shard.element_size() + ) + A = symm_mem.get_buffer(symm_mem.rank, A_shape, A_shard.dtype) + torch.ops.symm_mem.multimem_all_gather_out(A_shard, group_name, A) + return [torch.matmul(A, B) for B in Bs] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "Meta") +def _fused_all_gather_scaled_matmul_fallback( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: str, + biases: list[Optional[torch.Tensor]], + result_scales: list[Optional[torch.Tensor]], + out_dtypes: list[Optional[torch.dtype]], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + group_size = c10d._get_group_size_by_name(group_name) + A = torch.ops._c10d_functional.all_gather_into_tensor( + A_shard.contiguous(), group_size, group_name + ) + A = torch.ops._c10d_functional.wait_tensor(A) + A = A.view(group_size, *A_shard.shape).movedim(gather_dim + 1, 1).flatten(0, 1) + + scale_mode = _check_and_verify_fp8_all_gather_scale_mode( + shard=A_shard, scale=A_scale, gather_dim=gather_dim, group_size=group_size + ) + if scale_mode == _ScaleMode.ROW_WISE_SHARDED: + A_scale_shard = A_scale + A_scale = torch.ops._c10d_functional.all_gather_into_tensor( + A_scale.contiguous(), group_size, group_name + ) + A_scale = torch.ops._c10d_functional.wait_tensor(A_scale) + A_scale = ( + A_scale.view(group_size, *A_scale_shard.shape) + .movedim(gather_dim + 1, 1) + .flatten(0, -2) + ) + elif scale_mode == _ScaleMode.ROW_WISE_REPLICATED: + A_scale = A_scale.movedim(gather_dim, 0).flatten(0, -2) + else: + assert scale_mode == _ScaleMode.TENSOR_WISE + + def scaled_matmul( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + bias: Optional[torch.Tensor], + result_scale: Optional[torch.Tensor], + out_dtype: Optional[torch.dtype], + use_fast_accum: bool, + ) -> torch.Tensor: + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm( + A.flatten(0, -2), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype=out_dtype, + use_fast_accum=use_fast_accum, + ) + return res.unflatten(0, leading_dims) + + return A.movedim(0, gather_dim), [ + scaled_matmul( + A, B, A_scale, B_scale, bias, result_scale, out_dtype, fast_accum + ).movedim(0, gather_dim) + for B, B_scale, bias, result_scale, out_dtype, fast_accum in zip( + Bs, B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ] + + +@torch.library.impl(lib, "fused_all_gather_scaled_matmul", "CUDA") +def _fused_all_gather_scaled_matmul( + A_shard: torch.Tensor, + Bs: list[torch.Tensor], + A_scale: torch.Tensor, + B_scales: list[torch.Tensor], + gather_dim: int, + group_name: str, + biases: list[Optional[torch.Tensor]], + result_scales: list[Optional[torch.Tensor]], + out_dtypes: list[Optional[torch.dtype]], + use_fast_accum: list[bool], +) -> tuple[torch.Tensor, list[torch.Tensor]]: + """ + Perform the following logic with micro-pipelined computation and + communication: + + A = all_gather_tensor(A_shard, gather_dim, group_name) + leading_dims = A.shape[:-1] + res = torch.ops.aten._scaled_mm(A.flatten(0, -2), B, A_scale, B_scale) + res = res.unflatten(0, leading_dims) + + The input `A_scale` can be tensor-wise, row-wise-sharded or + row-wise-replicated. + + Optimal stride order for `A_shard` - if `A_shard.movedim(gather_dim, 0)` is + contiguous, no extra copy is required for input layout transformation. + Otherwise A_shard needs to be copied once. + """ + out_dtypes = _maybe_convert_scalar_types_to_dtypes(out_dtypes) + + if len(biases) != len(Bs): + raise ValueError("len(biases) must be the same as len(Bs)") + if len(result_scales) != len(Bs): + raise ValueError("len(result_scales) must be the same as len(Bs)") + if len(out_dtypes) != len(Bs): + raise ValueError("len(out_dtypes) must be the same as len(Bs)") + if len(use_fast_accum) != len(Bs): + raise ValueError("len(use_gast_accum_list) must be the same as len(Bs)") + + if _is_test_mode: + return _fused_all_gather_scaled_matmul_fallback( + A_shard, + Bs, + A_scale, + B_scales, + gather_dim, + group_name, + biases, + result_scales, + out_dtypes, + use_fast_accum, + ) + + with torch.profiler.record_function("fused_all_gather_scaled_matmul"): + A, res = _fused_all_gather_matmul_impl( + torch.ops.aten._scaled_mm.out, + A_shard, + Bs, + A_scale, + [ + { + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": fast_accum, + } + for B_scale, bias, result_scale, out_dtype, fast_accum in zip( + B_scales, biases, result_scales, out_dtypes, use_fast_accum + ) + ], + out_dtypes, + gather_dim, + group_name, + True, + ) + assert A is not None + return A, res + + +def make_contiguous_for_perm( + t: torch.Tensor, + perm: list[int], +) -> torch.Tensor: + """ + Restride `t` such that `t.permute(perm)` is contiguous. + """ + inv_perm = [0] * len(perm) + for i, p in enumerate(perm): + inv_perm[p] = i + return t.permute(perm).contiguous().permute(inv_perm) + + +def restride_A_shard_for_fused_all_gather_matmul( + t: torch.Tensor, + gather_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_all_gather_matmul` for optimal perf. + See the doc for `fused_all_gather_matmul` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(gather_dim)) + return make_contiguous_for_perm(t, perm) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "CUDA") +def _fused_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + """ + Perform the following logic with micro-pipelined computation and + communication: + + reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + + Optimal stride order for A - if A.movedim(scatter_dim, 0) is contiguous, no + extra copy is required for input layout transformation. Otherwise A needs + to be copied once. + """ + if _is_test_mode: + return _fused_matmul_reduce_scatter_fallback( + A, B, reduce_op, scatter_dim, group_name + ) + + with torch.profiler.record_function("fused_matmul_reduce_scatter"): + return _fused_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten.mm.out, + A=A, + B=B, + kwargs={}, + out_dtype=A.dtype, + reduce_op=reduce_op, + scatter_dim=scatter_dim, + group_name=group_name, + ) + + +@torch.library.impl(lib, "fused_matmul_reduce_scatter", "Meta") +def _fused_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + res = funcol.reduce_scatter_tensor(A @ B, reduce_op, scatter_dim, group_name) + res = funcol.wait_tensor(res) + return res + + +def _fused_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: Optional[torch.dtype], + reduce_op: str, + scatter_dim: int, + group_name: str, +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if scatter_dim < 0 or scatter_dim >= A.dim(): + raise ValueError("Invalid gather_dim") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + out_shape = [*A.shape[:-1], B.shape[1]] + out_shape[scatter_dim] //= group.size() + + # Move the scatter_dim to the front and flatten the tensor into a 2D matrix + x = A.movedim(scatter_dim, 0) + leading_dims = [group.size()] + list(x.shape[:-1]) + leading_dims[1] //= group.size() + x = x.flatten(0, -2) + A_shards = x.chunk(group.size()) + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, **kwargs, out=out) + + stacked_partials = x.new_empty(x.shape[0], B.shape[1], dtype=out_dtype or A.dtype) + + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + return reduce_fn( + stacked_partials.view(*leading_dims, -1) + .movedim(1, scatter_dim + 1) + .movedim(0, scatter_dim), + dim=scatter_dim, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "CUDA") +def _fused_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: Optional[torch.Tensor] = None, + result_scale: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if _is_test_mode: + return _fused_scaled_matmul_reduce_scatter_fallback( + A, + B, + A_scale, + B_scale, + reduce_op, + orig_scatter_dim, + scatter_dim_after_maybe_reshape, + group_name, + output_shape, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + with torch.profiler.record_function("fused_scaled_matmul_reduce_scatter"): + return _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op=torch.ops.aten._scaled_mm.out, + A=A, + B=B, + A_scale=A_scale, + kwargs={ + "scale_b": B_scale, + "bias": bias, + "scale_result": result_scale, + "out_dtype": out_dtype, + "use_fast_accum": use_fast_accum, + }, + out_dtype=out_dtype, + reduce_op=reduce_op, + orig_scatter_dim=orig_scatter_dim, + scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape, + group_name=group_name, + output_shape=output_shape, + ) + + +@torch.library.impl(lib, "fused_scaled_matmul_reduce_scatter", "Meta") +def _fused_scaled_matmul_reduce_scatter_fallback( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + bias: Optional[torch.Tensor] = None, + result_scale: Optional[torch.Tensor] = None, + out_dtype: Optional[torch.dtype] = None, + use_fast_accum: bool = False, +) -> torch.Tensor: + if A_scale.numel() > 1: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = A_scale.flatten(0, -2).contiguous() + elif A_scale.numel() != 1: + raise ValueError( + "Invalid A_scale shape " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + + C = torch._scaled_mm( + A.flatten(0, -2).contiguous(), + B, + A_scale, + B_scale, + bias, + result_scale, + out_dtype, + use_fast_accum, + ) + C = C.view(*output_shape[:-1], B.shape[1]) + res = funcol.reduce_scatter_tensor( + C, + reduce_op, + orig_scatter_dim, # need original scatter dim for 3D+ output tensor here + group_name, + ) + res = funcol.wait_tensor(res) + return res + + +def _fused_scaled_matmul_reduce_scatter_impl( + mm_out_op: torch._ops.OpOverload, + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + kwargs: dict[str, Any], + out_dtype: Optional[torch.dtype], + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], +) -> torch.Tensor: + if A.dim() < 2: + raise ValueError("A_shard must be a matrix") + if ( + scatter_dim_after_maybe_reshape < 0 + or scatter_dim_after_maybe_reshape >= A.dim() + ): + raise ValueError("Invalid scatter dim for 2D tensor input to scaled_mm") + if orig_scatter_dim < 0 or orig_scatter_dim >= len(output_shape): + raise ValueError("Invalid scatter dim for 3D+ output tensor") + if B.dim() != 2: + raise ValueError("B must be a matrix") + if reduce_op == "sum": + reduce_fn = partial(torch.sum, dim=0) + elif reduce_op == "avg": + reduce_fn = partial(torch.mean, dim=0) + else: + raise ValueError("reduce_op must be sum or avg") + + group = c10d._resolve_process_group(group_name) + + # Move scatter to first dim, then shard the tensor along the first dim, so the chunk producer + # can perform matmuls along the first dim. + A_with_scatter_dim_0 = A.movedim(scatter_dim_after_maybe_reshape, 0) + + # To handle case where A is 3D+, reshape to 2D to prepare for mm which requires 2D inputs. + A_2D_with_scatter_dim_0 = A_with_scatter_dim_0.flatten(0, -2) + + # Partition A along the first dim to prepare for sharding across TP process group. + A_shards = A_2D_with_scatter_dim_0.chunk(group.size()) + + # Now that 'A' is sharded along the first dim, we need to update its scale(s) accordingly. + # How we do this depends on if we are using tensorwise scaling, rowwise scaling, or no scaling. + tensorwise_scaling = A_scale is not None and A_scale.numel() == 1 + rowwise_scaling = A_scale is not None and A_scale.numel() > 1 + + # For tensorwise scaling, the scale should be replicated so each shard has a copy. + if tensorwise_scaling: + A_scale_shards = [A_scale] * group.size() + + # For rowwise scaling, we need to move the scatter dim to the first dim to match the + # dim swap of the 'A' tensor. Then we can shard the scales along the first dim, just like + # the 'A' tensor. + elif rowwise_scaling: + if A_scale.shape[:-1] != A.shape[:-1]: + raise ValueError( + "For row-wise scaling, the leading dims of A_scale " + "must match the leading dims of A " + f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})" + ) + A_scale = ( + A_scale.movedim(scatter_dim_after_maybe_reshape, 0) + .contiguous() + .flatten(0, -2) + ) + A_scale_shards = list(A_scale.chunk(group.size())) + else: + raise ValueError("A_scale cannot be none for scaled_mm") + + # Computing block-wise matmul along the first dim of A + def chunk_producer(rank: int, out: torch.Tensor) -> None: + mm_out_op(A_shards[rank], B, scale_a=A_scale_shards[rank], **kwargs, out=out) + + # Stacked partials will be the 2D outputs of the the pipelined scaled mm, and will + # have the shape (A_with_scatter_dim_0_tensor.shape[0], B.shape[1]) to align with the formula: + # (a*b,c) @ (c,d) = (a*b,d) + stacked_partials = A_with_scatter_dim_0.new_empty( + A_2D_with_scatter_dim_0.shape[0], B.shape[1], dtype=out_dtype or A.dtype + ) + + # Execute the pipelined mm/scaled_mm. + _pipelined_produce_and_all2all( + chunk_producer, + stacked_partials, + group_name, + ) + + # We now need to transform the *unreduced* stacked 2D partial mm outputs to an *unreduced* 3D+ output, + # then reduce-scatter. To do this, we first need to determine the shape of the unreduced 3D+ output, + # to reshape our stacked partials so we can apply the reduce-scatter. + # + # The *unreduced* 3D+ tensor will have dim 0 = `group_size`, as we have `group_size` instances of + # stacked partial outputs. The next dims will be A's leading dims (sharded along the original scatter dim), + # as it was the left operand of the mm op. We can use -1 as the final dim of the view to populate the rest. + stacked_partials_3D_leading_dims = [group.size()] + list( + # We use A from after the dim swap 0<=>scatter_dim, but before the flatten, + # to get the leading dims of the 3D+ view of stacked partials. + A_with_scatter_dim_0.shape[:-1] + ) + + # The `group_size` leading dim has been prepended to `stacked_partials_3D_leading_dims`, + # to capture the partial output from each rank. We need to divide the sharding/scatter dim + # by the group size. If the original scatter dim was 0, then it is now dim 1 in this + # tensor, since this new `group_size` dim was prepended. + stacked_partial_scatter_dim = orig_scatter_dim if orig_scatter_dim > 0 else 1 + stacked_partials_3D_leading_dims[stacked_partial_scatter_dim] //= group.size() + + # Ensures that the transpose and reduction produce contiguous result + # in a single reduction kernel. + reduced_out = reduce_fn( + # View 2D stacked partials as 3D+ tensor of shape (`group_size`, ...) + stacked_partials.view(*stacked_partials_3D_leading_dims, -1) + # We originally swapped 0<=>scatter_dim_after_maybe_reshape. Now after + # prepending the `group_size` dim, to undo this original swap, we + # must swap 1<=>scatter_dim_after_maybe_reshape+1. + .movedim(1, scatter_dim_after_maybe_reshape + 1), + # Reduce along the `group_size` dim (0). + dim=0, + ) + + # Output shape must be scattered along original scatter dim as well. + output_shape[orig_scatter_dim] //= group.size() + out = reduced_out.view(*output_shape) + return out + + +def restride_A_for_fused_matmul_reduce_scatter( + t: torch.Tensor, + scatter_dim: int, +) -> torch.Tensor: + """ + Restride the `A_shard` arg of `fused_matmul_reduce_scatter` for optimal + perf. See the doc for `fused_matmul_reduce_scatter` for detail. + """ + perm = list(range(len(t.shape))) + perm.insert(0, perm.pop(scatter_dim)) + return make_contiguous_for_perm(t, perm) + + +def _maybe_convert_scalar_types_to_dtypes( + scalar_types: list[Any], +) -> list[Optional[torch.dtype]]: + """ + When a list of `torch.dtype`s is passed through the dispatcher as + `ScalarType[]`, it is converted to a list of scalar type enum values. This + function converts it back to a list of `torch.dtype`s. + """ + # Order defined in https://github.com/pytorch/pytorch/blob/344defc9733a45fee8d0c4d3f5530f631e823196/c10/core/ScalarType.h + _SCALAR_TYPE_TO_DTYPE = { + 0: torch.uint8, + 1: torch.int8, + 2: torch.short, + 3: torch.int, + 4: torch.int64, + 5: torch.half, + 6: torch.float, + 7: torch.double, + 8: torch.complex32, + 9: torch.complex64, + 10: torch.complex128, + 11: torch.bool, + 12: torch.qint8, + 13: torch.quint8, + 14: torch.qint32, + 15: torch.bfloat16, + 16: torch.float8_e5m2, + 17: torch.float8_e4m3fn, + 18: torch.float8_e5m2fnuz, + 19: torch.float8_e4m3fnuz, + } + if any(not isinstance(x, (type(None), int)) for x in scalar_types): + return scalar_types + + dtypes: list[Optional[torch.dtype]] = [] + for scalar_type in scalar_types: + if scalar_type is None: + dtypes.append(scalar_type) + elif scalar_type not in _SCALAR_TYPE_TO_DTYPE: + raise ValueError("Unrecognized scalar type {scalar_type}") + else: + dtypes.append(_SCALAR_TYPE_TO_DTYPE[scalar_type]) + return dtypes + + +class Work(_Work): + def __init__(self) -> None: + super().__init__() + self.event = torch.cuda.Event() + self.event.record() + + def wait(self, timeout: timedelta = timedelta(seconds=0)) -> bool: + self.event.wait() + return True + + +""" +NOTE [low-contention collectives] +When a collective is overlapped with abundant compute, it makes sense to +prioritize reducing the contention between the collective and the overlapped +compute, even at the cost of a slightly slower collective. + +Common collective implementations (e.g., NCCL without user buffer +registration) optimize for throughput with no ambient compute. However, such +implementations may not be optimal when they are overlapped with compute: +- These implementations typically fuse the entire collective into a single +kernel and reserve SM resources based on the most demanding portion of the +collective, even when a large portion of the collective does not require this +much resource. +- These implementations often use SM-based P2P copy as opposed to copy +engine-based P2P copy. Copy engine-based P2P copy may not have a significant +advantage when there's no ambient compute. However, it may significantly +improve overall resource utilization in the presence of ambient compute. + +When overlapped with intensive compute (e.g., persistent matmul kernels), the +SM-usage of a collective can lead to inefficient overlapping. + +Low-contention collectives achieve their goals with the following strategies: +- Use copy engine-based copy whenever possible. +- Break down portions of a collective with different resource requirements +into multiple kernels. This improves the overlapping efficiency at the cost +of additional launching overhead. +""" + + +@torch.library.impl(lib, "_low_contention_all_gather", "Meta") +def _low_contention_all_gather_meta( + tensor: torch.Tensor, + group_name: str, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.new_empty(tensor.shape[0] * group_size, *tensor.shape[1:]) + + +@torch.library.impl(lib, "_low_contention_all_gather", "CUDA") +def _low_contention_all_gather( + tensor: torch.Tensor, + group_name: str, +) -> torch.Tensor: + """ + Performs all-gather with symmetric memory in a low-contention fashion. + + When `tensor` is already in symmetric memory: + - The collective is carried out without using SMs. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - An extra SM-based copy is performed to copy the input data into the + symmetric memory workspace. + - Symmetric memory workspace size requirement: the size of `tensor`. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + input_is_symm_mem = True + else: + symm_mem = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + input_is_symm_mem = False + + rank = symm_mem.rank + world_size = symm_mem.world_size + + output = tensor.new_empty(tensor.shape[0] * world_size, *tensor.shape[1:]) + chunks = output.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + if not input_is_symm_mem: + local_buf = symm_mem.get_buffer(rank, tensor.shape, tensor.dtype) + local_buf.copy_(tensor) + # pull + symm_mem.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer(remote_rank, tensor.shape, tensor.dtype) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + torch._C._distributed_c10d._register_work(output, Work()) + return output + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "Meta") +def _low_contention_reduce_scatter_meta( + tensor: torch.Tensor, + reduce_op: str, + group_name: str, +) -> torch.Tensor: + group_size = c10d._get_group_size_by_name(group_name) + return tensor.unflatten(0, (group_size, -1)).mean(dim=0) + + +def _low_contention_reduce_scatter_with_symm_mem_input( + tensor: torch.Tensor, + reduce_op: str, + symm_mem: _SymmetricMemory, +) -> torch.Tensor: + rank = symm_mem.rank + world_size = symm_mem.world_size + + assert tensor.shape[0] % world_size == 0 + a2a_res = torch.empty_like(tensor) + chunks = a2a_res.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # pull + offline reduction + symm_mem.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + src_buf = symm_mem.get_buffer( + remote_rank, + chunks[0].shape, + chunks[0].dtype, + chunks[0].numel() * rank, + ) + chunks[remote_rank].copy_(src_buf) + symm_mem.barrier() + + ret = a2a_res.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + torch._C._distributed_c10d._register_work(ret, Work()) + return ret + + +def _low_contention_reduce_scatter_with_workspace( + tensor: torch.Tensor, + reduce_op: str, + workspace: _SymmetricMemory, +) -> torch.Tensor: + rank = workspace.rank + world_size = workspace.world_size + + assert tensor.shape[0] % world_size == 0 + chunks = tensor.chunk(world_size) + + _get_backend_stream().wait_stream(torch.cuda.current_stream()) + with _get_backend_stream(): + # push + offline reduction + workspace.barrier() + for step in range(0, world_size): + remote_rank = (rank - step) % world_size + dst_buf = workspace.get_buffer( + remote_rank, chunks[0].shape, chunks[0].dtype, chunks[0].numel() * rank + ) + dst_buf.copy_(chunks[remote_rank]) + workspace.barrier() + + buf = workspace.get_buffer(rank, tensor.shape, tensor.dtype) + ret = buf.unflatten(0, (world_size, -1)) + if reduce_op == "sum": + ret = ret.sum(dim=0) + elif reduce_op == "avg": + ret = ret.mean(dim=0) + else: + raise ValueError(f"reduce_op ({reduce_op}) is not supported") + torch._C._distributed_c10d._register_work(ret, Work()) + return ret + + +@torch.library.impl(lib, "_low_contention_reduce_scatter", "CUDA") +def _low_contention_reduce_scatter( + tensor: torch.Tensor, + reduce_op: str, + group_name: str, +) -> torch.Tensor: + """ + Performs reduce-scatter with symmetric memory in a low-contention fashion. + + This implementation performs a P2P-based all-to-all followed by an offline + reduction. + + When `tensor` is already in symmetric memory: + - Pull-based all-to-all is used. + - No symmetric memory workspace is required. + + When `tensor` is not in symmetric memory: + - Push-based all-to-all is used. + - Symmetric memory workspace size requirement: the size of `tensor`. + + SM-usage: + - SM-based copy of the rank's own chunk for the all-to-all. + - Reduction on the all-to-all result. + + TODO(yifu): the SM-based copy can be avoided with a list-based reduction + kernel. + """ + symm_mem = rendezvous(tensor, group_name) + if symm_mem is not None: + return _low_contention_reduce_scatter_with_symm_mem_input( + tensor, reduce_op, symm_mem + ) + else: + workspace = get_symm_mem_workspace( + group_name, tensor.numel() * tensor.element_size() + ) + return _low_contention_reduce_scatter_with_workspace( + tensor, reduce_op, workspace + ) + + +# ============================================================================= +# User-facing APIs +# ============================================================================= + + +from collections.abc import Sequence +from typing import Any, overload, TYPE_CHECKING, Union + +from torch.types import _device, _dtype, _int + + +if TYPE_CHECKING: + from torch._C._distributed_c10d import ProcessGroup + + +@overload +def empty( + *size: _int, dtype: Optional[_dtype] = None, device: Optional[_device] = None +) -> torch.Tensor: ... + + +@overload +def empty( + size: Sequence[_int], + *, + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +) -> torch.Tensor: ... + + +def empty( # type: ignore[misc] + *size: Any, + dtype: Optional[_dtype] = None, + device: Optional[_device] = None, +) -> torch.Tensor: + r""" + empty(*size, *, dtype=None, device=None) -> Tensor + + Similar to :func:`torch.empty()`. The returned tensor can be used by + :func:`torch._distributed._symmetric_memory.rendezvous()` to establish a + symmetric memory tensor among participating processes. + + Args: + size (int...): a sequence of integers defining the shape of the output tensor. + Can be a variable number of arguments or a collection like a list or tuple. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + device (:class:`torch.device`, optional): the desired device of returned tensor. + Default: if ``None``, uses the current device for the default tensor type + (see :func:`torch.set_default_device`). :attr:`device` will be the CPU + for CPU tensor types and the current CUDA device for CUDA tensor types. + """ + if len(size) == 1 and isinstance(size[0], Sequence): + size = tuple(size[0]) + else: + size = tuple(size) + + if dtype is None: + dtype = torch.get_default_dtype() + + if device is None: + device = torch.get_default_device() + + return _SymmetricMemory.empty_strided_p2p( + size=size, + stride=torch._prims_common.make_contiguous_strides_for(size), + dtype=dtype, + device=torch.device(device), + ) + + +def rendezvous( + tensor: torch.Tensor, group: Union[str, "ProcessGroup"] +) -> _SymmetricMemory: + r""" + rendezvous(tensor, group) -> _SymmetricMemory + + Establish a symmetric memory tensor among participating processes. This is + a collective operation. + + Args: + tensor (:class:`torch.Tensor`): the local tensor used to establish the symmetric memory tensor. + It must be allocated via :func:`torch._distributed._symmetric_memory.empty()`. The shape, + dtype, and device type must be identical across all participating processes. + group (Union[str, :class:`torch.distributed.ProcessGroup`]): The group identifying the + participating processes. This can be either a group name or a process group object. + """ + from torch._C._distributed_c10d import ProcessGroup + + if isinstance(group, str): + group_name = group + elif isinstance(group, ProcessGroup): + group_name = group.group_name + else: + raise TypeError(f"rendezvous: unsupported group type: {type(group)}") + + enable_symm_mem_for_group(group_name) + return _SymmetricMemory.rendezvous(tensor, group_name) + + +def is_nvshmem_available() -> bool: + r""" + is_nvshmem_available() -> bool + + Check if NVSHMEM is available in current build and on current system. + """ + try: + from torch._C._distributed_c10d import _is_nvshmem_available + except ImportError: + # Not all builds have NVSHMEM support. + return False + + # Check if NVSHMEM is available on current system. + return _is_nvshmem_available() + + +__all__ = ["empty", "rendezvous", "is_nvshmem_available"] diff --git a/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b2a15a3b101321f02497433dc0b2d7d500bddc3f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f8c3f96e1e9a669025882a4c6fd8123862a57393 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/__pycache__/_nvshmem_triton.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..afe01b03becf0278c17c16caea6ed4edcef6d4f5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_symmetric_memory/_nvshmem_triton.py @@ -0,0 +1,181 @@ +import os +import sysconfig +from typing import Optional + +from torch.utils._triton import has_triton + + +def enable_triton(lib_dir: Optional[str] = None) -> dict[str, str]: + """ + Enable NVSHMEM device functions for Triton. It performs a NVSHMEM + device-side initialization on the kernel module created by Triton. + + Args: + lib_dir (Optional[str]): The directory where the NVSHMEM device library + is located. If not provided, it will use the default path where NVSHMEM + wheel is installed. + + Returns: + dict[str, str]: A dictionary containing the NVSHMEM device library name + and path. + """ + from triton.runtime.jit import JITFunction + + from torch._C._distributed_c10d import _nvshmemx_cumodule_init + + # Detect NVSHMEM device library path from python library path + if lib_dir is None: + py_lib_path = sysconfig.get_path("purelib") + lib_dir = py_lib_path + "/nvidia/nvshmem/lib" + + lib_path = os.path.join(lib_dir, "libnvshmem_device.bc") + if not os.path.exists(lib_path): + raise RuntimeError("NVSHMEM device library not found") + + extern_libs = {"libnvshmem_device": lib_path} + + # A hook function to initialize NVSHMEM in Triton + def nvshmem_init_hook(*args, **kwargs) -> None: # type: ignore[no-untyped-def] + key = kwargs["key"] + device = kwargs["compile"]["device"] + jit_function = kwargs["fn"].jit_function + kernel_cache, _, _, _ = jit_function.device_caches[device] + kernel = kernel_cache.get(key, None) + kernel.run + _nvshmemx_cumodule_init(kernel.module) + + # Register the function as a post-compile hook + JITFunction.compiled_hook = nvshmem_init_hook + + # Return to user so that they can use it in Triton kernel invocation + return extern_libs + + +if has_triton(): + from triton.language import core + + @core.extern + def putmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_putmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def getmem_block(dst, src, nelems, pe, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_getmem_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def putmem_signal_block( # type: ignore[no-untyped-def] + dst, + src, + nelems, + sig_addr, + signal, + sig_op, + pe, + _builder=None, + ): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [dst, src, nelems, sig_addr, signal, sig_op, pe], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmemx_putmem_signal_block", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def wait_until(ivar, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [ivar, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_longlong_wait_until", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def signal_wait_until(sig_addr, cmp, cmp_val, _builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [sig_addr, cmp, cmp_val], + { + ( + core.dtype("int64"), + core.dtype("int64"), + core.dtype("int64"), + ): ("nvshmem_signal_wait_until", core.dtype("int32")) + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def fence(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_fence", core.dtype("int32")), + }, + is_pure=False, + _builder=_builder, + ) + + @core.extern + def quiet(_builder=None): # type: ignore[no-untyped-def] + return core.extern_elementwise( + "", + "", + [], + { + (): ("nvshmem_quiet", core.dtype("int32")), + }, + is_pure=False, + _builder=_builder, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/_tensor/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2da4997d4152e2651b8e9f25a66315593a67dd33 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tensor/__init__.py @@ -0,0 +1,45 @@ +""" +NOTICE: DTensor has moved to torch.distributed.tensor + +This file is a shim to redirect to the new location, and +we keep the old import path starts with `_tensor` for +backward compatibility. We will remove this folder once +we resolve all the BC issues. +""" + +import sys +from importlib import import_module + + +submodules = [ + # TODO: _shards_wrapper/_utils here mainly for checkpoint BC, remove them + "_shards_wrapper", + "_utils", + "experimental", + "device_mesh", +] + +# Redirect imports +for submodule in submodules: + full_module_name = f"torch.distributed.tensor.{submodule}" + sys.modules[f"torch.distributed._tensor.{submodule}"] = import_module( + full_module_name + ) + +from torch.distributed.tensor import ( # noqa: F401 + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + init_device_mesh, + ones, + Partial, + Placement, + rand, + randn, + Replicate, + Shard, + zeros, +) diff --git a/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ea2dad396b3c21d1763c1263059114b09f94ebc1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b98570de93245679d1af697454d8a9cb6604c13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a0d048eb9ba1e03e1231d7f9adf7b46f5433f4e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tensor/__pycache__/placement_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tensor/api.py b/phivenv/Lib/site-packages/torch/distributed/_tensor/api.py new file mode 100644 index 0000000000000000000000000000000000000000..a9cabfea3189070e4db5c9b64e7988384ed9e7fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tensor/api.py @@ -0,0 +1,9 @@ +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from torch.distributed.tensor._api import * # noqa: F401, F403 diff --git a/phivenv/Lib/site-packages/torch/distributed/_tensor/placement_types.py b/phivenv/Lib/site-packages/torch/distributed/_tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..b5716dd807d02a9f6e9c55ffa78ee4b04c37996c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tensor/placement_types.py @@ -0,0 +1,10 @@ +""" +NOTE: torch.distributed._tensor has been moved to torch.distributed.tensor. +The imports here are purely for backward compatibility. We will remove these +imports in a few releases + +TODO: throw warnings when this module imported +""" + +from torch.distributed.tensor._dtensor_spec import * # noqa: F401, F403 +from torch.distributed.tensor.placement_types import * # noqa: F401, F403 diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__init__.py b/phivenv/Lib/site-packages/torch/distributed/_tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..284b6180a4a6495228ec2bbe5988b14e783e47e6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/__init__.py @@ -0,0 +1,12 @@ +from .fsdp2_mem_tracker import FSDPMemTracker +from .mem_tracker import MemTracker +from .memory_tracker import MemoryTracker +from .mod_tracker import ModTracker +from .runtime_estimator import RuntimeEstimator +from .sac_estimator import ( + MSPS, + SACEstimator, + SACGreedyOrderMeta, + SACStats, + SACTradeOffStats, +) diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c34c482341377fdf951d20d4923e3d1af9f6c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3fe8f6541e9df37cd12e74c855cab89cad02aee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/common_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..256b813b01bd08c13dba9fcadb3a98b81efa34e2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/fake_collectives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cefd5a17c8ef45da95c5ae17476a41ec344be7d6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/fsdp2_mem_tracker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..324940153fecdf1833259b091c6490292e521a8d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/ilp_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8eb0fe7256155b22488e3ba370e3c649a409a9a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/mem_tracker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a2d04a0f1fefa01a834bcc9d57e374a168828c7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/memory_tracker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c44e0b724f65acd8b337592a9bc4ec61281d346c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/mod_tracker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..616023f76c0fd7ce3036104579842dcafbfd8adb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/runtime_estimator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ef1ab77e69550b956ff3f77df08cb48bc4c1990 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/sac_estimator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..933d01d78a2215a8d7d31146b021d382f8042d13 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/_tools/__pycache__/sac_ilp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/common_utils.py b/phivenv/Lib/site-packages/torch/distributed/_tools/common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..33a18ddb4041c770df1a99f5e31a9da171e5b2ef --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/common_utils.py @@ -0,0 +1,33 @@ +import warnings + +import torch +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +def get_untyped_storages(t: torch.Tensor) -> set[torch.UntypedStorage]: + """ + Recursively extracts untyped storages from a tensor or its subclasses. + + Args: + t (torch.Tensor): The tensor to extract storages from. + + Returns: + Set[torch.UntypedStorage]: A set of untyped storages. + """ + unflattened_tensors = [t] + flattened_tensor_storages = set() + while len(unflattened_tensors) > 0: + obj = unflattened_tensors.pop() + if is_traceable_wrapper_subclass(obj): + attrs, _ = obj.__tensor_flatten__() # type: ignore[attr-defined] + unflattened_tensors.extend([getattr(obj, attr) for attr in attrs]) + else: + if not hasattr(obj, "untyped_storage"): + warnings.warn( + f"Expected a tensor or a traceable wrapper-subclass of tensor, but got {type(obj)}", + category=UserWarning, + stacklevel=2, + ) + else: + flattened_tensor_storages.add(obj.untyped_storage()) + return flattened_tensor_storages diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/fake_collectives.py b/phivenv/Lib/site-packages/torch/distributed/_tools/fake_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..c3a978c0af5e28c62b4609974201ea5722649687 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/fake_collectives.py @@ -0,0 +1,307 @@ +import random +from typing import Any + +import torch +from torch._C._distributed_c10d import ( + _resolve_process_group, + FakeWork, + ProcessGroup, + Work, +) +from torch.utils._pytree import tree_map_only + + +torch.distributed.batch_isend_irecv + +c10d = torch.ops.c10d +_c10d_functional = torch.ops._c10d_functional +_c10d_functional_autograd = torch.ops._c10d_functional_autograd +_dtensor = torch.ops._dtensor +used_ids: set[int] = set() + + +def generate_unique_id() -> int: + while True: + new_id = random.randint(1, 10**9) + if new_id not in used_ids: + used_ids.add(new_id) + return new_id + + +# Function to create and return FakeWork object +def create_fakework(args, return_first_arg=True): # type: ignore[no-untyped-def] + work = FakeWork() + work.seq_id = generate_unique_id() + fakework_script_obj = work.boxed() + return (args[0], fakework_script_obj) if return_first_arg else fakework_script_obj + + +# Dictionary mapping collective operations to their meta functions +# All 20 ops from torch.csrc.distributed.c10d.Ops.cpp are included +# _DEPRECATED_META_FUNCTIONS = { +# "allreduce_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "allgather_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "allgather_into_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# "reduce_scatter_tensor_coalesced_": lambda *args: create_fakework(args, return_first_arg=False), +# } +_META_FUNCTIONS = { + "broadcast_": lambda *args: create_fakework(args), + "allreduce_": lambda *args: create_fakework(args), + "allgather_": lambda *args: create_fakework(args), + "_allgather_base_": lambda *args: create_fakework(args), + "reduce_scatter_": lambda *args: create_fakework(args), + "_reduce_scatter_base_": lambda *args: create_fakework(args), + "reduce_": lambda *args: create_fakework(args, return_first_arg=False), + "gather_": lambda *args: create_fakework(args, return_first_arg=False), + "scatter_": lambda *args: create_fakework(args), + "alltoall_": lambda *args: create_fakework(args), + "alltoall_base_": lambda *args: create_fakework(args, return_first_arg=False), + "barrier": lambda *args: create_fakework(args, return_first_arg=False), + "monitored_barrier_": lambda *args: None, + "send": lambda *args: create_fakework(args, return_first_arg=False), + "recv_": lambda *args: create_fakework(args, return_first_arg=False), + "recv_any_source_": lambda *args: create_fakework(args, return_first_arg=False), +} + +if not torch._running_with_deploy(): + lib_impl = torch.library.Library("c10d", "IMPL") # noqa: TOR901 + for op, meta_func in _META_FUNCTIONS.items(): + lib_impl.impl(op, meta_func, "Meta") + +# List of collective operation functions including functional collectives +# Note: The following collectives might be deprecated soon hence not adding them +# depcreated_non_functional_collectives = [ +# c10d.allreduce_coalesced_.default, +# c10d.reduce_scatter_tensor_coalesced_.default, +# c10d.allgather_into_tensor_coalesced_.default, +# c10d.allgather_coalesced_.default, +# ] +non_functional_collectives: set[torch._ops.OpOverload] = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.reduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.allgather_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d._allgather_base_.default, + c10d.gather_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + c10d.alltoall_base_.default, + c10d.barrier.default, + c10d.monitored_barrier_.default, +} +functional_collectives: set[torch._ops.OpOverload] = { + _c10d_functional.broadcast.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _c10d_functional.wait_tensor.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + _c10d_functional.all_gather_into_tensor_out.default, + _c10d_functional.all_gather_into_tensor_coalesced.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + _c10d_functional.broadcast_.default, + _dtensor.shard_dim_alltoall.default, +} + +sync_ops: set[torch._ops.OpOverload] = { + c10d.barrier.default, + c10d.monitored_barrier_.default, + _c10d_functional.wait_tensor.default, +} + +collective_ops = set.union(functional_collectives, non_functional_collectives) + + +class CollectiveOp: + # Static sets for performance optimization + PG_ARG_1 = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.reduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.barrier.default, + # c10d.allreduce_coalesced_.default + } + + PG_ARG_2 = { + c10d.allgather_.default, + c10d._allgather_base_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d.gather_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + c10d.alltoall_base_.default, + # c10d.allgather_coalesced_.default, + # c10d.allgather_into_tensor_coalesced_.default + # c10d.reduce_scatter_tensor_coalesced_.default + } + + PG_ARG_3 = { + _c10d_functional.broadcast.default, + _c10d_functional.broadcast_.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor_out.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + _c10d_functional.all_gather_into_tensor_coalesced.default, + } + + PG_ARG_4 = { + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _dtensor.shard_dim_alltoall.default, + } + + WK_ARG_1 = { + c10d.broadcast_.default, + c10d.allreduce_.default, + c10d.allgather_.default, + c10d.reduce_scatter_.default, + c10d._reduce_scatter_base_.default, + c10d._allgather_base_.default, + c10d.scatter_.default, + c10d.alltoall_.default, + } + + WK = { + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.reduce_.default, + c10d.gather_.default, + c10d.alltoall_base_.default, + c10d.barrier.default, + } + + COMM_TENSOR_ARG_0 = { + c10d.allreduce_.default, + c10d.send.default, + c10d.recv_.default, + c10d.recv_any_source_.default, + c10d.allgather_.default, + c10d.gather_.default, + c10d.reduce_.default, + c10d.broadcast_.default, + _c10d_functional.all_reduce_coalesced.default, + _c10d_functional.all_reduce_coalesced_.default, + # c10d.allreduce_coalesced_.default + # c10d.allgather_coalesced_.default + # c10d.allgather_into_tensor_coalesced_.default, + } + + COMM_TENSOR_ARG_1 = { + c10d.reduce_scatter_.default, + c10d.scatter_.default, + # c10d.reduce_scatter_tensor_coalesced_.default, + } + + COMM_TENSOR_ARG_RES = { + _c10d_functional.all_gather_into_tensor.default, + _c10d_functional_autograd.all_gather_into_tensor.default, + } + + COMM_TENSOR_SINGLE_UNTYPED_STORAGE = { + c10d._allgather_base_.default, + _c10d_functional.broadcast.default, + _c10d_functional.broadcast_.default, + _c10d_functional.all_reduce.default, + _c10d_functional.all_reduce_.default, + _c10d_functional.reduce_scatter_tensor.default, + _c10d_functional_autograd.reduce_scatter_tensor.default, + } + + COMM_TENSOR_ARG_0_AND_RES = { + _c10d_functional.all_to_all_single.default, + _c10d_functional_autograd.all_to_all_single.default, + _dtensor.shard_dim_alltoall.default, + } + + COMM_TENSOR_RES_SUM = { + _c10d_functional.all_gather_into_tensor_coalesced.default, + _c10d_functional.reduce_scatter_tensor_coalesced.default, + } + + @staticmethod + def sum_tensors(arg: Any) -> int: + """Calculate total memory consumed by the tensors in the argument.""" + total_memory = 0 + + def sum_bytes(t: torch.Tensor) -> None: + nonlocal total_memory + total_memory += t.untyped_storage().nbytes() + + tree_map_only(torch.Tensor, sum_bytes, arg) + return total_memory + + @staticmethod + def get_process_group(func, args) -> ProcessGroup: # type: ignore[no-untyped-def] + """Retrieve the process group for collective operations, except `wait_tensor`.""" + if func in CollectiveOp.PG_ARG_1: + return ProcessGroup.unbox(args[1]) + if func in CollectiveOp.PG_ARG_2: + return ProcessGroup.unbox(args[2]) + if func in CollectiveOp.PG_ARG_3: + return _resolve_process_group(args[2]) + if func in CollectiveOp.PG_ARG_4: + return _resolve_process_group(args[3]) + raise TypeError(f"Func {func} not found in {collective_ops}") + + @staticmethod + def get_comm_tensor_size(func, res, args, kwargs) -> int: # type: ignore[no-untyped-def] + """Compute the communication tensor size, except for `wait_tensor`, `barrier`, and `monitored_barrier`.""" + if func in CollectiveOp.COMM_TENSOR_ARG_0: + return CollectiveOp.sum_tensors(args[0]) + if func in CollectiveOp.COMM_TENSOR_ARG_1: + return CollectiveOp.sum_tensors(args[1]) + if func in CollectiveOp.COMM_TENSOR_ARG_RES: + return res.untyped_storage().nbytes() + if func in CollectiveOp.COMM_TENSOR_SINGLE_UNTYPED_STORAGE: + return args[0].untyped_storage().nbytes() + if func == c10d._reduce_scatter_base_.default: + return args[1].untyped_storage().nbytes() + if func == c10d.alltoall_.default: + # TODO(@sanketpurandare) - Confirm size computation + return max( + CollectiveOp.sum_tensors(args[0]), CollectiveOp.sum_tensors(args[1]) + ) + if func == c10d.alltoall_base_.default: + # TODO(@sanketpurandare) - Confirm size computation + return max( + args[0].untyped_storage().nbytes(), args[1].untyped_storage().nbytes() + ) + if func == _c10d_functional.all_gather_into_tensor_out.default: + return args[-1].untyped_storage().nbytes() + if func in CollectiveOp.COMM_TENSOR_RES_SUM: + return CollectiveOp.sum_tensors(res) + if func in CollectiveOp.COMM_TENSOR_ARG_0_AND_RES: + # TODO(@sanketpurandare) - Confirm size computation + return args[0].untyped_storage().nbytes() + res.untyped_storage().nbytes() + raise TypeError(f"Unknown function: {func} in {collective_ops}") + + @staticmethod + def get_work(func, res) -> Work: # type: ignore[no-untyped-def] + if func in CollectiveOp.WK: + return FakeWork.unbox(res) + elif func in CollectiveOp.WK_ARG_1: + return FakeWork.unbox(res[1]) + raise TypeError(f"Func {func} not found in {collective_ops}") diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py b/phivenv/Lib/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..dd17aa6b4a0695bf8ace21198d2802c84d0b1863 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -0,0 +1,547 @@ +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import Any, Callable, NamedTuple, Optional, TypeVar, Union +from typing_extensions import ParamSpec, TypeVarTuple, Unpack + +import torch +import torch.distributed._tools.fake_collectives +from torch import nn, optim +from torch._guards import active_fake_mode +from torch.distributed._tools.mem_tracker import _RefType, _State, MemTracker +from torch.distributed.fsdp import FSDPModule +from torch.distributed.fsdp._fully_shard._fsdp_param_group import FSDPParamGroup +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_map_only +from torch.utils.weak import WeakIdKeyDictionary, weakref + + +_TOTAL_KEY = "Total" + +__all__ = ["FSDPMemTracker"] + +_P = ParamSpec("_P") +_R = TypeVar("_R") +_Ts = TypeVarTuple("_Ts") + +c10d = torch.ops.c10d + + +class _FSDPRefType(_RefType): + """ + Enumerates categories of memory usage in FSDP modules, including parameters, gradients, activations, + and optimizer states. + + Attributes: + SHARDED_PARAM (str): Memory usage of sharded parameters. + UNSHARDED_PARAM (str): Memory usage of unsharded parameters. + SHARDED_GRAD (str): Memory usage of sharded gradients corresponding to the sharded parameters. + UNSHARDED_GRAD (str): Memory usage of unsharded gradients corresponding to the unsharded parameters. + ACT (str): Memory usage of activations and tensors from forward and AC recomputation. + TEMP (str): Memory usage of temporary tensors during the backward pass including gradients of activations. + ALL_GATHER (str): Memory usage of all_gather output tensor. + REDUCE_SCATTER (str): Memory usage of reduce_scatter input tensor. + OPT (str): Memory usage of tensors storing optimizer states. + INP (str): Memory usage of input tensors. + """ + + SHARDED_PARAM = "Sharded Param" + UNSHARDED_PARAM = "Unsharded Param" + BUFFER = "Buffer" + SHARDED_GRAD = "Sharded Grad" + UNSHARDED_GRAD = "Unsharded Grad" + ACT = "Activation" + TEMP = "Temp" + ALL_GATHER = "All Gather" + REDUCE_SCATTER = "Reduce Scatter" + OPT = "OptState" + INP = "Inputs" + + +class _SavedFSDPMethods(NamedTuple): + pre_backward: Callable + post_backward: Callable + + +class _FSDPModState(_State): + """ + Enumerates the states of FSDP modules during the forward and backward passes. + """ + + BEF_PRE_FW = "Before Pre-Forward" + AFT_PRE_FW = "After Pre-Forward" + BEF_POST_FW = "Before Post-Forward" + AFT_POST_FW = "After Post-Forward" + BEF_PRE_BW = "Before Pre-Backward" + AFT_PRE_BW = "After Pre-Backward" + BEF_POST_BW = "Before Post-Backward" + AFT_POST_BW = "After Post-Backward" + PRE_FW_AC = "Pre-Forward AC" + POST_FW_AC = "Post-Forward AC" + PEAK_FW = "Peak Forward" + PEAK_BW = "Peak Backward" + + +class _FSDPModMemStats: + """ + A class to store the memory statistics of an FSDP module. + + Args: + mod_fqn (str): The fully qualified name of the FSDP module. + + Attributes: + snapshots (Dict[_FSDPModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states as defined by ``_FSDPModState``. Each key is a device, and + each value is another dictionary with keys as memory reference types defined by ``_FSDPRefType`` and + values as the memory consumed in bytes. + + """ + + def __init__(self, mod_fqn: str) -> None: + self.mod_fqn = mod_fqn + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[ + _FSDPModState, list[dict[torch.device, dict[str, int]]] + ] = {} + + +class _FSDPState(Enum): + PRE_FW = auto() + FW = auto() + POST_FW = auto() + PRE_BW = auto() + BW = auto() + POST_BW = auto() + + +class FSDPMemTracker(MemTracker): + """ + A ``TorchDispatchMode`` based context manager that extends ``torch.distributed._tools.mem_tracker.MemTracker`` to track + and categorize the peak memory and module-wise memory usage of FSDP modules. + + It tracks the peak memory usage across all the devices of all the FSDP modules in the module tree and categorizes + the tensor memory usage as defined by ``_FSDPRefType``. Further, it captures memory `snapshots` at different stages of + the module execution defined by ``_FSDPModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key is a reference + to a module, and each value is a ``_FSDPModMemStats`` object that stores the memory statistics of the module. + + Args: + mod (torch.nn.Module): The root FSDP module to be tracked. + optm (torch.optim.Optimizer, optional): The optimizer to be tracked. + + Note: Please refer to ``torch.distributed._tools.mem_tracker.MemTracker`` to learn about the limitations. + + Example usage + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + fmt = FSDPMemTracker(module, optimizer) + fmt.track_inputs((inp,)) + with fmt: + optimizer.zero_grad() + loss = module(inp) + print("After Forward:") + fmt.display_snapshot("current") + loss.backward() + optimizer.step() + fmt.display_snapshot("peak") + fmt.display_modulewise_snapshots(depth=3, units="MB") + + """ + + def __init__( + self, + mod: torch.nn.Module, + optm: Optional[torch.optim.Optimizer] = None, + ) -> None: + super().__init__() + assert isinstance(mod, FSDPModule), "FSDPMemTracker only supports FSDP modules" + self._root_mod = mod + self._optm = optm + self._fsdp_mod_to_saved_methods: WeakIdKeyDictionary = WeakIdKeyDictionary() + self._fsdp_state: _FSDPState = _FSDPState.PRE_FW + self._ref_class: type[_RefType] = _FSDPRefType + + def _instrument_fsdp_sharded_params_grads( + self, fsdp_param_group: FSDPParamGroup + ) -> None: + # Track sharded params and grads after initialization + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.sharded_param, + _FSDPRefType.SHARDED_PARAM, + ) + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + def _fsdp_state_pre_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_pre_fw: Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]], + ) -> Callable[_P, tuple[tuple[Unpack[_Ts]], dict[str, Any]]]: + # We capture memory snapshots before and after ``FSDPState._pre_forward`` to attribute the `unsharded` params + # and `all_gather` buffers. There are three cases: + # Case 1: If the module is not in the ``memory_tracking`` dictionary, create a new ``_FSDPModMemStats`` + # instance for the module and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + # For Case 2 we only capture 1 snapshot after ``FSDPState._pre_forward`` runs because it is a no-op. + @wraps(orig_fsdp_state_pre_fw) + def inner( + *args: _P.args, **kwargs: _P.kwargs + ) -> tuple[tuple[Unpack[_Ts]], dict[str, Any]]: + self._fsdp_state = _FSDPState.PRE_FW + mod_fqn = self._mod_tracker.get_known_fqn(fsdp_mod) + assert mod_fqn is not None + if fsdp_mod not in self.memory_tracking: + mod_stat = _FSDPModMemStats(mod_fqn) + self.memory_tracking[fsdp_mod] = mod_stat + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_FW, []).append( + snapshot + ) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_FW, []).append( + deepcopy(snapshot) + ) + elif not self._mod_tracker.is_bw: + parents = self._mod_tracker.parents - {mod_fqn} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "FSDPMemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + + args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs) + + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + self._update_and_maybe_create_winfos( + fsdp_param.unsharded_param, + _FSDPRefType.UNSHARDED_PARAM, + ) + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(fsdp_mod) + self._in_ac = True + else: + state = _FSDPModState.AFT_PRE_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + self._fsdp_state = _FSDPState.FW + return args, kwargs + + return inner + + def _fsdp_state_post_forward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_state_post_fw: Callable[_P, _R], + ) -> Callable[_P, _R]: + # We capture memory snapshots before and after ``FSDPState._post_forward`` to capture the resharded state + # if ``reshard_after_forward`` is not ``False``. There are two cases: + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward. + @wraps(orig_fsdp_state_post_fw) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> _R: + mod_stat = self.memory_tracking[fsdp_mod] + if self._mod_tracker.is_bw: + state = _FSDPModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is fsdp_mod: + self._ac_mod = None + self._in_ac = False + else: + state = _FSDPModState.BEF_POST_FW + mod_stat.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + self._fsdp_state = _FSDPState.POST_FW + + output = orig_fsdp_state_post_fw(*args, **kwargs) + + if not self._mod_tracker.is_bw: + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_FW, []).append( + self.get_tracker_snapshot() + ) + return output + + return inner + + def _fsdp_param_group_pre_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_pre_backward: Callable[_P, Any], + ) -> Callable[_P, None]: + # We capture memory snapshots before and after ``FSDPParamGroup.pre_backward`` to capture the pre-fetching + # and unsharding of params. We also initialize ``local_peak`` and ``PEAK_BW`` snapshot for the module. + @wraps(orig_fsdp_param_group_pre_backward) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> None: + self._fsdp_state = _FSDPState.PRE_BW + mod_stat = self.memory_tracking[fsdp_mod] + snapshot = self.get_tracker_snapshot() + mod_stat.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in snapshot.items() + } + mod_stat.snapshots.setdefault(_FSDPModState.PEAK_BW, []).append(snapshot) + mod_stat.snapshots.setdefault(_FSDPModState.BEF_PRE_BW, []).append( + deepcopy(snapshot) + ) + orig_fsdp_param_group_pre_backward(*args, **kwargs) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_PRE_BW, []).append( + self.get_tracker_snapshot() + ) + self._fsdp_state = _FSDPState.BW + + return inner + + def _fsdp_param_group_post_backward( + self, + fsdp_mod: FSDPModule, + orig_fsdp_param_group_post_backward: Callable[_P, Any], + ) -> Callable[_P, None]: + # We capture the memory snapshots before and after ``FSDPParamGroup.post_backward`` to track and attribute + # the `unsharded` grads before the post backward and then `sharded` grads and `reduce_scatter` buffers + # after the post backward. + @wraps(orig_fsdp_param_group_post_backward) + def inner(*args: _P.args, **kwargs: _P.kwargs) -> None: + fsdp_state = fsdp_mod._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + unsharded_grad = fsdp_param._unsharded_param.grad + if unsharded_grad is not None: + self._update_and_maybe_create_winfos( + unsharded_grad, + _FSDPRefType.UNSHARDED_GRAD, + update_existing=True, + ) + + mod_stat = self.memory_tracking[fsdp_mod] + mod_stat.snapshots.setdefault(_FSDPModState.BEF_POST_BW, []).append( + self.get_tracker_snapshot() + ) + self._fsdp_state = _FSDPState.POST_BW + orig_fsdp_param_group_post_backward(*args, **kwargs) + + if fsdp_param_group := fsdp_state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + sharded_grad = fsdp_param.sharded_param.grad + if sharded_grad is not None: + self._update_and_maybe_create_winfos( + sharded_grad, + _FSDPRefType.SHARDED_GRAD, + ) + + mod_stat.snapshots.setdefault(_FSDPModState.AFT_POST_BW, []).append( + self.get_tracker_snapshot() + ) + + return inner + + def _instrument_fsdp_module(self) -> None: + # We uninstall the existing `FSDPState._pre_forward` and `FSDPState._post_forward` hooks and install + # our own hooks that wrap them. We choose this over monkey-patching `FSDPParamGroup.pre_forward` and + # `FSDPParamGroup.post_forward` because during AC these won't be called. + # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786) + # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`. + for module in self._root_mod.modules(): + if isinstance(module, FSDPModule): + fsdp_state = module._get_fsdp_state() + if fsdp_param_group := fsdp_state._fsdp_param_group: + self._instrument_fsdp_sharded_params_grads(fsdp_param_group) + fsdp_state._pre_forward_hook_handle.remove() + fsdp_state._post_forward_hook_handle.remove() + fsdp_state._pre_forward_hook_handle = ( + module.register_forward_pre_hook( + self._fsdp_state_pre_forward( + module, fsdp_state._pre_forward + ), + prepend=True, + with_kwargs=True, + ) + ) + fsdp_state._post_forward_hook_handle = module.register_forward_hook( + self._fsdp_state_post_forward(module, fsdp_state._post_forward), + prepend=False, + always_call=True, + ) + self._fsdp_mod_to_saved_methods[module] = _SavedFSDPMethods( + fsdp_param_group.pre_backward, + fsdp_param_group.post_backward, + ) + fsdp_param_group.pre_backward = self._fsdp_param_group_pre_backward( # type: ignore[assignment] + module, fsdp_param_group.pre_backward + ) + fsdp_param_group.post_backward = ( # type: ignore[assignment] + self._fsdp_param_group_post_backward( + module, fsdp_param_group.post_backward + ) + ) + + for buffer in self._root_mod.buffers(): + self._update_and_maybe_create_winfos( + buffer, + _FSDPRefType.BUFFER, + ) + + def _instrument_optimizer(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + if self._optm is not None: + self._track_optimizer_states(_FSDPRefType.OPT, self._optm) + + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_FSDPRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + self._optm.register_step_pre_hook(_opt_step_pre_hook), + self._optm.register_step_post_hook(_opt_step_post_hook), + ) + + def _register_module_and_optimizer_hooks(self) -> None: + self._instrument_fsdp_module() + self._instrument_optimizer() + + def _deregister_module_and_optimizer_hooks(self) -> None: + for ( + fsdp_mod, + saved_methods, + ) in self._fsdp_mod_to_saved_methods.items(): + fsdp_state = fsdp_mod._get_fsdp_state() + fsdp_state._pre_forward_hook_handle.remove() + fsdp_state._post_forward_hook_handle.remove() + fsdp_state._pre_forward_hook_handle = fsdp_mod.register_forward_pre_hook( + fsdp_state._pre_forward, prepend=True, with_kwargs=True + ) + fsdp_state._post_forward_hook_handle = fsdp_mod.register_forward_hook( + fsdp_state._post_forward, prepend=False + ) + if fsdp_param_group := fsdp_state._fsdp_param_group: + fsdp_param_group.pre_backward = saved_methods.pre_backward + fsdp_param_group.post_backward = saved_methods.post_backward + self._fsdp_mod_to_saved_methods.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_inputs(self, inputs: tuple[Any, ...]) -> None: + """ + This is used to track the input tensors to the model and annotate them as ``Inputs``. + Args: + inputs (Tuple[Any]): A tuple containing the input data. This can include tensors + as well as other data types. Only tensors will be tracked. + """ + + def _track_inputs(t: torch.Tensor) -> None: + self._update_and_maybe_create_winfos( + t, + _FSDPRefType.INP, + ) + + tree_map_only(torch.Tensor, _track_inputs, inputs) + + def track_external( + self, *external: Union[nn.Module, optim.Optimizer, torch.Tensor] + ) -> None: + """This is no-op for ``FSDPMemTracker``""" + + def __enter__(self) -> "FSDPMemTracker": + if self._depth == 0: + self._register_module_and_optimizer_hooks() + self._track_resize() + self._track_dtensor_dispatch() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] + for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + TorchDispatchMode.__enter__(self) + self._depth += 1 + return self + + def __exit__(self, *args: Any) -> None: + self._depth -= 1 + if self._depth == 0: + self._deregister_module_and_optimizer_hooks() + self._restore_resize() + self._restore_dtensor_dispatch() + self._mod_tracker.__exit__(*args) + TorchDispatchMode.__exit__(self, *args) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + if ( + func == torch.ops._c10d_functional.wait_tensor.default + and active_fake_mode() + ): + # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns + # a new tensor which does not happen in eager mode, when a wait_tensor is called. + res = args[0] + else: + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _FSDPRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _FSDPRefType.TEMP + else: + reftype = _FSDPRefType.ACT + if func == c10d._allgather_base_.default and self._fsdp_state in [ + _FSDPState.PRE_FW, + _FSDPState.PRE_BW, + ]: + output_tensor = args[0] + self._update_and_maybe_create_winfos( + output_tensor, + _FSDPRefType.ALL_GATHER, + update_existing=True, + ) + if ( + func == c10d._reduce_scatter_base_.default + and self._fsdp_state == _FSDPState.POST_BW + ): + input_tensor = args[1] + self._update_and_maybe_create_winfos( + input_tensor, + _FSDPRefType.REDUCE_SCATTER, + update_existing=True, + ) + + tree_map_only(torch.Tensor, partial(self._track, reftype), res) + peak_state = ( + _FSDPModState.PEAK_BW if self._mod_tracker.is_bw else _FSDPModState.PEAK_FW + ) + self._update_peak_stats(peak_state) + return res diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/ilp_utils.py b/phivenv/Lib/site-packages/torch/distributed/_tools/ilp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1d708edb581861d63e70656b6c845c38fb57caaf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/ilp_utils.py @@ -0,0 +1,292 @@ +import copy +from collections import OrderedDict +from typing import cast, TypedDict + +import numpy as np + +import torch +from torch.distributed._tools.mem_tracker import ( + _MemRefType, + _ModMemStats, + _ModState, + MemTracker, +) +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.distributed._tools.sac_estimator import SACEstimator, SACTradeOffStats + + +class ModOrder(TypedDict): + fw_pre_order: list[str] + bw_pre_order: list[str] + fw_post_order: list[str] + bw_post_order: list[str] + + +class ModRuntime(TypedDict): + fw: float + bw: float + + +class ModStats(TypedDict): + fqn: str + # per-module params + param_per_module: int + # per-module grads + grad_per_module: int + # total accumulated gradients up to and including this module + grad_total: int + # per module fw activation size (excluding input and output) + act_fw_per_module: int + # per module bw activation size during peak_bw + act_bw_per_module: int + # per module activation grad size during peak_bw + act_grad_per_module: int + # total activation size up to but excluding the current module + # includes input of the current module (i.e., output of previous module) + act_total: int + # Inputs to the module + input_per_module: int + # Outputs of the module + output_per_module: int + # Total fw run-time of the module + fw_runtime_per_module: float + # Total bw run-time of the module + bw_runtime_per_module: float + # Is this module a leaf module + is_leaf: bool + # Total ac run-time of the module + sac_runtime: float + # Total ac_memory for the module + sac_memory: int + # Number of piecewise-linear functions used for approximating ac tradeoff curve + n_segments: int + # Slopes of the of piecewise-linear functions + slopes: list[float] + # Intercepts of the of piecewise-linear functions + intercepts: list[float] + # X breakpoints of the of piecewise-linear functions + breakpoints: list[float] + # Original trade-off curves + tradeoff_curve: OrderedDict[float, float] + + +class ModuleInfo(TypedDict): + mod_order: ModOrder + mod_stats: list[ModStats] + + +def aggregate_stats( + model: torch.nn.Module, + mem_tracker: MemTracker, + runtime_estimator: RuntimeEstimator, + sac_estimator: SACEstimator, + dev: torch.device, +) -> ModuleInfo: + """ + Collect modulewise stats for a given model, including memory, runtime, and AC tradeoff stats. + + Args: + model: nn.Module object + runtime_estimator: RuntimeEstimator object with runtime stats + mem_tracker: MemTracker object with memory stats + sac_estimator: SACEstimator object with AC tradeoff stats + dev: device the model was run on (used to extract memory stats from MemTracker) + + Returns: + ModuleInfo: A dictionary with module order and module stats. + """ + + # Memory stats + mod_mem_stats: dict[torch.nn.Module, _ModMemStats] = dict( + copy.deepcopy(mem_tracker.memory_tracking) + ) + + # Runtime stats + mod_runtime_stats: dict[str, ModRuntime] = { + fqn: {"fw": v["fw"], "bw": v["bw"]} + for fqn, v in runtime_estimator.mod_runtimes.items() + } + + # Module order + mod_order: ModOrder = { + "fw_pre_order": list(runtime_estimator.mod_fw_pre_order), + "bw_pre_order": list(runtime_estimator.mod_bw_pre_order), + "fw_post_order": list(runtime_estimator.mod_fw_post_order), + "bw_post_order": list(runtime_estimator.mod_bw_post_order), + } + + # Selective Activation Checkpointing stats + sac_estimator.pwlf_sac_tradeoff_curve() + mod_sac_tradeoff_stats: dict[str, SACTradeOffStats] = copy.deepcopy( + sac_estimator.sac_mod_tradeoff_stats + ) + + module_info: ModuleInfo = { + "mod_order": mod_order, + "mod_stats": [], + } + + for mod in model.modules(): + if mod_mem_stat := mod_mem_stats.get(mod, None): + if tradeoff_stats := mod_sac_tradeoff_stats.get(mod_mem_stat.mod_fqn, None): + sac_runtime = tradeoff_stats.sac_runtime + sac_memory = tradeoff_stats.sac_memory + n_segments = tradeoff_stats.n_segments + slopes = tradeoff_stats.slopes + intercepts = tradeoff_stats.intercepts + breakpoints = tradeoff_stats.fit_breaks + tradeoff_curve = tradeoff_stats.tradeoff_curve + is_leaf = False + else: + sac_runtime = sac_memory = n_segments = 0 + slopes = intercepts = breakpoints = [] + tradeoff_curve: OrderedDict[float, float] = OrderedDict() # type: ignore[no-redef] + is_leaf = True + mod_stat: ModStats = { + "fqn": mod_mem_stat.mod_fqn, + "param_per_module": mod_mem_stat.parameter_mem, + "grad_per_module": mod_mem_stat.parameter_mem, + "grad_total": mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.GRAD + ], + "act_fw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.snapshots[_ModState.PRE_FW][-1][dev][_MemRefType.ACT] + - mod_mem_stat.output_mem, + ), + "act_bw_per_module": max( + 0, + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.ACT], + ), + "act_grad_per_module": ( + mod_mem_stat.snapshots[_ModState.PEAK_BW][-1][dev][_MemRefType.TEMP] + - mod_mem_stat.snapshots[_ModState.PRE_BW][-1][dev][ + _MemRefType.TEMP + ] + ), + "act_total": mod_mem_stat.snapshots[_ModState.POST_FW][-1][dev][ + _MemRefType.ACT + ], + "input_per_module": mod_mem_stat.input_mem, + "output_per_module": mod_mem_stat.output_mem, + "fw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["fw"], + "bw_runtime_per_module": mod_runtime_stats[mod_mem_stat.mod_fqn]["bw"], + "is_leaf": is_leaf, + "sac_runtime": sac_runtime, + "sac_memory": sac_memory, + "n_segments": n_segments, + "slopes": slopes, + "intercepts": intercepts, + "breakpoints": breakpoints, + "tradeoff_curve": tradeoff_curve, + } + module_info["mod_stats"].append(mod_stat) + + return module_info + + +class Node(ModStats): + index: int # index according to forward pre-order + pos_fw_post_order: int # index according to forward post-order + + +class Graph: + def __init__(self, n: int) -> None: + self.nodes: list[Node] = [] + self.name2node: dict[str, Node] = {} + self.ad_matrix = np.zeros((n, n)) + self.fw_post_order: list[str] = [] + + def add_node(self, node: Node) -> None: + self.nodes.append(node) + self.name2node[node["fqn"]] = node + + +def parse_module_info(module_info: ModuleInfo) -> Graph: + """ + Parse module info and create a graph (tree) of modules. The graph will be + used by MILP solver to find optimal SAC and/or FSDP configurations. + """ + mod_stats = module_info["mod_stats"] + fw_pre_order = module_info["mod_order"]["fw_pre_order"] + # assertion and number of nodes + assert len(mod_stats) == len(fw_pre_order) + n_nodes = len(mod_stats) + + # create graph + g = Graph(n_nodes) + g.fw_post_order = module_info["mod_order"]["fw_post_order"] + + # sort the modules by pre-order and add them to the graph + module_info["mod_stats"] = sorted( + mod_stats, key=lambda x: fw_pre_order.index(x["fqn"]) + ) + for i, one_mod_stats in enumerate(mod_stats): + node: Node = cast(Node, one_mod_stats) + node["index"] = i + node["pos_fw_post_order"] = g.fw_post_order.index(node["fqn"]) + g.add_node(node) + + # set up ancestor-descendant matrix + for i in range(n_nodes): + for j in range(i, n_nodes): + if is_self_or_submodule(g.nodes[j]["fqn"], g.nodes[i]["fqn"]): + g.ad_matrix[i][j] = 1 + else: + break + + return g + + +def is_self_or_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + check if name_descendant is a submodule of name_ancestor, or if they are the same + """ + return name_descendant == name_ancestor or name_ancestor + "." in name_descendant + + +def is_submodule(name_descendant: str, name_ancestor: str) -> bool: + """ + if name_descendant is a submodule of name_ancestor, but not the same + """ + return name_ancestor + "." in name_descendant + + +def display_bytes(b: int, unit: str = "MiB") -> str: + """ + return a string that represent the number of bytes in a desired unit + """ + if unit == "KiB": + return f"{b / 2**10:.2f} KiB" + if unit == "MiB": + return f"{b / 2**20:.2f} MiB" + if unit == "GiB": + return f"{b / 2**30:.2f} GiB" + return f"{b:.2f} bytes" + + +def get_peak_memory_runtime_baseline(graph: Graph) -> tuple[int, float]: + """ + Get the baseline peak memory and runtime. + Baseline here means there is no FSDP or AC. + Memory includes the parameters, gradients, activations, and activation gradients. + Memory does not include e.g., optimizer states, embedding tables, etc. + + Returns: + int: peak memory in bytes + float: compute time in ms + """ + P_1 = graph.nodes[0]["param_per_module"] + num_nodes = len(graph.nodes) + peak_mem = 0 + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] + AG_i = graph.nodes[i]["act_grad_per_module"] + TA_i = graph.nodes[i]["act_total"] + peak_mem = max(peak_mem, P_1 + TG_i + AG_i + TA_i) + compute_time = ( + graph.nodes[0]["fw_runtime_per_module"] + + graph.nodes[0]["bw_runtime_per_module"] + ) + return (peak_mem, compute_time) diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/mem_tracker.py b/phivenv/Lib/site-packages/torch/distributed/_tools/mem_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..47fbcd50f1d0dc409d1d8ff145f7ce61d48a8787 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/mem_tracker.py @@ -0,0 +1,949 @@ +import math +import os +import re +import warnings +from contextlib import nullcontext +from copy import deepcopy +from enum import auto, Enum +from functools import partial, wraps +from typing import Any, Callable, Optional, TYPE_CHECKING, Union +from typing_extensions import Self + +import torch +import torch.distributed._tools.fake_collectives +from torch import nn, optim +from torch._guards import active_fake_mode +from torch.distributed._tools.common_utils import get_untyped_storages +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor import DTensor +from torch.optim.optimizer import ( + register_optimizer_step_post_hook, + register_optimizer_step_pre_hook, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten, tree_map_only +from torch.utils.weak import WeakIdKeyDictionary, weakref + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) +_TOTAL_KEY = "Total" + +__all__ = ["MemTracker"] + + +class _RefType(str, Enum): + """Base Class for defining memory reference types, categorizing tensors based on their usage within a model.""" + + +class _State(str, Enum): + """Base Class for defining module state to capture snapshots .""" + + +class _MemRefType(_RefType): + """ + An enum to define memory reference types, categorizing tensors based on their usage within a model. + + - PARAM: Tensors registered as nn.Parameter within modules. + - BUFFER: Tensors registered as nn.Buffer within modules. + - GRAD: Gradients associated with parameters. + - ACT: Tensors produced during the forward pass and recomputation in activation checkpointing. + - TMP: Temporary memory used during the backward pass, including gradients of activations. + - OPT: Tensors holding optimizer states. + - OTH: Tensors registered via `track_external` that do not fit the above categories. + """ + + PARAM = "Parameter" + BUFFER = "Buffer" + GRAD = "Gradient" + ACT = "Activation" + TEMP = "Temp" + OPT = "Optstate" + OTH = "Other" + + +class _ModState(_State): + """ + An enum to define the state of a module. + + - PRE_FW: The module is about to run the forward pass. + - POST_FW: The module has finished running the forward pass. + - PEAK_FW: The module has reached the peak memory usage during the forward pass. + - PRE_BW: The module is about to run the backward pass. + - PRE_FW_AC: The module is about to run the forward pass with activation checkpointing. + - POST_FW_AC: The module has finished running the forward pass with activation checkpointing. + - POST_BW: The module has finished running the backward pass. + - PEAK_BW: The module has reached the peak memory usage during the backward pass. + """ + + PRE_FW = "Pre-Forward" + POST_FW = "Post-Forward" + PEAK_FW = "Peak-Forward" + PRE_BW = "Pre-Backward" + PRE_FW_AC = "Pre-Forward-AC" + POST_FW_AC = "Post-Forward-AC" + POST_BW = "Post-Backward" + PEAK_BW = "Peak-Backward" + + +class _ModMemStats: + """ + A class to store the memory statistics of a module. + + Args: + mod_fqn (str): The fully qualified name of the module. + Attributes: + mod_fqn (str): The fully qualified name of the module. + parameter_mem (int): The memory usage of the parameters of the module. + buffer_mem (int): The memory usage of the buffers of the module. + input_mem (int): The memory usage of the inputs to the module. + output_mem (int): The memory usage of the outputs from the module. + snapshots (Dict[_ModState, Dict[torch.device, Dict[str, int]]]): A dictionary of memory snapshots + of the module at different states defined by ``_ModState``. + Note: + The memory snapshot is stored as a dictionary - Dict[torch.device, Dict[str, int]], where each key is a device, + and each value is another dictionary with keys as memory reference types defined by `_MemRefType` and + values as the memory consumed in bytes. + """ + + def __init__(self, mod_fqn: str): + self.mod_fqn = mod_fqn + self.parameter_mem: int + self.buffer_mem: int + self.input_mem: int + self.output_mem: int + self.local_peak: dict[torch.device, int] = {} + self.snapshots: dict[_ModState, list[dict[torch.device, dict[str, int]]]] = {} + + +class _WeakRefInfo: + """ + Manages memory statistics and device attributes for tensor storages. + """ + + def __init__( + self, size: int, element_size: int, device: torch.device, reftype: _RefType + ) -> None: + """ + Initializes the ``_WeakRefInfo`` object with tensor storage properties. + + Args: + size (int): The number of elements in the tensor storage. + element_size (int): The size of each element in the tensor storage. + device (torch.device): The device on which the tensor is allocated. + reftype (_RefType): The reference type of the tensor. + """ + self.size = size + self.element_size = element_size + self.reftype = reftype + self.device = device + self.mem_consumed = self._calculate_mem_consumed() + + def _calculate_mem_consumed(self) -> int: + """ + Calculates the memory consumed by the tensor storage, considering device-specific allocation rules. + + Returns: + int: The memory consumed in bytes. + """ + mem = self.size * self.element_size + if self.device.type == "cuda": + return math.ceil((mem) / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + return mem + + def update_mem_consumed(self, st: torch.UntypedStorage) -> int: + """ + Updates and returns the memory consumed if the storage size has changed. + + Args: + st (torch.UntypedStorage): The tensor storage to check for size updates. + + Returns: + int: The updated memory consumed in bytes. + """ + if st.size() != self.size: + self.size = st.size() + self.mem_consumed = self._calculate_mem_consumed() + return self.mem_consumed + + @classmethod + def create_winfo( + cls, + st: torch.UntypedStorage, + device: torch.device, + reftype: _RefType, + callback: Optional[Callable[[Self, weakref.ref], Any]] = None, + ) -> tuple[Self, weakref.ref]: + """ + Creates a new ``_WeakRefInfo`` instance and a weak reference to a ``torch.UntypedStorage`` object, + optionally attaching a callback to the weak reference. + + Args: + st (torch.UntypedStorage): The storage object for which to create the weak reference info. + device (torch.device): The device associated with the storage object. + reftype (_RefType): The type of reference, used to categorize the storage. + callback (Optional[Callable[[Self, weakref.ref]]]): A callback function that is called when + the storage object is about to be finalized (garbage collected). The callback function + should accept two arguments: the ``_WeakRefInfo`` instance and the weak reference to the storage. + Returns: + Tuple[Self, weakref.ref]: A tuple containing the newly created ``_WeakRefInfo`` instance and the + weak reference to the storage object. The weak reference may have an attached callback if provided. + """ + + winfo = cls(st.size(), st.element_size(), device, reftype) + w_st = weakref.ref(st, partial(callback, winfo) if callback else None) + return winfo, w_st + + +def _get_mem_divisor(units: str) -> int: + unit_dict = {"B": 1, "KiB": 2**10, "MiB": 2**20, "GiB": 2**30} + if units in unit_dict: + return unit_dict[units] + else: + raise ValueError( + f"Unsupported unit: {units}. Supported units are: {', '.join(unit_dict.keys())}" + ) + + +def _rounding_fn(value: int, divisor: int, precision: int) -> Union[float, int]: + return value if divisor == 1 else round(value / divisor, precision) + + +def _print_snapshot(snapshot: dict[torch.device, dict[str, int]], units: str) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + divisor = _get_mem_divisor(units) + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + print( + f"Device: {dev}", + *( + f"\t{k.value}: {_rounding_fn(v, divisor, 2)} {units}" + if isinstance(k, _RefType) + else f"\t{k}: {_rounding_fn(v, divisor, 2)} {units}" + for k, v in dev_snap.items() + ), + sep="\n", + ) + + +def _print_snapshot_tabular( + snapshot: dict[torch.device, dict[str, int]], units: str +) -> None: + if len(snapshot) == 0: + print("No memory tracked.") + return + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + divisor = _get_mem_divisor(units) + table_data = [] + key_list = list(next(iter(snapshot.values())).keys()) + headers = ["Device"] + [ + f"{key.value}" if isinstance(key, _RefType) else f"{key}" for key in key_list + ] + + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = [str(dev)] + row.extend(f"{_rounding_fn(v, divisor, 2)} {units}" for v in dev_snap.values()) + table_data.append(row) + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +def _print_state_snapshots( + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str +) -> None: + for state, snapshot_list in snapshots.items(): + print(f"{state.value}") + for i, snapshot in enumerate(snapshot_list): + print(f"# {i + 1}:") + _print_snapshot(snapshot, units) + print() + + +def _print_state_snapshots_tabular( + snapshots: dict[_State, list[dict[torch.device, dict[str, int]]]], units: str +) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError( + "Please install tabulate to use the tabulate option." + ) from err + + table_data = [] + last_state_call = None + divisor = _get_mem_divisor(units) + for state, snapshot_list in snapshots.items(): + for i, snapshot in enumerate(snapshot_list): + state_call = f"{state.value} # {i + 1}" + for dev, dev_snap in snapshot.items(): + if _rounding_fn(dev_snap[_TOTAL_KEY], divisor, 2) <= 0: + continue + row = { + "State & Call": ( + state_call if state_call != last_state_call else "" + ), + "Device": str(dev), + } + last_state_call = state_call + for k, v in dev_snap.items(): + row[f"{k.value}" if isinstance(k, _RefType) else f"{k}"] = ( + f"{_rounding_fn(v, divisor, 2)} {units}" + ) + table_data.append(row) + print(tabulate(table_data, headers="keys", tablefmt="rst")) + + +class _UpdateType(Enum): + # These are used for tracking updates to the continuouly maintained memory snapshot. + # ADD - When a new tensor storage is tracked + # DEL - When a tensor storage is about to be finalized (garbage collected). + # REF - When a tensor reference is updated, for instance, the gradients are marked as + # generic backward reference types until the grad_hook categorizes them as gradients. + # SIZE - When a tensor's storage is resized. + ADD = auto() + DEL = auto() + REF = auto() + SIZE = auto() + + +class MemTracker(TorchDispatchMode): + """ + A TorchDispatchMode to track, categorize and attribute the tensor memory created or accessed within its context. + + It categorizes the tracked tensors as parameters, buffers, activations, gradients, temporary memory and optimizer states + as defined by ``_MemRefType`` within its context. It captures memory `snapshots` for the modules, called within its context, + at various states defined by ``_ModState``. + + Attributes: + memory_tracking: A weakref key dictionary to store the memory statistics of each module. Each key + is a reference to a module, and each value is a ``_ModMemStats`` object that stores the memory + statistics of the module. + + Note: + The MemTracker should be used as a context manager. The modules, optimizers, and any other tensors created within + the context of MemTracker will be tracked by default. Any tensors or stateful objects such as modules, optimizers etc. + that need to be tracked but are created outside the MemTracker should be registered using the `track_external` method. + The `track_external` method should be called before the MemTracker is used. Any tensors created outside the ``MemTracker`` + and not supplied to the `track_external` method will not be tracked by the ``MemTracker``. + + Example usage: + + .. code-block:: python + + module = ... + optimizer = ... + inp = ... + mem_tracker = MemTracker() + mem_tracker.track_external(module, optimizer, inp) + with mem_tracker as mt: + loss = module(inp) + print("After Forward:") + mt.display_snapshot("current") + loss.backward() + optimizer.step() + optimizer.zero_grad() + mt.display_snapshot("peak") + mt.display_modulewise_snapshots(depth=3, units="MiB") + + Known Limitations: + - The ``MemTracker`` does not track memory for tensors that bypass the ``TorchDispatchMode`` ex. under ``no_dispatch``. + - Resizing tensor storages directly by using non-Tensor methods other than using ``torch.Untyped_Storage.resize_`` + is not tracked. File a Github issue if you have use-cases for this. + - If the tensors are not traceable or wrappable subclasses of ``torch.Tensor``, then the tracker does not know how to + track their storages. File a Github issue if you have use-cases for this. + - During AC in the backward pass there might be misattribution between activation and temp memory, but the peak memory + will be tracked accurately. This will be fixed in the next update by hooking intricately with ``torch.uitls.checkpoint``. + """ + + def __init__(self) -> None: + self.memory_tracking = WeakIdKeyDictionary() + self._curr_mem_snap: dict[torch.device, dict[str, int]] = {} + self._peak_mem: dict[torch.device, int] = {} + self._peak_mem_snap: dict[torch.device, dict[str, int]] = {} + self._param_to_grad_hook_handles = WeakIdKeyDictionary() + self._optimizer_hook_handles: Optional[ + tuple[RemovableHandle, RemovableHandle] + ] = None + # Dictionary to store the ``_WeakRefInfo`` instances corresponding to each tensor's storage. + self._WINFO = WeakIdKeyDictionary() + self._mod_tracker = ModTracker() + # This is a general memory tracker which can be used with any ``_RefType`` subclass + self._ref_class: type[_RefType] = _MemRefType + # Flags to track if we are in the AC region or optimizer step region + self._in_opt: bool = False + self._in_ac: bool = False + # Weak references to the topmost AC module currently active + self._ac_mod: Optional[weakref.ref] = None + self._orig_resize = torch.UntypedStorage.resize_ + self._orig_dtensor_dispatch = DTensor._op_dispatcher.dispatch + self._depth = 0 + + def _update_snap( + self, + u_type: _UpdateType, + winfo: _WeakRefInfo, + old_mem_consumed: Optional[int] = None, + old_reftype: Optional[_RefType] = None, + ) -> None: + # Initialize a flag to track if the total memory might drop to zero after updates. + maybe_zero = False + # Ensure the device entry exists in the current memory snapshot, initializing if necessary. + dev_snap = self._curr_mem_snap.setdefault( + winfo.device, dict.fromkeys(self._ref_class, 0) + ) + dev_snap.setdefault(_TOTAL_KEY, 0) + # Handle different types of updates based on the update type (`u_type`). + if u_type == _UpdateType.ADD: + # Increase the memory consumed for the specific reference type and update the total. + dev_snap[winfo.reftype] += winfo.mem_consumed + dev_snap[_TOTAL_KEY] += winfo.mem_consumed + elif u_type == _UpdateType.DEL: + # Decrease the memory consumed for the specific reference type and reduce the total. + dev_snap[winfo.reftype] -= winfo.mem_consumed + dev_snap[_TOTAL_KEY] -= winfo.mem_consumed + maybe_zero = True + elif u_type == _UpdateType.REF: + assert old_reftype is not None + # Adjust memory consumption between two reference types within the same device. + dev_snap[old_reftype] -= winfo.mem_consumed + dev_snap[winfo.reftype] += winfo.mem_consumed + elif u_type == _UpdateType.SIZE: + assert old_mem_consumed is not None + # Adjust the memory consumed for a reference type due to a change in size. + change = winfo.mem_consumed - old_mem_consumed + dev_snap[winfo.reftype] += change + dev_snap[_TOTAL_KEY] += change + maybe_zero = True + else: + raise ValueError(f"Invalid update type: {u_type}") + # Check if the total memory for the device has dropped to zero. + if maybe_zero: + if self._curr_mem_snap[winfo.device][_TOTAL_KEY] == 0: + # Remove the device entry from the memory snapshot if the total memory is zero. + del self._curr_mem_snap[winfo.device] + + def _update_and_maybe_create_winfos( + self, + t: torch.Tensor, + reftype: _RefType, + update_existing: bool = False, + ) -> set[_WeakRefInfo]: + sts = get_untyped_storages(t) + winfos = set() + for st in sts: + # Attempt to retrieve existing ``_WeakRefInfo`` and its weak reference from the tracking dictionary. + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + # If ``_WeakRefInfo`` exists, check if the reference type needs to be updated. + old_reftype = winfo.reftype + if old_reftype != reftype: + # Update the reference type and apply changes via ``_update_snap``. + winfo.reftype = reftype + self._update_snap(_UpdateType.REF, winfo, old_reftype=old_reftype) + winfos.add(winfo) + elif update_existing: + # If no existing ``_WeakRefInfo`` is found and update_existing is True, raise an error. + raise KeyError("No existing winfo found") + else: + # If no existing _WeakRefInfo is found and update_existing is False, create a new ``_WeakRefInfo``. + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + # Store the new ``_WeakRefInfo`` and its weak reference in the tracking dictionary. + self._WINFO[st] = (winfo, w_st) + # Update the snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + winfos.add(winfo) + return winfos + + def _delete_callback(self, winfo: _WeakRefInfo, w_st: weakref.ref) -> None: + # Callback to be called when the storage object corresponding to the ``_WeakRefInfo`` + # instance is about to be finalized. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.DEL, winfo) + + def _track_resize(self) -> None: + # Need to monkey-patch this because ``torch.UntypedStorage.resize_`` is not captured + # by ``TorchDispatchMode``. + @wraps(self._orig_resize) + def resize_(st: torch.UntypedStorage, size: int) -> None: + self._orig_resize(st, size) + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None and winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + + torch.UntypedStorage.resize_ = resize_ # type: ignore[method-assign, assignment] + + def _restore_resize(self) -> None: + torch.UntypedStorage.resize_ = self._orig_resize # type: ignore[method-assign] + + def _update_peak_stats(self, peak_state: _State) -> None: + # We first capture the current memory snapshot of the current tracker state then, + # We step through each of the modules we have tracked so far in ``memory_tracking`` + # and check if it is currently active by querying ``_mod_tracker.parents`` + # If it is active, we update the per device peak memory usage for the module + # corresponding to the ``_State`` which can be ``PEAK_FW`` or ``PEAK_BW``. + curr_snap = self._curr_mem_snap + + for mod_stats in self.memory_tracking.values(): + if mod_stats.mod_fqn in self._mod_tracker.parents: + if peak_state in mod_stats.snapshots: + for dev, dev_snap in curr_snap.items(): + if mod_stats.local_peak.get(dev, 0) < dev_snap[_TOTAL_KEY]: + mod_stats.local_peak[dev] = dev_snap[_TOTAL_KEY] + mod_stats.snapshots[peak_state][-1][dev] = deepcopy( + dev_snap + ) + + for dev, dev_snap in curr_snap.items(): + if self._peak_mem.get(dev, 0) < dev_snap[_TOTAL_KEY]: + self._peak_mem[dev] = dev_snap[_TOTAL_KEY] + self._peak_mem_snap[dev] = deepcopy(dev_snap) + + def _track(self, reftype: _RefType, t: torch.Tensor) -> None: + # Get the storages of the tensor and check if we have already tracked them. + # If yes, then check if the storage size has changed and update the current snapshot. + # Else create a new ``_WeakRefInfo`` instance and add it to the dictionary. + sts = get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + if winfo.size != st.size(): + old_mem_consumed = winfo.mem_consumed + winfo.update_mem_consumed(st) + self._update_snap( + _UpdateType.SIZE, winfo, old_mem_consumed=old_mem_consumed + ) + return + else: + winfo, w_st = _WeakRefInfo.create_winfo( + st, t.device, reftype, self._delete_callback + ) + self._WINFO[st] = (winfo, w_st) + # Update the current snapshot for the newly added ``_WeakRefInfo``. + if winfo.mem_consumed > 0: + self._update_snap(_UpdateType.ADD, winfo) + + def get_tracker_snapshot( + self, type: str = "current" + ) -> dict[torch.device, dict[str, int]]: + """ + Capture a snapshot of the memory usage breakdown per device, based on the specified type. + + Args: + type (str): The type of snapshot to capture. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + Returns: + Dict[torch.device, Dict[str, int]]: A dictionary where each key is a torch.device, and each value is another + dictionary. This inner dictionary has keys representing memory reference + types as defined in ``_MemRefType`` and values representing the amount of + memory consumed in bytes. + Raises: + ValueError: If an invalid type is specified. + """ + if type == "current": + return deepcopy(self._curr_mem_snap) + elif type == "peak": + return deepcopy(self._peak_mem_snap) + else: + raise ValueError(f"Invalid type {type}") + + def _track_module_params_and_buffers( + self, module: nn.Module, install_grad_hooks: bool = True + ) -> tuple[int, int]: + # Track the parameters and buffers of the module if not already tracked. + # If the parameters have gradients, track the gradients as well. + # If install_grad_hooks is True, install a gradient hook on the parameters + # to track the gradients, if it has not already been installed. + # Return the total memory consumed by the parameters and buffers. + def _grad_hook(grad: torch.Tensor) -> None: + self._update_and_maybe_create_winfos( + grad, + _MemRefType.GRAD, + ) + + param_memory = 0 + for param in module.parameters(): + winfos = self._update_and_maybe_create_winfos( + param, + _MemRefType.PARAM, + ) + param_memory += sum(winfo.mem_consumed for winfo in winfos) + if param.grad is not None: + self._update_and_maybe_create_winfos( + param.grad, + _MemRefType.GRAD, + ) + if ( + self._param_to_grad_hook_handles.get(param, None) is None + and install_grad_hooks + ): + grad_hook_handle = param.register_hook(_grad_hook) + post_acc_grad_hook_handle = param.register_post_accumulate_grad_hook( + lambda p: (_grad_hook(p.grad)) + ) + self._param_to_grad_hook_handles[param] = ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) + buffer_memory = 0 + for buffer in module.buffers(): + winfos = self._update_and_maybe_create_winfos( + buffer, + _MemRefType.BUFFER, + ) + buffer_memory += sum(winfo.mem_consumed for winfo in winfos) + return (param_memory, buffer_memory) + + def _track_inputs_or_outputs(self, args: Any) -> int: + # Calculate the memory consumed by the inputs or outputs of the module. + input_or_output_memory = 0 + + def add_inps_or_outs(t: torch.Tensor) -> None: + nonlocal input_or_output_memory + sts = get_untyped_storages(t) + for st in sts: + winfo, _ = self._WINFO.get(st, (None, None)) + if winfo is not None: + input_or_output_memory += winfo.mem_consumed + + tree_map_only(torch.Tensor, add_inps_or_outs, args) + return input_or_output_memory + + def _pre_fw_hook(self, module: nn.Module, inputs: Any) -> None: + # This is installed as a pre-fwd user hook with ``ModTracker.`` Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: If the module is not in the ``memory_tracking`` dictionary, we track the parameters, buffers, + # input and output memory of the module. Create a new ``_ModMemStats`` instance for the module + # and add it to the ``memory_tracking`` dictionary. + # Case 2: If the module is already in the ``memory_tracking`` dictionary and we are in backward, this means + # we are in the AC region. We check if this is the top most module in the AC region. If it is, + # we store a weak reference and set the flag ``_in_ac`` to True. + # Case 3: If the module is already in the ``memory_tracking`` dictionary and we are in forward, this means + # this module is called for the second time. If it is a root module, that means we are in the next + # iteration and we error out. If it is not a root module, that means it's a submodule that is being + # used multiple times in the same iteration, which we allow and track. + # For Case 1 and 3, we also initialize the ``local_peak`` and ``PEAK_FW`` snapshot for the module. + mod_name = self._mod_tracker.get_known_fqn(module) + assert mod_name is not None + if module not in self.memory_tracking: + mod_stats = _ModMemStats(mod_name) + param_mem, buffer_mem = self._track_module_params_and_buffers( + module, install_grad_hooks=True + ) + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.parameter_mem = param_mem + mod_stats.buffer_mem = buffer_mem + mod_stats.input_mem = input_mem + self.memory_tracking[module] = mod_stats + state = _ModState.PRE_FW + + elif self._mod_tracker.is_bw: + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW_AC + if self._ac_mod is None: + self._ac_mod = weakref.ref(module) + self._in_ac = True + else: + parents = set(self._mod_tracker.parents) - {mod_name} + if len(parents) == 1 and "Global" in parents: + raise NotImplementedError( + "MemTracker does not support memory tracking for multiple iterative calls." + " Either use ``reset_mod_stats`` to clear module memory stats for the previous iteration" + " or file a github issue if you need this feature." + ) + mod_stats = self.memory_tracking[module] + state = _ModState.PRE_FW + input_mem = self._track_inputs_or_outputs(inputs) + mod_stats.mod_fqn = mod_name + mod_stats.input_mem = input_mem + + mem_snapshot = self.get_tracker_snapshot() + if state == _ModState.PRE_FW: + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_FW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(state, []).append(deepcopy(mem_snapshot)) + + def _post_fw_hook(self, module: nn.Module, inputs: Any, outputs: Any) -> None: + # This is installed as a post-fwd user hook with ``ModTracker``. Based on the following cases we + # set the state and capture the memory snapshot for the module. + # Case 1: This is called in backward, which means we are in the AC region. If this is the top most module + # in the AC region, we set the flag ``_in_ac`` to False. + # Case 2: This is called in forward so we calculate the output memory + # of the module and update its mod_stats. + mod_stats = self.memory_tracking[module] + if self._mod_tracker.is_bw: + state = _ModState.POST_FW_AC + if self._ac_mod is not None and self._ac_mod() is module: + self._ac_mod = None + self._in_ac = False + else: + state = _ModState.POST_FW + output_mem = self._track_inputs_or_outputs(outputs) + mod_stats.output_mem = output_mem + mod_stats.snapshots.setdefault(state, []).append(self.get_tracker_snapshot()) + + def _pre_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a pre-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module. We also initialize the ``local_peak`` and ``PEAK_BW`` snapshot for it. + # If the module is None, we skip the hook. + # This can happen since this installed inside a multi-grad hook on the module's output tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping PRE_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mem_snapshot = self.get_tracker_snapshot() + mod_stats.local_peak = { + dev: dev_snap[_TOTAL_KEY] for dev, dev_snap in mem_snapshot.items() + } + mod_stats.snapshots.setdefault(_ModState.PEAK_BW, []).append(mem_snapshot) + mod_stats.snapshots.setdefault(_ModState.PRE_BW, []).append( + deepcopy(mem_snapshot) + ) + + def _post_bw_hook(self, module: nn.Module, args: Any) -> None: + # This is installed as a post-bwd user hook with ``ModTracker``. We set the state and capture the + # snapshot for the module if it is not None. + # This can happen since this installed inside a multi-grad hook on the module's input tensors + # and the module itself may not be alive during backward. + if module is None: + warnings.warn("Module is None. Skipping POST_BW hook.", stacklevel=2) + return + mod_stats = self.memory_tracking[module] + mod_stats.snapshots.setdefault(_ModState.POST_BW, []).append( + self.get_tracker_snapshot() + ) + + def _track_optimizer_states( + self, reftype: _RefType, optimizer: optim.Optimizer + ) -> None: + for states in optimizer.state.values(): + for val in states.values(): + if isinstance(val, torch.Tensor): + self._update_and_maybe_create_winfos( + val, + reftype, + ) + + def _register_global_optimizer_hook(self) -> None: + # Register a hook on the optimizer step to track the optimizer states. + # The pre-hook is to set the flag ``_in_opt`` to True. The post-hook unsets the flag, + # and also tracks any optimizer states that are created during the optimizer step. + def _opt_step_pre_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._in_opt = True + + def _opt_step_post_hook( + optimizer: optim.Optimizer, args: Any, kwargs: Any + ) -> None: + self._track_optimizer_states(_MemRefType.OPT, optimizer) + self._in_opt = False + + self._optimizer_hook_handles = ( + register_optimizer_step_pre_hook(_opt_step_pre_hook), + register_optimizer_step_post_hook(_opt_step_post_hook), + ) + + def _deregister_param_and_optimizer_hooks(self) -> None: + for ( + grad_hook_handle, + post_acc_grad_hook_handle, + ) in self._param_to_grad_hook_handles.values(): + grad_hook_handle.remove() + post_acc_grad_hook_handle.remove() + self._param_to_grad_hook_handles.clear() + + if self._optimizer_hook_handles is not None: + for handle in self._optimizer_hook_handles: + handle.remove() + self._optimizer_hook_handles = None + + def track_external( + self, *external: Union[nn.Module, optim.Optimizer, torch.Tensor] + ) -> None: + """ + Track tensors and stateful objects like modules, optimizers etc. that are created outside the MemTracker. + + This method should be called before the ``MemTracker`` is used. Any tensors that are not module parameters, buffers, + gradients activations, or optimizer states will be categorized as ``Other``. If you want them categorized with a + custom name, please file a GitHub issue. Any tensors created outside the MemTracker and not supplied to this + method will not be be tracked by ``MemTracker``. + + Args: + *external (Union[nn.Module, optim.Optimizer, torch.Tensor]): The external modules, optimizers, and + tensors to be tracked. + """ + flat_external, _ = tree_flatten(external) + for obj in flat_external: + if isinstance(obj, torch.Tensor): + self._update_and_maybe_create_winfos( + obj, + _MemRefType.OTH, + ) + elif isinstance(obj, torch.nn.Module): + self._track_module_params_and_buffers(obj, install_grad_hooks=False) + elif isinstance(obj, optim.Optimizer): + self._track_optimizer_states(_MemRefType.OPT, obj) + elif obj is None: + continue + else: + raise TypeError( + f"Object of type {type(obj)} is not supported for tracking. " + f"Only stateful objects like modules, optimizers, and tensors are supported." + ) + + def display_snapshot( + self, type: str = "current", units: str = "B", tabulate: bool = False + ) -> None: + """ + Display the memory usage breakdown snapshot of the tracker based on the specified type and units. + + Keyword args: + type (str): The type of snapshot to display. Can be "current" for the current memory usage or "peak" for the + peak memory usage. Defaults to "current". + units (str): The units to use for displaying memory usage. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool): Whether to display the snapshot in a tabular format. Defaults to False. + """ + snapshot = self.get_tracker_snapshot(type) + if tabulate: + _print_snapshot_tabular(snapshot, units) + else: + _print_snapshot(snapshot, units) + + def display_modulewise_snapshots( + self, depth: int = 2, units: str = "B", tabulate: bool = False + ) -> None: + """ + Print per device memory breakdown snapshot for each module called within MemTracker. + + Snapshots are displayed for the states defined by ``_ModState``. + The module hierarchy is displayed up to the specified depth. + + Keyword Args: + depth (int, optional): The depth of the module hierarchy to display. Defaults to 2. + units (str, optional): The units to use for memory tracking. Defaults to "B". Supports ["B", "KiB", "MiB", "GiB"]. + tabulate (bool, optional): Whether to display the snapshot in a tabular format. Defaults to False. + """ + + def natural_sort_key(s: str) -> list[Union[int, str]]: + return [ + int(text) if text.isdigit() else text.lower() + for text in re.split("([0-9]+)", s) + ] + + for mod_stats in sorted( + self.memory_tracking.values(), + key=lambda m_stats: natural_sort_key(m_stats.mod_fqn), + ): + mod_fqn = mod_stats.mod_fqn + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + if tabulate: + _print_state_snapshots_tabular(mod_stats.snapshots, units) + else: + _print_state_snapshots(mod_stats.snapshots, units) + + def reset_mod_stats(self) -> None: + """ + Reset all the module memory stats. Clears ``memory_tracking`` dictionary. + """ + self.memory_tracking.clear() + + def _track_dtensor_dispatch(self) -> None: + def track_dtensor_dispatch( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + with ( + self + if op_call in DTensor._op_dispatcher._custom_op_handlers + else nullcontext() + ): + return self._orig_dtensor_dispatch(op_call, args, kwargs) + + DTensor._op_dispatcher.dispatch = track_dtensor_dispatch # type: ignore[method-assign, assignment] + + def _restore_dtensor_dispatch(self) -> None: + DTensor._op_dispatcher.dispatch = self._orig_dtensor_dispatch # type: ignore[method-assign] + + def __enter__(self) -> "MemTracker": + if self._depth == 0: + self._register_global_optimizer_hook() + self._mod_tracker.register_user_hooks( + self._pre_fw_hook, + self._post_fw_hook, + self._pre_bw_hook, + self._post_bw_hook, + ) + self._track_resize() + self._track_dtensor_dispatch() + self._peak_mem_snap = self.get_tracker_snapshot() + self._peak_mem = { + dev: dev_snap[_TOTAL_KEY] + for dev, dev_snap in self._peak_mem_snap.items() + } + self._mod_tracker.__enter__() + super().__enter__() + self._depth += 1 + return self + + def __exit__(self, *args: Any) -> None: + self._depth -= 1 + if self._depth == 0: + self._deregister_param_and_optimizer_hooks() + self._mod_tracker.clear_user_hooks() + self._restore_resize() + self._restore_dtensor_dispatch() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): # type: ignore[no-untyped-def] + if ( + func == torch.ops._c10d_functional.wait_tensor.default + and active_fake_mode() + ): + # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns + # a new tensor which does not happen in eager mode, when a wait_tensor is called. + res = args[0] + else: + res = func(*args, **kwargs or {}) + # If we are tracking an optimizer state, we use the optimizer reference type. + # If we are in backward region and not in AC region, we use the backward reference type. + # Else we use the forward reference type. + if self._in_opt: + reftype = _MemRefType.OPT + elif self._mod_tracker.is_bw and not self._in_ac: + reftype = _MemRefType.TEMP + else: + reftype = _MemRefType.ACT + tree_map_only(torch.Tensor, partial(self._track, reftype), res) + peak_state = _ModState.PEAK_BW if self._mod_tracker.is_bw else _ModState.PEAK_FW + self._update_peak_stats(peak_state) + return res diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/memory_tracker.py b/phivenv/Lib/site-packages/torch/distributed/_tools/memory_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..061ad878380a9e519904ba6e8f346c7cc46417be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/memory_tracker.py @@ -0,0 +1,300 @@ +# mypy: allow-untyped-defs +import operator +import pickle +from collections import defaultdict +from collections.abc import Sequence +from itertools import chain +from typing import Any, Callable, no_type_check, TYPE_CHECKING + +import torch +import torch.nn as nn +from torch.utils._python_dispatch import TorchDispatchMode + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + + +BYTES_PER_MB = 1024 * 1024.0 + + +class MemoryProfileDispatchMode(TorchDispatchMode): + """Run in ``TorchDispatchMode`` to get memory stats at operator level.""" + + def __init__(self, memory_tracker) -> None: + self.memory_tracker = memory_tracker + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): + rs = func(*args, **kwargs) + if func == torch.ops.aten.detach.default: + return rs + func_name: str = ( + self.memory_tracker._cur_module_name + + "." + + func.__name__ + + "_" + + str(self.memory_tracker._operator_names[func.__name__]) + ) + self.memory_tracker._operator_names[func.__name__] = ( + self.memory_tracker._operator_names[func.__name__] + 1 + ) + self.memory_tracker._record_memory_stats(func_name) + + return rs + + +class MemoryTracker: + """ + Collect and plot the memory stats at operator level. + + Includes ``memories_allocated``, ``memories_active`` and ``memories_reserved``. + It also prints a summary for the top 20 operators that generate the most memories. + + Example usage: + + >>> # xdoctest: +SKIP(failing) + >>> net.cuda() + >>> input = input.cuda() + + >>> mem_tracker = MemoryTracker() + >>> mem_tracker.start_monitor(net) + + >>> net.zero_grad(True) + >>> loss = net(input) + >>> if isinstance(loss, dict): + >>> loss = loss['out'] + >>> loss.sum().backward() + >>> net.zero_grad(set_to_none=True) + + >>> mem_tracker.stop() + >>> mem_tracker.summary() + >>> mem_tracker.show_traces() + """ + + def __init__(self) -> None: + torch._C._log_api_usage_once("torch.distributed.memory_tracker") + self._hooks: list[RemovableHandle] = [] + self._operator_names: dict[str, int] = defaultdict(int) + self.memories_allocated: dict[int, dict[str, float]] = defaultdict() + self.memories_active: dict[int, dict[str, float]] = defaultdict() + self.memories_reserved: dict[int, dict[str, float]] = defaultdict() + self._markers: dict[str, int] = defaultdict(int) + self._cur_module_name: str = "" + self._op_index: int = 0 + self._num_alloc_retries: int = 0 + self._device_module = torch.get_device_module() + + @no_type_check + def start_monitor(self, root_module: nn.Module) -> None: + """ + Register module hooks and entering ``MemoryProfileDispatchMode``. + + This enables operator level memory stats can be tracked during module runtime. + """ + self._clear_state() + root_module.__setattr__("_memory_tracker_is_root", True) + for name, m in root_module.named_modules(): + if m is not root_module: + m.__setattr__("_memory_tracker_is_root", False) + # fused_proxy_group does not support hooks + if ".fused_proxy_grouped_embedding_bag" in name: + continue + # hook ordering with other hooks added by users is not managed, so + # the memory stats tracked here may not completely accurate. + h1 = m.register_forward_pre_hook(self._create_pre_forward_hook(name)) + h2 = m.register_forward_hook(self._create_post_forward_hook(name)) + # it does not work well with jagged tensor somehow, the root cause is not + # clear and remove it for now as it does not really capture important info. + # h3 = m.register_backward_hook(self._create_backward_hook(name)) + self._hooks.extend([h1, h2]) + self._device_module.empty_cache() + assert getattr(self, "profile_mode", None) is None + self.profile_mode = MemoryProfileDispatchMode(self) + self.profile_mode.__enter__() + + @no_type_check + def stop(self) -> None: + """ + Remove module hooks and exit ``MemoryProfileDispatchMode`` to stop tracking memory stats at operator level. + + Get some aggregated stats when the memory_tracker() is enabled, like ``num_alloc_retries``. + """ + self._num_alloc_retries = self._device_module.memory_stats().get( + "num_alloc_retries", 0 + ) + + for h in self._hooks: + h.remove() + self._hooks.clear() + assert getattr(self, "profile_mode", None) is not None + self.profile_mode.__exit__(None, None, None) + self.profile_mode = None + + @no_type_check + def summary(self, top: int = 20) -> None: + """ + Print out the top operators that generate the most memories. + + The number of the top operators can be configured. + """ + op_diff: dict[str, float] = defaultdict(float) + op_name, previous_allocated_memory = self.memories_allocated[0] + for i in range(1, self._op_index): + op_name, current_allocated_memory = self.memories_allocated[i] + op_diff[op_name] = current_allocated_memory - previous_allocated_memory + previous_allocated_memory = current_allocated_memory + + print("------------------------------------------------") + print(f"The number of alloc retries are: {self._num_alloc_retries}") + print(f"Top {top} ops that generates memory are:") + for k, v in sorted(op_diff.items(), key=operator.itemgetter(1), reverse=True)[ + :top + ]: + print(f"{k}: {v}MB") + print("------------------------------------------------") + + @no_type_check + def show_traces(self, path: str = "") -> None: + import matplotlib.pyplot as plt + + def _plot_figure(x, y_values, labels): + min_val = min(chain.from_iterable(y_values)) * 0.999 + max_val = max(chain.from_iterable(y_values)) * 1.001 + plt.figure() + for y, label in zip(y_values, labels): + plt.plot(x, y, label=label) + plt.xlabel("# Operator Calls") + plt.ylabel("Memory (MB)") + plt.legend() + for marker_name, marker in self._markers.items(): + if marker_name == "fw_bw_boundary": + plt.plot( + [marker, marker], + [min_val, max_val], + "r", + lw=2, + label=marker_name, + ) + else: + plt.plot( + [marker, marker], + [min_val, max_val], + "k-", + lw=2, + label=marker_name, + ) + + if path != "": + self.load(path) + + y_1 = [gb for (name, gb) in self.memories_allocated.values()] + y_2 = [gb for (name, gb) in self.memories_active.values()] + y_3 = [gb for (name, gb) in self.memories_reserved.values()] + x = list(range(len(y_1))) + # Split figures when there is big difference between + # "reserved_memory" and "allocated_memory" or "active_memory". + _plot_figure( + x, + [list(y_1), list(y_2), list(y_3)], + ["allocated_memory", "active_memory", "reserved_memory"], + ) + _plot_figure(x, [list(y_1)], ["allocated_memory"]) + _plot_figure(x, [list(y_2)], ["active_memory"]) + _plot_figure(x, [list(y_3)], ["reserved_memory"]) + + def save_stats(self, path: str) -> None: + """Save the stats using pickle during runtime if users want to plot the traces in other places like notebook.""" + stats = { + "memories_allocated": self.memories_allocated, + "memories_active": self.memories_active, + "memories_reserved": self.memories_reserved, + "markers": self._markers, + "num_alloc_retries": self._num_alloc_retries, + } + + with open(path, "wb") as f: + pickle.dump(stats, f, pickle.HIGHEST_PROTOCOL) + + def load(self, path: str) -> None: + """Load the pickled memory stats to plot the traces or print the summary.""" + with open(path, "rb") as f: + stats = pickle.load(f) + + self.memories_allocated = stats["memories_allocated"] + self.memories_active = stats["memories_active"] + self.memories_reserved = stats["memories_reserved"] + self._markers = stats["markers"] + self._num_alloc_retries = stats["num_alloc_retries"] + + def _create_pre_forward_hook(self, name: str) -> Callable: + """Prefix operator name with current module and 'forward', and insert 'fw_start' marker at forward pass start.""" + + def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: + self._cur_module_name = f"{name}.forward" + if ( + hasattr(module, "_memory_tracker_is_root") + and module._memory_tracker_is_root + ): + self._add_marker("fw_start") + + return _pre_forward_hook + + def _create_post_forward_hook(self, name: str) -> Callable: + """Insert the marker 'fw_bw_boundary' at the boundary of forward and backward pass.""" + + def _post_forward_hook( + module: nn.Module, + inputs: Sequence[torch.Tensor], + outputs: Sequence[torch.Tensor], + ) -> None: + if ( + hasattr(module, "_memory_tracker_is_root") + and module._memory_tracker_is_root + ): + self._add_marker("fw_bw_boundary") + + return _post_forward_hook + + def _create_backward_hook(self, name: str) -> Callable: + """Insert the current module name with backward prefix for the operator name.""" + + def _backward_hook( + module: nn.Module, grad_input: torch.Tensor, grad_output: torch.Tensor + ) -> None: + self._cur_module_name = f"{name}.backward" + + return _backward_hook + + @no_type_check + def _record_memory_stats(self, fn_name: str) -> None: + """ + Record current memory allocated, current memory active and current memory reserved. + + The memory stats dict is indexed with ``self._op_index``. + """ + memory_allocated: float = self._device_module.memory_allocated() / BYTES_PER_MB + memory_reserved: float = self._device_module.memory_reserved() / BYTES_PER_MB + memory_active: float = ( + self._device_module.memory_stats().get("active_bytes.all.current", 0) + / BYTES_PER_MB + ) + self.memories_allocated[self._op_index] = (fn_name, memory_allocated) + self.memories_reserved[self._op_index] = (fn_name, memory_reserved) + self.memories_active[self._op_index] = (fn_name, memory_active) + self._op_index += 1 + + def _add_marker(self, marker_name: str) -> None: + """Set the marker's x-axis value.""" + marker_val = len(self.memories_allocated.values()) + self._markers[marker_name] = marker_val + + def _clear_state(self) -> None: + """Clear states when start_monitor() is called.""" + self._operator_names.clear() + self.memories_allocated.clear() + self.memories_active.clear() + self.memories_reserved.clear() + self._markers.clear() + self._cur_module_name = "" + self._op_index = 0 + self._num_alloc_retries = 0 diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/mod_tracker.py b/phivenv/Lib/site-packages/torch/distributed/_tools/mod_tracker.py new file mode 100644 index 0000000000000000000000000000000000000000..74f122ddeb71b6f2fa7c0194404ff2723aca1a8c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/mod_tracker.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +import warnings +import weakref +from typing import Callable, Optional + +import torch +from torch.autograd.graph import register_multi_grad_hook +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, +) +from torch.utils._pytree import tree_flatten + + +__all__ = ["ModTracker"] + + +class ModTracker: + """ + ``ModTracker`` is a context manager that tracks the nn.Module hierarchy during execution + so that other system can query which Module is currently being executed (or its backward is being + executed). + + You can access the ``parents`` attribute on this context manager to get the set of all the + Modules currently being executed via their fqn (fully qualified name, also used as the key within + the state_dict). + You can access the ``is_bw`` attribute to know if you are currently running in backward or not. + + Note that ``parents`` is never empty and always contains the "Global" key. The ``is_bw`` flag + will remain ``True`` after the forward until another Module is executed. If you need it to be + more accurate, please submit an issue requesting this. Adding a map from fqn to the module instance + is possible but not done yet, please submit an issue requesting this if you need it. + + Example usage + + .. code-block:: python + + mod = torch.nn.Linear(2, 2) + + with ModTracker() as tracker: + # Access anything during the forward pass + def my_linear(m1, m2, bias): + print(f"Current modules: {tracker.parents}") + return torch.mm(m1, m2.t()) + bias + + torch.nn.functional.linear = my_linear + + mod(torch.rand(2, 2)) + + """ + + parents: set[str] + """ + A Set containing the fqn for each module currently running their forward + """ + + def __init__(self): + self.parents = {"Global"} + self._active_module_cnt = {} + self._known_modules: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + self._seen_modules: weakref.WeakSet = weakref.WeakSet() + self._has_callback = False + self._post_bw_callbacks_to_enqueue: list[Callable] = [] + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _maybe_set_engine_callback(self): + # This assumes no concurrent calls to backward + if self._has_callback: + return + + for post_bw_callback in reversed(self._post_bw_callbacks_to_enqueue): + torch.autograd.Variable._execution_engine.queue_callback(post_bw_callback) + self._post_bw_callbacks_to_enqueue.clear() + + def callback(): + self.parents = {"Global"} + self._has_callback = False + + torch.autograd.Variable._execution_engine.queue_callback(callback) + self._has_callback = True + + @property + def is_bw(self): + """ + A boolean marking if this is currently running during the backward pass or not + """ + return torch._C._current_graph_task_id() != -1 + + def get_known_fqn(self, mod): + """ + Return the fqn for the given module if it is known to the ``ModTracker``, otherwise ``None``. + """ + return self._known_modules.get(mod, None) + + def register_user_hooks( + self, + pre_fw_hook: Optional[Callable] = None, + post_fw_hook: Optional[Callable] = None, + pre_bw_hook: Optional[Callable] = None, + post_bw_hook: Optional[Callable] = None, + ): + """ + Registers user-specified hooks to be called before/after the forward/backward pass for each + module tracked by the ``ModTracker``. One or more can be ``None``. + Args: + pre_fw_hook (Callable, optional): A hook to be called before the forward pass for the + module. It should have the following signature: + pre_fw_hook (module, input) -> None + post_fw_hook (Callable, optional): A hook to be called after the forward pass for the + module. It should have the following signature: + post_fw_hook (module, input, output) -> None + pre_bw_hook (Callable, optional): A multi-grad hook to be called on all the outputs of + the module that require gradients. It should have the following signature: + pre_bw_hook (module, grad_output) -> None + post_bw_hook (Callable, optional): A multi-grad hook to be called on all the inputs of + the module that require gradients. It should have the following signature: + post_bw_hook (module, grad_input) -> None + Raises: + AssertionError: If a new hook is provided when one is already registered. + Note: + If the module is not alive during the backward pass, the pre_bw_hook and post_bw_hook will + will receive None as the module argument. + The module fqn will be present in the ``parents`` attribute when each of the hooks is called. + Hooks are intended to be used as markers only not to modify the inputs/outputs. + """ + + def set_hook(hook, user_hook, hook_name): + if hook is not None and user_hook is not None: + raise AssertionError( + f"Only one {hook_name} can be registered at a time" + f" Clear the existing hook by calling ``clear_user_hooks`` before registering a new one" + ) + return hook + + self._user_pre_fw_hook = set_hook( + pre_fw_hook, self._user_pre_fw_hook, "pre_fw_hook" + ) + self._user_post_fw_hook = set_hook( + post_fw_hook, self._user_post_fw_hook, "post_fw_hook" + ) + self._user_pre_bw_hook = set_hook( + pre_bw_hook, self._user_pre_bw_hook, "pre_bw_hook" + ) + self._user_post_bw_hook = set_hook( + post_bw_hook, self._user_post_bw_hook, "post_bw_hook" + ) + + def clear_user_hooks(self): + """ + Clears the user specified hooks registered with ``register_user_hooks`` + """ + self._user_pre_fw_hook = None + self._user_post_fw_hook = None + self._user_pre_bw_hook = None + self._user_post_bw_hook = None + + def _get_mod_name(self, mod): + if mod not in self._known_modules: + self._known_modules[mod] = type(mod).__name__ + mod_name = self._known_modules[mod] + if mod not in self._seen_modules: + for name, submod in mod.named_children(): + self._known_modules[submod] = f"{mod_name}.{name}" + self._get_mod_name(submod) + self._seen_modules.add(mod) + return mod_name + + def _get_append_fn(self, w_mod, name, is_bw): + def fn(*args): + if is_bw: + self._maybe_set_engine_callback() + if name in self.parents and not self.is_bw: + + def custom_formatwarning(msg, category, filename, lineno, line=None): + return f"{filename}:{lineno}: {category.__name__}: {msg} \n" + + warnings.formatwarning = custom_formatwarning + warnings.warn( + "The module hierarchy tracking maybe be messed up." + " Please file a bug to PyTorch, if it is the case." + ) + if name not in self.parents: + self._active_module_cnt[name] = 1 + self.parents.add(name) + else: + self._active_module_cnt[name] += 1 + + if self._user_pre_bw_hook is not None and is_bw: + self._user_pre_bw_hook(w_mod(), args) + + return fn + + def _get_pop_fn(self, w_mod, name, is_bw): + def fn(*args): + if self._user_post_bw_hook is not None and is_bw: + self._user_post_bw_hook(w_mod(), args) + if name in self.parents: + self._active_module_cnt[name] -= 1 + if self._active_module_cnt[name] == 0: + self.parents.remove(name) + elif not self.is_bw: + # Due to some input/output not requiring gradients, we cannot enforce + # proper nesting in backward + raise RuntimeError( + "The Module hierarchy tracking is wrong. Report a bug to PyTorch" + ) + + return fn + + def _fw_pre_hook(self, mod, input): + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + self._get_append_fn(w_mod, name, False)() + if self._user_pre_fw_hook is not None: + self._user_pre_fw_hook(mod, input) + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw: + if tensors: + register_multi_grad_hook(tensors, self._get_pop_fn(w_mod, name, True)) + else: + self._post_bw_callbacks_to_enqueue.append( + self._get_pop_fn(w_mod, name, True) + ) + + def _fw_post_hook(self, mod, input, output): + name = self._get_mod_name(mod) + w_mod = weakref.ref(mod) + if self._user_post_fw_hook is not None: + self._user_post_fw_hook(mod, input, output) + self._get_pop_fn(w_mod, name, False)() + args, _ = tree_flatten(output) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook( + tensors, self._get_append_fn(w_mod, name, True), mode="any" + ) + + def __enter__(self): + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook( + self._fw_post_hook, always_call=True + ) + return self + + def __exit__(self, *args): + self._fw_pre_handle.remove() + self._fw_post_handle.remove() diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/runtime_estimator.py b/phivenv/Lib/site-packages/torch/distributed/_tools/runtime_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..a3720055850625580c74e7b72ef04d942232718a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/runtime_estimator.py @@ -0,0 +1,527 @@ +# Owner(s): ["module: unknown"] +import math +import os +from collections import defaultdict +from typing import Any, Callable +from typing_extensions import Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import active_fake_mode +from torch._inductor.utils import get_device_tflops, get_gpu_dram_gbps +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.mod_tracker import ModTracker +from torch.utils._mode_utils import no_dispatch +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils.flop_counter import flop_registry + + +aten = torch.ops.aten + +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + +# No fall-back kernel needed/exists for view ops +_VIEW_OPS = { + aten.lift_fresh, + aten.t, + aten.transpose, + aten.view, + aten.detach, + aten._unsafe_view, + aten.split, + aten.adjoint, + aten.as_strided, + aten.diagonal, + aten.expand, + aten.expand_as, + aten.movedim, + aten.permute, + aten.select, + aten.squeeze, + aten.mT, + aten.mH, + aten.real, + aten.imag, + aten.view_as, + aten.unflatten, + aten.unfold, + aten.unbind, + aten.unsqueeze, + aten.vsplit, + aten.hsplit, + aten.split_with_sizes, + aten.swapaxes, + aten.swapdims, + aten.chunk, +} +# We can ignore benchmarking tensor create ops +_CREATE_OPS = { + aten.randint, + aten.randn, + aten.rand, + aten.randn_like, + aten.rand_like, + aten.randint_like, + aten.arange, + aten.ones_like, + aten.zeros_like, +} + +_IGNORE_OPS = _VIEW_OPS | _CREATE_OPS + +__all__ = ["RuntimeEstimator"] + + +class RuntimeEstimator(TorchDispatchMode): + """ + Estimates the GPU runtime in milliseconds using various estimation methods under the ``FakeTensorMode``. + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the eager + runtime of PyTorch functions. It supports two estimation modes, benchmarking (`operator-level-benchmark`) and + roofline cost modeling (`operator-level-cost-model`). + For modules executed under this context manager, it aggregates the forward and backward operation runtimes + and also records their execution orders. + + Attributes: + mod_runtimes (Dict[str, Dict[str, float]]): A dictionary of module runtimes. The key to the outer dictionary + is the fully qualified name (FQN) of the module. For each module the forward and backward runtimes of the + operations are aggregated in the inner dictionary keyed by 'fw' and 'bw'. + mod_fw_pre_order (List[str]): List of module FQNs in pre-forward execution order. + mod_bw_pre_order (List[str]): List of module FQNs in pre-backward execution order. + mod_fw_post_order (List[str]): List of module FQNs in post-forward execution order. + mod_bw_post_order (List[str]): List of module FQNs in post-backward execution order. + total_runtime (float): The total estimated runtime in milliseconds. + + Note: + 1) The benchmarking estimate mode will execute kernels on GPU and assumes that every operation can run in + isolation without causing an OOM error. It is also designed to be used only under ``FakeTensorMode``. + 2) Currently wrapper tensor sub-classes such as ``DTensor`` won't produce correct estimates. We plan to support + them in future PRs. + 3) We only estimate the compute time, if your code has communication, it will not be considered. Again, we will + support this in future PRs. + + Example usage: + + .. code-block:: python + + runtime_estimator = RuntimeEstimator() + with FakeTensorMode(): + module = ... + optimizer = ... + inp = ... + with runtime_estimator(estimate_mode_type="operator-level-cost-model"): + loss = module(inp) + loss.backward() + optimizer.step() + optimizer.zero_grad() + runtime_estimator.display_modulewise_stats() + """ + + _float_types: set[torch.dtype] = { + torch.float16, + torch.bfloat16, + torch.float32, + torch.float64, + } + _no_fallback_kernel: set[torch._ops._OpNamespace] = set() + fake_mode: FakeTensorMode + + def __init__(self) -> None: + super().__init__() + self._estimate: Callable + self._estimate_mode_type: str + self._mod_tracker = ModTracker() + self.mod_runtimes: dict[str, dict[str, float]] = defaultdict( + lambda: defaultdict(lambda: 0.0) + ) + self.mod_fw_pre_order: list[str] = [] + self.mod_bw_pre_order: list[str] = [] + self.mod_fw_post_order: list[str] = [] + self.mod_bw_post_order: list[str] = [] + self.total_runtime: float = 0.0 + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_subclasses/fake_tensor.py#L1969 # noqa: PGH004,B950 + # NB: returns fake tensors + @classmethod + def _maybe_run_and_benchmark_fallback_kernel( # type: ignore[no-untyped-def] + cls, + func, + args, + kwargs, + orig_not_implemented_exception, + ): + """ + Runs and benchmarks a fallback kernel for a given function. + + Args: + func (Callable): The function to benchmark. + args (Tuple): The arguments to pass to the function. + kwargs (Dict[str, Any]): The keyword arguments to pass to the function. + orig_not_implemented_exception (Exception): The original exception to raise if the fallback kernel + is not implemented. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + # these should all be supported, just to be safe + # avoid fallback for operators which inplace modify metadata + # because the input fake tensors would be umodified + if torch.Tag.inplace_view in func.tags: # type: ignore[attr-defined] + raise orig_not_implemented_exception + + inp_impls = {} + flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # Don't use in_kernel_invocation_manager(fake_mode) as we want to do + # REAL compute (not with meta device) + with no_dispatch(): + + def to_real_tensor(e): # type: ignore[no-untyped-def] + if cls.fake_mode.is_our_fake(e): + if e.dtype in cls._float_types: + out = torch.rand_like(e, device=e.fake_device) + else: + out = torch.ones_like(e, device=e.fake_device) + if e.is_sparse: + out._coalesced_(e.is_coalesced()) + inp_impls[id(out)] = e + return out + return e + + flat_args = [to_real_tensor(a) for a in flat_args] + args, kwargs = pytree.tree_unflatten(flat_args, args_spec) + r = func(*args, **kwargs) + warmup_iters, actual_iters = 2, 3 + for _ in range(warmup_iters): + func(*args, **kwargs) + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record(torch.cuda.current_stream()) + for _ in range(actual_iters): + func(*args, **kwargs) + end_event.record(torch.cuda.current_stream()) + torch.cuda.synchronize() + cuda_time = start_event.elapsed_time(end_event) + mean_op_time = cuda_time / actual_iters + + storages = set() + + for e in flat_args: + if isinstance(e, torch.Tensor): + if not e.is_sparse: + storages.add(e._typed_storage()._cdata) + + # TODO: also check metadata change on inputs + # proper aliasing/metadata relationship between outputs and inputs will + # not be set up, bc of conversion to device, unless we can reuse an + # input impl + + def map_out(e): # type: ignore[no-untyped-def] + if id(e) not in inp_impls and ( + isinstance(e, torch.Tensor) + and not e.is_sparse + and e._typed_storage()._cdata in storages + ): + raise orig_not_implemented_exception + + if isinstance(e, torch.Tensor): + if id(e) in inp_impls: + return inp_impls[id(e)] + else: + return cls.fake_mode.fake_tensor_converter.from_real_tensor( + cls.fake_mode, e + ) + else: + return e + + return (pytree.tree_map(map_out, r), mean_op_time) + + @classmethod + def _benchmark_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using benchmarking. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + res: The result of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert isinstance(cls.fake_mode, FakeTensorMode), ( + "Initialize/Assign FakeTensorMode before using this function" + ) + mean_op_time = 0.0 + if func._overloadpacket not in _VIEW_OPS: + try: + res, mean_op_time = cls._maybe_run_and_benchmark_fallback_kernel( + func, + args, + kwargs, + NotImplementedError, + ) + return (res, mean_op_time) + except NotImplementedError: + cls._no_fallback_kernel.add(func._overloadpacket) + res = func(*args, **kwargs or {}) + return (res, mean_op_time) + + # Adapted from: https://github.com/pytorch/pytorch/blob/9b902b3ee3bd608a19543362b66bf06c373dd374/torch/_inductor/scheduler.py#L589 # noqa: PGH004,B950 + @classmethod + def _roofline_estimate(cls, func, args, kwargs) -> tuple[Any, float]: # type: ignore[no-untyped-def] + """ + Estimates the runtime of a function using a roofline cost model. + + Args: + func: The function to estimate. + args: The arguments to pass to the function. + kwargs: The keyword arguments to pass to the function. + out: The output of the function. + + Returns: + Tuple[Any, float]: A tuple containing the result of the function and + the mean operation time in milliseconds. + """ + assert torch.cuda.is_available(), ( + "Roofline estimation needs to access CUDA capabilities to make estimations" + ) + + def get_num_bytes(t: torch.Tensor) -> int: + """ + Calculates the memory consumption of a tensor. + + Args: + t (torch.Tensor): The input tensor. + + Returns: + int: The memory consumption of the tensor in bytes. + """ + num_bytes = t.untyped_storage().nbytes() + mem_consumed = ( + math.ceil(num_bytes / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + ) + return mem_consumed + + def get_compute_time(func_packet, args, kwargs, out, out_dtypes) -> float: # type: ignore[no-untyped-def] + """ + Estimates the compute time of an aten operator. + + Args: + func_packet: The operator overload packet. + args: The arguments to the operator. + kwargs: The keyword arguments to the operator. + out: The output of the operator. + out_dtypes: The output data types. + + Returns: + float: The estimated compute time in nanoseconds. + """ + if func_packet in flop_registry: + assert len(out_dtypes) == 1, ( + f"Only support single out dtype got {out_dtypes} for {func_packet}" + ) + dtype = out_dtypes.pop() + # This actually gives peta-FLOPs/s hence multiply by 1e15 to get the FLOPs/s + peak_gpu_flops = get_device_tflops(dtype) * 1e15 + # We can expect to achieve 75% of theoretical peak flops + factor = 0.75 + peak_empirical_flops = factor * peak_gpu_flops + flop_count_func = flop_registry[func_packet] + # We divide by a factor of 2 to get the MACs (multiply and accumulate) + flop_count = flop_count_func(*args, **kwargs, out_val=out) / 2 + # We multiply by 1e9 to get the time in nano seconds + compute_time = (flop_count / peak_empirical_flops) * 1e9 + return compute_time + return 0.0 + + def get_transfer_time(flat_args_kwargs, flat_outs) -> float: # type: ignore[no-untyped-def] + """ + Estimates the memory transfer time of input and output tensors. + + Args: + flat_args_kwargs (List[torch.Tensor]): The flat list of arguments and keyword arguments. + flat_outs (List[torch.Tensor]): The flat list of outputs. + + Returns: + float: The estimated memory transfer time in nanoseconds. + """ + gpu_memory_bandwidth = get_gpu_dram_gbps() + read_bytes = sum( + get_num_bytes(t) + for t in flat_args_kwargs + if isinstance(t, torch.Tensor) + ) + write_bytes = sum( + get_num_bytes(t) for t in flat_outs if isinstance(t, torch.Tensor) + ) + counted_bytes = read_bytes + write_bytes + # The GPU memory bandwidth is in GB/s so the transfer time is in nanoseconds + transfer_time = counted_bytes / gpu_memory_bandwidth + return transfer_time + + # Roofline Cost Model Explanation + + # The roofline cost model estimates the execution time of an operator based on + # the device's empirical maximum FLOPs/sec (pi) and device DRAM bandwidth (beta). + + # Variables: + # - pi: Maximum empirical FLOPs/sec of the device + # - beta: Maximum empirical device DRAM bandwidth (bytes/sec) of the device + # - I: Arithmetic intensity of the operator (FLOPs/bytes) + # - op_flops: FLOPs required by the operator + # - op_bytes: Bytes transferred to and from DRAM for the operator + + # Calculation Steps: + # 1. Calculate arithmetic intensity: I = op_flops / op_bytes + # 2. Calculate estimated FLOPs/sec: est_flops_sec = min(pi, beta * I) + # 3. Calculate estimated operator time: estimated_op_time = op_flops / est_flops_sec + # This simplifies to: estimated_op_time = max(op_flops / pi, op_flops / (beta * I)) + # Further simplifying: estimated_op_time = max(op_flops / pi, op_bytes / beta) + + # Simplified Formulas: + # - compute_time = op_flops / pi + # - transfer_time = op_bytes / beta + # - estimated_op_time = max(compute_time, transfer_time) + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + op_time = 0.0 + func_packet = func._overloadpacket + if func_packet not in _IGNORE_OPS: + flat_args_kwargs, args_spec = pytree.tree_flatten((args, kwargs)) + flat_outs, out_spec = pytree.tree_flatten(out) + transfer_time = get_transfer_time(flat_args_kwargs, flat_outs) + + out_dtypes = { + t.dtype + for t in flat_outs + if isinstance(t, torch.Tensor) and t.dtype in cls._float_types + } + + args, kwargs = pytree.tree_unflatten(flat_args_kwargs, args_spec) + out = pytree.tree_unflatten(flat_outs, out_spec) + + compute_time = get_compute_time(func_packet, args, kwargs, out, out_dtypes) + # We get the estimated time as the max of the transfer time and + # compute time. We divide by 1e6 to get the time in ms + op_time = max(transfer_time, compute_time) / 1e6 + + return (out, op_time) + + def display_modulewise_stats(self, depth: int = 2) -> None: + """ + Displays module-wise statistics collected by ``RuntimeEstimator``. + + Prints the pre-forward and pre-backward execution orders. + Displays the module-wise forward and backward runtimes in milliseconds. + + Args: + depth (int): The maximum depth of module hierarchy to display (default to 2). + """ + print("Pre-Forward Execution Order: ") + for mod_fqn in self.mod_fw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + print("Pre-Backward Execution Order: ") + for mod_fqn in self.mod_bw_pre_order: + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(mod_fqn) + for mod_fqn, runtimes in self.mod_runtimes.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print( + f"{mod_fqn} fw: {runtimes.get('fw', 0.0):.3f}ms bw: {runtimes.get('bw', 0.0):.3f}ms" + ) + + def __torch_dispatch__(self, func, types, args=..., kwargs=None): # type: ignore[no-untyped-def] + # TODO: @sanketpurandare: Flatten tensors by desugaring the tensor subclasses + # TODO: @sanketpurandare: Add logic for incorporating communication time + res, op_time = self._estimate(func, args, kwargs) + for par in self._mod_tracker.parents: + if self._mod_tracker.is_bw: + self.mod_runtimes[par]["bw"] += op_time + else: + self.mod_runtimes[par]["fw"] += op_time + self.total_runtime += op_time + return res + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + RuntimeEstimator: The runtime estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + self._estimate_mode_type = estimate_mode_type + return self + + def __enter__(self) -> Self: + fake_mode = active_fake_mode() + assert isinstance(fake_mode, FakeTensorMode), ( + "No FakeTensorMode found, designed to used under FakeTensorMode" + ) + RuntimeEstimator.fake_mode = fake_mode + self.total_runtime = 0.0 + self.mod_runtimes = defaultdict(lambda: defaultdict(lambda: 0.0)) + self.mod_fw_pre_order.clear() + self.mod_bw_pre_order.clear() + self.mod_fw_post_order.clear() + self.mod_bw_post_order.clear() + self._mod_tracker.register_user_hooks( + pre_fw_hook=lambda mod, inp: self.mod_fw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + pre_bw_hook=lambda mod, g_out: self.mod_bw_pre_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_fw_hook=lambda mod, inp, out: self.mod_fw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + post_bw_hook=lambda mod, g_inp: self.mod_bw_post_order.append( + self._mod_tracker.get_known_fqn(mod) + ), + ) + self._mod_tracker.__enter__() + super().__enter__() + return self + + def __exit__(self, *args: Any) -> None: + print( + f"Estimated ({self._estimate_mode_type})" + f"total_time: {self.total_runtime:.3f} ms" + ) + if len(self._no_fallback_kernel) > 0: + print("no_fallback_kernel: ", list(self._no_fallback_kernel)) + super().__exit__(*args) + self._mod_tracker.clear_user_hooks() + self._mod_tracker.__exit__() diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/sac_estimator.py b/phivenv/Lib/site-packages/torch/distributed/_tools/sac_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..1fda8a175fc7ff8f2b748b1a3e4addc197547c77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/sac_estimator.py @@ -0,0 +1,960 @@ +import math +import os +import sys +from collections import OrderedDict +from dataclasses import astuple, dataclass +from typing import Any, NamedTuple, Optional +from typing_extensions import Self + +import torch +from torch import nan, nn, UntypedStorage +from torch._guards import active_fake_mode +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.distributed._tools.common_utils import get_untyped_storages +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed._tools.runtime_estimator import RuntimeEstimator +from torch.testing._internal.composite_compliance import ( + is_inplace, + is_inplace_view_fn, + is_view_fn, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten +from torch.utils.checkpoint import SAC_IGNORED_OPS + + +__all__ = ["SACEstimator", "SACStats", "MSPS", "SACTradeOffStats", "SACGreedyOrderMeta"] +aten = torch.ops.aten + +_ADDITIONAL_IGNORED_OPS = { + aten.lift_fresh.default, # type: ignore[attr-defined] + torch.ops.profiler._record_function_exit._RecordFunction, # type: ignore[attr-defined] + aten.clone.default, # type: ignore[attr-defined] # seems needed for torch.compile +} +OPS_TO_ALWAYS_SKIP = SAC_IGNORED_OPS | _ADDITIONAL_IGNORED_OPS +# This value is hard-coded here: +# https://github.com/pytorch/pytorch/blob/5fba5d83f0703ff8077ab65448a998e9ad6598fd/c10/cuda/CUDACachingAllocator.cpp#L117 +_PYTORCH_MIN_ALLOCATE = ( + 2**9 if int(os.environ.get("PYTORCH_NO_CUDA_MEMORY_CACHING", 0)) == 0 else 1 +) + + +def _display_stats_tabular(headers: list[str], table_data: list[list[Any]]) -> None: + try: + from tabulate import tabulate + except ImportError as err: + raise ImportError("Please install tabulate.") from err + + # Use tabulate to print the table + print(tabulate(table_data, headers=headers, tablefmt="rst")) + + +# Based on: +# https://github.com/facebookresearch/xformers/blob/main/xformers/checkpoint.py#L71 +@dataclass +class _SACMetadata: + """ + Stores metadata for a single operator for SAC. + + Attributes: + func (Any): The operator function. + time_taken (float): The time taken by the operator. + memory_used (float): The memory used by the operator. + curr_idx (int): The current operator index. + output_ids (Tuple[int, ...]): The storage IDs of the operator's outputs. + inplace_info (Tuple[int, ...]): Tuple of self and parent operator for in-place operator. + is_view_like (bool): Whether the operator is view-like. + is_rand_op (bool): Whether the operator is a random operator. + """ + + func: Any + time_taken: float + memory_used: float + curr_idx: int + output_ids: tuple[int, ...] + inplace_info: tuple[int, ...] + is_view_like: bool + is_rand_op: bool + + +@dataclass +class _SACModMetadata: + """ + Stores metadata for a module for SAC. + + Attributes: + start_idx (int): The starting index of the module's operators. + force_store_random (bool): Whether to force store random operators in the module. + sac_metadata (List[_SACMetadata]): List of metadata for each operator in the module. + """ + + start_idx: int + force_store_random: bool + sac_metadata: list[_SACMetadata] + + +@dataclass +class SACStats: + """ + A class for storing Activation Checkpointing statistics corresponding to a module. + + Attributes: + func_names (List[str]): List of operator names. + runtimes (List[float]): List of operator runtimes in millliseconds. + memory (List[int]): List of operator memory usage in bytes. + view_like_ops (List[int]): Indices of view-like operators. + rand_ops (List[int]): Indices of random operators. + saved_autograd_ops (List[int]): Indices of operator results saved by autograd engine. + inplace_ops (List[Tuple[int, int]]): Tuple of indices of op and its first parent for Inplace operators. + force_store_random (bool): Whether to force store random operator results. + """ + + func_names: list[str] + runtimes: list[float] + memory: list[int] + view_like_ops: list[int] + rand_ops: list[int] + saved_autograd_ops: list[int] + inplace_ops: list[tuple[int, int]] + force_store_random: bool + + +class MSPS(NamedTuple): + """ + Represents Memory and Runtime Statistics for an operator/operator group. + + Attributes: + func_names (set[str]): Set of operator/operator group names. + op_idx (int): Operator index (group head index in case of operator groups). + memory (int): Memory usage in bytes. + runtime (float): Runtime in milliseconds. + msps (float): Memory per second calculated as memory/runtime. + """ + + func_names: set[str] + op_idx: int + memory: int + runtime: float + msps: float + + +@dataclass +class SACTradeOffStats: + """ + Stores statistics for activation-checkpointing trade-off. + + Attributes: + n_segments (int): Number of piecewise linear segments fitted to the trade-off curve. + slopes (List[float]): Slopes of the pieces of linear segments fitted to the trade-off curve. + intercepts (List[float]): Intercepts of the of the pieces of linear segments fitted to the trade-off curve. + fit_breaks (List[float]): Breakpoints of the of the pieces of linear segments fitted to the trade-off curve. + tradeoff_curve (OrderedDict[float, float]): Trade-off curve data of memory discarded vs recomputation time. + sac_memory (int): Total memory of operations available for activation checkpointing in bytes. + sac_runtime (float): Total runtime of operations available for activation checkpointing in milliseconds. + """ + + n_segments: int + slopes: list[float] + intercepts: list[float] + fit_breaks: list[float] + tradeoff_curve: OrderedDict[float, float] + sac_memory: int + sac_runtime: float + + +@dataclass +class SACGreedyOrderMeta: + """ + Stores metadata for Greedy-order SAC. + + Attributes: + recomputed_ops (set[int]): Set of operator indices to be recomputed. + stored_ops (set[int]): Set of operator indices to be stored. + inplace_op_groups (dict[int, set[int]]): Dictionary of inplace operator groups from group-head to operators. + random_ops_group (dict[int, set[int]]): Dictionary of random op group head to random ops. + msps_meta (list[MSPS]): List of Memory and Runtime Statistics for operators. + """ + + recomputed_ops: set[int] + stored_ops: set[int] + inplace_op_groups: dict[int, set[int]] + random_ops_group: dict[int, set[int]] + msps_meta: list[MSPS] + + +class SACEstimator(TorchDispatchMode): + """ + Estimates the memory and recomputation time trade-offs for applying Selective Activation Checkpointing (SAC). + + This class provides a ``TorchDispatchMode`` based context manager that can be used to estimate the memory and + runtime trade-offs of functions or ``torch.nn.Module``s for Selective Activation Checkpointing (SAC). It provides + detailed statistics and metadata information for operators of each module and provides a greedy order for selecting + the operators to be recomputed/checkpointed. It also constructs the per-module trade-off graph of discarded memory + vs recomputation time for the obtained greedy order. Using ``RuntimeEstimator`` under the hood, it supports two + estimation modes, `operator-level-benchmark` and (`operator-level-cost-model` (roofline model). + + Attributes: + sac_mod_stats (Dict[str, SACStats]): Dictionary from module FQN (fully qualified name) to ``SACStats``. + sac_mod_tradeoff_stats (Dict[str, SACTradeOffStats]): Dictionary from module FQN to ``SACTradeOffStats``. + sac_mod_greedy_order_meta (Dict[str, SACGreedyOrderMeta]): Dictionary from module FQN to ``SACGreedyOrderMeta``. + + Note: + 1) This class is designed to be used under ``FakeTensorMode``. + 2) Currently, it only supports estimation of compute time and memory usage, and does not consider communication. + + Example usage: + + .. code-block:: python + + sac_estimator = SACEstimator() + with FakeTensorMode(): + module = ... + inp = ... + with sac_estimator("operator-level-cost-model"): + output = module(inp) + sac_estimator.display_modulewise_sac_stats(depth=4, print_tabular=True) + """ + + def __init__(self) -> None: + self.sac_mod_stats: dict[str, SACStats] = {} + self.sac_mod_tradeoff_stats: dict[str, SACTradeOffStats] = {} + self.sac_mod_greedy_order_meta: dict[str, SACGreedyOrderMeta] = {} + self._mod_tracker = ModTracker() + self._sac_metadata: list[_SACMetadata] = [] + self._sac_mod_metadata: dict[str, _SACModMetadata] = {} + self._leaf_modules: set[str] = set() + self._saved_tensor_hook_ctx = torch.autograd.graph.saved_tensors_hooks( + self._pack_hook, lambda x: x + ) + self._saved_tensor_ids: set[int] = set() + self._estimate_runtime = RuntimeEstimator._roofline_estimate + + def _pack_hook(self, x: torch.Tensor) -> torch.Tensor: + # Hook function to track underlying storage IDs of tensors + # Updates the _saved_tensor_ids set with the IDs of the tensor's storages + # Used in conjunction with torch.autograd.graph.saved_tensors_hooks + untyped_storages = get_untyped_storages(x) + storage_ids = (hash(st) for st in untyped_storages) + self._saved_tensor_ids.update(storage_ids) + return x + + def _pre_fw_hook(self, mod: nn.Module, inputs: Any) -> None: + # Pre-forward hook function to prepare module metadata + # Tracks module FQN, force store random flag, and ``SACModMetadata`` + # Initializes metadata for non-leaf modules, marks leaf modules + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + num_children = sum(1 for _ in mod.children()) + if num_children > 0: + force_store_random = self._get_force_store_random(inputs) + self._sac_mod_metadata[mod_fqn] = _SACModMetadata( + start_idx=len(self._sac_metadata), + force_store_random=force_store_random, + sac_metadata=[], + ) + else: + self._leaf_modules.add(mod_fqn) + + def _post_fw_hook(self, mod: nn.Module, inputs: Any, outputs: Any) -> None: + # 1. Retrieves the module's FQN and checks if it's a leaf module + # 2. If not a leaf module, computes: + # - ``SACStats`` using the module's metadata and force store random flag + # - ``SACGreedyOrderMeta`` using the computed SAC statistics + mod_fqn = self._mod_tracker.get_known_fqn(mod) + assert mod_fqn is not None + if mod_fqn in self._leaf_modules: + return + else: + self.sac_mod_stats[mod_fqn] = self._get_sac_stats( + data=self._sac_mod_metadata[mod_fqn].sac_metadata, + force_store_random=self._sac_mod_metadata[mod_fqn].force_store_random, + ) + self.sac_mod_greedy_order_meta[mod_fqn] = self._get_greedy_order_meta( + self.sac_mod_stats[mod_fqn] + ) + + def _get_force_store_random(self, inputs: Any) -> bool: + flat_inputs, _ = tree_flatten(inputs) + return all(not isinstance(x, torch.Tensor) for x in flat_inputs) + + def _get_sac_stats( + self, data: list[_SACMetadata], force_store_random: bool + ) -> SACStats: + # 1. Ignore the operations that should be skipped by SAC such as aten.detach.default because autograd + # inserts those during backward and it breaks the fwd-bwd alignment + filtered_data = [x for x in data if x.func not in OPS_TO_ALWAYS_SKIP] + + ( + ops, + runtimes_, + memory_, + new_ids, + output_ids, + inplace_ops_, + view_like_ops_, + rand_ops_, + ) = zip(*[astuple(x) for x in filtered_data], strict=True) + + # 2. Extract the metadata information + runtimes = list(runtimes_) + memory = list(memory_) + func_names = [op._overloadpacket.__name__ for op in ops] + view_like_ops = [i for i, x in enumerate(view_like_ops_) if x] + rand_ops = [i for i, x in enumerate(rand_ops_) if x] + saved_autograd_ops = [ + i + for i, out_ids in enumerate(output_ids) + if set(out_ids).issubset(self._saved_tensor_ids) + ] + + # 3. Remap the inplace indices as we have removed OPS_TO_ALWAYS_SKIP + # FIXME @sanketpurandare: Fix this by changing the parent of the inplace-op + # to itself if the original parent is in OPS_TO_ALWAYS_SKIP. + try: + inplace_ops = [tuple(map(new_ids.index, x)) for x in inplace_ops_ if x] + except ValueError as err: + raise ValueError( + f"The remapping of inplace ops failed since one of the inplace op parents" + f" must have been present in {OPS_TO_ALWAYS_SKIP}" + ) from err + + # 4. The last operation is always stored as the output of the checkpoint + # block, so we can avoid recomputing it. We set the memory to zero + # instead of adding a new constraint because we want both the 0 and 1 + # endpoints for memory_budget to be valid + # FIXME @sanketpurandare: this heuristic for finding the last non-view non-inplace op + # might not always be correct, which would yield suboptimal policies + last_op = len(ops) - 1 + skip_ops_ = set(view_like_ops) | set({x[0] for x in inplace_ops}) + reversed_skip_ops = sorted(skip_ops_, reverse=True) + for op in reversed_skip_ops: + if op == last_op: + last_op -= 1 + + memory[last_op] = 0 + + # 5. Create a single ``SACStats`` object for the entire block of ``_SACMetadata``. + return SACStats( + func_names=func_names, + runtimes=runtimes, + memory=memory, + view_like_ops=view_like_ops, + rand_ops=rand_ops, + saved_autograd_ops=saved_autograd_ops, + inplace_ops=inplace_ops, # type: ignore[arg-type] + force_store_random=force_store_random, + ) + + def _get_inplace_metadata( + self, func: Any, out_storages: set[UntypedStorage] + ) -> tuple[int, tuple[int, ...], dict[str, tuple[int, ...]]]: + # 1. Get the current index of the metadata obtained so far + curr_idx = len(self._sac_metadata) + # 2. Get the set of active modules that are not leaf + active_mod_fqns: set[str] = { + par for par in self._mod_tracker.parents if par not in self._leaf_modules + } + # 3. Output ids are the identifies of the storage objects corresponding to the tensors + output_ids = tuple(hash(st) for st in out_storages) + # 4. If the function is not inplace, return + if not is_inplace(func): + return curr_idx, output_ids, dict.fromkeys(active_mod_fqns, ()) + + op_idx = curr_idx + # 5. Initialize the parent op ids of the inplace op for each of the active modules + mod_op_parent_idxs: dict[str, int] = dict.fromkeys(active_mod_fqns, -1) + for i, d in enumerate(self._sac_metadata): + # 6. Find the first occurrence of a tensor corresponding to each module that + # shares the same storage as the current tensor + past_output_ids = d.output_ids + if set(output_ids).issubset(set(past_output_ids)): + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx == -1: + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + if i >= acm_stats.start_idx: + mod_op_parent_idxs[mod_fqn] = i + else: + assert mod_fqn == "Global" + mod_op_parent_idxs[mod_fqn] = i + # 7. If no parent tensor is found, then it's probably an inplace op on the arguments + # so one can just store the current-op idx as parent idx + for mod_fqn, op_parent_idx in mod_op_parent_idxs.items(): + if op_parent_idx < 0: + mod_op_parent_idxs[mod_fqn] = op_idx + mod_inplace_info = { + mod_fqn: (op_idx, mod_op_parent_idxs[mod_fqn]) + for mod_fqn in active_mod_fqns + } + return curr_idx, output_ids, mod_inplace_info # type: ignore[return-value] + + def __torch_dispatch__( # type: ignore[no-untyped-def] + self, func, types, args=..., kwargs=None + ): + # 1. Get the runtime estimate + out, op_time = self._estimate_runtime(func, args, kwargs) + flat_outs, _ = tree_flatten(out) + out_storages_cuda: set[UntypedStorage] = set() + out_storages_cpu: set[UntypedStorage] = set() + cuda_devices: set[torch.device] = set() + for o in flat_outs: + if isinstance(o, torch.Tensor): + if o.device.type == "cuda": + out_storages_cuda.update(get_untyped_storages(o)) + cuda_devices.add(o.device) + else: + out_storages_cpu.update(get_untyped_storages(o)) + + # Check if there's more than 1 CUDA device + assert len(cuda_devices) <= 1, ( + f"{func.__name__}'s output has more than 1 CUDA devices {cuda_devices}" + ) + + # 2. Get the memory consumed by output + nbytes_cuda = sum( + math.ceil(st.nbytes() / _PYTORCH_MIN_ALLOCATE) * _PYTORCH_MIN_ALLOCATE + for st in out_storages_cuda + ) + nbytes_cpu = sum(st.nbytes() for st in out_storages_cpu) + nbytes = nbytes_cuda + nbytes_cpu + # 3. Get the current operator index, output storage identifiers and inplace metadata + out_storages = out_storages_cuda | out_storages_cpu + curr_idx, output_ids, mod_inplace_info = self._get_inplace_metadata( + func, out_storages + ) + # 4. Determine if the function is in-place, random-op or a view-like + is_view_like = is_view_fn(func) or is_inplace_view_fn(func) + is_rand_op = torch.Tag.nondeterministic_seeded in func.tags + if is_view_like: + nbytes = 0 + # sdpa has non-deterministic seed, but might be deterministic + # if no dropout is applied + if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": + is_rand_op = kwargs.get("dropout_p", 0) != 0 + # 5. Create metadata information per active non-leaf module + for mod_fqn in self._mod_tracker.parents: + if mod_fqn in self._leaf_modules: + continue + acm = _SACMetadata( + func=func, + time_taken=op_time, + memory_used=nbytes, + curr_idx=curr_idx, + output_ids=output_ids, + inplace_info=mod_inplace_info[mod_fqn], + is_view_like=is_view_like, + is_rand_op=is_rand_op, + ) + if acm_stats := self._sac_mod_metadata.get(mod_fqn, None): + acm_stats.sac_metadata.append(acm) + else: + assert mod_fqn == "Global", ( + f"Module {mod_fqn} not found in AC Mod Stats" + ) + self._sac_metadata.append(acm) + + return out + + def _get_greedy_order_meta(self, sac_stats: SACStats) -> SACGreedyOrderMeta: + # An inplace-op group is a set of inplace-ops that operate on the same underlying tensor storage. + # 1. inplace_op_groups: A dictionary from the top-most parent of inplace-ops to the inplace-ops in the group + # The top-most op can itself be an inplace-op or can be a non-inplace op. + # 2. inplace_op_to_group_head: A dictionary that maps all the inplace-ops to their respective group heads. + inplace_op_groups: dict[int, set[int]] = {} + inplace_op_to_group_head: dict[int, int] = dict(sac_stats.inplace_ops) + + # Initialize inplace_op_groups using inplace_op_to_group_head + for op_idx, group_head_idx in inplace_op_to_group_head.items(): + op_group = inplace_op_groups.setdefault(group_head_idx, {group_head_idx}) + op_group.add(op_idx) + + # Like inplace ops, all of the random ops in the function/module should all be either recomputed or saved + # as a group. This is because, they affect the ranom seed generator. If force_store_random is set True, + # all of the random ops will be stored by default. For easy of manageability, we store the top-most random op + # as the leader of the random_ops_group. + random_ops_group: dict[int, set[int]] = {} + random_group_head_idx = min(sac_stats.rand_ops, default=-1) + has_rand_ops = bool(sac_stats.rand_ops) + if has_rand_ops: + random_ops_group[random_group_head_idx] = set(sac_stats.rand_ops) + + # 1. Random ops are stored if force_store_random is set + # 2. View-like ops are recomputed by default + # 3. For inplace_op_groups: + # a) If the head of this group is an inplace op, then we have to store the entire group. + # b) If any op in the group is random and force_store_random is set, then entire group will be stored. + # c) If none of ops in the group are random and the head of the group is not an in-place op, then + # this group can be considered for recomputation in its entirety + stored_ops: set[int] = set() + recomputed_ops: set[int] = set() + # Case 1: + if has_rand_ops and sac_stats.force_store_random: + stored_ops.add(random_group_head_idx) + # Case 2: + recomputed_ops.update(set(sac_stats.view_like_ops)) + + for group_head_idx, op_group in inplace_op_groups.items(): + # Case 3a: + if group_head_idx in inplace_op_to_group_head: + stored_ops.add(group_head_idx) + # Case 3b: + if ( + sac_stats.force_store_random & len(op_group & set(sac_stats.rand_ops)) + > 0 + ): + stored_ops.add(group_head_idx) + + # The potential recompute candidates are populated as: + recompute_candidates: set[int] = set() + # 1) The random group head if it is not stored + if has_rand_ops and random_group_head_idx not in stored_ops: + recompute_candidates.add(random_group_head_idx) + # 2) The in-place op group heads that are not stored + recompute_candidates.update(set(inplace_op_groups.keys()) - stored_ops) + # 3) The non-inplace and non-random ops that are neither stored nor recomputed by default + recompute_candidates.update( + set(range(len(sac_stats.memory))) + - recomputed_ops + - stored_ops + - set(inplace_op_to_group_head.keys()) + - set(sac_stats.rand_ops) + ) + + # We define msps for a recomp candidate as the ratio of memory/runtime aka memory savings per second + msps_meta: list[MSPS] = [] + for cand_idx in recompute_candidates: + op_indices = {cand_idx} + if cand_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand_idx]) + if has_rand_ops and cand_idx == random_group_head_idx: + op_indices.update(sac_stats.rand_ops) + + mem = sum(sac_stats.memory[op_idx] for op_idx in op_indices) + runtime = sum(sac_stats.runtimes[op_idx] for op_idx in op_indices) + func_names = {sac_stats.func_names[op_idx] for op_idx in op_indices} + msps = (mem / runtime) if runtime > 0 else sys.float_info.max + msps_meta.append(MSPS(func_names, cand_idx, mem, runtime, msps)) + # We choose candidates to be recomputed based on increasing msps + msps_meta.sort(key=lambda x: x.msps, reverse=True) + return SACGreedyOrderMeta( + recomputed_ops, stored_ops, inplace_op_groups, random_ops_group, msps_meta + ) + + def _get_sac_tradeoff_pwlf_stats( + self, + sac_stats: SACStats, + greedy_order_meta: SACGreedyOrderMeta, + n_segments: int = 2, + save_tradeoff_graph: bool = False, + filename: str = "ac_tradeoff", + ) -> SACTradeOffStats: + try: + import numpy as np # type: ignore[import-not-found] + import pwlf # type: ignore[import-untyped, import-not-found] + except ImportError as err: + raise ImportError("Please install pwlf and numpy package.") from err + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + # 1. Initialize the discarded memory and recomputation runtime to sum of already chosen recomputed_ops + recomp_indices: set[int] = set() + for r_idx in recomputed_ops: + recomp_indices.add(r_idx) + if r_idx in inplace_op_groups: + recomp_indices.update(inplace_op_groups[r_idx]) + if r_idx in random_ops_group: + recomp_indices.update(random_ops_group[r_idx]) + + discarded_mem = sum(sac_stats.memory[op_idx] for op_idx in recomp_indices) + recomp_runtime = sum(sac_stats.runtimes[op_idx] for op_idx in recomp_indices) + # 2. Initialize the max recomputation time and total recomputation memory + sac_runtime = sum(sac_stats.runtimes) + sac_memory = sum(sac_stats.memory) + # 3. Tradeoff curve stores the KV pair of the discarded memory to total memory and, + # recomputation time to total runtime incurred. + delta = 1e-2 + tradeoff_curve = OrderedDict() + # 4. Initialize the trade-off curve with the stats of of already chosen recomputed_ops + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 5. Update the trade-off curve with memory and runtime stats of SAC candidates in the + # greedy order of their ``MSPS``. + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + # 6. Finally, we add the memory and recomputation time of the always stored ops. + stored_indices: set[int] = set() + for s_idx in stored_ops: + stored_indices.add(s_idx) + if s_idx in inplace_op_groups: + stored_indices.update(inplace_op_groups[s_idx]) + if s_idx in random_ops_group: + stored_indices.update(random_ops_group[s_idx]) + discarded_mem += sum(sac_stats.memory[op_idx] for op_idx in stored_indices) + recomp_runtime += sum(sac_stats.runtimes[op_idx] for op_idx in stored_indices) + tradeoff_curve[(discarded_mem / sac_memory) + delta] = ( + recomp_runtime / sac_runtime + ) + x_ = list(tradeoff_curve.keys()) + y_ = list(tradeoff_curve.values()) + # 7. We shift the y values to left and x values to right to upperbound the trade-off function + # TODO: Write a better explanation why this needs to be done + x = x_[: len(x_) - 1] + y = y_[1:] + tradeoff_pwlf = pwlf.PiecewiseLinFit(x, y) + # 8. Fit a piecewise linear function with the specified number of segments to the trade-off curve. + n_segments = max(min(len(x) - 2, n_segments), 1) + tradeoff_pwlf.fit(n_segments=n_segments) + + # save prediction graph + def save_prediction_graph( + pwlf_: pwlf.PiecewiseLinFit, x: list[float], y: list[float], filename: str + ) -> None: + try: + import matplotlib.pyplot as plt # type: ignore[import-not-found] + import numpy as np # type: ignore[import-not-found] + except ImportError as err: + raise ImportError( + "Install matplotlib and numpy using pip: pip install matplotlib numpy" + ) from err + # predict for the determined points + xHat = np.linspace(min(x), max(x), num=10000) + yHat = pwlf_.predict(xHat) + + # plot the results + plt.figure() + plt.plot(x, y, "o", label="Shifted") + plt.plot(xHat, yHat, "-", label="Predicted") + plt.plot(x_, y_, "x", label="Original") + plt.ylabel("Recomp time / Total recomp time") + plt.xlabel("Memory discarded / Total memory") + plt.legend() + plt.title(f"{filename}") + plt.suptitle( + f"Total Memory = {sac_memory} B Total Runtime = {sac_runtime:.4f} ms", + fontsize=10, + ) + folder_name = "tradeoff_graphs" + if not os.path.exists(folder_name): + os.makedirs(folder_name) + # Save the plots in the folder + plt.savefig(os.path.join(folder_name, f"{filename}.png")) + + if save_tradeoff_graph: + save_prediction_graph(tradeoff_pwlf, x, y, filename) + # 9. Obtain the slopes, intercepts and breakpoints of the fitted piecewise linear functions + slopes = tradeoff_pwlf.calc_slopes().tolist() + assert isinstance(tradeoff_pwlf.intercepts, np.ndarray) and isinstance( + tradeoff_pwlf.fit_breaks, np.ndarray + ) + intercepts = tradeoff_pwlf.intercepts.tolist() + fit_breaks = tradeoff_pwlf.fit_breaks.tolist() + return SACTradeOffStats( + n_segments=n_segments, + slopes=slopes, + intercepts=intercepts, # type: ignore[arg-type] + fit_breaks=fit_breaks, # type: ignore[arg-type] + tradeoff_curve=tradeoff_curve, + sac_memory=sac_memory, + sac_runtime=sac_runtime, + ) + + def display_sac_stats( + self, sac_stats: SACStats, print_tabular: bool = False + ) -> None: + """ + Displays the SAC statistics. + + Args: + sac_stats (SACStats): The SAC statistics to display. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + 1. Total Memory: The total memory usage in bytes. + 2. Total Runtime: The total runtime in milliseconds. + 3. Store Random: A flag indicating whether to force store random operator results. + + Followed by a table with the following columns: + 1. Op Idx: The operator index. + 2. Op Name: The operator name. + 3. Runtimes (ms): The operator runtime in milliseconds. + 4. Memory (B): The operator memory usage in bytes. + 5. View-like: A flag indicating whether the operator is view-like. + 6. Random: A flag indicating whether the operator is random. + 7. Saved Autograd: A flag indicating whether the operator's result is saved by autograd engine. + 8. In-place: The index of the operator's first parent, or None if not in-place. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + print( + f"Total Memory: {sum(sac_stats.memory)} B Total Runtime: {sum(sac_stats.runtimes)} ms" + f" Store Random: {sac_stats.force_store_random}" + ) + table_data = [] + op_parent = dict(sac_stats.inplace_ops) + for i, fn_name in enumerate(sac_stats.func_names): + row = [ + str(i), + fn_name, + f"{sac_stats.runtimes[i]:.4f}", + str(sac_stats.memory[i]), + str(i in sac_stats.view_like_ops), + str(i in sac_stats.rand_ops), + str(i in sac_stats.saved_autograd_ops), + str(op_parent.get(i, None)), + ] + table_data.append(row) + # Define headers + headers = [ + "Op Idx", + "Op Name", + "Runtimes(ms)", + "Memory (B)", + "View-like", + "Random", + "Saved Autograd", + "In-place", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def display_sac_tradeoff_stats( + self, + greedy_order_meta: SACGreedyOrderMeta, + sac_stats: SACStats, + print_tabular: bool = False, + ) -> None: + """ + Displays the SAC trade-off statistics. + + Args: + greedy_order_meta (SACGreedyOrderMeta): The SAC greedy order metadata. + sac_stats (SACStats): The SAC statistics. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + A table with the following columns: + 1. Op Id(s): The operator index(es). + 2. Op Name(s): The operator name(s). + 3. Discarded Mem (%): The percentage of discarded memory. + 4. Discarded Mem (B): The discarded memory in bytes. + 5. Recomp time (%): The percentage of recomputed time. + 6. Recomp time (ms): The recomputed time in milliseconds. + 7. MSPS: The memory per second. + 8. Always Stored: A flag indicating whether the operator is always stored. + 9. Always Recomputed: A flag indicating whether the operator is always recomputed. + + If print_tabular is True, the table is printed in a tabular format. + Otherwise, the table is printed in a plain text format. + """ + table_data = [] + total_memory, total_runtime = sum(sac_stats.memory), sum(sac_stats.runtimes) + discarded_mem: int = 0 + recomp_runtime: float = 0.0 + + def append_row( + op_indices: set[int], + func_names: set[str], + msps: Optional[float] = None, + stored: Optional[bool] = False, + recomputed: Optional[bool] = False, + ) -> None: + row = [ + str(op_indices), + str(func_names), + f"{discarded_mem / total_memory:.4f}", + str(discarded_mem), + f"{recomp_runtime / total_runtime:.4f}", + str(recomp_runtime), + f"{msps:.2e}" if msps is not None else str(nan), + str(stored), + str(recomputed), + ] + table_data.append(row) + + stored_ops, recomputed_ops, inplace_op_groups, random_ops_group, msps_meta = ( + greedy_order_meta.stored_ops, + greedy_order_meta.recomputed_ops, + greedy_order_meta.inplace_op_groups, + greedy_order_meta.random_ops_group, + greedy_order_meta.msps_meta, + ) + + for op_idx in recomputed_ops: + op_indices: set[int] = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, recomputed=True) + + for cand in msps_meta: + discarded_mem += cand.memory + recomp_runtime += cand.runtime + op_indices = {cand.op_idx} + if cand.op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[cand.op_idx]) + if cand.op_idx in random_ops_group: + op_indices.update(random_ops_group[cand.op_idx]) + append_row(op_indices, cand.func_names, msps=cand.msps) + + for op_idx in stored_ops: + op_indices = {op_idx} + if op_idx in inplace_op_groups: + op_indices.update(inplace_op_groups[op_idx]) + if op_idx in random_ops_group: + op_indices.update(random_ops_group[op_idx]) + discarded_mem += sum(sac_stats.memory[i] for i in op_indices) + recomp_runtime += sum(sac_stats.runtimes[i] for i in op_indices) + func_names = {sac_stats.func_names[i] for i in op_indices} + append_row(op_indices, func_names, stored=True) + + headers = [ + "Op Id(s)", + "Op Name(s)", + "Discarded Mem (%)", + "Discarded Mem (B)", + "Recomp time (%)", + "Recomp time (ms)", + "MSPS", + "Always Stored", + "Always Recomputed", + ] + if print_tabular: + _display_stats_tabular(headers, table_data) + else: + max_widths = [0 for _ in range(len(headers))] + table_data.insert(0, headers) + for row in table_data: + for i, elem in enumerate(row): + max_widths[i] = max(max_widths[i], len(elem)) + for row in table_data: + print( + "\t".join( + [f"{elem:<{max_widths[i]}}" for i, elem in enumerate(row)] + ) + ) + + def pwlf_sac_tradeoff_curve( + self, + n_segments: int = 2, + save_tradeoff_graphs: bool = False, + ) -> None: + """ + Fits a piecewise linear function with the specified sumber of segments to the SAC trade-off curve of + discarded memory vs recomputation time. + + Args: + n_segments (int, optional): The number of segments to be used for fitting the piecewise linear function to + the trade-off curve. Defaults to 2. + save_tradeoff_graphs (bool, optional): Whether to save the trade-off graphs to file. Defaults to False. + + If save_tradeoff_graphs is True, the trade-off graphs are saved to file using the module FQN as the filename. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + self.sac_mod_tradeoff_stats[mod_fqn] = self._get_sac_tradeoff_pwlf_stats( + sac_stats=sac_stats, + greedy_order_meta=self.sac_mod_greedy_order_meta[mod_fqn], + n_segments=n_segments, + save_tradeoff_graph=save_tradeoff_graphs, + filename=mod_fqn, + ) + + def display_modulewise_sac_stats( + self, depth: int = 2, print_tabular: bool = False + ) -> None: + """ + Displays the SAC and trade-off statistics for each module. + + Args: + depth (int, optional): The maximum depth of modules to display. Defaults to 2. + print_tabular (bool, optional): Whether to print the statistics in a tabular format. Defaults to False. + + Prints: + For each module with depth less than or equal to the specified depth: + 1. The SAC statistics for the module (using display_sac_stats). + 2. The SAC trade-off statistics for the module (using display_sac_tradeoff_stats). + + If print_tabular is True, the statistics are printed in a tabular format. + Otherwise, the statistics are printed in a plain text format. + """ + for mod_fqn, sac_stats in self.sac_mod_stats.items(): + mod_depth = mod_fqn.count(".") + 1 + if mod_depth > depth: + continue + print(f"Module: {mod_fqn}") + self.display_sac_stats(sac_stats, print_tabular) + print(f"AC Trade-off for Module: {mod_fqn} MSPS = Memory/Runtime") + self.display_sac_tradeoff_stats( + self.sac_mod_greedy_order_meta[mod_fqn], sac_stats, print_tabular + ) + + def __call__(self, estimate_mode_type: str) -> Self: + """ + Sets the estimate mode type. + + Currently supported modes: + - "operator-level-benchmark": Estimates runtime using operator benchmarking. + - "operator-level-cost-model": Estimates runtime using roofline cost model. + + Args: + estimate_mode_type (str): The type of estimate mode to use. + + Returns: + SACEstimator: The SAC estimator instance. + + Raises: + NotImplementedError: If the estimate mode type is not supported. + """ + if estimate_mode_type == "operator-level-benchmark": + self._estimate_runtime = RuntimeEstimator._benchmark_estimate + elif estimate_mode_type == "operator-level-cost-model": + self._estimate_runtime = RuntimeEstimator._roofline_estimate + else: + raise NotImplementedError( + f"estimate_mode_type {estimate_mode_type} not supported" + ) + return self + + def __enter__(self) -> Self: # type: ignore[no-untyped-def] + fake_mode = active_fake_mode() + assert isinstance(fake_mode, FakeTensorMode), ( + "SAC Estimator should be called in FakeTensorMode" + ) + RuntimeEstimator.fake_mode = fake_mode + self._mod_tracker.register_user_hooks( + pre_fw_hook=self._pre_fw_hook, + post_fw_hook=self._post_fw_hook, + ) + self._mod_tracker.__enter__() + self._saved_tensor_hook_ctx.__enter__() + return super().__enter__() + + def __exit__(self, *args: Any) -> None: # type: ignore[no-untyped-def] + self._saved_tensor_hook_ctx.__exit__() + self._mod_tracker.__exit__(*args) + super().__exit__(*args) diff --git a/phivenv/Lib/site-packages/torch/distributed/_tools/sac_ilp.py b/phivenv/Lib/site-packages/torch/distributed/_tools/sac_ilp.py new file mode 100644 index 0000000000000000000000000000000000000000..16569493091c1de27693f452d02be50fb975cf09 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/_tools/sac_ilp.py @@ -0,0 +1,295 @@ +import logging +import math +from enum import IntEnum +from typing import Optional + +from torch.distributed._tools.ilp_utils import Graph, is_submodule +from torch.distributed._tools.sac_estimator import SACStats + + +try: + from pulp import ( # type: ignore[import-untyped,import-not-found] + lpDot, + LpInteger, + LpMaximize, + LpMinimize, + LpProblem, + LpStatus, + lpSum, + LpVariable, + PULP_CBC_CMD, + value, + ) +except ImportError as err: + raise ImportError( + "Please install pulp package. See: https://github.com/coin-or/pulp." + ) from err + +# Create a logger object +logger = logging.getLogger(__name__) + +# Set the logging level to INFO +logger.setLevel(logging.INFO) + + +def sac_milp( + graph: Graph, + memory_budget: float, + world_size: int = 1, + ac_units: Optional[list[str]] = None, + fsdp_units: Optional[list[str]] = None, +) -> tuple[dict[str, float], float, int]: + """ + MILP to decide which modules to AC and how much memory to discard. + The objective is to minimize recomputation time. + The constraint is to ensure peak memory is under budget. + + Args: + graph: graph representation of the model as a module submodule tree + where each node is a submodule with memory & runtime stats + memory_budget: memory budget in GiB + world_size: number of GPUs. In the case of FSDP, world_size will be + used to compute the amount of parameter and gradient memory on each rank + ac_units: a list of user-specified AC units. + fsdp_units: a list of FSDP units. AC units cannot be supermodules of FSDP units. + + Returns: + Dict[str, float]: the optimal SAC solution, mapping from module fqn to + the percentage of activation memory to **discard** + float: the recomputation time of the optimal SAC solution + int: upper bound on the peak memory of the optimal SAC solution. + note that value of -1 means that the ILP solver failed to find a solution. + + """ + num_nodes = len(graph.nodes) + M = 10**2 # note: numerical issue may occur if M is too big + MEM_MULTIPLIER = 2**30 + + # Create a MILP problem + prob = LpProblem("SAC", LpMinimize) + + # Create decision variables + # y_i: indicator for if module i is AC'ed + y = LpVariable.matrix("y", list(range(num_nodes)), 0, 1, LpInteger) + # r_i: percentage of discarded activation memory + r = LpVariable.matrix("r", list(range(num_nodes)), 0, 1) + # d_i: discarded activation memory for module i + d = LpVariable.matrix("d", list(range(num_nodes)), 0) + # a_i: total activation memory at module i + a = LpVariable.matrix("a", list(range(num_nodes)), 0) + # m_i: memory at module i, combining parameters, gradients, and activations + m = LpVariable.matrix("m", list(range(num_nodes)), 0) + # rcp_i: percentage of recomputation time + rcp = LpVariable.matrix("rcp", list(range(num_nodes)), 0) + # rct_i: recomputation time for module i (in ms) + rct = LpVariable.matrix("rct", list(range(num_nodes)), 0) + # max_m: peak memory + max_m = LpVariable("max_m", 0) + + # Add constraints + # [Constraint] User specified AC units + if ac_units: + ac_units_set = set(ac_units) + for i in range(num_nodes): + if graph.nodes[i]["fqn"] not in ac_units_set: + prob += y[i] == 0 + + # [Constraint] AC units cannot be supmodules of user specified FSDP units + if fsdp_units: + for i in range(num_nodes): + if any( + is_submodule(fsdp_unit, graph.nodes[i]["fqn"]) + for fsdp_unit in fsdp_units + ): + prob += y[i] == 0 + + # [Constraint] No nested AC units + for i in range(num_nodes): + for j in range(i + 1, num_nodes): + if graph.ad_matrix[i][j] == 1: + prob += y[i] + y[j] <= 1 + + # [Constraint] Do not AC leaf modules + for i in range(num_nodes): + if graph.nodes[i]["is_leaf"]: + prob += y[i] == 0 + + # [Constraint] Express amount of discarded activation memory + for i in range(num_nodes): + # There are two measures for activation memory: ACM and IA + # 1. IA is the activation memory saved when not using AC + # 2. ACM is the total activation memory, including those + # that are not typically saved when not using AC + # Note: ACM >= IA + if (not graph.nodes[i]["is_leaf"]) and graph.nodes[i][ + "sac_memory" + ] < graph.nodes[i]["act_fw_per_module"]: + logger.warning("For module {%s}: ", graph.nodes[i]["fqn"]) + logger.warning( + "activation memory from memory tracker is {%d},", + graph.nodes[i]["act_fw_per_module"], + ) + logger.warning( + "activation memory from SAC estimator is {%d}.", + graph.nodes[i]["sac_memory"], + ) + logger.warning("Something is wrong. Please check!") + logger.warning("Overriding the latter with the former.") + graph.nodes[i]["sac_memory"] = graph.nodes[i]["act_fw_per_module"] + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += d[i] == ACM_i * r[i] - (ACM_i - IA_i) * y[i] + + # [Constraint] Ensure correctness of r_i + # There are two parts to its correctness + # 1. r_i > 0 only if y_i == 1 (discard only if it is an AC unit) + # 2. r_i needs to be large enough to cover the difference between + # ACM and IA. Otherwise, we are not saving any memory + for i in range(num_nodes): + prob += y[i] >= r[i] + if graph.nodes[i]["is_leaf"]: + continue + ACM_i = graph.nodes[i]["sac_memory"] / MEM_MULTIPLIER + IA_i = graph.nodes[i]["act_fw_per_module"] / MEM_MULTIPLIER + prob += r[i] >= (ACM_i - IA_i) / ACM_i * y[i] + + # [Constraint] Express total activation memory in the backward pass + for i in range(num_nodes): + AG_i = graph.nodes[i]["act_grad_per_module"] / MEM_MULTIPLIER + TA_i = graph.nodes[i]["act_total"] / MEM_MULTIPLIER + # related to discarded amount of memory + pos = graph.nodes[i]["pos_fw_post_order"] + coeff = [0] * num_nodes + for p in range(pos): + j = graph.name2node[graph.fw_post_order[p]]["index"] + coeff[j] = 1 + prob += a[i] == TA_i + AG_i - lpDot(coeff, d) + + # [Constraint] Express the total amount of memory at each module + # Note that unsharded parameters and gradients are not included here + P_1 = graph.nodes[0]["param_per_module"] / MEM_MULTIPLIER + for i in range(num_nodes): + TG_i = graph.nodes[i]["grad_total"] / MEM_MULTIPLIER + prob += m[i] == a[i] + (P_1 + TG_i) / world_size + + # [Constraint] Express peak memory + for i in range(num_nodes): + prob += max_m >= m[i] + + # [Constraint] Express percentage of recomputation time + for i in range(num_nodes): + for s in range(graph.nodes[i]["n_segments"]): + slope = graph.nodes[i]["slopes"][s] + intercept = graph.nodes[i]["intercepts"][s] + prob += rcp[i] >= slope * r[i] + intercept + + # [Constraint] Express recomputation time + # rct_i = (rcp_i * ACT_i) if y_i == 1 else 0 + for i in range(num_nodes): + ACT_i = graph.nodes[i]["sac_runtime"] + prob += rct[i] <= M * y[i] + prob += rct[i] <= ACT_i * rcp[i] + prob += rct[i] >= ACT_i * rcp[i] - M * (1 - y[i]) + + # [Constraint] Peak memory should be below budget + prob += max_m <= memory_budget + + # Set Objeictive + prob += lpSum(rct) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=180, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return {}, 0, -1 + + # Gather and return solution if optimal solution is found + ac_decisions = {} + for i in range(num_nodes): + if round(y[i].varValue) == 1: + ac_decisions[graph.nodes[i]["fqn"]] = round(r[i].varValue, 4) + recomputation_time = round(value(prob.objective), 2) + peak_mem = round(max_m.varValue * MEM_MULTIPLIER) + + return ac_decisions, recomputation_time, peak_mem + + +class SACDecision(IntEnum): + RECOMPUTE = 0 + SAVE = 1 + + +def get_optimal_checkpointing_policy_per_module( + sac_stats: SACStats, memory_budget: float +) -> list[int]: + """ + This is adapted from -- + https://github.com/facebookresearch/xformers/blob/c6c0ac31f1b08542a0bc27278c6ed10f825f6963/xformers/checkpoint.py#L375 + + Given the SACStats of a module, including list of operators, their memory, runtimes, and metadata, + decide via MILP an optimal set of operators to checkpoint under a given ``memory_budget``. + + Args: + sac_stats: the SACStats object of the module + memory_budget: a float between zero and one + + Returns: + List[int]: the decision whether each operator should be saved (1) or recomptued (0). + """ + if not (0 <= memory_budget <= 1): + raise ValueError( + f"`memory_budget` must be a float between 0 and 1. Got {memory_budget}." + ) + num_ops = len(sac_stats.func_names) + + # Create a MILP problem + prob = LpProblem("SAC-per-module", LpMaximize) + + # Create decision variables + # x[i] = 1 means the i-th operator should be saved, otherwise it should be recomputed + x = LpVariable.matrix("x", list(range(num_ops)), 0, 1, LpInteger) + + # Add constraints + # [Constraint] random ops should be saved if ``force_store_random`` is True + # otherwise, random ops should either be all recomputed or all saved + if sac_stats.force_store_random: + for i in sac_stats.rand_ops: + prob += x[i] == SACDecision.SAVE.value + else: + for i1, i2 in zip(sac_stats.rand_ops[:-1], sac_stats.rand_ops[1:]): + prob += x[i1] == x[i2] + + # [Constraint] view-like ops should always be recomputed + for i in sac_stats.view_like_ops: + prob += x[i] == SACDecision.RECOMPUTE.value + + # [Constraint] inplace ops should always be done in conjunction with its parent op + for op, op_parent in sac_stats.inplace_ops: + if op != op_parent: + prob += x[op] == x[op_parent] + else: + prob += x[op] == SACDecision.SAVE.value + + # [Constraint] saved memory should be under the ``memory_budget`` + max_memory = math.ceil(memory_budget * sum(sac_stats.memory)) + prob += lpDot(x, sac_stats.memory) <= max_memory + + # [Objective] minimize recomputation time, note the ILP is a maximization problem + # because x[i] == 1 means the op is saved (not recomputed), and thus recomputation + # time is sum(sac_stats.runtimes) - lpDot(x, sac_stats.runtimes) + prob += lpDot(x, sac_stats.runtimes) + + # Solve + solver = PULP_CBC_CMD(gapRel=0.05, timeLimit=10, msg=0) + status = prob.solve(solver) + + # If solver fails, print status and return empty solution + if status != 1: + logger.error("Solver failed to find a solution: %s", LpStatus[status]) + return [] + + # Gather and return solution if optimal solution is found + return [round(x[i].varValue) for i in range(num_ops)] diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9db983e0b995e8f81839aafd6b1b08c6579200b6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53de0608c3fd471c9fa4a8875c0c937399a5d4b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/__pycache__/control_plane.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..94f919155b854e7f5af47deb81e01f8ecca7f834 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__init__.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +In the context of Torch Distributed Elastic we use the term *rendezvous* to +refer to a particular functionality that combines a **distributed +synchronization** primitive with **peer discovery**. + +It is used by Torch Distributed Elastic to gather participants of a training +job (i.e. nodes) such that they all agree on the same list of participants and +everyone's roles, as well as make a consistent collective decision on when +training can begin/resume. + +Torch Distributed Elastic rendezvous provides the following critical +functionalities: + +**Barrier**: + +Nodes performing rendezvous will all block until the rendezvous is considered +complete - this happens when at least ``min`` total number of nodes have joined +the rendezvous barrier (for the same job). This also implies the barrier is not +necessarily of fixed size. + +There's an additional small waiting time after reaching ``min`` number of +nodes - this is used to ensure the rendezvous is not completed "too quickly" +(which could potentially exclude additional nodes attempting to join at +approximately the same time). + +If ``max`` number of nodes is gathered at the barrier, the rendezvous is +completed immediately. + +There's also an overall timeout which causes the rendezvous to fail if ``min`` +number of nodes is never reached - this is meant to be a simple fail-safe to +help release partially allocated job resources, in case there's a problem with +the resource manager, and is meant to be interpreted as non-retryable. + +**Exclusivity**: + +A simple distributed barrier would not be sufficient, as we also need to ensure +that only one group of nodes exists at any given time (for a given job). In +other words, new nodes (i.e. joining late) should not be able to form a parallel +independent group of workers for the same job. + +Torch Distributed Elastic rendezvous ensures that if a group of nodes has +already completed a rendezvous (and hence might already be training), then +additional "late" nodes attempting to rendezvous will only announce themselves +as waiting, and will have to wait until the (previously completed) existing +rendezvous is destroyed first. + +**Consistency**: + +When a rendezvous is completed, all its members will agree on the job membership +and everyone's role in it. This role is represented using an integer, called +rank, that is between between 0 and world size. + +Note that ranks are *not stable*, in the sense that the same node can be +assigned a different rank in the next (re-)rendezvous. + +**Fault-tolerance**: + +Torch Distributed Elastic rendezvous is designed to tolerate node failures +during the rendezvous process. Should a process crash (or lose network +connectivity, etc), between joining the rendezvous and it being completed, then +a re-rendezvous with remaining healthy nodes will happen automatically. + +A node can also fail *after* it has completed (or *has been observed* by other +nodes to have completed) the rendezvous - this scenario will be handled by the +Torch Distributed Elastic ``train_loop`` instead (where it will also trigger a +re-rendezvous). + +**Shared key-value store**: + +When the rendezvous is completed, a shared key-value store is created and +returned. This store implements a ``torch.distributed.Store`` API (see +`distributed communication docs +`__). + +This store is only shared by the members of the completed rendezvous. It +is intended to be used by Torch Distributed Elastic to exchange information +necessary to initialize job control and data-planes. + +**Waiting workers and rendezvous closing**: + +Torch Distributed Elastic rendezvous handler object provides additional +functionalities, which are technically not part of the rendezvous process: + +1. Querying how many workers arrived late at the barrier, who can participate in + *next* rendezvous. + +2. Setting the rendezvous *closed* to signal all nodes not to participate in + next rendezvous. + +**DynamicRendezvousHandler**: + +Torch Distributed Elastic comes with the :py:class:`.DynamicRendezvousHandler` +class that implements the rendezvous mechanism described above. It is a backend- +agnostic type that expects a particular :py:class:`.RendezvousBackend` instance +to be specified during construction. + +Torch distributed users can either implement their own backend type or use one +of the following implementations that come with PyTorch: + +- :py:class:`.C10dRendezvousBackend`: Uses a C10d store (by default + ``TCPStore``) as the rendezvous backend. The main advantage of using a C10d + store is that it requires no 3rd-party dependency (such as etcd) to establish + a rendezvous. +- :py:class:`.EtcdRendezvousBackend`: Supersedes the legacy + :py:class:`.EtcdRendezvousHandler` class. Passing an + :py:class:`.EtcdRendezvousBackend` instance to + :py:class:`.DynamicRendezvousHandler` is functionally equivalent to + instantiating an :py:class:`.EtcdRendezvousHandler`. + + :: + + store = TCPStore("localhost") + + backend = C10dRendezvousBackend(store, "my_run_id") + + rdzv_handler = DynamicRendezvousHandler.from_backend( + run_id="my_run_id", store=store, backend=backend, min_nodes=2, max_nodes=4 + ) +""" + +from .api import ( + rendezvous_handler_registry, + RendezvousClosedError, + RendezvousConnectionError, + RendezvousError, + RendezvousGracefulExitError, + RendezvousHandler, + RendezvousHandlerCreator, + RendezvousHandlerRegistry, + RendezvousInfo, + RendezvousParameters, + RendezvousStateError, + RendezvousStoreInfo, + RendezvousTimeoutError, +) +from .registry import _register_default_handlers, _register_out_of_tree_handlers + + +_register_default_handlers() +_register_out_of_tree_handlers() + + +__all__ = [ + "RendezvousClosedError", + "RendezvousConnectionError", + "RendezvousError", + "RendezvousGracefulExitError", + "RendezvousHandler", + "RendezvousHandlerCreator", + "RendezvousHandlerRegistry", + "RendezvousInfo", + "RendezvousParameters", + "RendezvousStateError", + "RendezvousStoreInfo", + "RendezvousTimeoutError", + "rendezvous_handler_registry", +] diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9a3fda69cf499a61bd29afa967addc812188c3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2c7337940c2092bf94d1170c41fdcb48cc732f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/_etcd_stub.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fb615c16b77be8ed2901d968bd21dd053490161 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9afc0b0718fc90867ce4132060497c68975e4a49 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/c10d_rendezvous_backend.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbeb470cb63a89983be7033da42b149f6476ed19 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/dynamic_rendezvous.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3b75ca3dbd18d074dc1db5dc6e5795f4383111f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82ff8500fa3a5e6faad36adbbc1df172a84de14b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_rendezvous_backend.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebb7bf0d72af1b56bf0fd3ba0148ed8c7da18c17 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_server.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d08c24dcdfddb525564d1e707991d1b7986b07c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/etcd_store.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d6ccedfcfd993810deaa5cfc11f49c7ef5770fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52acd2751a56a28ddbfcc9df2623718755f11fc9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/static_tcp_rendezvous.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..309d517a4180fb161ade5eda4db06fb42ceac29f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/rendezvous/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d97b00d5bf7a3ba5915f556f1ed9832a9340866f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__init__.py @@ -0,0 +1,54 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Expiration timers are set up on the same process as the agent and +used from your script to deal with stuck workers. When you go into +a code-block that has the potential to get stuck you can acquire +an expiration timer, which instructs the timer server to kill the +process if it does not release the timer by the self-imposed expiration +deadline. + +Usage:: + + import torchelastic.timer as timer + import torchelastic.agent.server as agent + + def main(): + start_method = "spawn" + message_queue = mp.get_context(start_method).Queue() + server = timer.LocalTimerServer(message, max_interval=0.01) + server.start() # non-blocking + + spec = WorkerSpec( + fn=trainer_func, + args=(message_queue,), + ...) + agent = agent.LocalElasticAgent(spec, start_method) + agent.run() + + def trainer_func(message_queue): + timer.configure(timer.LocalTimerClient(message_queue)) + with timer.expires(after=60): # 60 second expiry + # do some work + +In the example above if ``trainer_func`` takes more than 60 seconds to +complete, then the worker process is killed and the agent retries the worker group. +""" + +from .api import ( # noqa: F401 + configure, + expires, + TimerClient, + TimerRequest, + TimerServer, +) +from .file_based_local_timer import ( # noqa: F401 + FileTimerClient, + FileTimerRequest, + FileTimerServer, +) +from .local_timer import LocalTimerClient, LocalTimerServer # noqa: F401 diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7224515e48f15a7e5704b2a71b76c5d967039f42 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49003040f8d402c799ebe67dff22ef2e8d15428f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..151f650eba7fde934c44ecd4c2ece5676ef88004 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/debug_info_logging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f964c67d243d806f0f63f959ed282134e17d1e40 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/file_based_local_timer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f9fb2e0f7346f89d0e813d6e2bbd467fd16e3fd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/__pycache__/local_timer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/api.py new file mode 100644 index 0000000000000000000000000000000000000000..02e2270ce0b342ac433ba84500b8b9a8c261e9c0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/api.py @@ -0,0 +1,283 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import abc +import logging +import threading +import time +from contextlib import contextmanager +from inspect import getframeinfo, stack +from typing import Any, Optional + + +__all__ = [ + "TimerRequest", + "TimerClient", + "RequestQueue", + "TimerServer", + "configure", + "expires", +] + +logger = logging.getLogger(__name__) + + +class TimerRequest: + """ + Data object representing a countdown timer acquisition and release + that is used between the ``TimerClient`` and ``TimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + + .. note:: the type of ``worker_id`` is implementation specific. + It is whatever the TimerServer and TimerClient implementations + have on to uniquely identify a worker. + """ + + __slots__ = ["worker_id", "scope_id", "expiration_time"] + + def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): + self.worker_id = worker_id + self.scope_id = scope_id + self.expiration_time = expiration_time + + def __eq__(self, other): + if isinstance(other, TimerRequest): + return ( + self.worker_id == other.worker_id + and self.scope_id == other.scope_id + and self.expiration_time == other.expiration_time + ) + return False + + +class TimerClient(abc.ABC): + """ + Client library to acquire and release countdown timers by communicating + with the TimerServer. + """ + + @abc.abstractmethod + def acquire(self, scope_id: str, expiration_time: float) -> None: + """ + Acquires a timer for the worker that holds this client object + given the scope_id and expiration_time. Typically registers + the timer with the TimerServer. + """ + + @abc.abstractmethod + def release(self, scope_id: str): + """ + Releases the timer for the ``scope_id`` on the worker this + client represents. After this method is + called, the countdown timer on the scope is no longer in effect. + """ + + +class RequestQueue(abc.ABC): + """ + Consumer queue holding timer acquisition/release requests + """ + + @abc.abstractmethod + def size(self) -> int: + """ + Returns the size of the queue at the time this method is called. + Note that by the time ``get`` is called the size of the queue + may have increased. The size of the queue should not decrease + until the ``get`` method is called. That is, the following assertion + should hold: + + size = q.size() + res = q.get(size, timeout=0) + assert size == len(res) + + -- or -- + + size = q.size() + res = q.get(size * 2, timeout=1) + assert size <= len(res) <= size * 2 + """ + + @abc.abstractmethod + def get(self, size: int, timeout: float) -> list[TimerRequest]: + """ + Gets up to ``size`` number of timer requests in a blocking fashion + (no more than ``timeout`` seconds). + """ + + +class TimerServer(abc.ABC): + """ + Entity that monitors active timers and expires them + in a timely fashion. This server is responsible for + reaping workers that have expired timers. + """ + + def __init__( + self, request_queue: RequestQueue, max_interval: float, daemon: bool = True + ): + """ + :param request_queue: Consumer ``RequestQueue`` + :param max_interval: max time (in seconds) to wait + for an item in the request_queue + :param daemon: whether to run the watchdog thread as a daemon + """ + super().__init__() + self._request_queue = request_queue + self._max_interval = max_interval + self._daemon = daemon + self._watchdog_thread: Optional[threading.Thread] = None + self._stop_signaled = False + + @abc.abstractmethod + def register_timers(self, timer_requests: list[TimerRequest]) -> None: + """ + Processes the incoming timer requests and registers them with the server. + The timer request can either be a acquire-timer or release-timer request. + Timer requests with a negative expiration_time should be interpreted + as a release-timer request. + """ + + @abc.abstractmethod + def clear_timers(self, worker_ids: set[Any]) -> None: + """ + Clears all timers for the given ``worker_ids``. + """ + + @abc.abstractmethod + def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]: + """ + Returns all expired timers for each worker_id. An expired timer + is a timer for which the expiration_time is less than or equal to + the provided deadline. + """ + + @abc.abstractmethod + def _reap_worker(self, worker_id: Any) -> bool: + """ + Reaps the given worker. Returns True if the worker has been + successfully reaped, False otherwise. If any uncaught exception + is thrown from this method, the worker is considered reaped + and all associated timers will be removed. + """ + + def _reap_worker_no_throw(self, worker_id: Any) -> bool: + """ + Wraps ``_reap_worker(worker_id)``, if an uncaught exception is + thrown, then it considers the worker as reaped. + """ + try: + return self._reap_worker(worker_id) + except Exception: + logger.exception( + "Uncaught exception thrown from _reap_worker(), " + "check that the implementation correctly catches exceptions", + ) + return True + + def _watchdog_loop(self): + while not self._stop_signaled: + try: + self._run_watchdog() + except Exception: + logger.exception("Error running watchdog") + + def _run_watchdog(self): + batch_size = max(1, self._request_queue.size()) + timer_requests = self._request_queue.get(batch_size, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_ids = set() + for worker_id, expired_timers in self.get_expired_timers(now).items(): + logger.info( + "Reaping worker_id=[%s]. Expired timers: %s", + worker_id, + self._get_scopes(expired_timers), + ) + if self._reap_worker_no_throw(worker_id): + logger.info("Successfully reaped worker=[%s]", worker_id) + reaped_worker_ids.add(worker_id) + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", worker_id + ) + self.clear_timers(reaped_worker_ids) + + def _get_scopes(self, timer_requests): + return [r.scope_id for r in timer_requests] + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s", + type(self).__name__, + self._max_interval, + self._daemon, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + + +_timer_client: Optional[TimerClient] = None + + +def configure(timer_client: TimerClient): + """ + Configures a timer client. Must be called before using ``expires``. + """ + global _timer_client + _timer_client = timer_client + logger.info("Timer client configured to: %s", type(_timer_client).__name__) + + +@contextmanager +def expires( + after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None +): + """ + Acquires a countdown timer that expires in ``after`` seconds from now, + unless the code-block that it wraps is finished within the timeframe. + When the timer expires, this worker is eligible to be reaped. The + exact meaning of "reaped" depends on the client implementation. In + most cases, reaping means to terminate the worker process. + Note that the worker is NOT guaranteed to be reaped at exactly + ``time.now() + after``, but rather the worker is "eligible" for being + reaped and the ``TimerServer`` that the client talks to will ultimately + make the decision when and how to reap the workers with expired timers. + + Usage:: + + torch.distributed.elastic.timer.configure(LocalTimerClient()) + with expires(after=10): + torch.distributed.all_reduce(...) + """ + if client is None: + if _timer_client is None: + raise RuntimeError("Configure timer client before using countdown timers.") + client = _timer_client + if scope is None: + # grab the caller file + lineno + caller = getframeinfo(stack()[1][0]) + scope = f"{caller.filename}#{caller.lineno}" + expiration = time.time() + after + client.acquire(scope, expiration) + try: + yield + finally: + client.release(scope) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/debug_info_logging.py b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/debug_info_logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e42e2fb0c2f02ea3c577570ffe7417130357d518 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/debug_info_logging.py @@ -0,0 +1,24 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.distributed.elastic.utils.logging import get_logger + + +logger = get_logger(__name__) + +__all__ = ["log_debug_info_for_expired_timers"] + + +def log_debug_info_for_expired_timers( + run_id: str, + expired_timers: dict[int, list[str]], +): + if expired_timers: + logger.info("Timers expired for run:[%s] [%s].", run_id, expired_timers) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..9c39f1893333a09349c9e826ed2cb51bc251943c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/file_based_local_timer.py @@ -0,0 +1,442 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import io +import json +import os +import select +import signal +import sys +import threading +import time +from typing import Callable, Optional, TypeVar +from typing_extensions import ParamSpec + +from torch.distributed.elastic.timer.api import TimerClient, TimerRequest +from torch.distributed.elastic.timer.debug_info_logging import ( + log_debug_info_for_expired_timers, +) +from torch.distributed.elastic.utils.logging import get_logger + + +_P = ParamSpec("_P") +_R = TypeVar("_R") + +__all__ = ["FileTimerClient", "FileTimerRequest", "FileTimerServer"] + +logger = get_logger(__name__) + + +def _retry(max_retries: int, sleep_time: float) -> Callable: + """ + A simple retry wrapper. + + Args: + max_retries: int, the maximum number of retries. + sleep_time: float, the time to sleep between retries. + """ + + def wrapper(func: Callable[_P, _R]) -> Callable[_P, _R]: + def wrapper(*args: _P.args, **kwargs: _P.kwargs): + for i in range(max_retries): + try: + return func(*args, **kwargs) + except Exception: + logger.exception("Error running %s. Retrying...", func.__name__) + if i < max_retries - 1: + time.sleep(sleep_time) + else: + raise + + return wrapper + + return wrapper + + +class FileTimerRequest(TimerRequest): + """ + Data object representing a countdown timer acquisition and release + that is used between the ``FileTimerClient`` and ``FileTimerServer``. + A negative ``expiration_time`` should be interpreted as a "release" + request. + ``signal`` is the signal to reap the worker process from the server + process. + """ + + __slots__ = ["version", "worker_pid", "scope_id", "expiration_time", "signal"] + + def __init__( + self, worker_pid: int, scope_id: str, expiration_time: float, signal: int = 0 + ) -> None: + self.version = 1 + self.worker_pid = worker_pid + self.scope_id = scope_id + self.expiration_time = expiration_time + self.signal = signal + + def __eq__(self, other) -> bool: + if isinstance(other, FileTimerRequest): + return ( + self.version == other.version + and self.worker_pid == other.worker_pid + and self.scope_id == other.scope_id + and self.expiration_time == other.expiration_time + and self.signal == other.signal + ) + return False + + def to_json(self) -> str: + return json.dumps( + { + "version": self.version, + "pid": self.worker_pid, + "scope_id": self.scope_id, + "expiration_time": self.expiration_time, + "signal": self.signal, + }, + ) + + +class FileTimerClient(TimerClient): + """ + Client side of ``FileTimerServer``. This client is meant to be used + on the same host that the ``FileTimerServer`` is running on and uses + pid to uniquely identify a worker. + This client uses a named_pipe to send timer requests to the + ``FileTimerServer``. This client is a producer while the + ``FileTimerServer`` is a consumer. Multiple clients can work with + the same ``FileTimerServer``. + + Args: + + file_path: str, the path of a FIFO special file. ``FileTimerServer`` + must have created it by calling os.mkfifo(). + + signal: signal, the signal to use to kill the process. Using a + negative or zero signal will not kill the process. + """ + + def __init__( + self, + file_path: str, + signal=(signal.SIGKILL if sys.platform != "win32" else signal.CTRL_C_EVENT), # type: ignore[attr-defined] + ) -> None: + super().__init__() + self._file_path = file_path + self.signal = signal + + @_retry(max_retries=10, sleep_time=0.1) + def _open_non_blocking(self) -> Optional[io.TextIOWrapper]: + # The server may have crashed or may haven't started yet. + # In such case, calling open() in blocking model blocks the client. + # To avoid such issue, open it in non-blocking mode, and an OSError will + # be raised if the server is not there. + fd = os.open(self._file_path, os.O_WRONLY | os.O_NONBLOCK) + return os.fdopen(fd, "wt") + + def _send_request(self, request: FileTimerRequest) -> None: + try: + file = self._open_non_blocking() + except Exception as e: + raise BrokenPipeError( + "Could not send the FileTimerRequest because FileTimerServer is not available." + ) from e + with file: + json_request = request.to_json() + # Write request with no greater than select.PIPE_BUF is guarantee to be atomic. + if len(json_request) > select.PIPE_BUF: + raise RuntimeError( + f"FileTimerRequest larger than {select.PIPE_BUF} bytes " + f"is not supported: {json_request}" + ) + file.write(json_request + "\n") + + def acquire(self, scope_id: str, expiration_time: float) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), + scope_id=scope_id, + expiration_time=expiration_time, + signal=self.signal, + ), + ) + + def release(self, scope_id: str) -> None: + self._send_request( + request=FileTimerRequest( + worker_pid=os.getpid(), scope_id=scope_id, expiration_time=-1, signal=0 + ), + ) + + +class FileTimerServer: + """ + Server that works with ``FileTimerClient``. Clients are expected to be + running on the same host as the process that is running this server. + Each host in the job is expected to start its own timer server locally + and each server instance manages timers for local workers (running on + processes on the same host). + + Args: + + file_path: str, the path of a FIFO special file to be created. + + max_interval: float, max interval in seconds for each watchdog loop. + + daemon: bool, running the watchdog thread in daemon mode or not. + A daemon thread will not block a process to stop. + log_event: Callable[[Dict[str, str]], None], an optional callback for + logging the events in JSON format. + """ + + def __init__( + self, + file_path: str, + run_id: str, + max_interval: float = 10, + daemon: bool = True, + log_event: Optional[Callable[[str, Optional[FileTimerRequest]], None]] = None, + ) -> None: + self._file_path = file_path + self._run_id = run_id + self._max_interval = max_interval + self._daemon = daemon + self._timers: dict[tuple[int, str], FileTimerRequest] = {} + self._stop_signaled = False + self._watchdog_thread: Optional[threading.Thread] = None + + self._is_client_started = False + if os.path.exists(self._file_path): + os.remove(self._file_path) + os.mkfifo(self._file_path) + # For test only. Count the number of requests received. + self._request_count = 0 + # For test only. Process all requests and stop the server. + self._run_once = False + self._log_event = ( + log_event if log_event is not None else lambda name, request: None + ) + self._last_progress_time = int(time.time()) + + def start(self) -> None: + logger.info( + "Starting %s... max_interval=%s, daemon=%s, file_path=%s", + type(self).__name__, + self._max_interval, + self._daemon, + self._file_path, + ) + self._watchdog_thread = threading.Thread( + target=self._watchdog_loop, daemon=self._daemon + ) + logger.info("Starting watchdog thread...") + self._watchdog_thread.start() + self._log_event("watchdog started", None) + + def stop(self) -> None: + logger.info("Stopping %s", type(self).__name__) + self._stop_signaled = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join(self._max_interval) + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + self._log_event("watchdog stopped", None) + + def run_once(self) -> None: + self._run_once = True + if self._watchdog_thread: + logger.info("Stopping watchdog thread...") + self._watchdog_thread.join() + self._watchdog_thread = None + else: + logger.info("No watchdog thread running, doing nothing") + if os.path.exists(self._file_path): + os.remove(self._file_path) + + @staticmethod + def is_process_running(pid: int): + """ + function to check process is running or not + """ + try: + # Check if the process exists and we can send signals to it + os.kill(pid, 0) + return True + except OSError: + return False + + def _watchdog_loop(self) -> None: + # Open the pipe in blocking mode blocks the server thread. + # This is fine for the following reasons: + # 1. No client case usually does not happen. + # 2. We are running the watchdog loop in a separate daemon + # thread, which will not block the process to stop. + try: + fd = open(self._file_path) + except Exception: + logger.exception("Could not open the FileTimerServer pipe") + raise + + with fd: + self._is_client_started = True + while not self._stop_signaled: + try: + run_once = self._run_once + self._run_watchdog(fd) + if run_once: + break + self._last_progress_time = int(time.time()) + except Exception: + logger.exception("Error running watchdog") + + def _run_watchdog(self, fd: io.TextIOWrapper) -> None: + timer_requests = self._get_requests(fd, self._max_interval) + self.register_timers(timer_requests) + now = time.time() + reaped_worker_pids = set() + kill_process = False + reap_signal = 0 + + all_expired_timers = self.get_expired_timers(now) + log_debug_info_for_expired_timers( + self._run_id, + { + pid: [expired_timer.to_json() for expired_timer in expired_timers] + for pid, expired_timers in all_expired_timers.items() + }, + ) + + for worker_pid, expired_timers in all_expired_timers.items(): + logger.info( + "Reaping worker_pid=[%s]. Expired timers: %s", + worker_pid, + self._get_scopes(expired_timers), + ) + reaped_worker_pids.add(worker_pid) + # In case we have multiple expired timers, we find the first timer + # with a valid signal (>0) in the expiration time order. + expired_timers.sort(key=lambda timer: timer.expiration_time) + signal = 0 + expired_timer = None + for timer in expired_timers: + self._log_event("timer expired", timer) + if timer.signal > 0: + signal = timer.signal + expired_timer = timer + break + if signal <= 0: + logger.info( + "No signal specified with worker=[%s]. Do not reap it.", worker_pid + ) + continue + if self._reap_worker(worker_pid, signal): + logger.info( + "Successfully reaped worker=[%s] with signal=%s", worker_pid, signal + ) + self._log_event("kill worker process", expired_timer) + kill_process = True + reap_signal = signal + else: + logger.error( + "Error reaping worker=[%s]. Will retry on next watchdog.", + worker_pid, + ) + if kill_process and reap_signal > 0: + logger.info( + "Terminating the server process=[%s] because of expired timers", + os.getpid(), + ) + self._reap_worker(os.getpid(), reap_signal) + + self.clear_timers(reaped_worker_pids) + + def _get_scopes(self, timer_requests: list[FileTimerRequest]) -> list[str]: + return [r.scope_id for r in timer_requests] + + def _get_requests( + self, fd: io.TextIOWrapper, max_interval: float + ) -> list[FileTimerRequest]: + start = time.time() + requests = [] + while not self._stop_signaled or self._run_once: + # For named pipe, readline() is blocking when at least one writer opens. + # It returns only when flush() is called at the writer side. + # Note that flush() is automatically called inside close(). + # After the last writer closes, readline() is not blocking. + # It will return an empty string when it's at end-of-file. + # Since the client side always opens the pipe, writes a message and closes + # the pipe immediately, the readline() call below is not blocking for long. + json_request = fd.readline() + if len(json_request) == 0: + if self._run_once: + break + time.sleep(min(max_interval, 1)) + else: + request = json.loads(json_request) + pid = request["pid"] + scope_id = request["scope_id"] + expiration_time = request["expiration_time"] + signal = request["signal"] + requests.append( + FileTimerRequest( + worker_pid=pid, + scope_id=scope_id, + expiration_time=expiration_time, + signal=signal, + ) + ) + now = time.time() + if now - start > max_interval: + break + return requests + + def register_timers(self, timer_requests: list[FileTimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_pid + scope_id = request.scope_id + expiration_time = request.expiration_time + self._request_count += 1 + + key = (pid, scope_id) + # negative expiration is a proxy for a release call + if expiration_time < 0: + if key in self._timers: + del self._timers[key] + else: + self._timers[key] = request + + def clear_timers(self, worker_pids: set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_pids or not FileTimerServer.is_process_running(pid): + del self._timers[(pid, scope_id)] + + def get_expired_timers(self, deadline: float) -> dict[int, list[FileTimerRequest]]: + # pid -> [timer_requests...] + expired_timers: dict[int, list[FileTimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_pid, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_pid: int, signal: int) -> bool: + try: + os.kill(worker_pid, signal) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_pid) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_pid) + return False + + def get_last_progress_time(self) -> int: + return self._last_progress_time if self._is_client_started else int(time.time()) diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/timer/local_timer.py b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/local_timer.py new file mode 100644 index 0000000000000000000000000000000000000000..13bc0462df082e36c3043a45e520a97135c1acdc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/timer/local_timer.py @@ -0,0 +1,128 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import logging +import multiprocessing as mp +import os +import signal +import time +from queue import Empty +from typing import Any + +from .api import RequestQueue, TimerClient, TimerRequest, TimerServer + + +__all__ = ["LocalTimerClient", "MultiprocessingRequestQueue", "LocalTimerServer"] + +logger = logging.getLogger(__name__) + + +class LocalTimerClient(TimerClient): + """ + Client side of ``LocalTimerServer``. This client is meant to be used + on the same host that the ``LocalTimerServer`` is running on and uses + pid to uniquely identify a worker. This is particularly useful in situations + where one spawns a subprocess (trainer) per GPU on a host with multiple + GPU devices. + """ + + def __init__(self, mp_queue): + super().__init__() + self._mp_queue = mp_queue + + def acquire(self, scope_id, expiration_time): + pid = os.getpid() + acquire_request = TimerRequest(pid, scope_id, expiration_time) + self._mp_queue.put(acquire_request) + + def release(self, scope_id): + pid = os.getpid() + release_request = TimerRequest(pid, scope_id, -1) + self._mp_queue.put(release_request) + + +class MultiprocessingRequestQueue(RequestQueue): + """ + A ``RequestQueue`` backed by python ``multiprocessing.Queue`` + """ + + def __init__(self, mp_queue: mp.Queue): + super().__init__() + self._mp_queue = mp_queue + + def size(self) -> int: + return self._mp_queue.qsize() + + def get(self, size, timeout: float) -> list[TimerRequest]: + requests = [] + wait = timeout + for _ in range(0, size): + start = time.time() + + try: + r = self._mp_queue.get(block=True, timeout=wait) + except Empty: + break + + requests.append(r) + wait = wait - (time.time() - start) + if wait <= 0: + break + + return requests + + +class LocalTimerServer(TimerServer): + """ + Server that works with ``LocalTimerClient``. Clients are expected to be + subprocesses to the parent process that is running this server. Each host + in the job is expected to start its own timer server locally and each + server instance manages timers for local workers (running on processes + on the same host). + """ + + def __init__( + self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True + ): + super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) + self._timers: dict[tuple[Any, str], TimerRequest] = {} + + def register_timers(self, timer_requests: list[TimerRequest]) -> None: + for request in timer_requests: + pid = request.worker_id + scope_id = request.scope_id + expiration_time = request.expiration_time + + # negative expiration is a proxy for a release call + if expiration_time < 0: + self._timers.pop((pid, scope_id), None) + else: + self._timers[(pid, scope_id)] = request + + def clear_timers(self, worker_ids: set[int]) -> None: + for pid, scope_id in list(self._timers.keys()): + if pid in worker_ids: + self._timers.pop((pid, scope_id)) + + def get_expired_timers(self, deadline: float) -> dict[Any, list[TimerRequest]]: + # pid -> [timer_requests...] + expired_timers: dict[Any, list[TimerRequest]] = {} + for request in self._timers.values(): + if request.expiration_time <= deadline: + expired_scopes = expired_timers.setdefault(request.worker_id, []) + expired_scopes.append(request) + return expired_timers + + def _reap_worker(self, worker_id: int) -> bool: + try: + os.kill(worker_id, signal.SIGKILL) + return True + except ProcessLookupError: + logger.info("Process with pid=%s does not exist. Skipping", worker_id) + return True + except Exception: + logger.exception("Error terminating pid=%s", worker_id) + return False diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbc76bf70244c273d84c617a96dfc9827f1ae70 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__init__.py @@ -0,0 +1,9 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .api import get_env_variable_or_raise, get_socket_with_port, macros # noqa: F401 diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e06c1efe2ba4185e15564d58005cfda2a3932f11 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..937e9cd8745abd2dfdf4588975b4aa977469c272 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c41179569b27672986e5e99d4e0a1e5b6ab0eb8d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/distributed.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9b71a533047113ee9e93845c705ef67e01ff640 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/log_level.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ee8a906a921dfe2ee0524b00c14b902e1207ae2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/logging.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb8e782831b02e95bca912cd957af02d0bb2c7ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/__pycache__/store.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/api.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/api.py new file mode 100644 index 0000000000000000000000000000000000000000..02e1ace69264e937bb6a1046bc8bfbcd648ae1e7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/api.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +import socket +from string import Template +from typing import Any + + +def get_env_variable_or_raise(env_name: str) -> str: + r""" + Tries to retrieve environment variable. Raises ``ValueError`` + if no environment variable found. + + Args: + env_name (str): Name of the env variable + """ + value = os.environ.get(env_name, None) + if value is None: + msg = f"Environment variable {env_name} expected, but not set" + raise ValueError(msg) + return value + + +def get_socket_with_port() -> socket.socket: + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError: + s.close() + raise RuntimeError("Failed to create a socket") + + +class macros: + """ + Defines simple macros for caffe2.distributed.launch cmd args substitution + """ + + local_rank = "${local_rank}" + + @staticmethod + def substitute(args: list[Any], local_rank: str) -> list[str]: + args_sub = [] + for arg in args: + if isinstance(arg, str): + sub = Template(arg).safe_substitute(local_rank=local_rank) + args_sub.append(sub) + else: + args_sub.append(arg) + return args_sub diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__init__.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..73fd6cdd4431a77cc1cb7ae49efc92cedebfab2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__init__.py @@ -0,0 +1,10 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from .cycling_iterator import CyclingIterator # noqa: F401 +from .elastic_distributed_sampler import ElasticDistributedSampler # noqa: F401 diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3363233513d6e7c39054ef17a12026927a4a0ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7bead2a7517589366f1f06dd51487b6bf61f1f58 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/cycling_iterator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad66867c7d563eb91f5a58db19b4bf888d6ac0bb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/__pycache__/elastic_distributed_sampler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..8a2dbc11646cefc5d8aedc3c0ef51054ad05c78c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/cycling_iterator.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 + +from collections.abc import Iterator +from typing import Callable, TypeVar +from typing_extensions import Self + + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +_T = TypeVar("_T") + +__all__ = ["CyclingIterator"] + + +class CyclingIterator(Iterator[_T]): + """ + An iterator decorator that cycles through the + underlying iterator "n" times. Useful to "unroll" + the dataset across multiple training epochs. + + The generator function is called as ``generator_fn(epoch)`` + to obtain the underlying iterator, where ``epoch`` is a + number less than or equal to ``n`` representing the ``k``th cycle + + For example if ``generator_fn`` always returns ``[1,2,3]`` + then ``CyclingIterator(n=2, generator_fn)`` will iterate through + ``[1,2,3,1,2,3]`` + """ + + def __init__( + self, + n: int, + generator_fn: Callable[[int], Iterator[_T]], + start_epoch: int = 0, + ): + self._n = n + self._epoch = start_epoch + self._generator_fn = generator_fn + self._iter = generator_fn(self._epoch) + + def __iter__(self) -> Self: + return self + + def __next__(self) -> _T: + try: + return next(self._iter) + except StopIteration as eod: # eod == end of data + if self._epoch < self._n - 1: + self._epoch += 1 + self._iter = self._generator_fn(self._epoch) + return self.__next__() + else: + raise eod diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..f3b7718edf79f8873c88a7d81d90f24a69b6de73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections.abc import Iterator, Sized +from typing import cast, Optional, TypeVar + +import torch +from torch.utils.data import Dataset +from torch.utils.data.distributed import DistributedSampler + + +T = TypeVar("T") + +__all__ = ["ElasticDistributedSampler"] + + +class ElasticDistributedSampler(DistributedSampler[T]): + """ + Sampler that restricts data loading to a subset of + the dataset for elastic training. + + It is especially useful in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSampler instance as a DataLoader sampler, + and load a subset of the original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be of constant size. + + Args: + dataset: Dataset used for sampling. + num_replicas (optional): Number of processes participating in + distributed training. + rank (optional): Rank of the current process within num_replicas. + start_index (optional): Which index of the dataset to start sampling from + """ + + def __init__( + self, + dataset: Dataset[T], + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + start_index: int = 0, + ): + super().__init__(dataset=dataset, num_replicas=num_replicas, rank=rank) + if not isinstance(dataset, Sized): + raise TypeError("Dataset must be an instance of collections.abc.Sized") + + # Cast to Sized for mypy + sized_dataset = cast(Sized, dataset) + + if start_index >= len(sized_dataset): + raise ValueError( + f"Start index {start_index} should be less than dataset size {len(sized_dataset)}" + ) + + self.start_index = start_index + sized_dataset = cast(Sized, self.dataset) + self.num_samples = int( + math.ceil(float(len(sized_dataset) - self.start_index) / self.num_replicas) + ) + self.total_size = self.num_samples * self.num_replicas + + def __iter__(self) -> Iterator[T]: + # deterministically shuffle based on epoch + g = torch.Generator() + g.manual_seed(self.epoch) + sized_dataset = cast(Sized, self.dataset) + indices = ( + torch.randperm(len(sized_dataset) - self.start_index, generator=g) + .add(self.start_index) + .tolist() + ) + + # add extra samples to make it evenly divisible + indices += indices[: (self.total_size - len(indices))] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/distributed.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..8f696c819bc01f23a6862e0e177a04c41c19c225 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/distributed.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import datetime +import os +import socket +from contextlib import closing +from typing import Optional + +import torch.distributed as dist +from torch.distributed.elastic.utils.logging import get_logger +from torch.distributed.elastic.utils.store import barrier + + +__all__ = ["create_c10d_store", "get_free_port", "get_socket_with_port"] + +logger = get_logger(__name__) + +_ADDRESS_IN_USE = "Address already in use" +_SOCKET_TIMEOUT = "Socket Timeout" + +_TCP_STORE_INIT = "_tcp_store/num_members" + + +def create_c10d_store( + is_server: bool, + server_addr: str, + server_port: int = -1, + world_size: int = 1, + timeout: float = (60 * 10), # 10 min + wait_for_workers: bool = True, + retries=3, + use_libuv: Optional[bool] = None, +): + if use_libuv is not None: + logger.warning( + "argument use_libuv is deprecated and ignored. Set USE_LIBUV environment " + 'variable to "0" to disable libuv, or "1" to enable it. If the env var ' + "is not set, libuv will be used by default." + ) + + # check os.environ for use_libuv + use_libuv = os.environ.get("USE_LIBUV", "1") == "1" # libuv is the default option + + if server_port == -1 and world_size > 1: + raise ValueError( + f"server_port must be specified when world_size > 1, got server_port={server_port}, world_size={world_size}" + ) + + if server_port != -1: + logger.info("sever_port: %s, specified, ignoring retries", server_port) + + # only retry when server_port is NOT static + attempt = retries if server_port == -1 else 1 + while True: + if server_port != -1: + port = server_port + else: + port = get_free_port() + + logger.info( + "Creating c10d store on %s:%s\n" + " world_size : %s\n" + " is_server : %s\n" + " timeout(sec): %s\n" + " use_libuv : %s\n", + server_addr, + port, + world_size, + is_server, + timeout, + use_libuv, + ) + + try: + store = dist.TCPStore( + host_name=server_addr, + port=port, + world_size=world_size, + is_master=is_server, + timeout=datetime.timedelta(seconds=timeout), + wait_for_workers=wait_for_workers, + use_libuv=use_libuv, + ) + # skips full rank check when we don't have to wait for all workers + if wait_for_workers: + _check_full_rank(store, world_size, timeout=timeout) + logger.info("Successfully created c10d store") + return store + except RuntimeError as e: + # this is brittle, but the underlying exception type is not properly pybinded + # so we parse the error msg for now, interestingly this is how torch itself + # detects timeouts and port conflicts in their own unittests + # see - caffe2/torch/testing/_internal/common_utils.py + # TODO properly map the exceptions in pybind (c10d/init.cpp) + if str(e) == _ADDRESS_IN_USE: # this will only happen on the server + if attempt < retries: + logger.warning( + "port: %s already in use, attempt: [%s/%s]", + port, + attempt, + retries, + ) + attempt += 1 + else: + raise RuntimeError( + f"on {server_addr}, port: {port} already in use" + ) from e + else: + raise + + +def _check_full_rank(store, world_size, timeout): + try: + barrier(store, world_size, key_prefix=_TCP_STORE_INIT, barrier_timeout=timeout) + except RuntimeError as e: + if str(e) == _SOCKET_TIMEOUT: + raise TimeoutError( + f"timed out waiting for all {world_size} members to join" + ) from e + else: + raise + + +def get_free_port(): + """ + Returns an unused port on localhost. + + This function finds an unused port on localhost by opening to socket to bind + to a port and then closing it. + + Returns: + int: an unused port on localhost + + Example: + >>> # xdoctest: +SKIP("Nondeterministic") + >>> get_free_port() + 63976 + + .. note:: + The port returned by :func:`get_free_port` is not reserved and may be + taken by another process after this function returns. + """ + sock = get_socket_with_port() + with closing(sock): + return sock.getsockname()[1] + + +def get_socket_with_port() -> socket.socket: + """ + Returns a free port on localhost that is "reserved" by binding a temporary + socket on it. Close the socket before passing the port to the entity + that requires it. Usage example + + :: + + sock = _get_socket_with_port() + with closing(sock): + port = sock.getsockname()[1] + sock.close() + # there is still a race-condition that some other process + # may grab this port before func() runs + func(port) + """ + + addrs = socket.getaddrinfo( + host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM + ) + for addr in addrs: + family, type, proto, _, _ = addr + s = socket.socket(family, type, proto) + try: + s.bind(("localhost", 0)) + s.listen(0) + return s + except OSError as e: + s.close() + logger.warning("Socket creation attempt failed.", exc_info=e) + raise RuntimeError("Failed to create a socket") diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/log_level.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/log_level.py new file mode 100644 index 0000000000000000000000000000000000000000..cf2d31347aeeb3ebc63af253a3f4db678cfdc0fc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/log_level.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +def get_log_level() -> str: + """ + Return default log level for pytorch. + """ + return "WARNING" diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/logging.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..e46ea9512e78e72506eb9a7a84e27726b7464908 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/logging.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import inspect +import logging +import os +import warnings +from typing import Optional + +from torch.distributed.elastic.utils.log_level import get_log_level + + +def get_logger(name: Optional[str] = None) -> logging.Logger: + """ + Util function to set up a simple logger that writes + into stderr. The loglevel is fetched from the LOGLEVEL + env. variable or WARNING as default. The function will use the + module name of the caller if no name is provided. + + Args: + name: Name of the logger. If no name provided, the name will + be derived from the call stack. + """ + + # Derive the name of the caller, if none provided + # Use depth=2 since this function takes up one level in the call stack + return _setup_logger(name or _derive_module_name(depth=2)) + + +def _setup_logger(name: Optional[str] = None) -> logging.Logger: + logger = logging.getLogger(name) + logger.setLevel(os.environ.get("LOGLEVEL", get_log_level())) + return logger + + +def _derive_module_name(depth: int = 1) -> Optional[str]: + """ + Derives the name of the caller module from the stack frames. + + Args: + depth: The position of the frame in the stack. + """ + try: + stack = inspect.stack() + assert depth < len(stack) + # FrameInfo is just a named tuple: (frame, filename, lineno, function, code_context, index) + frame_info = stack[depth] + + module = inspect.getmodule(frame_info[0]) + if module: + module_name = module.__name__ + else: + # inspect.getmodule(frame_info[0]) does NOT work (returns None) in + # binaries built with @mode/opt + # return the filename (minus the .py extension) as modulename + filename = frame_info[1] + module_name = os.path.splitext(os.path.basename(filename))[0] + return module_name + except Exception as e: + warnings.warn( + f"Error deriving logger module name, using . Exception: {e}", + RuntimeWarning, + ) + return None diff --git a/phivenv/Lib/site-packages/torch/distributed/elastic/utils/store.py b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/store.py new file mode 100644 index 0000000000000000000000000000000000000000..45632abb765c8e0e3a3b47b561f38f0d0a4bdca1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/elastic/utils/store.py @@ -0,0 +1,226 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Iterable +from contextlib import contextmanager +from datetime import timedelta +from typing import Callable, Optional + +import torch + + +DistStoreError = torch._C._DistStoreError + +_NUM_MEMBERS = "/num_members" +_LAST_MEMBER_CHECKIN = "/last_member" +_TRACE = "/TRACE" +_TRACING_GATE = "/TRACING_GATE" +_MAX_TRACE_MISSING_RANKS = 16 + + +__all__ = ["store_timeout", "get_all", "synchronize", "barrier"] + + +@contextmanager +def store_timeout(store, timeout: float): + """ + This sets the timeout and then restores the old timeout when the context + manager exits. + + Args: + store: the store to set the timeout on + timeout: the timeout to set + """ + + old_timeout = store.timeout + store.set_timeout(timedelta(seconds=timeout)) + yield + store.set_timeout(old_timeout) + + +def get_all(store, rank: int, prefix: str, world_size: int): + r""" + Given a store and a prefix, the method goes through the array of keys + of the following format: ``{prefix}{idx}``, where idx is in a range + from 0 to size, and tries to retrieve the data. + + The Rank0 process waits at the end to make sure all other processes + finished the procedure before exiting. + + Usage + + :: + + values = get_all(store, "torchelastic/data", 3) + value1 = values[0] # retrieves the data for key torchelastic/data0 + value2 = values[1] # retrieves the data for key torchelastic/data1 + value3 = values[2] # retrieves the data for key torchelastic/data2 + + """ + data_arr = store.multi_get([f"{prefix}{idx}" for idx in range(world_size)]) + + barrier_key = _barrier_nonblocking( + store=store, + world_size=world_size, + key_prefix=f"{prefix}/finished", + ) + if rank == 0: + # Rank0 runs the TCPStore daemon, as a result it needs to exit last. + # Otherwise, the barrier may timeout if rank0 process finished the work + # before other processes finished `get_all` method + store.wait([barrier_key]) + + return data_arr + + +def synchronize( + store, + data: bytes, + rank: int, + world_size: int, + key_prefix: str, + timeout: float = 300, +) -> list[bytes]: + """ + Synchronizes ``world_size`` agents between each other using the underlying c10d store. + The ``data`` will be available on each of the agents. + + Note: The data on the path is not deleted, as a result there can be stale data if + you use the same key_prefix twice. + + Time complexity: O(N) per worker, O(N^2) globally. + """ + with store_timeout(store, timeout): + store.set(f"{key_prefix}{rank}", data) + agent_data = get_all(store, rank, key_prefix, world_size) + return agent_data + + +def _try_detecting_missing_ranks( + store, + world_size: int, + key_prefix: str, + rank: int, + rank_decoder: Callable[[int], str], + trace_timeout: float, +) -> Optional[Iterable[str]]: + store.set(f"{key_prefix}{rank}{_TRACE}", "") + + def _find_missing_ranks(): + missing_rank_info = set() + ranks_missing = 0 + for i in range(1, world_size): + # reduce noise, assuming in general 8 ranks per node + # It is valuable to know that 1 or >1 nodes have timed-out. + if ranks_missing >= _MAX_TRACE_MISSING_RANKS: + break + try: + if ranks_missing == 0: + store.wait( + [f"{key_prefix}{i}{_TRACE}"], timedelta(seconds=trace_timeout) + ) + else: + # use a shortest timeout, some ranks have failed to check-in + store.wait([f"{key_prefix}{i}{_TRACE}"], timedelta(milliseconds=1)) + except DistStoreError: + ranks_missing += 1 + missing_rank_info.add(rank_decoder(i)) + return missing_rank_info + + def _checkin(): + try: + store.wait([f"{key_prefix}{_TRACING_GATE}"]) + return [f"[]"] + except DistStoreError: + # in case rank0 is the source of the timeout, original exception will be raised + return None + + if rank == 0: + missing_rank_info = _find_missing_ranks() + store.set(f"{key_prefix}{_TRACING_GATE}", "") + return missing_rank_info + else: + return _checkin() + + +def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str: + """ + Does all the non-blocking operations for a barrier and returns the final key + that can be waited on. + """ + num_members_key = key_prefix + _NUM_MEMBERS + last_member_key = key_prefix + _LAST_MEMBER_CHECKIN + + idx = store.add(num_members_key, 1) + if idx == world_size: + store.set(last_member_key, "") + + return last_member_key + + +def barrier( + store, + world_size: int, + key_prefix: str, + barrier_timeout: float = 300, + rank: Optional[int] = None, + rank_tracing_decoder: Optional[Callable[[int], str]] = None, + trace_timeout: float = 10, +) -> None: + """ + A global lock between agents. This will pause all workers until at least + ``world_size`` workers respond. + + This uses a fast incrementing index to assign waiting ranks and a success + flag set by the last worker. + + Time complexity: O(1) per worker, O(N) globally. + + Optionally, passing rank will enable tracing of missing ranks on timeouts. + `rank_tracing_decoder` lambda arg can be used to convert rank data + into a more meaningful information at an app level (e.g. hostname). + + Note: Since the data is not removed from the store, the barrier can be used + once per unique ``key_prefix``. + """ + + if rank is None: + assert rank_tracing_decoder is None, "Tracing requires rank information" + + with store_timeout(store, barrier_timeout): + last_member_key = _barrier_nonblocking( + store=store, world_size=world_size, key_prefix=key_prefix + ) + try: + store.wait([last_member_key]) + except DistStoreError as e: + if rank is None: + raise e + else: + missing_ranks = _try_detecting_missing_ranks( + store, + world_size, + key_prefix, + rank, + rank_tracing_decoder or (lambda x: str(x)), + trace_timeout, + ) + if missing_ranks is not None: + raise DistStoreError( + "Timed out waiting on barrier on " + "rank {}, for key prefix: {} (world_size={}, missing_ranks={}, timeout={})".format( + rank, + key_prefix, + world_size, + f"[{', '.join(missing_ranks)}]", + barrier_timeout, + ) + ) from None + else: + raise e diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__init__.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0fce47e484e7155630e3db14f0cbb43205ae69e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/__init__.py @@ -0,0 +1,66 @@ +from ._flat_param import FlatParameter as FlatParameter +from ._fully_shard import ( + CPUOffloadPolicy, + FSDPModule, + fully_shard, + MixedPrecisionPolicy, + OffloadPolicy, + register_fsdp_forward_method, + UnshardHandle, +) +from .fully_sharded_data_parallel import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + FullyShardedDataParallel, + LocalOptimStateDictConfig, + LocalStateDictConfig, + MixedPrecision, + OptimStateDictConfig, + OptimStateKeyType, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictSettings, + StateDictType, +) + + +__all__ = [ + # FSDP1 + "BackwardPrefetch", + "CPUOffload", + "FullOptimStateDictConfig", + "FullStateDictConfig", + "FullyShardedDataParallel", + "LocalOptimStateDictConfig", + "LocalStateDictConfig", + "MixedPrecision", + "OptimStateDictConfig", + "OptimStateKeyType", + "ShardedOptimStateDictConfig", + "ShardedStateDictConfig", + "ShardingStrategy", + "StateDictConfig", + "StateDictSettings", + "StateDictType", + # FSDP2 + "CPUOffloadPolicy", + "FSDPModule", + "fully_shard", + "MixedPrecisionPolicy", + "OffloadPolicy", + "register_fsdp_forward_method", + "UnshardHandle", +] + +# Set namespace for exposed private names +CPUOffloadPolicy.__module__ = "torch.distributed.fsdp" +FSDPModule.__module__ = "torch.distributed.fsdp" +fully_shard.__module__ = "torch.distributed.fsdp" +MixedPrecisionPolicy.__module__ = "torch.distributed.fsdp" +OffloadPolicy.__module__ = "torch.distributed.fsdp" +register_fsdp_forward_method.__module__ = "torch.distributed.fsdp" +UnshardHandle.__module__ = "torch.distributed.fsdp" diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25599e4d0f30c21d7ddfa4c6b7b499778adc21fe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa4b86ea682c63e66175761736bbd26812fcc614 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_common_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b512cefc75eaf6ecd46ee7b84bf297b852585ee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_debug_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46ffd4d3cdd698396ae944f4d36f8e817f598255 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_dynamo_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12ea9cb625767dbcb4bdfd189496627b09cc3a15 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_exec_order_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8544eaf99fa3bd06f898dcfbdd7708947f02fd2d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_flat_param.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07473b50d28ee50acd745d425f299a3ba2100ebe Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_fsdp_extensions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9a7af0dc50b033dd0bc22f34929d97a5540b7e3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_init_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5bf2ef9b1dc8f110190f74d3ffa1bc61a09b00e5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_limiter_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a79331be9d4a9124a222ea2094458b9487ee6bef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_optim_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59cd93a40fbce51fea1b829558087ef512438965 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_runtime_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..578f00e343f03a78de6fc32b6c371ca31c918c64 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_shard_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b20029c9f99ffa04c9f0a28bd9e923225f299e9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_state_dict_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..500ca1be9e6af119d176207d659797c70bdfc7f5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_trace_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..579c8302d7c6e7c297bcbf110a7f583a9f964866 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_traversal_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f72fa2bbd621cbf015c098767faef476e99e7114 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_unshard_param_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3b9b71000ea5f1de33d6abe316509ebd734fd2f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/_wrap_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23882e2b50e68aa1ee8237dc75f64bcc214b22de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20396c764addf827728802a39f61150de9385c0d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/fully_sharded_data_parallel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d533a7087e8110dbb99a6c6120ac2f6ffd60e43 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/sharded_grad_scaler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f72a77a5b2d2395af9899532647e4b6673a970d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/__pycache__/wrap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_common_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_common_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4489b77c03c163aac091462f39d8d15ebc993d3c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_common_utils.py @@ -0,0 +1,546 @@ +# mypy: allow-untyped-defs +""" +This file includes private common utilities for FSDP. +""" + +import logging +import traceback +import warnings +import weakref +from collections.abc import Generator, Iterable +from enum import auto, Enum +from functools import partial +from itertools import chain +from typing import Any, Callable, cast, no_type_check, Optional, TYPE_CHECKING + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._flat_param as flat_param_file +import torch.nn as nn +from torch.distributed._composable_state import _get_module_state, _State +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_PREFIX, +) +from torch.distributed.utils import _apply_to_tensors +from torch.utils._mode_utils import no_dispatch + +from .api import ( + FullOptimStateDictConfig, + FullStateDictConfig, + OptimStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictType, +) + + +if TYPE_CHECKING: + from torch.distributed.device_mesh import DeviceMesh + from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions + + from ._flat_param import FlatParamHandle + +FSDP_WRAPPED_MODULE = "_fsdp_wrapped_module" +FSDP_PREFIX = FSDP_WRAPPED_MODULE + "." +FSDP_FLATTENED = "_fsdp_flattened" + +# Save a global mapping from module to its input tensor dtype to be populated +# during the forward pre-hook and consumed in the forward post-hook when +# overriding a module's mixed precision +# NOTE: We currently take the last input tensor's dtype in the case of multiple +# floating-point input tensors, which may be incorrect. However, since there is +# not a 1:1 correspondence between input and output tensors, we must use *some* +# heuristic like this to predict the desired output dtype. +_MODULE_TO_INP_DTYPE: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() + + +class _FSDPDeviceHandle: + """ + This is a simple abstraction for FSDP computing devices, + which enables custom backends that implement CUDA-like + semantics to be integrated with FSDP. + """ + + def __init__(self, device: torch.device, backend: Any = None): + if backend is None: + try: + self.__backend = getattr(torch, device.type) + self.__device = device + except AttributeError as exc: + raise AttributeError( + f"Device '{device}' does not have a corresponding backend registered as 'torch.{device.type}'." + ) from exc + else: + self.__backend = backend + + @classmethod + def from_device(cls, device: torch.device) -> "_FSDPDeviceHandle": + """ + Return a device handle corresponding to the device, and through this handle, + operations with the same semantics as CUDA can be performed on the device. + Just return torch.cuda if the device is cuda to make attribute-access faster. + Custom backend must first register a module with the same name with {device.type} on torch. + """ + if device.type == "cuda": + return cast(_FSDPDeviceHandle, torch.cuda) + elif device.type == "mtia": + return cast(_FSDPDeviceHandle, torch.mtia) + return cls(device) + + def __getattr__(self, name: str, /) -> Any: + try: + return getattr(self.__backend, name) + except AttributeError as exc: + raise AttributeError( + f"Custom backend '{self.__device.type}' not implement 'torch.{self.__device.type}.{name}'" + ) from exc + + +class _UninitializedDeviceHandle(_FSDPDeviceHandle): + def __init__(self) -> None: + pass + + def __getattribute__(self, name: str, /) -> Any: + raise RuntimeError("Trying to use an uninitialized device handle.") + + +class _FSDPState(_State): + def __init__(self) -> None: + # TODO: Move all the attributes to this class to enable typing for + # FSDP/fully_shard. + self._ignored_modules: set[nn.Module] = set() + self._ignored_params: set[nn.Parameter] = set() + # Buffer names are cleaned (without wrapper prefixes) + self._ignored_buffer_names: set[str] = set() + self.process_group: Optional[dist.ProcessGroup] = None + self.rank: int = -1 + self.world_size: int = -1 + self._device_mesh: Optional[DeviceMesh] = None + self.sharding_strategy = ShardingStrategy.FULL_SHARD + self._use_orig_params: bool = False + self.training_state = TrainingState.IDLE + self._unshard_params_ctx: dict[nn.Module, Generator] = {} + self._state_dict_type: StateDictType = StateDictType.FULL_STATE_DICT + self._state_dict_config: StateDictConfig = FullStateDictConfig() + self._optim_state_dict_config: OptimStateDictConfig = FullOptimStateDictConfig() + self._is_root: Optional[bool] = None + self._handle: Optional[flat_param_file.FlatParamHandle] = None + self._fully_sharded_module_to_handle: dict[ + nn.Module, Optional[flat_param_file.FlatParamHandle] + ] = {} + self.compute_device: Optional[torch.device] = None + self._gradient_predivide_factor: int = 0 + self._gradient_postdivide_factor: int = 0 + self._comm_hook: Optional[Callable] = None + self._comm_hook_state: Optional[Any] = None + self._unshard_event: Optional[torch.Event] = None + # Abstract device handle for fsdp compute device. For now, + # the compute device must implement cuda semantics used by fsdp + self._device_handle: _FSDPDeviceHandle = _UninitializedDeviceHandle() + # All following attributes should only be used for root states: + # Save these static lists to avoid the repeated tree traversals + self._all_fsdp_states: list[_FSDPState] = [] + self._all_handles: list[flat_param_file.FlatParamHandle] = [] + self._fsdp_extension: Optional[FSDPExtensions] = None + + +def _get_module_fsdp_state(module: nn.Module) -> Optional[_FSDPState]: + state = _get_module_state(module) + if state is None or not isinstance(state, _FSDPState): + return None + return state + + +def _get_module_fsdp_state_if_fully_sharded_module( + module: nn.Module, +) -> Optional[_FSDPState]: + state = _get_module_fsdp_state(module) + if state is None: + return None + if state == module: # FullyShardedDataParallel module case. + return state + if module in state._fully_sharded_module_to_handle: # fully_shard case. + return state + return None + + +class TrainingState(Enum): + """ + An enum that indicates the state of a ``FullyShardedDataParallel` instance. + """ + + IDLE = auto() + FORWARD_BACKWARD = auto() + SUMMON_FULL_PARAMS = auto() + + +class HandleTrainingState(Enum): + """ + An enum that indicates the state of a ``FlatParamHandle`. + """ + + IDLE = auto() + FORWARD = auto() + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + SUMMON_FULL_PARAMS = auto() + + +def _is_composable(state: _FSDPState): + # TODO: This is a temporary hack for differentiate between code paths. + return not isinstance(state, nn.Module) + + +@no_type_check +def _module_handle(state: _FSDPState, module: nn.Module) -> Optional["FlatParamHandle"]: + """ + Returns the ``FlatParamHandle`` s corresponding to ``module``. This is + the handle that contains some parameter in ``module``. + """ + if _is_composable(state): + # A valid FSDP state may have no managed parameters and hence no + # handles, meaning no entry in `_fully_sharded_module_to_handles` + if state._handle is None: + return None + assert module in state._fully_sharded_module_to_handle, ( + f"Expects a fully sharded module but got {module} on rank {state.rank}" + ) + return state._fully_sharded_module_to_handle[module] + else: + # NOTE: This assumes `module` is a `FullyShardedDataParallel` instance. + return module._handle + + +@no_type_check +def _has_fsdp_params(state: _FSDPState, module: nn.Module) -> bool: + """Returns if ``module`` has parameters managed by FSDP.""" + return _module_handle(state, module) is not None + + +def _get_sharding_strategy(handle): + """ + Returns the sharding strategy of the handle. + """ + return handle._sharding_strategy if handle else None + + +def clean_tensor_name(tensor_name: str) -> str: + """ + Cleans the parameter or buffer name by removing any module wrapper + prefixes. + """ + tensor_name = tensor_name.replace(FSDP_PREFIX, "") + # TODO: Explicitly replacing the checkpoint wrapper prefix is not ideal as + # it couples `CheckpointWrapper` and FSDP and also does not scale for more + # module wrappers. + tensor_name = tensor_name.replace(_CHECKPOINT_PREFIX, "") + return tensor_name + + +def _set_fsdp_flattened(tensor: torch.Tensor) -> None: + """ + Sets an attribute on ``tensor`` to mark it as flattened by FSDP. This is to + avoid re-flattening it during nested construction. + """ + setattr(tensor, FSDP_FLATTENED, True) + + +def _is_fsdp_flattened(tensor: torch.Tensor) -> bool: + """Returns if ``tensor`` has been marked as flattened by FSDP.""" + return getattr(tensor, FSDP_FLATTENED, False) + + +def _named_parameters_with_duplicates( + module: nn.Module, **kwargs: Any +) -> list[tuple[str, nn.Parameter]]: + """ + This API is required as some modules overwrite `named_parameters()` but do not support + `remove_duplicate`. + """ + assert "remove_duplicate" not in kwargs, ( + "_named_parameters_with_duplicates cannot be used with `remove_duplicate` argument." + ) + kwargs["remove_duplicate"] = False + try: + ret = list(module.named_parameters(**kwargs)) + except AssertionError: + kwargs.pop("remove_duplicate") + ret = list(module.named_parameters(**kwargs)) + return ret + + +def _get_param_to_fqns( + model: torch.nn.Module, + dedup_shared_params: bool = True, +) -> dict[nn.Parameter, list[str]]: + """ + Constructs a mapping from parameter to a list of its \"canonical\" FQNs. Here, + we use canonical to mean the fully-qualified name assigned to the parameter + based on its position in the original nn.Module hierarchy before any wrapper + or parallelism has been applied to it. This is in contrast to FQNs that may be + generated after parallelisms or wrappers have been applied to the model. + + Each normal parameter maps to a singleton list containing its FQN, while each + ``FlatParameter`` maps to a list of its original parameter FQNs, which may + have length greater than one. All FQNs are prefixed starting from ``model``. + + In the case where FSDP was applied with ``use_orig_params=True``, there should be no + ``FlatParameter`` s registered to the model's modules and this mapping will only + contain mappings from ``nn.Parameter`` s to singleton FQN lists. + + It is only in the case where FSDP was applied with ``use_orig_params=False`` where + a ``FlatParameter`` will be registered in place of the original parameters and there + will be mappings from each ``FlatParameter`` to lists of FQNs corresponding to the + original parameters. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance). + dedup_shared_params (bool): For shared parameters, if ``True``, only + includes the FQNs corresponding to the first encounter of the + shared parameter in the module traversal; if ``False``, then + includes the FQNs across all encounters. (Default: ``True``) + """ + + def module_fn(module, prefix, tree_level, param_to_fqns): + for param_name, param in _named_parameters_with_duplicates( + module, recurse=False + ): + local_fqns = ( + param._fqns + if isinstance(param, flat_param_file.FlatParameter) + else [param_name] + ) # prefixed from `module` + global_fqns = [ + clean_tensor_name(prefix + name) for name in local_fqns + ] # prefixed from the top level `model` (i.e. including `prefix`) + is_shared_param = param in param_to_fqns + if not is_shared_param: + param_to_fqns[param] = global_fqns + else: + if isinstance(param, flat_param_file.FlatParameter): + # DMP overwrites `named_parameters` and skip (advance to + # the next child module) the wrapped_module (e.g., + # _dmp_wrapped_module and _fsdp_wrapped_module). When a user + # calls `named_child` to traverse the module recursively and + # calls `named_parameters` with `recurse=False`, parameters + # will be traversed more than once. + # This hack is specified designed for DMP + FSDP. We + # overwrite the flat_parameters traversal result to only obtain + # the last one, which happens to be the correct one. + # + # TODO: Remove this hack once DMP + FSDP is not supported. + warnings.warn( + "FlatParameter is being traversed more than once. " + "This case should only happen when using " + "DistributedModelParallel with FullyShardedDataParallel." + ) + param_to_fqns[param] = global_fqns + elif not dedup_shared_params: + param_to_fqns[param].extend(global_fqns) + + def return_fn(param_to_fqns): + return param_to_fqns + + param_to_unflat_param_names: dict[torch.nn.Parameter, list[str]] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [key for key, _ in _named_parameters_with_duplicates(model)], + param_to_unflat_param_names, + ) + + +@no_type_check +def _log_post_backward_hook( + state: _FSDPState, handle: "FlatParamHandle", logger: logging.Logger +) -> None: + # Under TORCH_DISTRIBUTED_DEBUG=INFO, log the module names this hook fires for. + # Below logging of module names this post-bwd hook fires for can help debug certain + # cases where hooks don't fire, such as under certain activation checkpoint configs. + if state._use_orig_params and handle._debug_level == dist.DebugLevel.INFO: + param_fqns = _get_handle_fqns_from_root(state, handle) + logger.warning("FSDP firing post-backward hooks for parameters %s", param_fqns) + + +@no_type_check +def _get_handle_fqns_from_root( + state: _FSDPState, handle: "FlatParamHandle" +) -> Optional[list[str]]: + if handle is None: + return None + param_to_fqn = state._exec_order_data.param_to_fqn + handle_params = handle.flat_param._params # only populated for use_orig_params + param_fqns = [*chain.from_iterable(param_to_fqn[p] for p in handle_params)] + return param_fqns + + +def _apply_to_modules( + root_module: torch.nn.Module, + module_fn: Callable, + return_fn: Callable, + filter_fqns: Optional[list[str]] = None, + *args, + **kwargs, +): + """ + Performs a pre-order traversal of the modules in the hierarchy rooted at + ``root_module``, applying ``module_fn`` at each module and finally + returning a value using ``return_fn``. The traversal constructs the full + module prefix name (e.g. "module.submodule." just like in model state dict) + and makes that available to ``module_fn``. + + ``filter_fqns`` is used because some module may have its own prefix similar + to ``FullyShardedDataParallel`` and the ``named_parameters()`` is overwritten + to remove the prefix. + """ + + def f(module: torch.nn.Module, prefix: str, tree_level: int, *args, **kwargs): + # Call the module function before recursing over children (pre-order) + module_fn(module, prefix, tree_level, *args, **kwargs) + for submodule_name, submodule in module.named_children(): + if submodule is None: + continue + new_prefix = prefix + submodule_name + "." + new_tree_level = tree_level + 1 + if filter_fqns is not None: + for fqn in filter_fqns: + if fqn.startswith(new_prefix): + break + else: + # DMP's named_parameter() will mess up the traversal with + # ``named_children`` + `named_parameter(recurse=False)``. + # This hack is a must to make the traversal work. + # TODO: Remove this hack once DMP + FSDP is not supported. + # It turns out that recursive wrapping may trigger this as + # well. + if ( + submodule_name == "_fsdp_wrapped_module" + or submodule_name == "_dmp_wrapped_module" + ): + new_prefix = prefix + elif submodule_name == "module": + new_prefix = prefix + f(submodule, new_prefix, new_tree_level, *args, **kwargs) + + f(root_module, "", 0, *args, **kwargs) + return return_fn(*args, **kwargs) + + +@no_type_check +def _assert_in_training_states( + state: _FSDPState, + training_states: list[TrainingState], +) -> None: + """Asserts that FSDP is in the states ``_training_states``.""" + # Raise a `ValueError` instead of using `assert` to ensure that these + # logical assertions run even if `assert`s are disabled + if state.training_state not in training_states: + msg = ( + f"expected to be in states {training_states} but current state is " + f"{state.training_state}" + ) + # Print the error on rank 0 in case this is called in the backward pass + if state.rank == 0: + if isinstance(state, nn.Module): + print(f"Asserting FSDP instance is: {state}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + +def _get_root_modules(modules: set[nn.Module]) -> set[nn.Module]: + """ + Returns: + Set[nn.Module]: The subset of ``modules`` that are root modules (i.e. + parent-less) with respect to the modules in the set itself. In other + words, these are the modules in ``modules`` that are not the child of + any other module in ``modules``. + """ + root_modules: set[nn.Module] = set() + module_to_submodules = {module: set(module.modules()) for module in modules} + for candidate_module in modules: + is_root_module = True + for module, submodules in module_to_submodules.items(): + is_child_module = ( + candidate_module is not module and candidate_module in submodules + ) + if is_child_module: + is_root_module = False + break + if is_root_module: + root_modules.add(candidate_module) + return root_modules + + +def _override_module_mixed_precision( + root: torch.nn.Module, + module_classes_to_override: Iterable[type[nn.Module]], + wrap_override_dict: dict[str, Any] = {"mixed_precision": None}, # noqa: B006 +) -> set[type[nn.Module]]: + module_classes_to_override = tuple(set(module_classes_to_override)) + # Return a set of the actually overridden module classes + overridden_module_classes: set[type[nn.Module]] = set() + for mod in root.modules(): + if isinstance(mod, module_classes_to_override): + overridden_module_classes.add(type(mod)) + mod._wrap_overrides = wrap_override_dict # type: ignore[assignment] + # TODO: We need to run this mixed precision ignored module in fp32, + # but ensure subsequent modules, that may possibly be running with + # mixed precision, still receive the appropriate precision inputs + # without user having to adjust mixed precision config too much. + # As a result, we attach pre and post forward hooks to up / down + # cast. We should revisit this design. + + def cast_fn( + dtype: torch.dtype, module: nn.Module, x: torch.Tensor + ) -> torch.Tensor: + if not torch.is_floating_point(x) or x.dtype == dtype: + return x + _MODULE_TO_INP_DTYPE[module] = x.dtype + return x.to(dtype) + + def forward_pre_hook(module, args): + return _apply_to_tensors(partial(cast_fn, torch.float32, module), args) + + def forward_post_hook(module, args, output): + # NOTE: If the forward did not have any floating-point tensors, + # then the dtype will not be set for this module, and we do not + # upcast the dtype. + if module in _MODULE_TO_INP_DTYPE: + old_dtype = _MODULE_TO_INP_DTYPE[module] + return _apply_to_tensors( + partial(cast_fn, old_dtype, module), output + ) + + # We intentionally append both of these hooks so that they run after + # all other hooks. + mod.register_forward_pre_hook(forward_pre_hook, prepend=False) + mod.register_forward_hook(forward_post_hook, prepend=False) + return overridden_module_classes + + +def _no_dispatch_record_stream(tensor: torch.Tensor, stream: torch.Stream) -> None: + # FIXME record_stream doesn't work with non-cuda/mtia/xpu tensors + if tensor.device.type not in [ + "cuda", + "mtia", + "xpu", + torch._C._get_privateuse1_backend_name(), + ]: + return + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + # from @ezyang: + # The no_dispatch was added in https://github.com/pytorch/pytorch/pull/88014 cc @fegin + # Looking over the PR, it looks like this is because we don't actually support Stream arguments + # in torch dispatch, so it just chokes. + # If Dynamo is able to answer "are there any torch dispatch modes" active (it should answer False), + # a better version of this would just be to check if there are any modes before disabling dispatch. + # TODO(voz): Extend a dynamo util to answer the above, unify the codepaths here. + tensor.record_stream(stream) + else: + with no_dispatch(): + tensor.record_stream(stream) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_debug_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_debug_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3acb42992527436787819180140937e0883e0670 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_debug_utils.py @@ -0,0 +1,157 @@ +# mypy: allow-untyped-defs +import logging +import time +from collections import defaultdict +from collections.abc import Iterator +from contextlib import contextmanager +from enum import Enum + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._flat_param as flat_param_file +from torch.distributed.fsdp._common_utils import ( + _apply_to_modules, + _get_module_fsdp_state, + clean_tensor_name, +) + + +logger = logging.getLogger(__name__) + + +class SimpleProfiler: + class Type(str, Enum): + ALL = "all" + ALLGATHER = "all_gather" + ALLGATHER_OBJ = "all_gather_object" + RESHARDING = "resharding" + H2D = "H2D" + D2H = "D2H" + + results: dict[str, float] = defaultdict(float) + profiling: set[str] = set() + + @classmethod + def reset(cls) -> None: + cls.results.clear() + cls.profiling.clear() + + @classmethod + @contextmanager + def profile(cls, profile_type: str) -> Iterator[None]: + assert profile_type not in cls.profiling, ( + f"{profile_type} is already being profiled. " + "SimpleProfiler does not support profiling multiple instances at " + "the same time. " + ) + + cls.profiling.add(profile_type) + begin = time.monotonic() + try: + yield + finally: + end = time.monotonic() + cls.results[profile_type] += end - begin + cls.profiling.remove(profile_type) + + @classmethod + def dump_and_reset(cls, msg: str) -> None: + # This cannot be combined with DETAIL distributed log + # as the profiling will be very incorrect. + if dist.get_rank() == 0 and dist.get_debug_level() == dist.DebugLevel.INFO: + logger.info("%s %s", msg, cls.results) + cls.reset() + + +def _get_sharded_module_tree_with_module_name_to_fqns( + model: torch.nn.Module, +) -> tuple[str, dict[str, list[str]]]: + """ + It is used for composable fully_shard() code path, it returns + 1. sharded module tree info: each line represents a submodule name that contains the + submodule's FQN and its submodule class name, if the submodule is sharded by `fully_shard`, + the submodule name will add a postfix with ' FULLY SHARDED'. Each increased tree + level adds 4 spaces before the printed name. A printed sharded module tree info for a toy model + is like this: + [CompositeModel] FULLY SHARDED + l1[Linear] + u1[UnitModule] FULLY SHARDED + u1.l1[Linear] + u1.seq[Sequential] + u1.seq.0[ReLU] + u1.seq.1[Linear] + u1.seq.2[ReLU] + u1.l2[Linear] + u2[UnitModule] FULLY SHARDED + u2.l1[Linear] + u2.seq[Sequential] + u2.seq.0[ReLU] + u2.seq.1[Linear] + u2.seq.2[ReLU] + u2.l2[Linear] + l2[Linear] + 2. a dict mapping from the concated module FQN and class name to a list of its managed + original parameters' FQNs. An example of the dict for the above toy sharded model is like this: + {'[CompositeModel]': ['l1.weight', 'l1.bias', 'l2.weight', 'l2.bias'], + 'u1[UnitModule]': ['u1.l1.weight', 'u1.l1.bias', 'u1.seq.1.weight', 'u1.seq.1.bias', 'u1.l2.weight', 'u1.l2.bias'], + 'u2[UnitModule]': ['u2.l1.weight', 'u2.l1.bias', 'u2.seq.1.weight', 'u2.seq.1.bias', 'u2.l2.weight', 'u2.l2.bias'] + } + All FQNs are prefixed starting from ``model``. + + Args: + model (torch.nn.Module): Root module (which may or may not be passed to + composable `fully_shard()`). + """ + + def module_fn( + module, prefix, tree_level, sharded_tree_info, sharded_module_name_to_fqns + ): + num_spaces = tree_level * 4 + trimed_prefix = ( + prefix[:-1] if (len(prefix) > 0 and prefix[-1] == ".") else prefix + ) + prefixed_module_name = trimed_prefix + "[" + module.__class__.__name__ + "]" + printed_prefixed_module_name = " " * num_spaces + prefixed_module_name + + state = _get_module_fsdp_state(module) + if state is None: + sharded_tree_info[0] += printed_prefixed_module_name + "\n" + return + + handle = state._fully_sharded_module_to_handle.get(module, None) + + if handle: + sharded_tree_info[0] += ( + printed_prefixed_module_name + " FULLY SHARDED" + "\n" + ) + else: + sharded_tree_info[0] += printed_prefixed_module_name + "\n" + + if handle: + param = handle.flat_param + assert isinstance(param, flat_param_file.FlatParameter) + global_fqns = [ + clean_tensor_name(prefix + name) for name in param._fqns + ] # prefixed from the top level `model` (i.e. including `prefix`) + + if prefixed_module_name in sharded_module_name_to_fqns: + sharded_module_name_to_fqns[prefixed_module_name].extend(global_fqns) + else: + sharded_module_name_to_fqns[prefixed_module_name] = global_fqns + + def return_fn(sharded_tree_info, sharded_module_name_to_fqns): + return sharded_tree_info[0], sharded_module_name_to_fqns + + # Use List to mutate its value in place while running the recursive functions + sharded_tree_info: list[str] = [ + "", + ] + sharded_module_name_to_fqns: dict[str, list[str]] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [key for key, _ in model.named_parameters()], + sharded_tree_info, + sharded_module_name_to_fqns, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_dynamo_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_dynamo_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..51fbddd5988ee28b0ae815681262bd89813a9992 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_dynamo_utils.py @@ -0,0 +1,43 @@ +import torch.nn as nn + + +def _annotate_modules_for_dynamo( + module: nn.Module, + ignored_modules: set[nn.Module], + use_orig_params: bool, +) -> None: + """ + Annotates the submodules in ``module`` 's tree, except those in + ``ignored_modules``, indicating that the submodules are FSDP-managed and + saving the ``use_orig_params`` setting passed to the FSDP constructor. + """ + for submodule in module.modules(): + if submodule not in ignored_modules: + """[note: Dynamo treats FSDP wrapped modules as UnspecializedNNModule] + + Dynamo doesn't get to see this instance (FullyShardedDataParallel) during tracing, since + it skips tracing all the torch.distributed.fsdp code. + - Why? Running the FSDP code eagerly avoids lots of issues trying to trace complex hooks, and also + gets us graph-breaks on FSDP module boundaries which we want anyway for comm ops. + - However, we _also_ want dynamo to treat the wrapped module inside FSDP 'unspecially' (*), + and we need a way to indicate to dynamo which modules are wrapped by FSDP. + + (*) UnspecializedNNModules in dynamo are traced-through without any assumptions, and with thorough + guards. NNModules otherwise are 'specialized', meaning there is less overhead due to assuming + their code is well-behaved. + + One particular issue with specialized NNModules for FSDP is that the + views created for orig_params are captured into the compiled graph on the first iteration, and while + they are always going to point to the correct flatparameter and give correct results, their order + of creation influences the order of backward execution, preventing overlap of comm and computation + during backward. We need to _use_ the new parameter views created on each forward iteration, in + order for backward to interleave hooks with compute per layer. UnspecializedNNModule lets us achieve + this by capturing the module code more 'functionally' and passing parameters in as inputs each time. + """ + submodule._is_fsdp_managed_module = True # type: ignore[assignment] + + # Dynamo only supports FSDP with use_orig_params=True. + # This is hacky, but I could not think of another way to add an assertion to dynamo + # for this, since Dynamo skips all the FSDP code frames and thus can't inspect the + # FSDP module directly + submodule._fsdp_use_orig_params = use_orig_params # type: ignore[assignment] diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_exec_order_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_exec_order_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb5cb699769484f9dab5150fe446138c485ae85 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_exec_order_utils.py @@ -0,0 +1,364 @@ +# mypy: allow-untyped-defs +import itertools +import warnings +from enum import auto, Enum +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.fsdp._common_utils import _FSDPState, _get_param_to_fqns +from torch.distributed.fsdp._flat_param import FlatParamHandle + + +class _ExecOrderWarnStatus(Enum): + """Used internally for execution order validation.""" + + NONE = auto() # no deviation yet + WARNING = auto() # deviated this iteration; currently issuing warnings + WARNED = auto() # deviated in a previous iteration + + +class _ExecOrderData: + """ + This contains the data structures to track the execution order. We track + the pre-forward order on the *first* iteration for forward prefetching + (which thus assumes static graph) and the post-forward order on *every* + iteration for backward prefetching (which thus does not assume static + graph but may be provide an incorrect order). + """ + + def __init__( + self, + debug_level: dist.DebugLevel, + backward_prefetch_limit: int, + forward_prefetch_limit: int, + ) -> None: + # Tracks the (static) pre-forward order for execution order validation + # and forward prefetching + self.handles_pre_forward_order: list[FlatParamHandle] = [] + # Tracks the post-forward order for pre-backward prefetching + self.handles_post_forward_order: list[Optional[FlatParamHandle]] = [] + self._iter = 0 + + # Gives the max number of backward/forward prefetched all-gathers by a + # single module + self._backward_prefetch_limit = backward_prefetch_limit + self._forward_prefetch_limit = forward_prefetch_limit + + # Data structures for execution order validation + self._checking_order: bool = debug_level == dist.DebugLevel.DETAIL + self.process_group: Optional[dist.ProcessGroup] = None + self.world_size: Optional[int] = None + self.all_handles: list[FlatParamHandle] = [] + # Names are prefixed from the root module + self.param_to_fqn: dict[nn.Parameter, list[str]] = {} + # Current index in the pre-forward execution order + self.current_order_index = 0 + self.warn_status = _ExecOrderWarnStatus.NONE + + def init( + self, + state: _FSDPState, + root_module: nn.Module, + process_group: dist.ProcessGroup, + ) -> None: + """ + Initializes the data structures needed for checking the forward order. + This should be called after a root FSDP instance has been set during + lazy initialization. + """ + self.process_group = process_group + self.rank = process_group.rank() + self.world_size = process_group.size() + # Fix an order over the handles, which should be the same across ranks + for handle in traversal_utils._get_fsdp_handles(root_module): + index = len(self.all_handles) + self.all_handles.append(handle) + handle._handle_index = index + self.param_to_fqn = _get_param_to_fqns(root_module) + # TODO (awgu): We can broadcast the metadata of rank 0's `all_handles` + # to check that all ranks have the same handles in the same order. + # https://github.com/pytorch/pytorch/issues/79620 + + @property + def is_first_iter(self) -> bool: + return self._iter == 0 + + def get_handle_to_backward_prefetch( + self, + current_handle: FlatParamHandle, + ) -> Optional[FlatParamHandle]: + """ + Returns a :class:`list` of the handles keys of the handles to backward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = current_handle._post_forward_index + if current_index is None: + return None + target_index = current_index - 1 + target_handle: Optional[FlatParamHandle] = None + for _ in range(self._backward_prefetch_limit): + if target_index < 0: + break + target_handle = self.handles_post_forward_order[target_index] + target_index -= 1 + return target_handle + + def get_handle_to_forward_prefetch( + self, + current_handle: FlatParamHandle, + ) -> Optional[FlatParamHandle]: + """ + Returns a :class:`list` of the handles keys of the handles to forward + prefetch given the current handles key. If there are no valid handles + keys to prefetch, then this returns an empty :class:`list`. + """ + current_index = current_handle._pre_forward_order_index + if current_index is None: + return None + target_index = current_index + 1 + target_handle: Optional[FlatParamHandle] = None + for _ in range(self._forward_prefetch_limit): + if target_index >= len(self.handles_pre_forward_order): + break + target_handle = self.handles_pre_forward_order[target_index] + target_index += 1 + return target_handle + + def record_post_forward(self, handle: Optional[FlatParamHandle]) -> None: + """ + Records ``handles`` in the post-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + Unlike :meth:`record_pre_forward`, this records the order *every* + iteration with the expectation that the recorded order is reset in + :meth:`next_iter`. + """ + if not handle: + return + # Only record the first usage of a handles key + if handle._post_forward_index: + self.handles_post_forward_order.append(handle) + return + index = len(self.handles_post_forward_order) + handle._post_forward_index = index + self.handles_post_forward_order.append(handle) + + def record_pre_forward( + self, handle: Optional[FlatParamHandle], is_training: bool + ) -> None: + """ + Records ``handles`` in the pre-forward order, where ``handles`` should + be a group of handles used in the same module's forward. If ``handles`` + is empty, then it is omitted. + + On the first iteration, this checks the execution order across ranks. + See :meth:`_check_order` for details. + """ + if not handle: + return + self._check_order(handle, is_training) + # Fix the order after the first iteration and only record the first + # usage of a handles key + if not self.is_first_iter or handle._pre_forward_order_index is not None: + return + index = len(self.handles_pre_forward_order) + handle._pre_forward_order_index = index + self.handles_pre_forward_order.append(handle) + + def _check_order(self, handle: FlatParamHandle, is_training: bool) -> None: + """ + Checks the forward execution order as long as ``is_training`` is + ``True`` since checking in eval mode is not supported. This only checks + if the distributed debug level is DETAIL. + + - On the first iteration, this uses all-gathers to check that all ranks + are all-gathering the same handles and hence ``FlatParameter`` s, + raising an error if not. + - On subsequent iterations, this checks that each rank is locally + consistent with its own forward order from the first iteration, issuing + a warning if not. This issues a warning on the first deviating + iteration and stops warning thereafter. + """ + # Do not check order in eval mode since the post-backward callback does + # not run so it cannot be used to mark the end of an iteration + if not is_training or not self._checking_order: + return + if self.is_first_iter: + msg_prefix = "Forward order differs across ranks:" + optional_local_indices: tuple[Optional[int], ...] = ( + self._get_handle_indices(handle) + ) + device = handle.device # guaranteed to be non-CPU + num_valid_indices = sum( + (index is not None) for index in optional_local_indices + ) + tensor_kwargs: dict[str, Union[torch.dtype, torch.device]] = { + "dtype": torch.int32, + "device": device, + } + world_num_valid_indices = torch.zeros(self.world_size, **tensor_kwargs) # type: ignore[arg-type, call-overload] + local_num_valid_indices = torch.tensor([num_valid_indices], **tensor_kwargs) # type: ignore[arg-type, call-overload] + dist.all_gather_into_tensor( + world_num_valid_indices, + local_num_valid_indices, + group=self.process_group, + ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_num_valid_indices = world_num_valid_indices.cpu() + # Check that all ranks plan to all-gather the same number of + # parameters + # TODO (awgu): Since every module has at most one handle in the + # current implementation, this should never raise the error. + assert self.world_size is not None # mypy + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # TODO(voz): Don't graph break on this - dynamo hates the n1 != n2 + # tensor comparison control flow. + # https://github.com/pytorch/pytorch/issues/107055 + for (r1, n1), (r2, n2) in itertools.combinations( + ( + (rank, world_num_valid_indices[rank]) + for rank in range(self.world_size) + ), + 2, + ): + if n1 != n2: + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering {n1} parameters " + f"while rank {r2} is all-gathering {n2} parameters" + ) + world_indices = torch.zeros( # type: ignore[call-overload] + self.world_size * num_valid_indices, **tensor_kwargs + ) + local_indices = torch.tensor(optional_local_indices, **tensor_kwargs) # type: ignore[arg-type] + dist.all_gather_into_tensor( + world_indices, local_indices, group=self.process_group + ) + # Copy entire tensor from D2H once to avoid per element D2H copies + world_indices = world_indices.cpu() + # Check that all ranks plan to all-gather the same index parameters + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # TODO(voz): Don't graph break on this - dynamo hates the i1 != i2 + # tensor comparison control flow. + # https://github.com/pytorch/pytorch/issues/107055 + for (r1, i1), (r2, i2) in itertools.combinations( + ( + ( + rank, + world_indices[ + rank * num_valid_indices : (rank + 1) + * num_valid_indices + ], + ) + for rank in range(self.world_size) + ), + 2, + ): + if i1 != i2: + r1_param_names = self._get_names_from_handle_indices(i1) + r2_param_names = self._get_names_from_handle_indices(i2) + raise RuntimeError( + f"{msg_prefix} rank {r1} is all-gathering parameters " + f"for {r1_param_names} while rank {r2} is all-gathering " + f"parameters for {r2_param_names}" + ) + else: + # Only issue warnings on the first deviating iteration and stop + # checking thereafter to avoid flooding the console + if self.warn_status == _ExecOrderWarnStatus.WARNED: + return + msg_prefix = None # non-`None` means we should warn + if self.current_order_index >= len(self.handles_pre_forward_order): + # This iteration sees extra all-gather(s) compared to the first + msg_prefix = ( + "Expected to not all-gather any more parameters in the " + "forward but trying to all-gather parameters for " + ) + else: + expected_handle = self.handles_pre_forward_order[ + self.current_order_index + ] + if expected_handle != handle: + expected_param_names = self._get_names_from_handles(expected_handle) + msg_prefix = ( + f"Expected to all-gather for {expected_param_names} " + "but trying to all-gather parameters for " + ) + if msg_prefix is not None: + param_names = self._get_names_from_handles(handle) + msg_suffix = ( + f"{param_names}" + if param_names + else "a newly-added parameter since construction time" + ) + warnings.warn( + "Forward order differs from that of the first iteration " + f"on rank {self.rank}. Collectives are unchecked and may " + f"give incorrect results or hang.\n{msg_prefix}{msg_suffix}" + ) + self.warn_status = _ExecOrderWarnStatus.WARNING + self.current_order_index += 1 + + def _get_handle_indices( + self, + handle: FlatParamHandle, + ) -> tuple[Optional[int], ...]: + """ + Returns the handle indices (i.e. indices into ``self.all_handles``) + corresponding to the handles in ``handle``. An entry in the + returned tuple is ``None`` if the handle is invalid. + """ + indices: list[Optional[int]] = [] + if handle: + indices.append(handle._handle_index) + return tuple(indices) + + def _get_names_from_handle_indices( + self, + handle_indices: tuple[int, ...], + ) -> list[list[str]]: + """ + Returns a list of FQNs for each handle in ``handle_indices``. If a + handle index is invalid, then its FQNs are omitted from the returned + list. + """ + fqns: list[list[str]] = [] + for index in handle_indices: + if index is None or index < 0 or index >= len(self.all_handles): + continue + handle = self.all_handles[index] + flat_param = handle.flat_param + fqns.append(self.param_to_fqn[flat_param]) + return fqns + + def _get_names_from_handles( + self, + handle: FlatParamHandle, + ) -> list[list[str]]: + """ + Returns a list of FQNs for each handle in ``handles_key``. If a handle + is invalid, then its FQNs are omitted from the returned list. + """ + fqns: list[list[str]] = [] + if handle: + flat_param = handle.flat_param + if flat_param in self.param_to_fqn: + fqns.append(self.param_to_fqn[flat_param]) + return fqns + + def next_iter(self): + """ + Advances the internal data structures per iteration. This should be + called in the post-backward callback since that marks the true end of + an iteration. + """ + self._iter += 1 + self.handles_post_forward_order.clear() + if self._checking_order: + self.current_order_index = 0 + if self.warn_status == _ExecOrderWarnStatus.WARNING: + self.warn_status = _ExecOrderWarnStatus.WARNED diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_flat_param.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_flat_param.py new file mode 100644 index 0000000000000000000000000000000000000000..c20d2315ea9ce9cb11ab33abab7ff9ffb369c0bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_flat_param.py @@ -0,0 +1,2788 @@ +# mypy: allow-untyped-defs +import contextlib +import functools +import logging +import os +import warnings +from collections.abc import Generator, Iterator, Sequence +from enum import auto, Enum +from itertools import accumulate, chain +from typing import Any, Callable, cast, NamedTuple, no_type_check, Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torch.distributed.fsdp._common_utils import ( + _FSDPDeviceHandle, + _named_parameters_with_duplicates, + _no_dispatch_record_stream, + _set_fsdp_flattened, + HandleTrainingState, +) +from torch.distributed.utils import ( + _alloc_storage, + _data_ptr_allocated, + _free_storage, + _p_assert, +) +from torch.nn.parameter import _ParameterMeta # type: ignore[attr-defined] +from torch.testing._internal.distributed.fake_pg import FakeProcessGroup + +from ._fsdp_extensions import ( + _ext_post_unflatten_transform, + _ext_pre_flatten_transform, + FSDPExtensions, +) + + +__all__ = [ + "FlatParameter", + "FlatParamHandle", + "FlatParamShardMetadata", + "ParamInfo", + "SharedParamInfo", + "HandleShardingStrategy", +] + +logger = logging.getLogger(__name__) + + +""" +[Note: Fully Sharded Module] +We define the "fully sharded module" to be the original ``nn.Module`` that owns +a ``FlatParamHandle``. It is the *single* module logically responsible for the +*single* unshard/reshard pair for the handle's ``FlatParameter`` for a given +forward or backward pass. The fully sharded module should be passed to the +``FlatParamHandle`` constructor. + +For the wrapper code path: +- The ``FullyShardedDataParallel`` module wrapping the fully sharded module +runs the unshard/reshard on behalf of the fully sharded module by overriding +``nn.Module.forward``. +- The fully sharded module is exactly the module passed to the +``FullyShardedDataParallel`` constructor's ``module`` argument. + +For the non-wrapper code path: +- Hooks registered on the fully sharded module run the unshard/reshard. +- The fully sharded module may either be the direct argument to ``fully_shard`` +or a submodule chosen by the provided wrapping policy. +""" + +# Environment variable toggling whether to use unsafe `setattr()` for view +# setting in `_use_sharded_views()` and `_use_unsharded_views()` +# We should use 'safe' by default since it respects method overrides, but for +# special cases such as for high CPU overhead or for intentionally bypassing +# checks in the overrides, we may use 'unsafe'. +_FSDP_USE_UNSAFE_SETATTR = "FSDP_USE_UNSAFE_SETATTR" + +# Environment variable toggling whether to check for parameter/gradient +# writeback in case their storages change after FSDP initialization +# We should check by default since it prevents silent correctness errors, but +# since such changes are atypical, we may want to skip the check to save CPU +# overhead, especially since the check happens in the pre-forward and +# pre-backward each iteration. +_FSDP_SKIP_WRITEBACK_CHECK = "FSDP_SKIP_WRITEBACK_CHECK" + +# Env var toggling whether when model is in .eval() mode, should we run in fp32 +# or the reduced precision. +_FSDP_USE_FULL_PREC_IN_EVAL = "FSDP_USE_FULL_PREC_IN_EVAL" + +# Some value to set padding in tensors to for debuggability +_FLAT_PARAM_PADDING_VALUE = 42 + +# Environment variables for disabling the all-gather and reduce-scatter +# communication ops for ablation studies. Note that without these communication +# ops the training won't converge, and you probably need to disable correctness +# checks in your model. +_FSDP_USE_FAKE_ALL_GATHER = "FSDP_USE_FAKE_ALL_GATHER" +_FSDP_USE_FAKE_REDUCE = "FSDP_USE_FAKE_REDUCE" + + +# TODO: Define this for now to avoid circular imports. See if we can remove. +class HandleShardingStrategy(Enum): + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + HYBRID_SHARD = auto() + _HYBRID_SHARD_ZERO2 = auto() + + +RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( + HandleShardingStrategy.FULL_SHARD, + HandleShardingStrategy.HYBRID_SHARD, +) +NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES = ( + HandleShardingStrategy.SHARD_GRAD_OP, + HandleShardingStrategy._HYBRID_SHARD_ZERO2, +) + + +class ParamInfo(NamedTuple): + """Information for an original parameter.""" + + param_name: str # unprefixed + module: nn.Module + module_name: str + + +class SharedParamInfo(NamedTuple): + """ + Additional information for a shared parameter. + + For each shared parameter, we designate one module and its parameter + variable to be the primary owner, determined as the first one encountered + in the parameter walk. These are prefixed with "prim". The primary module + and parameter do not have their own :class:`SharedParamInfo` instance. + """ + + param_name: str # unprefixed + module: nn.Module + module_name: str + prim_param_name: str # unprefixed + prim_module: nn.Module + prim_module_name: str + + +class _ShardParamInfo(NamedTuple): + """Shard-related information for an original parameter.""" + + in_shard: bool + # Use to index into the sharded flat parameter, e.g. + # `flat_param[offset_in_shard : offset_in_shard + numel_in_shard]` + offset_in_shard: Optional[int] + numel_in_shard: Optional[int] + # Use to get part of the parameter in the local shard from a flattened + # version of the unsharded parameter, e.g. either + # `param.flatten()[intra_param_start_idx : intra_param_end_idx + 1]` or + # `param.as_strided((param.numel(),), (1,))[intra_param_start_idx : intra_param_end_idx + 1]` + intra_param_start_idx: Optional[int] + intra_param_end_idx: Optional[int] # inclusive + + +class FlatParamShardMetadata(NamedTuple): + """ + This holds metadata specific to this rank's shard of the flat parameter. + + Attributes: + param_names (Tuple[str, ...]): Prefixed parameter names of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_shapes (Tuple[torch.Size, ...]): Parameter shapes of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_strides (Tuple[torch.Size, ...]): Parameter strides of this rank's + shard of the parameters; see :class:`FlatParameter`. + param_contiguities (Tuple[bool, ...]): Parameter `.contiguous` call results + of this rank's shard of the parameters; see :class:`FlatParameter`. + param_numels (Tuple[int, ...]): Parameter numels of this rank's shard + of the parameters; see :class:`FlatParameter`. + param_offsets (Tuple[Tuple[int, int], ...]): [start, end] offsets (in + units of numels) giving this rank's part of each flattened + original parameter. + """ + + param_names: tuple[str, ...] + param_shapes: tuple[torch.Size, ...] + param_strides: tuple[tuple[int, ...], ...] + param_contiguities: tuple[bool, ...] + param_numels: tuple[int, ...] + param_offsets: tuple[tuple[int, int], ...] + + +class _FlatParameterMeta(_ParameterMeta): + # Make `isinstance(t, FlatParameter)` return True for custom tensor + # instances that have the _is_flat_param flag for BC + def __instancecheck__(self, instance): + # NB: do NOT test the super implementation + return isinstance(instance, torch.Tensor) and getattr( + instance, "_is_flat_param", False + ) + + +class FlatParameter(nn.Parameter, metaclass=_FlatParameterMeta): + """ + This is the flat parameter used by :class:`FullyShardedDataParallel`. + + It is comprised of one or more original parameters, which are flattened and + concatenated to construct the flat parameter. + + Under the current design, this parameter logically represents both the + unsharded and sharded flat parameter, and its data changes storages + dynamically. + - In the :class:`FullyShardedDataParallel` constructor, the parameter + is initialized as unsharded and then sharded in-place. + - At runtime, the parameter is lazily (re)-initialized. The sharded + parameter data is saved in ``self._local_shard``, and a new ``Tensor`` + ``self._full_param_padded`` is created, which is the all-gather + destination and owns the unsharded parameter storage thereafter. (See + :meth:`FlatParamHandle.init_flat_param_attributes`.) + - Throughout runtime, the parameter data changes storages as needed, + e.g. to the sharded flat parameter, low precision sharded flat + parameter, or the unsharded flat parameter. + + NOTE: Since ``use_orig_params=True`` supports intra-``FlatParameter`` + padding, we have two versions of the per-parameter numels, one that + includes the padding (``_numels_with_padding``) and one that does not + (``_numels``). The former may have length longer than the other data + structures, while the latter has the same length as the number of actual + original parameters like the other per-parameter data structures. + + NOTE: This is not a real class; instead, you will always get a Parameter + back out if you try to create one of these. This is similar to the trick + we implemented for Parameter to get it to work with subclasses; this + is primarily so that FlatParameter supports combination with FakeTensor. + + Attributes: + _unpadded_unsharded_size (torch.Size): Unsharded flat parameter's size + without right-hand-side padding for divisibility by the world size. + For ``use_orig_params=True``, this includes alignment padding. + _padded_unsharded_size (torch.Size): Unsharded flat parameter's size + with right-hand-side padding for divisibility by the world size. + For ``use_orig_params=True``, this includes alignment padding. This + is only set for sharded strategies since they require padding for + the all-gather. + _sharded_size (torch.Size): Sharded flat parameter's size with padding. + This is also set for ``NO_SHARD``, in which case it is the same as + the unsharded sizes. (We omit "padded" because there is no + analogous unpadded one.) + + _num_params (int): Number of original parameters flattened into this + flat parameter. This is the length of the per-parameter data + structures. + _param_infos (Tuple[ParamInfo, ...]): Each parameter's parameter info + entry; see :class:`ParamInfo` for details. + _shapes (Tuple[torch.Size, ...]): Each parameter's original shape. + _strides (Tuple[torch.Size, ...]): Each parameter's original stride. + _contiguities (Tuple[bool, ...]): Each parameter's ``contiguous()`` + call result. + _fqns (Tuple[str, ...]): Each parameter's fully-qualified name (FQN) + prefixed from the ``_fully_sharded_module``. The names are + guaranteed to be unique in the subtree rooted at that module. + _param_extensions (Tuple[Optional[Any], ...]): Each parameter's + extension (i.e. some per-parameter state) used to customize + pre-flatten and post-unflatten behavior or ``None``. This is + experimental, and users should not depend on its existence in the + future. + _numels_with_padding (Tuple[int, ...]): Each parameter's numel + including entries for the padding. This is used to construct views + into the flat parameter via ``torch.split()``. This may have length + longer than ``_num_params``. + _numels (Tuple[int, ...]): Each parameter's numel excluding entries for + padding. This has length equal to ``_num_params``. + _shard_param_infos (Tuple[_ShardParamInfo, ...]): Each parameter's + shard parameter info; see :class:`_ShardParamInfo` for details. + _shared_param_infos (Tuple[SharedParamInfo, ...]): Shared parameter + info entries; see :class:`SharedParamInfo` for details. + _modules (set[nn.Module]): Modules that contain some original parameter + that is flattened into the flat parameter. + + _shard_numel_padded (int): Numel padded for this rank's sharded flat + parameter. + _local_shard (Tensor): Sharded flat parameter with padding if using a + sharded strategy. If using ``NO_SHARD``, then this is the unpadded + unsharded flat parameter, and there is no notion of a sharded flat + parameter or padded unsharded flat parameter. + _full_param_padded (Tensor): Unsharded flat parameter with padding. + This is not defined for ``NO_SHARD``. When using mixed precision + for parameters, this has the low precision. + _full_prec_full_param_padded (Tensor): Full precision unsharded flat + parameter with padding. This is used for unsharding outside of + computation when using mixed precision for parameters. This is + never defined for ``NO_SHARD``. + _post_backward_hook_handle (RemovableHandle): + Flat parameter's post-backward hook handle. (Compile only) + _post_backward_hook_state (Tuple[AccumulateGrad, RemovableHandle]): + Flat parameter's :class:`AccumulateGrad` object and post-backward + hook handle. (Eager only) + _mp_shard (Tensor): Low precision sharded flat parameter with padding. + This is only defined when parameter mixed precision is enabled. For + ``NO_SHARD``, this is used for computation. + _cpu_grad (Tensor): Sharded gradient with padding stored on CPU. + This is only defined when offloading parameters is enabled. + _saved_grad_shard (Tensor): Sharded gradient with padding from previous + iterations for gradient accumulation without :meth:`no_sync`. + + _params (Optional[List[nn.Parameter]]): If ``use_orig_params=True``, + then each original parameter variable; otherwise, ``None``. This + does not include any padding tensors. + _shared_params (Optional[List[nn.Parameter]]): The original shared + parameter variables if ``use_orig_params=True`` and ``None`` + otherwise. + _tensors (Optional[List[Optional[Tensor]]]): This saves the ``Tensor`` + views created in the forward and tracked by autograd when + ``use_orig_params=True`` and is ``None`` otherwise. This is to + preserve those ``Tensor`` variables for the backward to ensure that + the ``FlatParameter`` 's ``AccumulateGrad`` object does not change + in which case the post-backward hook does not run. This is relevant + for cases like reentrant activation checkpointing. + _is_grad_none_mask (Optional[List[bool]]): If ``use_orig_params=True``, + a mask over the original parameters' gradients indicating if it is + logically ``None`` or not; otherwise, ``None``. This does not + include entries for padding. This mask is needed because only some + of the parameters may have ``None`` gradient, in which case the + flat gradient must be non-``None`` and must use zeros to + approximate those original ``None`` gradients. This mask informs + FSDP to set the original parameter gradients to ``None`` (instead + of zeros) as needed. + """ + + _unpadded_unsharded_size: torch.Size + _padded_unsharded_size: torch.Size + _sharded_size: torch.Size + _num_params: int + _param_infos: tuple[ParamInfo, ...] + _shapes: tuple[torch.Size, ...] + _strides: tuple[tuple[int, ...], ...] + _contiguities: tuple[bool, ...] + _fqns: tuple[str, ...] + _param_extensions: tuple[Optional[Any], ...] + _numels_with_padding: tuple[int, ...] + _numels: tuple[int, ...] + _shard_param_infos: tuple[_ShardParamInfo, ...] + _shared_param_infos: tuple[SharedParamInfo, ...] + _modules: set[nn.Module] + _shard_numel_padded: int + _local_shard: Tensor + _full_param_padded: Tensor + _full_prec_full_param_padded: Tensor + # Eager only + _post_backward_hook_state: tuple[Any, Any] + # Compile only + _post_backward_hook_handle: Any + _mp_shard: Tensor + _cpu_grad: Tensor + _saved_grad_shard: Tensor + _params: Optional[list[nn.Parameter]] + _shared_params: Optional[list[nn.Parameter]] + _tensors: Optional[list[Optional[Tensor]]] + _is_grad_none_mask: Optional[list[bool]] + + _is_padding_mask: list[bool] + + def __new__(cls, data=None, requires_grad=True): + assert cls is FlatParameter, "subclasses FlatParameter not supported" + r = nn.Parameter.__new__(nn.Parameter, data, requires_grad) # type: ignore[call-arg] + r._is_flat_param = True # type: ignore[attr-defined] + return r + + # NB: This is not a regular method, because FlatParameters are not actually + # instances of this class (see __new__ above). So you must indirectly + # call this directly through the classmethod. + @classmethod + def _init_metadata( + cls, + self, + param_infos: list[ParamInfo], + numels: list[int], + shapes: list[torch.Size], + strides: list[tuple[int, ...]], + contiguities: list[bool], + fqns: list[str], + shared_param_infos: list[SharedParamInfo], + param_extensions: list[Optional[Any]], + params: Optional[list[nn.Parameter]], + shared_params: Optional[list[nn.Parameter]], + is_padding_mask: list[bool], + ) -> None: + """ + Initialize attributes holding metadata about the original parameters comprising the flat parameter. + + We expose this method separate from the constructor to keep the + constructor only responsible for the flat parameter's tensor data. This + method should only be called once per model, while the constructor may + be called multiple times, e.g. when reloading from a checkpoint, in + which case only the tensor data needs to be passed to the constructor. + Since :meth:`load_state_dict` is implemented via :meth:`copy_`, the + metadata is correctly assumed to be unchanged. + + Args: + See the Attributes in the class docstring. + """ + assert len(param_infos) == len(shapes) + assert len(param_infos) == len(strides) + assert len(param_infos) == len(contiguities) + assert len(param_infos) == len(fqns) + assert len(param_infos) == len(param_extensions) + self._num_params = len(param_infos) + self._param_infos = param_infos + self._shapes = shapes + self._strides = strides + self._contiguities = contiguities + self._fqns = fqns + self._param_extensions = param_extensions + self._is_padding_mask = is_padding_mask + + numels_without_padding: list[int] = [] + for numel, is_padding in zip(numels, is_padding_mask): + if not is_padding: + numels_without_padding.append(numel) + self._numels = tuple(numels_without_padding) + self._numels_with_padding = tuple(numels) + assert len(self._numels) == self._num_params + + self._shared_param_infos = tuple(shared_param_infos) + self._modules = {pi.module for pi in self._param_infos}.union( + {spi.module for spi in self._shared_param_infos} + ) + assert (params is None) == (shared_params is None) + if params is not None: + assert shared_params is not None and len(shared_params) == len( + shared_param_infos + ) + self._params = [] + for param, is_padding in zip(params, is_padding_mask): + if not is_padding: + self._params.append(param) + self._shared_params = shared_params + # Mark the original parameters to avoid flattening them into + # another `FlatParameter` during recursive construction + for param in chain(self._params, self._shared_params): + _set_fsdp_flattened(param) + self._is_grad_none_mask = [False for _ in range(self._num_params)] + self._tensors = [None for _ in range(self._num_params)] + else: + self._params = None + self._shared_params = None + self._is_grad_none_mask = None + self._tensors = None + self._unpadded_unsharded_size = self.size() + _set_fsdp_flattened(self) + # Tracks whether the `FlatParameter`'s post-backward hook has been + # called to modify the behavior of the post-backward callback + self._post_backward_called = False + + +class FlatParamHandle: + """ + A handle that manages a flat parameter (:class:`FlatParameter`). + + This includes sharding and view management. + + Args: + params (Sequence[nn.Parameter]): The parameters to flatten into the + flat parameter. + fully_sharded_module (nn.Module): See [Note: Fully Sharded Module]. + device (torch.device): The compute and communication device, which + should be a non-CPU device. We refer to it as the compute device. + sharding_strategy (ShardingStrategy): Sharding strategy to apply to + this handle's ``FlatParameter``. + offload_params (bool): Whether to offload the handle's + ``FlatParameter`` to CPU. + mp_param_dtype (Optional[torch.dtype]): Parameter mixed precision + setting passed to the FSDP constructor. + mp_reduce_dtype (Optional[torch.dtype]): Gradient reduction mixed + precision setting passed to the FSDP constructor. + keep_low_precision_grads (bool): Whether to keep gradients in low + precision. + use_orig_params (bool): If ``True``, then FSDP preserves the original + parameter variables and returns them from ``named_parameters()`` + (e.g. to support different optimizer hyperparameters within one + :class:`FlatParameter`). If ``False``, then FSDP reconstructs the + parameters every iteration and returns the :class:`FlatParameter` s + from ``named_parameters()``. + """ + + ################## + # INITIALIZATION # + ################## + def __init__( + self, + params: Sequence[Union[nn.Parameter, Tensor]], + fully_sharded_module: nn.Module, + device: torch.device, + sharding_strategy: HandleShardingStrategy, + offload_params: bool, + mp_param_dtype: Optional[torch.dtype], + mp_reduce_dtype: Optional[torch.dtype], + keep_low_precision_grads: bool, + process_group: dist.ProcessGroup, + use_orig_params: bool, + *, + fsdp_extension: Optional[FSDPExtensions] = None, + ): + super().__init__() + params = list(params) + if len(params) == 0: + raise ValueError( + f"Cannot construct a {self.__class__.__name__} with an empty parameter list" + ) + self._init_setattr_fns() + self._skip_writeback_check = ( + os.environ.get(_FSDP_SKIP_WRITEBACK_CHECK, "") == "1" + ) + self._use_full_prec_in_eval = ( + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + ) + self._use_fake_all_gather = os.environ.get(_FSDP_USE_FAKE_ALL_GATHER, "") == "1" + self._use_fake_reduce = os.environ.get(_FSDP_USE_FAKE_REDUCE, "") == "1" + if self._skip_writeback_check: + _warn_skip_writeback_check( + logger, + f"Since {_FSDP_SKIP_WRITEBACK_CHECK}=1, FSDP will not check " + "for parameter or gradient writeback. Changing parameter or " + "gradient storages may lead to silent correctness errors.", + ) + if self._use_fake_all_gather: + _warn_use_fake_all_gather( + logger, + f"Since {_FSDP_USE_FAKE_ALL_GATHER}=1, FSDP will not execute " + "all-gather ops. Your training will be incorrect, but " + "can reveal how much time spent on all-gather ops.", + ) + if self._use_fake_reduce: + _warn_use_fake_reduce( + logger, + f"Since {_FSDP_USE_FAKE_REDUCE}=1, FSDP will not execute " + "reduce-scatter ops. Your training will be incorrect, but " + "can reveal how much time spent on reduce-scatter ops.", + ) + # Only align addresses for `use_orig_params=True` (for now) + align_addresses = use_orig_params + self._init_get_unflat_views_fn(align_addresses) + self.device = device + self._device_handle = _FSDPDeviceHandle.from_device(self.device) + self.process_group = process_group + if self._use_fake_all_gather or self._use_fake_reduce: + self._fake_process_group = FakeProcessGroup( + rank=process_group.rank(), world_size=process_group.size() + ) + self.rank = process_group.rank() + self.world_size = process_group.size() + self._sharding_strategy = sharding_strategy + self._offload_params = offload_params + self._use_orig_params = use_orig_params + self._keep_low_precision_grads = keep_low_precision_grads + self._training_state = HandleTrainingState.IDLE + self._debug_level = dist.get_debug_level() + self._fully_sharded_module = fully_sharded_module + # For strategies that do not free after forward, we skip using sharded + # views after forward since the unsharded data exists. We still switch + # `self.flat_param` to point to the sharded flat parameter since what + # it points to parameterizes behavior. We use the following attribute + # to track which tensor data the parameters are unsharded views into. + self._unsharded_flat_param_for_skipped_views: Optional[Tensor] = None + # The index in the state's `all_handles`, which must be the + # same across ranks for the execution order validation to work + self._handle_index: Optional[int] = None + # Index in handles_to_pre_forward_order + self._pre_forward_order_index: Optional[int] = None + # Index in `handles_post_forward_order` + self._post_forward_index: Optional[int] = None + # Used for guarding against mistargeted forward prefetches + self._needs_pre_forward_unshard = False + # Used for guarding against mistargeted backward prefetches + self._needs_pre_backward_unshard = False + # Was the handle prefetched? Set on successful _prefetch_handle and unshard + self._prefetched = False + # Optimistically assume a valid input `params` and set dtype attributes + # before `_init_flat_param()`, which performs the actual validation + self._orig_param_dtype = params[0].dtype + self._init_param_reduce_dtypes(mp_param_dtype, mp_reduce_dtype) + assert self._fwd_bwd_param_dtype is not None # mypy + self._aligned_numel = ( + _get_aligned_numel(unsharded_dtype=self._fwd_bwd_param_dtype) + if align_addresses + else 0 + ) + self._fsdp_extension = fsdp_extension + self._init_flat_param_and_metadata( + params, + fully_sharded_module, + self._aligned_numel, + use_orig_params, # type: ignore[arg-type] + ) + self._use_unsharded_views(as_params=False) + + def __repr__(self): + return f"FlatParamHandle(flat_param.fqns={self.flat_param._fqns})" + + def _init_setattr_fns(self): + use_unsafe_setattr = os.environ.get(_FSDP_USE_UNSAFE_SETATTR, "") == "1" + self._setattr_tensor: Callable[[nn.Module, str, Tensor], None] + self._setattr_param: Callable[[nn.Module, str, nn.Parameter], None] + if use_unsafe_setattr: + self._setattr_tensor = _unsafe_setattr_tensor + self._setattr_param = _unsafe_setattr_param + else: + self._setattr_tensor = _safe_setattr_tensor_or_param + self._setattr_param = _safe_setattr_tensor_or_param + + def _init_get_unflat_views_fn(self, align_addresses: bool): + self._get_unflat_views = ( + self._get_unflat_views_aligned + if align_addresses + else self._get_unflat_views_unaligned + ) + + def _init_flat_param_and_metadata( + self, + params: list[Union[Tensor, nn.Parameter]], + module: nn.Module, + aligned_numel: int, + use_orig_params: bool, + ) -> None: + """ + Initialize the ``FlatParameter`` and its metadata. + + NOTE: This should only be called once at construction time, after which + the ``FlatParameter`` metadata is assumed to be static. + + NOTE: The elements of ``params`` should only be ``Tensor`` s when + composing with ``DTensor`` -based tensor parallelism, in which case the + elements may be ``DTensor`` local shards. + """ + if len(params) == 0: + raise ValueError("Expects non-empty `params`") + if aligned_numel < 0: + raise ValueError( + f"Expects non-negative `aligned_numel` but got {aligned_numel}" + ) + ( + dtype, + flat_param_requires_grad, + device, + ) = self._validate_tensors_to_flatten(params) + params_set = set(params) + # For alignment padding, only `numels` gets strictly non-`None` + # elements, and all other lists get `None` elements for padding. + param_infos: list[ParamInfo] = [] + numels: list[int] = [] + shapes: list[torch.Size] = [] + strides: list[tuple[int, ...]] = [] + contiguities: list[bool] = [] + fqns: list[str] = [] + shared_param_infos: list[SharedParamInfo] = [] + shared_param_memo: dict[ + Union[Tensor, nn.Parameter], tuple[nn.Module, str, str] + ] = {} + params_to_flatten: list[Union[Tensor, nn.Parameter]] = [] + shared_params: list[Union[Tensor, nn.Parameter]] = [] + param_extensions: list[Any] = [] + is_padding_mask: list[bool] = [] + total_numel = total_numel_without_padding = 0 + for submodule_name, submodule in module.named_modules(remove_duplicate=False): + for param_name, param in _named_parameters_with_duplicates( + submodule, recurse=False + ): + if param not in params_set: + continue + if param in shared_param_memo: # shared reference + prim_module, prim_module_name, prim_param_name = shared_param_memo[ + param + ] + shared_params.append(param) + shared_param_infos.append( + SharedParamInfo( + param_name, + submodule, + submodule_name, + prim_param_name, + prim_module, + prim_module_name, + ) + ) + else: + if aligned_numel > 0: + numel_to_pad = aligned_numel - (total_numel % aligned_numel) + if numel_to_pad > 0 and numel_to_pad < aligned_numel: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + params_to_flatten.append(padding_tensor) + is_padding_mask.append(True) + numels.append(numel_to_pad) + total_numel += numel_to_pad + transform_t, extension = _ext_pre_flatten_transform( + param, + self._fsdp_extension, + ) + param = cast(nn.Parameter, transform_t) + param_extensions.append(extension) + shared_param_memo[param] = (submodule, submodule_name, param_name) + params_to_flatten.append(param) + is_padding_mask.append(False) + param_infos.append(ParamInfo(param_name, submodule, submodule_name)) + numels.append(param.numel()) + shapes.append(param.shape) + strides.append(param.stride()) + contiguities.append(_is_truly_contiguous(param)) + fqn = ( + submodule_name + "." + param_name + if submodule_name + else param_name + ) + fqns.append(fqn) + total_numel += param.numel() + total_numel_without_padding += param.numel() + if len(params_to_flatten) == 0: + raise ValueError( + f"`params` were not found in `module`'s tree" + f"params: {params}\nmodule: {module}" + ) + if ( + self.rank == 0 + and aligned_numel > 0 + and total_numel != total_numel_without_padding + ): + logger.debug( + "FSDP FlatParameter address alignment created " + "%s numel of padding (%s vs. %s)", + total_numel - total_numel_without_padding, + total_numel, + total_numel_without_padding, + ) + if aligned_numel > 0: + # Pad to be divisible by world size to avoid a copy for the + # post-backward reduce-scatter + numel_to_pad = self.world_size - (total_numel % self.world_size) + if numel_to_pad > 0 and numel_to_pad < self.world_size: + if self.rank == 0: + logger.info( + "FSDP FlatParameter world size divisibility created " + "%s numel of padding", + numel_to_pad, + ) + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + params_to_flatten.append(padding_tensor) + is_padding_mask.append(True) + numels.append(numel_to_pad) + total_numel += numel_to_pad + # Pass `aligned_numel=0` since we already included padding tensors + self.flat_param: FlatParameter = self.flatten_tensors_into_flat_param( + params_to_flatten, + aligned_numel=0, + requires_grad=flat_param_requires_grad, + ) + FlatParameter._init_metadata( + self.flat_param, + param_infos, + numels, + shapes, + strides, + contiguities, + fqns, + shared_param_infos, + param_extensions, + _convert_to_params(params_to_flatten) if use_orig_params else None, + _convert_to_params(shared_params) if use_orig_params else None, + is_padding_mask, + ) + + def _validate_tensors_to_flatten( + self, tensors: list[Union[Tensor, nn.Parameter]] + ) -> tuple: + """Validate the tensors to flatten and returns any necessary metadata.""" + dtype: Optional[torch.dtype] = None + # Return as the logical OR over each tensor's value + flat_param_requires_grad: Optional[bool] = None + device: Optional[torch.device] = None + # For `use_orig_params=True`, permit non-uniform `requires_grad` + for tensor in tensors: + if isinstance(tensor, FlatParameter): + raise ValueError("Cannot flatten a `FlatParameter`") + if dtype is None and not tensor.is_floating_point(): + raise ValueError("Cannot flatten integer dtype tensors") + if dtype is not None and tensor.dtype != dtype: + raise ValueError( + f"Must flatten tensors with uniform dtype but got {dtype} " + f"and {tensor.dtype}" + ) + if ( + not self._use_orig_params + and flat_param_requires_grad is not None + and tensor.requires_grad != flat_param_requires_grad + ): + raise ValueError( + "Must flatten tensors with uniform `requires_grad` when " + "`use_orig_params=False`" + ) + if device is not None and tensor.device != device: + raise ValueError( + "Must flatten tensors on the same device but got both " + f"{device} and {tensor.device}" + ) + dtype = tensor.dtype + flat_param_requires_grad = flat_param_requires_grad or tensor.requires_grad + device = tensor.device + assert flat_param_requires_grad is not None, "Requires non-empty `tensors` list" + return dtype, flat_param_requires_grad, device + + def flatten_tensors( + self, + tensors: list[Tensor], + aligned_numel: int, + ) -> Tensor: + """ + Flatten ``tensors`` into a single flat tensor. + + The flattening optionally includes + padding if ``aligned_numel`` is greater than 0, where ``aligned_numel`` + gives the numel required to have address alignment. + + NOTE: The padding alignment algorithm must be kept in sync with + :meth:`_init_flat_param_metadata`. We separate the two methods because + the initialization happens once, whereas this method may be called + multiple times throughout training (e.g. for checkpointing). + """ + if len(tensors) == 0: + raise ValueError("Expects non-empty `tensors`") + if aligned_numel < 0: + raise ValueError( + f"Expects non-negative `aligned_numel` but got {aligned_numel}" + ) + dtype, _, device = self._validate_tensors_to_flatten(tensors) + flat_tensors: list[Tensor] = [] + if aligned_numel > 0: + total_numel = 0 + for tensor in tensors: + numel_to_pad = aligned_numel - (total_numel % aligned_numel) + if numel_to_pad > 0 and numel_to_pad < aligned_numel: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + flat_tensors.append(padding_tensor) + total_numel += numel_to_pad + flat_tensors.append( + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + ) + total_numel += tensor.numel() + numel_to_pad = self.world_size - (total_numel % self.world_size) + if numel_to_pad > 0 and numel_to_pad < self.world_size: + padding_tensor = _construct_padding_tensor( + numel_to_pad, dtype, False, device + ) + flat_tensors.append(padding_tensor) + total_numel += numel_to_pad + else: + flat_tensors = [ + torch.flatten(_detach_if_needed(tensor)) + if _is_truly_contiguous(tensor) + else _detach_if_needed(tensor).as_strided((tensor.numel(),), (1,)) + for tensor in tensors + ] + return torch.cat(flat_tensors, dim=0) + + def flatten_tensors_into_flat_param( + self, + tensors: list[Tensor], + aligned_numel: int, + requires_grad: bool, + ) -> FlatParameter: + flat_param_data = self.flatten_tensors(tensors, aligned_numel) + return FlatParameter(flat_param_data, requires_grad=requires_grad) + + def _init_param_reduce_dtypes( + self, + mp_param_dtype: Optional[torch.dtype], + mp_reduce_dtype: Optional[torch.dtype], + ) -> None: + """ + Initialize param and reduce dtypes. + + Precondition: ``self.flat_param`` is set. This ensures that this + handle's parameters have a single dtype. + + Postcondition: This sets ``self._fwd_bwd_param_dtype`` and + ``self._reduce_dtype``. If ``mp_param_dtype`` or ``mp_reduce_dtype`` + is ``None``, then we assume the original parameter dtype. One special + case is if ``mp_param_dtype`` is not ``None`` and ``mp_reduce_dtype`` + is ``None``, in which case we assume the gradient reduction dtype + matches the forward/backward parameter dtype. + """ + # Save whether these dtypes were specified so that we permit the + # parameter dtype to change up until the lazy initialization + self._low_prec_param_dtype_specified = mp_param_dtype is not None + self._low_prec_reduce_dtype_specified = mp_reduce_dtype is not None + if ( + self._low_prec_param_dtype_specified + and not self._low_prec_reduce_dtype_specified + ): + # Special case: infer gradient reduction mixed precision + self._fwd_bwd_param_dtype = mp_param_dtype + self._reduce_dtype = self._fwd_bwd_param_dtype + else: + self._fwd_bwd_param_dtype = mp_param_dtype or self._orig_param_dtype + self._reduce_dtype = mp_reduce_dtype or self._orig_param_dtype + assert self._fwd_bwd_param_dtype is not None + assert self._reduce_dtype is not None + + ################################### + # SHARD INITIALIZATION & METADATA # + ################################### + @torch.no_grad() + def shard(self): + """ + Shard the handle's ``FlatParameter``. + + This allocates new memory for + the sharded flat parameter and frees the unsharded flat parameter's + storage. + + Postcondition: ``self.flat_param`` is the sharded flat parameter. Shard + metadata attributes are set for all sharding strategies. + """ + flat_param = self.flat_param + if not self.uses_sharded_strategy: + self._init_shard_metadata(0, 0, flat_param.numel() - 1) + else: + _p_assert( + flat_param.storage_offset() == 0, + "The `FlatParameter` is not the sole occupant of its storage", + ) + sharded_flat_param, numel_padded = FlatParamHandle._get_shard( + flat_param, self.rank, self.world_size + ) + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + allocated = flat_param._typed_storage()._size() > 0 + if allocated: + flat_param._typed_storage()._resize_(0) + flat_param.set_(sharded_flat_param) # type: ignore[call-overload] + start_idx = sharded_flat_param.numel() * self.rank + end_idx = sharded_flat_param.numel() * (self.rank + 1) - 1 # inclusive + self._init_shard_metadata(numel_padded, start_idx, end_idx) + if self._use_orig_params: + self._use_sharded_views() + + def _init_shard_metadata( + self, + numel_padded: int, + unsharded_start_idx: int, + unsharded_end_idx: int, + ) -> None: + """ + Initialize shard-related metadata for this rank's shard of the flat parameter. + + This includes ``_sharded_size``, ``_shard_param_infos``, and ``_shard_numel_padded``. + + Args: + numel_padded (int): Numel padded for this rank's sharded flat + parameter. + unsharded_start_idx (int): Start index in the unsharded flat + parameter assigned to this rank. + unsharded_end_idx (int): End index (inclusive) in the unsharded + flat parameter assigned to this rank. + + Precondition: ``self.flat_param`` 's data is the sharded flat + parameter. + """ + flat_param = self.flat_param + flat_param._sharded_size = flat_param.size() # type: ignore[attr-defined] + sharded_flat_param_numel = flat_param.numel() # includes `numel_padded` + _p_assert( + unsharded_start_idx >= 0 and unsharded_start_idx <= unsharded_end_idx, + f"unsharded_start_idx: {unsharded_start_idx} unsharded_end_idx: {unsharded_end_idx}", + ) + _p_assert( + numel_padded <= sharded_flat_param_numel, + f"numel_padded: {numel_padded} " + f"sharded_flat_param_numel: {sharded_flat_param_numel}", + ) + shard_param_infos = self._get_shard_metadata( + unsharded_start_idx, unsharded_end_idx + ) + assert len(shard_param_infos) == flat_param._num_params, ( + f"Expects length {flat_param._num_params} but got {len(shard_param_infos)}" + ) + flat_param._shard_param_infos = shard_param_infos # type: ignore[attr-defined] + flat_param._shard_numel_padded = numel_padded # type: ignore[attr-defined] + + def _get_shard_metadata( + self, + unsharded_start_idx: int, + unsharded_end_idx: int, + ) -> tuple[_ShardParamInfo, ...]: + """ + Compute the shard metadata based on ``unsharded_start_idx`` and ``unsharded_end_idx`` (inclusive). + + ``unsharded_start_idx`` and ``unsharded_end_idx`` give the interval of the + unsharded flat parameter specifying the shard. + """ + flat_param_offsets = self._get_flat_param_offsets() + assert len(flat_param_offsets) == len(self.flat_param._numels_with_padding), ( + f"Expected {len(self.flat_param._numels_with_padding)} but got {len(flat_param_offsets)}" + ) + shard_param_infos: list[_ShardParamInfo] = [] + sharded_flat_param_numel = unsharded_end_idx - unsharded_start_idx + 1 + # `unsharded_param_start_idx` and `unsharded_param_end_idx` are indices + # into the unsharded flat parameter (inclusive) of the given parameter + for ( + (unsharded_param_start_idx, unsharded_param_end_idx), + is_padding, + ) in zip(flat_param_offsets, self.flat_param._is_padding_mask): + if is_padding: + continue + in_sharded_flat_param = ( + unsharded_start_idx <= unsharded_param_end_idx + and unsharded_end_idx >= unsharded_param_start_idx + ) + if not in_sharded_flat_param: + shard_param_info = _ShardParamInfo(False, None, None, None, None) + else: + if unsharded_start_idx <= unsharded_param_start_idx: + # This branch can only happen once since the rank's + # unsharded start index can only intersect one parameter + intra_param_start_idx = 0 + offset_in_shard = unsharded_param_start_idx - unsharded_start_idx + else: + intra_param_start_idx = ( + unsharded_start_idx - unsharded_param_start_idx + ) + offset_in_shard = 0 + assert ( + offset_in_shard >= 0 and offset_in_shard < sharded_flat_param_numel + ), ( + f"Invalid `offset_in_shard` of {offset_in_shard} for " + f"sharded flat parameter with {sharded_flat_param_numel} numel" + ) + intra_param_end_idx = ( + min(unsharded_param_end_idx, unsharded_end_idx) + - unsharded_param_start_idx + ) + numel_in_shard = intra_param_end_idx - intra_param_start_idx + 1 + shard_param_info = _ShardParamInfo( + True, + offset_in_shard, + numel_in_shard, + intra_param_start_idx, + intra_param_end_idx, + ) + shard_param_infos.append(shard_param_info) + return tuple(shard_param_infos) + + @staticmethod + def _get_unpadded_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> tuple[Tensor, int]: + """ + Return the unpadded shard of ``tensor`` for the given ``rank`` and ``world_size``. + + The returned value is a tuple of the shard of ``tensor`` without any + padding and the numel to pad for that shard. + + If ``tensor`` is already flattened or may be viewed in the flattened + shape (which is true in the expected usage), then this method does not + allocate any new tensor memory. + """ + chunks = ( + torch.flatten(tensor).chunk(world_size) + if _is_truly_contiguous(tensor) + else tensor.as_strided((tensor.numel(),), (1,)).chunk(world_size) + ) + if len(chunks) < (rank + 1): + # This rank gets an empty chunk fully padded with zeros since there + # are not enough chunks across ranks + chunk = chunks[0].new_empty(0) + else: + chunk = chunks[rank] + numel_to_pad = chunks[0].numel() - chunk.numel() + assert numel_to_pad >= 0, ( + "Chunk's size should be at most the first chunk's size" + ) + return chunk, numel_to_pad + + @staticmethod + def _get_shard( + tensor: Tensor, + rank: int, + world_size: int, + ) -> tuple[Tensor, int]: + """ + Return the shard of ``tensor`` with padding for the given ``rank`` and ``world_size`` and the numel padded for that shard. + + This method allocates new memory (via :meth:`clone`) since the + unsharded ``tensor`` may be deallocated after this method returns. + """ + chunk, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + shard = chunk.clone() + if numel_to_pad > 0: + shard = F.pad(shard, [0, numel_to_pad]) + return shard, numel_to_pad + + @staticmethod + def _get_sharded_size(tensor: Tensor, rank: int, world_size: int) -> torch.Size: + """ + Return the shape of ``tensor`` after sharding including padding. + + This requires ``tensor`` to have 1D shape and ensures that the returned + shape is 1D. + """ + assert len(tensor.shape) == 1, f"{tensor.shape}" + unpadded_sharded_tensor, numel_to_pad = FlatParamHandle._get_unpadded_shard( + tensor, rank, world_size + ) + unpadded_sharded_size = unpadded_sharded_tensor.size() + assert len(unpadded_sharded_size) == 1, f"{unpadded_sharded_size}" + return torch.Size([unpadded_sharded_size[0] + numel_to_pad]) + + def _get_flat_param_offsets(self) -> list[tuple[int, int]]: + """ + Return [start, end] offsets of each original parameter's flattened data in the unsharded flat parameter (without padding). + + NOTE: The returned list includes elements for alignment padding. + """ + cumulative_sum = list(accumulate(self.flat_param._numels_with_padding)) + starts = [0] + cumulative_sum[:-1] + ends = [end - 1 for end in cumulative_sum] # inclusive + param_offsets = list(zip(starts, ends)) + return param_offsets + + @no_type_check + def shard_metadata( + self, + ) -> FlatParamShardMetadata: + """ + Return the shard-related metadata specific to this rank's shard of the flat parameter. + + NOTE: The returned tuple does not include elements for alignment + padding but does account for the padding. + """ + fqns_list = [] + shapes_list = [] + strides_list = [] + contiguities_list = [] + numels_list = [] + shard_param_offsets = [] + for fqn, shape, stride, contiguous, numel, shard_param_info in zip( + self.flat_param._fqns, + self.flat_param._shapes, + self.flat_param._strides, + self.flat_param._contiguities, + self.flat_param._numels, + self.flat_param._shard_param_infos, + ): + if not shard_param_info.in_shard: + continue + fqns_list.append(fqn) + shapes_list.append(shape) + strides_list.append(stride) + contiguities_list.append(contiguous) + numels_list.append(numel) + shard_param_offsets.append( + ( + shard_param_info.intra_param_start_idx, + shard_param_info.intra_param_end_idx, + ) + ) + return FlatParamShardMetadata( + tuple(fqns_list), + tuple(shapes_list), + tuple(strides_list), + tuple(contiguities_list), + tuple(numels_list), + tuple(shard_param_offsets), + ) + + @no_type_check + @torch.no_grad() + def init_flat_param_attributes(self) -> None: + """ + This initializes some attributes on the handle's ``FlatParameter``. + This should be called during lazy initialization since it requires the + parameter to be on the compute device if not offloading to CPU and we + want to give users the chance to move the parameter appropriately after + the FSDP constructor. + + For each tensor attribute on the ``FlatParameter``, see the unshard and + reshard methods in this class for the allocation and free pattern. + """ + flat_param = self.flat_param + if flat_param.dtype != self._orig_param_dtype: + # Entering this branch means that the user changed the parameter + # dtype after FSDP initialization, in which case we may need to + # refresh some saved dtype attributes (dtypes specified as a part + # of mixed precision take precedence). + if not self._low_prec_param_dtype_specified: + self._fwd_bwd_param_dtype = flat_param.dtype + # For `reduce_dtype`, require `param_dtype` was not specified since + # then we infer the `reduce_dtype` from the specified `param_dtype` + if ( + not self._low_prec_reduce_dtype_specified + and not self._low_prec_param_dtype_specified + ): + self._reduce_dtype = flat_param.dtype + self._orig_param_dtype = flat_param.dtype + cpu_device = torch.device("cpu") + if self._offload_params: + _p_assert( + flat_param.device == cpu_device, + f"Expects the `FlatParameter` to be on CPU when parameter CPU " + f"offloading is enabled, not {flat_param.device}", + ) + else: + self._check_on_compute_device(self.flat_param) + flat_param._local_shard = flat_param.data + if self._offload_params: + # Pin the memory for faster H2D transfer + flat_param._local_shard = flat_param._local_shard.pin_memory() + # Pre-allocate the sharded gradient on CPU to enable non-blocking + # D2H transfer during the backward pass + flat_param._cpu_grad = torch.zeros_like( + flat_param._local_shard, device=cpu_device + ).pin_memory() + if self._uses_param_mixed_precision: + # For parameter mixed precision, we maintain a low precision + # sharded tensor on the compute device to be all-gathered (for + # sharded strategies) or directly used (for `NO_SHARD`) for + # computation. + flat_param._mp_shard = torch.empty_like( + flat_param._local_shard, + device=self.device, + dtype=self._fwd_bwd_param_dtype, + ) + _free_storage(flat_param._mp_shard) + if self.uses_sharded_strategy: + # We maintain a padded unsharded tensor that serves as the + # all-gather destination and owns the original parameter storages. + unsharded_param_dtype = ( + self._fwd_bwd_param_dtype + if self._uses_param_mixed_precision + else flat_param.dtype + ) # use low precision if parameter mixed precision is enabled + padded_unsharded_numel = flat_param.numel() * self.world_size + flat_param._full_param_padded = torch.empty( + padded_unsharded_numel, + device=self.device, + dtype=unsharded_param_dtype, + ) + flat_param._padded_unsharded_size = flat_param._full_param_padded.size() + _free_storage(flat_param._full_param_padded) + + if self._uses_param_mixed_precision: + # For parameter mixed precision, we maintain a full precision + # padded unsharded tensor for when we force full precision. + flat_param._full_prec_full_param_padded = torch.empty( + padded_unsharded_numel, + device=self.device, + dtype=flat_param.dtype, # full precision + ) + _free_storage(flat_param._full_prec_full_param_padded) + + ################### + # UNSHARD/RESHARD # + ################### + def pre_unshard(self) -> bool: + """ + Return ``False`` if this is a no-op and ``True`` otherwise. + + Postcondition: ``self.flat_param`` 's data is on the device for + communication and is what should be all-gathered. This means that it + matches the dtype of the expected unsharded parameter. + """ + if ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + and self._skipped_use_sharded_views + ): + # Since this path imposes special semantics for the unsharded flat + # parameter (e.g. forcing full precision), use sharded views to + # reuse the existing logic for that special handling + self._use_sharded_views() + ret = False + if self._use_orig_params and not self._skip_writeback_check: + ret = self._writeback_orig_params() + if ( + self.uses_sharded_strategy + and not self._offload_params + and not self.needs_unshard() + ): + pass # no-op + elif self._uses_param_mixed_precision and not self._force_full_precision: + self._use_low_precision_shard() + ret = True + elif self._offload_params and self.flat_param.device != self.device: + # NOTE: This creates a new tensor distinct from any attributes. + self.flat_param_to(self.device, non_blocking=True) + ret = True + self._check_on_compute_device(self.flat_param) + return ret + + def _use_low_precision_shard(self): + """Allocate on the compute device and switch to using the low precision sharded flat parameter.""" + self._check_low_precision_shard() + flat_param = self.flat_param + _alloc_storage( + flat_param._mp_shard, + flat_param._local_shard.size(), # type: ignore[attr-defined] + ) + # `copy_()` implicitly casts to the low precision + flat_param._mp_shard.copy_( # type: ignore[attr-defined] + flat_param._local_shard.to( # type: ignore[attr-defined] + self.device, non_blocking=True + ) + ) + # Invariant: `_mp_shard` is always on the compute device. + flat_param.data = flat_param._mp_shard # type: ignore[attr-defined] + + def unshard(self): + """ + Run the unshard logic. + + This includes all-gathering the flat parameter + and switching to using the unsharded flat parameter. If the handle does + not need unsharding, then this only switches to using the unsharded + flat parameter. For ``NO_SHARD``, this is a no-op. + + If FSDP is in :meth:`summon_full_params` and the handle uses parameter + mixed precision, then the parameter is forced to full precision. + """ + if not self.needs_unshard(): + # Even when not needing an unshard, we should switch to using + # the unsharded flat parameter + unsharded_flat_param = ( + self._get_padded_unsharded_flat_param() + if self.uses_sharded_strategy + else self.flat_param + ) + self._use_unsharded_flat_param(unsharded_flat_param) + return + unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + padded_unsharded_flat_param = self._all_gather_flat_param(unsharded_flat_param) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def needs_unshard(self) -> bool: + """Return if the handle's flat parameter needs to be unsharded.""" + if not self.uses_sharded_strategy: + return False + unsharded_flat_param = self._get_padded_unsharded_flat_param() + already_unsharded = _same_storage_size( + unsharded_flat_param, unsharded_flat_param.numel() + ) + return not already_unsharded + + def _alloc_padded_unsharded_flat_param(self): + """ + Allocate the *padded* unsharded flat parameter. + + The unpadded unsharded + flat parameter is always a view into the padded one. This padded + parameter is saved to a different attribute on the ``FlatParameter`` + depending on if we force full precision. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_storage_freed(unsharded_flat_param) + _alloc_storage(unsharded_flat_param, flat_param._padded_unsharded_size) # type: ignore[attr-defined] + return unsharded_flat_param + + def _get_padded_unsharded_flat_param(self) -> torch.Tensor: + """ + Return a reference to the padded unsharded flat parameter depending on the calling context. + + This should only be called if using a sharded strategy. + """ + self._check_sharded_strategy() + flat_param = self.flat_param + if self._force_full_precision and self._uses_param_mixed_precision: + # When parameter mixed precision is enabled, we use a different + # tensor as the all-gather destination to preserve the invariant + # that `_full_param_padded` is in the low precision + unsharded_flat_param = flat_param._full_prec_full_param_padded # type: ignore[attr-defined] + _p_assert( + unsharded_flat_param.dtype != self._fwd_bwd_param_dtype, + f"Expects full precision but got {self._fwd_bwd_param_dtype}", + ) + # For no-reshard-after-forward strategies, `_full_param_padded` may + # still be allocated from a previous forward. As we are forcing + # full precision here, the full-precision unsharded copy may be + # modified, invalidating the existing low-precision unsharded copy, + # so we should free it here to ensure a new all-gather for the next + # forward/backward computation to persist the modifications. + if flat_param._full_param_padded.untyped_storage().size() > 0: + _free_storage(flat_param._full_param_padded) + else: + unsharded_flat_param = flat_param._full_param_padded # type: ignore[attr-defined] + return unsharded_flat_param + + def _all_gather_flat_param( + self, + padded_unsharded_flat_param: Tensor, + ) -> Tensor: + """ + All-gather the handle's flat parameter to the destination ``padded_unsharded_flat_param``. + + Then switch to use the all-gathered tensor. + """ + _p_assert( + hasattr(self, "process_group") and hasattr(self, "world_size"), + "Expects a process group and world size to have been set via `shard()`", + ) + sharded_flat_param = self.flat_param.data + expected_numel = sharded_flat_param.numel() * self.world_size + _p_assert( + padded_unsharded_flat_param.numel() == expected_numel, + f"Expects {expected_numel} numel but got {padded_unsharded_flat_param.numel()}", + ) + + pg = ( + self._fake_process_group + if self._use_fake_all_gather + else self.process_group + ) + + # HACK this should be handled by C10D + if sharded_flat_param.is_cpu: # type: ignore[attr-defined] + tensor_list = list( + torch.chunk( + padded_unsharded_flat_param, + dist.get_world_size(pg), # type: ignore[arg-type] + ) + ) + dist.all_gather(tensor_list, sharded_flat_param, group=pg) + else: + dist.all_gather_into_tensor( + padded_unsharded_flat_param, + sharded_flat_param, + pg, + ) + + if self._offload_params: + # In case of offloading, `flat_param.data` (i.e. sharded param) is + # created on the pre-unshard stream. We need to hand it over to the + # unshard stream for all-gather + _no_dispatch_record_stream( + sharded_flat_param, + self._device_handle.current_stream(), # unshard_stream + ) + return padded_unsharded_flat_param + + def _use_unsharded_flat_param( + self, + padded_unsharded_flat_param: torch.Tensor, + ) -> None: + """ + Switch to use the *unpadded* unsharded flat parameter. + + This is a view into the *padded* unsharded flat parameter. + """ + unsharded_size = self.flat_param._unpadded_unsharded_size + flat_param_part = padded_unsharded_flat_param[: unsharded_size.numel()] + # slicing [:] is not visible to autograd because of .data + self.flat_param.data = flat_param_part + in_forward = self._training_state == HandleTrainingState.FORWARD + in_pre_backward = self._training_state == HandleTrainingState.BACKWARD_PRE + if self._use_orig_params: + if self._skipped_use_sharded_views and in_pre_backward: + # This call corresponds to the complementary pre-backward + # `_use_unsharded_views()` to the skipped pre-forward + # `_use_sharded_views()`, so we should skip this one too. + return + # We use `Tensor` views in the forward so that they are tracked by + # autograd. We use them in the pre-backward as well to support + # reentrant activation checkpointing, which needs the views to be + # tracked by autograd in the backward pass's recomputed forward. + self._use_unsharded_views( + as_params=(not in_forward and not in_pre_backward) + ) + elif in_forward: + self._use_unsharded_views(as_params=False) + + def post_unshard(self): + """ + Run the post-unshard logic. + + This includes freeing the low precision shard if needed. + """ + if self._uses_param_mixed_precision and self.uses_sharded_strategy: + self._free_low_precision_sharded_param() + self._check_on_compute_device(self.flat_param) + + def _free_low_precision_sharded_param(self): + """Frees the low precision sharded flat parameter.""" + self._check_low_precision_shard() + # `_mp_shard` is allocated in the pre-unshard stream, consumed in the + # unshard stream for sharded strategies, and consumed in both the + # unshard and default streams for `NO_SHARD`. For sharded strategies, + # the current stream here is the unshard stream, and for `NO_SHARD`, + # it is the default stream. For `NO_SHARD`, only recording for the + # default stream suffices since the default stream waits for the + # unshard stream. + _no_dispatch_record_stream( + self.flat_param._mp_shard, + self._device_handle.current_stream(), # type: ignore[attr-defined] + ) + _free_storage(self.flat_param._mp_shard) # type: ignore[attr-defined] + + @torch.no_grad() + def unshard_grad(self): + """ + Unshard the handle's ``FlatParameter``'s gradient. + + If all ranks have + ``None`` gradient, then all original parameters will as well. This + method performs an all-reduce and an all-gather. The additional + all-reduce is tolerable since this method is not meant to be used on + the computation critical path. + + Postcondition: ``_saved_grad_shard`` is defined and contains the value + to set ``flat_param.grad`` after gradients are resharded. + """ + if not self.uses_sharded_strategy: + self._use_unsharded_grad_views() + return + flat_param = self.flat_param + self._check_unsharded(flat_param) + + # Check if all ranks have a `None` gradient + num_grad_none = torch.zeros(1, dtype=torch.int32, device=self.device) + num_grad_none[0] = flat_param.grad is None + dist.all_reduce(num_grad_none, group=self.process_group) + if num_grad_none[0] == self.world_size: + flat_param._saved_grad_shard = None # type: ignore[assignment] + self._use_unsharded_grad_views() + return + + if flat_param.grad is None: + # In the case that only some ranks have `None` gradient, we use + # zeros to approximate as a best effort attempt + if self._debug_level == dist.DebugLevel.INFO: + warnings.warn( + f"[Rank {self.rank}] Only some but not all ranks have a " + "`None` `FlatParameter` gradient, so FSDP is using zeros to " + "approximate those ranks' sharded gradients being `None`" + ) + flat_param._saved_grad_shard = None # type: ignore[assignment] + sharded_grad = torch.zeros(flat_param._sharded_size, device=self.device) # type: ignore[attr-defined] + else: + self._check_sharded(flat_param.grad) + flat_param._saved_grad_shard = flat_param.grad # type: ignore[attr-defined] + sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + padded_unsharded_grad = torch.empty( + flat_param._padded_unsharded_size, # type: ignore[attr-defined] + device=self.device, + dtype=sharded_grad.dtype, + ) + dist.all_gather_into_tensor( + padded_unsharded_grad, sharded_grad, self.process_group + ) + unsharded_size = self.flat_param._unpadded_unsharded_size + flat_param.grad = padded_unsharded_grad[: unsharded_size.numel()].view( + unsharded_size + ) + self._use_unsharded_grad_views() + + def reshard_grad(self): + if self._use_orig_params: + self._use_sharded_grad_views() + if not self.uses_sharded_strategy: + return + self.flat_param.grad = self.flat_param._saved_grad_shard # type: ignore[attr-defined] + delattr(self.flat_param, "_saved_grad_shard") + + def prepare_gradient_for_backward(self): + """ + Prepare the gradient for the backward computation. + + This is done by saving and clearing any existing sharded gradient + in ``.grad`` to enable computing a new unsharded gradient. + """ + _p_assert( + self._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.IDLE), + "Expects to be in `BACKWARD_PRE` or `IDLE` (if prefetching)", + ) + flat_param = self.flat_param + if flat_param.grad is not None and ( + flat_param.grad.size() != flat_param._unpadded_unsharded_size + or flat_param.grad.device != flat_param.device # grad on CPU + ): + self._check_on_compute_device(self.flat_param) + grad_offloaded = flat_param.grad.device != self.device + _p_assert( + not grad_offloaded or self._offload_params, + f"Expects the sharded gradient to be on {self.device} " + f"but got {flat_param.grad.device}", + ) + prev_iter_synced_gradients = ( + flat_param.grad.size() == flat_param._local_shard.size() # type: ignore[attr-defined] + ) + if prev_iter_synced_gradients: + # TODO (awgu): Gradient accumulation outside `no_sync()` + # does not work with CPU offloading. The issue should be + # that, in the post-backward hook, we cannot do an addition + # between a CPU tensor (the existing sharded gradient) and + # a GPU tensor (the new sharded gradient). + if not grad_offloaded: + flat_param._saved_grad_shard = flat_param.grad.data # type: ignore[attr-defined] + sharded_grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + else: + _p_assert( + hasattr(flat_param, "_cpu_grad"), + "`_cpu_grad` should be defined if the gradient is on CPU", + ) + sharded_grad = flat_param._cpu_grad # type: ignore[attr-defined] + # If user specified to keep the gradient in low precision, then + # the gradient may still be of the low precision dtype if the + # user did not set the gradient to `None` after the previous + # backward, in which case FSDP should cast back to the full + # precision dtype so that FSDP can accumulate in that dtype in + # the post-backward hook and assign to `.grad` in that dtype in + # the post-backward callback. + local_shard_dtype = flat_param._local_shard.dtype # type: ignore[attr-defined] + if ( + self._keep_low_precision_grads + and sharded_grad.dtype != local_shard_dtype + ): + sharded_grad.data = sharded_grad.to(local_shard_dtype) + else: + padded_unsharded_size = flat_param._padded_unsharded_size # type: ignore[attr-defined] + _p_assert( + flat_param.grad.size() == padded_unsharded_size, + "Expects `.grad` to be the unsharded gradient in " + f"`no_sync()` with size {padded_unsharded_size} " + f"but got size {flat_param.grad.size()}", + ) + flat_param.grad = None + + def prepare_gradient_for_optim(self): + """Prepare the gradient for optimizer computation by moving the sharded gradient to the ``.grad`` attribute.""" + + def cast_grad_to_param_dtype_if_needed(flat_param): + # TODO (rohan-varma): test for full precision with keep_low_precision_grads + if not self._force_full_precision and self._keep_low_precision_grads: + _p_assert(flat_param.grad is not None, "Unexpected None grad!") + if flat_param.grad.dtype != self._fwd_bwd_param_dtype: + flat_param.grad.data = flat_param.grad.to(self._fwd_bwd_param_dtype) + if self._use_orig_params: + self._use_sharded_grad_views() + + flat_param = self.flat_param + # TODO (awgu): We should replace these conditional checks to encode + # the logical intention more directly. + if hasattr(flat_param, "_cpu_grad"): + # NOTE: This branch includes `NO_SHARD`. + self._check_sharded(flat_param) + self._check_on_cpu(flat_param) + flat_param.grad = flat_param._cpu_grad # type: ignore[attr-defined] + cast_grad_to_param_dtype_if_needed(flat_param) + elif hasattr(flat_param, "_saved_grad_shard"): + self._check_sharded(flat_param) + self._check_on_compute_device(flat_param) + if flat_param._saved_grad_shard is not None: + self._check_on_compute_device(flat_param._saved_grad_shard) # type: ignore[attr-defined] + # If no sharded gradient was computed this iteration, then there is + # no need to forward `_saved_grad_shard` to `grad` + if flat_param._post_backward_called: # type: ignore[attr-defined] + flat_param.grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + if flat_param.grad is not None: + cast_grad_to_param_dtype_if_needed(flat_param) + else: + _p_assert( + not self.uses_sharded_strategy or not flat_param._post_backward_called, # type: ignore[attr-defined] + "All sharded parameters that received a gradient in the " + "post-backward should use `_saved_grad_shard`", + ) + # Delete `_saved_grad_shard` since its existence indicates a previous + # gradient to accumulate with in the post-backward hook + if hasattr(flat_param, "_saved_grad_shard"): + delattr(flat_param, "_saved_grad_shard") + + @contextlib.contextmanager + def to_cpu(self): + """ + Move the unpadded unsharded flat parameter to CPU while in the context and moves it back to the previous device upon exit. + + For now, this assumes the ``FlatParameter`` is the unpadded unsharded flat parameter + since (1) there is no reason to include the padding in the copy and (2) + there is no use case for the sharded flat parameter. + + Precondition: ``self.flat_param`` 's data is the unpadded unsharded + flat parameter on the compute device, and the handle uses a sharded + strategy. + Postcondition: Same as the precondition. + """ + self._check_sharded_strategy() + _p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + self._check_on_compute_device(self.flat_param) + # Check that the unpadded unsharded flat parameter is a view into the + # padded unsharded flat parameter as expected + # NOTE: This check is not strictly needed for correctness but is a + # useful sanity check since the tensor should only be used internally. + _p_assert( + _same_storage(self.flat_param, self._get_padded_unsharded_flat_param()), + "Expects the unpadded parameter to be a view into the padded parameter", + ) + self.flat_param_to(torch.device("cpu")) + self._free_unsharded_flat_param() + try: + yield + finally: + _p_assert( + self.flat_param.size() == self.flat_param._unpadded_unsharded_size, + f"Expects size {self.flat_param._unpadded_unsharded_size} but got {self.flat_param.size()}", + ) + padded_unsharded_flat_param = self._alloc_padded_unsharded_flat_param() + # Copy from CPU to the compute device + padded_unsharded_flat_param[: self.flat_param.numel()].copy_( + self.flat_param + ) + self._use_unsharded_flat_param(padded_unsharded_flat_param) + + def reshard(self, free_unsharded_flat_param: bool): + """ + Run the reshard logic. + + This includes freeing the unsharded flat + parameter if ``free_unsharded_flat_param`` and switching to using the + sharded flat parameter. Note that this also implicitly offloads + the sharded flat parameter (if CPU offload is enabled) by pointing + it to the ``_local_shard`` attribute which resides on CPU. + """ + # Switch to the sharded `FlatParameter` before freeing to prevent + # "use-after-free"-type bugs with external profiling tools, where for + # `use_orig_params=True`, the `param` does not point to valid memory + # when setting `param.data = ...` in `_use_sharded_views()`. + self._use_sharded_flat_param() + if free_unsharded_flat_param: + self._free_unsharded_flat_param() + + def post_reshard(self): + """ + Run the post-reshard logic. + + This includes freeing any memory that + can now be freed given that the ``FlatParameter`` points to the full + precision sharded flat parameter. + + Precondition: ``self.flat_param`` 's data points to the full precision + sharded flat parameter. + """ + # For `NO_SHARD`, `_mp_shard` is not freed in the post-unshard since it + # is also the low precision *unsharded* flat parameter. Hence, we delay + # the free until the reshard. + if ( + self._uses_param_mixed_precision + and not self.uses_sharded_strategy + and not self._force_full_precision # did not use the low precision shard + ): + self._free_low_precision_sharded_param() + + def _free_unsharded_flat_param(self): + """ + Free the padded unsharded flat parameter. We allow this + function to be called even when storage is not allocated + + The tensor to free depends + on the calling context since the unshard may have forced full + precision, in which case a different tensor is used. + """ + self._check_sharded_strategy() + unsharded_flat_param = self._get_padded_unsharded_flat_param() + self._check_on_compute_device(unsharded_flat_param) + # Do not free the memory until all ops in the current stream finish + _no_dispatch_record_stream( + unsharded_flat_param, self._device_handle.current_stream() + ) + _free_storage(unsharded_flat_param) + + def _use_sharded_flat_param(self) -> None: + """Switches to using the sharded flat parameter.""" + flat_param = self.flat_param + if self._use_orig_params: + in_forward = self._training_state == HandleTrainingState.FORWARD + skip_use_sharded_views = ( + torch.is_grad_enabled() + and in_forward + and self._sharding_strategy + in NO_RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + # Only incur the extra `.data` call if needed + if skip_use_sharded_views: + unsharded_flat_param = flat_param.data + if self._offload_params: + device = flat_param._local_shard.device # type: ignore[attr-defined] + _p_assert( + device == torch.device("cpu"), + f"Expects the local shard to be on CPU but got {device}", + ) + flat_param.data = flat_param._local_shard # type: ignore[attr-defined] + if self._use_orig_params: + if skip_use_sharded_views: # type: ignore[possibly-undefined] + self._unsharded_flat_param_for_skipped_views = unsharded_flat_param # type: ignore[possibly-undefined] + else: + self._use_sharded_views() + # For the post-forward reshard, we may try to use sharded gradient + # views (or unsharded gradient views if a gradient was accumulated + # in `no_sync()`), but for the post-backward reshard, we delay the + # call to after the reduce-scatter. + if ( + in_forward # type: ignore[possibly-undefined] + # Skip using gradient views if skipped using sharded views + # since exposing unsharded parameters with sharded gradients + # may be confusing to the user + and not self._skipped_use_sharded_views + ): + # TODO: Change `_unpadded_unsharded_size` if we change the + # gradient to be computed directly with padding. + accumulated_grad_in_no_sync = ( + flat_param.grad is not None + and self.uses_sharded_strategy + and flat_param.grad.shape == flat_param._unpadded_unsharded_size + ) + if accumulated_grad_in_no_sync: + self._use_unsharded_grad_views() + else: + self._use_sharded_grad_views() + + ######### + # VIEWS # + ######### + @no_type_check + def _get_unflat_views_unaligned( + self, + tensor: Optional[torch.Tensor] = None, + ) -> Iterator[Tensor]: + """ + Return unflattened ``Tensor`` views into ``tensor``. + + If `tensor`` is ``None``, ``flat_param`` is used. The unflattening is based + on ``flat_param`` 's metadata. + + Examples for ``tensor`` include ``flat_param.grad`` or unsharded + tensor optimizer state. + """ + flat_param = self.flat_param + if tensor is None: + tensor = flat_param + views = ( + _ext_post_unflatten_transform( + subtensor.view(shape) + if contiguous + else subtensor.as_strided(shape, stride), + param_extension, + self._fsdp_extension, + ) + for (subtensor, shape, stride, contiguous, param_extension) in zip( + torch.split(tensor, flat_param._numels, dim=0), + flat_param._shapes, + flat_param._strides, + flat_param._contiguities, + flat_param._param_extensions, + ) + ) + return views + + @no_type_check + def _get_unflat_views_aligned( + self, + tensor: Optional[Tensor] = None, + ) -> list[Tensor]: + """ + Return unflattened ``Tensor`` views into ``tensor`` with handling for padding. + + This method has the same contract as :meth:`_get_unflat_views_unaligned` + except it checks for ``None`` placeholders representing padding for + alignment, which may incur slightly more CPU overhead. + """ + flat_param = self.flat_param + if tensor is None: + tensor = flat_param + splits: list[Tensor] = torch.split( + tensor, flat_param._numels_with_padding, dim=0 + ) + idx = 0 + views: list[Tensor] = [] + for split, is_padding in zip(splits, flat_param._is_padding_mask): + if is_padding: + continue + views.append( + _ext_post_unflatten_transform( + split.view(flat_param._shapes[idx]) + if flat_param._contiguities[idx] + else split.as_strided( + flat_param._shapes[idx], flat_param._strides[idx] + ), + flat_param._param_extensions[idx], + self._fsdp_extension, + ) + ) + idx += 1 + return views + + @no_type_check + @torch.enable_grad() + def _use_unsharded_views(self, as_params: bool) -> None: + """ + Unflatten the unsharded flat parameter by setting the original parameter variables to be views into it. + + Args: + as_params (bool): If ``True``, then registers the original + parameters as ``nn.Parameter`` s; if ``False``, then registers + the original parameters only as ``Tensor`` s. ``False`` should + be used during forward/backward computation and when hiding the + original parameters from :meth:`nn.Module.named_parameters`. + + Note: + when prefetching for next forward, current forward may be + annotated with `@torch.no_grad()` + `@torch.enable_grad()` ensures non-empty `view.grad_fn` + otherwise `_post_backward_hook` will not get called + """ + flat_param = self.flat_param + self._check_unsharded(flat_param) + views = self._get_unflat_views() + from torch.distributed.tensor import DTensor + + for i, (view, (param_name, module, _)) in enumerate( + zip(views, flat_param._param_infos) + ): + if self._use_orig_params and as_params: + if type(view) is DTensor: + # A `DTensor` `view` is not compatible with assigning + # `param.data = view`, so we cannot preserve the parameter + # variable. + self._setattr_param( + module, + param_name, + nn.Parameter(view, requires_grad=flat_param.requires_grad), + ) + continue + param = self.flat_param._params[i] + self._setattr_param(module, param_name, param) + param.data = view + elif as_params: + self._setattr_param( + module, + param_name, + nn.Parameter(view, requires_grad=flat_param.requires_grad), + ) + else: # `as_params=False` + param_var: Tensor = view + if self._use_orig_params: + if self._training_state == HandleTrainingState.FORWARD: + # Save the `Tensor` for the pre-backward + self.flat_param._tensors[i] = view # save for pre-backward + elif self._training_state == HandleTrainingState.BACKWARD_PRE: + # Use the saved `Tensor` variable from the forward to + # preserve the autograd graph so that the post-backward + # hook fires (e.g. for reentrant AC) + tensor = self.flat_param._tensors[i] + tensor.data = view + param_var = tensor + self._setattr_tensor(module, param_name, param_var) + if ( + self._use_orig_params + and self._training_state == HandleTrainingState.FORWARD + ): + module._parameters[param_name] = param_var + for i, ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in enumerate(self.flat_param._shared_param_infos): + prim_param: Union[Tensor, nn.Parameter] = getattr( + prim_module, prim_param_name + ) + _p_assert( + not as_params or isinstance(prim_param, nn.Parameter), + f"as_params={as_params} type(prim_param)={type(prim_param)}", + ) + if self._use_orig_params and as_params: + shared_param = self.flat_param._shared_params[i] + self._setattr_param(module, param_name, shared_param) + shared_param.data = prim_param + elif as_params: + self._setattr_param(module, param_name, prim_param) + else: + self._setattr_tensor(module, param_name, prim_param) + if ( + self._use_orig_params + and self._training_state == HandleTrainingState.FORWARD + ): + module._parameters[param_name] = prim_param + + @no_type_check + def _use_unsharded_grad_views(self) -> None: + """ + Unflatten the unsharded flat parameter's gradient. + + The original parameter variables' gradients are set to be views into + the unsharded flat parameter's gradient. + """ + # Expects the gradient to be in `flat_param.grad` + if self.flat_param.grad is None: + for param in chain(self.flat_param._params, self.flat_param._shared_params): + param.grad = None + return + self._check_unsharded(self.flat_param.grad) + views = self._get_unflat_views(self.flat_param.grad) + for i, (view, (param_name, module, _)) in enumerate( + zip(views, self.flat_param._param_infos) + ): + _p_assert( + hasattr(module, param_name), + f"{self.flat_param._fqns[i]} is missing", + ) + param = getattr(module, param_name) + if ( + param.shape != view.shape + or param.dtype != view.dtype + or param.device != view.device + ): + # NOTE: This is a hack using `.data` to side step the check + # that parameter/gradient sizes/dtypes/devices match. From + # calling `reshard()`, `param` has the sharded size, has the + # full precision dtype, and if CPU offloading is enabled, is on + # CPU. Thus, one or more of the following cases can hold when + # in `no_sync()`, where `view` is the original parameter's + # gradient: + # 1. `view` can have the unsharded size. + # 2. `view` can have the parameter low precision dtype. + # 3. `view` can be on GPU. + if param.grad is None: + param.grad = torch.empty_like(param) + param.grad.data = view + else: + param.grad = view + for i, ( + param_name, + module, + module_name, + prim_param_name, + prim_module, + _, + ) in enumerate(self.flat_param._shared_param_infos): + _p_assert( + hasattr(module, param_name), + f"{module_name + '.' + param_name if module_name else param_name} is missing", + ) # did not save FQN info in `_shared_param_infos` + param = getattr(module, param_name) + prim_param = getattr(prim_module, prim_param_name) + if ( + param.shape != prim_param.grad.shape + or param.dtype != prim_param.grad.dtype + or param.device != prim_param.grad.device + ): + # NOTE: This is the same hack to use `.data` to side step the + # size check. + if param.grad is None: + param.grad = torch.empty_like(param) + param.grad.data = prim_param.grad + else: + param.grad = prim_param.grad + + @contextlib.contextmanager + def unflatten_as_params(self) -> Generator: + """ + Unflatten the original parameters. + + The function assumes that the flat parameter is unsharded. When in the context, + unflattens the original parameters as ``nn.Parameter`` views into the + flat parameter, and after the context, restores the original parameters + as ``Tensor`` views into the flat parameter. + """ + self._use_unsharded_views(as_params=True) + try: + yield + finally: + self._use_unsharded_views(as_params=False) + + @no_type_check + @torch.no_grad() + def _use_sharded_views(self) -> None: + """ + Set the original parameter variables' data to be flattened views into the sharded flat parameter. + + The views are kept as flattened to simplify the case where a parameter + is sharded across ranks. Parameters whose data is not present in the + sharded flat parameter have their data set to a size-0 empty tensor. We + do not delete them to ensure to preserve expected behaviors like model + printability. Parameters whose data is present must preserve their + variables to be passable to an optimizer. + """ + self._unsharded_flat_param_for_skipped_views = None + if not self.uses_sharded_strategy: + # For `NO_SHARD`, use the *unflattened* unsharded views since we + # have the unsharded parameter + self._use_unsharded_views(as_params=True) + return + flat_param = self.flat_param + self._check_sharded(flat_param) + # Construct once and reuse for all parameters not in the local shard + size_0_empty_tensor = torch.empty( + 0, + dtype=self.flat_param.dtype, # in case `flat_param` changed dtype + device=self.flat_param.device, + requires_grad=False, + ) + for param, shard_param_info, (param_name, module, _) in zip( + flat_param._params, flat_param._shard_param_infos, flat_param._param_infos + ): + self._setattr_param(module, param_name, param) + if not shard_param_info.in_shard: + # Allow the original data to be freed via garbage collection + param.data = size_0_empty_tensor + else: + offset = shard_param_info.offset_in_shard + numel_in_shard = shard_param_info.numel_in_shard + param.data = flat_param[offset : offset + numel_in_shard] + assert self.flat_param._shared_params is not None + for i, ( + param, + (param_name, module, _, prim_param_name, prim_module, _), + ) in enumerate( + zip(self.flat_param._shared_params, self.flat_param._shared_param_infos) + ): + self._setattr_param(module, param_name, param) + prim_param = getattr(prim_module, prim_param_name) + param.data = prim_param # could be both empty and non-empty + if self._training_state == HandleTrainingState.BACKWARD_POST: + # Clear the saved `Tensor`s since they are unneeded now + for i in range(len(self.flat_param._tensors)): + self.flat_param._tensors[i] = None + + @no_type_check + @torch.no_grad() + def _use_sharded_grad_views(self) -> None: + """ + Set the original parameter variables' gradients to be flattened views into the sharded flat parameter's gradient. + + This is a no-op if there is no gradient. + + Parameters whose data is not present in the sharded flat parameter and + parameters with ``requires_grad=False`` have their gradients set to + ``None``. Since the gradient variables do not need to be preserved, + this method does not manipulate existing ``Tensor`` data directly and + creates new ``Tensor`` variables instead. + """ + flat_param = self.flat_param + self._check_sharded(flat_param) + grad = self.sharded_grad + if grad is None: + for param in chain(flat_param._params, flat_param._shared_params): + param.grad = None + return + self._check_sharded(grad) + for param, shard_param_info, is_grad_none in zip( + flat_param._params, + flat_param._shard_param_infos, + flat_param._is_grad_none_mask, + ): + if not shard_param_info.in_shard: + param.grad = None + else: + numel_in_shard = shard_param_info.numel_in_shard + if param.requires_grad and not is_grad_none: + offset = shard_param_info.offset_in_shard + if self._keep_low_precision_grads or param.dtype != grad.dtype: + # NOTE: This is a hack using `.data` to side step the + # check that parameter/gradient dtypes match. Here, + # `param` has full precision; `grad` has low precision. + if param.grad is None: + # `.grad` must have the same shape as `param` + param.grad = torch.empty_like(param) + param.grad.data = grad[ + offset : offset + numel_in_shard + ].reshape(param.shape) + else: + param.grad = grad[offset : offset + numel_in_shard].reshape( + param.shape + ) + else: + param.grad = None + assert flat_param._shared_params is not None + for param, (_, _, _, prim_param_name, prim_module, _) in zip( + flat_param._shared_params, flat_param._shared_param_infos + ): + in_sharded_flat_param = hasattr(prim_module, prim_param_name) + if in_sharded_flat_param and param.requires_grad: + prim_param = getattr(prim_module, prim_param_name) + param.grad = prim_param.grad # share the same reference + else: + param.grad = None + + @no_type_check + @torch.no_grad() + def _writeback_orig_params(self) -> bool: + """ + Write back any parameters that changed storage to the handle's ``FlatParameter``. + + Iterates over the original parameters and writes back any parameters + that changed storages (due to a non-inplace operator) to the handle's + ``FlatParameter``. This method preserves the ``FlatParameter` 's + device even if an original parameter's device changes. + + Raises: + RuntimeError: If an original parameter or gradient changes storages + but no longer has the expected flattened shape. + Returns: ``True`` if some writeback happened, and ``False`` otherwise. + """ + if ( + self.uses_sharded_strategy + and not self.is_sharded(self.flat_param) + and not self._skipped_use_sharded_views + ): + # For `NO_SHARD`, we may still need to writeback + return False + flat_param = self.flat_param + wroteback = False + if self._skipped_use_sharded_views and self.uses_sharded_strategy: + # NOTE: We must use the unsharded flat parameter from which the + # unsharded views were computed, not the one from the current + # calling context (`_get_padded_unsharded_flat_param()`) since that + # may be different (e.g. the model changed from train to eval). + flat_param_tensor = self._unsharded_flat_param_for_skipped_views + _p_assert( + _data_ptr_allocated(flat_param_tensor), + "If skipped using sharded views, the unsharded flat parameter " + "should be allocated", + ) + else: + flat_param_tensor = flat_param + # NOTE: Since this method is called in the pre-unshard, which is only + # called during computation in the pre-forward or pre-backward, the + # sharded gradient should be guaranteed to be in `.grad`, not in + # `._saved_grad_shard`. + flat_param_grad = ( + flat_param.grad + if self.uses_sharded_strategy or not self._offload_params + else flat_param._cpu_grad + ) + for i, ( + param, + (in_shard, offset_in_shard, numel_in_shard, _, _), + (param_name, module, _), + ) in enumerate( + zip( + flat_param._params, + flat_param._shard_param_infos, + flat_param._param_infos, + ) + ): + if not in_shard: + continue + if not hasattr(module, param_name): + # Do not writeback if original parameters are deregistered + # (e.g. during model checkpointing) + continue + + # Check for parameter writeback + if self._skipped_use_sharded_views: + param = flat_param._tensors[i] + _p_assert( + param is not None, + f"Expects to have saved tensor for {flat_param._fqns[i]}", + ) + param_changed = getattr(module, param_name) is not param + needs_param_writeback = ( + param_changed # changed parameter variable itself + or not _same_storage(param, flat_param_tensor) + ) + if self._skipped_use_sharded_views and ( + param_changed or needs_param_writeback + ): + raise AssertionError( + "FSDP does not support changing the parameters between " + f"forward and backward for {self._sharding_strategy}" + ) + if param_changed: + # NOTE: The gradient is not preserved after a parameter change. + param = getattr(module, param_name) + flat_param._params[i] = param + if needs_param_writeback: + expected_shape = torch.Size([numel_in_shard]) + self._writeback_tensor( + param, flat_param, i, expected_shape, offset_in_shard, True + ) + wroteback = True + + # Check for gradient writeback + if self._skipped_use_sharded_views: + # Skip the writeback check because we do not expose gradients + # when we skipped using sharded views + continue + if param.grad is None and flat_param.grad is not None: + expected_shape = torch.Size([numel_in_shard]) + self._writeback_tensor( + None, flat_param.grad, i, expected_shape, offset_in_shard, False + ) + elif param.grad is not None: + # For `NO_SHARD` + CPU offloading, `_cpu_grad` is always in + # memory and owns the gradient storage, so it will never + # require gradient writeback. + if not self.uses_sharded_strategy and self._offload_params: + # Explicitly continue to handle the case of `no_sync()`, + # where `param.grad` is a view into the GPU gradient + # referenced by `flat_param.grad`, while `flat_param_grad` + # is `flat_param._cpu_grad`, which is on CPU + continue + + needs_grad_writeback = flat_param_grad is None or not _same_storage( + param.grad, flat_param_grad + ) + if needs_grad_writeback: + if flat_param_grad is None: + flat_param_grad = torch.zeros_like(flat_param) + expected_shape = torch.Size([numel_in_shard]) + self._writeback_tensor( + param.grad, + flat_param_grad, + i, + expected_shape, + offset_in_shard, + False, + ) + flat_param.grad = flat_param_grad + flat_param_grad = flat_param.grad + + # TODO: If we want to handle shared parameters, we need to re-generate + # the shared parameter data structures in case sharedness changed. + for i, ( + param_name, + module, + _, + prim_param_name, + prim_module, + _, + ) in enumerate(flat_param._shared_param_infos): + if getattr(module, param_name) is not getattr(prim_module, prim_param_name): + raise NotImplementedError( + "Changing shared parameters is not supported yet" + ) + return wroteback + + def _writeback_tensor( + self, + src_tensor: Optional[Tensor], + dst_tensor: Tensor, + tensor_index: int, + expected_shape: torch.Size, + offset: int, + is_param: bool, # else gradient + ) -> None: + """ + Write back ``src_tensor`` to ``dst_tensor`` at offset ``offset``, where ``src_tensor`` should have shape ``expected_shape``. + + ``is_param`` indicates if the tensor is the parameter (if ``True``) or gradient (if + ``False``). If ``src_tensor`` is ``None``, then the effect is zeroing + instead of copying. ``tensor_index`` gives the index of ``src_tensor`` + in the metadata structures. + + Raises: + RuntimeError: If the ``src_tensor`` does not have the expected + shape. + """ + _p_assert( + len(expected_shape) == 1, + f"Expects a 1D expected shape but got {expected_shape}", + ) + if self._debug_level == dist.DebugLevel.INFO: + rank = self.rank if hasattr(self, "rank") else dist.get_rank() + src_shape = src_tensor.shape if src_tensor is not None else None + src_device = src_tensor.device if src_tensor is not None else None + warnings.warn( + f"[Rank {rank}] {'Parameter' if is_param else 'Gradient'} needs " + f"writeback in {self._training_state}\n" + f"expected shape={expected_shape} shape={src_shape} " + f"expected device={dst_tensor.device} device={src_device}" + ) + if src_tensor is not None and src_tensor.shape != expected_shape: + # NOTE: Gradient shape mismatch is not possible in practice since + # the gradient shape is enforced to match that of the parameter and + # we already check for parameter shape mismatch. + raise RuntimeError( + f"Cannot writeback when the {'parameter' if is_param else 'gradient'} " + f"shape changes\nExpects {expected_shape} but got {src_tensor.shape}" + ) + if src_tensor is not None: + dst_tensor[offset : offset + expected_shape.numel()].copy_(src_tensor) + else: + dst_tensor[offset : offset + expected_shape.numel()].zero_() + assert self.flat_param._is_grad_none_mask is not None + self.flat_param._is_grad_none_mask[tensor_index] = True + + def _reset_flat_param_grad_info_if_needed(self): + """ + Reset ``flat_param.grad`` if needed. + + When ``use_orig_params=True``: + (1) sets the underlying ``flat_param.grad`` to ``None`` if *all* of the + original parameters' ``.grad`` are ``None``, and + (2) sets ``flat_param.requires_grad=False`` if *none* of the original + parameters require gradient. + For (1), this is targeting ``optim.zero_grad(set_to_none=True)``, in + which case we want to free the gradients as soon after the + ``zero_grad()`` call as possible. + """ + if not self._use_orig_params: + return + flat_param = self.flat_param + assert flat_param._params is not None # mypy + all_grad_none = True + requires_grad = False + for param in flat_param._params: + all_grad_none &= param.grad is None + requires_grad |= param.requires_grad + if all_grad_none: + flat_param.grad = None + # As long as one parameter requires gradient, then the flat parameter + # must require gradient + flat_param.requires_grad = requires_grad + + def _deregister_orig_params(self): + for param_info in self.flat_param._param_infos: + param_name, module, _ = param_info + if hasattr(module, param_name): + delattr(module, param_name) + for param_name, module, _, _, _, _ in self.flat_param._shared_param_infos: + if hasattr(module, param_name): + delattr(module, param_name) + + ########### + # HELPERS # + ########### + def flat_param_to(self, *args, **kwargs): + """Wrap an in-place call to ``.to()`` for ``self.flat_param``.""" + self.flat_param.data = self.flat_param.to(*args, **kwargs) + if self._use_orig_params: + # Refresh the views because their storage may have changed + if self.is_sharded(self.flat_param): + self._use_sharded_views() + else: + self._use_unsharded_views(as_params=True) + + def _get_modules(self) -> set[nn.Module]: + """Return a :class:`set` of the modules whose parameters are included in this handle's flat parameter.""" + return {pi.module for pi in self.flat_param._param_infos}.union( + {spi.module for spi in self.flat_param._shared_param_infos} + ) + + def is_sharded(self, tensor: Tensor) -> bool: + """ + Return whether ``tensor`` is *currently* sharded. + + For ``NO_SHARD``, we choose to have this always return ``False`` for clarity. + """ + if ( + not hasattr(self.flat_param, "_sharded_size") + or not self.uses_sharded_strategy + ): + # `_sharded_size` is defined iff `handle.shard()` has been called + return False + sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] + return tensor.size() == sharded_size + + def param_module_names(self) -> Iterator[tuple[str, str]]: + shared_param_infos = [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ] + for param_info in chain(self.flat_param._param_infos, shared_param_infos): + param_name, _, module_name = param_info # type: ignore[misc] + yield (param_name, module_name) + + def shared_param_module_names(self) -> Iterator[tuple[str, str]]: + for param_name, _, module_name in [ + ParamInfo(param_name, module, module_name) + for ( + param_name, + module, + module_name, + _, + _, + _, + ) in self.flat_param._shared_param_infos + ]: + yield (param_name, module_name) + + @property + def _fqns_in_shard(self) -> list[str]: + """Return the FQNs of the parameters present in this rank's shard.""" + fqns_in_shard: list[str] = [] + for fqn, shard_param_info in zip( + self.flat_param._fqns, + self.flat_param._shard_param_infos, # type: ignore[attr-defined] + ): + if shard_param_info.in_shard: + fqns_in_shard.append(fqn) + return fqns_in_shard + + @property + def sharded_grad(self) -> Optional[Tensor]: + """Return the handle's sharded gradient.""" + flat_param = self.flat_param + # Priority for non-`None`: `_cpu_grad` > `_saved_grad_shard` > `grad` + # - CPU offloading: `_cpu_grad` + # - No CPU offloading + sharded strategies: `_saved_grad_shard` + # - No CPU offloading + `NO_SHARD`: `grad` + grad: Optional[Tensor] + if hasattr(flat_param, "_cpu_grad"): + grad = flat_param._cpu_grad # type: ignore[attr-defined] + elif hasattr(flat_param, "_saved_grad_shard"): + # In the post-backward hook, the sharded gradient is still in + # `_saved_grad_shard`. + grad = flat_param._saved_grad_shard # type: ignore[attr-defined] + else: + # If in IDLE or in FORWARD states, then there may be an + # (accumulated) gradient. If accessed in IDLE, then this should + # be due to re-registering the original parameters (e.g. in state + # dict load). + _p_assert( + flat_param.grad is None + or not self.uses_sharded_strategy + or self._training_state + in (HandleTrainingState.FORWARD, HandleTrainingState.IDLE), + "Sharded strategies should use `_cpu_grad` or `_saved_grad_shard` " + "unless in IDLE or FORWARD", + ) + grad = flat_param.grad + return grad + + def _reset_is_grad_none(self) -> None: + """ + Reset ``_is_grad_none_mask`` as needed. + + This method should only be + called in the post-backward after gradient computation, in which case + if a parameter requires gradient, then it will surely receive a + gradient and we may reset its mask entry to ``False``. + """ + if not self._use_orig_params: + return + _p_assert( + self._training_state == HandleTrainingState.BACKWARD_POST, + "Expects to only be called in the post-backward after gradient computation", + ) + flat_param = self.flat_param + assert flat_param._params is not None # mypy + for i, param in enumerate(flat_param._params): # type: ignore[arg-type] + # As long as the parameter requires gradient, it should receive a + # meaningful gradient (even if the gradient happens to be zeros) + if param.requires_grad: + assert flat_param._is_grad_none_mask is not None # mypy + flat_param._is_grad_none_mask[i] = False + + ####################### + # CHECKS & INVARIANTS # + ####################### + def _check_sharded_strategy(self): + _p_assert(self.uses_sharded_strategy, "Expects sharded strategy") + + def _check_on_compute_device(self, tensor: Tensor): + _p_assert( + tensor.device == self.device, + f"Expects tensor to be on the compute device {self.device}, was on {tensor.device}", + ) + + def _check_on_cpu(self, tensor: Tensor): + _p_assert( + tensor.device == torch.device("cpu"), + f"Expects tensor to be on CPU but got {tensor.device}", + ) + + @staticmethod + def _check_storage_freed(tensor: Tensor): + # Compile does not resize during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + _p_assert( + _same_storage_size(tensor, 0), + "Expects storage to be freed but got storage with size > 0", + ) + + @staticmethod + def _check_storage_allocated(tensor: Tensor): + _p_assert(_storage_size_allocated(tensor), "Expects storage to be allocated") + + def _check_low_precision_shard(self): + _p_assert( + self._uses_param_mixed_precision, + "Not using low precision for parameters", + ) + _p_assert( + getattr(self.flat_param, "_mp_shard", None) is not None, + "Expects `_mp_shard` to exist", + ) + device = self.flat_param._mp_shard.device # type: ignore[attr-defined] + _p_assert( + device == self.device, + f"Expects the low precision shard to be on {self.device} but got {device}", + ) + + def _check_unsharded(self, tensor: Tensor): + msg_prefix = "Expects tensor to be unsharded " + _p_assert(tensor is not None, msg_prefix + "but got `None`") + unsharded_size = self.flat_param._unpadded_unsharded_size + _p_assert( + tensor.size() == unsharded_size, + msg_prefix + f"with size {unsharded_size} but got {tensor.size()}", + ) + + def _check_sharded(self, tensor: Tensor): + msg_prefix = "Expects tensor to be sharded " + _p_assert(tensor is not None, msg_prefix + "but got `None`") + sharded_size = self.flat_param._sharded_size # type: ignore[attr-defined] + _p_assert( + tensor.size() == sharded_size, + msg_prefix + f"with size {sharded_size} but got {tensor.size()}", + ) + + ############## + # PROPERTIES # + ############## + @property + def uses_sharded_strategy(self) -> bool: + return self._sharding_strategy != HandleShardingStrategy.NO_SHARD + + @property + def _uses_param_mixed_precision(self) -> bool: + return self._fwd_bwd_param_dtype != self._orig_param_dtype + + @property + def _uses_reduce_mixed_precision(self) -> bool: + return self._reduce_dtype != self._orig_param_dtype + + @property + def _force_full_precision(self) -> bool: + return ( + self._uses_param_mixed_precision or self._uses_reduce_mixed_precision + ) and ( + self._training_state == HandleTrainingState.SUMMON_FULL_PARAMS + or + # Also disable mixed precision in model eval mode, if configured + (not self._fully_sharded_module.training and self._use_full_prec_in_eval) + ) + + @property + def _skipped_use_sharded_views(self) -> bool: + """ + This property is used for sharding strategies that do not free after forward with ``use_orig_params=True``. + + This returns if this handle is + currently in a state where it has skipped using sharded views, in which + case it can restore view invariants via ``_use_sharded_views()``. + """ + return self._unsharded_flat_param_for_skipped_views is not None + + +# NOTE: These are hacks to bypass `nn.Module.__setattr__` checks. +def _unsafe_setattr_param( + module: nn.Module, param_name: str, param: nn.Parameter +) -> None: + module._parameters[param_name] = param + # This bypasses any overrides in case `module` is an instance of an + # `nn.Module` subclass + super(nn.Module, module).__setattr__(param_name, param) + + +def _unsafe_setattr_tensor(module: nn.Module, param_name: str, tensor: Tensor) -> None: + module._parameters.pop(param_name, None) + # This bypasses any overrides in case `module` is an instance of an + # `nn.Module` subclass + super(nn.Module, module).__setattr__(param_name, tensor) + + +def _safe_setattr_tensor_or_param( + module: nn.Module, param_name: str, tensor_or_param: Union[Tensor, nn.Parameter] +): + # Call `delattr()` and `setattr()` to go through `nn.Module` checks + if hasattr(module, param_name): + delattr(module, param_name) + setattr(module, param_name, tensor_or_param) + + +def _convert_to_params( + tensors: list[Union[torch.Tensor, nn.Parameter]], +) -> list[nn.Parameter]: + return [t if isinstance(t, nn.Parameter) else nn.Parameter(t) for t in tensors] + + +def _is_truly_contiguous(x: Tensor) -> bool: + # Special case: Pytorch thinks that 1x1 channels_last convolution weights are + # both contiguous and channels_last contiguous at the same time. + # CuDNN does not agree though and refuses to select faster kernels. + # It is the reason of having the extra check here. + return x.stride(-1) == 1 and x.is_contiguous() + + +def _detach_if_needed(param_or_tensor: Union[nn.Parameter, Tensor]) -> Tensor: + return ( + param_or_tensor.detach() + if isinstance(param_or_tensor, nn.Parameter) + else param_or_tensor + ) + + +def _get_aligned_numel(unsharded_dtype: torch.dtype): + # NOTE: This alignment constraint comes from TorchInductor. + ALIGNMENT = 16 # bytes + unsharded_dtype_size = _get_dtype_size(unsharded_dtype) + aligned_numel = ALIGNMENT // unsharded_dtype_size + return aligned_numel + + +@functools.lru_cache(8) +def _get_dtype_size(dtype): + return torch.empty((), dtype=dtype).element_size() + + +def _construct_padding_tensor( + padding_numel: int, dtype: torch.dtype, requires_grad: bool, device: torch.device +): + # NOTE: Set the padding value as a magic number for debuggability. The + # value itself should never be used in any user-facing computation. + return ( + torch.ones( + (padding_numel,), dtype=dtype, requires_grad=requires_grad, device=device + ) + * _FLAT_PARAM_PADDING_VALUE + ) + + +# Use `lru_cache(1)` to only log the warning once (assuming the fixed warning +# message is passed in) +@functools.lru_cache(1) +def _warn_skip_writeback_check(log: logging.Logger, warning: str): + logger.warning(warning) + + +# Use `lru_cache(1)` to only log the warning once +@functools.lru_cache(1) +def _warn_use_fake_all_gather(log: logging.Logger, warning: str): + logger.warning(warning) + + +# Use `lru_cache(1)` to only log the warning once +@functools.lru_cache(1) +def _warn_use_fake_reduce(log: logging.Logger, warning: str): + logger.warning(warning) + + +def _same_storage(a, b): + # Params are DTensors in backward + # with SHARD_GRAD_OP + TP + from torch.distributed.tensor import DTensor + + if isinstance(a, DTensor): + a = a._local_tensor + if isinstance(b, DTensor): + b = b._local_tensor + return a.untyped_storage().data_ptr() == b.untyped_storage().data_ptr() + + +def _same_storage_size(a: torch.Tensor, b: int): + return a.untyped_storage().size() // a.element_size() == b + + +def _storage_size_allocated(tensor: Tensor): + storage_size: int = tensor.untyped_storage().size() + return storage_size > 0 diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fsdp_extensions.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fsdp_extensions.py new file mode 100644 index 0000000000000000000000000000000000000000..25730da77febd0c4737ae6740e8104561f681bcd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fsdp_extensions.py @@ -0,0 +1,179 @@ +from abc import ABC, abstractmethod +from typing import Any, Optional + +import torch +import torch.distributed as dist +from torch.distributed._shard.sharded_tensor.api import ShardedTensor +from torch.distributed._shard.sharded_tensor.shard import Shard +from torch.distributed.fsdp._shard_utils import ( + _all_gather_dtensor, + _create_chunk_dtensor, + _create_chunk_sharded_tensor, +) +from torch.distributed.tensor import DeviceMesh, DTensor + + +class FSDPExtensions(ABC): + """ + This enables some customizable hooks to enable composability with tensor + parallelism. To activate these hooks, use :func:`_set_fsdp_extensions` to + set a custom :class:`FSDPExtensions` that implements the hooks. + """ + + @abstractmethod + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, Optional[Any]]: + """E.g. converting ``DistributedTensor`` to local tensor.""" + ... + + @abstractmethod + def post_unflatten_transform( + self, + tensor: torch.Tensor, + param_extension: Any, + ) -> torch.Tensor: + """E.g. converting local tensor to ``DistributedTensor``.""" + ... + + @abstractmethod + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + """Shards a tensor to chunks and returns the local chunk.""" + ... + + @abstractmethod + def chunk_dtensor( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> torch.Tensor: + """Shards a tensor/DTensor to DTensor and returns the local DTensor.""" + ... + + @abstractmethod + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, list[Shard]]: + """ + This is to be called before loading a *sharded* model state dict and + should return the tensor and list of shards from which to load data. + """ + ... + + @abstractmethod + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + """ + This is to be called before loading a *sharded* DTensor state dict. + This gathers tensor in FSDP dimension and returns local tensor of + TP DTensor. + """ + ... + + +_extensions: Optional[FSDPExtensions] = None + + +def _set_fsdp_extensions(flattener: FSDPExtensions) -> None: + global _extensions + _extensions = flattener + + +def _ext_pre_flatten_transform( + tensor: torch.Tensor, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> tuple[torch.Tensor, Optional[Any]]: + if fsdp_extension is not None: + new_tensor, param_extension = fsdp_extension.pre_flatten_transform(tensor) + if param_extension is not None: + return new_tensor, param_extension + return tensor, None + + +def _ext_post_unflatten_transform( + tensor: torch.Tensor, + param_extension: Any, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + if fsdp_extension is not None and param_extension is not None: + return fsdp_extension.post_unflatten_transform(tensor, param_extension) + return tensor + + +def _ext_chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + chunk_tensor_fn = ( + fsdp_extension.chunk_tensor + if fsdp_extension is not None + else _create_chunk_sharded_tensor + ) + return chunk_tensor_fn( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _ext_chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + chunk_dtensor_fn = ( + fsdp_extension.chunk_dtensor + if fsdp_extension is not None + else _create_chunk_dtensor + ) + return chunk_dtensor_fn( + tensor, + rank, + device_mesh, + ) + + +def _ext_pre_load_state_dict_transform( + tensor: torch.Tensor, + fsdp_extension: Optional[FSDPExtensions] = None, +) -> tuple[torch.Tensor, list[Shard]]: + if fsdp_extension is not None: + return fsdp_extension.pre_load_state_dict_transform(tensor) + + assert type(tensor) is ShardedTensor + shards = tensor.local_shards() + return (tensor, shards) + + +def _ext_all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + fsdp_extension: Optional[FSDPExtensions] = None, +) -> torch.Tensor: + all_gather_dtensor_fn = ( + fsdp_extension.all_gather_dtensor + if fsdp_extension is not None + else _all_gather_dtensor + ) + return all_gather_dtensor_fn(tensor, parent_mesh) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7aa0e06a4545e3c9134c19e026b23421bb6526a6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__init__.py @@ -0,0 +1,18 @@ +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fully_shard import ( + FSDPModule, + fully_shard, + register_fsdp_forward_method, + UnshardHandle, +) + + +__all__ = [ + "CPUOffloadPolicy", + "FSDPModule", + "fully_shard", + "MixedPrecisionPolicy", + "OffloadPolicy", + "register_fsdp_forward_method", + "UnshardHandle", +] diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..187b045242398c84be0998c75e10a0b330f0bb74 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..835ce847709b3815044107636552e7c8df7ac95b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3df00025200a47d737768ef7507867a23c500b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_collectives.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c77c1cc04590e83b33835cb20fa5d10425084b16 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b5cbba08e2b782a6ceff6ba15ec48d1a07cab49 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_init.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..862ed5e45b4b60ec9d207e49d66c01408223cc8c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aae25f2fca423e1b8948b5c9c146c2c032e5f0c3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_param_group.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..454155ca18ed61fc188f3be87fe21a8df5a15302 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fsdp_state.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff30f384530e16aa5c173d41515f729339d05eb2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/__pycache__/_fully_shard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py new file mode 100644 index 0000000000000000000000000000000000000000..9ad802e57b97ee7675b2b1bf4f6362c856a2cbc3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_api.py @@ -0,0 +1,75 @@ +# mypy: allow-untyped-defs +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass(frozen=True) +class MixedPrecisionPolicy: + """ + This configures FSDP's mixed precision. Unlike autocast, this applies mixed + precision at the module level, not op level, which means low-precision + activations are saved for backward and high-to-low-precision casts are + incurred only at module boundaries. + + FSDP works well with module-level mixed precision since it keeps the + high-precision sharded parameters in memory anyway. In other words, FSDP + does not require any extra memory to keep a high-precision copy of the + parameters for the optimizer step. + + Attributes: + param_dtype (Optional[torch.dtype]): This specifies the dtype for + the unsharded parameter and hence the dtype for forward/backward + computation and the parameter all-gather. If this is ``None``, then + the unsharded parameter uses the original dtype. The optimizer step + uses the sharded parameter in the original dtype. (Default: + ``None``) + reduce_dtype (Optional[torch.dtype]): This specifies the dtype for + gradient reduction (i.e. reduce-scatter or all-reduce). If this is + ``None`` but ``param_dtype`` is not ``None``, then the reduction + uses the compute dtype. This can be used to run gradient reduction + in full precision while using low precision for compute. If also + gradient reduction is disabled via :meth:`set_requires_gradient_sync`, + then FSDP will accumulate gradients using ``reduce_dtype``. + (Default: ``None``) + output_dtype (Optional[torch.dtype]): This specifies the dtype for + casting floating-point forward outputs. This can be used to + help implement cases where different modules have different mixed + precision policies. (Default: ``None``) + cast_forward_inputs (bool): This specifies whether FSDP should cast the + forward's floating-point input tensors to ``param_dtype`` or not. + """ + + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + output_dtype: Optional[torch.dtype] = None + cast_forward_inputs: bool = True + + +@dataclass +class OffloadPolicy: + """ + This base class represents the policy of no offloading and is only used as + the default value for the ``offload_policy`` arg. + """ + + +@dataclass +class CPUOffloadPolicy(OffloadPolicy): + """ + This offload policy offloads parameters, gradients, and optimizer states to + CPU. Sharded parameters are copied host-to-device before all-gather. The + all-gathered parameters are freed according to ``reshard_after_forward``. + Sharded gradients are copied device-to-host in backward, and the optimizer + step runs on CPU with CPU optimizer states. + + Attributes: + pin_memory (bool): Whether to pin sharded parameter and gradient + memory. Pinning memory allows both more efficient H2D/D2H copies + and for the copies to overlap with compute. However, the pinned + memory cannot be used by other processes. Set this to ``False`` if + you have insufficient CPU memory. (Default: ``True``) + """ + + pin_memory: bool = True diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py new file mode 100644 index 0000000000000000000000000000000000000000..c796b6e00655d2b04f87c1fc78377e60db1b8d4f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py @@ -0,0 +1,661 @@ +from itertools import chain +from typing import Callable, cast, NamedTuple, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.distributed_c10d import _resolve_process_group, ReduceOp +from torch.distributed.tensor import DTensor + +from ._fsdp_common import ( + _get_dim0_padded_size, + _raise_assert_with_print, + _to_dtype_if_needed, + compiled_autograd_enabled, +) +from ._fsdp_param import FSDPParam, ShardedState + + +class AllGatherResult(NamedTuple): + all_gather_output: torch.Tensor + all_gather_event: Optional[torch.Event] + all_gather_work: Optional[dist.distributed_c10d.Work] + # For each parameter, the all-gather input dtype for each input + param_all_gather_input_dtypes: list[list[torch.dtype]] + # For each parameter, the all-gather input numel for each input + param_all_gather_input_numels: list[list[int]] + # 1D flattened version of `param_all_gather_input_numels` saved to avoid + # CPU overhead from recomputing + all_gather_input_split_sizes: list[int] + + +def allocate_memory( + size: int, + dtype: torch.dtype, + device: torch.device, + group: dist.ProcessGroup, + from_process_group: bool, +) -> torch.Tensor: + if from_process_group: + backend = group._get_backend(device) + if backend.supports_tensor_alloc(device): + return backend.allocate_tensor(size, dtype=dtype, device=device) + return torch.empty((size,), dtype=dtype, device=device) + + +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define( + """ + all_gather_copy_in( + Tensor[] all_gather_inputs, + SymInt[] inp_split_sizes, + SymInt all_gather_input_numel, + SymInt world_size, + SymInt rank, + ScalarType dtype, + Device device, + str group_name, + bool allocate_memory_from_process_group + ) -> (Tensor, Tensor) + """ +) + + +@torch.library.impl(lib, "all_gather_copy_in", "Meta") +def all_gather_copy_in_meta( + all_gather_inputs: list[torch.Tensor], + inp_split_sizes: list[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + group_name: str, + allocate_memory_from_process_group: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + all_gather_output = torch.empty( + (all_gather_input_numel * world_size,), dtype=dtype, device="meta" + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + return all_gather_input, all_gather_output + + +@torch.library.impl(lib, "all_gather_copy_in", "CUDA") +@torch.library.impl(lib, "all_gather_copy_in", "XPU") +@torch.library.impl(lib, "all_gather_copy_in", "HPU") +@torch.library.impl(lib, "all_gather_copy_in", "CPU") +@torch.library.impl(lib, "all_gather_copy_in", "MTIA") +@torch.library.impl(lib, "all_gather_copy_in", "PrivateUse1") +def all_gather_copy_in_cuda( + all_gather_inputs: list[torch.Tensor], + inp_split_sizes: list[int], + all_gather_input_numel: int, + world_size: int, + rank: int, + dtype: torch.dtype, + device: torch.device, + group_name: str, + allocate_memory_from_process_group: bool, +) -> tuple[torch.Tensor, torch.Tensor]: + all_gather_output = allocate_memory( + all_gather_input_numel * world_size, + dtype=dtype, + device=device, + group=_resolve_process_group(group_name), + from_process_group=allocate_memory_from_process_group, + ) + all_gather_input = all_gather_output.narrow( + 0, all_gather_input_numel * rank, all_gather_input_numel + ) + foreach_copy_dsts = torch.split(all_gather_input, inp_split_sizes) + with torch.no_grad(): + torch._foreach_copy_(foreach_copy_dsts, all_gather_inputs) + return all_gather_input, all_gather_output + + +lib.define( + "split_with_sizes_copy(Tensor all_gather_output, SymInt[] all_gather_input_split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()" +) + + +@torch.library.impl(lib, "split_with_sizes_copy", "Meta") +@torch.library.impl(lib, "split_with_sizes_copy", "CUDA") +@torch.library.impl(lib, "split_with_sizes_copy", "XPU") +@torch.library.impl(lib, "split_with_sizes_copy", "HPU") +@torch.library.impl(lib, "split_with_sizes_copy", "CPU") +@torch.library.impl(lib, "split_with_sizes_copy", "MTIA") +@torch.library.impl(lib, "split_with_sizes_copy", "PrivateUse1") +def split_with_sizes_copy( + all_gather_output: torch.Tensor, + all_gather_input_split_sizes: list[int], + dim: int, + out: list[torch.Tensor], +) -> None: + torch.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=dim, out=out + ) + + +lib.define( + "chunk_cat(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> ()" +) + + +@torch.library.impl(lib, "chunk_cat", "Meta") +@torch.library.impl(lib, "chunk_cat", "CUDA") +@torch.library.impl(lib, "chunk_cat", "XPU") +@torch.library.impl(lib, "chunk_cat", "HPU") +@torch.library.impl(lib, "chunk_cat", "CPU") +@torch.library.impl(lib, "chunk_cat", "MTIA") +@torch.library.impl(lib, "chunk_cat", "PrivateUse1") +def chunk_cat( + tensors: list[torch.Tensor], + dim: int, + num_chunks: int, + out: torch.Tensor, +) -> None: + torch._chunk_cat(tensors, dim, num_chunks, out=out) + + +@torch.no_grad() +def foreach_all_gather( + fsdp_params: list[FSDPParam], + group: dist.ProcessGroup, + async_op: bool, + all_gather_copy_in_stream: torch.Stream, + all_gather_stream: torch.Stream, + device: torch.device, + allocate_memory_from_process_group: bool = False, +) -> Optional[AllGatherResult]: + world_size, rank = group.size(), group.rank() + device_handle = _get_device_handle(device.type) + with device_handle.stream(all_gather_copy_in_stream): + param_all_gather_inputs = _get_param_all_gather_inputs(fsdp_params) + ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + dtype, + ) = _get_all_gather_input_metadatas(param_all_gather_inputs) + if dtype == torch.uint8: + all_gather_inputs = [ + t.view(torch.uint8) for ts in param_all_gather_inputs for t in ts + ] + else: + all_gather_inputs = [*chain.from_iterable(param_all_gather_inputs)] + inp_split_sizes = [t.numel() for t in all_gather_inputs] + all_gather_input_numel = sum(inp_split_sizes) + all_gather_input, all_gather_output = torch.ops.fsdp.all_gather_copy_in( + all_gather_inputs, + inp_split_sizes, + all_gather_input_numel, + world_size, + rank, + dtype, + device, + group.group_name, + allocate_memory_from_process_group, + ) + del param_all_gather_inputs + all_gather_stream.wait_stream(all_gather_copy_in_stream) + with device_handle.stream(all_gather_stream): + all_gather_work = dist.all_gather_into_tensor( + output_tensor=all_gather_output, + input_tensor=all_gather_input, + group=group, + async_op=async_op, + ) + all_gather_event = all_gather_stream.record_event() + return AllGatherResult( + all_gather_output, + all_gather_event, + all_gather_work, + param_all_gather_input_dtypes, + param_all_gather_input_numels, + inp_split_sizes, + ) + + +@torch.no_grad() +def _get_param_all_gather_inputs( + fsdp_params: list[FSDPParam], +) -> list[list[torch.Tensor]]: + if compiled_autograd_enabled(): + return [fsdp_param.all_gather_inputs for fsdp_param in fsdp_params] + + # Intentionally try to run a fast-path that bypasses abstractions for the + # common FSDP case of bf16/fp32 mixed precision in order to use foreach + # copy for lower CPU overhead and more efficient copying in eager + def use_foreach_copy(fsdp_param: FSDPParam) -> bool: + return ( + fsdp_param.param_dtype is not None + and not fsdp_param.offload_to_cpu + and not hasattr(fsdp_param._sharded_local_tensor, "fsdp_pre_all_gather") + ) + + param_all_gather_inputs: list[list[torch.Tensor]] = [[] for _ in fsdp_params] + foreach_copy_indices: list[int] = [] + foreach_copy_inputs: list[torch.Tensor] = [] + foreach_copy_input_numels: list[int] = [] + + # 1st pass: for foreach-copy parameters, get inputs and metadata for the + # foreach copy, and for the others, actually get their all-gather inputs + for i, fsdp_param in enumerate(fsdp_params): + if use_foreach_copy(fsdp_param): + foreach_copy_indices.append(i) + all_gather_input = ( + fsdp_param._sharded_param_data + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast(torch.Tensor, fsdp_param._sharded_post_forward_param_data) + ) + foreach_copy_inputs.append(all_gather_input) + foreach_copy_input_numels.append(all_gather_input.numel()) + else: + param_all_gather_inputs[i] = fsdp_param.all_gather_inputs + + # 2nd pass: use foreach copy to compute the remaining all-gather inputs + if foreach_copy_inputs: + fsdp_param_0 = fsdp_params[foreach_copy_indices[0]] + param_dtype, device = fsdp_param_0.param_dtype, fsdp_param_0.device + flat_foreach_copy_input = torch.empty( + (sum(foreach_copy_input_numels),), device=device, dtype=param_dtype + ) + splits = torch.split(flat_foreach_copy_input, foreach_copy_input_numels) + torch._foreach_copy_(splits, foreach_copy_inputs) + for i, split in zip(foreach_copy_indices, splits): + param_all_gather_inputs[i] = [split] + + return param_all_gather_inputs + + +@torch.no_grad() +def foreach_all_gather_copy_out( + all_gather_result: AllGatherResult, + fsdp_params: list[FSDPParam], + group: dist.ProcessGroup, +) -> None: + ( + all_gather_output, + all_gather_event, + all_gather_work, + param_all_gather_input_dtypes, + param_all_gather_input_numels, + all_gather_input_split_sizes, + ) = all_gather_result + _dtype, device = all_gather_output.dtype, all_gather_output.device + device_handle = _get_device_handle(device.type) + if all_gather_event is not None: # sync op + device_handle.current_stream().wait_event(all_gather_event) + if isinstance(all_gather_work, dist.distributed_c10d.Work): # async op + all_gather_work.wait() + world_size, device = group.size(), all_gather_output.device + + split_with_sizes_out: list[torch.Tensor] = [] + shard_i_copy_infos: list[tuple[FSDPParam, list[torch.Tensor]]] = [] + for all_gather_input_numels, all_gather_input_dtypes, fsdp_param in zip( + param_all_gather_input_numels, param_all_gather_input_dtypes, fsdp_params + ): + # NOTE: Under compile, make sure we always recreate all_gather_outputs + # per AllGather. See [Note: Invariants for torch.compile Traceable FSDP2]. + force_recreate = compiled_autograd_enabled() + fsdp_param.init_all_gather_outputs( + all_gather_input_numels, + all_gather_input_dtypes, + world_size, + device, + force_recreate=force_recreate, + ) + if not force_recreate: + fsdp_param.alloc_all_gather_outputs() + param_all_gather_outputs = fsdp_param.all_gather_outputs + if fsdp_param.fsdp_placement.dim != 0: + # Copy to a temporary and then chunk-cat into the final all-gather + # output tensors + param_all_gather_outputs = [ + torch.empty_like(t) for t in param_all_gather_outputs + ] + shard_i_copy_infos.append((fsdp_param, param_all_gather_outputs)) + split_with_sizes_out.extend(param_all_gather_outputs) + + all_gather_output = all_gather_output.view(world_size, -1) + if all_gather_output.dtype == torch.uint8: + out = [t.view(world_size, -1).view(torch.uint8) for t in split_with_sizes_out] + else: + out = [t.view(world_size, -1) for t in split_with_sizes_out] + + # only avoid VC bump if we are not in inference mode + if torch._dynamo.is_compiling(): + # For torch.compile, we turn off inference_mode for fake tensor + # propagation, and therefore graph break on is_inference. For `compile`, + # we don't care about VCs, so just skip the optimization. + non_inference_outs = [] + else: + non_inference_outs = [o for o in out if not o.is_inference()] + + if len(non_inference_outs) > 0: + with torch.autograd._unsafe_preserve_version_counter(tuple(non_inference_outs)): + torch.ops.fsdp.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=1, out=out + ) + else: + torch.ops.fsdp.split_with_sizes_copy( + all_gather_output, all_gather_input_split_sizes, dim=1, out=out + ) + + for fsdp_param, param_all_gather_outputs in shard_i_copy_infos: + # Chunk-cat from the temporary to the final all-gather output tensors + shard_dim = fsdp_param.fsdp_placement.dim + + with torch.autograd._unsafe_preserve_version_counter( + tuple(fsdp_param.all_gather_outputs) + ): + for param_all_gather_output, target_all_gather_output in zip( + param_all_gather_outputs, fsdp_param.all_gather_outputs + ): + padded_sharded_size = ( + fsdp_param.padded_sharded_param_size + if fsdp_param.sharded_state == ShardedState.SHARDED + else cast( + torch.Tensor, fsdp_param._sharded_post_forward_param_data + ).size() + ) + pre_param_size = list(padded_sharded_size) + pre_param_size[0] *= world_size + chunks = torch.chunk( + param_all_gather_output.view(pre_param_size), world_size, dim=0 + ) + post_param_size = list(padded_sharded_size) + post_param_size[shard_dim] *= world_size + cat_out = target_all_gather_output.view(post_param_size) + torch.cat(chunks, dim=shard_dim, out=cat_out) + + +@torch.no_grad() +def foreach_reduce( + fsdp_params: list[FSDPParam], + unsharded_grads: list[torch.Tensor], + reduce_scatter_group: dist.ProcessGroup, + reduce_scatter_stream: torch.Stream, + orig_dtype: Optional[torch.dtype], + reduce_dtype: Optional[torch.dtype], + device: torch.device, + gradient_divide_factor: Optional[float], + all_reduce_group: Optional[dist.ProcessGroup], # not `None` iff HSDP + all_reduce_stream: torch.Stream, + all_reduce_grads: bool, + partial_reduce_output: Optional[torch.Tensor], # only used for HSDP + all_reduce_hook: Optional[Callable[[torch.Tensor], None]], + allocate_memory_from_process_group: bool = False, + force_sum_reduction_for_comms: bool = False, +) -> tuple[ + torch.Tensor, + torch.Event, + torch.Event, + Optional[torch.Tensor], + Optional[torch.Event], + Optional[torch.Tensor], +]: + """ + ``unsharded_grads`` owns the references to the gradients computed by + autograd, so clearing the list frees the gradients. + """ + grad_dtypes = {grad.dtype for grad in unsharded_grads} + if len(grad_dtypes) != 1: + # Check this at runtime since it could be a real runtime error if e.g. + # fp8 weights do not produce the correct higher precision gradients + _raise_assert_with_print( + f"FSDP reduce-scatter expects uniform gradient dtype but got {grad_dtypes}" + ) + grad_dtype = unsharded_grads[0].dtype + reduce_dtype = reduce_dtype or grad_dtype + (predivide_factor, postdivide_factor, reduce_scatter_op, all_reduce_op) = ( + _get_gradient_divide_factors( + reduce_scatter_group, + all_reduce_group, + reduce_dtype, + device.type, + gradient_divide_factor, + force_sum_reduction_for_comms, + ) + ) + world_size = reduce_scatter_group.size() + for i, (fsdp_param, unsharded_grad) in enumerate(zip(fsdp_params, unsharded_grads)): + if (shard_dim := fsdp_param.fsdp_placement.dim) == 0: + continue + assert unsharded_grad.size(shard_dim) % world_size == 0, ( + f"Shard({shard_dim}) requires even sharding: {unsharded_grad.size()=} {world_size=}" + ) + chunks = torch.chunk(unsharded_grad, world_size, dim=shard_dim) + unsharded_grads[i] = torch.cat(chunks, dim=0) + padded_unsharded_sizes = tuple( + _get_dim0_padded_size(grad.size(), world_size) for grad in unsharded_grads + ) + reduce_scatter_input_numel = sum(s.numel() for s in padded_unsharded_sizes) + reduce_scatter_output_numel = reduce_scatter_input_numel // world_size + reduce_scatter_input = allocate_memory( + reduce_scatter_input_numel, + dtype=reduce_dtype, + device=device, + group=reduce_scatter_group, + from_process_group=allocate_memory_from_process_group, + ) + device_handle = _get_device_handle(device.type) + foreach_reduce_scatter_copy_in(unsharded_grads, reduce_scatter_input, world_size) + current_stream = device_handle.current_stream() + # Only after the copy-in finishes can we free the gradients + unsharded_grads.clear() + reduce_scatter_stream.wait_stream(current_stream) + all_reduce_input = None + all_reduce_event = None + with device_handle.stream(reduce_scatter_stream): + reduce_output = allocate_memory( + reduce_scatter_output_numel, + dtype=reduce_dtype, + device=device, + group=reduce_scatter_group, + from_process_group=allocate_memory_from_process_group, + ) + _div_if_needed(reduce_scatter_input, predivide_factor) + dist.reduce_scatter_tensor( + output=reduce_output, + input=reduce_scatter_input, + group=reduce_scatter_group, + op=reduce_scatter_op, + ) + reduce_scatter_event = reduce_scatter_stream.record_event() + post_reduce_stream = reduce_scatter_stream + if all_reduce_group is not None: # HSDP + # Accumulations must run in the reduce-scatter stream + if not all_reduce_grads: + if partial_reduce_output is not None: + partial_reduce_output += reduce_output + else: + partial_reduce_output = reduce_output + return ( + reduce_scatter_input, + reduce_scatter_event, + post_reduce_stream.record_event(), + all_reduce_input, + all_reduce_event, + partial_reduce_output, + ) + if partial_reduce_output is not None: + reduce_output += partial_reduce_output + post_reduce_stream = all_reduce_stream + all_reduce_stream.wait_stream(reduce_scatter_stream) + with device_handle.stream(all_reduce_stream): + dist.all_reduce( + reduce_output, + group=all_reduce_group, + op=all_reduce_op, + ) + all_reduce_input = reduce_output + all_reduce_event = all_reduce_stream.record_event() + # -- END: ops in reduce_scatter stream + + if all_reduce_hook is not None: + # Execute user-specified all reduce hook. + # If native HSDP is used, this is executed after the HSDP all reduce. + # If 1-d FSDP is used, this is executed post reduce-scatter. + post_reduce_stream = all_reduce_stream + all_reduce_stream.wait_stream(reduce_scatter_stream) + with device_handle.stream(all_reduce_stream): + all_reduce_hook(reduce_output) + # -- END: ops post reduce_scatter + + with device_handle.stream(post_reduce_stream): + _div_if_needed(reduce_output, postdivide_factor) + reduce_output = _to_dtype_if_needed(reduce_output, orig_dtype) + # View out and accumulate sharded gradients + flat_grad_offset = 0 # [0, reduce_scatter_output_numel - 1] + for padded_unsharded_size, fsdp_param in zip( + padded_unsharded_sizes, fsdp_params + ): + # Assume even sharding for Shard(i), i > 0; otherwise would require + # copy-out for contiguous strides + new_sharded_grad = torch.as_strided( + reduce_output, + size=fsdp_param.sharded_size, + stride=fsdp_param.contiguous_sharded_stride, + storage_offset=flat_grad_offset, + ) + to_accumulate_grad = fsdp_param.sharded_param.grad is not None + if fsdp_param.offload_to_cpu: + # Only overlap the D2H copy (copying to pinned memory) if not + # accumulating gradients since the CPU add kernel depends on + # the copy result and we cannot run the add as a callback + non_blocking = fsdp_param.pin_memory and not to_accumulate_grad + # Since the GPU sharded gradient is allocated in the RS stream, + # we can free it here by not keeping a ref without waiting for + # the D2H copy since future RS-stream ops run after the copy + new_sharded_grad = new_sharded_grad.to( + torch.device("cpu"), non_blocking=non_blocking + ) + if non_blocking: + # Record an event on which to block the CPU thread to + # ensure that the D2H copy finishes before the optimizer + fsdp_param.grad_offload_event = reduce_scatter_stream.record_event() + if to_accumulate_grad: + assert isinstance(fsdp_param.sharded_param.grad, DTensor) + fsdp_param.sharded_param.grad._local_tensor += new_sharded_grad + else: + new_sharded_dtensor_grad = fsdp_param.to_sharded_dtensor( + new_sharded_grad + ) + fsdp_param.sharded_param.grad = new_sharded_dtensor_grad + if not compiled_autograd_enabled(): + for hook in ( + getattr(fsdp_param.sharded_param, "_post_accumulate_grad_hooks", {}) + or {} + ).values(): + hook(fsdp_param.sharded_param) + padded_sharded_numel = padded_unsharded_size.numel() // world_size + flat_grad_offset += padded_sharded_numel + post_reduce_event = post_reduce_stream.record_event() + # The RS output is allocated in the RS stream and used in the default + # stream (for optimizer). To ensure its memory is not reused for later + # RSs, we do not need extra synchronization since the sharded parameters + # hold refs through the end of backward. + return ( + reduce_scatter_input, + reduce_scatter_event, + post_reduce_event, + all_reduce_input, + all_reduce_event, + None, + ) + + +def foreach_reduce_scatter_copy_in( + unsharded_grads: list[torch.Tensor], + reduce_scatter_input: torch.Tensor, + world_size: int, +) -> None: + reduce_scatter_input = reduce_scatter_input.view(world_size, -1) + torch.ops.fsdp.chunk_cat( + unsharded_grads, dim=0, num_chunks=world_size, out=reduce_scatter_input + ) + + +def _get_all_gather_input_metadatas( + param_all_gather_inputs: list[list[torch.Tensor]], +) -> tuple[list[list[torch.dtype]], list[list[int]], torch.dtype]: + param_all_gather_input_dtypes: list[list[torch.dtype]] = [] + param_all_gather_input_numels: list[list[int]] = [] + all_gather_dtype = param_all_gather_inputs[0][0].dtype + for all_gather_inputs in param_all_gather_inputs: + input_dtypes: list[torch.dtype] = [] + input_numels: list[int] = [] + for all_gather_input in all_gather_inputs: + if all_gather_input.dtype != all_gather_dtype: + all_gather_dtype = torch.uint8 + input_dtypes.append(all_gather_input.dtype) + input_numels.append(all_gather_input.numel()) + param_all_gather_input_dtypes.append(input_dtypes) + param_all_gather_input_numels.append(input_numels) + return ( + param_all_gather_input_dtypes, + param_all_gather_input_numels, + all_gather_dtype, + ) + + +def _get_gradient_divide_factors( + reduce_scatter_group: dist.ProcessGroup, + all_reduce_group: Optional[dist.ProcessGroup], + reduce_dtype: torch.dtype, + device_type: str = "", + factor: Optional[float] = None, + force_sum_reduction_for_comms: bool = False, +) -> tuple[ + Optional[float], + Optional[float], + Union[dist.ReduceOp, dist.ReduceOp.RedOpType], + Union[dist.ReduceOp, dist.ReduceOp.RedOpType], +]: + # MTIA appears to only support SUM reduction, hence we force it implicitly + if device_type == "mtia": + force_sum_reduction_for_comms = True + + # For fp32/bf16, we do not need to worry about overflow/underflow, so we + # use NCCL's built-in division to avoid separate div kernels + overflow_risk = reduce_dtype not in (torch.float32, torch.bfloat16) + + data_parallel_size = reduce_scatter_group.size() + if all_reduce_group is not None: + data_parallel_size *= all_reduce_group.size() + + if factor is None: + factor = float(data_parallel_size) + + if not overflow_risk and not force_sum_reduction_for_comms: + if factor == data_parallel_size: + # Warning: NCCL ReduceOp.AVG may produce incorrect results with + # world size 1. + return None, None, ReduceOp.AVG, ReduceOp.AVG + else: + reduce_scatter_op = torch.distributed._make_nccl_premul_sum(1 / factor) + return None, None, reduce_scatter_op, ReduceOp.SUM + + pre_factor: Optional[float] + if overflow_risk: + # Since fp16 has smaller dynamic range than fp32/bf16, we want to avoid + # overflow/underflow. For N data parallel workers, each worker computes + # g_i, and they collectively reduce (g_1 + ... + g_N) / N. To avoid + # overflow/underflow, we divide by ~sqrt(N) before/after the reduction. + pre_factor = 1 + while factor % pre_factor == 0 and factor / pre_factor > pre_factor: + pre_factor *= 2 + post_factor = factor / pre_factor + else: + # Prefer post-multiplying as it operates on less data and is thus faster + pre_factor, post_factor = None, factor + + return pre_factor, post_factor, ReduceOp.SUM, ReduceOp.SUM + + +def _div_if_needed(tensor: torch.Tensor, div_factor: Optional[float]) -> None: + if div_factor is not None and div_factor != 1: + tensor.div_(div_factor) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py new file mode 100644 index 0000000000000000000000000000000000000000..bf72a2437a3b718634ac3ad7589db7a37c593c1b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -0,0 +1,181 @@ +# mypy: allow-untyped-defs +import math +import traceback +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed._composable.contract import _get_registry +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec + + +_compiled_autograd_enabled: bool = False + +if torch._running_with_deploy(): + + def detect_compiled_autograd(): + pass + + def compiled_autograd_enabled(): + return False + +else: + + def detect_compiled_autograd(): + assert not torch.compiler.is_compiling(), ( + "`detect_compiled_autograd()` is designed to be called in eager mode" + ) + global _compiled_autograd_enabled + import torch._dynamo.compiled_autograd as ca + + _compiled_autograd_enabled = ( + ca.compiled_autograd_enabled + or ca.compiled_autograd_enabled_force_eager + or ca.in_compiled_autograd_region + ) + + def compiled_autograd_enabled(): + global _compiled_autograd_enabled + return _compiled_autograd_enabled + + +@dataclass +class DataParallelMeshInfo: + mesh: DeviceMesh + shard_mesh_dim: Optional[int] = None + replicate_mesh_dim: Optional[int] = None + + def __post_init__(self): + if self.shard_mesh_dim is None and self.replicate_mesh_dim is None: + raise AssertionError( + "At least one of shard_mesh_dim and replicate_mesh_dim must not be None" + ) + + +@dataclass +class FSDPMeshInfo(DataParallelMeshInfo): + def __post_init__(self): + super().__post_init__() + if self.shard_mesh_dim is None: + raise AssertionError("Expects non-None shard_mesh_dim") + self.shard_mesh_size: int = self.mesh.size(self.shard_mesh_dim) + self.shard_process_group = self.mesh.get_group(self.shard_mesh_dim) + self.shard_mesh_rank: int = self.shard_process_group.rank() + + +@dataclass +class DDPMeshInfo(DataParallelMeshInfo): + def __post_init__(self): + super().__post_init__() + if self.replicate_mesh_dim is None: + raise AssertionError("Expects non-None replicate_mesh_dim") + self.replicate_mesh_size: int = self.mesh.size(self.replicate_mesh_dim) + self.replicate_process_group = self.mesh.get_group(self.replicate_mesh_dim) + self.replicate_mesh_rank: int = self.replicate_process_group.rank() + + +@dataclass +class HSDPMeshInfo(FSDPMeshInfo, DDPMeshInfo): + def __post_init__(self): + # Calls `FSDPMeshInfo` -> `DDPMeshInfo` -> `DataParallelMeshInfo` + super().__post_init__() + + +class TrainingState(Enum): + """Describes the training state of one FSDP state / parameter group.""" + + # Transition to forward starting pre-forward until post-forward + FORWARD = auto() + # Transition to pre-backward when unsharding in backward + PRE_BACKWARD = auto() + # Transition to post-backward when resharding and reducing gradients + POST_BACKWARD = auto() + # Idle before/after forward or before pre-backward/after post-backward + IDLE = auto() + + +def _raise_assert_with_print(*args: Any, **kwargs: Any): + print(f"[Rank {dist.get_rank()}] ", end="") + print(*args, **kwargs) + traceback.print_stack() + raise AssertionError(*args, **kwargs) + + +def _is_composable_with_fsdp(module: nn.Module) -> bool: + registry = _get_registry(module) + if registry is None: + return True + # Registry keys by function name + return "replicate" not in registry + + +def _get_dim0_padded_size(tensor_size: torch.Size, dim0_factor: int) -> torch.Size: + padded_dim0 = math.ceil(tensor_size[0] / dim0_factor) * dim0_factor + return torch.Size([padded_dim0]) + tensor_size[1:] + + +def _chunk_with_empty( + tensor: torch.Tensor, num_chunks: int, dim: int +) -> list[torch.Tensor]: + chunks = list(torch.chunk(tensor, num_chunks, dim=dim)) + while len(chunks) < num_chunks: + chunks.append(chunks[0].new_empty(0)) + return chunks + + +def _get_dim_chunked_size( + chunk: torch.Tensor, unchunked_size: torch.Size, dim: int +) -> torch.Size: + if chunk.numel() > 0: + return chunk.size() + # For 0 numel, we need to preserve nonzero-sized dims for DTensor APIs + return unchunked_size[:dim] + torch.Size([0]) + unchunked_size[dim + 1 :] + + +def _from_local_no_grad( + local_tensor: torch.Tensor, + sharding_spec: DTensorSpec, +) -> DTensor: + """ + This method is similar to ``DTensor.from_local()`` except that in eager mode + it avoids some CPU overhead by avoiding default args and not being differentiable. + """ + + if not compiled_autograd_enabled(): + return DTensor( + # Use the local tensor directly instead of constructing a new tensor + # variable, e.g. with `view_as()`, since this is not differentiable + local_tensor, + sharding_spec, + requires_grad=local_tensor.requires_grad, + ) + else: + return DTensor.from_local( + local_tensor, + sharding_spec.mesh, + sharding_spec.placements, + shape=sharding_spec.shape, + stride=sharding_spec.stride, + ) + + +def _to_dtype_if_needed( + tensor: torch.Tensor, dtype: Optional[torch.dtype] +) -> torch.Tensor: + if dtype is not None and tensor.dtype != dtype: + return tensor.to(dtype) + return tensor + + +def _cast_fp_tensor(dtype: torch.dtype, x: torch.Tensor) -> torch.Tensor: + if ( + not isinstance(x, torch.Tensor) + or not torch.is_floating_point(x) + or x.dtype == dtype + ): + return x + return x.to(dtype) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py new file mode 100644 index 0000000000000000000000000000000000000000..3d8cb22c7fdf3ea8695b7aab18f80faf9a573501 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_init.py @@ -0,0 +1,242 @@ +import itertools +import logging +from typing import Optional, Union + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch._logging import warning_once +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.tensor import DeviceMesh, DTensor, init_device_mesh +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + +from ._fsdp_common import _is_composable_with_fsdp, FSDPMeshInfo, HSDPMeshInfo +from ._fsdp_state import _get_module_fsdp_state + + +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + + +def _get_post_forward_mesh_info( + reshard_after_forward: Union[bool, int], mesh_info: FSDPMeshInfo +) -> Optional[FSDPMeshInfo]: + shard_mesh_size = mesh_info.shard_mesh_size + if not isinstance(reshard_after_forward, (bool, int)): + raise ValueError( + "reshard_after_forward should be a bool or an int representing the " + f"group size to reshard to, not {reshard_after_forward}" + ) + # NOTE: `isinstance(False, int)` returns `True`. + if not isinstance(reshard_after_forward, bool) and isinstance( + reshard_after_forward, int + ): + if ( + reshard_after_forward < 1 + or reshard_after_forward > shard_mesh_size + or shard_mesh_size % reshard_after_forward != 0 + ): + raise ValueError( + "If passing reshard_after_forward as an int, it should be a " + f"factor of {shard_mesh_size}, not {reshard_after_forward}" + ) + elif reshard_after_forward == 1: + msg = ( + "reshard_after_forward=1 (int) means resharding parameters to world size 1, " + "instead of reshard_after_forward=True (bool)" + ) + warning_once(logger, msg, stacklevel=2) + reshard_after_forward = False + elif reshard_after_forward == shard_mesh_size: + reshard_after_forward = True + post_forward_mesh_info = None + if reshard_after_forward is True: + post_forward_mesh_info = mesh_info + elif reshard_after_forward is not False: # int case + # For HSDP, we can flatten the two replicate dims into the 0th dim + post_forward_mesh_tensor = mesh_info.mesh.mesh.view(-1, reshard_after_forward) + post_forward_mesh = DeviceMesh( + mesh_info.mesh.device_type, post_forward_mesh_tensor + ) + post_forward_mesh_info = HSDPMeshInfo( + post_forward_mesh, shard_mesh_dim=1, replicate_mesh_dim=0 + ) + return post_forward_mesh_info + + +def _init_default_fully_shard_mesh() -> DeviceMesh: + """Default to global CUDA mesh if possible else global CPU mesh.""" + if not dist.distributed_c10d.is_initialized(): + dist.distributed_c10d.init_process_group() + default_pg = dist.distributed_c10d._get_default_group() + device = torch._C._get_accelerator() + mesh = init_device_mesh(device.type, mesh_shape=(default_pg.size(),)) + return mesh + + +def _get_device_from_mesh(mesh: DeviceMesh) -> torch.device: + if mesh.device_type == "cpu": + return torch.device("cpu") + device_handle = _get_device_handle(mesh.device_type) + return torch.device(mesh.device_type, device_handle.current_device()) + + +def _ignore_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignore_decision: dict[nn.Module, bool], +) -> bool: + """ + Decide if it is safe to ignore a module for applying fully_shard. + """ + if module in ignore_decision: + return ignore_decision[module] + + if len(list(module.buffers(recurse=False))) > 0: + # Cannot ignore a module with any buffer + ignore_decision[module] = False + return False + + for _, param in module.named_parameters(recurse=False): + if param not in ignored_params: + # at least one param is not ignored. So this module shouldn't be. + ignore_decision[module] = False + return False + + # Need to consider descendants of module + for child in list(module.children()): + ignore_child = _ignore_module(child, ignored_params, ignore_decision) + if not ignore_child: + # Cannot ignore module if one of its children is not ignored + ignore_decision[module] = False + return False + + # Safe to ignore module + ignore_decision[module] = True + return True + + +def _adjust_managed_modules( + modules: list[nn.Module], ignored_params: set[nn.Parameter] +) -> list[nn.Module]: + """ + Adjust the given list of managed modules by removing those with all parameters ignored. + """ + ignore_decision: dict[nn.Module, bool] = {} + new_modules = [] + for module in modules: + ignored = _ignore_module(module, ignored_params, ignore_decision) + if not ignored: + new_modules.append(module) + return new_modules + + +def _get_managed_modules( + root_modules: tuple[nn.Module, ...], + ignored_params: Optional[set[nn.Parameter]] = None, +) -> list[nn.Module]: + modules: list[nn.Module] = [] + root_modules_set = set(root_modules) + # Track visisted modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = set() + + def dfs(module: nn.Module) -> None: + """ + Runs a DFS to collect managed modules, not recursing into modules with + a non-composable API or ``fully_shard`` already applied. + """ + if not _is_composable_with_fsdp(module): + return + elif ( + module not in root_modules_set + and _get_module_fsdp_state(module) is not None + ): + return # nested `fully_shard` module + visited_modules.add(module) + for submodule in module.children(): + if submodule not in visited_modules: + dfs(submodule) + modules.append(module) + + for root_module in root_modules: + dfs(root_module) + + if ignored_params is None: + return modules + + adjusted_modules = _adjust_managed_modules(modules, ignored_params) + return adjusted_modules + + +def _verify_managed_param(name: str, param: nn.Parameter) -> None: + """ + Verify if the parameter is accepted by fully_shard. The only restriction now + is that the parameter cannot be a scalar tensor (param.numel == 0) since we + need at least one dim to shard. + """ + if len(param.shape) == 0: + raise ValueError( + "fully_shard doesn't support scalar parameters. " + f"Change {name} to a 1D tensor with numel equal to 1." + ) + + +def _get_managed_states( + modules: list[nn.Module], ignored_params: Optional[set[nn.Parameter]] = None +) -> tuple[list[nn.Parameter], list[torch.Tensor]]: + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] + # Track visited parameters/buffers to avoid visiting shared parameters and + # buffers multiple times + visited_params: set[nn.Parameter] = set() + visited_buffers: set[torch.Tensor] = set() + if ignored_params is None: + ignored_params = set() + + for module in modules: + for name, param in module.named_parameters(recurse=False): + if param in ignored_params: + # do not include an ignored parameters + continue + if param not in visited_params: + _verify_managed_param(name, param) + params.append(param) + visited_params.add(param) + for buffer in module.buffers(recurse=False): + if buffer not in visited_buffers: + buffers.append(buffer) + visited_buffers.add(buffer) + return params, buffers + + +def _move_states_to_device( + params: list[nn.Parameter], + buffers: list[torch.Tensor], + device: torch.device, +) -> None: + """ + We have FSDP move states to device for simpler and faster initialization + since FSDP almost always uses CUDA for training. We move parameters/buffers + rather than modules since modules to support ignoring parameters/buffers in + the future. + """ + # Follow the logic in `nn.Module._apply` + for tensor in itertools.chain(params, buffers): + if tensor.device == device or tensor.device.type == "meta": + # Keep meta-device tensors on meta device for deferred init + continue + if isinstance(tensor, DTensor): + if (dtensor_mesh_type := tensor.device_mesh.device_type) != device.type: + raise ValueError( + "Requires DTensor to have mesh of the same type as the FSDP mesh " + f"but got {dtensor_mesh_type} for DTensor and {device.type} for FSDP" + ) + raise AssertionError( + f"Expects DTensor to be moved to {dtensor_mesh_type} but got {tensor.device}" + ) + tensor_ = tensor + if is_traceable_wrapper_subclass(tensor_): + with torch.no_grad(): # avoid autograd increasing C++ refcount by 1 + tensor_on_device = nn.Parameter(tensor.to(device)) + torch.utils.swap_tensors(tensor, tensor_on_device) + else: + tensor.data = tensor.to(device) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py new file mode 100644 index 0000000000000000000000000000000000000000..5e674919481483e385e0298140ee4a2607ebf3d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -0,0 +1,896 @@ +# mypy: allow-untyped-defs +import inspect +import itertools +from collections.abc import Sequence +from dataclasses import dataclass, field +from enum import auto, Enum +from typing import Any, Callable, cast, Optional + +import torch +import torch.nn as nn +from torch._prims_common import make_contiguous_strides_for +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import _mesh_resources +from torch.distributed.tensor.placement_types import _StridedShard, Placement + +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_common import ( + _chunk_with_empty, + _from_local_no_grad, + _get_dim_chunked_size, + _raise_assert_with_print, + _to_dtype_if_needed, + compiled_autograd_enabled, + FSDPMeshInfo, + HSDPMeshInfo, +) + + +""" +[Note: FSDP tensors] +FSDP considers the following tensors: +- Original parameter: parameter passed to :class:`FSDPParam`, i.e. the one + on the module when applying FSDP +- Sharded parameter: sharding the original parameter on dim-0 (or a + user-specified dim) as a DTensor over the main mesh +- All-gather inputs: the ``torch.Tensor`` or ``Tensor`` s passed to all-gather, + derived from the sharded parameter +- All-gather output: the ``torch.Tensor`` or ``Tensor`` s resulting from + all-gathering the all-gather inputs +- Unsharded parameter: parameter used for forward/backward computation, derived + from the all-gather output; autograd leaf + +We define these tensors to describe the general framework that can accommodate +extensions, where: +- all-gather-inputs = pre-all-gather-transform(sharded-parameter) +- unsharded-parameter = post-all-gather-transform(all-gather-outputs) + +For the default ``torch.Tensor`` case, there is only one all-gather input, and +it shares the same underlying tensor data as the sharded parameter, meaning +that they can be thought of as the same tensors. The same applies for the +all-gather output and unsharded parameter. For non-``torch.Tensor`` extensions, +these equivalences may no longer hold due to the pre/post-all-gather +transforms, and some may have multiple all-gather inputs/outputs (e.g. +quantized data and scales). + +[Note: FSDP and autograd] +FSDP dynamically frees and allocates the unsharded parameter. Since autograd +can pack a reference to it or a view to save for backward, we use storage +resizing to implement the freeing/allocation since that preserves the aliasing. +This implies that we construct the unsharded parameter object once and write to +it in-place thereafter. For the default ``torch.Tensor` original parameter +case, the all-gather output and unsharded parameter share the same +data, so we use storage resizing on the all-gather output. +""" + +lib = torch.library.Library("fsdp", "FRAGMENT") # noqa: TOR901 + +lib.define("copy_(Tensor(a!) tensor, Tensor data) -> ()") + + +@torch.library.impl(lib, "copy_", "Meta") +@torch.library.impl(lib, "copy_", "CUDA") +@torch.library.impl(lib, "copy_", "XPU") +@torch.library.impl(lib, "copy_", "HPU") +@torch.library.impl(lib, "copy_", "CPU") +@torch.library.impl(lib, "copy_", "MTIA") +def copy_(tensor, data): + tensor.copy_(data) + + +""" +[Note: Avoiding functionalization for fsdp.copy_ and inductor.resize_storage_bytes_] + +Currently we don't functionalize `fsdp.copy_` op or `inductor.resize_storage_bytes_` op +(i.e. they show up as a mutation op in the middle of the AOT joint graph). + +Reason: +Traceable FSDP2 compiled autograd BWD graph have the following traits: +(1) Two inputs of the graph were aliased to each other (one from hook closed-over tensors, one from FWD saved tensors). +(2) One of them is mutated (copy_ and resize_ to handle the all-gathered param). +(3) They are both subclasses. +The combination of these traits is not supported by AOTAutograd (it's difficult to reason about subclass aliasing). +So this doesn't work at all for Traceable FSDP2. + +The compromise we use is to avoid functionalization for the FSDP2 copy_ and resize_ ops. +This avoids the problem above, because from AOTAutograd point-of-view there are no mutations +that functionalization needs to handle. (Although we need to be careful not to DCE those mutable ops.) + +We can avoid this functionalization because: +(1) The nn.Parameter is never used before its .copy_() is called in eager code (i.e. no alias of it is created), +so it's safe to call .copy_() in the middle of the graph to update its content and start using the nn.Parameter downstream. +(2) We always re-allocate the buffer for nn.Parameter to store the AllGather output and to be used in downstream user ops. +So calling resize-to-0 in the middle of the graph to free nn.Parameter memory after use should always be okay +(since we always allocate anew next time we need it, we strictly don't need to keep the old tensor storage around anymore). + +Q: Wouldn't the extra resize_ and copy_ ops hurt both memory usage and performance? +A: Yes it would. As an optimization, we have an Inductor post-grad FX pass to remove those resize_ and copy_ ops +for unsharded params that have this pattern: resize_(full) -> copy_ -> resize_(0). + +TODO: +Now that we are maintaining the invariant of "no aliased + mutated graph inputs" in both the forward and backward, +it is now more feasible to functionalize all of the mutable FSDP ops. Some of the pros and cons are: + +Cons (of functionalizing those ops): +(1) By not functionalizing them as we are today, we are making it more likely that they will run at the "correct" time +in the generated code. If we start to functionalize them, we will need to make sure that Inductor reinplaces them +in a way where it properly moves the mutations back to exactly where they should have run, or we risk suffering worse +peak memory than eager. (We probably already need to do something similar in Inductor's reinplacing for copy_: +https://github.com/pytorch/pytorch/issues/135305#issuecomment-2334888089) + +Pros (of functionalizing): +(1) Better safety, we don't need to worry about the graph passes in inductor/partitioning handling input mutations +mid-graph quite as much (to be fair we've already done some amount of auditing, but we might have to do some more). +(2) Better perf: each mutation midway through the graph prevents Inductor from pattern matching across it. +But maybe there are few enough mutations induced by FSDP for this to matter. +""" + + +@torch.library.impl(lib, "copy_", "Functionalize") +def copy__functionalize(tensor, data): + torch._sync(tensor) + torch._sync(data) + tensor_inner = torch._from_functional_tensor(tensor) + data_inner = torch._from_functional_tensor(data) + with torch._C._ExcludeDispatchKeyGuard( + torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize) + ): + torch.ops.fsdp.copy_.default(tensor_inner, data_inner) + + +if not torch._running_with_deploy(): + torch.fx.node.has_side_effect(torch.ops.fsdp.copy_.default) + + +class ShardedState(Enum): + """ + - ``SHARDED``: The sharded parameter is registered to the module. It is the + only contributor to parameter memory. + - ``SHARDED_POST_FORWARD``: The unsharded parameter is resharded to a + smaller world size. Since this data should not be used for computation, + we do not register it to the module. Users should reshard the module + before any in-place modifications. Both it and the sharded parameter + contribute to parameter memory. + - ``UNSHARDED``: The unsharded parameter is registered to the module. Both + it and the sharded parameter contribute to parameter memory. + """ + + SHARDED = auto() + SHARDED_POST_FORWARD = auto() + UNSHARDED = auto() + + +@dataclass +class ParamModuleInfo: + """ + For a parameter, this stores the module and the parameter name to be able + to do a parameter swap via ``setattr(module, param_name, ...)`` or to get + the parameter via ``getattr(module, param_name)``. We additionally save + shared modules and shared parameter names to update them accordingly. + """ + + # Parameter names are unprefixed, e.g. "weight", not "lin.weight" + module: nn.Module + param_name: str + shared_modules: list[nn.Module] = field(default_factory=list) + shared_param_names: list[str] = field(default_factory=list) + + +@dataclass +class ExtensionsData: + # User-defined metadata passed from pre to post-all-gather + all_gather_metadata: Optional[Any] = None + # Save the all-gather input sizes to unflatten the all-gather outputs to ND + all_gather_input_sizes: Sequence[torch.Size] = () # ND + + def clear(self): + self.all_gather_metadata = None + self.all_gather_input_sizes = () + + +class FSDPParam: + """ + This class manages a parameter with FSDP or FSDP variants applied, + implementing dim-0 per-parameter sharding. + """ + + orig_dtype: torch.dtype + param_dtype: Optional[torch.dtype] + reduce_dtype: Optional[torch.dtype] + _orig_size: torch.Size # ND + sharded_size: torch.Size # ND + contiguous_sharded_stride: tuple[int, ...] + padded_sharded_param_size: torch.Size # ND + sharded_post_forward_size: torch.Size # ND + contiguous_sharded_post_forward_stride: tuple[int, ...] + _sharded_param_data: torch.Tensor # 1D + sharded_param: nn.Parameter # ND + _sharded_post_forward_param_data: Optional[torch.Tensor] # 1D + _sharded_post_forward_param: Optional[nn.Parameter] # ND + _unsharded_param: nn.Parameter # ND + unsharded_accumulated_grad: Optional[torch.Tensor] # ND + _sharding_spec: DTensorSpec + # DTensor attributes (only defined for DTensor `param`): + _tp_spec: DTensorSpec + all_gather_outputs: list[torch.Tensor] # 1D + # All-gather extension attributes + _extensions_data: ExtensionsData + _unsharded_inner_tensors: list[torch.Tensor] + + def __init__( + self, + param: nn.Parameter, + module_info: ParamModuleInfo, + mesh_info: FSDPMeshInfo, + post_forward_mesh_info: Optional[FSDPMeshInfo], + device: torch.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], + mp_policy: MixedPrecisionPolicy, + offload_policy: OffloadPolicy, + ): + self._module_info: ParamModuleInfo = module_info + self.mesh_info = mesh_info + self.post_forward_mesh_info = post_forward_mesh_info + self.device = device + self.mp_policy = mp_policy + self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) + self.pin_memory = ( + self.offload_to_cpu and cast(CPUOffloadPolicy, offload_policy).pin_memory + ) + self.grad_offload_event: Optional[torch.Event] = None + self._init_sharded_param(param, device, shard_placement_fn) + if self.post_forward_mesh_info: + self._init_sharded_post_forward_param_metadata(param) + self._init_extensions() + self.all_gather_outputs: list[torch.Tensor] = [] + self.unsharded_accumulated_grad = None + self._param_fqn: Optional[str] = None # prefixed from root module + # TODO: Remove this padding logic once DTensor pads the local tensor: + # https://github.com/pytorch/pytorch/issues/113045 + self._post_load_hook_handle = ( + module_info.module.register_load_state_dict_post_hook( + lambda *args, **kwargs: self.reset_sharded_param() + ) + ) + + @torch.no_grad() + def _init_sharded_param( + self, + param: nn.Parameter, + device: torch.device, + shard_placement_fn: Optional[Callable], + ): + if param.device != device and param.device.type != "meta": + raise AssertionError( + f"Expects the parameter to already be moved to device {device} but got {param.device}" + ) + if not param.is_contiguous(): + raise NotImplementedError( + f"FSDP does not support non-contiguous parameters yet: {param.shape=} {param.stride()=}" + ) + fsdp_placement = shard_placement_fn(param) if shard_placement_fn else None + if fsdp_placement is None: + fsdp_placement = Shard(0) + elif fsdp_placement.dim < 0: + fsdp_placement = Shard(fsdp_placement.dim + param.ndim) + assert isinstance(fsdp_placement, Shard), f"{fsdp_placement}" + self.fsdp_placement = fsdp_placement + shard_dim = fsdp_placement.dim + # TODO: Replace the sharded DTensor parameter construction logic with + # `distribute_tensor` after https://github.com/pytorch/pytorch/issues/116101 + # TODO: Simplify the following sharded parameter padding logic after + # https://github.com/pytorch/pytorch/issues/113045 + self.is_dtensor = isinstance(param, DTensor) + if self.is_dtensor: + self._tp_spec = cast(DTensor, param)._spec + dp_mesh, tp_mesh = (self.mesh_info.mesh, self._tp_spec.mesh) + dp_global_mesh = _mesh_resources.get_root_mesh(dp_mesh) + tp_global_mesh = _mesh_resources.get_root_mesh(tp_mesh) + if dp_global_mesh != tp_global_mesh or ( + dp_global_mesh is None or tp_global_mesh is None + ): + raise AssertionError( + "FSDP requires the DP and TP mesh to have the same parent mesh but got: \n" + f"DP's global mesh: {dp_global_mesh}\nTP's global mesh: {tp_global_mesh}" + ) + name_dims_error = "FSDP requires named DeviceMesh dims for ND parallelism" + assert dp_mesh.mesh_dim_names is not None, name_dims_error + assert tp_mesh.mesh_dim_names is not None, name_dims_error + submesh_names = dp_mesh.mesh_dim_names + tp_mesh.mesh_dim_names + self._spmd_mesh = dp_global_mesh[submesh_names] + if len(self._tp_spec.placements) != 1: + raise NotImplementedError( + f"FSDP only supports 1D TP, not {self._tp_spec.placements}" + ) + split_factor = self._tp_spec.num_shards_map[shard_dim] + assert 2 <= self._spmd_mesh.ndim <= 3, ( + f"_spmd_mesh.ndim can only be 2 or 3 but got {self._spmd_mesh.ndim}." + ) + self._spmd_placements: tuple[Placement, ...] + dp_shard_tp_placement = ( + ( + _StridedShard(shard_dim, split_factor=split_factor) + if split_factor > 1 + else fsdp_placement + ), + self._tp_spec.placements[0], + ) + if self._spmd_mesh.ndim == 2: + self._spmd_placements = dp_shard_tp_placement + else: + assert self.mesh_info.replicate_mesh_dim == 0 + self._spmd_placements = (Replicate(),) + dp_shard_tp_placement + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=self._tp_spec.tensor_meta, + ) + param_data = cast(DTensor, param)._local_tensor + else: + self._spmd_mesh = self.mesh_info.mesh + if isinstance(self.mesh_info, HSDPMeshInfo): + self._spmd_placements = (Replicate(), fsdp_placement) + else: + self._spmd_placements = (fsdp_placement,) + self._sharding_spec = DTensorSpec( + self._spmd_mesh, + self._spmd_placements, + tensor_meta=TensorMeta(param.size(), param.stride(), param.dtype), + ) + param_data = param + assert param_data.is_contiguous(), f"{param_data.shape=} {param_data.stride()=}" + shard_dim = fsdp_placement.dim + if shard_dim >= param_data.ndim: + raise AssertionError( + f"Shard dim {shard_dim} is invalid for {param_data.ndim}D tensor: {param.shape}" + ) + self._orig_size = param_data.size() + self._contiguous_orig_stride = make_contiguous_strides_for(self._orig_size) + shard_rank = self.mesh_info.shard_mesh_rank + shard_world_size = self.mesh_info.shard_mesh_size + if shard_dim > 0 and param_data.size(shard_dim) % shard_world_size != 0: + # If sharding on nonzero dim, require even sharding for now because + # the uneven sharding (1) requires extra copies before/after FSDP + # collectives and (2) introduces extra complexity to handle padding + # and unpadding + raise NotImplementedError( + f"FSDP does not support uneven sharding on dim {shard_dim}: " + f"{param_data.size()} (world size: {shard_world_size})" + ) + chunks = _chunk_with_empty(param_data, shard_world_size, dim=shard_dim) + sharded_param = chunks[shard_rank] + self.sharded_size = _get_dim_chunked_size( + sharded_param, param_data.size(), dim=shard_dim + ) + self.contiguous_sharded_stride = make_contiguous_strides_for(self.sharded_size) + padded_sharded_size = chunks[0].size() # 0th always padded + self.padded_sharded_param_size = padded_sharded_size + # Pre-pad the sharded parameter to avoid padding before all-gather + padded_sharded_param = param_data.new_zeros(padded_sharded_size) + if sharded_param.numel() > 0: + padded_sharded_param.narrow( + dim=shard_dim, start=0, length=sharded_param.size(shard_dim) + ).copy_(sharded_param) + if self.offload_to_cpu and not padded_sharded_param.is_meta: + padded_sharded_param = padded_sharded_param.cpu() + if self.pin_memory: + padded_sharded_param = padded_sharded_param.pin_memory( + device=self.device + ) + self._sharded_param_data = padded_sharded_param.view(-1) + length = sharded_param.size(shard_dim) if sharded_param.numel() > 0 else 0 + sharded_param = padded_sharded_param.narrow( + dim=shard_dim, start=0, length=length + ) + assert sharded_param.is_contiguous(), f"{self.fsdp_placement=}" + self.sharded_param = nn.Parameter(self.to_sharded_dtensor(sharded_param)) + self.sharded_param.requires_grad_(param.requires_grad) + # Let `param_data` be freed normally when its ref count reaches 0 when + # the `fully_shard` call returns to allow provided parameters to alias + self._setattr_on_modules(self.sharded_param) + self.sharded_state = ShardedState.SHARDED + + def _init_sharded_post_forward_param_metadata(self, param: torch.Tensor) -> None: + mesh_info = self.post_forward_mesh_info + assert mesh_info is not None # mypy + param_data = param._local_tensor if isinstance(param, DTensor) else param + chunks = _chunk_with_empty(param_data, mesh_info.shard_mesh_size, dim=0) + self.sharded_post_forward_size = _get_dim_chunked_size( + chunks[mesh_info.shard_mesh_rank], + param_data.size(), + dim=self.fsdp_placement.dim, + ) + self.contiguous_sharded_post_forward_stride = make_contiguous_strides_for( + self.sharded_post_forward_size + ) + + def init_dtype_attrs(self, mp_policy: MixedPrecisionPolicy): + param_dtype, reduce_dtype = (mp_policy.param_dtype, mp_policy.reduce_dtype) + self.orig_dtype = self.sharded_param.dtype + # Clamp `reduce_dtype` to `None` if no casting is required: since + # gradients are computed in `param_dtype`, if `reduce_dtype` matches, + # then we do not need extra casting + if reduce_dtype == param_dtype: + reduce_dtype = None + # Clamp `param_dtype` to `None` if no casting is required + if param_dtype == self.orig_dtype: + param_dtype = None + self.param_dtype = param_dtype + self.reduce_dtype = reduce_dtype + # None indicates that the mixed precision is not enabled + + def _init_extensions(self) -> None: + inner_tensor = self._sharded_local_tensor + has_fsdp_pre_all_gather = hasattr(inner_tensor, "fsdp_pre_all_gather") + has_fsdp_post_all_gather = hasattr(inner_tensor, "fsdp_post_all_gather") + if has_fsdp_pre_all_gather != has_fsdp_post_all_gather: + raise AssertionError( + "Both fsdp_pre_all_gather and fsdp_post_all_gather should be defined " + f"if using all-gather extensions: {inner_tensor}" + ) + if has_fsdp_pre_all_gather: + self._extensions_data = ExtensionsData() + self._unsharded_inner_tensors: list[torch.Tensor] = [] + + def init_all_gather_outputs( + self, + all_gather_input_numels: list[int], + all_gather_input_dtypes: list[torch.dtype], + world_size: int, + device: torch.device, + force_recreate: bool = False, + ): + if not force_recreate and len(self.all_gather_outputs) > 0: + return # already initialized + self.all_gather_outputs = [ + torch.empty(torch.Size([numel * world_size]), dtype=dtype, device=device) + for numel, dtype in zip(all_gather_input_numels, all_gather_input_dtypes) + ] + + def init_unsharded_param(self): + """ + [Note: Invariants for torch.compile Traceable FSDP2] + 1. Under compile, we always re-populate the content of `self._unsharded_param` + per AllGather using the slow path. + 2. Under compile, we always recreate `self.all_gather_outputs` per AllGather. + This is to ensure the buffer creation is internal to the graph and + avoid `self.all_gather_outputs` being captured as a graph input. + 3. Under compile, at the end of `free_unsharded_param()`, we always clean up + `self.all_gather_outputs` and `self._unsharded_inner_tensors`, + to avoid them being captured as graph output. + + With these invariants, only these tensors will be inputs to the graph: + - Sharded parameters + - Placeholders for the `self._unsharded_param` nn.Parameter + """ + if not compiled_autograd_enabled() and hasattr( + self, "_unsharded_param" + ): # after the 1st all-gather + inner_tensor = self._sharded_local_tensor + if not hasattr(inner_tensor, "fsdp_post_all_gather"): + return # already initialized + for tensor in self._unsharded_inner_tensors: + alloc_storage(tensor) + all_gather_outputs = self._unflatten_all_gather_outputs() + inner_tensor.fsdp_post_all_gather( + all_gather_outputs, + self._extensions_data.all_gather_metadata, + self.param_dtype or self.orig_dtype, + out=self._unsharded_param, + ) + self._extensions_data.clear() + return + inner_tensor = self._sharded_local_tensor + if not compiled_autograd_enabled() and hasattr( + inner_tensor, "fsdp_post_all_gather" + ): + all_gather_outputs = self._unflatten_all_gather_outputs() + ( + unsharded_tensor, + self._unsharded_inner_tensors, + ) = inner_tensor.fsdp_post_all_gather( + all_gather_outputs, + self._extensions_data.all_gather_metadata, + self.param_dtype or self.orig_dtype, + ) + self._extensions_data.clear() + else: + # For the default path (no post-all-gather), the all-gather output + # gives the unsharded parameter data directly + assert len(self.all_gather_outputs) == 1, f"{len(self.all_gather_outputs)}" + unsharded_tensor = self.all_gather_outputs[0] + unsharded_param = torch.as_strided( + unsharded_tensor, + self._orig_size, + self._contiguous_orig_stride, + storage_offset=0, + ) + if self.is_dtensor: + unsharded_param = _from_local_no_grad(unsharded_param, self._tp_spec) + if hasattr(self, "_unsharded_param"): + assert compiled_autograd_enabled() + with ( + torch.no_grad(), + torch.autograd._unsafe_preserve_version_counter(self._unsharded_param), + ): + # NOTE: Under compile, if an unsharded param goes through + # resize_(full) -> copy_ -> resize_(0) pattern, we will remove those + # resize_ and copy_ ops in a compiler graph pass + # `remove_fsdp2_unsharded_param_graph_input_usage` to recover performance. + self._unsharded_param.untyped_storage().resize_( + self._unsharded_param.numel() * self._unsharded_param.itemsize + ) + torch.ops.fsdp.copy_(self._unsharded_param, unsharded_param) + else: + self._unsharded_param = nn.Parameter( + unsharded_param, requires_grad=self.sharded_param.requires_grad + ) + + def _unflatten_all_gather_outputs(self) -> tuple[torch.Tensor, ...]: + return tuple( + t.view(-1, *s[1:]) + for t, s in zip( + self.all_gather_outputs, self._extensions_data.all_gather_input_sizes + ) + ) + + def to_sharded(self) -> None: + self._setattr_on_modules(self.sharded_param) + self.free_unsharded_param() + self.sharded_state = ShardedState.SHARDED + + def to_sharded_post_forward(self) -> None: + if self.is_dtensor: + raise NotImplementedError( + "Resharding to smaller mesh with TP is not supported yet" + ) + self._assert_in_states(ShardedState.UNSHARDED) + assert self.post_forward_mesh_info is not None # mypy + assert len(self.all_gather_outputs) == 1 + shard_world_size = self.post_forward_mesh_info.shard_mesh_size + if (numel := self.all_gather_outputs[0].numel()) % shard_world_size != 0: + _raise_assert_with_print( + f"All-gather output size ({numel}) must be divisible by the shard " + f"world size ({shard_world_size})" + ) + shard_rank = self.post_forward_mesh_info.shard_mesh_rank + sharded_numel = numel // shard_world_size + self._sharded_post_forward_param_data = ( + self.all_gather_outputs[0].narrow( + 0, sharded_numel * shard_rank, sharded_numel + ) + ).clone() # clone to be able to free all-gather output + sharded_post_forward_tensor = torch.as_strided( + self._sharded_post_forward_param_data, + size=self.sharded_post_forward_size, + stride=self.contiguous_sharded_post_forward_stride, + storage_offset=0, + ) + self._sharded_post_forward_param = nn.Parameter( + self.to_sharded_post_forward_dtensor(sharded_post_forward_tensor) + ) + self._setattr_on_modules(self._sharded_post_forward_param) + self.free_unsharded_param() + self.sharded_state = ShardedState.SHARDED_POST_FORWARD + + def to_unsharded(self) -> None: + # Assume that the data has been allocated and all-gathered + set_requires_grad_if_needed(self.sharded_param, self._unsharded_param) + self._setattr_on_modules(self._unsharded_param) + if self.sharded_state == ShardedState.SHARDED_POST_FORWARD: + # The data is allocated in the default stream via the post-forward + # reshard and must be kept alive for the next all-gather copy-in. + # Since we call this method after the copy-out, the data's lifetime + # is ensured without further synchronization. + self._sharded_post_forward_param = None + self._sharded_post_forward_param_data = None # free + self.sharded_state = ShardedState.UNSHARDED + + def _setattr_on_modules(self, param: nn.Parameter) -> None: + unsafe_setattr_param( + self._module_info.module, self._module_info.param_name, param + ) + for shared_module, shared_param_name in zip( + self._module_info.shared_modules, self._module_info.shared_param_names + ): + unsafe_setattr_param(shared_module, shared_param_name, param) + + def to_sharded_dtensor(self, tensor: torch.Tensor) -> DTensor: + """ + Converts a local tensor representing either the sharded parameter or + sharded gradient to DTensor. + """ + if tensor.shape != self.sharded_size: + _raise_assert_with_print( + f"Expects size {self.sharded_size} but got {tensor.shape}" + ) + return _from_local_no_grad( + tensor, + self._sharding_spec, + ) + + def to_sharded_post_forward_dtensor(self, tensor: torch.Tensor) -> DTensor: + if tensor.shape != self.sharded_post_forward_size: + _raise_assert_with_print( + f"Expects size {self.sharded_post_forward_size} but got {tensor.shape}" + ) + assert isinstance(self.post_forward_mesh_info, HSDPMeshInfo) + # TODO: Prefer this DTensor to be read-only and generalize the + # placement once we support TP. + post_forward_sharding_spec = DTensorSpec( + self.post_forward_mesh_info.mesh, + (Replicate(), Shard(0)), + tensor_meta=self._sharding_spec.tensor_meta, + ) + return _from_local_no_grad(tensor, post_forward_sharding_spec) + + def to_accumulated_grad_if_needed(self) -> None: + # Access `_unsharded_param` to bypass the sharded state check since we + # prefer to reshard before upcasting the gradient to save memory + if ( + self.reduce_dtype is None + or self._unsharded_param.grad is None + or self._unsharded_param.grad.dtype == self.reduce_dtype + ): + return + unsharded_grad = self._unsharded_param.grad + self._unsharded_param.grad = None + self.unsharded_accumulated_grad = unsharded_grad.to(self.reduce_dtype) + + def accumulate_unsharded_grad_if_needed(self) -> None: + if ( + self.unsharded_accumulated_grad is not None + and self.unsharded_param.grad is not None + ): + self.unsharded_accumulated_grad += self.unsharded_param.grad + self.unsharded_param.grad = None + + def alloc_all_gather_outputs(self) -> None: + for tensor in self.all_gather_outputs: + alloc_storage(tensor) + + def free_unsharded_param(self) -> None: + if compiled_autograd_enabled(): + """ + Assumptions under compile: + - `self._unsharded_param` is NOT an alias of `self.all_gather_outputs`. + Instead, we resize `self._unsharded_param` storage size to full and then + explicitly *copy* the data from `self.all_gather_outputs` to `self._unsharded_param` + in `init_unsharded_param()`. (For full-graph FSDP2 case, we will then remove + the resize_ and copy_ ops in a compiler graph pass to recover performance.) + - `self.all_gather_outputs` and `self._unsharded_inner_tensors` are NOT + graph inputs. They are created within the graph and is guaranteed to be freed + by the end of the graph. They don't leak outside of the graph. + """ + self._unsharded_param.untyped_storage().resize_(0) + self.all_gather_outputs = [] + self._unsharded_inner_tensors = [] + else: + for tensor in itertools.chain( + self.all_gather_outputs, self._unsharded_inner_tensors + ): + free_storage(tensor) + + @property + def all_gather_inputs(self) -> list[torch.Tensor]: # 1D + self._assert_in_states(ShardedState.SHARDED, ShardedState.SHARDED_POST_FORWARD) + if self.sharded_state == ShardedState.SHARDED: + if not compiled_autograd_enabled() and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): + sharded_local_tensor = self._sharded_local_tensor + if self.offload_to_cpu: + sharded_local_tensor = sharded_local_tensor.to( + self.device, non_blocking=True + ) + pre_all_gather_signature = inspect.signature( + sharded_local_tensor.fsdp_pre_all_gather + ) + num_fn_params = len(pre_all_gather_signature.parameters) + # Old signature only passes mesh; keep for BC for now + assert num_fn_params in ( + 1, + 5, + ), ( + f"Invalid fsdp_pre_all_gather: {pre_all_gather_signature}\n" + "Expects fsdp_pre_all_gather(self, mesh: DeviceMesh, " + "module: nn.Module, mp_policy: MixedPrecisionPolicy)" + ) + if num_fn_params == 1: + ( + all_gather_inputs, + self._extensions_data.all_gather_metadata, + ) = sharded_local_tensor.fsdp_pre_all_gather( + self.shard_mesh_from_root + ) + else: + ( + all_gather_inputs, + self._extensions_data.all_gather_metadata, + ) = sharded_local_tensor.fsdp_pre_all_gather( + self.shard_mesh_from_root, + self._orig_size, + self._contiguous_orig_stride, + self._module_info.module, + self.mp_policy, + ) + if ( + sharded_local_tensor.size() != self.padded_sharded_param_size + and any( + all_gather_input.size() != self.padded_sharded_param_size + for all_gather_input in all_gather_inputs + ) + ): + # NOTE: Since this error can only be raised on the + # ranks that have padding, this can manifest as a NCCL + # watchdog timeout, as the other ranks will not error. + raise AssertionError( + "When a parameter is unevenly sharded by FSDP " + f"(orig size={self._orig_size}, FSDP world size={self.mesh_info.mesh.size()}), " + "fsdp_pre_all_gather must return all-gather inputs with the padded sharded size " + f"{self.padded_sharded_param_size} but got {[t.size() for t in all_gather_inputs]}" + ) + self._extensions_data.all_gather_input_sizes = [ + t.size() for t in all_gather_inputs + ] + return [t.view(-1) for t in all_gather_inputs] + sharded_param_data = self._sharded_param_data + if self.offload_to_cpu: + sharded_param_data = sharded_param_data.to( + self.device, non_blocking=True + ) + return [_to_dtype_if_needed(sharded_param_data, self.param_dtype)] + elif self.sharded_state == ShardedState.SHARDED_POST_FORWARD: + if not compiled_autograd_enabled() and hasattr( + self._sharded_local_tensor, "fsdp_pre_all_gather" + ): + raise NotImplementedError + all_gather_input = _to_dtype_if_needed( + cast(torch.Tensor, self._sharded_post_forward_param_data), + self.param_dtype, + ) + return [all_gather_input] + return [torch.empty(0)] # mypy + + @property + def unsharded_param(self) -> nn.Parameter: # ND + return self._unsharded_param + + @property + def unsharded_grad_data(self) -> torch.Tensor: + grad = self.unsharded_param.grad + assert grad is not None, "Expects unsharded_param.grad to not be None" + return self._get_grad_inner_tensor(grad) + + @property + def unsharded_accumulated_grad_data(self) -> torch.Tensor: + grad = self.unsharded_accumulated_grad + assert grad is not None, "Expects unsharded_accumulated_grad to not be None" + return self._get_grad_inner_tensor(grad) + + def _get_grad_inner_tensor(self, grad: torch.Tensor) -> torch.Tensor: + if self.is_dtensor: + if isinstance(grad, AsyncCollectiveTensor): + grad = grad.wait() + assert isinstance(grad, DTensor), f"{type(grad)}" + placements = self._tp_spec.placements + if placements != grad.placements: + assert len(self._tp_spec.placements) == len(grad.placements), ( + f"{self._tp_spec=} {grad.placements=}" + ) + grad = grad.redistribute(placements=placements) + grad = grad._local_tensor + return grad + + @property + def _sharded_local_tensor(self) -> torch.Tensor: + return cast(DTensor, self.sharded_param)._local_tensor + + @property + def shard_mesh(self): + mesh = self.mesh_info.mesh + if mesh.ndim == 1: + return mesh + elif mesh.ndim == 2: + assert mesh.mesh_dim_names is not None + return mesh[mesh.mesh_dim_names[-1]] + raise ValueError(f"Invalid mesh: {mesh}") + + @property + def shard_mesh_from_root(self): + mesh = self.mesh_info.mesh + + if mesh.ndim == 1: + return mesh + else: + assert mesh.mesh_dim_names is not None + shard_dim_name = mesh.mesh_dim_names[-1] + + root_mesh = _mesh_resources.get_root_mesh(mesh) + return root_mesh[shard_dim_name] + + def _assert_in_states(self, *states: ShardedState) -> None: + if self.sharded_state not in states: + _raise_assert_with_print( + f"Expects to be in one of {states}, not {self.sharded_state}" + ) + + def reset_sharded_param(self): + # For ops like `nn.Module._apply` or `load_state_dict(assign=True)` + # that change the sharded parameter tensor, we may need to re-pad the + # sharded local tensor and re-save the reference. + module_info = self._module_info + new_param = getattr(module_info.module, module_info.param_name) + if new_param is not self.sharded_param: + if torch.__future__.get_swap_module_params_on_conversion(): + raise AssertionError( + f"Expects swap_tensors to preserve object but got {new_param} " + f"instead of {self.sharded_param}" + ) + self.sharded_param = new_param + local_tensor = new_param._local_tensor + if local_tensor.is_meta: + return + updated_local_tensor = False + padded_sharded_size = self.padded_sharded_param_size + shard_dim = self.fsdp_placement.dim + length = local_tensor.size(shard_dim) if local_tensor.numel() > 0 else 0 + if local_tensor.size() != padded_sharded_size: + assert shard_dim == 0, ( + f"Shard({shard_dim}) requires even sharding: {local_tensor.size()=}" + ) + padded_local_tensor = local_tensor.new_zeros(padded_sharded_size) + padded_local_tensor.narrow(dim=shard_dim, start=0, length=length).copy_( + local_tensor + ) + local_tensor = padded_local_tensor + updated_local_tensor = True + if self.pin_memory and not local_tensor.is_pinned(): + local_tensor = local_tensor.cpu().pin_memory(device=self.device) + updated_local_tensor = True + self._sharded_param_data = local_tensor.view(-1) + assert isinstance(self.sharded_param, DTensor) # mypy + if updated_local_tensor: + # Only change the local tensor object if needed + self.sharded_param._local_tensor = local_tensor.narrow( + dim=shard_dim, start=0, length=length + ) + assert self.sharded_param._local_tensor.is_contiguous() + self._sharding_spec = self.sharded_param._spec + + def __repr__(self): + return f"FSDPParam(fqn={self._param_fqn}, orig_size={self._orig_size})" + + +def alloc_storage(tensor: torch.Tensor) -> None: + size = tensor.numel() * tensor.itemsize + if (storage := tensor.untyped_storage()).size() != size: + storage.resize_(size) + + +def free_storage(tensor: torch.Tensor) -> None: + if (storage := tensor.untyped_storage()).size() != 0: + storage.resize_(0) + + +# NOTE: These bypass `nn.Module.__setattr__` checks, which incur non-trivial +# CPU overhead, if the module did not override it. For FSDP, we know we do not +# need those checks when transitioning between sharded/unsharded parameters. +def unsafe_setattr_param( + module: nn.Module, param_name: str, param: nn.Parameter +) -> None: + if getattr(module.__setattr__, "__func__", None) is nn.Module.__setattr__: + module._parameters[param_name] = param + else: # slow path + setattr(module, param_name, param) + + +def set_requires_grad_if_needed( + src_tensor: torch.Tensor, dst_tensor: torch.Tensor +) -> None: + # Only call `requires_grad_` if needed to avoid the Python <> C++ context + # switch overhead + if src_tensor.requires_grad != dst_tensor.requires_grad: + dst_tensor.requires_grad_(src_tensor.requires_grad) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py new file mode 100644 index 0000000000000000000000000000000000000000..0d113a66b06710b9cbe5626fb324bb49197583dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -0,0 +1,769 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +from typing import Any, Callable, cast, NamedTuple, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.fsdp._common_utils import _named_parameters_with_duplicates +from torch.distributed.tensor import Shard +from torch.profiler import record_function +from torch.utils._pytree import tree_flatten, tree_unflatten +from torch.utils.hooks import RemovableHandle + +from ._fsdp_api import CPUOffloadPolicy, MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_collectives import ( + AllGatherResult, + foreach_all_gather, + foreach_all_gather_copy_out, + foreach_reduce, +) +from ._fsdp_common import ( + compiled_autograd_enabled, + FSDPMeshInfo, + HSDPMeshInfo, + TrainingState, +) +from ._fsdp_param import FSDPParam, ParamModuleInfo, ShardedState + + +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + +_ModuleToHandleDict = dict[nn.Module, RemovableHandle] # for state dict + + +""" +[Note: Overlapping all-gather copy-in and all-gather] +For implicit forward prefetching, we want to overlap the next copy-in with the +current all-gather. We do so using a separate copy-in stream. However, since +we have the all-gather input as a view into the output, we must make sure to +copy into different memory from the current all-gather's output. Thus, we keep +a reference to the current all-gather's output and have the next FSDP parameter +group free it after its copy-in. Finally, we have the last FSDP state flush the +reference to avoid holding onto memory after forward. +""" + + +class FSDPCommContext: + """This has the communication state shared across FSDP states/parameter groups.""" + + def lazy_init(self, device: torch.device): + self.device_handle = _get_device_handle(device.type) + # Setting the all-gather/reduce-scatter streams to be higher priority + # can help avoid some issues where their copies in/out are delayed and + # block computation (this is different from high-pri NCCL streams) + high_priority = -1 + # All-gather state and copy-in stream allow overlapping the next + # copy-in with the current all-gather in forward; copy-in overlaps with + # reduce-scatter in backward without the separate copy-in stream + self.all_gather_copy_in_stream = self.device_handle.Stream( + priority=high_priority + ) + # All-gather stream allows overlapping next all-gather with current + # forward compute + self.all_gather_stream = self.device_handle.Stream(priority=high_priority) + # Reduce-scatter stream gives separate execution "thread" for post- + # backward logic like pre/post-gradient division and reduce-scatter + self.reduce_scatter_stream = self.device_handle.Stream(priority=high_priority) + # Run the HSDP all-reduces concurrently with all-gather/reduce-scatter + # since collectives use different network resources and can overlap + # in the typical intra-node sharding / inter-node replication case + self.all_reduce_stream = self.device_handle.Stream() + # All-gather/reduce-scatter states keep references to collective + # tensors produced in one stream and used in another and accompanying + # CUDA events for synchronization + self.all_gather_state: Optional[AllGatherState] = None + self.reduce_scatter_state: Optional[ReduceScatterState] = None + # Post-forward order for explicit backward prefetching + self.post_forward_order: list[FSDPParamGroup] = [] # will cause ref cycles + + def get_all_gather_streams( + self, async_op: bool, training_state: TrainingState + ) -> tuple[torch.Stream, torch.Stream]: + if not async_op and training_state in ( + TrainingState.FORWARD, + TrainingState.PRE_BACKWARD, + ): + # Use separate streams for implicit prefetching + return self.all_gather_copy_in_stream, self.all_gather_stream + current_stream = self.device_handle.current_stream() + return current_stream, current_stream + + +# See [Note: Overlapping all-gather copy-in and all-gather] +class AllGatherState(NamedTuple): + all_gather_result: AllGatherResult + event: Optional[torch.Event] # all-gather copy-out + + +class ReduceScatterState(NamedTuple): + reduce_scatter_input: torch.Tensor + event: Optional[torch.Event] # reduce-scatter event + + +class AllReduceState(NamedTuple): + all_reduce_input: torch.Tensor + event: Optional[torch.Event] # all-reduce event + + +class FSDPParamGroup: + """This class represents a parameter group to communicate together.""" + + _orig_dtype: Optional[torch.dtype] + _reduce_dtype: Optional[torch.dtype] + + def __init__( + self, + params: list[nn.Parameter], + modules: tuple[nn.Module, ...], + mesh_info: FSDPMeshInfo, + post_forward_mesh_info: Optional[FSDPMeshInfo], + device: torch.device, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]], + mp_policy: MixedPrecisionPolicy, + offload_policy: OffloadPolicy, + ): + self.modules = modules # permit ref cycle because 1:1 lifetime + param_module_infos = _get_param_module_infos(params, modules) + + self.fsdp_params = [ + FSDPParam( + param, + module_info, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + for param, module_info in zip(params, param_module_infos) + ] + self.mesh_info = mesh_info + self.post_forward_mesh_info = post_forward_mesh_info + self.device = device + self.device_handle = _get_device_handle(device.type) + self.mp_policy = mp_policy + self.offload_policy = offload_policy + self._training_state = TrainingState.IDLE + # Group's sharded state always matches its parameters' sharded states + self._sharded_state = ShardedState.SHARDED + self._module_fqn: Optional[str] = None # prefixed from root module + # Only consider resetting sharded parameters once in lazy init since it + # can incur nontrivial overhead to reset them + self._reset_sharded_params: bool = False + + # - Hook state + self._module_to_pre_save_state_dict_hook_handle: _ModuleToHandleDict = {} + self._module_to_pre_load_state_dict_hook_handle: _ModuleToHandleDict = {} + self._all_reduce_hook: Optional[Callable[[torch.Tensor], None]] = None + # Optional stream to run the user-defined all-reduce hook in + # Saved here and not in the comm. context because we allow the user to + # specify it, possibly at construction time before lazy init + self._all_reduce_hook_stream: Optional[torch.cuda.Stream] = None + + # - Communication and communication/computation overlap + self.comm_ctx = FSDPCommContext() + # Group's indices in the shared post-forward order + self._post_forward_indices: list[int] = [] + # Whether to reduce gradients at all (whether for FSDP or HSDP) + self.reduce_grads: bool = True + # Whether to all-reduce gradients for HSDP; only used if + # `self.reduce_grads` is true, in which case setting this to false + # means reduce-scatter but no all-reduce + self.all_reduce_grads: bool = True + # Whether to reshard parameters after backward (only useful for + # gradient accumulation) + self.reshard_after_backward: bool = True + # Optional custom factor for the gradient reduction op (e.g. to divide + # by a factor other than the world size) + self.gradient_divide_factor: Optional[float] = None + # Whether reduce-scatter and all-reduce should be issued using only + # summations, potentially with separate pre-/post-scaling. + self.force_sum_reduction_for_comms: bool = False + # `async_op` arg used for pre-forward/pre-backward unshard; can be + # overridden to only do explicit prefetching and avoid inter-stream + # fragmentation from using separate unshard streams + self.unshard_async_op: bool = False + # Whether to unshard in backward: can be overridden by the user if the + # parameters in this group are not needed for backward (e.g. embedding) + self.unshard_in_backward: bool = True + # Whether to (try to) use the ProcessGroup's allocate_tensor method for + # the staging buffers for collective comms. + self.allocate_memory_from_process_group = False + + # - CUDA events for stream synchronization + # Holds the all-gather output buffer, sync objects, and metadata + self._all_gather_result: Optional[AllGatherResult] = None + # Holds the reduce-scatter/all-reduce view-out CUDA event that marks the end of + # the group's post-backward (e.g. reduce-scatter, all-reduce and div), which + # should be waited on at the end of backward + self._post_reduce_event: Optional[torch.Event] = None + # Holds the reshard-after-forward CUDA event when resharding to a + # different world size, which should be waited on in the next unshard + self._reshard_after_forward_event: Optional[torch.Event] = None + + # Only for HSDP, if accumulating gradients without all-reduce, save the + # partial reduce output (only reduce-scattered but not all-reduced) + self._partial_reduce_output: Optional[torch.Tensor] = None + # Holds the all-reduce input and all-reduce event to keep it alive + # until the end of backward (critical when doing bf16 reduction with + # fp32 parameters since the all-reduce input is allocated in the RS + # stream and will have no refs to it after being upcast to fp32) + self._all_reduce_state: Optional[AllReduceState] = None + + # Initialization # + def _init_mp_dtypes(self) -> None: + for fsdp_param in self.fsdp_params: + fsdp_param.init_dtype_attrs(self.mp_policy) + trainable_params: list[FSDPParam] = [ + p for p in self.fsdp_params if p.sharded_param.requires_grad + ] + orig_dtypes = {p.orig_dtype for p in trainable_params} + reduce_dtypes = {p.reduce_dtype for p in trainable_params} + if len(trainable_params) > 0 and len(orig_dtypes) != 1: + # Models may have no grad params + raise AssertionError( + f"FSDP expects uniform original parameter dtype but got {orig_dtypes}" + ) + self._orig_dtype = next(iter(orig_dtypes)) if len(trainable_params) else None + if len(trainable_params) > 0 and len(reduce_dtypes) != 1: + # This can be relaxed if we issue one reduce-scatter per reduce + # dtype (but we would need a way for users to specify multiple + # reduce dtypes) + raise AssertionError( + f"FSDP expects uniform reduce dtype but got {reduce_dtypes}" + ) + self._reduce_dtype = ( + next(iter(reduce_dtypes)) if len(trainable_params) else None + ) + + def lazy_init(self): + # Lazy init should be idempotent + # Users may change or register parameters after construction time. + # For example, DoRA (https://arxiv.org/abs/2402.09353) initializes linear magnitudes based on + # other parameters (e.g. loaded from the state dict). + if not hasattr(self.comm_ctx, "device_handle"): + self.comm_ctx.device_handle = _get_device_handle(self.device.type) + if self.is_sharded and not self._reset_sharded_params: + for fsdp_param in self.fsdp_params: + fsdp_param.reset_sharded_param() + fsdp_param._init_extensions() # allow monkey patch after init + self._reset_sharded_params = True + self._validate_no_meta_params() + self._validate_cpu_offload_params() + # Initialize mixed precision attributes lazily in case the user changes + # the parameter dtypes after construction time but before forward + self._init_mp_dtypes() + self._register_state_dict_hooks() + + # Runtime # + def unshard(self, async_op: bool = False): + if self._all_gather_result is not None: # already called, pending wait + return + if self.is_unsharded: + return # no-op + if ( + not self.unshard_in_backward + and self._training_state == TrainingState.PRE_BACKWARD + ): + return + if self._reshard_after_forward_event is not None: + # Resharded parameter data is allocated in the default stream and + # used in the all-gather streams + self._wait_all_gather_streams_on_event(self._reshard_after_forward_event) + self._reshard_after_forward_event = None + with record_function(self._with_fqn("FSDP::all_gather")): + self._all_gather_result = foreach_all_gather( + self.fsdp_params, + self._all_gather_process_group, + async_op, + *self.comm_ctx.get_all_gather_streams(async_op, self._training_state), + self.device, + self.allocate_memory_from_process_group, + ) + + def wait_for_unshard(self): + """ + 1. In forward with implicit prefetching, to overlap the current copy-out + with the next all-gather, we save a reference to the current all-gather + result to free after the next copy-out. + 2. Otherwise (explicit prefetching or in backward), we free the + all-gather result immediately after the current copy-out since we can + already overlap the current copy-out with the previous reduce-scatter. + """ + if not self._all_gather_result: + return # no preceding unshard + async_op = self._all_gather_result.all_gather_work is not None + if self._training_state == TrainingState.FORWARD: # implicit prefetch + if prev_all_gather_state := self.comm_ctx.all_gather_state: + self._wait_all_gather_streams_on_event(prev_all_gather_state.event) + self.comm_ctx.all_gather_state = None # free the all-gather result + with record_function(self._with_fqn("FSDP::all_gather_copy_out")): + foreach_all_gather_copy_out( + self._all_gather_result, + self.fsdp_params, + self._all_gather_process_group, + ) + for fsdp_param in self.fsdp_params: + fsdp_param.init_unsharded_param() + self._to_unsharded() + all_gather_copy_out_event = self.device_handle.Event() + all_gather_copy_out_event.record() + if not async_op and self._training_state == TrainingState.FORWARD: + # Defer free to allow for overlap of this copy-out with next + # all-gather collective + self.comm_ctx.all_gather_state = AllGatherState( + self._all_gather_result, all_gather_copy_out_event + ) + else: + self._wait_all_gather_streams_on_event(all_gather_copy_out_event) + self._all_gather_result = None # free unless saved in `all_gather_state` + + def _wait_all_gather_streams_on_event(self, event: Optional[torch.Event]): + # Calling `unshard` before lazy init means streams are not initialized + if hasattr(self.comm_ctx, "all_gather_copy_in_stream") and event is not None: + self.comm_ctx.all_gather_copy_in_stream.wait_event(event) + if hasattr(self.comm_ctx, "all_gather_stream") and event is not None: + self.comm_ctx.all_gather_stream.wait_event(event) + + def reshard(self): + if self._training_state == TrainingState.FORWARD: + if not self._reshard_after_forward: + return + if self._use_post_forward_mesh: + self._to_sharded_post_forward() + self._reshard_after_forward_event = self.device_handle.Event() + if self._reshard_after_forward_event is not None: + self._reshard_after_forward_event.record() + return + self._to_sharded() + + def pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::pre_forward")) + with record_function(self._with_fqn("FSDP::pre_forward")): + self._training_state = TrainingState.FORWARD + self.unshard(self.unshard_async_op) + self.wait_for_unshard() + args, kwargs = self._register_post_backward_hook(args, kwargs) + return args, kwargs + + def post_forward(self, module: nn.Module, input: Any, output: Any): + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::post_forward")) + with record_function(self._with_fqn("FSDP::post_forward")): + self.reshard() + self._record_post_forward() + self._training_state = TrainingState.IDLE + return output + + def _record_post_forward(self) -> None: + # Since a group has one pre-backward unshard for each forward call + # before the backward, we record each usage (with multiplicity) + post_forward_index = len(self.comm_ctx.post_forward_order) + self.comm_ctx.post_forward_order.append(self) + self._post_forward_indices.append(post_forward_index) + + def pre_backward(self, default_prefetch: bool, *unused: Any): + if ( + compiled_autograd_enabled() + and self._training_state == TrainingState.PRE_BACKWARD + ): + # Traceable FSDP2 cannot trigger the param group's `post_backward` immediately after param usage; + # instead it relies on this to trigger the previously unexecuted `post_backward`. + self.post_backward() + if self._training_state == TrainingState.PRE_BACKWARD: + return + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::pre_backward")) + with record_function(self._with_fqn("FSDP::pre_backward")): + self._training_state = TrainingState.PRE_BACKWARD + self.unshard(self.unshard_async_op) # no-op if prefetched + self.wait_for_unshard() + if default_prefetch and not compiled_autograd_enabled(): + self._backward_prefetch() + + def post_backward(self, *unused: Any): + # This method should be idempotent and safe to call even when this + # FSDP parameter group was not used in backward (should be a no-op) + if not compiled_autograd_enabled(): + logger.debug("%s", self._with_fqn("FSDP::post_backward")) + self._training_state = TrainingState.POST_BACKWARD + with record_function(self._with_fqn("FSDP::post_backward_accumulate")): + for fsdp_param in self.fsdp_params: + fsdp_param.accumulate_unsharded_grad_if_needed() + with record_function(self._with_fqn("FSDP::post_backward_reshard")): + if not self.reduce_grads: + if self.reshard_after_backward: + self.reshard() + for fsdp_param in self.fsdp_params: + fsdp_param.to_accumulated_grad_if_needed() + return + # Save the autograd-computed gradients before resharding to only + # access the unsharded parameters when their data is present + fsdp_params_with_grad: list[FSDPParam] = [] + unsharded_grads: list[torch.Tensor] = [] + for fsdp_param in self.fsdp_params: + if not hasattr(fsdp_param, "_unsharded_param"): + continue + # May have an accumulated gradient of the reduce dtype if the + # previous backward did not reduce-scatter + if fsdp_param.unsharded_accumulated_grad is not None: + fsdp_params_with_grad.append(fsdp_param) + unsharded_grads.append(fsdp_param.unsharded_accumulated_grad_data) + fsdp_param.unsharded_accumulated_grad = None + elif fsdp_param.unsharded_param.grad is not None: + fsdp_params_with_grad.append(fsdp_param) + unsharded_grads.append(fsdp_param.unsharded_grad_data) + fsdp_param.unsharded_param.grad = None + if self.reshard_after_backward: + self.reshard() + if len(fsdp_params_with_grad) == 0: + return + with record_function(self._with_fqn("FSDP::post_backward_reduce")): + if ( + self.comm_ctx.reduce_scatter_state is not None + and self.comm_ctx.reduce_scatter_state.event is not None + ): + self.device_handle.current_stream().wait_event( + self.comm_ctx.reduce_scatter_state.event + ) + self.comm_ctx.reduce_scatter_state = None + all_reduce_pg = self._all_reduce_process_group if self._is_hsdp else None + all_reduce_stream: torch.cuda.Stream + if all_reduce_pg is None and self._all_reduce_hook_stream is not None: + # this means the native HSDP is not enabled, + # but user may want to have a custom HSDP setup + assert self._all_reduce_hook is not None, ( + "all reduce hook stream is specified but hook itself is missing." + ) + all_reduce_stream = self._all_reduce_hook_stream + else: + all_reduce_stream = self.comm_ctx.all_reduce_stream + + self._wait_for_post_backward() + ( + reduce_scatter_input, + reduce_scatter_event, + self._post_reduce_event, + all_reduce_input, + all_reduce_event, + self._partial_reduce_output, + ) = foreach_reduce( + fsdp_params_with_grad, + unsharded_grads, + self._reduce_scatter_process_group, + self.comm_ctx.reduce_scatter_stream, + self._orig_dtype, + self._reduce_dtype, + self.device, + self.gradient_divide_factor, + self._all_reduce_process_group if self._is_hsdp else None, + all_reduce_stream, + self.all_reduce_grads, + self._partial_reduce_output, + self._all_reduce_hook, + self.allocate_memory_from_process_group, + self.force_sum_reduction_for_comms, + ) + self.comm_ctx.reduce_scatter_state = ReduceScatterState( + reduce_scatter_input, reduce_scatter_event + ) + if all_reduce_input is not None: + if self.device.type != "cpu": + assert all_reduce_event is not None + self._all_reduce_state = AllReduceState( + all_reduce_input, all_reduce_event + ) + + def finalize_backward(self): + self._wait_for_post_backward() + for fsdp_param in self.fsdp_params: + if fsdp_param.grad_offload_event is not None: + fsdp_param.grad_offload_event.synchronize() + fsdp_param.grad_offload_event = None + if self._all_gather_result is not None: + # If there was a mistargeted unshard without a corresponding wait, + # then we wait here and clear the unshard + if (event := self._all_gather_result.all_gather_event) is not None: + torch.accelerator.current_stream().wait_event(event) + work = self._all_gather_result.all_gather_work + if isinstance(work, dist.distributed_c10d.Work): + work.wait() + self._all_gather_result = None + self._post_forward_indices.clear() + + def _wait_for_post_backward(self): + if self._post_reduce_event is not None: + self.device_handle.current_stream().wait_event(self._post_reduce_event) + self._post_reduce_event = None + if ( + self._all_reduce_state is not None + and self._all_reduce_state.event is not None + ): + self.device_handle.current_stream().wait_event(self._all_reduce_state.event) + self._all_reduce_state = None + + def _backward_prefetch(self) -> None: + if self._training_state == TrainingState.PRE_BACKWARD: + if not self._post_forward_indices: + # Can be cleared if running multiple `backward`s + return + curr_index = self._post_forward_indices.pop() + if (target_index := curr_index - 1) < 0: + return + # Prefetch naively using the reverse post-forward order, which may + # have mistargeted prefetches if not all modules used in forward + # are used in this backward + target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] + self._prefetch_unshard(target_fsdp_param_group, "backward") + + @staticmethod + def _prefetch_unshard( + target_fsdp_param_group: "FSDPParamGroup", pass_type: str + ) -> None: + if pass_type == "backward": + training_state = TrainingState.PRE_BACKWARD + elif pass_type == "forward": + training_state = TrainingState.FORWARD + else: + raise ValueError(f"Unknown pass type: {pass_type}") + target_fqn = target_fsdp_param_group._module_fqn + with ( + record_function(f"FSDP::{pass_type}_prefetch for {target_fqn}"), + target_fsdp_param_group.use_training_state(training_state), + ): + async_op = target_fsdp_param_group.unshard_async_op + target_fsdp_param_group.unshard(async_op) + + # Utilities # + def _to_sharded(self): + if not self.is_sharded: + for fsdp_param in self.fsdp_params: + fsdp_param.to_sharded() + self._sharded_state = ShardedState.SHARDED + + def _to_sharded_post_forward(self): + if not self.is_sharded_post_forward: + for fsdp_param in self.fsdp_params: + fsdp_param.to_sharded_post_forward() + self._sharded_state = ShardedState.SHARDED_POST_FORWARD + + def _to_unsharded(self): + if not self.is_unsharded: + for fsdp_param in self.fsdp_params: + fsdp_param.to_unsharded() + self._sharded_state = ShardedState.UNSHARDED + + @property + def is_sharded(self) -> bool: + return self._sharded_state == ShardedState.SHARDED + + @property + def is_sharded_post_forward(self) -> bool: + return self._sharded_state == ShardedState.SHARDED_POST_FORWARD + + @property + def is_unsharded(self) -> bool: + return self._sharded_state == ShardedState.UNSHARDED + + @contextlib.contextmanager + def use_training_state(self, training_state: TrainingState): + old_training_state = self._training_state + self._training_state = training_state + try: + yield + finally: + self._training_state = old_training_state + + # Hook Registration # + def _register_post_backward_hook( + self, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + # Traceable FSDP2 relies on `root_post_backward_callback` to call each + # `FSDPParamGroup.post_backward` + if (not torch._dynamo.config.skip_fsdp_hooks) or compiled_autograd_enabled(): + return args, kwargs + if not torch.is_grad_enabled(): + return args, kwargs + args_list, args_spec = tree_flatten(args) + kwargs_list, kwargs_spec = tree_flatten(kwargs) + args_kwargs_list = list(args_list) + list(kwargs_list) + inp_tensor_indices: list[int] = [] + inp_tensors: list[torch.Tensor] = [] + for i, obj in enumerate(args_kwargs_list): + if torch.is_tensor(obj) and obj.requires_grad: + inp_tensor_indices.append(i) + inp_tensors.append(obj) + if len(inp_tensors) == 0: + return args, kwargs # no tensors that require gradients + inp_tensors = RegisterPostBackwardFunction.apply(self, *inp_tensors) + for inp_tensor_idx, inp_tensor in zip(inp_tensor_indices, inp_tensors): + args_kwargs_list[inp_tensor_idx] = inp_tensor + args_list = args_kwargs_list[: len(args_list)] + kwargs_list = args_kwargs_list[len(args_list) :] + args = tree_unflatten(args_list, args_spec) + kwargs = tree_unflatten(kwargs_list, kwargs_spec) + return args, kwargs + + def _register_state_dict_hooks(self) -> None: + num_pre_save_hooks = len(self._module_to_pre_save_state_dict_hook_handle) + num_pre_load_hooks = len(self._module_to_pre_load_state_dict_hook_handle) + assert num_pre_save_hooks == num_pre_load_hooks, ( + f"Pre-save: {num_pre_save_hooks} pre-load: {num_pre_load_hooks}" + ) + if num_pre_save_hooks > 0: + return # already registered + modules_with_fsdp_params: set[nn.Module] = { + fsdp_param._module_info.module for fsdp_param in self.fsdp_params + } + + def to_sharded_hook(*args: Any, **kwargs: Any) -> None: + self._to_sharded() + + for module in modules_with_fsdp_params: + self._module_to_pre_save_state_dict_hook_handle[module] = ( + module.register_state_dict_pre_hook(to_sharded_hook) + ) + self._module_to_pre_load_state_dict_hook_handle[module] = ( + module._register_load_state_dict_pre_hook(to_sharded_hook) + ) + + # Properties # + @property + def _reshard_after_forward(self) -> bool: + return self.post_forward_mesh_info is not None + + @property + def _use_post_forward_mesh(self) -> bool: + return ( + self._reshard_after_forward + and self.mesh_info != self.post_forward_mesh_info + ) + + @property + def _is_hsdp(self) -> bool: + return isinstance(self.mesh_info, HSDPMeshInfo) + + @property + def _all_gather_process_group(self) -> dist.ProcessGroup: + mesh_info = ( + cast(FSDPMeshInfo, self.post_forward_mesh_info) + if self.is_sharded_post_forward + else self.mesh_info + ) + assert isinstance(mesh_info, FSDPMeshInfo) + return mesh_info.shard_process_group + + @property + def _reduce_scatter_process_group(self) -> dist.ProcessGroup: + assert isinstance(self.mesh_info, FSDPMeshInfo) + return self.mesh_info.shard_process_group + + @property + def _all_reduce_process_group(self) -> dist.ProcessGroup: + assert isinstance(self.mesh_info, HSDPMeshInfo) + return self.mesh_info.replicate_process_group + + def _with_fqn(self, label: str) -> str: + if self._module_fqn: + return f"{label} ({self._module_fqn})" + return label + + def __repr__(self): + return f"FSDPParamGroup(fqn={self._module_fqn})" + + def _validate_no_meta_params(self): + param_names_on_meta = [ + fsdp_param._param_fqn + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type == "meta" + ] + if param_names_on_meta: + raise RuntimeError( + "FSDP parameters should be materialized from meta device before training, " + f"but the following were still on meta device: {param_names_on_meta}\n" + "For example, call module.to_empty(device) to materialize to device and " + "call module.reset_parameters() on each module to initialize values." + ) + + def _validate_cpu_offload_params(self): + if not isinstance(self.offload_policy, CPUOffloadPolicy): + return + fsdp_params_not_on_cpu = [ + fsdp_param + for fsdp_param in self.fsdp_params + if fsdp_param.sharded_param.device.type != "cpu" + ] + if fsdp_params_not_on_cpu: + raise RuntimeError( + "FSDP parameters should be materialized on CPU when enabling CPU offloading. " + 'For example, load a CPU state dict or call module.to_empty(device="cpu"). ' + "Found following parameters on non-CPU device: " + f"{[(fsdp_param._param_fqn, fsdp_param.sharded_param.device) for fsdp_param in fsdp_params_not_on_cpu]}\n" + ) + + +def _get_param_module_infos( + params: list[nn.Parameter], modules: tuple[nn.Module, ...] +) -> list[ParamModuleInfo]: + """ + Shared parameter: lin1.weight = lin2.weight + Shared module: mlp.lin1 = mlp.lin2 + We do not remove duplicates when traversing both modules and parameters to + find shared modules' parameters and shared parameters within a module. + """ + params_set = set(params) + param_to_module_info: dict[nn.Parameter, ParamModuleInfo] = {} + for module in modules: + for _, submodule in module.named_modules(remove_duplicate=False): + for param_name, param in _named_parameters_with_duplicates( + submodule, recurse=False + ): + if param in params_set: + if param not in param_to_module_info: + param_to_module_info[param] = ParamModuleInfo( + submodule, param_name + ) + else: + param_to_module_info[param].shared_modules.append(submodule) + param_to_module_info[param].shared_param_names.append( + param_name + ) + if len(param_to_module_info) != len(params): + raise AssertionError(f"Some parameters are not in the module tree of {module}") + return [param_to_module_info[param] for param in params] + + +class RegisterPostBackwardFunction(torch.autograd.Function): + @staticmethod + def _assert_not_tracing_fsdp(): + if compiled_autograd_enabled(): + # TODO: Find a way to print the offending FSDP2 module. + msg = """\ +When Traceable FSDP2 is enabled, we should not be calling into `RegisterPostBackwardFunction`. +Instead, we rely on the param group's next `pre_backward` hook to trigger its previously unexecuted +`post_backward`, and we rely on FSDPState's `root_post_backward_callback` to trigger the resharding +of any leftover unsharded param groups. +If you are here, it means the forward part of this FSDP2 instance is not compiled, and you must also +compile the forward part if you want to use Traceable FSDP2.""" + torch._dynamo.comptime.comptime.print(msg) + raise RuntimeError(msg) + + @staticmethod + def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): + # All tensors in `inputs` should require gradient + RegisterPostBackwardFunction._assert_not_tracing_fsdp() + ctx.param_group = param_group + return inputs + + @staticmethod + def backward(ctx, *grads: torch.Tensor): + RegisterPostBackwardFunction._assert_not_tracing_fsdp() + ctx.param_group.post_backward() + return (None,) + grads diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py new file mode 100644 index 0000000000000000000000000000000000000000..2eccab574aef54252d43c285f39f861be9a819ab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -0,0 +1,403 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import functools +import logging +from collections.abc import Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING + +import torch +import torch.nn as nn +from torch._logging import warning_once +from torch.autograd import Variable +from torch.autograd.graph import _MultiHandle +from torch.distributed._composable_state import ( + _get_module_state, + _insert_module_state, + _State, +) +from torch.distributed.device_mesh import _get_device_handle +from torch.distributed.utils import _apply_to_tensors, _to_kwargs +from torch.utils._pytree import tree_flatten + +from ._fsdp_api import MixedPrecisionPolicy +from ._fsdp_common import ( + _cast_fp_tensor, + compiled_autograd_enabled, + detect_compiled_autograd, + TrainingState, +) +from ._fsdp_param_group import FSDPCommContext, FSDPParamGroup + + +if TYPE_CHECKING: + from ._fsdp_param import FSDPParam + + +logger = logging.getLogger("torch.distributed.fsdp.fully_shard") + + +class FSDPStateContext: + """This has state shared across FSDP states.""" + + def __init__(self) -> None: + # All FSDP states in the root state's module tree + self.all_states: list[FSDPState] = [] + # Iteration's forward root runs the once-per-forward logic; this root + # may not be the overall root set by lazy initialization in cases where + # only a submodule runs forward (e.g. encoder-only for eval) + self.iter_forward_root: Optional[FSDPState] = None + # Final callback should only be queued once per backward + self.post_backward_final_callback_queued: bool = False + # Whether to finalize backward in this backward's final callback + self.is_last_backward: bool = True + # Optional user-provided event recorded after optimizer for the + # all-gather streams to wait on in the root pre-forward + self.post_optim_event: Optional[torch.Event] = None + + +def disable_if_config_true(func): + @functools.wraps(func) + def fsdp_hook_wrapper(*args, **kwargs): + if torch._dynamo.config.skip_fsdp_hooks: + return torch._dynamo.disable( + func, + recursive=True, + reason="skipping FSDP hooks since torch._dynamo.config.skip_fsdp_hooks is set", + )(*args, **kwargs) + else: + return func(*args, **kwargs) + + return fsdp_hook_wrapper + + +class FSDPState(_State): + def __init__(self) -> None: + super().__init__() + self._fsdp_param_group: Optional[FSDPParamGroup] = None + self._is_root: Optional[bool] = None # root set during lazy init + self._state_ctx = FSDPStateContext() + self._comm_ctx = FSDPCommContext() + self._training_state: TrainingState = TrainingState.IDLE + self._states_to_forward_prefetch: list[FSDPState] = [] + self._states_to_backward_prefetch: list[FSDPState] = [] + self._modules_to_run_forward: set[nn.Module] = set() + # ``False`` when user set reshard_after_forward + # through ``fully_shard`` or ``set_reshard_after_forward`` + self._auto_reshard_after_forward: Optional[bool] = True + + # Define a separate init since `__init__` is called in the contract + def init( + self, + modules: tuple[nn.Module, ...], + device: torch.device, + mp_policy: MixedPrecisionPolicy, + auto_reshard_after_forward: bool, + ) -> None: + for module in modules: + _insert_module_state(module, self) + self._modules = modules + self._device = device + self._device_handle = _get_device_handle(device.type) + self._mp_policy = mp_policy + self._auto_reshard_after_forward = auto_reshard_after_forward + if len(modules) == 1: + self._pre_forward_hook_handle = modules[0].register_forward_pre_hook( + self._pre_forward, prepend=True, with_kwargs=True + ) + self._post_forward_hook_handle = modules[0].register_forward_hook( + self._post_forward, prepend=False + ) + else: + hook_handle = _register_group_forward_hooks( + modules, + self._pre_forward, + self._post_forward, + self._modules_to_run_forward, + ) + self._pre_forward_hook_handle = hook_handle + self._post_forward_hook_handle = hook_handle + + def _root_pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + self._lazy_init() + if self._state_ctx.iter_forward_root is not None: + return args, kwargs + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_pre_forward") + self._state_ctx.iter_forward_root = self + with torch.profiler.record_function("FSDP::root_pre_forward"): + # Wait for optimizer before implicitly prefetched all-gathers + if (event := self._state_ctx.post_optim_event) is not None: + self._comm_ctx.all_gather_copy_in_stream.wait_event(event) + self._comm_ctx.all_gather_stream.wait_event(event) + self._state_ctx.post_optim_event = None + else: + current_stream = self._device_handle.current_stream() + self._comm_ctx.all_gather_copy_in_stream.wait_stream(current_stream) + self._comm_ctx.all_gather_stream.wait_stream(current_stream) + if self._device.type in [ + "cuda", + "hpu", + "xpu", + "mtia", + torch._C._get_privateuse1_backend_name(), + ]: + with torch.profiler.record_function("FSDP::inputs_to_device"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, self._device, False + ) # same as DDP + args, kwargs = args_tuple[0], kwargs_tuple[0] + return args, kwargs + + def _lazy_init(self) -> None: + """ + Lazy initialization represents when all modules' parallelisms have + finalized (e.g. FSDP has been applied to all desired modules). This + means that we can determine which state is the root, and we do so by + the 1st state to run forward. + """ + if self._is_root is not None: + return # no-op: already initialized + self._is_root = True + if len(self._modules) > 1: + raise RuntimeError( + f"FSDP requires a single root module but got {self._modules}" + ) + detect_compiled_autograd() + root_module = self._modules[0] + visited_states: set[FSDPState] = set() + for module_name, module in root_module.named_modules(): + if (state := _get_module_fsdp_state(module)) is None: + continue + if module is not root_module: + if state not in visited_states and state._is_root is not None: + raise RuntimeError( + "FSDP state has already been lazily initialized for " + f"{module_name}\nFSDP requires running forward through " + "the root module first" + ) + state._is_root = False + self._state_ctx.all_states.append(state) + visited_states.add(state) + if self._fsdp_param_group and self._auto_reshard_after_forward: + # For the root, do not reshard after forward since for training, + # the parameters would be freed and all-gathered immediately + self._fsdp_param_group.post_forward_mesh_info = None + self._init_fqns() + self._init_shared_state() + # Run parameter group lazy inits after initializing FQNs for improved + # error messages + for state in self._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.lazy_init() + + def _init_shared_state(self) -> None: + self._comm_ctx.lazy_init(self._device) + for state in self._state_ctx.all_states: + state._state_ctx = self._state_ctx + state._comm_ctx = self._comm_ctx + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.comm_ctx = self._comm_ctx + + def _init_fqns(self) -> None: + """Sets module and parameter FQN attributes for debugging.""" + assert self._is_root + root_module = self._modules[0] + param_to_fsdp_param: dict[nn.Parameter, FSDPParam] = {} + module_to_fsdp_param_group: dict[nn.Module, FSDPParamGroup] = {} + for state in self._state_ctx.all_states: + if fsdp_param_group := state._fsdp_param_group: + for fsdp_param in fsdp_param_group.fsdp_params: + param_to_fsdp_param[fsdp_param.sharded_param] = fsdp_param + for module in fsdp_param_group.modules: + module_to_fsdp_param_group[module] = fsdp_param_group + for param_name, param in root_module.named_parameters(): + if param in param_to_fsdp_param: + param_to_fsdp_param[param]._param_fqn = param_name + for module_name, module in root_module.named_modules(): + if module in module_to_fsdp_param_group: + module_fqn = module_to_fsdp_param_group[module]._module_fqn + if module_fqn is None: + module_to_fsdp_param_group[module]._module_fqn = module_name + else: + assert isinstance(module_fqn, str), f"{module_fqn}" + module_fqn += f", {module_name}" + module_to_fsdp_param_group[module]._module_fqn = module_fqn + + @disable_if_config_true + def _pre_forward( + self, module: nn.Module, args: tuple[Any, ...], kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + # When composing with module-hook-based activation checkpointing, the + # the pre-backward hook is responsible for the unshard + if self._training_state == TrainingState.PRE_BACKWARD: + return args, kwargs + self._training_state = TrainingState.FORWARD + args, kwargs = self._root_pre_forward(module, args, kwargs) + if self._mp_policy.cast_forward_inputs and self._mp_policy.param_dtype: + with torch.profiler.record_function("FSDP::cast_forward_inputs"): + cast_fn = functools.partial( + _cast_fp_tensor, self._mp_policy.param_dtype + ) + args, kwargs = ( + _apply_to_tensors(cast_fn, args), + _apply_to_tensors(cast_fn, kwargs), + ) + if self._fsdp_param_group: + args, kwargs = self._fsdp_param_group.pre_forward(module, args, kwargs) + for fsdp_state in self._states_to_forward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "forward") + return args, kwargs + + @disable_if_config_true + def _post_forward(self, module: nn.Module, input: Any, output: Any) -> Any: + # When composing with module-hook-based activation checkpointing, the + # post-backward hook is responsible for the reshard + if self._training_state == TrainingState.PRE_BACKWARD: + return output + if self._fsdp_param_group: + output = self._fsdp_param_group.post_forward(module, input, output) + output = self._register_pre_backward_hook(output) + self._training_state = TrainingState.IDLE + if self._state_ctx.iter_forward_root is self: + if all_gather_state := self._comm_ctx.all_gather_state: + # Free the last all-gather result if needed; refer to + # [Note: Overlapping all-gather copy-in and all-gather] + self._comm_ctx.all_gather_copy_in_stream.wait_event( + all_gather_state.event + ) + self._comm_ctx.all_gather_stream.wait_event(all_gather_state.event) + self._comm_ctx.all_gather_state = None # free the all-gather result + self._state_ctx.iter_forward_root = None + if self._mp_policy.output_dtype is not None: + with torch.profiler.record_function("FSDP::cast_forward_outputs"): + output = _apply_to_tensors( + functools.partial(_cast_fp_tensor, self._mp_policy.output_dtype), + output, + ) + return output + + def _pre_backward(self, grad: torch.Tensor) -> torch.Tensor: + self._training_state = TrainingState.PRE_BACKWARD + self._register_root_post_backward_final_callback() + if self._fsdp_param_group: + default_prefetch = len(self._states_to_backward_prefetch) == 0 + self._fsdp_param_group.pre_backward(default_prefetch) + for fsdp_state in self._states_to_backward_prefetch: + if (target_param_group := fsdp_state._fsdp_param_group) is not None: + FSDPParamGroup._prefetch_unshard(target_param_group, "backward") + return grad + + def _root_post_backward_final_callback(self) -> None: + if not compiled_autograd_enabled(): + logger.debug("FSDP::root_post_backward") + with torch.profiler.record_function("FSDP::root_post_backward_callback"): + for state in self._state_ctx.all_states: + fsdp_param_group = state._fsdp_param_group + if ( + fsdp_param_group + and fsdp_param_group._training_state != TrainingState.POST_BACKWARD + ): + # Run post-backward in case forward inputs did not require + # gradient so the autograd backward did not run + fsdp_param_group.post_backward() + state._training_state = TrainingState.IDLE + if fsdp_param_group: + fsdp_param_group._training_state = TrainingState.IDLE + if self._state_ctx.is_last_backward: + state._finalize_backward() + if self._state_ctx.is_last_backward: + self._comm_ctx.post_forward_order.clear() + if self._comm_ctx.reduce_scatter_state is not None: + self._device_handle.current_stream().wait_event( + self._comm_ctx.reduce_scatter_state.event + ) + self._comm_ctx.reduce_scatter_state = None + self._state_ctx.post_backward_final_callback_queued = False + + def _finalize_backward(self) -> None: + if self._modules_to_run_forward: + msg = ( + f"{len(self._modules_to_run_forward)} of the {len(self._modules)} " + f"modules passed to fully_shard did not run forward before backward, " + "which is error-prone since FSDP post-forward/pre-backward logic " + "will not run for these modules. We recommend passing only modules " + "that run forward together. Modules that did not run forward: " + f"{list(self._modules_to_run_forward)}" + ) + warning_once(logger, msg, stacklevel=2) + # Clear since we want the next forward to run + self._modules_to_run_forward.clear() + if self._fsdp_param_group: + self._fsdp_param_group.finalize_backward() + + def _register_pre_backward_hook(self, output: Any) -> Any: + if not torch.is_grad_enabled(): + return output + flat_outputs, _ = tree_flatten(output) + for t in flat_outputs: + if torch.is_tensor(t) and t.requires_grad: + t.register_hook(self._pre_backward) + return output + + def _register_root_post_backward_final_callback(self): + if self._state_ctx.post_backward_final_callback_queued: + return + self._state_ctx.post_backward_final_callback_queued = True + Variable._execution_engine.queue_callback( + self._root_post_backward_final_callback + ) + + +def _get_module_fsdp_state(module: nn.Module) -> Optional[FSDPState]: + state = _get_module_state(module) + if isinstance(state, FSDPState): + return state + return None + + +def _register_group_forward_hooks( + modules: Sequence[nn.Module], + pre_hook: Callable, + post_hook: Callable, + modules_to_run: set[nn.Module], +): + """ + Registers group forward pre and post-hooks. The pre-hook runs upon the + first module pre-forward, and the post-hook runs upon the last. If at least + one module does not run forward, then the post-hook does not run. + """ + modules_set = set(modules) + + @disable_if_config_true + @functools.wraps(pre_hook) + def wrapped_pre_hook(*args: Any, **kwargs: Any): + if len(modules_to_run) == 0: # first to run + modules_to_run.update(modules_set) + return pre_hook(*args, **kwargs) + + @disable_if_config_true + def get_wrapped_post_hook(module: nn.Module): + @functools.wraps(post_hook) + def wrapped_post_hook(*args: Any, **kwargs: Any): + modules_to_run.discard(module) + if len(modules_to_run) == 0: + return post_hook(*args, **kwargs) + + return wrapped_post_hook + + pre_handles = [ + module.register_forward_pre_hook( + wrapped_pre_hook, prepend=True, with_kwargs=True + ) + for module in modules + ] + post_handles = [ + module.register_forward_hook( + get_wrapped_post_hook(module), prepend=False, always_call=True + ) + for module in modules + ] + return _MultiHandle(tuple(pre_handles + post_handles)) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py new file mode 100644 index 0000000000000000000000000000000000000000..023cf8045bc120376bf899c206fa91620e3b1a60 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -0,0 +1,672 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +from __future__ import annotations + +import functools +from typing import ( + Any, + Callable, + cast, + NoReturn, + Optional, + overload, + TYPE_CHECKING, + Union, +) +from typing_extensions import deprecated + +import torch +import torch.nn as nn +from torch.distributed._composable import contract +from torch.distributed.utils import _get_root_modules + +from ._fsdp_api import MixedPrecisionPolicy, OffloadPolicy +from ._fsdp_common import FSDPMeshInfo, HSDPMeshInfo +from ._fsdp_init import ( + _get_device_from_mesh, + _get_managed_modules, + _get_managed_states, + _get_post_forward_mesh_info, + _init_default_fully_shard_mesh, + _move_states_to_device, +) +from ._fsdp_param_group import FSDPParamGroup +from ._fsdp_state import _get_module_fsdp_state, FSDPState + + +if TYPE_CHECKING: + from collections.abc import Iterable + + from torch.distributed.tensor import DeviceMesh, Shard + +__all__ = [ + "fully_shard", + "FSDPModule", + "UnshardHandle", + "register_fsdp_forward_method", +] + + +cls_to_fsdp_cls: dict[type, type] = {} + + +@overload +def fully_shard( + module: nn.Module, + *, + mesh: Optional[DeviceMesh] = ..., + reshard_after_forward: Union[bool, int] = ..., + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: Optional[set[nn.Parameter]] = ..., +) -> FSDPModule: ... + + +@overload +def fully_shard( + module: list[nn.Module], + *, + mesh: Optional[DeviceMesh] = ..., + reshard_after_forward: Union[bool, int] = ..., + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = ..., + mp_policy: MixedPrecisionPolicy = ..., + offload_policy: OffloadPolicy = ..., + ignored_params: Optional[set[nn.Parameter]] = ..., +) -> list[FSDPModule]: ... + + +# The decorator adds a state object to `module` that can be accessed via +# `fully_shard.state(module)`. The state object and module are 1:1. +# [1] Python runtime decorator does not play well with static type checking +# so suppressing some type checks to support type overloads +# such that caller can still get correct return types based on input type +@contract(state_cls=FSDPState) # type: ignore[misc] # see [1] +def fully_shard( + module, + *, + mesh: Optional[DeviceMesh] = None, + reshard_after_forward: Optional[Union[bool, int]] = None, + shard_placement_fn: Optional[Callable[[nn.Parameter], Optional[Shard]]] = None, + mp_policy: MixedPrecisionPolicy = MixedPrecisionPolicy(), + offload_policy: OffloadPolicy = OffloadPolicy(), + ignored_params: Optional[set[nn.Parameter]] = None, +): + """ + Apply fully sharded data parallelism (FSDP) to ``module``, where FSDP + shards module parameters, gradients, and optimizer states across data + parallel workers to save memory at the cost of communication. + + At initialization, FSDP shards the module's parameters across the data + parallel workers given by ``mesh``. Before forward, FSDP all-gathers the + sharded parameters across the data-parallel workers to get the unsharded + parameters for forward computation. If ``reshard_after_forward`` is + ``True``, then FSDP frees the unsharded parameters after forward and + re-all-gathers them in backward before gradient computation. After gradient + computation, FSDP frees the unsharded parameters and reduce-scatters the + unsharded gradients across data-parallel workers. + + This implementation represents the sharded parameters as :class:`DTensor` s + sharded on dim-0, while the unsharded parameters will be like the original + parameters on ``module`` (e.g. :class:`torch.Tensor` if originally + :class:`torch.Tensor`). A module + `forward pre-hook `_ + on ``module`` all-gathers the parameters, and a module + `forward hook `_ + on ``module`` frees them (if needed). Similar backward hooks all-gather + parameters and later free parameters and reduce-scatter gradients. + + Since grouping multiple tensors together for one collective is critical for + communication efficiency, this implementation makes this grouping first + class. Calling :meth:`fully_shard` on ``module`` constructs one group that + includes the parameters in ``module.parameters()`` except those already + assigned to a group from an earlier call on a submodule. This means that + :meth:`fully_shard` should be called bottom-up on your model. Each group's + parameters are all-gathered in one collective, and its gradients are + reduce-scattered in one collective. Partitioning the model into multiple + groups ("layer by layer") allows for peak memory savings and communication/computation + overlap. Users generally should *not* call :meth:`fully_shard` only on the + topmost root module. + + Args: + module (Union[nn.Module, List[nn.Module]): The module or modules to + shard with FSDP and group together for communication. + mesh (Optional[DeviceMesh]): This data parallel mesh defines the + sharding and device. If 1D, then parameters are fully sharded + across the 1D mesh (FSDP) with ``(Shard(0),)`` placement. If 2D, + then parameters are sharded across the 1st dim and replicated + across the 0th dim (HSDP) with ``(Replicate(), Shard(0))`` + placement. The mesh's device type gives the device type used for + communication; if a CUDA or CUDA-like device type, then we use the + current device. + reshard_after_forward (Optional[Union[bool, int]]): This controls the parameter + behavior after forward and can trade off memory and communication: + + - If ``True``, then this reshards parameters after forward and + re-all-gathers in backward. + - If ``False``, then this keeps the unsharded parameters in memory + after forward and avoids the all-gather in backward. For best performance, + we usually set ``False`` for the root module, because the root module + is typically required immediately when the backward pass begins. + - If ``None``, it is set to ``True`` for non-root modules and ``False`` + for root modules. + - If an ``int``, then this represents the world size to reshard to + after forward. It should be a non-trivial divisor of the ``mesh`` + shard dim size (i.e. excluding 1 and the dim size itself). A + choice may be the intra-node size (e.g. ``torch.cuda.device_count()``). + This allows the all-gather in backward to be over a smaller world + size at the cost of higher memory usage than setting to ``True``. + - After forward, the parameters registered to the module depend on + to this: The registered parameters are the sharded parameters if + ``True``; unsharded parameters if ``False``; and the parameters + resharded to the smaller mesh otherwise. To modify the parameters + between forward and backward, the registered parameters must be + the sharded parameters. For ``False`` or an ``int``, this can be + done by manually resharding via :meth:`reshard`. + shard_placement_fn (Optional[Callable[[nn.Parameter], Optional[Shard]]]): + This callable can be used to override the sharding placement for a + parameter to shard a parameter on a dimension other than dim-0. If + this callable returns a :class:`Shard` placement (not ``None``), + then FSDP will shard according to that placement (e.g. ``Shard(1)``). + If sharding on a nonzero dim, we currently require even sharding, + i.e. the tensor dim size on that dim must be divisible by the FSDP + shard mesh size. + mp_policy (MixedPrecisionPolicy): This controls the mixed precision + policy, which offers parameter/reduction mixed precision for this + module. See :class:`MixedPrecisionPolicy` for details. + offload_policy (OffloadPolicy): This controls the offloading policy, + which offers parameter/gradient/optimizer state offloading. See + :class:`OffloadPolicy` and its subclasses for details. + ignored_params: Optional(Set[nn.Parameter]): The set of parameters to be + ignored by FSDP. They will not be sharded, nor moved to the device + during init, nor have their gradients reduced in backward. + + Returns: + FSDPModule: The module with FSDP applied (in-place). + """ + torch._C._log_api_usage_once("torch.distributed.fsdp.fully_shard") + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + raise ValueError( + f"fully_shard does not support containers that do not implement forward: {module}" + ) + mesh = mesh or _init_default_fully_shard_mesh() + if mesh.ndim not in (1, 2): + raise ValueError(f"fully_shard expects a 1D or 2D DeviceMesh but got {mesh}") + elif mesh.ndim == 1: + mesh_info = FSDPMeshInfo(mesh, shard_mesh_dim=0) + else: + if mesh.mesh_dim_names is None: + raise AssertionError( + "Please init the 2D mesh for HSDP with mesh_dim_names specified" + ) + mesh_info = HSDPMeshInfo(mesh, shard_mesh_dim=1, replicate_mesh_dim=0) + device = _get_device_from_mesh(mesh) + auto_reshard_after_forward = reshard_after_forward is None + # If the user does not provide ``reshard_after_forward``, we set it to True. + # During lazy_init, we identify which module is the root and override its value to False + post_forward_mesh_info = _get_post_forward_mesh_info( + reshard_after_forward if not auto_reshard_after_forward else True, # type: ignore[arg-type] + mesh_info, + ) + + arg_module = module + modules = ( + (module,) if isinstance(module, nn.Module) else tuple(_get_root_modules(module)) + ) + state = fully_shard.state(modules[0]) # type: ignore[attr-defined] # see [1] + state.init(modules, device, mp_policy, auto_reshard_after_forward) + + managed_modules = _get_managed_modules(modules, ignored_params) + params, buffers = _get_managed_states(managed_modules, ignored_params) + + _move_states_to_device(params, buffers, device) + if params: + state._fsdp_param_group = FSDPParamGroup( + params, + modules, + mesh_info, + post_forward_mesh_info, + device, + shard_placement_fn, + mp_policy, + offload_policy, + ) + + # For Dynamo + for managed_module in managed_modules: + managed_module._is_fsdp_managed_module = True # type: ignore[assignment] + managed_module._fsdp_use_orig_params = True # type: ignore[assignment] + + # Place FSDP leftmost for highest priority in the method resolution order + for module in modules: + cls = module.__class__ + new_cls = cls_to_fsdp_cls.get(cls, None) + if not new_cls: + dct = {"__deepcopy__": _unimplemented_deepcopy} + new_cls = type(f"FSDP{cls.__name__}", (FSDPModule, cls), dct) + cls_to_fsdp_cls[cls] = new_cls + module.__class__ = new_cls + return arg_module + + +def _unimplemented_deepcopy(*args: Any, **kwargs: Any) -> NoReturn: + raise AssertionError( + "FSDP does not support deepcopy. Please use state dict for serialization." + ) + + +class FSDPModule: + def __new__(cls, *args, **kwargs): + """ + Override ``__new__`` to remove the FSDP class and directly construct + the original class for cases like indexing into a container module. + """ + # Use index 2 since 0 is the dynamically constructed `FSDP<...>` class + # and index 1 is the `FSDPModule` class itself + orig_cls = cls.__mro__[2] + self = orig_cls.__new__(orig_cls, *args, **kwargs) + self.__init__(*args, **kwargs) + return self + + def reshard(self) -> None: + """ + Reshards the module's parameters, freeing the unsharded parameters if + they are allocated and registering the sharded parameters to the + module. This method is *not* recursive. + """ + state = self._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard() + + def unshard(self, async_op: bool = False) -> Optional[UnshardHandle]: + """ + Unshards the module's parameters by allocating memory and all-gathering + the parameters. This method is *not* recursive. The unshard follows the + :class:`MixedPrecisionPolicy`, so it will all-gather following + ``param_dtype`` if set. + + Args: + async_op (bool): If ``True``, then returns a :class:`UnshardHandle` + that has a :meth:`wait` method to wait on the unshard op. If + ``False``, then returns ``None`` and waits on the handle inside + this function. + + .. note:: If ``async_op=True``, then FSDP will wait on the pending + unshard in the module's pre-forward for the user. The user only + needs to call :meth:`wait` explicitly if the wait should happen + before pre-forward. + """ + state = self._get_fsdp_state() + fsdp_param_group = state._fsdp_param_group + if fsdp_param_group is not None: + fsdp_param_group.lazy_init() + fsdp_param_group.unshard(async_op=async_op) + handle = _UnshardHandleImpl(fsdp_param_group) + if async_op: + return handle + handle.wait() + return None + + def set_is_last_backward(self, is_last_backward: bool) -> None: + """ + Sets whether the next backward is the last one. On the last backward, + FSDP waits on pending gradient reduction and clears internal data + data structures for backward prefetching. This can be useful for + microbatching. + """ + state = self._get_fsdp_state() + state._state_ctx.is_last_backward = is_last_backward + + def set_requires_gradient_sync( + self, requires_gradient_sync: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should sync gradients. This can be used to implement + gradient accumulation *without communication*. For HSDP, this controls + both reduce-scatter and all-reduce together. This is the equivalence of + `no_sync` in FSDP1. + + Args: + requires_gradient_sync (bool): Whether to reduce gradients for the + module's parameters. + recurse (bool): Whether to set for all FSDP submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reduce_grads = requires_gradient_sync + fsdp_param_group.all_reduce_grads = requires_gradient_sync + + def set_requires_all_reduce( + self, requires_all_reduce: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should all-reduce gradients. This can be used to + implement gradient accumulation with only reduce-scatter but not + all-reduce for HSDP. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.all_reduce_grads = requires_all_reduce + + def set_reshard_after_forward( + self, reshard_after_forward: bool, recurse: bool = True + ) -> None: + """ + Sets if the module should reshard parameters after forward. This can be + used to change the ``reshard_after_forward`` FSDP arg at runtime. For + example, this can be used to set the FSDP root module's value to + ``True`` (since it is otherwise specially set to ``False``), or it can + set an FSDP module's value to ``False`` for running evals and set back + to ``True`` for training. + + Args: + reshard_after_forward (bool): Whether to reshard parameters after + forward. + recurse (bool): Whether to set for all FSDP submodules or just the + passed-in module. + """ + if not isinstance(reshard_after_forward, bool): + raise ValueError( + f"reshard_after_forward should be a bool, got {type(reshard_after_forward)}" + ) + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + state._auto_reshard_after_forward = False + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.post_forward_mesh_info = ( + _get_post_forward_mesh_info( + reshard_after_forward, fsdp_param_group.mesh_info + ) + ) + + def set_reshard_after_backward( + self, reshard_after_backward: bool, *, recurse: bool = True + ) -> None: + """ + Sets if the module should reshard parameters after backward. This can + be used during gradient accumulation to trade off higher memory for + reduced communication since the unsharded parameters do not need to be + re-all-gathered before the next forward. + + Args: + reshard_after_backward (bool): Whether to reshard parameters after + backward. + recurse (bool): Whether to set for all FSDP submodules or just the + passed-in module. + """ + self_module = cast(nn.Module, self) + modules = list(self_module.modules()) if recurse else [self_module] + for module in modules: + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.reshard_after_backward = reshard_after_backward + + def set_modules_to_forward_prefetch(self, modules: list[FSDPModule]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in forward. The prefetching runs after this + module's all-gather copy-out. + + Passing a singleton list containing the next FSDP module gives the same + all-gather overlap behavior as the default overlap behavior, except the + prefetched all-gather is issued earlier from the CPU. Passing a list + with at least length two is required for more aggressive overlap and + will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_forward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_modules_to_backward_prefetch(self, modules: list[FSDPModule]) -> None: + """ + Sets the FSDP modules for which this FSDP module should explicitly + prefetch all-gathers in backward. This overrides the default backward + pretching implementation that prefetches the next FSDP module based on + the reverse post-forward order. + + Passing a singleton list containing the previous FSDP module gives the + same all-gather overlap behavior as the default overlap behavior. + Passing a list with at least length two is required for more aggressive + overlap and will use more reserved memory. + + Args: + modules (List[FSDPModule]): FSDP modules to prefetch. + """ + _assert_all_fsdp_modules(modules) + self._get_fsdp_state()._states_to_backward_prefetch = [ + module._get_fsdp_state() for module in modules + ] + + def set_all_reduce_hook( + self, + hook: Callable[[torch.Tensor], None], + *, + stream: Optional[torch.cuda.Stream] = None, + ): + """ + Args: + hook (Callable[[torch.Tensor], None]): User-defined all-reduce hook + with expected signature ``hook(reduce_output: torch.Tensor) -> None`` + where ``reduce_output`` is the reduce-scatter output if only + using FSDP or the all-reduce output if using native HSDP. + stream (Optional[torch.cuda.Stream]): Stream to run the all-reduce + hook in. This should only be set if not using native HSDP. If + using native HSDP, the hook will run in the internally defined + all-reduce stream used by the native HSDP all-reduce. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group._all_reduce_hook = hook + if stream is not None: + if fsdp_param_group._is_hsdp: + raise ValueError("stream cannot be set when using native HSDP") + fsdp_param_group._all_reduce_hook_stream = stream + + def set_post_optim_event(self, event: torch.Event) -> None: + """ + Sets a post-optimizer-step event for the root FSDP module to wait the + all-gather streams on. + + By default, the root FSDP module waits the all-gather streams on the + current stream to ensure that the optimizer step has finished before + all-gathering. However, this may introduce false dependencies if + there is unrelated computation after the optimizer step. This API + allows the user to provide their own event to wait on. After the root + waits on the event, the event is discarded, so this API should be + called with a new event each iteration. + + Args: + event (torch.Event): Event recorded after the optimizer step + to wait all-gather streams on. + """ + self._get_fsdp_state()._state_ctx.post_optim_event = event + + @deprecated("Use `set_gradient_divide_factor` instead") + def set_reduce_scatter_divide_factor(self, factor: float) -> None: + """Use :py:meth:`set_gradient_divide_factor` instead""" + self.set_gradient_divide_factor(factor) + + def set_gradient_divide_factor(self, factor: float) -> None: + """ + Sets a custom divide factor for the gradient reduction. This might use + a custom reduce op using NCCL's PreMulSum, which allows multiplying by + the factor before reduction. + + Args: + factor (float): Custom divide factor. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.gradient_divide_factor = factor + + def set_force_sum_reduction_for_comms(self, enable: bool) -> None: + """ + Sets whether to require the low-level collective communication + primitives to exclusively use "sum"-type reductions, even if it comes + at the cost of separate additional pre- or post-scaling operations. + This is needed for example because NCCL currently supports zero-copy + transfers only for this kind of collectives. + + NB: for MTIA devices, this is always implicitly enabled. + + NB: if `set_all_reduce_hook` is used under FSDP setup, the caller needs + to ensure the custom all-reduce across FSDP units follow this strategy + as well, as FSDP can no longer automatically handle that. + + Args: + enable (bool): Whether to only ever use ReduceOp.SUM for comms. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.force_sum_reduction_for_comms = enable + + def set_unshard_in_backward(self, unshard_in_backward: bool) -> None: + """ + Sets whether the FSDP module's parameters need to be unsharded in + backward. This can be used in expert cases when the user knows that all + parameters in this FSDP module's parameter group are not needed for + backward computation (e.g. embedding). + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.unshard_in_backward = unshard_in_backward + + def set_allocate_memory_from_process_group_for_comm(self, enable: bool) -> None: + """ + Sets whether the temporary staging buffers used to send and receive data + over collective communications should be allocated using the custom + optimized allocator provided by the ProcessGroup itself (if any). This + might allow the ProcessGroup to be more efficient. For example, when + using NCCL, this enables it to leverage zero-copy transfers over SHARP + (for NVLink and/or InfiniBand). + + Args: + enable (bool): Whether to turn on ProcessGroup allocation. + """ + state = self._get_fsdp_state() + if (fsdp_param_group := state._fsdp_param_group) is not None: + fsdp_param_group.allocate_memory_from_process_group = enable + + def _set_unshard_async_op(self, async_op: bool): + """ + Sets whether to use ``async_op=True`` or ``False`` for the pre-forward + and pre-backward unshard op. This defaults to ``False`` but can be set + to ``True`` with this method. + + Setting this to ``True`` allows the all-gather allocations to happen in + the default stream, avoiding inter-stream memory fragmentation. + However, you must use explicit prefetching (e.g. via :meth:`unshard`) + in forward to still get overlap, and the pre-all-gather ops like dtype + casting and copy-in will not overlap with compute. + """ + self_module = cast(nn.Module, self) + for module in self_module.modules(): + if isinstance(module, FSDPModule): + state = module._get_fsdp_state() + if fsdp_param_group := state._fsdp_param_group: + fsdp_param_group.unshard_async_op = async_op + + def _get_fsdp_state(self) -> FSDPState: + if (state := _get_module_fsdp_state(cast(nn.Module, self))) is None: + raise AssertionError(f"No FSDP state found on {self}") + return state + + def _apply(self, *args: Any, **kwargs: Any) -> Any: + # Reshard to ensure that sharded parameters are registered + self.reshard() + ret = super()._apply(*args, **kwargs) # type: ignore[misc] + state = self._get_fsdp_state() + if not (fsdp_param_group := state._fsdp_param_group): + return ret + # TODO: Remove this padding logic once DTensor pads the local tensor: + # https://github.com/pytorch/pytorch/issues/113045 + with torch.no_grad(): + for fsdp_param in fsdp_param_group.fsdp_params: + fsdp_param.reset_sharded_param() + return ret + + +class UnshardHandle: + """ + A handle to wait on a :meth:`FSDPModule.unshard` op. + """ + + def wait(self) -> None: + """ + Waits on the unshard op. This ensures that the current stream can use + the unsharded parameters, which are now registered to the module. + """ + return + + +class _UnshardHandleImpl(UnshardHandle): + def __init__(self, fsdp_param_group: Optional[FSDPParamGroup]): + self._fsdp_param_group = fsdp_param_group + + def wait(self): + if self._fsdp_param_group is not None: + self._fsdp_param_group.wait_for_unshard() + # Avoid keeping a reference + self._fsdp_param_group = None + + +def register_fsdp_forward_method(module: nn.Module, method_name: str) -> None: + """ + Registers a method on ``module`` to be considered a forward method for + FSDP. + + FSDP all-gathers parameters pre-forward and optionally frees parameters + post-forward (depending on ``reshard_after_forward``). FSDP only knows to + do this for :meth:`nn.Module.forward` by default. This function patches a + user-specified method to run the pre/post-forward hooks before/after the + method, respectively. If ``module`` is not an :class:`FSDPModule`, then + this is a no-op. + + Args: + module (nn.Module): Module to register the forward method on. + method_name (str): Name of the forward method. + """ + if not isinstance(module, FSDPModule): + # Make no-op to allow including both when using/not using FSDP + return + if not hasattr(module, method_name): + raise ValueError(f"{type(module)} does not have a method {method_name}") + orig_method = getattr(module, method_name) + + @functools.wraps(orig_method) + def wrapped_method(self, *args, **kwargs): + fsdp_state = self._get_fsdp_state() + args, kwargs = fsdp_state._pre_forward(self, args, kwargs) + out = orig_method(*args, **kwargs) + return fsdp_state._post_forward(self, args, out) + + # Use `__get__` to make `wrapped_method` an instance method + setattr( + module, + method_name, + wrapped_method.__get__(module, type(module)), # type:ignore[attr-defined] + ) + + +def _assert_all_fsdp_modules(modules: Iterable[Any]) -> None: + for module in modules: + if not isinstance(module, FSDPModule): + raise ValueError(f"Expects FSDPModule but got {type(module)}: {module}") diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_init_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_init_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1276017de77a0a596fadd99f0c34fc6626dc8ed1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_init_utils.py @@ -0,0 +1,1186 @@ +# mypy: allow-untyped-defs +import collections +import itertools +import os +import warnings +from collections.abc import Generator, Iterable, Iterator +from typing import Any, Callable, no_type_check, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._exec_order_utils as exec_order_utils +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.distributed.fsdp.fully_sharded_data_parallel as fsdp_file +import torch.nn as nn +from torch.distributed.algorithms._comm_hooks import default_hooks +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.distributed_c10d import _get_default_group +from torch.distributed.fsdp._common_utils import ( + _FSDPDeviceHandle, + _FSDPState, + _get_module_fsdp_state, + _is_fsdp_flattened, + _named_parameters_with_duplicates, + clean_tensor_name, + TrainingState, +) +from torch.distributed.fsdp._flat_param import ( + _FSDP_USE_FULL_PREC_IN_EVAL, + FlatParameter, + FlatParamHandle, + HandleShardingStrategy, +) +from torch.distributed.fsdp._limiter_utils import _FreeEventQueue +from torch.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + MixedPrecision, + ShardingStrategy, + StateDictConfig, + StateDictType, +) +from torch.distributed.fsdp.wrap import _Policy +from torch.distributed.tensor.parallel.fsdp import DTensorExtensions +from torch.distributed.utils import _sync_params_and_buffers +from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + +if TYPE_CHECKING: + from torch.utils.hooks import RemovableHandle + +_TORCHDISTX_AVAIL = True +try: + from torchdistx import deferred_init, fake # type: ignore[import] +except ImportError: + _TORCHDISTX_AVAIL = False + +PARAM_BROADCAST_BUCKET_SIZE = int(250 * 1024 * 1024) +FSDP_SYNCED = "_fsdp_synced" +# Specification of process groups for hybrid sharding strategies. +HybridShardProcessGroupType = tuple[dist.ProcessGroup, dist.ProcessGroup] +# Overall specification of process group. +ProcessGroupType = Optional[Union[dist.ProcessGroup, HybridShardProcessGroupType]] + + +# TODO (awgu): Refactor this later +SHARDING_STRATEGY_MAP = { + ShardingStrategy.NO_SHARD: HandleShardingStrategy.NO_SHARD, + ShardingStrategy.FULL_SHARD: HandleShardingStrategy.FULL_SHARD, + ShardingStrategy.SHARD_GRAD_OP: HandleShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy.HYBRID_SHARD: HandleShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2: HandleShardingStrategy._HYBRID_SHARD_ZERO2, +} +HYBRID_SHARDING_STRATEGIES = [ + ShardingStrategy.HYBRID_SHARD, + ShardingStrategy._HYBRID_SHARD_ZERO2, +] +NO_RESHARD_AFTER_FORWARD_STRATEGIES = ( + ShardingStrategy.SHARD_GRAD_OP, + ShardingStrategy._HYBRID_SHARD_ZERO2, +) + + +# NOTE: Since non-self attributes cannot be type annotated, several attributes +# on `state` are defined first as local variables before being assigned. + + +@no_type_check +def _init_process_group_state( + state: _FSDPState, + process_group: ProcessGroupType, + sharding_strategy: ShardingStrategy, + policy: Optional[_Policy], + device_mesh: Optional[DeviceMesh] = None, +) -> _FSDPState: + if process_group is not None and device_mesh is not None: + raise ValueError( + "Cannot pass both process_group and device_mesh at the " + "same time. Please just pass only one of them." + ) + is_hybrid_strategy = sharding_strategy in HYBRID_SHARDING_STRATEGIES + if is_hybrid_strategy: + if process_group is None and policy is None and device_mesh is None: + # Raise an error here, since this is manual wrapping with no process group + # passed in, there is no way to ensure all wrapped FSDP instances use the same + # process groups. + raise ValueError( + f"Manual wrapping with {sharding_strategy} " + "requires explicit specification of process group or device_mesh." + ) + else: + state = _init_process_group_state_for_hybrid_shard( + state, process_group, device_mesh + ) + else: + if device_mesh: + state._device_mesh = device_mesh + state.process_group = device_mesh.get_group(mesh_dim=0) + else: + state.process_group = ( + process_group if process_group is not None else _get_default_group() + ) + + state.rank = state.process_group.rank() + state.world_size = state.process_group.size() + data_parallel_world_size = state.world_size + if is_hybrid_strategy: + data_parallel_world_size *= state._inter_node_pg.size() + state._gradient_predivide_factor = ( + default_hooks.DefaultState._get_gradient_predivide_factor( + data_parallel_world_size + ) + ) + state._gradient_postdivide_factor = ( + data_parallel_world_size / state._gradient_predivide_factor + ) + return state + + +@no_type_check +def _init_process_group_state_for_hybrid_shard( + state: _FSDPState, + process_group: ProcessGroupType, + device_mesh: DeviceMesh, +) -> _FSDPState: + if device_mesh: + if _is_valid_hybrid_shard_device_mesh(device_mesh): + state._device_mesh = device_mesh + # We currently only allow _inter_node_pg to be the outermost dimension, and the + # process_group(intra_node) to be the innermost dimension. + state._inter_node_pg = device_mesh.get_group(mesh_dim=0) + state.process_group = device_mesh.get_group(mesh_dim=1) + else: + raise ValueError( + f"Expected device_mesh to have ndim=2 but got {device_mesh.ndim}" + ) + elif process_group is None: + default_group = _get_default_group() + intra_node_group, inter_node_group = _init_intra_and_inter_node_groups( + default_group, state._device_handle.device_count() + ) + # we shard across intra-node + state.process_group = intra_node_group + # save _inter_node_pg to allreduce across. + state._inter_node_pg = inter_node_group + else: + # Check type and assign state.process_group and state._inter_node_pg. + if _is_valid_hybrid_shard_pg_type(process_group): + # Assuming that user passed in as intra node group and inter node group + # as documented. + state.process_group, state._inter_node_pg = process_group + else: + raise ValueError( + "Expected process_group to be passed in as either None or " + f"Tuple[dist.ProcessGroup, dist.ProcessGroup] but got {type(process_group)}" + ) + # Create state for allreduce + state._inter_node_state = _get_default_comm_hook_state( + process_group=state._inter_node_pg, + ) + return state + + +@no_type_check +def _is_valid_hybrid_shard_pg_type(process_group: Any) -> bool: + return ( + isinstance(process_group, tuple) + and len(process_group) == 2 + and all(isinstance(pg, dist.ProcessGroup) for pg in process_group) + ) + + +@no_type_check +def _is_valid_hybrid_shard_device_mesh(device_mesh: DeviceMesh) -> bool: + return isinstance(device_mesh, DeviceMesh) and device_mesh.ndim == 2 + + +@no_type_check +def _init_intra_node_process_group(num_devices_per_node: int) -> dist.ProcessGroup: + """ + Return a process group across the current node. + + For example, given each row is a distinct node: + 0 1 2 3 4 5 6 7 + 8 9 10 11 12 13 14 15 + This API would return an intra-node subgroup across + [0, 1, ..., 7] or [8, 9, ..., 15] depending on the process's rank. + For example, rank 3 would get [0, 1, ..., 7]. + """ + intra_node_subgroup, _ = dist.new_subgroups(num_devices_per_node) + return intra_node_subgroup + + +@no_type_check +def _init_inter_node_process_group( + global_process_group: dist.ProcessGroup, + num_devices_per_node: int, +) -> dist.ProcessGroup: + """ + Return an inter-node process group where each contained rank has the same local rank. + + For example, given each row is a distinct node: + 0 1 2 3 4 5 6 7 + 8 9 10 11 12 13 14 15 + This API would return inter-node process group [0, 8], [1, 9], [2, 10], and so forth + depending on the process's rank. For example, rank 1 would get [1, 9], rank 5 + would get [5, 13]. + """ + # the inter-node pg that is returned + inter_node_pg = None + sharding_backend = dist.get_backend(global_process_group) + world_size = dist.get_world_size(global_process_group) + # Assuming fully homogeneous setup + num_nodes = world_size // num_devices_per_node + my_local_rank = dist.get_rank(global_process_group) % num_devices_per_node + for local_rank in range(num_devices_per_node): + ranks_for_inter_group = [ + local_rank + (i * num_devices_per_node) for i in range(num_nodes) + ] + # every rank always needs to call dist.new_group + grp = dist.new_group(ranks=ranks_for_inter_group, backend=sharding_backend) + if local_rank == my_local_rank: + inter_node_pg = grp + + assert inter_node_pg is not None, ( + f"{my_local_rank} expected to assign inter-node pg, but did not" + ) + return inter_node_pg + + +def _init_intra_and_inter_node_groups( + global_process_group: dist.ProcessGroup, + num_devices_per_node: int, +) -> tuple[dist.ProcessGroup, dist.ProcessGroup]: + """ + Initialize intra and inter-node process groups and return the ones corresponding to this process's rank. + + This function can be used to initialize process groups for ``HYBRID_SHARD`` or + ``_HYBRID_SHARD_ZERO2`` in FSDP. + This function assumes each node has an equal number of CUDA-enabled devices. + Returns: + Tuple[dist.ProcessGroup, dist.ProcessGroup]: Intra and inter-node process group. + """ + return ( + _init_intra_node_process_group(num_devices_per_node), + _init_inter_node_process_group(global_process_group, num_devices_per_node), + ) + + +@no_type_check +def _init_ignored_module_states( + state: _FSDPState, + module: nn.Module, + ignored_modules: Optional[Iterable[torch.nn.Module]], + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, +) -> _FSDPState: + if ignored_modules is not None and ignored_states is not None: + raise ValueError( + "Cannot pass both ignored_modules and ignored_states at the " + "same time. Please just pass ignored_states." + ) + ignored_parameters = None + passed_as_ignored_states = ignored_states is not None + if passed_as_ignored_states: + ignored_states_list = list(ignored_states) + _check_ignored_states(ignored_states_list, True) + else: + ignored_states_list = [] + _check_ignored_states( + list(ignored_modules) if ignored_modules is not None else [], False + ) + if len(ignored_states_list) > 0: + if isinstance(ignored_states_list[0], nn.Parameter): + ignored_parameters = ignored_states_list + else: + ignored_modules = ignored_states_list + state._ignored_modules = _get_ignored_modules(module, ignored_modules) + state._ignored_params = _get_ignored_params( + module, + state._ignored_modules, + ignored_parameters, + ) + state._ignored_buffer_names = _get_ignored_buffer_names( + module, + state._ignored_modules, + ) + # TODO: FSDP's contract for buffers is not well-defined. They are + # implicitly ignored for most functionality since they are not sharded; + # however, FSDP still imposes some semantics on buffers (e.g. buffer mixed + # precision). We should formalize this contract and decide if we need to + # compute and store `_ignored_buffers`. + return state + + +def _check_ignored_states( + ignored_states: list[Any], passed_as_ignored_states: bool +) -> None: + """ + Check that the ignored states are uniformly parameters or uniformly modules. + + We may remove this check in the future if we permit mixing. + """ + if len(ignored_states) == 0: + return + if passed_as_ignored_states: + all_params = all(isinstance(state, nn.Parameter) for state in ignored_states) + all_modules = all(isinstance(state, nn.Module) for state in ignored_states) + if not all_params and not all_modules: + # Sort for consistent ordering for unit test regex matching + sorted_types = sorted({type(state) for state in ignored_states}, key=repr) + raise ValueError( + "ignored_states expects all nn.Parameter or all nn.Module list " + f"elements but got types {sorted_types}" + ) + else: + if not all(isinstance(state, nn.Module) for state in ignored_states): + sorted_types = sorted({type(state) for state in ignored_states}, key=repr) + raise ValueError( + "ignored_modules expects nn.Module list elements but got " + f"types {sorted_types}" + ) + + +@no_type_check +def _init_device_handle( + state: _FSDPState, + module: nn.Module, + ignored_params: set[nn.Parameter], + device_id: Optional[Union[int, torch.device]], +) -> _FSDPState: + """ + Determine device handle used for initializing FSDP. + + If a device is specified by ``device_id``, + then returns device handle corresponds to that device type. Otherwise, If the + module is already on a non-CPU device, then the device type is that non-CPU device type. + If the module is on CPU or meta, then the device type is the current accelerator device. + See the :ref:`Accelerators` for details. + + + This method will be called once ignored parameters was determined, as the device handle maybe needed + for other initialization. + """ + determined_device = None + if device_id is not None: + determined_device = ( + device_id + if isinstance(device_id, torch.device) + else torch.device(device_id) + ) + if determined_device is None: + for param in _get_orig_params(module, ignored_params): + if param.device.type in {"cpu", "meta"}: + continue + if determined_device is None: + determined_device = param.device + else: + if param.device.type != determined_device.type: + raise RuntimeError( + f"FSDP does not support modules with different device types " + f"but got params on {determined_device.type} and {param.device.type}" + ) + determined_device = determined_device or torch._C._get_accelerator() + if determined_device.type == "cpu": + raise RuntimeError( + "FSDP needs a non-CPU accelerator device, but no accelerator device is detected." + ) + + state._device_handle = _FSDPDeviceHandle.from_device(determined_device) + return state + + +@no_type_check +def _init_buffer_state( + state: _FSDPState, + module: nn.Module, +) -> _FSDPState: + state._buffer_names = _get_buffer_names(module) + # Save a mapping from clean fully-qualified buffer name (starting from + # `module`) to its original dtype for restoring that dtype during model + # checkpointing when buffer mixed precision is enabled. The names should + # be clean since the casting happens in a `summon_full_params()` context. + _buffer_name_to_orig_dtype: dict[str, torch.dtype] = {} + for buffer_name, buffer in module.named_buffers(): + buffer_name = clean_tensor_name(buffer_name) + _buffer_name_to_orig_dtype[buffer_name] = buffer.dtype + state._buffer_name_to_orig_dtype = _buffer_name_to_orig_dtype + return state + + +@no_type_check +def _init_core_state( + state: _FSDPState, + sharding_strategy: Optional[ShardingStrategy], + mixed_precision: Optional[MixedPrecision], + cpu_offload: Optional[CPUOffload], + limit_all_gathers: bool, + use_orig_params: bool, + backward_prefetch_limit: int, + forward_prefetch_limit: int, +) -> _FSDPState: + # We clamp the strategy to `NO_SHARD` for world size of 1 since they are + # currently functionally equivalent. This may change if/when we integrate + # FSDP with MoE. + if state.world_size == 1: + if sharding_strategy != ShardingStrategy.NO_SHARD: + warnings.warn( + "FSDP is switching to use `NO_SHARD` instead of " + f"{sharding_strategy or ShardingStrategy.FULL_SHARD} since " + "the world size is 1." + ) + sharding_strategy = ShardingStrategy.NO_SHARD + elif sharding_strategy == ShardingStrategy.NO_SHARD: + warnings.warn( + "The `NO_SHARD` sharding strategy is deprecated. If having issues, " + "please use `DistributedDataParallel` instead.", + FutureWarning, + # Level 1 is here, level 2 is from `FullyShardedDataParallel`, and + # level 3 is from the true caller + stacklevel=3, + ) + state.sharding_strategy = sharding_strategy or ShardingStrategy.FULL_SHARD + state.mixed_precision = mixed_precision or MixedPrecision() + if mixed_precision is not None: + torch._C._log_api_usage_once( + f"torch.distributed.fsdp.mixed_precision.{str(state.mixed_precision)}" + ) + state._use_full_prec_in_eval = ( + os.environ.get(_FSDP_USE_FULL_PREC_IN_EVAL, "") == "1" + ) + state.cpu_offload = cpu_offload or CPUOffload() + state.limit_all_gathers = limit_all_gathers + state._use_orig_params = use_orig_params + state.training_state = TrainingState.IDLE + state._is_root = None + state._free_event_queue = _FreeEventQueue() + state._debug_level = dist.get_debug_level() + state._exec_order_data = exec_order_utils._ExecOrderData( + state._debug_level, + backward_prefetch_limit, + forward_prefetch_limit, + ) + state._unshard_event = None + # Mapping from fully sharded module to the handles it is responsible to + # unshard and reshard (see [Note: Fully Sharded Module]) + _fully_sharded_module_to_handle: dict[nn.Module, FlatParamHandle] = {} + state._fully_sharded_module_to_handle = _fully_sharded_module_to_handle + # Invariant: `state.params` contains exactly the `FlatParameter`s of the + # handles in `state._handle` + _handle: Optional[FlatParamHandle] = None + state._handle = _handle + params: list[FlatParameter] = [] + state.params = params + return state + + +@no_type_check +def _init_runtime_state( + state: _FSDPState, +) -> _FSDPState: + _root_pre_forward_handles: list[RemovableHandle] = [] + state._root_pre_forward_handles = _root_pre_forward_handles + _pre_forward_handles: list[RemovableHandle] = [] + state._pre_forward_handles = _pre_forward_handles + _post_forward_handles: list[RemovableHandle] = [] + state._post_forward_handles = _post_forward_handles + state._sync_gradients = True + state._comm_hook = None + state._comm_hook_state = None + # Used to prevent running the pre-backward hook multiple times + return state + + +@no_type_check +def _init_prefetching_state( + state: _FSDPState, + backward_prefetch: BackwardPrefetch, + forward_prefetch: bool, +) -> _FSDPState: + state.backward_prefetch = backward_prefetch + state.forward_prefetch = forward_prefetch + # The data structures use tuples of handles to generalize over the case + # where a module's forward involves multiple handles. + return state + + +@no_type_check +def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: + # TODO: we need to add additional check once we support FSDP + PiPPy. + # This check is currently sufficient, since we only support FSDP + TP. + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if device_mesh and root_mesh != state._device_mesh: + state._fsdp_extension = DTensorExtensions(state._device_handle) + else: + # We need to explicitly set _fsdp_extension to None. + # Otherwise, we will run into an infinite recursion when getting the attribute. + state._fsdp_extension = None + return state + + +@no_type_check +def _init_state_dict_state(state: _FSDPState) -> _FSDPState: + state._state_dict_type = StateDictType.FULL_STATE_DICT + state_dict_config: StateDictConfig = FullStateDictConfig() + state._optim_state_dict_config = FullOptimStateDictConfig() + state._state_dict_config = state_dict_config + unshard_params_ctx: dict[nn.Module, Generator] = {} + state._unshard_params_ctx = unshard_params_ctx + + return state + + +def _verify_managed_params(module: nn.Module, params: list[nn.Parameter]) -> None: + """ + Verify if the parameters are accepted by FSDP. The only restriction now + is that the parameter cannot be a scalar tensor (param.shape == []). + """ + for param in params: + if len(param.shape) == 0: + param_name = "" + for name, param_ in module.named_parameters(): + if param is param_: + param_name = name + break + assert param_name + raise ValueError( + "FSDP doesn't support scalar parameters. " + f"Change {param_name} to a 1D tensor with numel equal to 1." + ) + + +@no_type_check +def _init_param_handle_from_module( + state: _FSDPState, + fully_sharded_module: nn.Module, + device_id: Optional[Union[int, torch.device]], + param_init_fn: Optional[Callable[[nn.Module], None]], + sync_module_states: bool, +) -> _FSDPState: + """Initialize a ``FlatParamHandle`` from a module ``fully_sharded_module``.""" + _check_single_device_module(fully_sharded_module, state._ignored_params, device_id) + device_from_device_id = _get_device_from_device_id( + device_id, state.rank, state._device_handle + ) + is_meta_module, is_torchdistX_deferred_init = _need_to_materialize_module( + fully_sharded_module, state._ignored_params, state._ignored_modules + ) + # Materialize the module if needed + if (is_meta_module or is_torchdistX_deferred_init) and param_init_fn is not None: + _materialize_with_param_init_fn( + fully_sharded_module, param_init_fn, state._ignored_modules + ) + elif is_meta_module: + _materialize_meta_module( + fully_sharded_module, + device_id, + state._ignored_modules, + state._device_handle, + ) + elif is_torchdistX_deferred_init: + deferred_init.materialize_module( + fully_sharded_module, + check_fn=lambda submodule: _get_module_fsdp_state(submodule) is None + and submodule not in state._ignored_modules, + ) + + ignored_buffers = { + buffer + for ignored_module in state._ignored_modules + for buffer in ignored_module.buffers() + } + + _move_module_to_device( + fully_sharded_module, + state._ignored_params, + ignored_buffers, + device_from_device_id, + ) + state.compute_device = _get_compute_device( + fully_sharded_module, + state._ignored_params, + device_from_device_id, + state.rank, + state._device_handle, + ) + + managed_params = list(_get_orig_params(fully_sharded_module, state._ignored_params)) + _verify_managed_params(fully_sharded_module, managed_params) + if sync_module_states: + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state.process_group + ) + if state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + _sync_module_params_and_buffers( + fully_sharded_module, managed_params, state._inter_node_pg + ) + _init_param_handle_from_params(state, managed_params, fully_sharded_module) + return state + + +@no_type_check +def _init_param_handle_from_params( + state: _FSDPState, + params: list[nn.Parameter], + fully_sharded_module: nn.Module, +): + if len(params) == 0: + return + handle = FlatParamHandle( + params, + fully_sharded_module, + state.compute_device, + SHARDING_STRATEGY_MAP[state.sharding_strategy], + state.cpu_offload.offload_params, + state.mixed_precision.param_dtype, + state.mixed_precision.reduce_dtype, + state.mixed_precision.keep_low_precision_grads, + state.process_group, + state._use_orig_params, + fsdp_extension=state._fsdp_extension, + ) + handle.shard() + assert not state._handle + state.params.append(handle.flat_param) + state._handle = handle + state._fully_sharded_module_to_handle[handle._fully_sharded_module] = handle + cpu_device = torch.device("cpu") + if state.cpu_offload.offload_params and handle.flat_param.device != cpu_device: + handle.flat_param_to(cpu_device) + + +def _get_ignored_modules( + root_module: nn.Module, + _ignored_modules: Optional[Iterable[torch.nn.Module]], +) -> set[nn.Module]: + """ + Check that ``_ignored_modules`` is an iterable of ``nn.Module`` s without any FSDP instances. + + Return the modules contained in their module + subtrees as a :class:`set`. Nested FSDP instances are excluded, but their + already-computed ignored modules are included. + + ``_ignored_modules`` represents the argument passed by the user to FSDP. + """ + msg_prefix = "`ignored_modules` should be an iterable of `torch.nn.Module`s " + try: + ignored_root_modules = ( + set(_ignored_modules) if _ignored_modules is not None else set() + ) + except TypeError as e: + raise TypeError(msg_prefix + f"but got {type(_ignored_modules)}") from e + for module in ignored_root_modules: + if not isinstance(module, torch.nn.Module): + raise TypeError(msg_prefix + f"but got an iterable with {type(module)}") + if _get_module_fsdp_state(module): + # TODO: We may relax this by taking the FSDP instance's wrapped + # module to provide more flexibility to the user. + raise ValueError("`ignored_modules` should not include FSDP modules") + # Treat modules that cannot compose with `fully_shard` as ignored modules, + # meaning that their subtrees are ignored + for module in root_module.modules(): + if not traversal_utils._composable(module): + ignored_root_modules.add(module) + # NOTE: Even if `ignored_root_modules` is empty, do not return early so + # that this FSDP instance can get any ignored modules from its children. + + # Include child modules and exclude nested FSDP modules themselves + ignored_modules = { + child + for module in ignored_root_modules + for child in module.modules() + if not isinstance(child, fsdp_file.FullyShardedDataParallel) + } + if root_module in ignored_modules: + warnings.warn( + "Trying to ignore the top-level module passed into the FSDP " + "constructor itself will result in all parameters being " + f"ignored and is not well-supported: {module}" + ) + # Include nested FSDP modules' ignored modules + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + assert hasattr(optional_fsdp_state, "_ignored_modules") + ignored_modules.update(optional_fsdp_state._ignored_modules) + return ignored_modules + + +def _get_ignored_params( + root_module: torch.nn.Module, + ignored_modules: set[torch.nn.Module], + ignored_parameters: Optional[Iterable[torch.nn.Parameter]] = None, +) -> set[torch.nn.Parameter]: + """ + Return the parameters of the modules in ``ignored_modules`` and the parameters in ``ignored_parameters``. + + :class:`FlatParameter` s are excluded from the result. + """ + all_ignored_params: set[torch.nn.Parameter] = set() + + params_in_ignored_modules = { + p for m in ignored_modules for p in m.parameters() if not _is_fsdp_flattened(p) + } + + all_ignored_params.update(params_in_ignored_modules) + + if ignored_parameters is not None: + params_in_ignored_parameters = { + p for p in ignored_parameters if not _is_fsdp_flattened(p) + } + all_ignored_params.update(params_in_ignored_parameters) + + # Always include nested FSDP modules' ignored parameters + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + assert hasattr(optional_fsdp_state, "_ignored_params") + all_ignored_params.update(optional_fsdp_state._ignored_params) + + return all_ignored_params + + +def _get_ignored_buffer_names( + root_module: torch.nn.Module, + ignored_modules: set[torch.nn.Module], +) -> set[str]: + """Return the cleaned buffer FQNs in ``ignored_modules``.""" + all_ignored_buffer_names: set[str] = set() + + buffers_in_ignored_modules = { + buffer for m in ignored_modules for buffer in m.buffers() + } + + all_ignored_buffer_names.update( + { + clean_tensor_name(buffer_name) + for buffer_name, buffer in root_module.named_buffers() + if buffer in buffers_in_ignored_modules + } + ) + + # Always include nested FSDP modules' ignored buffer names + for submodule in root_module.modules(): + optional_fsdp_state = _get_module_fsdp_state(submodule) + if optional_fsdp_state is not None: + assert hasattr(optional_fsdp_state, "_ignored_buffer_names") + all_ignored_buffer_names.update(optional_fsdp_state._ignored_buffer_names) + + return all_ignored_buffer_names + + +def _get_buffer_names(root_module: nn.Module) -> set[str]: + """Return the fully prefixed names of all buffers in the module hierarchy rooted at ``root_module`` as a class:`set`.""" + return { + clean_tensor_name(buffer_name) for buffer_name, _ in root_module.named_buffers() + } + + +def _check_single_device_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + device_id: Optional[Union[int, torch.device]], +) -> None: + """ + Raise an error if ``module`` has original parameters on multiple devices, ignoring the parameters in ``ignored_params``. + + Thus, after this method, the + module must be either fully on the CPU or fully on a non-CPU device. + """ + devices = {param.device for param in _get_orig_params(module, ignored_params)} + # We allow module to be partially on CPU and partially on GPU if device_id is not + # None, since the device_id arg will result in the CPU portion being moved to + # GPU. This is useful in cases where part of the module may be parallelized + # by another algorithm and may already be on GPU. We'd like to enforce device_id + # to not be None, otherwise we'd flatten parameters in a mixed module which is + # not supported. + if len(devices) == 2 and torch.device("cpu") in devices: + if device_id is None: + raise RuntimeError( + "To support a module with both CPU and GPU params, " + "please pass in device_id argument." + ) + elif len(devices) > 1: + raise RuntimeError( + f"FSDP only supports single device modules but got params on {devices}" + ) + + +def _get_device_from_device_id( + device_id: Optional[Union[int, torch.device]], + rank: int, + device_handle: _FSDPDeviceHandle, +) -> Optional[torch.device]: + """ + Return a ``torch.device`` for the specified ``device_id``. + + Processes ``device_id`` and returns either the corresponding device or + ``None`` if ``device_id`` is ``None``. + """ + if device_id is None: + return None + device = ( + device_id if isinstance(device_id, torch.device) else torch.device(device_id) + ) + if device.type != "cpu" and device.index is None: + warnings.warn( + f"FSDP got the argument `device_id` {device_id} on rank " + f"{rank}, which does not have an explicit index. " + f"FSDP will use the current device {device_handle.current_device()}. " + f"If this is incorrect, please explicitly call `torch.{device.type}.set_device()` " + "before FSDP initialization or pass in the explicit device " + "index as the `device_id` argument." + ) + device = torch.device(device_handle.current_device()) + return device + + +def _need_to_materialize_module( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignored_modules: set[nn.Module], +) -> tuple[bool, bool]: + """ + Return if ``module`` has parameters on meta device and if ``module`` is using torchdistX deferred initialization. + + At most of the returned bools can + be ``True``. If either is ``True``, then ``module`` needs to be + materialized. + """ + managed_params = list(_get_orig_params(module, ignored_params)) + is_meta_module = any(param.is_meta for param in managed_params) + # TODO: We need to establish a contract for FSDP and buffers. For now, we + # skip checking for meta buffers from ignored modules. We should consider + # refactoring the initialization holistically to avoid so many traversals. + for submodule in module.modules(): + if submodule in ignored_modules: + continue + for buf in submodule.buffers(recurse=False): + is_meta_module |= buf.is_meta + is_torchdistX_deferred_init = ( + not is_meta_module + and _TORCHDISTX_AVAIL + and any(fake.is_fake(param) for param in managed_params) + ) + return is_meta_module, is_torchdistX_deferred_init + + +def _materialize_with_param_init_fn( + root_module: nn.Module, + param_init_fn: Callable[[nn.Module], None], + ignored_modules: set[nn.Module], +) -> None: + if not callable(param_init_fn): + raise ValueError( + f"Expected {param_init_fn} to be callable but got {type(param_init_fn)}" + ) + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + for module in modules_to_materialize: + param_init_fn(module) + + +def _materialize_meta_module( + root_module: nn.Module, + device_from_device_id: Optional[torch.device], + ignored_modules: set[nn.Module], + device_handle: _FSDPDeviceHandle, +): + # Run default meta device initialization + materialization_device = device_from_device_id or torch.device( + device_handle.current_device() + ) + modules_to_materialize = _get_modules_to_materialize(root_module, ignored_modules) + module = None + try: + # Assume that each module's `reset_parameters()` only initializes its + # own parameters and not those of its children + with torch.no_grad(): + for module in modules_to_materialize: + # As a contract to the user, only call `reset_parameters()` if + # the module has directly managed parameters/buffers + module_state_iter = itertools.chain( + module.parameters(recurse=False), module.buffers(recurse=False) + ) + has_module_states = len(list(module_state_iter)) > 0 + if has_module_states: + module.to_empty(device=materialization_device, recurse=False) + module.reset_parameters() # type: ignore[operator] + except BaseException as e: + warnings.warn( + "Unable to call `reset_parameters()` for module on meta " + f"device with error {str(e)}. Please ensure that your module of" + f"type {type(module)} implements a `reset_parameters()` method." # type: ignore[possibly-undefined] + ) + raise e + + +def _get_modules_to_materialize( + root_module: nn.Module, ignored_modules: set[nn.Module] +) -> list[nn.Module]: + # Run BFS to collect the modules to materialize via `reset_parameters()`, + # stopping at any module with FSDP already applied or at ignored modules. + modules_to_materialize: list[nn.Module] = [] + queue = collections.deque([root_module]) + visited_modules: set[nn.Module] = {root_module} + while queue: + module = queue.popleft() + modules_to_materialize.append(module) + for child_module in module.children(): + if ( + child_module not in visited_modules + and _get_module_fsdp_state(child_module) is None + and child_module not in ignored_modules + ): + visited_modules.add(child_module) + queue.append(child_module) + return modules_to_materialize + + +def _move_module_to_device( + module: nn.Module, + ignored_params: set[nn.Parameter], + ignored_buffers: set[torch.Tensor], + device_from_device_id: Optional[torch.device], +) -> None: + """ + Move ``module`` depending on ``device_from_device_id`` and its current device. + + This includes moving ignored modules' parameters. + + - If ``device_from_device_id`` is not ``None``, then this moves + ``module`` to the device. + - If ``device_from_device_id`` is ``None``, then this does not move + ``module`` but warns the user if it is on CPU. + + Precondition: ``_check_single_device_module()``. + """ + cpu_device = torch.device("cpu") + if device_from_device_id is not None: + # BFS from `module` without traversing any nested FSDP instances to + # collect the parameters/buffers that have not yet been managed + queue: collections.deque[nn.Module] = collections.deque() + queue.append(module) + params: list[nn.Parameter] = [] + buffers: list[torch.Tensor] = [] + while queue: + curr_module = queue.popleft() + # NOTE: We include a check to only move parameters/buffers that are + # on CPU device. If they are on a CUDA device different from the + # one specified by `device_id`, then this does NOT move them. This + # is so that we can raise an error in `_get_compute_device()`. + params.extend( + param + for param in curr_module.parameters(recurse=False) + if param.device == cpu_device + ) + buffers.extend( + buffer + for buffer in curr_module.buffers(recurse=False) + if buffer.device == cpu_device + ) + for submodule in curr_module.children(): + if not isinstance(submodule, fsdp_file.FullyShardedDataParallel): + queue.append(submodule) + params_to_move = [p for p in params if p not in ignored_params] + bufs_to_move = [p for p in buffers if p not in ignored_buffers] + _move_states_to_device(params_to_move, bufs_to_move, device_from_device_id) + return + param = next(_get_orig_params(module, ignored_params), None) + if param is not None and param.device == cpu_device: + _warn_cpu_init() + + +def _move_states_to_device( + params: list[nn.Parameter], + buffers: list[torch.Tensor], + device_from_device_id: Optional[torch.device], +) -> None: + """ + Move states to the specified device. + + Precondition: ``_check_single_device_module()`` and module's parameters and + buffers have been materialized if needed. + """ + if len(params) == 0 and len(buffers) == 0: + return + if len(params) > 0: + current_device = params[0].device + elif len(buffers) > 0: + current_device = buffers[0].device + cpu_device = torch.device("cpu") + if device_from_device_id is not None: + # Move the parameters and buffers like the `.data` code path in + # `nn.Module._apply()`, which underlies `nn.Module.to()` + for param in params: + with torch.no_grad(): + param.data = param.to(device_from_device_id) + if param.grad is not None: + param.grad.data = param.grad.to(device_from_device_id) + for buffer in buffers: + buffer.data = buffer.to(device_from_device_id) + elif current_device == cpu_device: # type: ignore[possibly-undefined] + _warn_cpu_init() + + +def _warn_cpu_init(): + warnings.warn( + "The passed-in `module` is on CPU and will thus have FSDP's sharding " + "initialization run on CPU, which may be slower than on GPU. We " + "recommend passing in the `device_id` argument for FSDP to move " + "`module` to GPU for the sharding initialization. `module` must also " + "be on GPU device to work with the `sync_module_states=True` flag " + "since that requires GPU communication." + ) + + +def _get_compute_device( + module: nn.Module, + ignored_params: set[nn.Parameter], + device_from_device_id: Optional[torch.device], + rank: int, + device_handle: _FSDPDeviceHandle, +) -> torch.device: + """ + Determine and return this FSDP instance's compute device. + + If the module is already on a non-CPU device, then the compute device is that non-CPU + device. If the module is on CPU, then the compute device is the current + device. + + Since this method should be called after materializing the module, any + non-CPU device should not be meta device. For now, the compute device is + always a CUDA or CUDA-like device with its explicit index. + + Precondition: ``_check_single_device_module()`` and + ``_move_module_to_device()``. + """ + param = next(_get_orig_params(module, ignored_params), None) + if param is not None and param.device.type != "cpu": + compute_device = param.device # Determined by model param placement + else: + compute_device = torch.device(device_handle.current_device()) + if device_from_device_id is not None and compute_device != device_from_device_id: + raise ValueError( + f"Inconsistent compute device and `device_id` on rank {rank}: " + f"{compute_device} vs {device_from_device_id}" + ) + return compute_device + + +# TODO: See how to deprecate! +def _sync_module_params_and_buffers( + module: nn.Module, + params: list[nn.Parameter], + process_group: dist.ProcessGroup, +) -> None: + """ + Synchronize module states (i.e. parameters ``params`` and all not-yet-synced buffers) by broadcasting from rank 0 to all ranks. + + Precondition: ``sync_module_states == True`` and ``self.process_group`` has + been set. + """ + module_states: list[torch.Tensor] = [] + for buffer in module.buffers(): + # Avoid re-synchronizing buffers in case of nested wrapping + if not getattr(buffer, FSDP_SYNCED, False): + setattr(buffer, FSDP_SYNCED, True) + detached_buffer = buffer.detach() + if is_traceable_wrapper_subclass(detached_buffer): + # NOTE: Here we assume no nested subclasses, at most one level of subclass + # in both model's buffers and params + attrs, _ = detached_buffer.__tensor_flatten__() # type: ignore[attr-defined] + inner_buffers = [getattr(detached_buffer, attr) for attr in attrs] + module_states.extend(inner_buffers) + else: + module_states.append(detached_buffer) + + for param in params: + detached_param = param.detach() + if is_traceable_wrapper_subclass(detached_param): + attrs, _ = detached_param.__tensor_flatten__() # type: ignore[attr-defined] + inner_params = [getattr(detached_param, attr) for attr in attrs] + module_states.extend(inner_params) + else: + module_states.append(detached_param) + + _check_module_states_for_sync_module_states(module_states) + _sync_params_and_buffers( + process_group, + module_states, + PARAM_BROADCAST_BUCKET_SIZE, + src=0, + ) + + +def _check_module_states_for_sync_module_states( + module_states: list[torch.Tensor], +) -> None: + if module_states and any( + tensor.device == torch.device("cpu") for tensor in module_states + ): + raise ValueError( + "The module has CPU parameters or buffers when `sync_module_states=True`, " + "which requires them to be on GPU. Please specify the `device_id` argument " + "or move the module to GPU before passing it to FSDP." + ) + + +def _get_orig_params( + module: nn.Module, + ignored_params: set[nn.Parameter], +) -> Iterator[nn.Parameter]: + """ + Return an iterator over the original parameters in ``module``. + + The iterator does not return + the parameters in ``ignored_params``, any ``FlatParameter`` s (which may be + present due to nested FSDP wrapping), or any original parameters already + flattened (only relevant when ``use_orig_params=True``). + """ + param_gen = module.parameters() + try: + while True: + param = next(param_gen) + if param not in ignored_params and not _is_fsdp_flattened(param): + yield param + except StopIteration: + pass + + +def _check_orig_params_flattened( + fsdp_module, + ignored_params: set[nn.Parameter], +) -> None: + """ + Check that original parameters in ``fsdp_module`` have been flattened. + + The flattened parameters are made + invisible to ``named_parameters()`` for the module hierarchy rooted at + ``fsdp_module``. This should be called as a sanity check after flattening + the wrapped module's parameters. + """ + for param_name, param in _named_parameters_with_duplicates(fsdp_module): + if param not in ignored_params and not _is_fsdp_flattened(param): + raise RuntimeError( + f"Found an unflattened parameter: {param_name}; " + f"{param.size()} {param.__class__}" + ) + + +def _get_default_comm_hook(sharding_strategy: ShardingStrategy): + return ( + default_hooks.allreduce_hook + if sharding_strategy == ShardingStrategy.NO_SHARD + else default_hooks.reduce_scatter_hook + ) + + +def _get_default_comm_hook_state( + process_group: dist.ProcessGroup, +) -> default_hooks.DefaultState: + return default_hooks.DefaultState(process_group=process_group) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_limiter_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_limiter_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ac4487462a85fd1c8e2dbf155e672f8f0f8f1c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_limiter_utils.py @@ -0,0 +1,33 @@ +import collections +from typing import Optional + +import torch + + +class _FreeEventQueue: + """ + This tracks all pending frees corresponding to inflight all-gathers. The + queueing pattern is iterative enqueues with a single dequeue per iteration + once the limit ``_max_num_inflight_all_gathers`` is reached. + """ + + def __init__(self) -> None: + self._queue: collections.deque[torch.Event] = collections.deque() + self._max_num_inflight_all_gathers = 2 # empirically chosen + + def enqueue(self, free_event: torch.Event) -> None: + """Enqueues a free event.""" + self._queue.append(free_event) + + def dequeue_if_needed(self) -> Optional[torch.Event]: + """Dequeues a single event if the limit is reached.""" + if len(self._queue) >= self._max_num_inflight_all_gathers: + return self._dequeue() + return None + + def _dequeue(self) -> Optional[torch.Event]: + """Dequeues a free event if possible.""" + if self._queue: + event = self._queue.popleft() + return event + return None diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_optim_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_optim_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0827c7fa20e3b7b71e52080d1bbfe14a55a3e4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_optim_utils.py @@ -0,0 +1,2072 @@ +# mypy: allow-untyped-defs +import copy +import functools +import logging +import warnings +from collections.abc import Iterable, Iterator, Sequence +from contextlib import ExitStack +from dataclasses import dataclass, field +from itertools import chain +from typing import Any, cast, NamedTuple, no_type_check, Optional, TYPE_CHECKING, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed._state_dict_utils import _gather_state_dict +from torch.distributed.distributed_c10d import _get_pg_default_device +from torch.distributed.fsdp._common_utils import ( + _apply_to_modules, + _FSDPState, + _get_module_fsdp_state_if_fully_sharded_module, + _get_param_to_fqns, + _module_handle, + _named_parameters_with_duplicates, + clean_tensor_name, +) +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._flat_param import FlatParameter, FlatParamHandle +from torch.distributed.fsdp._fsdp_extensions import ( + _ext_chunk_dtensor, + _ext_chunk_tensor, +) +from torch.distributed.fsdp._runtime_utils import ( + _lazy_init, + _reset_flat_param_grad_info_if_needed, +) +from torch.distributed.fsdp.api import ( + ShardingStrategy, + StateDictSettings, + StateDictType, +) +from torch.distributed.tensor import DTensor, Replicate +from torch.utils._pytree import tree_map_only + + +if TYPE_CHECKING: + from torch.distributed._shard.sharded_tensor import ShardedTensor + + +logger = logging.getLogger(__name__) + + +@dataclass +class FSDPParamInfo: + state: _FSDPState + handle: FlatParamHandle + param_indices: dict[str, int] + param_requires_grad: list[bool] + + +def sorted_items(dictionary: dict[str, Any]) -> Iterator[tuple[str, Any]]: + keys = sorted(dictionary.keys()) + for k in keys: + yield k, dictionary[k] + + +@dataclass +class _ConsolidatedOptimState: + """ + This holds the consolidated optimizer state on the target rank. Positive- + dimension tensor state is communicated across ranks, while zero-dimension + tensor state and non-tensor state is taken directly from the target rank. + + PyTorch version 1.12 moved to using zero-dimension tensors for scalar + values, but user implemented optimizers may still use float (i.e. a + non-tensor). Thus, we support both and handle them identically. + + Attributes: + tensor_state (Dict[str, torch.Tensor]): Mapping from positive-dimension + tensor state name to the unsharded flat tensor representing the + state. + zero_dim_tensor_state (Dict[str, torch.Tensor]): Mapping from zero- + dimension tensor state name to its value. + non_tensor_state (Dict[str, Any]): Mapping from non-tensor state + name to its value. + """ + + tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + zero_dim_tensor_state: dict[str, torch.Tensor] = field(default_factory=dict) + non_tensor_state: dict[str, Any] = field(default_factory=dict) + + +class _PosDimTensorInfo(NamedTuple): + """ + Metadata for positive-dimension tensors used internally for + :meth:`scatter_full_optim_state_dict`. + + Attributes: + shape (torch.Size): Sharded tensor shape (which is equal to the + unsharded tensor shape if the tensor is optimizer state for a + non-FSDP parameter and is hence not sharded). + dtype (torch.dtype): Data type of the tensor. + """ + + shape: torch.Size + dtype: torch.dtype + + +class _OptimStateKey(NamedTuple): + """ + This represents an optimizer state key that may be used commonly across + ranks. It is based on the unflattened parameter names rather than parameter + IDs to make it independent of each rank's own optimizer construction. + """ + + unflat_param_names: tuple[str, ...] + is_fsdp_managed: bool + + +def _unflatten_optim_state( + fsdp_param_info: FSDPParamInfo, + flat_param_state: dict[str, Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool, +) -> list[dict[str, Any]]: + """ + Unflattens the optimizer state, consisting of the "state" part and the + "param_groups" part. Unflattening the "state" part involves consolidating + the state on the target rank and remapping from flattened to unflattened + parameter IDs, and the "param_groups" part only involves remapping from + flattened to unflattened parameter IDs. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + flat_param_state (Dict[str, Any]): Entry for the flat parameter in the + "state" part of the optimizer state dict. + to_save (bool): Whether to save the state on this rank. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flat parameter if on the target + rank or an empty :class:`list` otherwise. The final optimizer state + dict will need to map these entries using the proper unflattened + parameter IDs. + """ + assert not shard_state or to_save, ( + "If ``shard_state`` is True, ``to_save`` has to be True." + ) + consolidated_state = _communicate_optim_state( + fsdp_param_info, + flat_param_state, + ) + if to_save: + unflat_param_state = _unflatten_communicated_optim_state( + fsdp_param_info, + consolidated_state, + shard_state, + ) + for optim_state in unflat_param_state: + # We can't use .items() below cuz we'd run into a concurrent modification error + if cpu_offload: + for key in list(optim_state.keys()): + state = optim_state[key] + if not isinstance(state, torch.Tensor): + continue + optim_state[key] = state.cpu() + return unflat_param_state + else: + return [] + + +def _is_zero_dim_tensor(x: Any) -> bool: + return torch.is_tensor(x) and x.dim() == 0 + + +def _communicate_optim_state( + fsdp_param_info: FSDPParamInfo, + flat_param_state: dict[str, Any], +) -> _ConsolidatedOptimState: + """ + Communicates the optimizer state for a flat parameter across ranks. All + ranks will hold the entire non-sharded optimizer state on GPU. + + If ``N`` is the number of tensor optimizer states in the optimizer state + dict, then the communication complexity is 0 if ``N = 0`` and ``N + 1`` + otherwise (where the plus 1 comes from all-gathering the padding per rank). + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + flat_param_state (Dict[str, Any]): The entry in the "state" part of the + optimizer state dict corresponding to the flat parameter. + + Returns: + ConsolidatedOptimState: Consolidated optimizer state for the target + flat parameter. + """ + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + state = _ConsolidatedOptimState() + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for state_name, value in sorted_items(flat_param_state): + # Positive-dimension tensor state: communicate across ranks + if torch.is_tensor(value) and value.dim() > 0: + # If the parameter is not sharded, then neither is the + # positive-dimension tensor state, so no need to communicate it -- + # we take the target rank's value + if ( + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + ): + tensor_state[state_name] = value + continue + assert fsdp_state.compute_device is not None, ( + "compute_device has not been initialized" + ) + if value.device.type != fsdp_state.compute_device.type: + value = value.to(fsdp_state.compute_device) + # Assume that positive-dimension tensor optimizer state + # has the same shape as the sharded flat parameter + buffer_size = flat_param._full_param_padded.size() # type: ignore[attr-defined] + tensor_buffer = value.new_zeros(*buffer_size) + dist.all_gather_into_tensor( + tensor_buffer, value, group=fsdp_state.process_group + ) + fsdp_state._device_handle.synchronize() + unpadded_numel = cast( + nn.Parameter, flat_param._unpadded_unsharded_size + ).numel() + tensor_state[state_name] = tensor_buffer[:unpadded_numel] + # Zero-dimension tensor state and non-tensor state: take this rank's + # value directly + else: + if _is_zero_dim_tensor(value): + zero_dim_tensor_state[state_name] = value.detach().clone() + else: + non_tensor_state[state_name] = value + return state + + +def _unflatten_communicated_optim_state( + fsdp_param_info: FSDPParamInfo, + state: _ConsolidatedOptimState, + shard_state: bool, +) -> list[dict[str, Any]]: + """ + Unflattens the communicated optimizer state (given by ``tensor_state``, + ``non_tensor_state``, and ``zero_dim_tensor_state``) for a single flat + parameter. This should only be called on the target rank. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + state (_ConsolidatedOptimState): Consolidated optimizer state. + + Returns: + List[Dict[str, Any]]: A :class:`list` holding the entries in the + "state" part of the optimizer state dict corresponding to the + unflattened parameters comprising the flat parameter. The final + optimizer state dict will need to map these entries using the proper + unflattened parameter IDs. + """ + fsdp_state = fsdp_param_info.state + handle = fsdp_param_info.handle + flat_param = handle.flat_param + unflat_param_state: list[dict[str, Any]] = [] + flat_param_views: dict[str, Iterator] = {} + num_unflat_params = flat_param._num_params + tensor_state, zero_dim_tensor_state, non_tensor_state = ( + state.tensor_state, + state.zero_dim_tensor_state, + state.non_tensor_state, + ) + + for _ in range(num_unflat_params): + unflat_state_param = {} + # Add positive-dimension tensor state: unflatten with views + for state_name, flat_tensor in sorted_items(tensor_state): + views_generated = state_name in flat_param_views + if not views_generated: + views = handle._get_unflat_views(flat_tensor) + flat_param_views[state_name] = views + else: + views = flat_param_views[state_name] + optim_state: Union[torch.Tensor, ShardedTensor, DTensor] = next(views) + if shard_state: + osd_config = fsdp_state._optim_state_dict_config + if getattr(osd_config, "_use_dtensor", False): + assert fsdp_state._device_mesh is not None + optim_state = _ext_chunk_dtensor( + optim_state, + fsdp_state.rank, + fsdp_state._device_mesh, + fsdp_state._fsdp_extension, + ) + else: + assert fsdp_state.process_group is not None + optim_state = _ext_chunk_tensor( + optim_state, + fsdp_state.rank, + fsdp_state.world_size, + fsdp_state._device_handle.device_count(), + fsdp_state.process_group, + fsdp_state._fsdp_extension, + ) + unflat_state_param[state_name] = optim_state + + # Add zero-dimension tensor state: take the target rank's value + unflat_state_param.update(sorted_items(zero_dim_tensor_state)) + # Add non-tensor state: take the target rank's value + unflat_state_param.update(sorted_items(non_tensor_state)) + unflat_param_state.append(unflat_state_param) + return unflat_param_state + + +def _broadcast_processed_state( + fsdp_state: _FSDPState, + optim_state: dict[str, Any], + group: Optional[dist.ProcessGroup], +) -> dict[str, Any]: + objects: list[Any] = [None] + if dist.get_rank(group) == 0: + objects[0] = tree_map_only( + torch.Tensor, + lambda v: v.cpu() if v.dim() == 0 else _PosDimTensorInfo(v.shape, v.dtype), # type: ignore[union-attr] + optim_state, + ) + dist.broadcast_object_list(objects, src=0, group=group) + if dist.get_rank(group) == 0: + return optim_state + else: + return objects[0] + + +def _broadcast_state( + fsdp_state: _FSDPState, state: Any, group: Optional[dist.ProcessGroup] +) -> Any: + if dist.get_rank(group) == 0: + if not isinstance(state, torch.Tensor) or state.dim() == 0: + return state + tensor = state.to(fsdp_state.compute_device) + else: + if isinstance(state, torch.Tensor): + assert state.dim() == 0, ( + "For non-zero ranks, a tensor state should have zero dimension, " + "but got the state with shape {state.shape()}." + ) + return state + elif not isinstance(state, _PosDimTensorInfo): + return state + tensor = torch.zeros( + state.shape, dtype=state.dtype, device=fsdp_state.compute_device + ) + dist.broadcast(tensor, src=0, group=group) + return tensor + + +def _shard_orig_param_state( + fsdp_param_info: FSDPParamInfo, + fqn: str, + optim_state: dict[str, Any], +) -> dict[str, Any]: + """ + Shard the optimizer state for the original parameter with the name ``fqn``. + This API should only be used when ``use_orig_params`` is True. + """ + if not optim_state: + return {} + fsdp_state = fsdp_param_info.state + flat_param = fsdp_param_info.handle.flat_param + param_idx = fsdp_param_info.param_indices[fqn] + shard_param_info = flat_param._shard_param_infos[param_idx] # type: ignore[attr-defined] + optim_state = _gather_state_dict( + optim_state, pg=fsdp_state.process_group, device=fsdp_state.compute_device + ) + if not shard_param_info.in_shard: + return {} + # Flatten and shard the state. + new_optim_state: dict[str, Any] = {} + intra_param_start_idx = shard_param_info.intra_param_start_idx + intra_param_end_idx = shard_param_info.intra_param_end_idx + for state_name, value in optim_state.items(): + if ( + torch.is_tensor(value) + and value.dim() > 0 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + value = value.flatten()[ + intra_param_start_idx : intra_param_end_idx # type: ignore[operator] + + 1 + ].clone() + new_optim_state[state_name] = value + return new_optim_state + + +def _flatten_optim_state_dict( + optim_state_dict: dict[str, Any], + model: nn.Module, + use_orig_params: bool = False, + optim: Optional[torch.optim.Optimizer] = None, + rank0_only: bool = False, + group: Optional[dist.ProcessGroup] = None, +) -> dict[str, Any]: + """ + Flattens the full optimizer state dict, still keying by unflattened parameter + names. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP know how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- it is managed by other parallelism and FSDP does not + know ho to handle/aggregate them. + + Note that ``_flatten_tensor_optim_state`` does not need ``optim`` to + flatten/shard the state. However, NamedOptimizer and KeyedOptimizer require + all the states even if the corresponding parameters are empty. To this end, + ``optim`` will be used to to get the initial state of the empty parameters. + ``optim`` should only be non-None if the ``optim` is KeyedOptimizer or + NamedOptimizer. + + Returns: + Dict[str, Any]: The flattened optimizer state dict. + """ + SimpleProfiler.reset() + + unflat_osd = optim_state_dict + if "state" not in unflat_osd and not rank0_only: + raise ValueError( + '`optim_state_dict` must have the keys "state"' + "to be a valid optimizer state dict" + ) + param_to_fqns = _get_param_to_fqns(model) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + fsdp_state = next(iter(fqn_to_fsdp_param_info.values())).state + + # Broadcast unflat_osd without non-scalar tensor if rank0_only is True. + if rank0_only: + unflat_osd = _broadcast_processed_state(fsdp_state, unflat_osd, group=group) + + # Construct the "state" part + flat_osd_state: dict[Union[_OptimStateKey, str], Any] = {} + unflat_osd_state = unflat_osd["state"] + all_state_keys = set(unflat_osd_state.keys()) + + for param, fqns in param_to_fqns.items(): + fqn = fqns[0] + if fqn not in unflat_osd_state: + continue + all_state_keys.difference_update(fqns) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name in unflat_osd_state[fqn].keys(): + unflat_osd_state[fqn][state_name] = _broadcast_state( + fsdp_state, unflat_osd_state[fqn][state_name], group=group + ) + fqn = fqns[0] + if fqn in fqn_to_fsdp_param_info: + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + if use_orig_params: + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + flat_state = _shard_orig_param_state( + fsdp_param_info, + fqn, + unflat_osd_state[fqn], + ) + else: + flat_state = _flatten_optim_state( + fsdp_param_info, + unflat_osd_state, + fqns, + ) + key = _OptimStateKey(tuple(fqns), True) + # Only include non-empty states since as expected by + # `torch.optim.Optimizer` s unless the optimizer is KeyedOptimizer + # or NamedOptimizer. + if flat_state: + flat_osd_state[key] = flat_state + elif use_orig_params: + assert len(fqns) == 1, ( + f"use_orig_params is True but there are multiple FQNs, {fqns}." + ) + if optim is not None: # NamedOptimizer or KeyedOptimizer case. + state = optim.state.get(param, None) # type: ignore[call-overload] + if state is not None: + flat_osd_state[key] = copy.deepcopy(state) + else: + warnings.warn( + f"optim_state[{key}] is not on rank{fsdp_state.rank}." + ) + + else: + raise RuntimeError( + f"The state of {key} is empty. This should happen when " + "use_orig_params=True." + ) + else: # do not flatten non-FSDP parameters' states + assert len(fqns) == 1 + key = _OptimStateKey(tuple(fqns), False) + flat_osd_state[key] = copy.copy(unflat_osd_state[fqn]) + + if rank0_only: + for fqn in fqns: + if not unflat_osd_state[fqn]: + continue + for state_name, param_state in list(unflat_osd_state[fqn].items()): + if fsdp_state.rank > 0: + # Deference the tensor so that PyTorch can collect the memory. + del unflat_osd_state[fqn][state_name] + else: + # Move the tensor in the original osd back to CPU to make the + # original osd unaffected. + unflat_osd_state[fqn][state_name] = param_state.cpu() + + # Handle user-defined state, states that are not associated with parameters. + for key in all_state_keys: + user_state = unflat_osd_state[key] + if isinstance(user_state, torch.Tensor) and rank0_only and use_orig_params: + user_state = _broadcast_state(fsdp_state, user_state, group=group) + flat_osd_state[key] = copy.copy(user_state) + + SimpleProfiler.dump_and_reset("FSDP _flatten_optim_state_dict() profiling: ") + # Construct the "param_groups" part -- copy as is since it will be + # rekeyed later according to the target rank's optimizer + # Only copy param_groups if it exists in unflat_osd + if "param_groups" in unflat_osd: + flat_osd_param_groups = copy.deepcopy(unflat_osd["param_groups"]) + return {"state": flat_osd_state, "param_groups": flat_osd_param_groups} + else: + return {"state": flat_osd_state} + + +def _flatten_optim_state( + fsdp_param_info: FSDPParamInfo, + unflat_osd_state: dict[str, dict[str, Any]], + unflat_param_names: list[str], +) -> dict[str, Any]: + """ + Flattens the optimizer state in ``full_optim_state_dict`` for a single + flat parameter in ``fsdp_param_info`` corresponding to the unflattened + parameter names in ``unflat_param_names``. + + Args: + fsdp_param_info (FSDPParamInfo): The FSDP state, the handle, and a + mapping from FQN to original parameter index. + unflat_osd_state (Dict[str, Dict[str, Any]]): The "state" part of the + optimizer state dict corresponding to the unflattened parameters. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the flat parameter ``flat_param``. + + Returns: + Dict[str, Any]: A :class:`dict` mapping state names to their values for + a particular flat parameter. The sharded optimizer state dict's "state" + part will map a key to this returned value. + """ + fsdp_state = fsdp_param_info.state + handle = fsdp_param_info.handle + flat_param = handle.flat_param + num_unflat_params = len(unflat_param_names) + assert num_unflat_params > 0, ( + "Expects at least one unflattened parameter corresponding to the flat parameter" + ) + unflat_param_shapes = flat_param._shapes + num_unflat_param_shapes = len(unflat_param_shapes) + assert num_unflat_params == num_unflat_param_shapes, ( + f"Expects {num_unflat_params} shapes but got {num_unflat_param_shapes}" + ) + + # Check if these unflattened parameters have any optimizer state + has_state = [ + bool(unflat_param_name in unflat_osd_state) + for unflat_param_name in unflat_param_names + ] + # If none of the unflattened parameters comprising this flat parameter have + # any state, then we do not want an entry in the optimizer state dict + if not any(has_state): + return {} # no need to flatten any state + # There may still be some unflattened parameters with state and some + # without + unflat_param_states = [ + _gather_state_dict( + unflat_osd_state[unflat_param_name], + pg=fsdp_state.process_group, + device=fsdp_state.compute_device, + ) + if unflat_param_name in unflat_osd_state + else None + for unflat_param_name in unflat_param_names + ] + # Check that the unflattened parameters have the same state names + state_names = None + for unflat_param_state in unflat_param_states: + if unflat_param_state is None: + continue + if state_names is None: + state_names = set(unflat_param_state.keys()) + else: + if state_names != set(unflat_param_state.keys()): + raise ValueError( + "Differing optimizer state names for the unflattened " + f"parameters: {unflat_param_names}" + ) + assert state_names is not None + + # Flatten the state + flat_state: dict[str, Optional[torch.Tensor]] = {} + for state_name in state_names: + state_values = [ + unflat_param_state[state_name] if unflat_param_state is not None else None + for unflat_param_state in unflat_param_states + ] + non_none_state_values = [v for v in state_values if v is not None] + # If all ranks have None, this is a None value + if not non_none_state_values: + flat_state[state_name] = None + continue + are_pos_dim_tensors = are_zero_dim_tensors = are_non_tensors = True + for v in non_none_state_values: + are_pos_dim_tensors &= torch.is_tensor(v) and v.dim() > 0 + are_zero_dim_tensors &= _is_zero_dim_tensor(v) + are_non_tensors &= not torch.is_tensor(v) + types = {type(v) for v in non_none_state_values} + if len(types) != 1 or not ( + are_pos_dim_tensors or are_zero_dim_tensors or are_non_tensors + ): + raise ValueError( + f"Differing optimizer state types for state {state_name}, " + f"values {non_none_state_values}, and unflattened parameter " + f"names {unflat_param_names}" + ) + if are_pos_dim_tensors: + flat_tensor = _flatten_tensor_optim_state( + state_name, + state_values, # type: ignore[arg-type] + unflat_param_names, + unflat_param_shapes, + handle, + ) + # Shard the flattened tensor immediately to minimize max memory + # usage + if ( + fsdp_state.world_size != 1 + and fsdp_state.sharding_strategy != ShardingStrategy.NO_SHARD + ): + sharded_flat_tensor, _ = FlatParamHandle._get_shard( + flat_tensor, + fsdp_state.rank, + fsdp_state.world_size, + ) + else: + sharded_flat_tensor = flat_tensor + flat_state[state_name] = sharded_flat_tensor + elif are_zero_dim_tensors: + flat_state[state_name] = _flatten_zero_dim_tensor_optim_state( + state_name, + state_values, # type: ignore[arg-type] + unflat_param_names, + ) + else: + assert are_non_tensors + flat_state[state_name] = _flatten_non_tensor_optim_state( + state_name, + state_values, + unflat_param_names, + ) + + return flat_state + + +def _flatten_tensor_optim_state( + state_name: str, + pos_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], + unflat_param_shapes: Sequence[torch.Size], + handle: FlatParamHandle, +) -> torch.Tensor: + """ + Flattens the positive-dimension tensor optimizer state given by the values + ``tensors`` for the state ``state_name`` for a single flat parameter + from ``handle`` corresponding to the unflattened parameter names + ``unflat_param_names`` and unflatted parameter shapes + ``unflat_param_shapes``. This flattens each unflattened parameter's tensor + state into one tensor. + + NOTE: We use zero tensors for any unflattened parameters without state + since some value is required to fill those entries. This assumes that the + zero tensor is mathematically equivalent to having no state, which is true + for Adam's "exp_avg" and "exp_avg_sq" but may not be true for all + optimizers. + + Args: + state_name (str): Optimizer state name. + pos_dim_tensors (List[torch.Tensor]): Positive-dimension tensor + optimizer state values for the unflattened parameters corresponding + to the single flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + unflat_param_shapes (List[torch.Size]): Unflattened parameter shapes + corresponding to the single flat parameter. + handle (FlatParamHandle): The flat parameter's handle. + + Returns: + torch.Tensor: A flat tensor containing the optimizer state + corresponding to ``state_name`` constructed by concatenating the + unflattened parameter tensor states in ``pos_dim_tensors`` (using zero + tensors for any unflattened parameters without the state). + """ + flat_param = handle.flat_param + non_none_tensors = [t for t in pos_dim_tensors if t is not None] + # Check that all are tensors with the same dtype + dtypes = {t.dtype for t in non_none_tensors} + if len(dtypes) != 1: + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have positive-dimension tensor state with the " + f"same dtype but got dtypes {dtypes} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + dtype = next(iter(dtypes)) + # Check that each tensor state matches its parameter's shape + for tensor, shape in zip(pos_dim_tensors, unflat_param_shapes): + if tensor is None and len(shape) == 0: + raise ValueError("Flattening a zero-dimension parameter is not supported") + elif tensor is not None and tensor.shape != shape: + raise ValueError( + "Tensor optimizer state does not have same shape as its " + f"parameter: {tensor.shape} {shape}" + ) + # Flatten the tensor states: we do not need to add any right-hand-side + # padding since the flat optimizer state tensor is sharded via + # `_get_shard()`, which pads the shard as needed (just like for the flat + # parameter) + cpu_device = torch.device("cpu") + tensors_to_flatten = [ + torch.flatten(state_value.to(cpu_device)) + if state_value is not None + else torch.flatten( + torch.zeros( + size=shape, + dtype=dtype, + device=cpu_device, + ) + ) + for state_value, shape in zip(pos_dim_tensors, unflat_param_shapes) + ] + flat_tensor = handle.flatten_tensors(tensors_to_flatten, handle._aligned_numel) + flat_param_shape = flat_param._unpadded_unsharded_size # type: ignore[attr-defined] + assert flat_tensor.shape == flat_param_shape, ( + f"tensor optim state: {flat_tensor.shape} flat parameter: {flat_param_shape}" + ) + return flat_tensor + + +def _flatten_zero_dim_tensor_optim_state( + state_name: str, + zero_dim_tensors: list[torch.Tensor], + unflat_param_names: list[str], +) -> torch.Tensor: + """ + Flattens the zero-dimension tensor optimizer state given by the values + ``zero_dim_tensors`` for the state ``state_name`` for a single flat + parameter corresponding to the unflattened parameter names + ``unflat_param_names`` by enforcing that all tensors are the same and using + that common value. + + NOTE: The requirement that the tensors are the same across all unflattened + parameters comprising the flat parameter is needed to maintain the + invariant that FSDP performs the same computation as its non-sharded + equivalent. This means that none of the unflattened parameters can be + missing this state since imposing a value may differ from having no value. + For example, for Adam's "step", no value means maximum bias correction, + while having some positive value means less bias correction. + + Args: + state_name (str): Optimizer state name. + zero_dim_tensors (List[torch.Tensor]): Zero-dimension optimizer state + for the unflattened parameters corresponding to the single + flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + + Returns: + torch.Tensor: A zero-dimensional tensor giving the value of the state + ``state_name`` for all unflattened parameters corresponding to the + names ``unflat_param_names``. + """ + non_none_tensors = [t for t in zero_dim_tensors if t is not None] + # Enforce that all have the same value and dtype + values_set = {t.item() if t is not None else None for t in zero_dim_tensors} + dtypes = {t.dtype if t is not None else None for t in zero_dim_tensors} + if ( + len(non_none_tensors) != len(zero_dim_tensors) + or len(values_set) != 1 + or len(dtypes) != 1 + ): + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have scalar state with the same value and dtype " + f"but got values {values_set} and dtypes {dtypes} for state " + f"{state_name} and unflattened parameter names " + f"{unflat_param_names}" + ) + value = next(iter(values_set)) + dtype = next(iter(dtypes)) + return torch.tensor(value, dtype=dtype, device=torch.device("cpu")) + + +def _flatten_non_tensor_optim_state( + state_name: str, + non_tensors: list[Any], + unflat_param_names: list[str], +) -> Any: + """ + Flattens the non-tensor optimizer state given by the values ``non_tensors`` + for the state ``state_name`` for a single flat parameter corresponding + to the unflattened parameter names ``unflat_param_names`` by enforcing that + all values are the same and using that common value. + + See the note in :func:`_flatten_zero_dim_tensor_optim_state`. + + Args: + state_name (str): Optimizer state name. + non_tensors (List[Any]): Non-tensor optimizer state for the unflattened + parameters corresponding to the single flat parameter. + unflat_param_names (List[str]): A :class:`list` of unflattened + parameter names corresponding to the single flat parameter. + + Returns: + Any: A non-tensor giving the value of the state ``state_name`` for all + unflattened parameters corresponding to the names + ``unflat_param_names``. + """ + non_none_non_tensors = [nt for nt in non_tensors if nt is not None] + # Enforce that all have the same value (same type already checked) + non_tensor_set = set(non_tensors) + if len(non_none_non_tensors) != len(non_tensors) or len(non_tensor_set) != 1: + raise ValueError( + "All unflattened parameters comprising a single flat " + "parameter must have scalar state with the same value and dtype " + f"but got values {non_tensor_set} for state {state_name} and " + f"unflattened parameter names {unflat_param_names}" + ) + non_tensor = next(iter(non_tensor_set)) + return non_tensor + + +def _rekey_sharded_optim_state_dict( + sharded_osd: dict[str, Any], + model: nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ], + using_optim_input: bool, + is_named_optimizer: bool = False, +) -> dict[str, Any]: + """ + Rekeys the optimizer state dict from unflattened parameter names to flat + parameter IDs according to the calling rank's ``optim``, which may be + different across ranks. In particular, the unflattened parameter names are + represented as :class:`_OptimStateKey` s. + """ + param_to_fqns = _get_param_to_fqns(model) + flat_param_to_fqn = _get_flat_param_to_fqn(model) + param_to_param_key: dict[nn.Parameter, Union[int, str]] = cast( + dict[nn.Parameter, Union[int, str]], + ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_key( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + ), + ) + # All parameter keys in `param_to_param_key` should be in + # `param_to_fqns` -- strict inequality follows when not all parameters are + # passed to the optimizer + assert len(param_to_param_key) <= len(param_to_fqns) + + unflat_param_names_to_flat_param_key: dict[ + tuple[str, ...], Union[int, str] + ] = {} # for "state" + unflat_param_name_to_flat_param_key: dict[ + str, Union[int, str] + ] = {} # for "param_groups" + for param, unflat_param_names in param_to_fqns.items(): + if param not in param_to_param_key: + # This parameter was not passed to the optimizer + continue + flat_param_key = param_to_param_key[param] + unflat_param_names_to_flat_param_key[tuple(unflat_param_names)] = flat_param_key + for unflat_param_name in unflat_param_names: + unflat_param_name_to_flat_param_key[unflat_param_name] = flat_param_key + + sharded_osd_state = sharded_osd["state"] + rekeyed_osd_state: dict[Union[str, int], Any] = {} + for key, param_state in sharded_osd_state.items(): + if isinstance(key, str): + rekeyed_osd_state[key] = param_state + continue + flat_param_key = unflat_param_names_to_flat_param_key.get( + key.unflat_param_names, key.unflat_param_names + ) + rekeyed_osd_state[flat_param_key] = param_state + + # Only process param_groups if it exists in sharded_osd + if "param_groups" in sharded_osd: + rekeyed_osd_param_groups: list[dict[str, Any]] = [] + for unflat_param_group in sharded_osd["param_groups"]: + flat_param_group = copy.deepcopy(unflat_param_group) + flat_param_keys = sorted( + { + unflat_param_name_to_flat_param_key[unflat_param_name] + for unflat_param_name in unflat_param_group["params"] + } + ) + flat_param_group["params"] = flat_param_keys + rekeyed_osd_param_groups.append(flat_param_group) + return {"state": rekeyed_osd_state, "param_groups": rekeyed_osd_param_groups} + else: + return {"state": rekeyed_osd_state} + + +def _get_param_id_to_param_from_optim_input( + model: nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ] = None, +) -> dict[int, nn.Parameter]: + """ + Constructs a mapping from parameter IDs to parameters. This may be used + both for models with ``FlatParameter`` s and without. + + NOTE: This method is only preserved for backward compatibility. The method + :meth:`_get_param_key_to_param` is the preferred code path that does not + rely on ``optim_input``. + + NOTE: We critically assume that, whether the optimizer input is a list of + parameters or a list of parameter groups, :class:`torch.optim.Optimizer` + enumerates the parameter IDs in order. In other words, for a parameter list + input, the parameter IDs should be in that list order, and for a parameter + groups input, the parameter IDs should be in order within each parameter + group and in order across parameter groups. + + Args: + model (nn.Module): Model whose parameters are passed into the + optimizer. + optim_input (Optional[Union[List[Dict[str, Any]], + Iterable[nn.Parameter]]]): Input passed into the optimizer + representing either a :class:`list` of parameter groups or an + iterable of parameters; if ``None``, then this method assumes the + input was ``model.parameters()``. (Default: ``None``) + + Returns: + List[nn.Parameter]: Mapping from parameter IDs to parameters, + where the parameter ID is implicitly the index in the :class:`list`. + """ + # Assume the standard case of passing `model.parameters()` to the optimizer + # if `optim_input` is not specified + if optim_input is None: + return dict(enumerate(model.parameters())) + try: + params = cast(list[nn.Parameter], list(optim_input)) + except TypeError as e: + raise TypeError( + "Optimizer input should be an iterable of Tensors or dicts, " + f"but got {optim_input}" + ) from e + if len(params) == 0: + raise ValueError("Optimizer input should not be empty") + + # Check if the optimizer input represents tensors or parameter groups + all_tensors = True + all_dicts = True + for param in params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError("Optimizer input should be an iterable of Tensors or dicts") + if all_tensors: + return dict(enumerate(params)) + assert all_dicts + param_id_to_param: list[nn.Parameter] = [] + for param_group in params: + has_params_key = "params" in param_group # type: ignore[operator] + assert has_params_key, ( + 'A parameter group should map "params" to a list of the ' + "parameters in the group" + ) + # Implicitly map `flat_param_id` (current length of the list) to + # `param` + param_id_to_param.extend(param_group["params"]) # type: ignore[index] + return dict(enumerate(param_id_to_param)) + + +def _get_flat_param_to_fqn(model: torch.nn.Module) -> dict[FlatParameter, str]: + """ + Constructs a mapping from ``FlatParameter`` to a cleaned (devoid of prefixes + from wrappers) fully qualified name (FQN). Note that this FQN is "non-canonical" + because ``FlatParameter`` s do not come from the original module but are + registered only after FSDP has been applied. This function returns the FSDP-given + name for the ``FlatParameter`` (usually module._flat_param) as opposed to the + canonical FQNs returned for ``FlatParameter`` s in ``_common_utils._get_param_to_fqns(...)``). + + Consequently, this function will only return a non-empty mapping if FSDP was + applied with ``use_orig_params=False`` as, otherwise, the original parameters + are used within the module and there would be no ``FlatParameter`` s in the module. + + """ + + def module_fn(module, prefix, tree_level, flat_param_to_fqn): + for param_name, param in _named_parameters_with_duplicates( + module, recurse=False + ): + if not isinstance(param, FlatParameter): + continue + fqn = clean_tensor_name(prefix + param_name) + flat_param_to_fqn[param] = fqn + + def return_fn(flat_param_to_fqn): + return flat_param_to_fqn + + flat_param_to_fqn_ret: dict[FlatParameter, str] = {} + return _apply_to_modules( + model, + module_fn, + return_fn, + [fqn for fqn, _ in _named_parameters_with_duplicates(model)], + flat_param_to_fqn_ret, + ) + + +def _get_param_key_to_param( + optim: torch.optim.Optimizer, + model: Optional[nn.Module] = None, + is_named_optimizer: bool = False, + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[Union[int, str], nn.Parameter]: + """ + Constructs a mapping from parameter keys to parameters. For the regular + optimizers, the keys are parameter IDs. For NamedOptimizer, the keys + are FQNs. This API may be used both for models with ``FlatParameter`` s and + without. + """ + clean_fqn_to_curr_fqn: dict[str, str] = {} + if is_named_optimizer: + assert param_to_fqns is not None and flat_param_to_fqn is not None, ( + "The optimizer is a NamedOptimizer, `param_to_fqns` must not be None." + ) + assert model is not None + for key, _ in _named_parameters_with_duplicates(model): + clean_fqn_to_curr_fqn[clean_tensor_name(key)] = key + + param_key_to_param: dict[Union[str, int], nn.Parameter] = {} + pid = 0 + for param_group in optim.param_groups: + if is_named_optimizer: + for param in param_group["params"]: + assert flat_param_to_fqn is not None + if param in flat_param_to_fqn: + # FlatParameter case + key = flat_param_to_fqn[param] + else: + assert param_to_fqns is not None + # use_orig_params case + assert len(param_to_fqns[param]) == 1 + key = param_to_fqns[param][0] + try: + key = clean_fqn_to_curr_fqn[key] + except KeyError as e: + raise KeyError( + f"Can't find {key} from {list(clean_fqn_to_curr_fqn.keys())}." + ) from e + param_key_to_param[key] = param + else: + for param in param_group["params"]: + param_key_to_param[pid] = param + pid += 1 + + return param_key_to_param + + +def _get_param_to_param_key( + optim: torch.optim.Optimizer, + model: Optional[nn.Module] = None, + is_named_optimizer: bool = False, + param_to_fqns: Optional[dict[nn.Parameter, list[str]]] = None, + flat_param_to_fqn: Optional[dict[FlatParameter, str]] = None, +) -> dict[nn.Parameter, Union[int, str]]: + """ + Constructs the inverse mapping of :func:`_get_param_key_to_param`. This API + only supports the case where `optim` is a regular optimizer, not NamedOptimizer. + So the parameter keys will be parameter ids. + """ + param_id_to_param = _get_param_key_to_param( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + return {param: param_id for param_id, param in param_id_to_param.items()} + + +def _get_param_to_param_id_from_optim_input( + model: nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ] = None, +) -> dict[nn.Parameter, int]: + """Constructs the inverse mapping of :func:`_get_param_id_to_param_from_optim_input`.""" + param_id_to_param = _get_param_id_to_param_from_optim_input(model, optim_input) + return {param: param_id for param_id, param in param_id_to_param.items()} + + +def _check_missing_keys_on_rank( + r0_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[str, int]], + param_key_to_param: dict[Union[str, int], nn.Parameter], + group: Optional[dist.ProcessGroup], +) -> None: + # Ensure that all ranks have at least the optimizer states needed by + # rank 0's optimizer + missing_keys: list[_OptimStateKey] = [] + for r0_optim_state_key in r0_optim_state_keys: + if r0_optim_state_key not in optim_state_key_to_param_key: + # A parameter from rank 0's optimizer does not exist for this + # rank's optimizer + missing_keys.append(r0_optim_state_key) + continue + param_key = optim_state_key_to_param_key[r0_optim_state_key] + if isinstance(param_key, int): + assert param_key >= 0 and param_key < len(param_key_to_param), ( + "Check the `param_key_to_param` construction" + ) + # We cannot use FSDPState.compute_device as this API is a global view. + device = _get_pg_default_device(group) + num_missing = torch.tensor([len(missing_keys)], dtype=torch.int32, device=device) + dist.all_reduce(num_missing, group=group) + if num_missing.item() > 0: + obj_list = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(obj_list, missing_keys, group=group) + error_msg = ( + "FSDP currently requires each rank to have at least the " + "optimizer states needed by rank 0's optimizer but some ranks " + "are missing some of those states" + ) + for rank, keys in enumerate(obj_list): + keys = cast(list[_OptimStateKey], keys) + if len(keys) > 0: + error_msg += ( + f"\nRank {rank} is missing states for the parameters: " + f"{[key.unflat_param_names for key in keys]}" + ) + raise RuntimeError(error_msg) + + +def _map_param_key_to_optim_keys( + optim_state_dict: dict[str, Any], + group: Optional[dist.ProcessGroup], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + merge_keys: bool = False, +) -> tuple[list[_OptimStateKey], dict[_OptimStateKey, Union[int, str]]]: + """ + Construct the local mapping between the ``_OptimStateKey`` and parameter keys + and all the ``_OptimStateKey`` across ranks. If ``merge_keys`` is False, rank0 + must contain all the ``_OptimStateKey``, an exception will be raised otherwise. + Note that ``merge_keys`` should equal to ``use_orig_params``. + """ + rank = dist.get_rank(group) + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]] = {} # local + all_optim_state_keys: list[_OptimStateKey] = [] + + for param_key, param in param_key_to_param.items(): + # Do not include parameters without state to avoid empty mappings + # just like in normal `torch.optim.Optimizer.state_dict()` + if param_key not in optim_state_dict["state"]: + continue + fqns = param_to_fqns[param] + is_fsdp_managed = isinstance(param, FlatParameter) + if is_fsdp_managed: + assert fqns[0] in fqn_to_fsdp_param_info, ( + fqns[0], + list(fqn_to_fsdp_param_info.keys()), + ) + is_fsdp_managed = fqns[0] in fqn_to_fsdp_param_info + optim_state_key = _OptimStateKey( + unflat_param_names=tuple(fqns), + is_fsdp_managed=is_fsdp_managed, + ) + if rank == 0 or merge_keys: + all_optim_state_keys.append(optim_state_key) + optim_state_key_to_param_key[optim_state_key] = param_key + + if merge_keys: + all_keys: list[list[_OptimStateKey]] = [ + [] for _ in range(dist.get_world_size(group)) + ] + dist.all_gather_object(all_keys, all_optim_state_keys, group=group) + merge_all_optim_state_keys = [*chain.from_iterable(all_keys)] + all_optim_state_keys = sorted(set(merge_all_optim_state_keys)) + else: + key_obj_list: list[Optional[list[_OptimStateKey]]] = ( + [all_optim_state_keys] if rank == 0 else [None] + ) + dist.broadcast_object_list(key_obj_list, src=0, group=group) + assert key_obj_list[0] is not None + all_optim_state_keys = key_obj_list[0] + _check_missing_keys_on_rank( + all_optim_state_keys, + optim_state_key_to_param_key, + param_key_to_param, + group, + ) + + return all_optim_state_keys, optim_state_key_to_param_key + + +def _unflatten_param_groups( + state_dict: dict[str, Any], + param_key_to_param: dict[Union[int, str], nn.Parameter], + param_to_fqns: dict[nn.Parameter, list[str]], +) -> list[dict[str, Any]]: + param_groups: list[dict[str, Any]] = [] + for flat_param_group in state_dict["param_groups"]: + unflat_param_group = copy.deepcopy(flat_param_group) + param_group_params = [ + param_key_to_param[flat_param_key] + for flat_param_key in flat_param_group["params"] + ] + nested_unflat_param_names = [ + param_to_fqns[param] for param in param_group_params + ] + unflat_param_group["params"] = [ + *chain.from_iterable(nested_unflat_param_names) + ] # flatten the list of lists + param_groups.append(unflat_param_group) + return param_groups + + +def _is_named_optimizer(optim_state_dict: dict[str, Any]) -> bool: + """ + Returns whether the state_dict is from a NamedOptimizer. + This function checks that the keys in the state_dict['state'] are strings + (which usually are FQNs) versus integers (which usually refer to param_ids + from a vanilla torch.optim.Optimizer). + """ + state = optim_state_dict.get("state", None) + if not state: + # If we cannot find a state, assume it is not NamedOptimizer as + # NamedOptimizer has eager initialization. + return False + try: + key = next(iter(state.keys())) + except Exception as e: + raise Exception(optim_state_dict) from e # noqa: TRY002 + return isinstance(key, str) + + +@dataclass +class StateInfo: + # The key of these dictionaries are the state name, e.g., `exp_avg`. + tensors: dict[str, _PosDimTensorInfo] + scalar_tensors: dict[str, torch.Tensor] + non_tensors: dict[str, Any] + + +def _allgather_state_info( + fsdp_state: _FSDPState, + input_states: dict[str, Any], +) -> list[dict[str, StateInfo]]: + """ + Given the ``input_states``, allgather StateInfo for each state. The function + uses all_gather_object to gather StateInfo so no GPU tensors are sent. + """ + + processed_state_dict: dict[str, StateInfo] = {} + gathered_state_info: list[dict[str, StateInfo]] = [ + {} for _ in range(fsdp_state.world_size) + ] + + for fqn, optim_state in input_states.items(): + # Allgather the scalar tensor state, non-tensor states and tensors metadata. + processed_state = StateInfo({}, {}, {}) + for state_name, value in sorted_items(optim_state): + if torch.is_tensor(value): + if value.dim() == 0: + # Ensure that `step` is on CPU. + processed_state.scalar_tensors[state_name] = value.cpu() + else: + processed_state.tensors[state_name] = _PosDimTensorInfo( + value.shape, value.dtype + ) + else: + processed_state.non_tensors[state_name] = value + processed_state_dict[fqn] = processed_state + dist.all_gather_object( + gathered_state_info, + processed_state_dict, + group=fsdp_state.process_group, + ) + return gathered_state_info + + +def _convert_all_state_info( + fsdp_param_info: FSDPParamInfo, + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + output_states: dict[str, dict[str, Any]], +) -> tuple[Optional[torch.dtype], dict[str, list[Optional[torch.Tensor]]]]: + """ + Given the ``gathered_state_info`` and ``input_states``, the API converted + the StateInfo into the original state if the state is not a non-scalar + tensor. For a multi-dimensional tensor, the local state will be stored in + ``state_buffer`` in a correct order for later allgather purpose. + """ + + state_buffers: dict[str, list[Optional[torch.Tensor]]] = {} + + for fqn, gathered_state in output_states.items(): + state_info = [s[fqn] for s in gathered_state_info] + all_tensor_states = sorted( + {n for state in state_info for n in state.tensors.keys()} + ) + empty_ranks: set[int] = set() + dtype: Optional[torch.dtype] = None + # First check all the non-scalar states and get the information of + # states on each rank. + for state_name in all_tensor_states: + numels = [] + _empty_ranks: set[int] = set() + for rank, object_state in enumerate(state_info): + numels.append(0) + info = object_state.tensors.get(state_name, None) + if info is not None: + numels[-1] = info.shape.numel() + if not dtype: + dtype = info.dtype + else: + assert dtype == info.dtype + if numels[-1] == 0: + _empty_ranks.add(rank) + + assert not empty_ranks or empty_ranks == _empty_ranks + empty_ranks = _empty_ranks + if state_name not in state_buffers: + state_buffers[state_name] = [ + None for _ in fsdp_param_info.param_indices + ] + local_state = input_states[fqn].get(state_name, None) + # N.B. We need to move the state to compute_device. The reason is + # not yet clear and we need to figure out why the state may be on a + # different device. + if local_state is not None: + local_state = local_state.to(fsdp_param_info.state.compute_device) + state_buffers[state_name][fsdp_param_info.param_indices[fqn]] = local_state + + # Restoring the scalar and non-tensor states. If the corresponding + # non-scalar states do not exist on the rank, we also skip the scalar + # non-tensor states on that rank. + for rank, object_state in enumerate(state_info): + if rank in empty_ranks: + continue + for name, non_tensor_value in object_state.non_tensors.items(): + curr_non_tensor_value = gathered_state.get(name, None) + assert ( + curr_non_tensor_value is None + or curr_non_tensor_value == non_tensor_value + ), ( + f"Rank {rank} has different values for {name}: {non_tensor_value}." + + f" Other ranks: {curr_non_tensor_value}" + ) + gathered_state[name] = non_tensor_value + + for name, scalar_tensor_value in object_state.scalar_tensors.items(): + curr_scalar_tensor_value = gathered_state.get(name, None) + assert curr_scalar_tensor_value is None or torch.equal( + scalar_tensor_value, curr_scalar_tensor_value + ), ( + f"Rank {rank} has different values for {name}: {scalar_tensor_value}." + + f" Other ranks: {curr_scalar_tensor_value}" + ) + gathered_state[name] = scalar_tensor_value + + return dtype, state_buffers # type: ignore[possibly-undefined] + + +def _unflatten_orig_param_states( + fsdp_param_info: FSDPParamInfo, + output_states: dict[str, dict[str, Any]], + state_name: str, + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> None: + """ + Given a output state dict, ``output_states``, which the keys are FQNs to the + original parameters (not FlatParameters nor parameter ID), and the values + are gathered states, unflatten the states to the original dimensions. + + This function performs the unflattening process in-place. + """ + if not to_save: + return + flat_param = fsdp_param_info.handle.flat_param + fsdp_state = fsdp_param_info.state + for fqn, gathered_state in output_states.items(): + value = gathered_state[state_name] + param_idx = fsdp_param_info.param_indices[fqn] + + # TODO: This solution is not general and only apply to PTD TP solution. + if isinstance(value, DTensor): + placement = value.placements[0] + # If gathered state is a DTensor and its TP placement is not Replicate(), we need to + # gather the tensor on its TP dimension before chunking them into DTensor again. + if placement != Replicate(): + placement_dim = placement.dim # type: ignore[attr-defined] + value.redistribute(placements=(Replicate(),)) + reshape_size = list(flat_param._shapes[param_idx]) + reshape_size[placement_dim] *= value.device_mesh.size(0) + reshape_size = torch.Size(reshape_size) + value = value.reshape(reshape_size) + # If gathered state is a replicate DTensor, we directly reshape it. + else: + value = value.reshape(flat_param._shapes[param_idx]) + else: + # If gathered state is a tensor, we directly reshape it into unflatten state. + value = value.reshape(flat_param._shapes[param_idx]) + + if shard_state: + osd_config = fsdp_state._optim_state_dict_config + if getattr(osd_config, "_use_dtensor", False): + assert fsdp_state._device_mesh is not None + value = _ext_chunk_dtensor( + value, + fsdp_state.rank, + fsdp_state._device_mesh, + fsdp_state._fsdp_extension, + ) + else: + assert fsdp_state.process_group is not None + value = _ext_chunk_tensor( + value, + fsdp_state.rank, + fsdp_state.world_size, + fsdp_state._device_handle.device_count(), + fsdp_state.process_group, + fsdp_state._fsdp_extension, + ) + elif not cpu_offload: + with SimpleProfiler.profile("clone"): + value = value.detach().clone() + + if cpu_offload: + with SimpleProfiler.profile(SimpleProfiler.Type.D2H): + value = value.cpu() + gathered_state[state_name] = value + + +def _allgather_orig_param_states( + fsdp_param_info: FSDPParamInfo, + gathered_state_info: list[dict[str, StateInfo]], + input_states: dict[str, Any], + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> dict[str, dict[str, Any]]: + """ + Given the ``gathered_state_info`` and ``input_states``, the API allgathers + all tensor states and restore non-tensor states from ``gathered_state_info``. + """ + fsdp_state = fsdp_param_info.state + if fsdp_state.rank == 0 and dist.get_debug_level() == dist.DebugLevel.DETAIL: + logger.info( + "Memory Summary before calling to _allgather_orig_param_states %s", + fsdp_state._device_handle.memory_summary(), + ) + + output_states: dict[str, dict[str, Any]] = {fqn: {} for fqn in input_states.keys()} + + dtype, state_buffers = _convert_all_state_info( + fsdp_param_info, gathered_state_info, input_states, output_states + ) + + if len(state_buffers) == 0: + return output_states + + has_state_params: list[bool] = [ + True if fqn in output_states else False + for fqn, idx in fsdp_param_info.param_indices.items() + ] + + # Loop through the ``state_buffers`` and construct the flattened, concatenated, + # sharded states. The size of the constructed state will be the same size as + # flat_param (also sharded). + # Then we perform an allgather_into_tensor to get the full flat_param state. + # The full flat_param state is the result of concatenation of multiple states + # the order of of flat_param._fqns. + # The final step is to split the flat_param state into original param states + # and return the result. + flat_param = fsdp_param_info.handle.flat_param + empty_func = functools.partial( + torch.empty, dtype=dtype, device=fsdp_state.compute_device + ) + gathered_tensor = empty_func(flat_param._padded_unsharded_size) + # Synchronize can be slow but this will be easier for us to debug. + fsdp_state._device_handle.synchronize() + for state_name, buffers in state_buffers.items(): + local_buffers: list[torch.Tensor] = [] + begin = fsdp_state.rank * flat_param._sharded_size.numel() + # End is inclusive. + end = begin + flat_param._sharded_size.numel() - 1 + # param_idx corresponds to the parameter index in the FlatParameter. + mem_offset, param_idx = 0, 0 + for numel, is_padding in zip( + flat_param._numels_with_padding, flat_param._is_padding_mask + ): + frozen_and_no_state = not is_padding and ( + not fsdp_param_info.param_requires_grad[param_idx] + and not has_state_params[param_idx] + ) + + if is_padding or frozen_and_no_state: + # This memory range is a padding or the param is frozen and does + # not require gradient. For the later case, we treat it as a + # padding and add empty values to the local_buffers. + + padding_begin, padding_end = mem_offset, mem_offset + numel - 1 + if padding_begin <= begin <= padding_end: + # The range is an align padding before the first parameter in + # the shard. The shard includes parts of this align padding. + padding_len = ( + padding_end - begin + 1 + if end >= padding_end + else end - begin + 1 + ) + elif padding_begin <= end <= padding_end: + # The range is an align padding after the last parameter in + # the shard. The shard includes parts of this align padding. + padding_len = ( + end - padding_begin + 1 + if begin <= padding_begin + else end - begin + 1 + ) + elif begin < padding_begin <= padding_end < end: + # The range is an align padding that is completely in the + # shard. + padding_len = numel + else: + padding_len = 0 + if padding_len: + local_buffers.append(empty_func(padding_len)) + + if not is_padding: + # This memory range is a parameter in FlatParameter. So there + # should be an corresponding state in the optimizer unless the + # parameter is frozen, which we treat it as a padding above. + + # We need to check if this rank owns the buffer. If this is None: + # 1.) the rank does not own any part of the original parameter. + # As a result, there is no corresponding optimizer state on + # the rank as well. + # 2.) the parameter is frozen AND no optimizer state for the + # parameter. If a parameter is frozen, there can still be + # optimizer state if the parameter is not frozen in the + # previous steps. + if buffers[param_idx] is not None: + local_buffers.append(cast(torch.Tensor, buffers[param_idx])) + param_idx += 1 + + mem_offset += numel + + shard_numel_padded = flat_param._sharded_size.numel() - ( + sum(t.numel() for t in local_buffers) + ) + + assert flat_param._shard_numel_padded == shard_numel_padded, ( + "Manually calculated _sharded_numel_padded is incorrect. " + f"_shard_numel_padded={flat_param._shard_numel_padded}, " + f"shard_numel_padded={shard_numel_padded}, " + f"_sharded_size.numel={flat_param._sharded_size.numel()}, " + f"_numels_with_padding={flat_param._numels_with_padding}, " + f"begin={begin}, end={end}," + ) + if shard_numel_padded > 0: + # Add right-handed padding. + local_buffers.append(empty_func(shard_numel_padded)) + local_shard = torch.cat(local_buffers) + assert local_shard.numel() * fsdp_state.world_size == gathered_tensor.numel(), ( + "The size of local shard times the world size should equal to the " + "gathered tensor size. The inconsistency may be from a bug of " + "FlatParameter's metadata or the reconstruction logic in optimizer " + "state dict." + ) + fsdp_state._device_handle.synchronize() + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + gathered_tensor, local_shard, group=fsdp_state.process_group + ) + # Synchronize can be slow but this will be easier for us to debug. + fsdp_state._device_handle.synchronize() + + unpadded_tensor = gathered_tensor[: flat_param._unpadded_unsharded_size.numel()] + flat_param_handle = fsdp_param_info.handle + orig_states = flat_param_handle._get_unflat_views_aligned(unpadded_tensor) + assert len(orig_states) == len(fsdp_param_info.param_indices), ( + "The number of parameters from FlatParameter is not consistent to " + "the number of states used by optimizer state dict reconstruction " + "logic." + ) + for fqn, idx in fsdp_param_info.param_indices.items(): + if fsdp_param_info.param_requires_grad[idx] or fqn in output_states: + output_states[fqn][state_name] = orig_states[idx] + + _unflatten_orig_param_states( + fsdp_param_info, + output_states, + state_name, + shard_state, + to_save, + cpu_offload, + ) + + del gathered_tensor + return output_states + + +def _gather_all_orig_param_state( + fsdp_param_info: FSDPParamInfo, + input_states: dict[str, Any], + shard_state: bool, + to_save: bool, + cpu_offload: bool, +) -> dict[str, Any]: + """ + Given a optimizer state dict, ``input_states``, which the keys are FQNs to the + original parameters (not FlatParameters nor parameter ID), gather all the + states and unflatten them to the original dimensions. Note that all the + params referred by the ``input_states`` must be managed by FSDP. + """ + fsdp_state = fsdp_param_info.state + if ( + fsdp_state.world_size == 1 + or fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + ): + return input_states if to_save else {} + + with SimpleProfiler.profile(SimpleProfiler.Type.RESHARDING): + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER_OBJ): + gathered_state_info = _allgather_state_info(fsdp_state, input_states) + output_states = _allgather_orig_param_states( + fsdp_param_info, + gathered_state_info, + input_states, + shard_state, + to_save, + cpu_offload, + ) + if to_save: + for key, idx in fsdp_param_info.param_indices.items(): + if key in output_states: + continue + if not fsdp_param_info.param_requires_grad[idx]: + continue + + raise RuntimeError( + f"{key} is not in the output state. " + "The FSDPParamInfo has the param keys " + f"{sorted(fsdp_param_info.param_indices.keys())} while " + "the output_states has the param keys " + f"{sorted(output_states.keys())}." + ) + return output_states + else: + return {} + + +def _convert_state_with_orig_params( + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool = True, +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} + # This variable is used to deduplicate the FSDPParamInfo as one FSDPParamInfo + # usually corresponds to multiple parameters. We could not use FSDPParamInfo + # as the key because FSDPParamInfo is not hashable. As a result, we fall back + # to `id(FSDPParamInfo)`, which the type is an integer. + all_states: dict[int, dict[str, Any]] = {} + # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers + # across ranks + for optim_state_key in all_optim_state_keys: + param_key: Union[str, int, None] = optim_state_key_to_param_key.get( + optim_state_key, None + ) + + if param_key is None and not optim_state_key.is_fsdp_managed: + continue + + if optim_state_key.is_fsdp_managed: + fqn = optim_state_key.unflat_param_names[0] + fsdp_param_info = fqn_to_fsdp_param_info.get(fqn, None) + if fsdp_param_info is None: + # This can happen if the not all FSDP instances have all the + # parameters. This can happen with FSDP + some MPMD style + # parallelism. + + # TODO: it is unclear if we need to do the same check with + # non-FSDP managed keys. + continue + state = {} if param_key is None else optim_state_dict[param_key] + if id(fsdp_param_info) not in all_states: + all_states[id(fsdp_param_info)] = {} + all_states[id(fsdp_param_info)][fqn] = state + + elif to_save: + assert len(optim_state_key.unflat_param_names) == 1 + unflat_param_name = optim_state_key.unflat_param_names[0] + with SimpleProfiler.profile("none_fsdp_managed_copy"): + param_key = cast(Union[str, int], param_key) + fsdp_osd_state[unflat_param_name] = copy.copy( + optim_state_dict[param_key] + ) + if cpu_offload: + for state_name, value in sorted_items( + fsdp_osd_state[unflat_param_name] + ): + if not torch.is_tensor(value): + continue + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + # Instead of gathering the state of each parameter individually, we perform + # the gathering all at once to speed up the process. + for _all_states in all_states.values(): + fqn = next(iter(_all_states.keys())) + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + assert len(fsdp_param_info.param_requires_grad) > 0, ( + "With use_orig_params, FSDPParamInfo should have requires_grad " + "information. However, the length is zero." + ) + for key, idx in fsdp_param_info.param_indices.items(): + if key in _all_states: + continue + if not fsdp_param_info.param_requires_grad[idx]: + continue + raise RuntimeError( + f"{key} is not in the optimizer state. " + "The FSDPParamInfo has the param keys " + f"{sorted(fsdp_param_info.param_indices.keys())} while " + "the optimizer has the param keys " + f"{sorted(_all_states.keys())}." + ) + fsdp_osd_state.update( + _gather_all_orig_param_state( + fsdp_param_info, + _all_states, + shard_state, + to_save, + cpu_offload, + ) + ) + + return fsdp_osd_state + + +def _convert_state_with_flat_params( + all_optim_state_keys: list[_OptimStateKey], + optim_state_key_to_param_key: dict[_OptimStateKey, Union[int, str]], + fqn_to_fsdp_param_info: dict[str, FSDPParamInfo], + optim_state_dict: dict[Union[str, int], Any], + to_save: bool, + shard_state: bool, + cpu_offload: bool = True, +) -> dict[str, Any]: + fsdp_osd_state: dict[str, Any] = {} + # Iterate in rank 0's flat parameter ID order to ensure aligned all-gathers + # across ranks + for optim_state_key in all_optim_state_keys: + param_key: Union[str, int, None] = optim_state_key_to_param_key.get( + optim_state_key, None + ) + + assert param_key is not None, ( + "If use_orig_params is False, we must be able to find the " + f"corresponding param id. {optim_state_key} {param_key}" + ) + + if optim_state_key.is_fsdp_managed: + # If there are multiple unflat_param_names (not use_orig_params), + # they share the same FSDPParamInfo. So the first unflat_param_name + # is sufficient to fetch the FSDPParamInfo. + fqn = optim_state_key.unflat_param_names[0] + fsdp_param_info = fqn_to_fsdp_param_info[fqn] + unflat_state = _unflatten_optim_state( + fsdp_param_info, + optim_state_dict[param_key], + to_save, + shard_state, + cpu_offload, + ) + if to_save: + assert len(unflat_state) == len(optim_state_key.unflat_param_names) + fsdp_osd_state.update( + zip( + optim_state_key.unflat_param_names, + unflat_state, + ) + ) + elif to_save: + assert len(optim_state_key.unflat_param_names) == 1 + unflat_param_name = optim_state_key.unflat_param_names[0] + fsdp_osd_state[unflat_param_name] = copy.copy(optim_state_dict[param_key]) + if cpu_offload: + for state_name, value in sorted_items( + fsdp_osd_state[unflat_param_name] + ): + if not torch.is_tensor(value): + continue + fsdp_osd_state[unflat_param_name][state_name] = value.cpu() + + return fsdp_osd_state + + +@torch.no_grad() +def _optim_state_dict( + model: nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[nn.Parameter], + ] + ], + rank0_only: bool, + shard_state: bool, + group: Optional[dist.ProcessGroup], + using_optim_input: bool, + use_orig_params: bool = False, + cpu_offload: bool = True, +) -> dict[str, Any]: + """ + Consolidates the optimizer state and returns it as a :class:`dict` + following the convention of :meth:`torch.optim.Optimizer.state_dict`, + i.e. with keys ``"state"`` and ``"param_groups"``. + The flat parameters in ``FSDP`` modules contained in ``model`` are mapped + back to their unflattened parameters. + + Parameter keys are not well-defined. For a regular optimizer, the optimizer + state_dict contains a mapping from parameter IDs to parameter states. + Parameter IDs are the order of parameters in ``optim.param_groups()`` across + all the groups. This API also allows user to pass ``optim_input`` for the + mapping between parameters and parameter IDs. Using ``optim_input`` is being + deprecated. + + If the optimizer is a ``NamedOptimizer``, the optimizer state_dict does not + contain parameter IDs mapping but a mapping from parameter FQNs to parameter + states. This API finds the mapping from FQNs to parameters if the optimizer + is a ``NamedOptimizer``. + + If ``use_orig_params`` is True, each rank will have all FSDP-managed + parameters but some of these parameters may be empty due to the sharding. + For a regular optim.Optimizer, states for those empty parameters will + not be initialized. So, when aggregating the FQNs across ranks, no assert + will be raised on a rank even if it does not have all the states -- it is + valid and FSDP knows how to aggregate them. However, FSDP has to ignore + handling those parameters that are not managed by FSDP and do not exist on + the local rank -- those are managed by other parallelisms and FSDP does not + know how to handle/aggregate them. + + Args: + model (nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + shard_state (bool): If ``True``, shard and distribute all + non-zero-dimension states. + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=False``, + then nonzero ranks return an empty :class:`dict`. + """ + SimpleProfiler.reset() + cm = ExitStack() + cm.enter_context(SimpleProfiler.profile(SimpleProfiler.Type.ALL)) + _reset_flat_param_grad_info_if_needed(traversal_utils._get_fsdp_handles(model)) + to_save = not rank0_only or dist.get_rank(group) == 0 or shard_state + + with SimpleProfiler.profile("preprocessing"): + param_to_fqns = _get_param_to_fqns(model) + flat_param_to_fqn = _get_flat_param_to_fqn(model) + is_named_optimizer = _is_named_optimizer(optim_state_dict) + + param_key_to_param = cast( + dict[Union[int, str], nn.Parameter], + ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_key_to_param( + optim, model, is_named_optimizer, param_to_fqns, flat_param_to_fqn + ) + ), + ) + fqn_to_fsdp_param_info = _get_fqn_to_fsdp_param_info(model) + + with SimpleProfiler.profile("preprocessing_with_comm"): + ( + all_optim_state_keys, + optim_state_key_to_param_key, + ) = _map_param_key_to_optim_keys( + optim_state_dict, + group, + param_key_to_param, + param_to_fqns, + fqn_to_fsdp_param_info, + merge_keys=use_orig_params, + ) + + with SimpleProfiler.profile("state_converting"): + convert_fn = ( + _convert_state_with_orig_params + if use_orig_params + else _convert_state_with_flat_params + ) + fsdp_osd_state = convert_fn( + all_optim_state_keys, + optim_state_key_to_param_key, + fqn_to_fsdp_param_info, + optim_state_dict["state"], + to_save, + shard_state, + cpu_offload, + ) + + # At this point, communication is complete and ranks can return early if nothing + # will be saved on that rank. + if not to_save: + return {} + + fsdp_osd: dict[str, Any] = {"state": fsdp_osd_state} + + flat_param_fqns = set(flat_param_to_fqn.values()) + for key, value in optim_state_dict["state"].items(): + if key in fsdp_osd_state: + continue + if key in flat_param_fqns: + continue + if key in param_key_to_param: + continue + # This key is not recognized by FSDP. It may be a user-defined state + # or some parameters state that FSDP is unable to map from + # ``optim.param_groups``. + warnings.warn( + f"Found a optim state, {key}, that FSDP cannot process. FSDP " + "will directly copy everything to the returned state_dict. In " + "most cases, this is a user-defined state that is not " + "associated with any particular parameter. Another possible " + "case is this state is managed by TorchRec. Otherwise, there may " + " be a mismatched assumption of optim_state_dict of this mode." + ) + fsdp_osd_state[key] = value + + if "param_groups" in optim_state_dict: + fsdp_osd["param_groups"] = _unflatten_param_groups( + optim_state_dict, param_key_to_param, param_to_fqns + ) + + cm.close() + SimpleProfiler.dump_and_reset("FSDP _optim_state_dict() profiling: ") + + return fsdp_osd + + +def _get_fqn_to_fsdp_param_info(model: nn.Module) -> dict[str, FSDPParamInfo]: + """ + Construct the mapping from a param's fqn to its corresponding ``FSDPParamInfo`` + if the param is managed by FSDP. Shared parameters, or original parameters that + are shared across multiple nn.Modules, are required to belong to one and only + one FSDP instance and thus correspond to one ``FlatParameter``. Within the one + ``FlatParameter``, ``FlatParameter._fqns`` only stores the first FQN of a shared + parameter. Thus, the keys in the mapping are guaranteed to map to unique parameters. + """ + + def module_fn(module, prefix, tree_level, fqn_to_param_info): + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state is None: + return + _lazy_init(fsdp_state, module) + handle = _module_handle(fsdp_state, module) + if not handle: + return + flat_param = handle.flat_param + fsdp_param_info = FSDPParamInfo(fsdp_state, handle, {}, []) + # NOTE: `idx` indexes into the data structures *without* padding + # elements + for idx, local_fqn in enumerate(flat_param._fqns): + fqn = clean_tensor_name(prefix + local_fqn) + if fqn in fqn_to_param_info: + assert fqn_to_param_info[fqn].handle.flat_param is flat_param, fqn + fqn_to_param_info[fqn] = fsdp_param_info + fsdp_param_info.param_indices[fqn] = idx + if flat_param._params is not None: + fsdp_param_info.param_requires_grad.append( + flat_param._params[idx].requires_grad + ) + + def return_fn(fqn_to_param_info): + return fqn_to_param_info + + fqn_to_param_info: dict[str, FSDPParamInfo] = {} + # FlatParameter._fqns stores the local fqn, starting from the root of the + # FSDP. Using _apply_to_modules() with model (may not be the FSDP root + # module) allows us to construct the global fqn. + return _apply_to_modules( + model, + module_fn, + return_fn, + [fqn for fqn, _ in _named_parameters_with_duplicates(model)], + fqn_to_param_info, + ) + + +@no_type_check +def _set_optim_use_dtensor( + fsdp_state: _FSDPState, + state_dict_settings: StateDictSettings, +) -> None: + # If device_mesh is passed in when initializing FSDP, we automatically turn the + # _use_dtensor flag to be true for ShardedOptimStateDictConfig() if state_dict_type + # has to be set to SHARDED_STATE_DICT. + if getattr(fsdp_state, "_device_mesh", None): + state_dict_type = state_dict_settings.state_dict_type + if state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise RuntimeError( + "Found state_dict_type LOCAL_STATE_DICT.", + "DeviceMesh is not compatible with LOCAL_STATE_DICT.", + "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", + ) + else: + state_dict_settings.optim_state_dict_config._use_dtensor = True diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_runtime_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_runtime_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f0971160a1b45a9183d4ed4e3baa7e92062e9194 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_runtime_utils.py @@ -0,0 +1,1645 @@ +# mypy: allow-untyped-defs +import functools +import logging +from enum import auto, Enum +from typing import Any, Callable, no_type_check, Optional + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +from torch.autograd.graph import register_multi_grad_hook +from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS +from torch.distributed.fsdp._common_utils import ( + _assert_in_training_states, + _FSDPState, + _get_module_fsdp_state, + _is_composable, + _log_post_backward_hook, + _no_dispatch_record_stream, + clean_tensor_name, + TrainingState, +) +from torch.distributed.fsdp._flat_param import ( + FlatParameter, + FlatParamHandle, + HandleShardingStrategy, + HandleTrainingState, + RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES, +) +from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES +from torch.distributed.fsdp.api import BackwardPrefetch +from torch.distributed.utils import ( + _apply_to_tensors, + _cast_forward_inputs, + _p_assert, + _to_kwargs, +) +from torch.utils import _pytree as pytree + + +logger = logging.getLogger(__name__) + +# Do not include "process_group" to enable hybrid shard and MoE cases +HOMOGENEOUS_ATTR_NAMES = ( + "_use_orig_params", + "limit_all_gathers", + "_use_full_prec_in_eval", +) + + +class _PrefetchMode(Enum): + BACKWARD = auto() + FORWARD = auto() + + +def _get_fsdp_root_states_with_modules( + module: nn.Module, +) -> tuple[list[_FSDPState], list[nn.Module]]: + """ + Returns a tuple containing: + 1. A list of the root ``_FSDPState`` instances in the module tree rooted at + ``module`` without any duplicates and following the ``module.modules()`` + traversal order (which is assumed to be depth-first). + 2. A corresponding list of the root modules owning the states in the first + list. + + This is similar to :func:`_get_fsdp_states_with_modules` except that we + must call :func:`_is_fsdp_root` to force a lazy initialization to determine + the FSDP root in case lazy initialization has not yet happened. + """ + fsdp_root_states: list[_FSDPState] = [] + fsdp_root_modules: list[nn.Module] = [] + visited_fsdp_states: set[_FSDPState] = set() + # NOTE: This function assumes that `module.modules()` proceeds top-down. + for submodule in module.modules(): + optional_state = _get_module_fsdp_state(submodule) + if ( + optional_state is not None + and optional_state not in visited_fsdp_states + and _is_fsdp_root(optional_state, submodule) + ): + visited_fsdp_states.add(optional_state) + fsdp_root_states.append(optional_state) + fsdp_root_modules.append(submodule) + return fsdp_root_states, fsdp_root_modules + + +def _get_fsdp_root_states(module: nn.Module) -> list[_FSDPState]: + """See :func:`_get_fsdp_root_states_with_modules`.""" + fsdp_root_states, _ = _get_fsdp_root_states_with_modules(module) + return fsdp_root_states + + +def _is_fsdp_root(state: _FSDPState, module: nn.Module) -> bool: + """ + Returns if ``state`` corresponds to that of an FSDP root. + + For the wrapper code path, ``state`` and ``module`` should be the same. For + the non-wrapper code path, ``state`` should be ``module`` 's state. + """ + # Force a lazy initialization to determine the FSDP root + _lazy_init(state, module) + assert state._is_root is not None # mypy + return state._is_root + + +@no_type_check +def _lazy_init( + state: _FSDPState, + root_module: nn.Module, +) -> _FSDPState: + """ + Performs initialization lazily, typically right before the first forward + pass. The laziness is needed to ensure that the parameter device/dtype and + the FSDP hierarchy have finalized. This method's actual logic only runs on + the root FSDP instance, which performs initialization for all non-root FSDP + instances to avoid partial initialization. + + For the non-composable code path, ``state`` and ``root_module`` should be + the same, namely the FSDP instance itself. + """ + if state._is_root is not None: + return # no-op: already lazily initialized + if not state._device_handle.is_available(): + # Allow the FSDP constructor to run even without CUDA but check this + # once we start real execution + raise RuntimeError("FSDP does not support CPU only execution") + # The following logic is only run on the root FSDP instance since it will + # set `_is_root=False` for the non-root instances + state._is_root = True + _assert_in_training_states(state, [TrainingState.IDLE]) + _check_flat_params_on_expected_device(state, root_module) + state._all_fsdp_states = traversal_utils._get_fsdp_states(root_module) + _init_streams(state) + buffers, buffer_dtypes = _get_buffers_and_dtypes_for_computation(state, root_module) + _cast_buffers_to_dtype_and_device(buffers, buffer_dtypes, state.compute_device) + state._exec_order_data.init(state, root_module, state.process_group) + _share_state_and_init_handle_attrs(state, root_module) + return state + + +def _check_flat_params_on_expected_device(state: _FSDPState, module: nn.Module): + """ + Checks that all ``FlatParameter``s in ``module`` 's tree managed by + ``state`` are on the expected device for *lazy initialization*. + """ + cpu_device = torch.device("cpu") + for handle in traversal_utils._get_fsdp_handles(module): + if ( + not handle._offload_params + and handle.flat_param.device != state.compute_device + ): + raise RuntimeError( + "An FSDP-managed module unexpectedly has parameters on " + f"{handle.flat_param.device}. Make sure to move the module to " + f"{state.compute_device} before training." + ) + elif handle._offload_params and handle.flat_param.device != cpu_device: + raise RuntimeError( + "An FSDP-managed module with parameter CPU offloading enabled " + f"has parameters on {handle.flat_param.device}. Make sure to " + f"not move the module from CPU when offloading parameters." + ) + + +@no_type_check +def _share_state_and_init_handle_attrs( + root_state: _FSDPState, + root_module: nn.Module, +) -> None: + """ + Shares data structure state from the ``root_state`` to all FSDP states in + ``root_module`` 's module tree, and initializes handle attributes. These + are done together to require a single loop over the states. + """ + handle = root_state._handle + if handle: + handle.init_flat_param_attributes() + attr_name_to_values: dict[str, set[Any]] = {} + for attr_name in HOMOGENEOUS_ATTR_NAMES: + attr_name_to_values[attr_name] = set() + root_state._all_handles = root_state._exec_order_data.all_handles # share reference + # Update _has_optim_in_backward for each handle. + for handle in root_state._all_handles: + flat_param = handle.flat_param + if hasattr(flat_param, "_in_backward_optimizers"): + raise RuntimeError( + "FSDP optimizer in backward only supported with use_orig_params=True!" + ) + handle._has_optim_in_backward = flat_param._params is not None and any( + hasattr(param, "_in_backward_optimizers") for param in flat_param._params + ) + if handle._has_optim_in_backward: + torch._C._log_api_usage_once("fsdp.optimizer_in_backward") + for fsdp_state in root_state._all_fsdp_states: + for attr_name in HOMOGENEOUS_ATTR_NAMES: + _p_assert( + hasattr(fsdp_state, attr_name), + f"FSDP state missing attribute {attr_name}", + ) + attr_name_to_values[attr_name].add(getattr(fsdp_state, attr_name)) + if fsdp_state is root_state: + continue + # Relax the assert for non-root FSDP instances in case the nested + # initialized module is wrapped again in FSDP later (e.g. after + # training to run inference) + _p_assert( + fsdp_state._is_root is None or not fsdp_state._is_root, + "Non-root FSDP instance's `_is_root` should not have been " + "set yet or should have been set to `False`", + ) + fsdp_state._is_root = False + fsdp_state._unshard_stream = root_state._unshard_stream + fsdp_state._post_backward_stream = root_state._post_backward_stream + fsdp_state._pre_unshard_stream = root_state._pre_unshard_stream + fsdp_state._all_reduce_stream = root_state._all_reduce_stream + fsdp_state._default_stream = root_state._default_stream + fsdp_state._exec_order_data = root_state._exec_order_data + fsdp_state._free_event_queue = root_state._free_event_queue + if fsdp_state._fsdp_extension is not None: + fsdp_state._fsdp_extension.compute_stream = root_state._default_stream + handle = fsdp_state._handle + if handle: + handle.init_flat_param_attributes() + for attr_name, attr_values in attr_name_to_values.items(): + if len(attr_values) != 1: + raise ValueError( + f"Expects one homogeneous value for {attr_name} but got {attr_values}" + ) + + +@no_type_check +def _init_streams( + state: _FSDPState, +) -> None: + """ + Initializes CUDA streams for overlapping communication, computation, and + data transfers. The streams should be shared across FSDP instances. + """ + assert state._is_root + assert state._device_handle.is_available() + uses_hybrid_sharding = any( + fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES + for fsdp_state in state._all_fsdp_states + ) + # Prioritize all-gathers/reduce-scatters over async all-reduce for HSDP and + # preserve the default priority of 0 otherwise + high_priority = -1 if state.limit_all_gathers and uses_hybrid_sharding else 0 + # Default stream for computation + state._default_stream = state._device_handle.current_stream() + if state._fsdp_extension is not None: + # set the compute stream to the FSDP extension + state._fsdp_extension.compute_stream = state._default_stream + + # Stream for unshard logic, including allocating the all-gather destination + # tensors and the all-gathers themselves + state._unshard_stream = state._device_handle.Stream(priority=high_priority) + # Stream for overlapping gradient reduction with the backward pass gradient + # computation + state._post_backward_stream = state._device_handle.Stream(priority=high_priority) + # Stream for pre-unshard logic, namely allocations and writes for CPU + # offloading (H2D copy) and mixed precision (low precision cast) + state._pre_unshard_stream = state._device_handle.Stream(priority=high_priority) + # Stream to run HSDP's all-reduce as async (if using HSDP) + state._all_reduce_stream = ( + state._device_handle.Stream() if uses_hybrid_sharding else state._default_stream + ) + + +@no_type_check +def _unshard( + state: _FSDPState, + handle: FlatParamHandle, + unshard_stream: torch.Stream, + pre_unshard_stream: torch.Stream, +) -> None: + """ + Unshards the handles in ``handles``. If the handles are in + :meth:`summon_full_params` and are using mixed precision, then they are + forced to full precision. + + Postcondition: handle's ``FlatParameter`` 's data is the padded + unsharded flat parameter on the compute device. + """ + if not handle: + return + with state._device_handle.stream(pre_unshard_stream): + ran_pre_unshard = handle.pre_unshard() + if ran_pre_unshard: + unshard_stream.wait_stream(pre_unshard_stream) + if state.limit_all_gathers: + event = state._free_event_queue.dequeue_if_needed() + if event: + with torch.profiler.record_function( + "FullyShardedDataParallel.rate_limiter" + ): + event.synchronize() + with state._device_handle.stream(unshard_stream): + handle.unshard() + handle.post_unshard() + + +@no_type_check +def _reshard( + state: _FSDPState, + handle: FlatParamHandle, + free_unsharded_flat_param: bool, +): + """ + Reshards the handle. ``free_unsharded_flat_param`` indicates whether to + free the handle's padded unsharded flat parameter. + """ + handle.reshard(free_unsharded_flat_param) + if state.limit_all_gathers and free_unsharded_flat_param: + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + # We don't run a even queue for freeing under torch compile atm + # But maybe we need to? TODO(voz): Look into this + free_event = state._device_handle.Event() + free_event.record() + state._free_event_queue.enqueue(free_event) + handle.post_reshard() + # Flat parameter freed or not, we always have to "unshard" the parameter + # upon next access to get its shape correct. + handle._prefetched = False + + +def _unshard_grads( + handle: Optional[FlatParamHandle], +) -> None: + if handle: + handle.unshard_grad() + + +def _reshard_grads( + handle: Optional[FlatParamHandle], +) -> None: + if handle: + handle.reshard_grad() + + +@no_type_check +def _pre_forward( + state: _FSDPState, + handle: Optional[FlatParamHandle], + unshard_fn: Callable, + module: nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> tuple[tuple[Any, ...], dict[str, Any]]: + """ + Runs the pre-forward logic. This includes an opportunity to unshard + currently sharded parameters such as those for the current forward and + registering post-backward hooks for these current parameters. This function + also converts forward ``args`` and ``kwargs`` to the given precision. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters used in + the current forward. + unshard_fn (Optional[Callable]): A callable to unshard any currently + sharded parameters or ``None`` to not do any unsharding. + module (nn.Module): Module whose forward this method runs right before; + expected by the hook signature. + args (Tuple[Any, ...]): Module forward ``args``. + kwargs (Dict[str, Any]): Module forward ``kwargs``. + """ + with torch.profiler.record_function("FullyShardedDataParallel._pre_forward"): + # For `fully_shard` + `checkpoint`, skip pre-forward logic in the + # recomputed forward + if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: + # For both checkpoint implementations, we do not need to re-cast + # inputs here since they will be checkpointed in the low precision + # either by AC or normally by autograd as long as the AC region is + # nested within FSDP + return args, kwargs + state.training_state = TrainingState.FORWARD_BACKWARD + state._exec_order_data.record_pre_forward(handle, module.training) + if handle: + handle._training_state = HandleTrainingState.FORWARD + if unshard_fn is not None: + unshard_fn(state, handle) + # Register post-backward hooks to reshard the parameters and reduce-scatter + # their gradients. They must be re-registered every forward pass in case + # the `grad_fn` is mutated. + _register_post_backward_hook(state, handle) + # We have to reallocate the _cpu_grad if optimizer overlap + # set the grad to None in the backward pass. + if handle and handle._offload_params and handle.flat_param._cpu_grad is None: + handle.flat_param._cpu_grad = torch.zeros_like( + handle.flat_param._local_shard, device=torch.device("cpu") + ).pin_memory() + + should_cast_forward_inputs = ( + state._handle and not state._handle._force_full_precision + ) + + if should_cast_forward_inputs and state.mixed_precision.cast_forward_inputs: + # Recursively convert args and kwargs to specified precision. + input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype + args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) + _register_post_backward_reshard_only_hook(state, handle, args, kwargs) + return args, kwargs + + +@no_type_check +def _pre_forward_unshard( + state: _FSDPState, + handle: Optional[FlatParamHandle], +) -> None: + """Unshards parameters in the pre-forward.""" + if not handle: + return + # If the handles have been prefetched, then there is no need to call + # `_unshard()` again + if not handle._prefetched: + _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) + handle._needs_pre_forward_unshard = False + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + current_stream = state._device_handle.current_stream() + if state._unshard_event is not None: + current_stream.wait_event(state._unshard_event) + state._unshard_event = None + else: + current_stream.wait_stream(state._unshard_stream) + with torch.profiler.record_function( + "FullyShardedDataParallel._pre_forward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.FORWARD) + + +@no_type_check +def _post_forward( + state: _FSDPState, + handle: Optional[FlatParamHandle], + reshard_fn: Callable, + module: nn.Module, + input: Any, + output: Any, +) -> Any: + """ + Runs the post-forward logic. This includes an opportunity to reshard + currently unsharded parameters such as those used in the current forward + and registering pre-backward hooks on the forward outputs. + + Args: + handles (List[FlatParamHandle]): Handles giving the parameters used in + the current forward. + reshard_fn (Optional[Callable]): A callable to reshard any currently + unsharded parameters (e.g. from the current forward) or ``None`` to + not do any resharding. + module (nn.Module): Module whose forward just ran, which should be a + fully sharded module (see [Note: Fully Sharded Module]); expected + by the hook signature. + input (Any): Unused; expected by the hook signature. + output (Any): Forward pass output; pre-backward hooks are registered on + the tensors that require gradients in this output. + + Postcondition: Each ``FlatParameter`` 's data points to the sharded flat + parameter. + """ + with torch.profiler.record_function("FullyShardedDataParallel._post_forward"): + # For `fully_shard` + `checkpoint`, skip post-forward logic in the + # recomputed forward + if handle and handle._training_state == HandleTrainingState.BACKWARD_PRE: + return output + + state._exec_order_data.record_post_forward(handle) + if reshard_fn is not None: + reshard_fn(state, handle) + # Register pre-backward hooks to unshard the flat parameters for the + # gradient computation (if needed) + output = _register_pre_backward_hooks(state, module, output, handle) + state.training_state = TrainingState.IDLE + if handle: + handle._training_state = HandleTrainingState.IDLE + return output + + +@no_type_check +def _post_forward_reshard( + state: _FSDPState, + handle: FlatParamHandle, +) -> None: + """Reshards parameters in the post-forward.""" + if not handle: + return + # Do not free the root's parameters in the post-forward for `FULL_SHARD` + # with the intention that they are immediately used for backward + # computation (though this may not be true) + free_unsharded_flat_param = ( + not state._is_root + and handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + _reshard(state, handle, free_unsharded_flat_param) + + +@no_type_check +def _root_pre_forward( + state: _FSDPState, + module: nn.Module, + args, + kwargs, +) -> None: + """ + Runs pre-forward logic specific to the root FSDP instance, which should run + before any individual module's pre-forward. This starts with an attempt at + lazy initialization (which only runs non-vacuously once). Otherwise, if + this is called on a non-root FSDP instance, then it returns directly. + + Args: + module (nn.Module): Module for which this logic tries to run. It may or + may not be the root. If not, then this method does not do anything. + """ + with torch.profiler.record_function("FullyShardedDataParallel._root_pre_forward"): + _lazy_init(state, module) + _p_assert(state._is_root is not None, "Expects a root FSDP to have been set") + if not state._is_root: + # Always cast forward inputs in the root of this local FSDP unit for mixed + # precision, as this is where mixed precision could be configured. + # This is more useful for auto wrapping that is recommended in composable path. + # For manual wrapping, cast forward inputs on each local FSDP unit root will + # increase some overhead, so not turned on for model wrapper path right now where + # manual wrapping is more broadly used. + if _is_composable(state): + return _root_cast_forward_input(state, module, args, kwargs) + return args, kwargs + + # We cast buffers back to full precision if we're forcing full precision. Disjointly, we check if buffers + # are in full precision and if we should cast them back to lower precision, which happens when + # exiting eval() mode. + handle = state._handle + if handle: + should_cast_buffers_to_full_prec = handle._force_full_precision + else: + # If the root has no handle (no managed parameters), then we fall + # back to checking if any child wants to force full precision as a + # workaround + handles = traversal_utils._get_fsdp_handles(module) + should_cast_buffers_to_full_prec = any( + handle._force_full_precision for handle in handles + ) + + if should_cast_buffers_to_full_prec: + _cast_buffers_to_dtype_and_device( + buffers=dict(module.named_buffers()).values(), + buffer_dtypes=list(state._buffer_name_to_orig_dtype.values()), + device=state.compute_device, + ) + # This flag is only set when we cast buffers to full precision, to avoid the + # CPU overhead that can stem from retrieving all buffers and their types in the + # following else branch. + state._needs_buffer_dtype_restore_check = True + elif getattr(state, "_needs_buffer_dtype_restore_check", False): + # Check if buffers are in full precision and we need to cast them + # back down. + ( + buffers, + buffer_dtypes_for_computation, + ) = _get_buffers_and_dtypes_for_computation(state, module) + if len(buffers) > 0 and len(buffer_dtypes_for_computation) > 0: + if any( + buffer.dtype != buffer_dtype_for_computation + for buffer, buffer_dtype_for_computation in zip( + buffers, buffer_dtypes_for_computation + ) + ): + # Assume we have to cast everything if there is one mismatch + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes_for_computation, state.compute_device + ) + # We don't have to check this again until we cast buffers to full precision again. + state._needs_buffer_dtype_restore_check = False + + if state.forward_prefetch: + handles = [ + fsdp_state._handle + for fsdp_state in state._all_fsdp_states + if fsdp_state._handle + ] + for handle in handles: + handle._needs_pre_forward_unshard = True + handle._prefetched = False + _wait_for_computation_stream( + state._device_handle.current_stream(), + state._unshard_stream, + state._pre_unshard_stream, + ) + _reset_flat_param_grad_info_if_needed(state._all_handles) + + # Prepares the forward inputs by moving them to ``compute_device`` + # TODO: Do not use the side stream for tensor copies for now; investigate + # the perf with/without it. + with torch.profiler.record_function("FullyShardedDataParallel._to_kwargs"): + args_tuple, kwargs_tuple = _to_kwargs( + args, kwargs, state.compute_device, False + ) + args = args_tuple[0] if args_tuple else tuple() + kwargs = kwargs_tuple[0] if kwargs_tuple else {} + + return _root_cast_forward_input(state, module, args, kwargs) + + +@no_type_check +def _root_cast_forward_input( + state: _FSDPState, module: torch.nn.Module, args, kwargs +) -> tuple[Any, Any]: + if state._handle: + force_full_precision = not state._handle._force_full_precision + else: + force_full_precision = True + + should_cast_forward_inputs = ( + (module.training or not state._use_full_prec_in_eval) and force_full_precision + ) and state.mixed_precision.cast_root_forward_inputs + + if should_cast_forward_inputs: + input_dtype: Optional[torch.dtype] = state.mixed_precision.param_dtype + args, kwargs = _cast_forward_inputs(input_dtype, *args, **kwargs) + + return args, kwargs + + +@no_type_check +def _pre_backward_hook( + state: _FSDPState, + module: nn.Module, + handle: FlatParamHandle, + grad, + *unused: Any, +) -> Any: + """ + Prepares ``_handle`` 's ``FlatParameter`` s for gradient computation. + + Args: + module (nn.Module): Fully sharded module (see [Note: Fully Sharded + Module]). + """ + # Only run the pre-backward hook once per group of handles involved in the + # same module forward computation + if ( + handle + and hasattr(handle, "_ran_pre_backward_hook") + and handle._ran_pre_backward_hook + ): + return grad + + with torch.profiler.record_function("FullyShardedDataParallel._pre_backward_hook"): + # Queue the post-backward callback once for the root FSDP instance to + # attach it to the outermost backward graph task so that it is called + # after all backward calls complete + if state._is_root and not state._post_backward_callback_queued: + _register_post_backward_final_callback(state, module) + _reset_flat_param_grad_info_if_needed(state._all_handles) + elif handle: + allowed_states = [TrainingState.IDLE] + if _is_composable(state): + allowed_states.append(TrainingState.FORWARD_BACKWARD) + _assert_in_training_states(state, allowed_states) + state.training_state = TrainingState.FORWARD_BACKWARD + # Queueing the post-backward callback is the only logic that is not + # per-handle in the pre-backward hook, so we can return early here if + # there are no handles. + if not handle: + return grad + handle._training_state = HandleTrainingState.BACKWARD_PRE + + if handle._needs_pre_backward_unshard: + # If the handles have been prefetched, then there is no need to + # call `_unshard()` again + if not handle._prefetched: + _unshard( + state, + handle, + state._unshard_stream, + state._pre_unshard_stream, + ) + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._device_handle.current_stream().wait_stream(state._unshard_stream) + + # Set this to `False` to ensure that a mistargeted prefetch does not + # actually unshard these handles + handle._needs_pre_backward_unshard = False + with torch.profiler.record_function( + "FullyShardedDataParallel._pre_backward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) + handle.prepare_gradient_for_backward() + handle._ran_pre_backward_hook = True + return grad + + +@no_type_check +@torch.no_grad() +def _post_backward_hook( + state: _FSDPState, + handle: FlatParamHandle, + flat_param, + *unused: Any, +): + """ + Reduce-scatters the gradient of ``handle`` 's ``FlatParameter``. + + Precondition: The ``FlatParameter`` 's ``.grad`` attribute contains the + unsharded gradient for the local batch. + + Postcondition: + - If using ``NO_SHARD``, then the ``.grad`` attribute is the reduced + unsharded gradient. + - Otherwise, the ``_saved_grad_shard`` attribute is the reduced sharded + gradient (accumulating with any existing gradient). + """ + _log_post_backward_hook(state, handle, logger) + flat_param = handle.flat_param + flat_param._post_backward_called = True + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook" + ): + _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) + # For multiple applications of reentrant AC across submodules sharing + # the same `FlatParameter`, the post-backward hook may run multiple + # times in one backward, in which case we permit the state to already + # be in `BACKWARD_POST`. + _p_assert( + handle._training_state + in (HandleTrainingState.BACKWARD_PRE, HandleTrainingState.BACKWARD_POST), + f"Expects `BACKWARD_PRE` or `BACKWARD_POST` state but got {handle._training_state}", + ) + handle._training_state = HandleTrainingState.BACKWARD_POST + + if flat_param.grad is None: + return + if flat_param.grad.requires_grad: + raise RuntimeError("FSDP does not support gradients of gradients") + + _post_backward_reshard(state, handle) + if not state._sync_gradients: + if handle._use_orig_params: + handle._use_unsharded_grad_views() + return + + # Wait for all ops in the current stream (e.g. gradient computation) to + # finish before reduce-scattering the gradient + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._post_backward_stream.wait_stream( + state._device_handle.current_stream() + ) + + with state._device_handle.stream(state._post_backward_stream): + autograd_computed_grad = flat_param.grad.data + if ( + not _low_precision_hook_enabled(state) + and flat_param.grad.dtype != handle._reduce_dtype + # If we are forcing full precision but communicating grads + # (i.e. model.eval() + full precision in eval was configured), don't downcast gradient. + and not handle._force_full_precision + ): + flat_param.grad.data = flat_param.grad.to(handle._reduce_dtype) + if handle.uses_sharded_strategy: + _reduce_grad(state, handle) + else: + _reduce_grad_no_shard(state, handle) + # Since the unsharded gradient is produced in the computation + # stream and consumed in the post-backward stream, inform the + # caching allocator (before it goes out of scope) + _no_dispatch_record_stream( + autograd_computed_grad, state._post_backward_stream + ) + + +def _post_backward_reshard_only_hook( + state: _FSDPState, + handle: FlatParamHandle, + *unused: Any, +) -> None: + with torch.profiler.record_function( + "FullyShardedDataParallel._post_backward_hook_reshard_only" + ): + # `_pre_backward_hook` may not get executed + # if forward output does not require grad + # overwrite IDLE state for post-backward prefetching + state.training_state = TrainingState.FORWARD_BACKWARD + handle._training_state = HandleTrainingState.BACKWARD_POST + _post_backward_reshard(state, handle) + + +def _post_backward_reshard( + state: _FSDPState, + handle: FlatParamHandle, + *unused: Any, +) -> None: + free_unsharded_flat_param = _should_free_in_backward(state, handle) + _reshard(state, handle, free_unsharded_flat_param) + + # TODO: Post-backward prefetching does not support the multiple handles + # per module case since the post-backward hook runs per handle, not per + # group of handles. + with torch.profiler.record_function( + "FullyShardedDataParallel._post_backward_prefetch" + ): + _prefetch_handle(state, handle, _PrefetchMode.BACKWARD) + + +@no_type_check +def _should_free_in_backward( + state: _FSDPState, + handle: FlatParamHandle, +) -> bool: + """ + Returns whether FSDP should free the unsharded flat parameter in the + post-backward or not. + """ + if not handle.uses_sharded_strategy: + return False + # If not syncing gradients, then we do not free for strategies that do not + # reshard after forward as a *heuristic* to tradeoff higher memory for + # higher throughput. + return ( + state._sync_gradients + or handle._sharding_strategy in RESHARD_AFTER_FORWARD_HANDLE_STRATEGIES + ) + + +@no_type_check +def _reduce_grad(state: _FSDPState, handle: FlatParamHandle) -> None: + """ + For sharded strategies, this runs gradient reduction, sharded gradient + accumulation if needed, and the post-reduction callback. + """ + flat_param = handle.flat_param + uses_hybrid_sharded_strategy = handle._sharding_strategy in ( + HandleShardingStrategy.HYBRID_SHARD, + HandleShardingStrategy._HYBRID_SHARD_ZERO2, + ) + # We clear `.grad` to permit multiple backwards. This avoids a race where + # the second backward pass computation precedes ahead of the first backward + # pass reduction, which is possible since the reduction is issued in a + # separate stream and is async and would result in reducing the wrong + # gradient. + unsharded_grad = flat_param.grad.data + flat_param.grad = None + padded_unsharded_grad, new_sharded_grad = _get_reduce_scatter_tensors( + state, unsharded_grad + ) + if state._comm_hook is None: # default path + _div_if_needed(padded_unsharded_grad, state._gradient_predivide_factor) + pg = ( + handle._fake_process_group + if handle._use_fake_reduce + else state.process_group + ) + dist.reduce_scatter_tensor( + new_sharded_grad, + padded_unsharded_grad, + group=pg, + ) + if uses_hybrid_sharded_strategy: + # Don't wait during trace + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._all_reduce_stream.wait_stream(state._post_backward_stream) + with state._device_handle.stream(state._all_reduce_stream): + # Since the new sharded gradient is produced in the post- + # backward stream and consumed in the all-reduce stream, + # inform the caching allocator + _no_dispatch_record_stream(new_sharded_grad, state._all_reduce_stream) + dist.all_reduce(new_sharded_grad, group=state._inter_node_pg) + _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) + grad_to_offload = _accumulate_sharded_grad( + state, handle, new_sharded_grad + ) + _post_reduce_grad_callback(state, handle, grad_to_offload) + return + _div_if_needed(new_sharded_grad, state._gradient_postdivide_factor) + else: + state._comm_hook( + state._comm_hook_state, padded_unsharded_grad, new_sharded_grad + ) + # NOTE: HSDP variants do not support communication hook. + grad_to_offload = _accumulate_sharded_grad(state, handle, new_sharded_grad) + _post_reduce_grad_callback(state, handle, grad_to_offload) + + +@no_type_check +def _get_reduce_scatter_tensors( + state: _FSDPState, unsharded_grad: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Returns the input and output tensors to reduce-scatter, respectively. + """ + chunks = list(unsharded_grad.chunk(state.world_size)) + numel_to_pad = state.world_size * chunks[0].numel() - unsharded_grad.numel() + padded_unsharded_grad = ( + F.pad(unsharded_grad, [0, numel_to_pad]) if numel_to_pad > 0 else unsharded_grad + ) + new_sharded_grad = torch.empty_like(chunks[0]) # padded + return padded_unsharded_grad, new_sharded_grad + + +@no_type_check +def _accumulate_sharded_grad( + state: _FSDPState, + handle: FlatParamHandle, + sharded_grad: torch.Tensor, +) -> torch.Tensor: + """ + Accumulates the reduce-scattered sharded gradient with any existing sharded + gradient if needed, returning the gradient to offload (if CPU offloading is + enabled). + """ + flat_param = handle.flat_param + _cast_grad_to_param_dtype(state, sharded_grad, flat_param) + # Save the sharded gradient in `_saved_grad_shard` to support gradient + # accumulation -- for multiple backwards, the gradient reductions may + # happen in arbitrary order + accumulate_grad = hasattr(flat_param, "_saved_grad_shard") + if accumulate_grad: + _check_grad_to_accumulate(sharded_grad, flat_param._saved_grad_shard) + flat_param._saved_grad_shard += sharded_grad + else: + flat_param._saved_grad_shard = sharded_grad + grad_to_offload = flat_param._saved_grad_shard + return grad_to_offload + + +@no_type_check +def _reduce_grad_no_shard(state: _FSDPState, handle: FlatParamHandle) -> None: + """ + For no-shard, this runs gradient reduction (which directly covers any + gradient accumulation implicitly) and the post-reduction callback. + """ + flat_param = handle.flat_param + if state._comm_hook is None: # default path + _div_if_needed(flat_param.grad, state._gradient_predivide_factor) + dist.all_reduce(flat_param.grad, group=state.process_group) + _div_if_needed(flat_param.grad, state._gradient_postdivide_factor) + else: + state._comm_hook(state._comm_hook_state, flat_param.grad) + # For `NO_SHARD`, we can keep the low precision gradients by simply + # omitting the cast altogether + if not handle._keep_low_precision_grads: + _cast_grad_to_param_dtype(state, flat_param.grad, flat_param) + grad_to_offload = flat_param.grad.data + _post_reduce_grad_callback(state, handle, grad_to_offload) + + +@no_type_check +def _post_reduce_grad_callback( + state: _FSDPState, + handle: FlatParamHandle, + # Additional arguments needed for the callback logic + grad_to_offload: torch.Tensor, +): + """ + This callback captures any logic to run after the gradient reduction + finishes. Currently, this offloads the gradient to CPU if CPU offloading is + enabled and uses sharded gradient views if ``use_orig_params=True``. + """ + _offload_grad(state, handle, grad_to_offload) + _post_backward_use_sharded_grad_views(handle) + + +@no_type_check +def _offload_grad( + state: _FSDPState, + handle: FlatParamHandle, + grad_to_offload: torch.Tensor, +): + if not handle._offload_params: + return + # Offload the gradient to CPU to ensure parameters and gradients are on the + # same device as required by the optimizer + # TODO: Investigate why `NO_SHARD` breaks correctness when using + # `non_blocking=True` here. + # TODO (rohan-varma): When CPU offload and optimizer overlap, + # non_blocking=True won't work since the copy may have not finished before + # the optimizer step executes on CPU. If we want to use non-blocking=True + # here, we'll have to synchronize before using result on CPU. + non_blocking = handle.uses_sharded_strategy and not handle._has_optim_in_backward + handle.flat_param._cpu_grad.copy_( + grad_to_offload.detach(), non_blocking=non_blocking + ) # synchronized in the post-backward callback + # Since the gradient being offloaded may have been produced in the + # computation stream and is being consumed here in the post-backward + # stream, inform the caching allocator + _no_dispatch_record_stream(grad_to_offload.data, state._post_backward_stream) + + +@no_type_check +def _post_backward_use_sharded_grad_views(handle: FlatParamHandle): + if not handle._use_orig_params: + return + # Since the handle's `FlatParameter` completed its gradient computation, we + # should reset the gradient noneness mask + handle._reset_is_grad_none() + # Delay using sharded gradient views until after the reduce-scatter instead + # of immediately after resharding + handle._use_sharded_grad_views() + if handle._has_optim_in_backward: + handle.prepare_gradient_for_optim() + for orig_param in handle.flat_param._params: + # Check for `None` gradient to filter parameters not in the rank + if orig_param.grad is not None and hasattr( + orig_param, "_in_backward_optimizers" + ): + # TODO (rohan-varma): For CPU offload, this unfortunately + # operates on CPU because the parameters and gradients have + # already been offloaded. We should run this on GPU after + # refactoring. + for optim in orig_param._in_backward_optimizers: + optim.step() + + optim.zero_grad(set_to_none=True) + handle._reset_flat_param_grad_info_if_needed() + if handle._offload_params: + handle.flat_param._cpu_grad = None + + +def _div_if_needed(tensor: torch.Tensor, div_factor: float) -> None: + if div_factor > 1: + tensor.div_(div_factor) + + +@no_type_check +def _cast_grad_to_param_dtype( + state: _FSDPState, + sharded_grad: torch.Tensor, + param: FlatParameter, +): + """ + Casts ``sharded_grad`` back to the full parameter dtype so that the + optimizer step runs with that dtype. This performs an actual cast if + 1. parameters were in reduced precision during the forward since then + gradients would be in that reduced precision, or + 2. parameters were not in reduced precision but gradients were in + reduced precision for communication. + However, if a low precision communication hook is registered, then this + dtype cast happens in the hook instead. + """ + _assert_in_training_states(state, [TrainingState.FORWARD_BACKWARD]) + if not _low_precision_hook_enabled(state) and sharded_grad.dtype != param.dtype: + low_prec_grad_data = sharded_grad.data + sharded_grad.data = sharded_grad.data.to(dtype=param.dtype) + # Since for `NO_SHARD`, the gradient is produced in the computation + # stream and consumed here in the post-backward stream, inform the + # caching allocator; for the sharded strategies, the gradient is + # produced in the post-backward stream, so this `record_stream()` + # should be a no-op + _no_dispatch_record_stream( + low_prec_grad_data, state._device_handle.current_stream() + ) + + +def _check_grad_to_accumulate( + new_sharded_grad: torch.Tensor, + accumulated_grad: torch.Tensor, +) -> None: + _p_assert( + accumulated_grad.shape == new_sharded_grad.shape, + "Shape mismatch when accumulating gradients: " + f"existing gradient shape={accumulated_grad.shape} " + f"new gradient shape={new_sharded_grad.shape}", + ) + _p_assert( + accumulated_grad.device == new_sharded_grad.device, + "Device mismatch when accumulating gradients: " + f"existing gradient device={accumulated_grad.device} " + f"new gradient device={new_sharded_grad.device}", + ) + + +@no_type_check +def _low_precision_hook_enabled(state: _FSDPState) -> bool: + return state._comm_hook in LOW_PRECISION_HOOKS + + +@no_type_check +@torch.no_grad() +def _post_backward_final_callback( + state: _FSDPState, + module: nn.Module, +): + """ + This waits for the post-backward to finish and performs some final cleanup. + This runs at the end of the entire backward pass and should only be called + on the root FSDP instance. + """ + _p_assert( + state._is_root, + "The post-backward callback should only be called on the root FSDP instance", + ) + root_state = state + + if root_state._sync_gradients: + current_stream = state._device_handle.current_stream() + # TODO (rohan-varma): this also waits for the overlapped optimizer step to finish + # since it currently runs in the post-backward stream. That can be + # pushed to the next forward if run in a different stream + current_stream.wait_stream(root_state._post_backward_stream) + if root_state._all_reduce_stream is not current_stream: # uses HSDP + current_stream.wait_stream(root_state._all_reduce_stream) + if root_state.cpu_offload.offload_params: + # Wait for non-blocking GPU -> CPU sharded gradient copies from the + # post-backward hooks to finish explicitly since CPU gradients do + # not automatically synchronize with the GPU + state._device_handle.current_stream().synchronize() + root_state._exec_order_data.next_iter() + + for fsdp_state in state._all_fsdp_states: + _catch_all_reshard(fsdp_state) + _finalize_params(fsdp_state) + fsdp_state.training_state = TrainingState.IDLE + handle = fsdp_state._handle + if handle: + handle._ran_pre_backward_hook = False + handle._needs_pre_backward_unshard = False + handle._post_forward_index = None + handle._training_state = HandleTrainingState.IDLE + handle._prefetched = False + # Reset for cases like one forward and multiple backwards + root_state._post_backward_callback_queued = False + + +@no_type_check +def _catch_all_reshard( + state: _FSDPState, +) -> None: + """ + Reshards the parameters that may not have been resharded in the + post-backward hook. This can happen when a module's output is used in the + forward pass, meaning that its pre-backward hook runs (unsharding the + parameter), but the post-backward hook does not run because the output was + not jused in the loss computation corresponding to this backward pass. + """ + # Wrap with a try-except to provide a more informative traceback if an + # error is raised + try: + if state._handle: + # TODO: This already-resharded check is brittle: + # https://github.com/pytorch/pytorch/issues/83956 + already_resharded = ( + state._handle.flat_param.data_ptr() + == state._handle.flat_param._local_shard.data_ptr() + # If FSDP skipped using sharded views, then the flat parameter + # still points to the sharded data, so we need to reshard to + # use sharded views + and not state._handle._skipped_use_sharded_views + ) + if already_resharded: + return + free_unsharded_flat_param = _should_free_in_backward(state, state._handle) + _reshard(state, state._handle, free_unsharded_flat_param) + except Exception as e: + _p_assert( + False, + f"Got exception in the catch-all reshard for {state}: {str(e)}", + raise_assertion_error=False, + ) + raise e + + +@no_type_check +def _finalize_params( + state: _FSDPState, +) -> None: + """Finalizes the parameters before the next iteration.""" + handle = state._handle + if not handle: + return + flat_param = handle.flat_param + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + if hasattr(flat_param, "_post_backward_hook_handle"): + pbhs_handle = flat_param._post_backward_hook_handle + pbhs_handle.remove() + del flat_param._post_backward_hook_handle + else: + if hasattr(flat_param, "_post_backward_hook_state"): + post_backward_hook_state_len = len(flat_param._post_backward_hook_state) + expected_post_backward_hook_state_len = int(flat_param.requires_grad) + 1 + _p_assert( + post_backward_hook_state_len == expected_post_backward_hook_state_len, + f"Invalid: ``_post_backward_hook_state``: {flat_param._post_backward_hook_state}", + ) + flat_param._post_backward_hook_state[-1].remove() + delattr(flat_param, "_post_backward_hook_state") + if flat_param.requires_grad: + if not state._sync_gradients: + # Preserve the gradient accumulation state if not synchronizing + # gradients: `.grad` remains the unsharded gradient from prior + # `no_sync()` iterations, and `_saved_grad_shard` remains the + # sharded gradient from the last synchronized iteration + return + if not handle._has_optim_in_backward: + handle.prepare_gradient_for_optim() + _p_assert( + hasattr(flat_param, "_post_backward_called"), + "Expects `_post_backward_called` to be set on the `FlatParameter`", + ) + flat_param._post_backward_called = False + + +@no_type_check +def _prefetch_handle( + state: _FSDPState, + current_handle: Optional[FlatParamHandle], + prefetch_mode: _PrefetchMode, +) -> None: + """ + Prefetches the next handles if needed (without synchronization). An empty + handles key cannot prefetch. + """ + if not current_handle: + return + handle = _get_handle_to_prefetch(state, current_handle) + if not handle: + return + # Temporarily emulate the training state while calling `_unshard` to + # ensure the correct `as_params` for `_use_unsharded_views()` + prev_training_state = handle._training_state + if prefetch_mode == _PrefetchMode.BACKWARD: + handle._training_state = HandleTrainingState.BACKWARD_PRE + elif prefetch_mode == _PrefetchMode.FORWARD: + handle._training_state = HandleTrainingState.FORWARD + else: + raise ValueError(f"Invalid prefetch mode on rank {state.rank}: {prefetch_mode}") + # Prefetch the next set of handles without synchronizing to allow + # the sync to happen as late as possible to maximize overlap + _unshard(state, handle, state._unshard_stream, state._pre_unshard_stream) + handle._training_state = prev_training_state + handle._prefetched = True + + +@no_type_check +def _get_handle_to_prefetch( + state: _FSDPState, + current_handle: FlatParamHandle, +) -> FlatParamHandle: + """ + Returns a :class:`list` of the handles keys to prefetch for the next + module(s), where ``current_handle`` represents the current module. + + "Prefetching" refers to running the unshard logic early (without + synchronization), and the "next" modules depend on the recorded execution + order and the current training state. + """ + training_state = _get_training_state(current_handle) + valid_training_states = ( + HandleTrainingState.BACKWARD_PRE, + HandleTrainingState.BACKWARD_POST, + HandleTrainingState.FORWARD, + ) + _p_assert( + training_state in valid_training_states, + f"Prefetching is only supported in {valid_training_states} but " + f"currently in {training_state}", + ) + eod = state._exec_order_data + target_handle: Optional[FlatParamHandle] = None + if ( + training_state == HandleTrainingState.BACKWARD_PRE + and state.backward_prefetch == BackwardPrefetch.BACKWARD_PRE + ) or ( + training_state == HandleTrainingState.BACKWARD_POST + and state.backward_prefetch == BackwardPrefetch.BACKWARD_POST + ): + target_handle_candidate = eod.get_handle_to_backward_prefetch(current_handle) + if ( + target_handle_candidate + and target_handle_candidate._needs_pre_backward_unshard + and not target_handle_candidate._prefetched + ): + target_handle = target_handle_candidate + else: + target_handle = None + elif training_state == HandleTrainingState.FORWARD and state.forward_prefetch: + target_handle_candidate = eod.get_handle_to_forward_prefetch(current_handle) + if ( + target_handle_candidate + and target_handle_candidate._needs_pre_forward_unshard + and not target_handle_candidate._prefetched + ): + target_handle = target_handle_candidate + else: + target_handle = None + + return target_handle + + +def _get_training_state( + handle: FlatParamHandle, +) -> HandleTrainingState: + """Returns the training state of the handles in ``handle``.""" + _p_assert(handle, "Expects a non-empty handle") + return handle._training_state + + +@no_type_check +def _register_pre_forward_hook( + state: _FSDPState, + module: nn.Module, +) -> None: + """ + Registers a pre-forward hook on ``module``. + """ + for forward_handle in state._pre_forward_handles: + forward_handle.remove() + state._pre_forward_handles.clear() + module_param_handle = state._fully_sharded_module_to_handle.get(module, None) + hook = functools.partial( + _pre_forward, state, module_param_handle, _pre_forward_unshard + ) + state._pre_forward_handles.append( + module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) + ) + + +@no_type_check +def _register_post_forward_hook( + state: _FSDPState, + module: nn.Module, +) -> None: + """ + Registers a post-forward hook on ``module``. Even if the module has no + handles, we should register the hook since it will register the module's + pre-backward hook. + """ + for forward_handle in state._post_forward_handles: + forward_handle.remove() + state._post_forward_handles.clear() + module_param_handle = state._fully_sharded_module_to_handle.get(module, None) + hook = functools.partial( + _post_forward, + state, + module_param_handle, + _post_forward_reshard, + ) + state._post_forward_handles.append(module.register_forward_hook(hook)) + + +@no_type_check +def _register_root_pre_forward_hook( + state: _FSDPState, + module: nn.Module, +): + """ + Registers root pre-forward hook on ``module``, which should be the local + FSDP root. + + NOTE: For the current composable FSDP design, we have each application of + ``fully_shard()`` to a module to indicate that that module is the local + FSDP root. We may remove this assumption in the future, in which case we + will need to register this root pre-forward hook on any candidate module + that may be the local FSDP root. + """ + for forward_handle in state._root_pre_forward_handles: + forward_handle.remove() + state._root_pre_forward_handles.clear() + hook = functools.partial(_root_pre_forward, state) + state._root_pre_forward_handles.append( + module.register_forward_pre_hook(hook, prepend=True, with_kwargs=True) + ) + + +@no_type_check +def _register_pre_backward_hooks( + state: _FSDPState, + module: nn.Module, + outputs: Any, + handle: FlatParamHandle, +) -> None: + """ + Registers pre-backward hooks on the tensors that require gradients in the + forward pass outputs ``outputs``, which were computed using the + ``FlatParameter`` s of ``handles``. + + Args: + module (nn.Module): Fully sharded module (see [Note: Fully Sharded + Module]). + + Returns: + Forward pass outputs with pre-backward hooks registered to tensors that + require gradients. + """ + # If there is no gradient computation, then there is no need for + # pre-backward logic + if not torch.is_grad_enabled(): + return outputs + if state._is_root: + state._post_backward_callback_queued = False # only defined on the root + + if handle: + handle._needs_pre_backward_unshard = False + # Since these handles' `FlatParameter`s participated in a forward, we + # conservatively assume that they will be used in the backward + handle._ran_pre_backward_hook = False + + def _register_hook(t: torch.Tensor) -> torch.Tensor: + if t.requires_grad: + t.register_hook( + torch.utils.hooks.unserializable_hook( + functools.partial(_pre_backward_hook, state, module, handle) + ) + ) + if handle: + handle._needs_pre_backward_unshard = True + return t + + return _apply_to_tensors(_register_hook, outputs) + + +def _register_post_backward_hook( + state: _FSDPState, + handle: Optional[FlatParamHandle], +) -> None: + """ + Registers post-backward hooks on the ``FlatParameter`` s' + ``AccumulateGrad`` objects to reshard and to reduce-scatter gradients. + + The ``AccumulateGrad`` object represents the last function that finalizes + the ``FlatParameter`` 's gradient, so it only runs after its entire + gradient computation has finished. + + We register the post-backward hook only once in the *first* forward that a + ``FlatParameter`` participates in. This relies on the ``AccumulateGrad`` + object being preserved through multiple forwards. + + NOTE: We follow this heuristic to prefer the *first* forward to target the + parameter mixed precision case, where there are *separate* + ``AccumulateGrad`` objects across the different forwards. (Without + parameter mixed precision, the ``AccumulateGrad`` objects are the same.) If + we instead prefer the *last* forward, then the hook runs early. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + if not handle: + return + flat_param = handle.flat_param + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_registered = hasattr(flat_param, "_post_backward_hook_handle") + if already_registered or not flat_param.requires_grad: + return + hook = functools.partial(_post_backward_hook, state, handle) + hook_handle = flat_param.register_post_accumulate_grad_hook(hook) + flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined] + else: + already_registered = hasattr(flat_param, "_post_backward_hook_state") + if already_registered or not flat_param.requires_grad: + return + # Get the `AccumulateGrad` object + temp_flat_param = flat_param.expand_as(flat_param) + _p_assert( + temp_flat_param.grad_fn is not None, + "The `grad_fn` is needed to access the `AccumulateGrad` and " + "register the post-backward hook", + ) + acc_grad = temp_flat_param.grad_fn.next_functions[0][0] # type: ignore[union-attr] + assert acc_grad is not None + hook_handle = acc_grad.register_hook( + functools.partial(_post_backward_hook, state, handle) + ) + flat_param._post_backward_hook_state = (acc_grad, hook_handle) # type: ignore[attr-defined] + + +def _register_post_backward_reshard_only_hook( + state: _FSDPState, + handle: Optional[FlatParamHandle], + args: tuple[Any, ...], + kwargs: dict[str, Any], +) -> None: + """ + Registers post-backward hooks to reshard flat parameters that do not + require gradient. We register these using multi-post-grad hooks on the + input activations to ensure that all gradients that may depend on the + parameters have been computed before resharding. + """ + # If there is no gradient computation, then there is no need for + # post-backward logic + if not torch.is_grad_enabled(): + return + # Construct `inp_tensors` lazily to avoid CPU overhead in typical case + # where each flat parameter requires gradient + inp_tensors: Optional[list[torch.Tensor]] = None + if not handle: + return + flat_param = handle.flat_param + + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + already_registered = hasattr(flat_param, "_post_backward_hook_handle") + else: + already_registered = hasattr(flat_param, "_post_backward_hook_state") + + if already_registered or flat_param.requires_grad: + return + if inp_tensors is None: + args_flat = pytree.arg_tree_leaves(*args, **kwargs) + inp_tensors = [ + obj for obj in args_flat if torch.is_tensor(obj) and obj.requires_grad + ] + assert inp_tensors is not None # mypy + hook_handle = register_multi_grad_hook( + inp_tensors, functools.partial(_post_backward_reshard_only_hook, state, handle) + ) + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + flat_param._post_backward_hook_handle = hook_handle # type: ignore[attr-defined, assignment] + else: + flat_param._post_backward_hook_state = (hook_handle,) # type: ignore[attr-defined, assignment] + + +@no_type_check +def _register_post_backward_final_callback( + state: _FSDPState, module: nn.Module +) -> None: + """ + Registers the post-backward final callback that runs at the end of the + backward pass. This should be called from the root FSDP instance at the + beginning of the pre-backward. + """ + _p_assert( + state._is_root, + "Only the root FSDP instance should register the post-backward callback", + ) + if state._post_backward_callback_queued: + return + _assert_in_training_states(state, [TrainingState.IDLE]) + # Trace does not need this callback + if not torch.distributed._functional_collectives.is_torchdynamo_compiling(): + state._post_backward_callback_queued = True + Variable._execution_engine.queue_callback( + functools.partial(_post_backward_final_callback, state, module) + ) + + +def _wait_for_computation_stream( + computation_stream: torch.Stream, + unshard_stream: torch.Stream, + pre_unshard_stream: torch.Stream, +): + """ + Has the unshard and pre-unshard streams wait for the computation stream. + For example, this should be called in the FSDP root's pre-forward to + respect optimizer step computation. + """ + # Tracing does not need to wait + if torch.distributed._functional_collectives.is_torchdynamo_compiling(): + return + unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + # Having the pre-all-gather stream wait for the current stream even if we + # do not leverage the pre-all-gather stream is tolerable since this only + # runs once per iteration + pre_unshard_stream.wait_stream(computation_stream) # type: ignore[attr-defined] + + +def _reset_flat_param_grad_info_if_needed( + handles: list[FlatParamHandle], +): + """ + Clears the original parameters' gradients if needed. This method's CPU + overhead is minimal, so we may call it throughout FSDP methods, which serve + as callsites to free the gradient memory earlier. + """ + if not isinstance(handles, list): + handles = [handles] + for handle in handles: + if handle._use_orig_params: + handle._reset_flat_param_grad_info_if_needed() + + +@no_type_check +def _get_buffers_and_dtypes_for_computation( + state: _FSDPState, + root_module: nn.Module, +) -> tuple[list[torch.Tensor], list[Optional[torch.dtype]]]: + """ + Returns all buffers in the module tree rooted at ``root_module`` and a + corresponding list of the buffer dtypes for computation. Each buffer dtype + is either ``None`` if buffer mixed precision is not enabled or the buffer + low precision dtype otherwise. + """ + _p_assert(state._is_root, "Expects the root to cast buffers") + buffers: list[torch.Tensor] = [] + buffer_dtypes: list[Optional[torch.dtype]] = [] + visited_buffers: set[torch.Tensor] = set() + # Traverse the FSDP states bottom-up so that we prefer the owning FSDP + # instance's mixed precision setting for each buffer + fsdp_states, fsdp_modules = traversal_utils._get_fsdp_states_with_modules( + root_module + ) + for fsdp_state, fsdp_module in zip(reversed(fsdp_states), reversed(fsdp_modules)): + for buffer_name, buffer in fsdp_module.named_buffers(): + if buffer in visited_buffers: + continue + visited_buffers.add(buffer) + if clean_tensor_name(buffer_name) in fsdp_state._ignored_buffer_names: + continue + buffers.append(buffer) + buffer_dtypes.append(fsdp_state.mixed_precision.buffer_dtype) + assert len(buffers) == len(buffer_dtypes), f"{len(buffers)} {len(buffer_dtypes)}" + return buffers, buffer_dtypes + + +@no_type_check +def _get_orig_buffer_dtypes( + state: _FSDPState, + buffer_names: list[str], +) -> list[torch.dtype]: + """ + Returns the original buffer types of the given buffer names. + """ + buffer_dtypes: list[torch.dtype] = [] + for buffer_name in buffer_names: + _p_assert( + buffer_name in state._buffer_name_to_orig_dtype, + f"{buffer_name} is missing from pre-computed dict on rank " + f"{state.rank}, which only has keys " + f"{state._buffer_name_to_orig_dtype.keys()}", + ) + buffer_dtypes.append(state._buffer_name_to_orig_dtype[buffer_name]) + return buffer_dtypes + + +def _cast_buffers_to_dtype_and_device( + buffers: list[torch.Tensor], + buffer_dtypes: list[Optional[torch.dtype]], + device: torch.device, +) -> None: + """ + Casts ``buffers`` to the dtypes given by ``buffer_dtypes`` and moves them + to ``device``. If an element in ``buffer_dtypes`` is ``None``, then the + corresponding buffer is only moved to ``device``. + """ + _p_assert( + buffer_dtypes is None or len(buffers) == len(buffer_dtypes), + f"Expects `buffers` and `buffer_dtypes` to have the same length if " + f"`buffer_dtypes` is specified but got {len(buffers)} and " + f"{len(buffer_dtypes)}", + ) + for buffer, buffer_dtype in zip(buffers, buffer_dtypes): + if not torch.is_floating_point(buffer) or buffer_dtype is None: + buffer.data = buffer.to(device=device) + else: + buffer.data = buffer.to(device=device, dtype=buffer_dtype) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_shard_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_shard_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..18933cf83a2f31ddd3f2e0f1e8e169240a4d54b6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_shard_utils.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +import copy +import itertools +import math +from typing import Optional + +import torch +import torch.distributed as dist +from torch._utils import _get_device_module +from torch.distributed import distributed_c10d +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard + + +def _get_remote_device_str(rank, device_type, num_devices_per_node): + if device_type.lower() == "cpu": + return f"rank:{rank}/{device_type}" + elif device_type.lower() == "hpu": + return f"rank:{rank}/{device_type}:{_get_device_module(device_type).current_device()}" + else: + return f"rank:{rank}/{device_type}:{rank % num_devices_per_node}" + + +def _create_chunk_sharded_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, +) -> ShardedTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local shard to create a ShardedTensor. + """ + chunks = tensor.chunk(world_size, dim=0) + if len(chunks) > rank: + local_shard = chunks[rank].clone() + offsets = [0 for _ in tensor.size()] + offsets[0] = math.ceil(tensor.size()[0] / world_size) * rank + local_shards = [Shard.from_tensor_and_offsets(local_shard, offsets, rank)] + else: + local_shards = [] + + # Create a ShardedTensor without invoking communication. + chunk_sizes = [list(chunk.size()) for chunk in chunks] + dim0_offsets = [0] + list( + itertools.accumulate([chunk_size[0] for chunk_size in chunk_sizes]) + )[:-1] + offsets = [0] * (len(chunk_sizes[0]) - 1) + chunk_offsets = [[d0] + offsets for d0 in dim0_offsets] + device_type = ( + distributed_c10d._get_pg_default_device(pg).type + if device is None + else device.type + ) + placements = [ + _get_remote_device_str( + dist.get_global_rank(pg, r), + device_type, + num_devices_per_node, + ) + for r in range(len(chunk_sizes)) + ] + assert len(chunk_sizes) == len(chunk_offsets) == len(placements) + shard_metadata = [ + ShardMetadata(offset, size, placement) + for offset, size, placement in zip(chunk_offsets, chunk_sizes, placements) + ] + sharded_tensor_metadata = ShardedTensorMetadata( + shards_metadata=shard_metadata, + size=tensor.size(), + tensor_properties=TensorProperties( + dtype=tensor.dtype, + layout=tensor.layout, + requires_grad=False, + memory_format=torch.contiguous_format, + pin_memory=tensor.is_pinned(), + ), + ) + return ShardedTensor._init_from_local_shards_and_global_metadata( + local_shards, sharded_tensor_metadata=sharded_tensor_metadata, process_group=pg + ) + + +def _create_chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, +) -> DTensor: + """ + Shard a tensor to chunks along the first dimension. The local rank will gets its + corresponding chunk as the local tensor to create a DTensor. + """ + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.detach().clone() + + # FSDP placements: [Shard(0)] + # HSDP placements: [Replicate(), Shard(0)] + replicate_placements = [Replicate() for _ in range(device_mesh.ndim)] + shard_placements = [Replicate() for _ in range(device_mesh.ndim)] + shard_placements[-1] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local( + tensor, device_mesh, replicate_placements, run_check=False + ).redistribute( + placements=shard_placements, + ) + + +def _all_gather_dtensor( + tensor: DTensor, + root_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + """ + All gather a DTensor in its sharded dimension and return the local tensor. + """ + assert root_mesh == tensor.device_mesh, ( + "The device mesh of a tensor should be a root mesh." + ) + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP placements: [Shard(0)] -> [Replicate()] + # HSDP placements: [Replicate(), Shard(0)] -> [Replicate(), Replicate()] + placements[-1] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_state_dict_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_state_dict_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..70e317e896caaaa1606f2e476a369b14a1ca9a94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_state_dict_utils.py @@ -0,0 +1,919 @@ +# mypy: allow-untyped-defs +import contextlib +import logging +import math +import warnings +from collections.abc import Generator, Iterator +from typing import Any, Callable, cast, no_type_check + +import torch +import torch.distributed as dist +import torch.distributed.algorithms._checkpoint.checkpoint_wrapper as checkpoint_wrapper +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed._shard.sharded_tensor import ( + init_from_local_shards, + Shard, + ShardedTensor, +) +from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_module_fsdp_state_if_fully_sharded_module, + _has_fsdp_params, + _is_composable, + _module_handle, + clean_tensor_name, + FSDP_PREFIX, + FSDP_WRAPPED_MODULE, +) +from torch.distributed.fsdp._debug_utils import SimpleProfiler +from torch.distributed.fsdp._runtime_utils import ( + _cast_buffers_to_dtype_and_device, + _get_orig_buffer_dtypes, + _lazy_init, + _reset_flat_param_grad_info_if_needed, +) +from torch.distributed.fsdp.api import ( + FullStateDictConfig, + ShardingStrategy, + StateDictType, +) +from torch.distributed.tensor import DTensor +from torch.distributed.utils import _replace_by_prefix + +from ._fsdp_extensions import ( + _ext_all_gather_dtensor, + _ext_chunk_dtensor, + _ext_chunk_tensor, + _ext_post_unflatten_transform, + _ext_pre_load_state_dict_transform, +) +from ._unshard_param_utils import _unshard_fsdp_state_params, FLAT_PARAM + + +logger = logging.getLogger(__name__) + + +def _should_unshard_params(fsdp_state: _FSDPState) -> bool: + return not ( + fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD + and (_is_composable(fsdp_state) or fsdp_state._use_orig_params) + ) + + +def _convert_to_wrapped_module_name(module_name: str) -> str: + module_name = module_name.replace(f"{FSDP_PREFIX}", "") + module_name = module_name.replace(f"{FSDP_WRAPPED_MODULE}", "") + if module_name: + module_name = f"{module_name}." + # `CheckpointWrapper` adds a prefix that has to be removed as well. + module_name = module_name.replace(checkpoint_wrapper._CHECKPOINT_PREFIX, "") + return module_name + + +def _param_name_infos( + module: nn.Module, fsdp_state: _FSDPState +) -> Iterator[tuple[str, str, str]]: + if not _has_fsdp_params(fsdp_state, module): + return + for param_name, module_name in _module_handle( + fsdp_state, module + ).param_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +def _shared_param_name_infos( + module: nn.Module, fsdp_state +) -> Iterator[tuple[str, str, str]]: + for param_name, module_name in _module_handle( + fsdp_state, module + ).shared_param_module_names(): + module_name = _convert_to_wrapped_module_name(module_name) + fqn = f"{module_name}{param_name}" + yield fqn, param_name, module_name + + +@no_type_check +def _enter_unshard_params_ctx( + module: nn.Module, + fsdp_state: _FSDPState, + writeback: bool = False, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, +) -> None: + """ + state_dict hooks cannot use the pure context call as the checkpoint flow + requires to enter the context in the pre-hook but leave the context in the + post-hook. This API enters the context of ``_unshard_fsdp_state_params``. + """ + assert module not in fsdp_state._unshard_params_ctx, ( + "Entering the ``_unshard_fsdp_state_params`` context but _unshard_params_ctx[module] " + "is not None." + ) + fsdp_state._unshard_params_ctx[module] = _unshard_fsdp_state_params( + module, + fsdp_state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + fsdp_state._unshard_params_ctx[module].__enter__() + + +@no_type_check +def _exit_unshard_params_ctx(module: nn.Module, fsdp_state: _FSDPState) -> None: + """A helper function to exit ``_unshard_fsdp_state_params`` context.""" + fsdp_state._unshard_params_ctx[module].__exit__(None, None, None) + fsdp_state._unshard_params_ctx.pop(module) + + +def _common_pre_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, +) -> None: + """Performs the pre-state_dict tasks shared by all state_dict types.""" + if fsdp_state._device_handle.is_available(): + fsdp_state._device_handle.synchronize() + # TODO: need to check if this is always correct for composable FSDP. + _lazy_init(fsdp_state, module) + if fsdp_state._is_root: + _reset_flat_param_grad_info_if_needed(fsdp_state._all_handles) + + +def _common_unshard_pre_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + offload_to_cpu: bool, + rank0_only: bool, +) -> None: + """ + Performs the pre-state_dict tasks shared by all state_dict types that require + ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this hook. + """ + # For composable `fully_shard`, it does not need to unshard parameters for `NO_SHARD` cases. + if not _should_unshard_params(fsdp_state): + return + _enter_unshard_params_ctx( + module, + fsdp_state, + writeback=False, + offload_to_cpu=offload_to_cpu, + rank0_only=rank0_only, + ) + + +@no_type_check +def _common_unshard_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, + param_hook: Callable, +) -> dict[str, Any]: + """ + The post-state_dict flow that shared by all state_dict types that require + ``_unshard_fsdp_state_params()``. FULL_STATE_DICT and SHARDED_STATE_DICT use this + hook. + """ + _replace_by_prefix(state_dict, prefix + f"{FSDP_PREFIX}", prefix) + # Return early for trivial cases + if not state_dict or not _has_fsdp_params(fsdp_state, module): + if _should_unshard_params(fsdp_state): + _exit_unshard_params_ctx(module, fsdp_state) + return state_dict + + # If a rank does not have unsharded parameters(when `rank0_only=True` + # and `rank != 0`), then the rank only needed to participate in the + # all-gather and does not need to save the # state dict. We simply check + # rank0_only to ensure this issue. + rank0_only = ( + fsdp_state._state_dict_type == StateDictType.FULL_STATE_DICT + and cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only + ) + # no_fsdp_return means the state_dict returned by this rank should contain + # only non-FSDP controlled parameters and buffers. + no_fsdp_return = rank0_only and fsdp_state.rank != 0 + if no_fsdp_return and not fsdp_state._use_orig_params: + for clean_key in fsdp_state._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_key.replace( + f"{checkpoint_wrapper._CHECKPOINT_PREFIX}.", "" + ) + state_dict.pop(f"{prefix}{clean_key}", None) + # Non-zero ranks have flat_param key when rank0_only=True, because rank0_only=True is + # passed in to unshard context, but nonzero ranks reshard early, causing this flat_param + # to appear in state_dict. + state_dict.pop(f"{prefix}{FLAT_PARAM}") + _exit_unshard_params_ctx(module, fsdp_state) + return state_dict + + # Loop only the parameters saved in this instance's wrapped module to + # avoid processing buffers. + for fqn, param_name, module_name in _param_name_infos(module, fsdp_state): + fqn = f"{prefix}{fqn}" + if no_fsdp_return: + state_dict.pop(fqn) + continue + assert fqn in state_dict, ( + f"FSDP assumes {fqn} is in the state_dict but the state_dict only " + f"has {state_dict.keys()}. " + f"prefix={prefix}, module_name={module_name}, " + f"param_name={param_name} rank={fsdp_state.rank}." + ) + + param_hook(state_dict, prefix, fqn) + + if _should_unshard_params(fsdp_state): + _exit_unshard_params_ctx(module, fsdp_state) + + cpu_device = torch.device("cpu") + buffer_clean_fqns = [] + buffers = [] + for clean_key in fsdp_state._buffer_names: + # This is a hack to support activation checkpoint. + clean_key = clean_tensor_name(clean_key) + fqn = f"{prefix}{clean_key}" + if fqn not in state_dict: + # A buffer can be registered as non-persistent. + continue + if no_fsdp_return: + state_dict.pop(fqn) + else: + buffer = state_dict[fqn] + if ( + fsdp_state._state_dict_config.offload_to_cpu + and buffer.device != cpu_device + ): + state_dict[fqn] = buffer.to(cpu_device) + # skip upcasting for ignored buffers + if clean_key not in fsdp_state._ignored_buffer_names: + buffer_clean_fqns.append(clean_key) + buffers.append(state_dict[fqn]) + + if buffers: + mixed_precision_enabled_for_buffers = ( + fsdp_state._mixed_precision_enabled_for_buffers() + if not _is_composable(fsdp_state) + else (fsdp_state.mixed_precision.buffer_dtype is not None) + ) + if mixed_precision_enabled_for_buffers: + buffer_dtypes = _get_orig_buffer_dtypes(fsdp_state, buffer_clean_fqns) + _cast_buffers_to_dtype_and_device( + buffers, buffer_dtypes, fsdp_state.compute_device + ) + for buffer, clean_fqn in zip(buffers, buffer_clean_fqns): + fqn = f"{prefix}{clean_fqn}" + logger.info("FSDP is casting the dtype of %s to %s", fqn, buffer.dtype) + state_dict[fqn] = buffer.clone() + return state_dict + + +@no_type_check +def _full_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. pre-state_dict hook is + not actually supported by ``nn.Module``. As a result, this API is called + from ``_full_post_state_dict_hook()`` to simulate the case. Once pre-state_dict + is supported in ``nn.Module``, this hook will be registered as a hook in + ``nn.Module``. + """ + if getattr(fsdp_state, "_device_mesh", False): + _mesh_resources.get_root_mesh(fsdp_state._device_mesh) + + _common_pre_state_dict_hook(module, fsdp_state) + _common_unshard_pre_state_dict_hook( + module, + fsdp_state, + offload_to_cpu=fsdp_state._state_dict_config.offload_to_cpu, + rank0_only=cast(FullStateDictConfig, fsdp_state._state_dict_config).rank0_only, + ) + + +@no_type_check +def _full_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + Hook that runs after model.state_dict() is called before returning result to + user. For FSDP, we may have to clone the tensors in state_dict as params go + back to sharded version after _unshard_fsdp_state_params ends, and also remove + the ``FSDP_WRAPPED_MODULE`` prefix. + """ + + def param_hook( + state_dict: dict[str, Any], + prefix: str, + fqn: str, + ) -> None: + clean_key = fqn + clean_prefix = clean_tensor_name(prefix) + # Strip prefix out of key if needed as buffer names and param names + # do not have prefix considered as they are not computed in `state_dict` + # call. + clean_key = clean_key.removeprefix(clean_prefix) + + # Clone parameters before exiting the `_unshard_fsdp_state_params()` context. + if not getattr(state_dict[fqn], "_has_been_cloned", False): + try: + state_dict[fqn] = state_dict[fqn].detach().clone() + state_dict[fqn]._has_been_cloned = True # type: ignore[attr-defined] + except BaseException as e: + warnings.warn( + f"Failed to clone() tensor with name {fqn} on rank {fsdp_state.rank}. " + "This may mean that this state_dict entry could point to invalid " + "memory regions after returning from state_dict() call if this " + "parameter is managed by FSDP. Please check clone " + f"implementation of {fqn}. Error: {str(e)}" + ) + + return _common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +def _full_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + _lazy_init(fsdp_state, module) + if _should_unshard_params(fsdp_state): + with SimpleProfiler.profile("_enter_unshard_params_ctx"): + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + # Add FSDP_PREFIX only for wrapper-based FSDP. + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") + + +def _full_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if _should_unshard_params(fsdp_state): + with SimpleProfiler.profile("_exit_unshard_params_ctx"): + _exit_unshard_params_ctx(module, fsdp_state) + + +def _local_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. Right now, pre-state_dict + hook is not supported by the PyTorch core. So this API is called from + `_local_post_state_dict_hook()` to simulate the case. + """ + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handle(fsdp_state, module).uses_sharded_strategy + ): + raise RuntimeError( + "``local_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, fsdp_state) + + +@no_type_check +def _local_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + This hook create a ShardedTensor from the local flat_param and replace + the state_dict[f"{prefix}{FLAT_PARAM}] with the ShardedTensor. No copy + will happen. The underlying storage is the same. + """ + + _replace_by_prefix(state_dict, f"{prefix}{FSDP_PREFIX}", prefix) + if not _has_fsdp_params(fsdp_state, module): + return state_dict + + # state_dict[f"{prefix}{FLAT_PARAM}"] exists and has the same tensor + # value as the flat_param but it is a pure Tensor because + # nn.Module.state_dict() will detach the parameter. Therefore, we need + # to get flat_param to get the metadata. + assert _module_handle(fsdp_state, module), "Should have returned early" + flat_param = _module_handle(fsdp_state, module).flat_param + # Constructs a ShardedTensor from the flat_param "without" padding. + # Removing the padding allows users to change the number of ranks + # when loading the local_state_dict. + full_numel = flat_param._unpadded_unsharded_size.numel() # type: ignore[attr-defined] + shard_offset = flat_param.numel() * fsdp_state.rank + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + if valid_data_size > 0: + # If FlatParameter is returned, FlatParameter._local_shard cause a + # pickling issue (can be torch.save but not torch.load). Since there + # is no benefit for state_dict to return the actual FlatParameter class, + # a view (which is a tensor) of the FlatParameter will be returned. + flat_param = flat_param[:valid_data_size].view(valid_data_size) + local_shards = [ + Shard.from_tensor_and_offsets(flat_param, [shard_offset], fsdp_state.rank) + ] + else: + local_shards = [] + sharded_tensor = init_from_local_shards( + local_shards, full_numel, process_group=fsdp_state.process_group + ) # type: ignore[assignment] + # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. + if fsdp_state._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[f"{prefix}{FLAT_PARAM}"] = sharded_tensor + return state_dict + + +def _local_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + pass + + +def _local_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + """ + This hook finds the local flat_param for this FSDP module from the + state_dict. The flat_param should be a ShardedTensor. This hook converts + the ShardedTensor to a tensor. No copy happen unless padding is required. + """ + _lazy_init(fsdp_state, module) + _replace_by_prefix(state_dict, prefix, f"{prefix}{FSDP_PREFIX}") + fqn = f"{prefix}{FSDP_PREFIX}{FLAT_PARAM}" + if fqn not in state_dict: + assert not _has_fsdp_params(fsdp_state, module), ( + "No `FlatParameter` in `state_dict` for this FSDP instance " + "but it has parameters" + ) + return + load_tensor = state_dict[fqn] + assert isinstance(load_tensor, ShardedTensor), ( + "Tensors in local_state_dict should be ShardedTensor." + ) + + # Convert the ShardedTensor to a Tensor. + flat_param = _module_handle(fsdp_state, module).flat_param + assert flat_param is not None + valid_data_size = flat_param.numel() - flat_param._shard_numel_padded + shards = load_tensor.local_shards() + if valid_data_size > 0: + assert len(shards), "load_local_state_dict assume one shard per ShardedTensor." + load_tensor = shards[0].tensor + + # Get the metadata of the flat_param to decide whether to pad the loaded + # tensor. + if flat_param._shard_numel_padded > 0: + assert load_tensor.numel() < flat_param.numel(), ( + f"Local shard size = {flat_param.numel()} and the tensor in " + f"the state_dict is {load_tensor.numel()}." + ) + load_tensor = F.pad(load_tensor, [0, flat_param._shard_numel_padded]) + else: + load_tensor = flat_param + # TODO: Add DTensor state_dict support for LOCAL_STATE_DICT. + state_dict[fqn] = load_tensor + + +def _sharded_pre_state_dict_hook( + fsdp_state: _FSDPState, + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + Hook that runs before model.state_dict() is called. Check + ``_full_pre_load_state_dict_hook`` for the detail. + """ + if ( + _has_fsdp_params(fsdp_state, module) + and not _module_handle(fsdp_state, module).uses_sharded_strategy + ): + raise RuntimeError( + "``sharded_state_dict`` can only be used when parameters are flatten " + "and sharded." + ) + _common_pre_state_dict_hook(module, fsdp_state) + # Setting offload_to_cpu here does not work even if offload_to_cpu is True. + # We have to create ShardedTensor first then move it to CPU. + _common_unshard_pre_state_dict_hook( + module, + fsdp_state, + offload_to_cpu=False, + rank0_only=False, + ) + + +@no_type_check +def _sharded_post_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> dict[str, Any]: + """ + The hook replaces the unflattened, unsharded parameter in the state_dict + with a unflattened, sharded parameter (a ShardedTensor). + """ + + def param_hook(state_dict: dict[str, Any], prefix: str, fqn: str): + param = state_dict[fqn] + if not fsdp_state._state_dict_config._use_dtensor: + sharded_tensor = _ext_chunk_tensor( + tensor=param, + rank=fsdp_state.rank, + world_size=fsdp_state.world_size, + num_devices_per_node=fsdp_state._device_handle.device_count(), + pg=fsdp_state.process_group, + fsdp_extension=fsdp_state._fsdp_extension, + ) + else: + sharded_tensor = _ext_chunk_dtensor( + tensor=param, + rank=fsdp_state.rank, + device_mesh=fsdp_state._device_mesh, + fsdp_extension=fsdp_state._fsdp_extension, + ) + if fsdp_state._state_dict_config.offload_to_cpu: + sharded_tensor = sharded_tensor.cpu() + state_dict[fqn] = sharded_tensor + + return _common_unshard_post_state_dict_hook( + module, fsdp_state, state_dict, prefix, param_hook + ) + + +@no_type_check +def _sharded_post_load_state_dict_hook( + module: nn.Module, fsdp_state: _FSDPState, *args, **kwargs +) -> None: + if _has_fsdp_params(fsdp_state, module): + with SimpleProfiler.profile("_exit_unshard_params_ctx"): + _exit_unshard_params_ctx(module, fsdp_state) + + +@no_type_check +def _sharded_pre_load_state_dict_hook( + module: nn.Module, + fsdp_state: _FSDPState, + state_dict: dict[str, Any], + prefix: str, +) -> None: + """ + The hook combines the unflattened, sharded parameters (ShardedTensor) to + a new FlatParameter and shards the new FlatParameter to the local chunk. + """ + _lazy_init(fsdp_state, module) + if not _is_composable(fsdp_state): + _replace_by_prefix(state_dict, prefix, prefix + f"{FSDP_PREFIX}") + if not _has_fsdp_params(fsdp_state, module): + return + + handle = _module_handle(fsdp_state, module) + if not handle.uses_sharded_strategy: + raise RuntimeError( + "load_sharded_state_dict can only be called when parameters " + "are flattened and sharded." + ) + fqn_to_param_ext = dict( + zip(handle.flat_param._fqns, handle.flat_param._param_extensions) + ) + + for fqn, _, _ in _param_name_infos(module, fsdp_state): + if not _is_composable(fsdp_state): + fqn_from_global_root = f"{prefix}{FSDP_PREFIX}{fqn}" + else: + fqn_from_global_root = f"{prefix}{fqn}" + try: + param = state_dict.pop(fqn_from_global_root) + except KeyError: + logger.warning( + f"Did not find param with FQN {fqn_from_global_root}, skipping it. " # noqa: G004 + "The weight will not be filled if you expect it to be." + ) + continue # TODO: Improve unittesting for state_dict finetuning + # cases: https://github.com/pytorch/pytorch/issues/109134 + + if not fsdp_state._state_dict_config._use_dtensor: + # All-gather the param (ShardedTensor) + param, shards = _ext_pre_load_state_dict_transform( + param, fsdp_state._fsdp_extension + ) + + assert len(shards) < 2, ( + "Expects 0 or 1 shard per rank " + f"but got {len(shards)} shards on rank {fsdp_state.rank}." + ) + param_numel = param.size().numel() + dim_0_size = param.size()[0] + chunk_size = ( + math.ceil(dim_0_size / fsdp_state.world_size) + * param_numel + // dim_0_size + ) + if len(shards) == 1: + local_tensor = shards[0].tensor.flatten() + with SimpleProfiler.profile(SimpleProfiler.Type.H2D): + local_tensor = local_tensor.to(fsdp_state.compute_device) + num_padding = chunk_size - local_tensor.numel() + if num_padding > 0: + local_tensor = F.pad(local_tensor, [0, num_padding]) + else: + local_tensor = torch.zeros( + chunk_size, dtype=param.dtype, device=fsdp_state.compute_device + ) + tensor = torch.empty( + chunk_size * fsdp_state.world_size, + dtype=local_tensor.dtype, + device=fsdp_state.compute_device, + ) + with SimpleProfiler.profile(SimpleProfiler.Type.ALLGATHER): + dist.all_gather_into_tensor( + tensor, local_tensor, group=fsdp_state.process_group + ) + tensor = tensor.narrow(0, 0, param_numel).reshape(param.size()) + state_dict[fqn_from_global_root] = tensor + else: + if param.device != fsdp_state._device_mesh.device_type: + param = param.to(fsdp_state._device_mesh.device_type) + + root_mesh = _mesh_resources.get_root_mesh(fsdp_state._device_mesh) + local_tensor = _ext_all_gather_dtensor( + param, root_mesh, fsdp_state._fsdp_extension + ) + + if fqn_to_param_ext.get(fqn) is not None: + ext = fqn_to_param_ext[fqn] + local_tensor = _ext_post_unflatten_transform( + local_tensor, ext, fsdp_state._fsdp_extension + ) + state_dict[fqn_from_global_root] = local_tensor + + with SimpleProfiler.profile("_enter_unshard_params_ctx"): + _enter_unshard_params_ctx(module, fsdp_state, writeback=True) + + +@contextlib.contextmanager +def _replace_with_full_state_dict_type(fsdp_state: _FSDPState) -> Generator: + old_state_dict_config = fsdp_state._state_dict_config + old_state_dict_type = fsdp_state._state_dict_type + fsdp_state._state_dict_config = FullStateDictConfig() + fsdp_state._state_dict_type = StateDictType.FULL_STATE_DICT + yield + fsdp_state._state_dict_config = old_state_dict_config + fsdp_state._state_dict_type = old_state_dict_type + + +@no_type_check +@torch.no_grad() +def _post_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, +) -> dict[str, Any]: + """ + _post_state_dict_hook() is called after the state_dict() of this + FSDP module is executed. ``fsdp_state._state_dict_type`` is used to decide + what postprocessing will be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " + "be returned." + ) + else: + context = contextlib.nullcontext() + + with context: + _post_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_post_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_state_dict_hook, + } + processed_state_dict = _post_state_dict_hook_fn[fsdp_state._state_dict_type]( + module, fsdp_state, state_dict, prefix + ) + + if fsdp_state._is_root: + logger.info("FSDP finished processing state_dict(), prefix=%s", prefix) + for key, tensor in sorted(processed_state_dict.items()): + if key.startswith(prefix) and isinstance(tensor, torch.Tensor): + local_shape = tensor.shape + device = None + if isinstance(tensor, ShardedTensor): + local_shape = None + shards = tensor.local_shards() + if shards: + local_shape = shards[0].tensor.shape + device = shards[0].tensor.device + elif isinstance(tensor, DTensor): + local_shape = tensor.to_local().shape + device = tensor.device + else: + device = tensor.device + logger.info( + "FQN=%s: type=%s, shape=%s, local_shape=%s, dtype=%s, device=%s", + key, + type(tensor), + tensor.shape, + local_shape, + tensor.dtype, + device, + ) + + return processed_state_dict + + +@no_type_check +@torch.no_grad() +def _pre_state_dict_hook( + module: nn.Module, + *args, + **kwargs, +) -> None: + """ + This is called before the core state dict saving logic of ``module``. + ``fsdp_state._state_dict_type`` is used to decide what postprocessing will + be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will " + "be returned." + ) + else: + _set_use_dtensor(fsdp_state) + context = contextlib.nullcontext() + + with context: + _pre_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_pre_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_state_dict_hook, + } + _pre_state_dict_hook_fn[fsdp_state._state_dict_type]( + fsdp_state, + module, + *args, + **kwargs, + ) + + +@no_type_check +def _set_use_dtensor(fsdp_state: _FSDPState) -> None: + # If device_mesh is passed in when initializing FSDP, we automatically turn the + # _use_dtensor flag to be true for ShardedStateDictConfig(). + if getattr(fsdp_state, "_device_mesh", None): + state_dict_type = fsdp_state._state_dict_type + if state_dict_type == StateDictType.LOCAL_STATE_DICT: + raise RuntimeError( + "Found state_dict_type LOCAL_STATE_DICT", + "DeviceMesh is not compatible with LOCAL_STATE_DICT.", + "Please set state_dict_type to SHARDED_STATE_DICT to get DTensor state_dict.", + ) + else: + fsdp_state._state_dict_config._use_dtensor = True + + +@no_type_check +@torch.no_grad() +def _pre_load_state_dict_hook( + module: nn.Module, + state_dict: dict[str, Any], + prefix: str, + *args: Any, +) -> None: + """ + This is called before ``module._load_from_state_dict()``. + ``fsdp_state._state_dict_type`` is used to decide what preprocessing will + be done. + """ + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" + "be returned." + ) + else: + _set_use_dtensor(fsdp_state) + context = contextlib.nullcontext() + + _lazy_init(fsdp_state, module) + if fsdp_state._is_root: + SimpleProfiler.reset() + + with context: + _pre_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_pre_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_pre_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_pre_load_state_dict_hook, + } + # Code that is common for all state_dict impls + if fsdp_state._device_handle.is_available(): + fsdp_state._device_handle.synchronize() + # Dispatch into state_dict specific implementation of pre-hook. + _pre_load_state_dict_hook_fn[fsdp_state._state_dict_type]( + module, fsdp_state, state_dict, prefix + ) + + +@no_type_check +@torch.no_grad() +def _post_load_state_dict_hook( + module: nn.Module, + incompatible_keys: tuple[list[str], list[str]], + *args: Any, +) -> None: + fsdp_state = _get_module_fsdp_state_if_fully_sharded_module(module) + if fsdp_state.sharding_strategy == ShardingStrategy.NO_SHARD: + context = _replace_with_full_state_dict_type(fsdp_state) + warnings.warn( + "When using ``NO_SHARD`` for ``ShardingStrategy``, full_state_dict will" + "be returned." + ) + else: + context = contextlib.nullcontext() + + with context: + _post_load_state_dict_hook_fn = { + StateDictType.FULL_STATE_DICT: _full_post_load_state_dict_hook, + StateDictType.LOCAL_STATE_DICT: _local_post_load_state_dict_hook, + StateDictType.SHARDED_STATE_DICT: _sharded_post_load_state_dict_hook, + } + # Code that is common for all state_dict impls + # Dispatch into state_dict type specific implementation of post-hook for + # loading state_dict. + _post_load_state_dict_hook_fn[fsdp_state._state_dict_type](module, fsdp_state) + + # When reporting incompatible keys, trim FSDP prefixes. + missing_keys = incompatible_keys[0] + unexpected_keys = incompatible_keys[1] + for i in range(len(missing_keys)): + missing_keys[i] = clean_tensor_name(missing_keys[i]) + + for i in range(len(unexpected_keys)): + unexpected_keys[i] = clean_tensor_name(unexpected_keys[i]) + + if fsdp_state._is_root: + SimpleProfiler.dump_and_reset("FSDP model load_state_dict profiling: ") + + +def _register_all_state_dict_hooks(state: _FSDPState): + """ + Registers pre-save, post-save, pre-load, and post-load state dict hooks. + """ + for hook_registration_fn_str, hook, hook_registration_fn_kwargs in ( + ("register_state_dict_pre_hook", _pre_state_dict_hook, {}), + ("_register_state_dict_hook", _post_state_dict_hook, {}), + ( + "_register_load_state_dict_pre_hook", + _pre_load_state_dict_hook, + {"with_module": True}, + ), + ("register_load_state_dict_post_hook", _post_load_state_dict_hook, {}), + ): + _register_state_dict_hooks_base( + state, hook_registration_fn_str, hook, hook_registration_fn_kwargs + ) + + +@no_type_check +def _register_state_dict_hooks_base( + state: _FSDPState, + hook_registration_fn_name: str, + hook: Callable, + hook_registration_fn_kwargs: dict[str, Any], +) -> None: + """Registers ``hook`` using ``hook_registration_fn``.""" + if not _is_composable(state): + getattr(state, hook_registration_fn_name)(hook, **hook_registration_fn_kwargs) + else: + handle = state._handle + if handle: + getattr(handle._fully_sharded_module, hook_registration_fn_name)( + hook, **hook_registration_fn_kwargs + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_trace_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_trace_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..662fab3864e45f1f3a97c922f7d4f077d44cc4ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_trace_utils.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import functools +from contextlib import contextmanager +from dataclasses import dataclass, field +from typing import Any, Callable, NamedTuple, Optional + +import torch +import torch.nn as nn + + +@dataclass +class TracingConfig: + """ + This represents a symbolic tracing configuration. + + Args: + tracer (torch.fx.Tracer): An instance of :class:`torch.fx.Tracer` to + use for symbolic tracing. The default value is the native + :class:`torch.fx.Tracer` constructed with default arguments. + However, the user may want to pass a different value such as the + ``HFTracer`` for models in the HuggingFace Transformers_ library. + .. _Transformers: https://huggingface.co/docs/transformers/index + concrete_args (Optional[Dict[str, Any]]): Concrete arguments that + should not be treated as ``torch.fx.Proxy`` when tracing the + module ``forward()``. Passing ``concrete_args`` allows partially + specializing the forward, e.g. to remove control flow or data + structures. This ``concrete_args`` here is the same argument used + in :meth:`~torch.fx.Tracer.trace`. + """ + + tracer: torch.fx.Tracer = field(default_factory=torch.fx.Tracer) + concrete_args: Optional[dict[str, Any]] = None + + +class _ParamUsageInfo(NamedTuple): + """ + This is used for ``_ExecutionInfo.module_to_param_usage_infos`` to record + execution information. The ``dict`` maps modules to a list of these + ``_ParamUsageInfo`` instances, where each instance represents a group of + parameters used together. + + Specifically, for each module key in the ``dict``, each instance of this + class represents either: + (1) the module and some sublist of its ``named_parameters()`` used + together in execution (see ``_patched_create_proxy()``), or + (2) a submodule and all of ``submodule.named_parameters()`` (see + ``_patched_call_module()``). + + Type (1) corresponds to directly using parameters in ops without calling + ``forward()``, and type (2) corresponds to calling ``forward()``. The + mapped-to lists in the ``dict`` follow the execution order. + """ + + module: nn.Module + named_params: list[tuple[str, nn.Parameter]] + + +class _ExecutionInfo: + """ + This represents the execution order information from the forward pass. + + Attributes: + curr_module (nn.Module): Current module being traced. + module_forward_order (List[nn.Module]): The modules in (pre-)forward + order, i.e. the order in which their ``forward()`` methods are + called. Each call to a module's ``forward()`` corresponds to one + element in the list. + module_to_param_usage_infos (Dict[nn.Module, List[_ParamUsageInfo]]): + Maps a module to a list of module execution infos. See + :class:`_ParamUsageInfo` for details. + param_forward_order (List[nn.Parameter]): The parameters in forward + execution order, where only a parameter's first participation is + included. + visited_params (Set[nn.Parameter]): The parameters visited so far + during the trace. This is only used during tracing for fast + membership check. Invariant: The parameters in + ``param_forward_order`` are exactly those in ``visited_params``. + """ + + def __init__(self, root_module: nn.Module) -> None: + self.curr_module: nn.Module = root_module + self.module_forward_order: list[nn.Module] = [root_module] + self.module_to_param_usage_infos: dict[nn.Module, list[_ParamUsageInfo]] = { + root_module: [] + } + self.param_forward_order: list[nn.Parameter] = [] + self.visited_params: set[nn.Parameter] = set() + + +class _ExecOrderTracer: + def __init__(self) -> None: + self.exec_info: Optional[_ExecutionInfo] = None + + @contextmanager + def patch_tracer(self, tracer: torch.fx.Tracer, root_module: nn.Module): + self.exec_info = _ExecutionInfo(root_module) + orig_call_module = tracer.call_module + orig_create_proxy = tracer.create_proxy + tracer.call_module = functools.partial( # type: ignore[method-assign] + self._patched_call_module, orig_call_module, self.exec_info + ) + fqn_to_param = dict(root_module.named_parameters()) + tracer.create_proxy = functools.partial( # type: ignore[method-assign] + self._patched_create_proxy, + orig_create_proxy, + self.exec_info, + fqn_to_param, + ) + try: + yield + finally: + tracer.call_module = orig_call_module # type: ignore[method-assign] + tracer.create_proxy = orig_create_proxy # type: ignore[method-assign] + + def _patched_call_module( + self, + call_module: Callable, + exec_info: _ExecutionInfo, + # Below are the expected arguments to `call_module()` + module: nn.Module, + forward: Callable, + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + """ + Overrides ``call_module`` to save execution information to + ``exec_info``. Note that ``call_module`` is called during symbolic + tracing for each non-root module. + + Args: + call_module (Callable): Original ``call_module`` to override. + exec_info (_ExecutionInfo): Used to record execution information. + module (nn.Module): Module corresponding to this ``call_module``. + forward (Callable): ``forward()`` method of ``module`` to be called + for this ``call_module``. + args (Tuple[Any, ...]): Positional arguments for ``forward``. + kwargs (Dict[str, Any]): Keyword arguments for ``forward``. + + Returns: + Same return value as ``call_module``. + """ + exec_info.module_forward_order.append(module) + named_params = list(module.named_parameters()) + curr_module = exec_info.curr_module + if named_params: + assert curr_module in exec_info.module_to_param_usage_infos, ( + "The current module should have already been processed by a patched `call_module`" + ) + exec_info.module_to_param_usage_infos[exec_info.curr_module].append( + _ParamUsageInfo(module, named_params) + ) + prev_curr_module = curr_module + exec_info.curr_module = module + exec_info.module_to_param_usage_infos[module] = [] + output = call_module(module, forward, args, kwargs) + exec_info.curr_module = prev_curr_module + return output + + def _patched_create_proxy( + self, + create_proxy: Callable, + exec_info: _ExecutionInfo, + fqn_to_param: dict[str, nn.Parameter], + # Below are the expected arguments to `create_proxy()` + kind: str, + target: torch.fx.node.Target, + args: tuple[Any, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + proxy_factory_fn: Optional[Callable[[torch.fx.Node], torch.fx.Proxy]] = None, + ) -> torch.fx.Proxy: + """ + Overrides ``create_proxy`` to save execution information to + ``exec_info``. Note that ``create_proxy`` is called during symbolic + tracing for each leaf function/method/module. + + Args: + create_proxy (Callable): Original ``create_proxy`` to override. + exec_info (_ExecutionInfo): Used to record execution information. + fqn_to_param (Dict[str, nn.Parameter]): ``dict`` version of the + root module's ``named_parameters()`` with FQN as key and + parameter as value. + kind (str): Kind of the target method ('call_function', + 'call_method', 'get_attr', 'call_module', 'placeholder', or + 'output'). See :class:`torch.fx.Graph` for details. This is + passed to ``create_proxy``. + target (torch.fx.node.Target): Contains the string name of the + function/method/module. This is passed to ``create_proxy``. + args (Tuple[Any, ...]): Positional arguments for the function/ + method/module. This is passed to ``create_proxy``. + kwargs (Dict[str, Any]): Keyword arguments for the function/method/ + module. This is passed to ``create_proxy`` + name (Optional[str]): An optional string name for the ``Node`` + created in ``create_proxy``. This is passed to + ``create_proxy``. + type_expr (Optional[Any]): An optional type annotation representing + the Python type that the output of the node has. This is passed + to ``create_proxy``. + proxy_factory_fn (Callable[[torch.fx.Node], torch.fx.Proxy]): + An alternative proxy constructor used in ``create_proxy``. This + is passed to ``create_proxy``. + + Returns: + torch.fx.Proxy: Created ``Node`` wrapped in a ``Proxy`` object. + """ + proxy = create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + curr_module = exec_info.curr_module + if kind in ("call_function", "call_method"): + if args is not None: + named_params: list[tuple[str, nn.Parameter]] = [] + for arg in args: + if ( + isinstance(arg, torch.fx.Proxy) + and arg.node.target in fqn_to_param + ): + param = fqn_to_param[arg.node.target] # type: ignore[index] + named_params.append((arg.node.target, param)) # type: ignore[arg-type] + if param not in exec_info.visited_params: + exec_info.visited_params.add(param) + exec_info.param_forward_order.append(param) + if named_params: + exec_info.module_to_param_usage_infos[curr_module].append( + _ParamUsageInfo(curr_module, named_params) + ) + elif kind == "call_module": + named_params = list(curr_module.named_parameters()) + if named_params: + exec_info.module_to_param_usage_infos[curr_module].append( + _ParamUsageInfo(curr_module, named_params) + ) + for _, param in named_params: + if param not in exec_info.visited_params: + exec_info.visited_params.add(param) + exec_info.param_forward_order.append(param) + return proxy diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_traversal_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_traversal_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9fcaa35e3c4662a274a819e06a4754dbe77eda --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_traversal_utils.py @@ -0,0 +1,112 @@ +""" +NOTE: This file must be imported like +``import torch.distributed.fsdp._traversal_utils`` and not like +``from torch.distributed.fsdp._traversal_utils import ...`` to avoid circular +imports. For brevity, we may import the file as ``traversal_utils``. +""" + +import collections + +import torch.nn as nn +from torch.distributed._composable.contract import _get_registry +from torch.distributed.fsdp._common_utils import _FSDPState, _get_module_fsdp_state + + +""" +[Note: FSDP State Traversal] +For the wrapper code path, ``_FSDPState`` is the ``FullyShardedDataParallel`` +module wrapping a fully sharded module, and for the non-wrapper code path, +``_FSDPState`` is an object that gets embedded on a fully sharded module. +See [Note: Fully Sharded Module] for the definition. + +There are three common traversal idioms: Given a root module, +- ``_get_fsdp_states()`` returns all ``_FSDPState`` s in the tree. +- ``get_fsdp_root_states()`` returns all local root ``_FSDPState`` s in the +tree (i.e. those with ``_is_root == True``). +- ``_get_fsdp_handles()``returns all ``FlatParamHandle`` s in the tree. + +All of these methods must take in the root module (i.e. an ``nn.Module``) and +not a general ``_FSDPState`` because ``_FSDPState`` does not support a graph +traversal, whereas ``nn.Module`` has ``nn.Module.modules()`` for traversal. +""" + + +def _composable(module: nn.Module) -> bool: + """ + Returns if ``module`` can compose with ``fully_shard``. + """ + # TODO: Add any other composable APIs that are mutually exclusive. + registry = _get_registry(module) + if registry is None: + return True + return "replicate" not in registry + + +# TODO (awgu): We may be able to remove this function if we retired the +# `use_orig_params=False` code path since so far we only need the module for +# `FlatParameter` registration, which is not needed for `use_orig_params=True`. +def _get_fsdp_states_with_modules( + module: nn.Module, +) -> tuple[list[_FSDPState], list[nn.Module]]: + """ + Returns a tuple containing: + 1. A list of the ``_FSDPState`` instances in the module tree rooted at + ``module`` without any duplicates and following the ``module.modules()`` + traversal order (which is assumed to be depth-first). + 2. A corresponding list of the modules owning the states in the first list. + + For the wrapper code path, both returned lists are the same, each + containing all ``FullyShardedDataParallel`` instances. For the composable + code path, this returns a list of all composable state instances and a list + of the corresponding fully sharded modules. See [Note: Fully Sharded + Module]. + + NOTE: The traversal does not proceed into any module annotated by an + incompatible API (e.g. ``replicate``). + """ + fsdp_states: list[_FSDPState] = [] + fsdp_modules: list[nn.Module] = [] + # Track the visited FSDP states since multiple modules may share the same + # one and we want to return a de-duplicated list + visited_fsdp_states: set[_FSDPState] = set() + # Track the visited modules in case of shared modules, which implies the + # module graph is no longer a tree + visited_modules: set[nn.Module] = set() + + # Perform depth-first search from `module` to ensure that we do not + # traverse into an incompatible API's subtree (use DFS instead of BFS to + # match `.modules()` order) + deque: collections.deque[nn.Module] = collections.deque([module]) + while deque: + submodule = deque.popleft() + visited_modules.add(submodule) + if not _composable(submodule): + continue + for child_module in reversed(list(submodule.children())): + if child_module not in visited_modules: + deque.appendleft(child_module) + optional_state = _get_module_fsdp_state(submodule) + if optional_state is not None and optional_state not in visited_fsdp_states: + visited_fsdp_states.add(optional_state) + fsdp_states.append(optional_state) + fsdp_modules.append(submodule) + return fsdp_states, fsdp_modules + + +def _get_fsdp_states(module: nn.Module) -> list[_FSDPState]: + """See :func:`_get_fsdp_states_with_modules`.""" + fsdp_states, _ = _get_fsdp_states_with_modules(module) + return fsdp_states + + +def _get_fsdp_handles(module: nn.Module) -> list: + """ + Returns all ``FlatParamHandle`` s in the module tree rooted at ``module`` + following the rules in :func:`_get_fsdp_state`. + """ + handles = [ + fsdp_state._handle + for fsdp_state in _get_fsdp_states(module) + if fsdp_state._handle is not None + ] + return handles diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_unshard_param_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_unshard_param_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b265cab272f9e26a31ef3b188881f52c3d21c18a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_unshard_param_utils.py @@ -0,0 +1,337 @@ +# mypy: allow-untyped-defs +import contextlib +import warnings +from collections.abc import Generator +from typing import cast + +import torch +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_module_fsdp_state, + _has_fsdp_params, + _module_handle, + HandleTrainingState, + TrainingState, +) +from torch.distributed.fsdp._runtime_utils import ( + _lazy_init, + _reset_flat_param_grad_info_if_needed, + _reshard, + _reshard_grads, + _unshard, + _unshard_grads, +) +from torch.distributed.utils import _p_assert + +from ._flat_param import FlatParamHandle + + +FLAT_PARAM = "_flat_param" + + +@torch.no_grad() +def _writeback_to_local_shard( + handle: FlatParamHandle, + writeback_grad: bool, +): + """ + For the handle, writes back the this rank's shard of the unsharded + flattened parameter to the sharded flattened parameter. If + ``writeback_grad=True``, then writes back to the sharded gradient as + well. + + Precondition: The handle's ``FlatParameter`` 's data points to the + padded unsharded flattened parameter. + """ + + def _get_shard(flat_param_or_grad: torch.Tensor) -> torch.Tensor: + if handle.uses_sharded_strategy: + # For sharded strategies, get the *unpadded* shard instead of + # the *padded* shard to persist user changes to the padding + # (though FSDP does not explicitly support this) + shard, _ = FlatParamHandle._get_unpadded_shard( + flat_param_or_grad, + handle.rank, + handle.world_size, + ) + return shard + # For `NO_SHARD`, the `flat_param` or its gradient may be modified, + # so we write it back directly + return flat_param_or_grad + + param_shard = _get_shard(handle.flat_param) + handle.flat_param._local_shard[: param_shard.numel()].copy_(param_shard) # type: ignore[attr-defined] + if writeback_grad: + existing_grad = handle.sharded_grad + if existing_grad is not None: + assert handle.flat_param.grad is not None + grad_shard = _get_shard(handle.flat_param.grad) + existing_grad[: grad_shard.numel()].copy_(grad_shard) + + +def _deregister_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + De-registers the flattened parameter from the wrapped module, hiding it + from ``nn.Module`` methods. + + We do not use ``del`` because we want ``FLAT_PARAM`` to always be an + attribute but dynamically change whether it is visible to ``nn.Module`` + methods. + """ + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters.pop(FLAT_PARAM, None) + + +def _register_flat_param(state: _FSDPState, module: nn.Module) -> None: + """ + Registers the flattened parameter to the wrapped module, making it + visible to ``nn.Module`` methods. + + We do not use :meth:`nn.Module.register_parameter` because we want + ``FLAT_PARAM`` to always be an attribute but dynamically change whether + it is visible to ``nn.Module`` methods. + """ + handle = _module_handle(state, module) + if _has_fsdp_params(state, module): + # TODO: figure out the case for the composable APIs. + cast(nn.Module, module.module)._parameters[FLAT_PARAM] = handle.flat_param + + +@contextlib.contextmanager +def _unflatten_as_params(state: _FSDPState, module: nn.Module) -> Generator: + """ + Assumes that the flattened parameter is unsharded. When in the context, + de-registers the flattened parameter and unflattens the original + parameters as ``nn.Parameter`` views into the flattened parameter. + After the context, re-registers the flattened parameter and restores + the original parameters as ``Tensor`` views into the flattened + parameter. + """ + handle = _module_handle(state, module) + if not handle: + yield + else: + _deregister_flat_param(state, module) + try: + with handle.unflatten_as_params(): + yield + finally: + if not handle._use_orig_params: + _register_flat_param(state, module) + + +def _validate_unshard_params_args( + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +) -> None: + if with_grads and (offload_to_cpu or not state._use_orig_params): + raise NotImplementedError( + f"with_grads={with_grads}, " + f"use_orig_params={state._use_orig_params}, " + f"offload_to_cpu={offload_to_cpu} " + f"is not supported yet" + ) + if offload_to_cpu and state._handle and (not state._handle.uses_sharded_strategy): + raise NotImplementedError( + "offload_to_cpu=True and NO_SHARD is not supported yet" + ) + if writeback and rank0_only: + # TODO: Rank 0 can broadcast the `FlatParameter` to allow all ranks to + # persist the changes. + raise NotImplementedError( + "writeback=True and rank0_only=True is not supported yet" + ) + if offload_to_cpu and not rank0_only: + warnings.warn( + "offload_to_cpu=True and rank0_only=False may result in the" + "unsharded parameters being redundantly copied to CPU memory for " + "GPUs sharing the same CPU memory, which risks CPU OOM. We " + "recommend using offload_to_cpu=True with rank0_only=True." + ) + + +@contextlib.contextmanager +def _unshard_fsdp_state_params( + module: nn.Module, + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + """ + This unshards the parameters for a single FSDP state ``state`` that + corresponds to ``module``. + """ + _validate_unshard_params_args( + state, writeback, rank0_only, offload_to_cpu, with_grads + ) + state._device_handle.synchronize() + # If handles are shared by other module(s), the handle may be already unsharded. + maybe_handle = _module_handle(state, module) + handle = None + if ( + maybe_handle + and maybe_handle._training_state != HandleTrainingState.SUMMON_FULL_PARAMS + ): + handle = maybe_handle + if not handle: + yield + return + + assert handle._training_state == HandleTrainingState.IDLE, ( + f"Expects the handle training to be IDLE but got {handle._training_state}" + ) + + handle._training_state = HandleTrainingState.SUMMON_FULL_PARAMS + + _reset_flat_param_grad_info_if_needed(handle) + free_unsharded_flat_param = handle.needs_unshard() + # No need to call `wait_stream()` since we unshard in the computation + # stream directly + computation_stream = state._device_handle.current_stream() + _unshard(state, handle, computation_stream, computation_stream) + if with_grads: + _unshard_grads(handle) + + if rank0_only and state.rank != 0: + # Free the unsharded flattened parameter early + _reshard(state, handle, free_unsharded_flat_param) + if with_grads: + _reshard_grads(handle) + try: + yield + finally: + handle._training_state = HandleTrainingState.IDLE + else: + # Unflatten the unsharded flattened parameters + with contextlib.ExitStack() as stack: + # Invariant: rank == 0 or !rank0_only + if offload_to_cpu and handle.uses_sharded_strategy: + stack.enter_context(handle.to_cpu()) + # NOTE: Since PyTorch enforces that a parameter and its + # gradients need to match metadata (e.g. device), we must + # move gradients to CPU *after* we move parameters. + # NOTE: This assumes 1 `FlatParameter` + if not state._use_orig_params: + stack.enter_context(_unflatten_as_params(state, module)) + try: + yield + finally: + stack.close() + if writeback: + _writeback_to_local_shard(handle, with_grads) + _reshard(state, handle, free_unsharded_flat_param) + if with_grads: + _reshard_grads(handle) + handle._training_state = HandleTrainingState.IDLE + + +@contextlib.contextmanager +def _unshard_params_for_summon( + module: nn.Module, + state: _FSDPState, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + _validate_unshard_params_args( + state, writeback, rank0_only, offload_to_cpu, with_grads + ) + _lazy_init(state, module) + if state.training_state == TrainingState.FORWARD_BACKWARD: + raise AssertionError( + "Cannot manually unshard parameters during forward/backward" + ) + elif state.training_state == TrainingState.SUMMON_FULL_PARAMS: + raise AssertionError( + "Cannot manually unshard parameters when already unsharding parameters" + ) + with _unshard_fsdp_state_params( + module=module, + state=state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ): + try: + state.training_state = TrainingState.SUMMON_FULL_PARAMS + yield + finally: + state.training_state = TrainingState.IDLE + + +@contextlib.contextmanager +def _unshard_params( + module: nn.Module, + recurse: bool, + writeback: bool, + rank0_only: bool, + offload_to_cpu: bool, + with_grads: bool, +): + """ + This unshards FSDP-managed parameters for all modules with FSDP applied in + the module tree rooted at ``module``. + """ + if not recurse: + optional_state = _get_module_fsdp_state(module) + if optional_state is None: + with contextlib.nullcontext(): + yield + return + states_and_modules = ([optional_state], [module]) + else: + states_and_modules = traversal_utils._get_fsdp_states_with_modules(module) + with contextlib.ExitStack() as stack: + for state, module in zip(*states_and_modules): + stack.enter_context( + _unshard_params_for_summon( + module=module, + state=state, + writeback=writeback, + rank0_only=rank0_only, + offload_to_cpu=offload_to_cpu, + with_grads=with_grads, + ) + ) + yield + + +def _deregister_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the original parameters; registers the ``FlatParameter``. + """ + handle = _module_handle(state, module) + if not handle: + return + _p_assert( + handle._use_orig_params, + f"Inconsistent `_use_orig_params` -- FSDP: {state._use_orig_params} " + f"handle: {handle._use_orig_params}", + ) + handle._deregister_orig_params() + _register_flat_param(state, module) + + +def _register_orig_params(state: _FSDPState, module: nn.Module) -> None: + """ + Deregisters the ``FlatParameter``; registers the original parameters. + """ + handle = _module_handle(state, module) + if not handle: + return + _deregister_flat_param(state, module) + if handle.is_sharded(handle.flat_param): + handle._use_sharded_views() + handle._use_sharded_grad_views() + else: + handle._use_unsharded_views(as_params=True) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/_wrap_utils.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/_wrap_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..851e856c04bcbf96b48e4d2a26fbef46d7e3e2c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/_wrap_utils.py @@ -0,0 +1,262 @@ +# mypy: allow-untyped-defs +import collections +import functools +import inspect +import warnings +from functools import partial +from typing import Any, Callable, Union + +import torch.nn as nn +from torch.distributed.fsdp._common_utils import ( + _get_module_fsdp_state, + _override_module_mixed_precision, +) +from torch.distributed.fsdp.wrap import ( + _construct_wrap_fn, + _or_policy, + _Policy, + _post_order_apply, + _recursive_wrap, + _run_mixed_precision_override_policy, + _wrap_module_cls_individually, +) + + +def _auto_wrap( + root_module: nn.Module, + policy: Union[Callable, _Policy], + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + root_kwargs: dict[str, Any], + fsdp_fn: Callable, # e.g. `FullyShardedDataParallel` or `fully_shard` +): + """ + Auto wraps modules in ``root_module`` 's tree according to ``policy`` + following a post-order traversal. + + Precondition: ``root_kwargs`` should contain all arguments except + ``module``. This function accepts the kwargs dict directly since it gets + forwarded into the post-order traversal function. + """ + mixed_precision = root_kwargs["mixed_precision"] + is_wrapper = inspect.isclass(fsdp_fn) + # TODO: We may relax this no-nested-wrapping constraint to support manual + # wrapping followed by auto wrapping. + _check_nested_wrapping(root_module) + + if isinstance(policy, _Policy): + root_kwargs["auto_wrap_policy" if is_wrapper else "policy"] = None + target_module_to_kwargs = policy._run_policy( + root_module, ignored_modules, root_kwargs + ) + if mixed_precision is not None: + target_module_to_kwargs = _run_mixed_precision_override_policy( + root_module, + mixed_precision._module_classes_to_ignore, + ignored_modules, + root_kwargs, + target_module_to_kwargs, + ) + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + _warn_on_overridden_mixed_precision(overridden_module_classes) + use_orig_params = root_kwargs.get("use_orig_params", False) + _validate_frozen_params( + root_module, + set(target_module_to_kwargs.keys()), + ignored_params, + use_orig_params, + ) + wrap_fn = _construct_wrap_fn(root_module, target_module_to_kwargs, fsdp_fn) + _post_order_apply(root_module, wrap_fn) + return + + recursive_wrap_kwargs = { + "module": root_module, + "auto_wrap_policy": policy, + "wrapper_cls": fsdp_fn, + "ignored_modules": ignored_modules, + "ignored_params": ignored_params, + "only_wrap_children": True, + } + if mixed_precision is not None: + # Wrap modules of the ignored types separately and register forward + # hooks to cast to fp32 and back to the original dtype, respectively + overridden_module_classes = _override_module_mixed_precision( + root_module, mixed_precision._module_classes_to_ignore + ) + policy = functools.partial( + _or_policy, + policies=[ + policy, + partial( + _wrap_module_cls_individually, + module_classes=mixed_precision._module_classes_to_ignore, + ), + ], + ) + recursive_wrap_kwargs["auto_wrap_policy"] = policy + _warn_on_overridden_mixed_precision(overridden_module_classes) + _recursive_wrap(**recursive_wrap_kwargs, **root_kwargs) # type: ignore[arg-type] + + +def _check_nested_wrapping(root_module: nn.Module): + for module_name, module in root_module.named_modules(): + if _get_module_fsdp_state(module) is not None: + raise ValueError( + "FSDP auto wrapping requires modules to not already have " + f"FSDP applied but found {module_name} in\n{root_module}" + ) + + +def _warn_on_overridden_mixed_precision( + overridden_module_classes: set[type[nn.Module]], +): + if len(overridden_module_classes) == 0: + return + warnings.warn( + "Both mixed precision and an auto_wrap_policy were specified to FSDP, " + f"where the wrapped module has submodules of type:\n{overridden_module_classes}\n" + "These modules will be wrapped as separate FSDP instacnes with mixed " + "precision disabled." + ) + + +def _validate_frozen_params( + root_module: nn.Module, + modules_to_wrap: set[nn.Module], + ignored_params: set[nn.Parameter], + use_orig_params: bool, +): + """ + This checks that, given ``modules_to_wrap``, each module would manage + parameters that are uniformly frozen or non-frozen. This uniformity + requirement is strict for ``use_orig_params=False`` (hard error) and highly + recommended for ``use_orig_params=True`` (user warning). + """ + post_order_named_modules = _get_post_order_named_modules(root_module) + visited_modules: set[nn.Module] = set() + for module_name, module in post_order_named_modules: + if module in modules_to_wrap: + param_to_fqn = _get_managed_param_to_fqn( + module, ignored_params, visited_modules, module_name + ) + frozen_param_fqns: list[str] = [] + frozen_param_numel = 0 + nonfrozen_param_fqns: list[str] = [] + nonfrozen_param_numel = 0 + for param, fqn in param_to_fqn.items(): + if param.requires_grad: + nonfrozen_param_fqns.append(fqn) + nonfrozen_param_numel += param.numel() + else: + frozen_param_fqns.append(fqn) + frozen_param_numel += param.numel() + if len(frozen_param_fqns) > 0 and len(nonfrozen_param_fqns) > 0: + msg = f"{module_name} has both parameters with requires_grad=True and False." + if use_orig_params: + total_param_numel = frozen_param_numel + nonfrozen_param_numel + msg += ( + " We do not recommend wrapping such modules since " + "the gradient memory usage will be higher than expected " + f"({total_param_numel} numel instead of {nonfrozen_param_numel} numel " + "before sharding via reduce-scatter). " + ) + else: + msg += " FSDP does not support wrapping such modules when use_orig_params=False. " + msg += "If possible, wrap the frozen parameters with FSDP separately.\n" + msg += ( + f"The following parameters have requires_grad=True:\n{nonfrozen_param_fqns}\n" + f"The following parameters have requires_grad=False:\n{frozen_param_fqns}" + ) + if use_orig_params: + warnings.warn(msg) + else: + raise ValueError(msg) + + +def _get_post_order_named_modules( + root_module: nn.Module, +) -> list[tuple[str, nn.Module]]: + """ + This returns the named modules following a post-order traversal, which is a + valid reverse topological sort. We achieve this using the reverse of a + stack-based DFS order instead of reversing ``root_module.named_modules()`` + since the former gives the modules in registration order at each level in + the module tree (as opposed to the reverse), which allows us to error/warn + on the first registered module that violates the condition. + + For example, consider the following module structure: + M( + S1(), + S2( + SS1(), + SS2(), + ), + S3(), + ) + The reverse DFS order is [S1, SS1, SS2, S2, S3, M], while the reverse + ``named_modules()`` order is [S3, SS2, SS1, S2, S1, M]. + """ + visited_modules = {root_module} + stack = [("", root_module)] + # Append and reverse at the end for linear-time algorithm + reverse_post_order_named_modules: list[tuple[str, nn.Module]] = [] + while stack: + module_name, module = stack.pop() + reverse_post_order_named_modules.append((module_name, module)) + for child_module_name, child_module in module.named_children(): + if child_module is None: # only for overrides of `named_children()` + continue + if child_module not in visited_modules: + visited_modules.add(child_module) + if module_name != "": + child_module_name = module_name + "." + child_module_name + stack.append((child_module_name, child_module)) + post_order_named_modules = list(reversed(reverse_post_order_named_modules)) + return post_order_named_modules + + +def _get_managed_param_to_fqn( + module_to_wrap: nn.Module, + ignored_params: set[nn.Parameter], + visited_modules: set[nn.Module], + root_prefix: str, +) -> dict[nn.Parameter, str]: + """ + This returns a dict that maps managed parameter to its FQN for the given + ``module_to_wrap``. The dict's keys are exactly the parameters that would + be managed by the module, where this is achieved by calling this function + on the modules to wrap in reverse topological order, destructively updating + ``visited_modules``, and not traversing into those modules. The FQNs are + prefixed from the root (via ``root_prefix``) to be more informative. + + NOTE: This function is meant to be called pre-wrapping and iteratively in + reverse topological order to cover the full module tree. This differs from + the ``_get_param_to_fqn()`` function meant to be called post-wrapping and + on the full module tree in one shot. Given those differences, we do not try + to unify the two. + """ + param_to_fqn: dict[nn.Parameter, str] = {} + # Run BFS (or any tree traversal works) + queue = collections.deque([(module_to_wrap, root_prefix)]) + visited_modules.add(module_to_wrap) + while queue: + module, prefix = queue.popleft() + for param_name, param in module.named_parameters(recurse=False): + if param not in ignored_params: + fqn = param_name if prefix == "" else prefix + "." + param_name + param_to_fqn[param] = fqn + for child_module_name, child_module in module.named_children(): + if child_module is None: # only for overrides of `named_children()` + continue + if child_module not in visited_modules: + visited_modules.add(child_module) + child_prefix = ( + child_module_name + if prefix == "" + else prefix + "." + child_module_name + ) + queue.append((child_module, child_prefix)) + return param_to_fqn diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/api.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/api.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9d9b7e51ce21dc16d53e90a5c9effa2728bcdb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/api.py @@ -0,0 +1,417 @@ +""" +This file includes public APIs for FSDP such as the classes used for the +constructor arguments. +""" + +from collections.abc import Sequence +from dataclasses import dataclass +from enum import auto, Enum +from typing import Optional + +import torch +from torch.nn.modules.batchnorm import _BatchNorm + + +__all__ = [ + "ShardingStrategy", + "BackwardPrefetch", + "MixedPrecision", + "CPUOffload", + "StateDictType", + "StateDictConfig", + "FullStateDictConfig", + "LocalStateDictConfig", + "ShardedStateDictConfig", + "OptimStateDictConfig", + "FullOptimStateDictConfig", + "LocalOptimStateDictConfig", + "ShardedOptimStateDictConfig", + "StateDictSettings", +] + + +class ShardingStrategy(Enum): + """ + This specifies the sharding strategy to be used for distributed training by + :class:`FullyShardedDataParallel`. + + - ``FULL_SHARD``: Parameters, gradients, and optimizer states are sharded. + For the parameters, this strategy unshards (via all-gather) before the + forward, reshards after the forward, unshards before the backward + computation, and reshards after the backward computation. For gradients, + it synchronizes and shards them (via reduce-scatter) after the backward + computation. The sharded optimizer states are updated locally per rank. + - ``SHARD_GRAD_OP``: Gradients and optimizer states are sharded during + computation, and additionally, parameters are sharded outside + computation. For the parameters, this strategy unshards before the + forward, does not reshard them after the forward, and only reshards them + after the backward computation. The sharded optimizer states are updated + locally per rank. Inside ``no_sync()``, the parameters are not resharded + after the backward computation. + - ``NO_SHARD``: Parameters, gradients, and optimizer states are not sharded + but instead replicated across ranks similar to PyTorch's + :class:`DistributedDataParallel` API. For gradients, this strategy + synchronizes them (via all-reduce) after the backward computation. The + unsharded optimizer states are updated locally per rank. + - ``HYBRID_SHARD``: Apply ``FULL_SHARD`` within a node, and replicate parameters across + nodes. This results in reduced communication volume as expensive all-gathers and + reduce-scatters are only done within a node, which can be more performant for medium + -sized models. + - ``_HYBRID_SHARD_ZERO2``: Apply ``SHARD_GRAD_OP`` within a node, and replicate parameters across + nodes. This is like ``HYBRID_SHARD``, except this may provide even higher throughput + since the unsharded parameters are not freed after the forward pass, saving the + all-gathers in the pre-backward. + """ + + FULL_SHARD = auto() + SHARD_GRAD_OP = auto() + NO_SHARD = auto() + HYBRID_SHARD = auto() + _HYBRID_SHARD_ZERO2 = auto() + + +class BackwardPrefetch(Enum): + """ + This configures explicit backward prefetching, which improves throughput by + enabling communication and computation overlap in the backward pass at the + cost of slightly increased memory usage. + + - ``BACKWARD_PRE``: This enables the most overlap but increases memory + usage the most. This prefetches the next set of parameters *before* the + current set of parameters' gradient computation. This overlaps the *next + all-gather* and the *current gradient computation*, and at the peak, it + holds the current set of parameters, next set of parameters, and current + set of gradients in memory. + - ``BACKWARD_POST``: This enables less overlap but requires less memory + usage. This prefetches the next set of parameters *after* the current + set of parameters' gradient computation. This overlaps the *current + reduce-scatter* and the *next gradient computation*, and it frees the + current set of parameters before allocating memory for the next set of + parameters, only holding the next set of parameters and current set of + gradients in memory at the peak. + - FSDP's ``backward_prefetch`` argument accepts ``None``, which disables + the backward prefetching altogether. This has no overlap and does not + increase memory usage. In general, we do not recommend this setting since + it may degrade throughput significantly. + + For more technical context: For a single process group using NCCL backend, + any collectives, even if issued from different streams, contend for the + same per-device NCCL stream, which implies that the relative order in which + the collectives are issued matters for overlapping. The two backward + prefetching values correspond to different issue orders. + """ + + # NOTE: For both modes, the ordering that defines "current" and "next" is + # not always exact in the current implementation. A mistargeted prefetch + # simply means that the parameter memory is allocated earlier than needed, + # possibly increasing peak memory usage, but does not affect correctness. + BACKWARD_PRE = auto() + BACKWARD_POST = auto() + + +@dataclass +class MixedPrecision: + """ + This configures FSDP-native mixed precision training. + + Attributes: + param_dtype (Optional[torch.dtype]): This specifies the dtype for model + parameters during forward and backward and thus the dtype for + forward and backward computation. Outside forward and backward, the + *sharded* parameters are kept in full precision (e.g. for the + optimizer step), and for model checkpointing, the parameters are + always saved in full precision. (Default: ``None``) + reduce_dtype (Optional[torch.dtype]): This specifies the dtype for + gradient reduction (i.e. reduce-scatter or all-reduce). If this is + ``None`` but ``param_dtype`` is not ``None``, then this takes on + the ``param_dtype`` value, still running gradient reduction in low + precision. This is permitted to differ from ``param_dtype``, e.g. + to force gradient reduction to run in full precision. (Default: + ``None``) + buffer_dtype (Optional[torch.dtype]): This specifies the dtype for + buffers. FSDP does not shard buffers. Rather, FSDP casts them to + ``buffer_dtype`` in the first forward pass and keeps them in that + dtype thereafter. For model checkpointing, the buffers are saved + in full precision except for ``LOCAL_STATE_DICT``. (Default: + ``None``) + keep_low_precision_grads (bool): If ``False``, then FSDP upcasts + gradients to full precision after the backward pass in preparation + for the optimizer step. If ``True``, then FSDP keeps the gradients + in the dtype used for gradient reduction, which can save memory if + using a custom optimizer that supports running in low precision. + (Default: ``False``) + cast_forward_inputs (bool): If ``True``, then this FSDP module casts + its forward args and kwargs to ``param_dtype``. This is to ensure + that parameter and input dtypes match for forward computation, as + required by many ops. This may need to be set to ``True`` when only + applying mixed precision to some but not all FSDP modules, in which + case a mixed-precision FSDP submodule needs to recast its inputs. + (Default: ``False``) + cast_root_forward_inputs (bool): If ``True``, then the root FSDP module + casts its forward args and kwargs to ``param_dtype``, overriding + the value of ``cast_forward_inputs``. For non-root FSDP modules, + this does not do anything. (Default: ``True``) + _module_classes_to_ignore: (Sequence[Type[nn.Module]]): This specifies + module classes to ignore for mixed precision when using an + ``auto_wrap_policy``: Modules of these classes will have FSDP + applied to them separately with mixed precision disabled (meaning + that the final FSDP construction would deviate from the specified + policy). If ``auto_wrap_policy`` is not specified, then this does + not do anything. This API is experimental and subject to change. + (Default: ``(_BatchNorm,)``) + + .. note:: This API is experimental and subject to change. + + .. note:: Only floating point tensors are cast to their specified dtypes. + + .. note:: In ``summon_full_params``, parameters are forced to full + precision, but buffers are not. + + .. note:: Layer norm and batch norm accumulate in ``float32`` even when + their inputs are in a low precision like ``float16`` or ``bfloat16``. + Disabling FSDP's mixed precision for those norm modules only means that + the affine parameters are kept in ``float32``. However, this incurs + separate all-gathers and reduce-scatters for those norm modules, which + may be inefficient, so if the workload permits, the user should prefer + to still apply mixed precision to those modules. + + .. note:: By default, if the user passes a model with any ``_BatchNorm`` + modules and specifies an ``auto_wrap_policy``, then the batch norm + modules will have FSDP applied to them separately with mixed precision + disabled. See the ``_module_classes_to_ignore`` argument. + + .. note:: ``MixedPrecision`` has ``cast_root_forward_inputs=True`` and + ``cast_forward_inputs=False`` by default. For the root FSDP instance, + its ``cast_root_forward_inputs`` takes precedence over its + ``cast_forward_inputs``. For non-root FSDP instances, their + ``cast_root_forward_inputs`` values are ignored. The default setting is + sufficient for the typical case where each FSDP instance has the same + ``MixedPrecision`` configuration and only needs to cast inputs to the + ``param_dtype`` at the beginning of the model's forward pass. + + .. note:: For nested FSDP instances with different ``MixedPrecision`` + configurations, we recommend setting individual ``cast_forward_inputs`` + values to configure casting inputs or not before each instance's + forward. In such a case, since the casts happen before each FSDP + instance's forward, a parent FSDP instance should have its non-FSDP + submodules run before its FSDP submodules to avoid the activation dtype + being changed due to a different ``MixedPrecision`` configuration. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3)) + >>> model[1] = FSDP( + >>> model[1], + >>> mixed_precision=MixedPrecision(param_dtype=torch.float16, cast_forward_inputs=True), + >>> ) + >>> model = FSDP( + >>> model, + >>> mixed_precision=MixedPrecision(param_dtype=torch.bfloat16, cast_forward_inputs=True), + >>> ) + + The above shows a working example. On the other hand, if ``model[1]`` + were replaced with ``model[0]``, meaning that the submodule using + different ``MixedPrecision`` ran its forward first, then ``model[1]`` + would incorrectly see ``float16`` activations instead of ``bfloat16`` + ones. + + """ + + param_dtype: Optional[torch.dtype] = None + reduce_dtype: Optional[torch.dtype] = None + buffer_dtype: Optional[torch.dtype] = None + keep_low_precision_grads: bool = False + cast_forward_inputs: bool = False + cast_root_forward_inputs: bool = True + _module_classes_to_ignore: Sequence[type[torch.nn.Module]] = (_BatchNorm,) + + +@dataclass +class CPUOffload: + """ + This configures CPU offloading. + + Attributes: + offload_params (bool): This specifies whether to offload parameters to + CPU when not involved in computation. If ``True``, then this + offloads gradients to CPU as well, meaning that the optimizer step + runs on CPU. + """ + + offload_params: bool = False + + +class StateDictType(Enum): + """ + This enum indicates that which type of ``state_dict`` the FSDP module is + currently processing (returning or loading). + The default value is FULL_STATE_DICT to comply the PyTorch convention. + + .. note:: + FSDP currently supports three types of ``state_dict``: + 1. ``state_dict/load_state_dict`: this pair of APIs return and load + the non-sharded, unflattened parameters. The semantics is the + same as using DDP. + 2. ``_local_state_dict/_load_local_state_dict``: this pair of APIs return + and load local sharded, flattened parameters. The values returned + by ``_local_state_dict`` can be directly used by FSDP and is only + meaningful to FSDP (because parameters are flattened). Note that + these APIs are meant for use via the :func:`state_dict_type` + context manager as follows: + >>> # xdoctest: +SKIP("undefined variables") + >>> with fsdp.state_dict_type(StateDictType.LOCAL_STATE_DICT): + ... state = fsdp.state_dict() # loads local state dict + 3. ``_sharded_state_dict/_load_sharded_state_dict``: this pair of APIs + return and load sharded, unflattened parameters. The ``state_dict`` + return by ``sharded_state_dict`` can be used by all other parallel + schemes (resharding may be required). + """ + + FULL_STATE_DICT = auto() + LOCAL_STATE_DICT = auto() + SHARDED_STATE_DICT = auto() + + +@dataclass +class StateDictConfig: + """ + ``StateDictConfig`` is the base class for all ``state_dict`` configuration + classes. Users should instantiate a child class (e.g. + ``FullStateDictConfig``) in order to configure settings for the + corresponding ``state_dict`` type supported by FSDP. + + Attributes: + offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict + values to CPU, and if ``False``, then FSDP keeps them on GPU. + (Default: ``False``) + """ + + offload_to_cpu: bool = False + + +@dataclass +class FullStateDictConfig(StateDictConfig): + """ + ``FullStateDictConfig`` is a config class meant to be used with + ``StateDictType.FULL_STATE_DICT``. We recommend enabling both + ``offload_to_cpu=True`` and ``rank0_only=True`` when saving full state + dicts to save GPU memory and CPU memory, respectively. This config class + is meant to be used via the :func:`state_dict_type` context manager as + follows: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> fsdp = FSDP(model, auto_wrap_policy=...) + >>> cfg = FullStateDictConfig(offload_to_cpu=True, rank0_only=True) + >>> with FSDP.state_dict_type(fsdp, StateDictType.FULL_STATE_DICT, cfg): + >>> state = fsdp.state_dict() + >>> # `state` will be empty on non rank 0 and contain CPU tensors on rank 0. + >>> # To reload checkpoint for inference, finetuning, transfer learning, etc: + >>> model = model_fn() # Initialize model in preparation for wrapping with FSDP + >>> if dist.get_rank() == 0: + >>> # Load checkpoint only on rank 0 to avoid memory redundancy + >>> state_dict = torch.load("my_checkpoint.pt") + >>> model.load_state_dict(state_dict) + >>> # All ranks initialize FSDP module as usual. `sync_module_states` argument + >>> # communicates loaded checkpoint states from rank 0 to rest of the world. + >>> fsdp = FSDP( + ... model, + ... device_id=torch.cuda.current_device(), + ... auto_wrap_policy=..., + ... sync_module_states=True, + ... ) + >>> # After this point, all ranks have FSDP model with loaded checkpoint. + + Attributes: + rank0_only (bool): If ``True``, then only rank 0 saves the full state + dict, and nonzero ranks save an empty dict. If ``False``, then all + ranks save the full state dict. (Default: ``False``) + """ + + rank0_only: bool = False + + +@dataclass +class LocalStateDictConfig(StateDictConfig): + pass + + +@dataclass +class ShardedStateDictConfig(StateDictConfig): + """ + ``ShardedStateDictConfig`` is a config class meant to be used with + ``StateDictType.SHARDED_STATE_DICT``. + + Attributes: + _use_dtensor (bool): If ``True``, then FSDP saves the state dict values + as ``DTensor``, and if ``False``, then FSDP saves them as + ``ShardedTensor``. (Default: ``False``) + + .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedStateDictConfig` + and it is used by FSDP to determine the type of state dict values. Users should not + manually modify ``_use_dtensor``. + """ + + _use_dtensor: bool = False + + +@dataclass +class OptimStateDictConfig: + """ + ``OptimStateDictConfig`` is the base class for all ``optim_state_dict`` + configuration classes. Users should instantiate a child class (e.g. + ``FullOptimStateDictConfig``) in order to configure settings for the + corresponding ``optim_state_dict`` type supported by FSDP. + + Attributes: + offload_to_cpu (bool): If ``True``, then FSDP offloads the state dict's + tensor values to CPU, and if ``False``, then FSDP keeps them on the + original device (which is GPU unless parameter CPU offloading is + enabled). (Default: ``True``) + """ + + offload_to_cpu: bool = True + + +@dataclass +class FullOptimStateDictConfig(OptimStateDictConfig): + """ + Attributes: + rank0_only (bool): If ``True``, then only rank 0 saves the full state + dict, and nonzero ranks save an empty dict. If ``False``, then all + ranks save the full state dict. (Default: ``False``) + """ + + rank0_only: bool = False + + +@dataclass +class LocalOptimStateDictConfig(OptimStateDictConfig): + offload_to_cpu: bool = False + + +@dataclass +class ShardedOptimStateDictConfig(OptimStateDictConfig): + """ + ``ShardedOptimStateDictConfig`` is a config class meant to be used with + ``StateDictType.SHARDED_STATE_DICT``. + + Attributes: + _use_dtensor (bool): If ``True``, then FSDP saves the state dict values + as ``DTensor``, and if ``False``, then FSDP saves them as + ``ShardedTensor``. (Default: ``False``) + + .. warning:: ``_use_dtensor`` is a private field of :class:`ShardedOptimStateDictConfig` + and it is used by FSDP to determine the type of state dict values. Users should not + manually modify ``_use_dtensor``. + """ + + _use_dtensor: bool = False + + +@dataclass +class StateDictSettings: + state_dict_type: StateDictType + state_dict_config: StateDictConfig + optim_state_dict_config: OptimStateDictConfig diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..8e22ad3e94f749f2e5fb8aa2a5fdeea3fd63d485 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py @@ -0,0 +1,2175 @@ +# mypy: ignore-errors + +import contextlib +import copy +import functools +import math +import traceback +import warnings +from collections.abc import Generator, Iterable, Iterator +from contextlib import contextmanager +from enum import auto, Enum +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +import torch.distributed.fsdp._traversal_utils as traversal_utils +import torch.nn as nn +from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( + _CHECKPOINT_WRAPPED_MODULE, + ActivationWrapper, +) +from torch.distributed.algorithms._comm_hooks import LOW_PRECISION_HOOKS +from torch.distributed.fsdp._common_utils import ( + _FSDPState, + _get_param_to_fqns, + FSDP_PREFIX, + FSDP_WRAPPED_MODULE, + HandleTrainingState, + TrainingState, +) +from torch.distributed.fsdp._dynamo_utils import _annotate_modules_for_dynamo +from torch.distributed.fsdp._init_utils import ( + _check_orig_params_flattened, + _init_buffer_state, + _init_core_state, + _init_device_handle, + _init_extension, + _init_ignored_module_states, + _init_param_handle_from_module, + _init_prefetching_state, + _init_process_group_state, + _init_runtime_state, + _init_state_dict_state, + HYBRID_SHARDING_STRATEGIES, + ProcessGroupType, +) +from torch.distributed.fsdp._runtime_utils import ( + _get_fsdp_root_states, + _is_fsdp_root, + _lazy_init, + _post_forward, + _post_forward_reshard, + _pre_forward, + _pre_forward_unshard, + _root_pre_forward, + _unshard, + _wait_for_computation_stream, +) +from torch.distributed.fsdp._wrap_utils import _auto_wrap +from torch.distributed.fsdp.api import ( + BackwardPrefetch, + CPUOffload, + FullOptimStateDictConfig, + FullStateDictConfig, + LocalOptimStateDictConfig, + LocalStateDictConfig, + MixedPrecision, + OptimStateDictConfig, + ShardedOptimStateDictConfig, + ShardedStateDictConfig, + ShardingStrategy, + StateDictConfig, + StateDictSettings, + StateDictType, +) +from torch.distributed.tensor import DeviceMesh +from torch.distributed.utils import _p_assert + +from ._flat_param import FlatParameter, FlatParamHandle +from ._optim_utils import ( + _flatten_optim_state_dict, + _get_param_id_to_param_from_optim_input, + _get_param_key_to_param, + _get_param_to_param_id_from_optim_input, + _get_param_to_param_key, + _optim_state_dict, + _rekey_sharded_optim_state_dict, + _set_optim_use_dtensor, +) +from ._state_dict_utils import _register_all_state_dict_hooks +from ._unshard_param_utils import ( + _deregister_orig_params, + _register_flat_param, + _register_orig_params, + _unshard_params, + _unshard_params_for_summon, +) +from .wrap import CustomPolicy, ModuleWrapPolicy + + +__all__ = [ + "FullyShardedDataParallel", + "OptimStateKeyType", +] + + +FLAT_PARAM = "_flat_param" + + +class OptimStateKeyType(Enum): + """Represents the type of key in an optimizer state-dict.""" + + PARAM_NAME = auto() + PARAM_ID = auto() + + +class FullyShardedDataParallel(nn.Module, _FSDPState): + """A wrapper for sharding module parameters across data parallel workers. + + This is inspired by `Xu et al. `_ as + well as the ZeRO Stage 3 from `DeepSpeed `_. + FullyShardedDataParallel is commonly shortened to FSDP. + + To understand FSDP internals, refer to the + :ref:`fsdp_notes`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> torch.cuda.set_device(device_id) + >>> sharded_module = FSDP(my_module) + >>> optim = torch.optim.Adam(sharded_module.parameters(), lr=0.0001) + >>> x = sharded_module(x, y=3, z=torch.Tensor([1])) + >>> loss = x.sum() + >>> loss.backward() + >>> optim.step() + + Using FSDP involves wrapping your module and then initializing your + optimizer after. This is required since FSDP changes the parameter + variables. + + When setting up FSDP, you need to consider the destination CUDA + device. If the device has an ID (``dev_id``), you have three options: + + * Place the module on that device + * Set the device using ``torch.cuda.set_device(dev_id)`` + * Pass ``dev_id`` into the ``device_id`` constructor argument. + + This ensures that the FSDP instance's compute device is the + destination device. For option 1 and 3, the FSDP initialization + always occurs on GPU. For option 2, the FSDP initialization + happens on module's current device, which may be a CPU. + + If you're using the ``sync_module_states=True`` flag, you need to + ensure that the module is on a GPU or use the ``device_id`` + argument to specify a CUDA device that FSDP will move the module + to in the FSDP constructor. This is necessary because + ``sync_module_states=True`` requires GPU communication. + + FSDP also takes care of moving input tensors to the forward method + to the GPU compute device, so you don't need to manually move them + from CPU. + + For ``use_orig_params=True``, + ``ShardingStrategy.SHARD_GRAD_OP`` exposes the unsharded + parameters, not the sharded parameters after forward, unlike + ``ShardingStrategy.FULL_SHARD``. If you want + to inspect the gradients, you can use the ``summon_full_params`` + method with ``with_grads=True``. + + With ``limit_all_gathers=True``, you may see a gap in the FSDP + pre-forward where the CPU thread is not issuing any kernels. This is + intentional and shows the rate limiter in effect. Synchronizing the CPU + thread in that way prevents over-allocating memory for subsequent + all-gathers, and it should not actually delay GPU kernel execution. + + FSDP replaces managed modules' parameters with ``torch.Tensor`` + views during forward and backward computation for autograd-related + reasons. If your module's forward relies on saved references to + the parameters instead of reacquiring the references each + iteration, then it will not see FSDP's newly created views, + and autograd will not work correctly. + + Finally, when using ``sharding_strategy=ShardingStrategy.HYBRID_SHARD`` + with the sharding process group being intra-node and the + replication process group being inter-node, setting + ``NCCL_CROSS_NIC=1`` can help improve the all-reduce times over + the replication process group for some cluster setups. + + **Limitations** + + There are several limitations to be aware of when using FSDP: + + * FSDP currently does not support gradient accumulation outside + ``no_sync()`` when using CPU offloading. This is because FSDP + uses the newly-reduced gradient instead of accumulating with any + existing gradient, which can lead to incorrect results. + + * FSDP does not support running the forward pass of a submodule + that is contained in an FSDP instance. This is because the + submodule's parameters will be sharded, but the submodule itself + is not an FSDP instance, so its forward pass will not all-gather + the full parameters appropriately. + + * FSDP does not work with double backwards due to the way it + registers backward hooks. + + * FSDP has some constraints when freezing parameters. + For ``use_orig_params=False``, each FSDP instance must manage + parameters that are all frozen or all non-frozen. For + ``use_orig_params=True``, FSDP supports mixing frozen and + non-frozen parameters, but it's recommended to avoid doing so to + prevent higher than expected gradient memory usage. + + * As of PyTorch 1.12, FSDP offers limited support for shared + parameters. If enhanced shared parameter support is needed for + your use case, please post in + `this issue `__. + + * You should avoid modifying the parameters between forward and + backward without using the ``summon_full_params`` context, as + the modifications may not persist. + + Args: + module (nn.Module): + This is the module to be wrapped with FSDP. + process_group (Optional[Union[ProcessGroup, Tuple[ProcessGroup, ProcessGroup]]]): + This is the process group over which the model is sharded and thus + the one used for FSDP's all-gather and reduce-scatter collective + communications. If ``None``, then FSDP uses the default process + group. For hybrid sharding strategies such as + ``ShardingStrategy.HYBRID_SHARD``, users can pass in a tuple of + process groups, representing the groups over which to shard and + replicate, respectively. If ``None``, then FSDP constructs process + groups for the user to shard intra-node and replicate inter-node. + (Default: ``None``) + sharding_strategy (Optional[ShardingStrategy]): + This configures the sharding strategy, which may trade off memory + saving and communication overhead. See :class:`ShardingStrategy` + for details. (Default: ``FULL_SHARD``) + cpu_offload (Optional[CPUOffload]): + This configures CPU offloading. If this is set to ``None``, then + no CPU offloading happens. See :class:`CPUOffload` for details. + (Default: ``None``) + auto_wrap_policy (Optional[Union[Callable[[nn.Module, bool, int], bool], ModuleWrapPolicy, CustomPolicy]]): + This specifies a policy to apply FSDP to submodules of ``module``, + which is needed for communication and computation overlap and thus + affects performance. If ``None``, then FSDP only applies to + ``module``, and users should manually apply FSDP to parent modules + themselves (proceeding bottom-up). For convenience, this accepts + ``ModuleWrapPolicy`` directly, which allows users to specify the + module classes to wrap (e.g. the transformer block). Otherwise, + this should be a callable that takes in three arguments + ``module: nn.Module``, ``recurse: bool``, and + ``nonwrapped_numel: int`` and should return a ``bool`` specifying + whether the passed-in ``module`` should have FSDP applied if + ``recurse=False`` or if the traversal should continue into the + module's subtree if ``recurse=True``. Users may add additional + arguments to the callable. The ``size_based_auto_wrap_policy`` in + ``torch.distributed.fsdp.wrap.py`` gives an example callable that + applies FSDP to a module if the parameters in its subtree exceed + 100M numel. We recommend printing the model after applying FSDP + and adjusting as needed. + + Example:: + + >>> def custom_auto_wrap_policy( + >>> module: nn.Module, + >>> recurse: bool, + >>> nonwrapped_numel: int, + >>> # Additional custom arguments + >>> min_num_params: int = int(1e8), + >>> ) -> bool: + >>> return nonwrapped_numel >= min_num_params + >>> # Configure a custom `min_num_params` + >>> my_auto_wrap_policy = functools.partial(custom_auto_wrap_policy, min_num_params=int(1e5)) + + backward_prefetch (Optional[BackwardPrefetch]): + This configures explicit backward prefetching of all-gathers. If + ``None``, then FSDP does not backward prefetch, and there is no + communication and computation overlap in the backward pass. See + :class:`BackwardPrefetch` for details. (Default: ``BACKWARD_PRE``) + mixed_precision (Optional[MixedPrecision]): + This configures native mixed precision for FSDP. If this is set to + ``None``, then no mixed precision is used. Otherwise, parameter, + buffer, and gradient reduction dtypes can be set. See + :class:`MixedPrecision` for details. (Default: ``None``) + ignored_modules (Optional[Iterable[torch.nn.Module]]): Modules whose + own parameters and child modules' parameters and buffers are + ignored by this instance. None of the modules directly in + ``ignored_modules`` should be :class:`FullyShardedDataParallel` + instances, and any child modules that are already-constructed + :class:`FullyShardedDataParallel` instances will not be ignored if + they are nested under this instance. This argument may be used to + avoid sharding specific parameters at module granularity when using an + ``auto_wrap_policy`` or if parameters' sharding is not managed by + FSDP. (Default: ``None``) + param_init_fn (Optional[Callable[[nn.Module], None]]): + A ``Callable[torch.nn.Module] -> None`` that + specifies how modules that are currently on the meta device should + be initialized onto an actual device. As of v1.12, FSDP detects + modules with parameters or buffers on meta device via ``is_meta`` + and either applies ``param_init_fn`` if specified or calls + ``nn.Module.reset_parameters()`` otherwise. For both cases, the + implementation should *only* initialize the parameters/buffers of + the module, not those of its submodules. This is to avoid + re-initialization. In addition, FSDP also supports deferred + initialization via torchdistX's (https://github.com/pytorch/torchdistX) + ``deferred_init()`` API, where the deferred modules are initialized + by calling ``param_init_fn`` if specified or torchdistX's default + ``materialize_module()`` otherwise. If ``param_init_fn`` is + specified, then it is applied to all meta-device modules, meaning + that it should probably case on the module type. FSDP calls the + initialization function before parameter flattening and sharding. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> module = MyModule(device="meta") + >>> def my_init_fn(module: nn.Module): + >>> # E.g. initialize depending on the module type + >>> ... + >>> fsdp_model = FSDP(module, param_init_fn=my_init_fn, auto_wrap_policy=size_based_auto_wrap_policy) + >>> print(next(fsdp_model.parameters()).device) # current CUDA device + >>> # With torchdistX + >>> module = deferred_init.deferred_init(MyModule, device="cuda") + >>> # Will initialize via deferred_init.materialize_module(). + >>> fsdp_model = FSDP(module, auto_wrap_policy=size_based_auto_wrap_policy) + + device_id (Optional[Union[int, torch.device]]): An ``int`` or + ``torch.device`` giving the CUDA device on which FSDP + initialization takes place, including the module initialization + if needed and the parameter sharding. This should be specified to + improve initialization speed if ``module`` is on CPU. If the + default CUDA device was set (e.g. via ``torch.cuda.set_device``), + then the user may pass ``torch.cuda.current_device`` to this. + (Default: ``None``) + sync_module_states (bool): If ``True``, then each FSDP module will + broadcast module parameters and buffers from rank 0 to ensure that + they are replicated across ranks (adding communication overhead to + this constructor). This can help load ``state_dict`` checkpoints + via ``load_state_dict`` in a memory efficient way. See + :class:`FullStateDictConfig` for an example of this. (Default: + ``False``) + forward_prefetch (bool): If ``True``, then FSDP *explicitly* prefetches + the next forward-pass all-gather before the current forward + computation. This is only useful for CPU-bound workloads, in which + case issuing the next all-gather earlier may improve overlap. This + should only be used for static-graph models since the prefetching + follows the first iteration's execution order. (Default: ``False``) + limit_all_gathers (bool): If ``True``, then FSDP explicitly + synchronizes the CPU thread to ensure GPU memory usage from only + *two* consecutive FSDP instances (the current instance running + computation and the next instance whose all-gather is prefetched). + If ``False``, then FSDP allows the CPU thread to issue all-gathers + without any extra synchronization. (Default: ``True``) We often + refer to this feature as the "rate limiter". This flag should only + be set to ``False`` for specific CPU-bound workloads with low + memory pressure in which case the CPU thread can aggressively issue + all kernels without concern for the GPU memory usage. + use_orig_params (bool): Setting this to ``True`` has FSDP use + ``module`` 's original parameters. FSDP exposes those original + parameters to the user via :meth:`nn.Module.named_parameters` + instead of FSDP's internal :class:`FlatParameter` s. This means + that the optimizer step runs on the original parameters, enabling + per-original-parameter hyperparameters. FSDP preserves the original + parameter variables and manipulates their data between unsharded + and sharded forms, where they are always views into the underlying + unsharded or sharded :class:`FlatParameter`, respectively. With the + current algorithm, the sharded form is always 1D, losing the + original tensor structure. An original parameter may have all, + some, or none of its data present for a given rank. In the none + case, its data will be like a size-0 empty tensor. Users should not + author programs relying on what data is present for a given + original parameter in its sharded form. ``True`` is required to + use ``torch.compile()``. Setting this to ``False`` exposes FSDP's + internal :class:`FlatParameter` s to the user via + :meth:`nn.Module.named_parameters`. (Default: ``False``) + ignored_states (Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]]): + Ignored parameters or modules that will not be managed by this FSDP + instance, meaning that the parameters are not sharded and their + gradients are not reduced across ranks. This argument unifies with + the existing ``ignored_modules`` argument, and we may deprecate + ``ignored_modules`` soon. For backward compatibility, we keep both + ``ignored_states`` and `ignored_modules``, but FSDP only allows one + of them to be specified as not ``None``. + device_mesh (Optional[DeviceMesh]): DeviceMesh can be used as an alternative to + process_group. When device_mesh is passed, FSDP will use the underlying process + groups for all-gather and reduce-scatter collective communications. Therefore, + these two args need to be mutually exclusive. For hybrid sharding strategies such as + ``ShardingStrategy.HYBRID_SHARD``, users can pass in a 2D DeviceMesh instead + of a tuple of process groups. For 2D FSDP + TP, users are required to pass in + device_mesh instead of process_group. For more DeviceMesh info, please visit: + https://pytorch.org/tutorials/recipes/distributed_device_mesh.html + """ + + def __init__( + self, + module: nn.Module, + process_group: ProcessGroupType = None, + sharding_strategy: Optional[ShardingStrategy] = None, + cpu_offload: Optional[CPUOffload] = None, + auto_wrap_policy: Optional[ + Union[Callable, ModuleWrapPolicy, CustomPolicy] + ] = None, + backward_prefetch: Optional[BackwardPrefetch] = BackwardPrefetch.BACKWARD_PRE, + mixed_precision: Optional[MixedPrecision] = None, + ignored_modules: Optional[Iterable[torch.nn.Module]] = None, + param_init_fn: Optional[Callable[[nn.Module], None]] = None, + device_id: Optional[Union[int, torch.device]] = None, + sync_module_states: bool = False, + forward_prefetch: bool = False, + limit_all_gathers: bool = True, + use_orig_params: bool = False, + ignored_states: Union[ + Optional[Iterable[torch.nn.Parameter]], Optional[Iterable[torch.nn.Module]] + ] = None, + device_mesh: Optional[DeviceMesh] = None, + ): + torch._C._log_api_usage_once("torch.distributed.fsdp") + super().__init__() + if isinstance(module, (nn.ModuleList, nn.ModuleDict)): + warnings.warn( + "FSDP will not all-gather parameters for containers that do " + f"not implement forward: {module}", + stacklevel=2, + ) + _init_ignored_module_states(self, module, ignored_modules, ignored_states) + _init_device_handle(self, module, self._ignored_params, device_id) + + # Add module annotations for Dynamo support (see function for details) + _annotate_modules_for_dynamo(module, self._ignored_modules, use_orig_params) + + # Initializes self.process_group, along with rank and world size. This will + # also set another attribute, _inter_node_pg, to control the process group + # over which sharding occurs, if sharding_strategy is {HYBRID_SHARD, _HYBRID_SHARD_ZERO2}. + # Note that this is done before auto_wrapping, so that child FSDP modules simply pick up + # the same process group state as the root FSDP module. + self._device_mesh = device_mesh + _init_process_group_state( + self, + process_group, + sharding_strategy, + auto_wrap_policy, + device_mesh, + ) + if auto_wrap_policy is not None: + root_kwargs = { + "process_group": process_group, + "sharding_strategy": sharding_strategy, + "cpu_offload": cpu_offload, + "backward_prefetch": backward_prefetch, + "mixed_precision": mixed_precision, + "param_init_fn": param_init_fn, + "device_id": device_id, + "sync_module_states": sync_module_states, + "forward_prefetch": forward_prefetch, + "limit_all_gathers": limit_all_gathers, + "use_orig_params": use_orig_params, + "ignored_states": self._ignored_params, + "device_mesh": device_mesh, + } + if sharding_strategy in HYBRID_SHARDING_STRATEGIES and device_mesh is None: + # Share root process groups with children to maintain + # the invariant that all FSDP modules will have the same + # process groups. + root_kwargs["process_group"] = (self.process_group, self._inter_node_pg) + + _auto_wrap( + module, + auto_wrap_policy, + self._ignored_modules, + self._ignored_params, + root_kwargs, + FullyShardedDataParallel, + ) + + backward_prefetch_limit = 1 + forward_prefetch_limit = 1 + _init_core_state( + self, + sharding_strategy, + mixed_precision, + cpu_offload, + limit_all_gathers, + use_orig_params, + backward_prefetch_limit, + forward_prefetch_limit, + ) + _init_runtime_state(self) + _init_prefetching_state(self, backward_prefetch, forward_prefetch) + _init_buffer_state(self, module) + # extension needs to be set before `_init_param_handle_from_module()` + _init_extension(self, device_mesh) + _init_param_handle_from_module( + self, + module, + device_id, + param_init_fn, + sync_module_states, + ) + self._fsdp_wrapped_module = module + if not use_orig_params: + _check_orig_params_flattened(self, self._ignored_params) + _register_flat_param(self, self) + + # `_state_dict_type` controls the `state_dict()` behavior, which is + # implemented using post-save and pre-load hooks + _init_state_dict_state(self) + _register_all_state_dict_hooks(self) + self._zero_scalar = None + + @property + def module(self) -> nn.Module: + """Return the wrapped module.""" + # FSDP's `.module` must refer to the innermost wrapped module when + # composing with other module wrappers in order for state dict to work + if isinstance(self._fsdp_wrapped_module, ActivationWrapper): + return getattr(self._fsdp_wrapped_module, _CHECKPOINT_WRAPPED_MODULE) + return self._fsdp_wrapped_module + + @property + def _has_params(self) -> bool: + """Returns whether this FSDP instance manages any parameters.""" + return hasattr(self, "_handle") and self._handle is not None + + @property + def _flat_param(self) -> Optional[FlatParameter]: + return self._handle.flat_param if self._handle else None + + def __getattr__(self, name: str) -> Any: + """Forward missing attributes to the wrapped module.""" + try: + return super().__getattr__(name) # defer to nn.Module's logic + except AttributeError: + return getattr(self._fsdp_wrapped_module, name) + + def __getitem__(self, key: int) -> Any: + """Forward indexing calls in case the module is an ``nn.Sequential``.""" + if hasattr(self, FSDP_WRAPPED_MODULE): + return self._fsdp_wrapped_module.__getitem__(key) # type: ignore[operator] + return super().__getitem__(key) + + def check_is_root(self) -> bool: + """Check if this instance is a root FSDP module.""" + return _is_fsdp_root(self, self) + + @staticmethod + def fsdp_modules( + module: nn.Module, + root_only: bool = False, + ) -> list["FullyShardedDataParallel"]: + """Return all nested FSDP instances. + + This possibly includes ``module`` itself and only includes FSDP root modules if ``root_only=True``. + + Args: + module (torch.nn.Module): Root module, which may or may not be an + ``FSDP`` module. + root_only (bool): Whether to return only FSDP root modules. + (Default: ``False``) + + Returns: + List[FullyShardedDataParallel]: FSDP modules that are nested in + the input ``module``. + """ + if root_only: + return _get_fsdp_root_states(module) + return traversal_utils._get_fsdp_states(module) + + def apply(self, fn: Callable[[nn.Module], None]) -> "FullyShardedDataParallel": + r"""Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self. + + Typical use includes initializing the parameters of a model (see also :ref:`nn-init-doc`). + + Compared to ``torch.nn.Module.apply``, this version additionally gathers + the full parameters before applying ``fn``. It should not be called from + within another ``summon_full_params`` context. + + Args: + fn (:class:`Module` -> None): function to be applied to each submodule + + Returns: + Module: self + """ + uninitialized = self._is_root is None + self._assert_state(TrainingState.IDLE) + # Use `_unshard_params_for_summon()` with `recurse=False` instead of + # `_unshard_fsdp_state_params()` directly to perform lazy + # initialization, which is needed to initialize `FlatParameter` + # parameter attributes as required by the unshard logic + with _unshard_params_for_summon( + self, + self, + writeback=True, + rank0_only=False, + offload_to_cpu=False, + with_grads=False, + ): + ret = super().apply(fn) + + # Reset lazy init called in `_unshard_params_for_summon()` since + # `apply()` may have been called on FSDP instance that is not truly a + # root, in which case it will be incorrectly marked as one. + if uninitialized and self._is_root: + for module in traversal_utils._get_fsdp_states(self): + module._reset_lazy_init() + + return ret + + def _mixed_precision_enabled_for_buffers(self) -> bool: + """Return whether the user explicitly enabled buffer mixed precision. + + NOTE: Unlike parameters and gradient reduction, buffer mixed precision + is applied at the FSDP instance level, not the ``FlatParameter`` level, + which may be different for the composable code path. + """ + return self.mixed_precision.buffer_dtype is not None + + def _low_precision_hook_enabled(self) -> bool: + """Whether a low precision hook is registered or not.""" + return self._comm_hook is not None and self._comm_hook in LOW_PRECISION_HOOKS + + def _reset_lazy_init(self) -> None: + """Reset instance so :func:`_lazy_init` will run on the next forward.""" + self._is_root: Optional[bool] = None + + @staticmethod + def set_state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> StateDictSettings: + """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. + + Also takes (optional) configuration for the model's and optimizer's state dict. + The target module does not have to be a FSDP module. If the target + module is a FSDP module, its ``state_dict_type`` will also be changed. + + .. note:: This API should be called for only the top-level (root) + module. + + .. note:: This API enables users to transparently use the conventional + ``state_dict`` API to take model checkpoints in cases where the + root FSDP module is wrapped by another ``nn.Module``. For example, + the following will ensure ``state_dict`` is called on all non-FSDP + instances, while dispatching into `sharded_state_dict` implementation + for FSDP: + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.SHARDED_STATE_DICT, + >>> state_dict_config = ShardedStateDictConfig(offload_to_cpu=True), + >>> optim_state_dict_config = OptimStateDictConfig(offload_to_cpu=True), + >>> ) + >>> param_state_dict = model.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict(model, optim) + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + state_dict_config (Optional[StateDictConfig]): the configuration for the + target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the configuration + for the optimizer state dict. + + Returns: + A StateDictSettings that include the previous state_dict type and + configuration for the module. + """ + warnings.warn( + "FSDP.state_dict_type() and FSDP.set_state_dict_type() are being " + "deprecated. Please use APIs, get_state_dict() and set_state_dict(), " + "which can support different parallelisms, FSDP1, FSDP2, DDP. " + "API doc: https://pytorch.org/docs/stable/distributed.checkpoint.html" + "#torch.distributed.checkpoint.state_dict.get_state_dict ." + "Tutorial: https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html .", + FutureWarning, + ) + _state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedStateDictConfig, + } + _optim_state_dict_type_to_config = { + StateDictType.FULL_STATE_DICT: FullOptimStateDictConfig, + StateDictType.LOCAL_STATE_DICT: LocalOptimStateDictConfig, + StateDictType.SHARDED_STATE_DICT: ShardedOptimStateDictConfig, + } + + # Use the default config if a state_dict config is not set. + state_dict_config_type = _state_dict_type_to_config[state_dict_type] + optim_state_dict_config_type = _optim_state_dict_type_to_config[state_dict_type] + if state_dict_config is None: + state_dict_config = state_dict_config_type() + if optim_state_dict_config is None: + optim_state_dict_config = optim_state_dict_config_type() + if state_dict_config_type != type(state_dict_config): + raise RuntimeError( + f"Expected state_dict_config of type {state_dict_config_type} " + f"but got {type(state_dict_config)}" + ) + if optim_state_dict_config_type != type(optim_state_dict_config): + raise RuntimeError( + f"Expected optim_state_dict_config of type {optim_state_dict_config_type} " + f"but got {type(optim_state_dict_config)}" + ) + + # Set the state_dict type and configurations. + prev_state_dict_type = None + prev_state_dict_config = None + prev_optim_state_dict_config = None + for submodule in traversal_utils._get_fsdp_states(module): + if prev_state_dict_type is None: + prev_state_dict_type = submodule._state_dict_type + else: + assert prev_state_dict_type == submodule._state_dict_type, ( + "All FSDP modules should have the same state_dict_type." + ) + if prev_state_dict_config is None: + prev_state_dict_config = submodule._state_dict_config + else: + assert isinstance( + submodule._state_dict_config, type(prev_state_dict_config) + ), "All FSDP modules must have the same type of state_dict_config." + if prev_optim_state_dict_config is None: + prev_optim_state_dict_config = submodule._optim_state_dict_config + else: + assert isinstance( + submodule._optim_state_dict_config, + type(prev_optim_state_dict_config), + ), ( + "All FSDP modules must have the same type of optim_state_dict_config." + ) + + submodule._state_dict_type = state_dict_type + submodule._state_dict_config = state_dict_config + submodule._optim_state_dict_config = optim_state_dict_config + + return StateDictSettings( + prev_state_dict_type, prev_state_dict_config, prev_optim_state_dict_config + ) + + @staticmethod + def get_state_dict_type(module: nn.Module) -> StateDictSettings: + """Get the state_dict_type and the corresponding configurations for the FSDP modules rooted at ``module``. + + The target module does not have to be an FSDP module. + + Returns: + A ``StateDictSettings`` containing the state_dict_type and + state_dict / optim_state_dict configs that are currently set. + + Raises: + ``AssertionError`` if the ``StateDictSettings`` for different + FSDP submodules differ. + """ + state_dict_settings: Optional[StateDictSettings] = None + for submodule in FullyShardedDataParallel.fsdp_modules(module): + if state_dict_settings is None: + state_dict_settings = StateDictSettings( + state_dict_type=submodule._state_dict_type, + state_dict_config=submodule._state_dict_config, + optim_state_dict_config=submodule._optim_state_dict_config, + ) + _set_optim_use_dtensor(submodule, state_dict_settings) + else: + submodule_settings = StateDictSettings( + submodule._state_dict_type, + submodule._state_dict_config, + submodule._optim_state_dict_config, + ) + assert state_dict_settings == submodule_settings, ( + "All FSDP modules must have the same state dict settings." + f"Got {submodule_settings} and {state_dict_settings}." + ) + _set_optim_use_dtensor(submodule, submodule_settings) + return state_dict_settings + + @staticmethod + @contextlib.contextmanager + def state_dict_type( + module: nn.Module, + state_dict_type: StateDictType, + state_dict_config: Optional[StateDictConfig] = None, + optim_state_dict_config: Optional[OptimStateDictConfig] = None, + ) -> Generator: + """Set the ``state_dict_type`` of all the descendant FSDP modules of the target module. + + This context manager has the same functions as :meth:`set_state_dict_type`. Read the document of + :meth:`set_state_dict_type` for the detail. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = DDP(FSDP(...)) + >>> with FSDP.state_dict_type( + >>> model, + >>> StateDictType.SHARDED_STATE_DICT, + >>> ): + >>> checkpoint = model.state_dict() + + Args: + module (torch.nn.Module): Root module. + state_dict_type (StateDictType): the desired ``state_dict_type`` to set. + state_dict_config (Optional[StateDictConfig]): the model ``state_dict`` + configuration for the target ``state_dict_type``. + optim_state_dict_config (Optional[OptimStateDictConfig]): the optimizer + ``state_dict`` configuration for the target ``state_dict_type``. + """ + prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type( + module, + state_dict_type, + state_dict_config, + optim_state_dict_config, + ) + yield + FullyShardedDataParallel.set_state_dict_type( + module, + prev_state_dict_settings.state_dict_type, + prev_state_dict_settings.state_dict_config, + prev_state_dict_settings.optim_state_dict_config, + ) + + def forward(self, *args: Any, **kwargs: Any) -> Any: + """Run the forward pass for the wrapped module, inserting FSDP-specific pre- and post-forward sharding logic.""" + handle = self._handle + with torch.autograd.profiler.record_function( + "FullyShardedDataParallel.forward" + ): + args, kwargs = _root_pre_forward(self, self, args, kwargs) + unused = None + args, kwargs = _pre_forward( + self, + handle, + _pre_forward_unshard, + self._fsdp_wrapped_module, + args, + kwargs, + ) + if handle: + _p_assert( + handle.flat_param.device == self.compute_device, + "Expected `FlatParameter` to be on the compute device " + f"{self.compute_device} but got {handle.flat_param.device}", + ) + output = self._fsdp_wrapped_module(*args, **kwargs) + return _post_forward( + self, handle, _post_forward_reshard, self, unused, output + ) + + @staticmethod + @contextlib.contextmanager + def summon_full_params( + module: nn.Module, + recurse: bool = True, + writeback: bool = True, + rank0_only: bool = False, + offload_to_cpu: bool = False, + with_grads: bool = False, + ) -> Generator: + r"""Expose full params for FSDP instances with this context manager. + + Can be useful *after* forward/backward for a model to get + the params for additional processing or checking. It can take a non-FSDP + module and will summon full params for all contained FSDP modules as + well as their children, depending on the ``recurse`` argument. + + .. note:: This can be used on inner FSDPs. + .. note:: This can *not* be used within a forward or backward pass. Nor + can forward and backward be started from within this context. + .. note:: Parameters will revert to their local shards after the context + manager exits, storage behavior is the same as forward. + .. note:: The full parameters can be modified, but only the portion + corresponding to the local param shard will persist after the + context manager exits (unless ``writeback=False``, in which case + changes will be discarded). In the case where FSDP does not shard + the parameters, currently only when ``world_size == 1``, or ``NO_SHARD`` + config, the modification is persisted regardless of ``writeback``. + .. note:: This method works on modules which are not FSDP themselves but + may contain multiple independent FSDP units. In that case, the given + arguments will apply to all contained FSDP units. + + .. warning:: Note that ``rank0_only=True`` in conjunction with + ``writeback=True`` is not currently supported and will raise an + error. This is because model parameter shapes would be different + across ranks within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + + .. warning:: Note that ``offload_to_cpu`` and ``rank0_only=False`` will + result in full parameters being redundantly copied to CPU memory for + GPUs that reside on the same machine, which may incur the risk of + CPU OOM. It is recommended to use ``offload_to_cpu`` with + ``rank0_only=True``. + + Args: + recurse (bool, Optional): recursively summon all params for nested + FSDP instances (default: True). + writeback (bool, Optional): if ``False``, modifications to params are + discarded after the context manager exits; + disabling this can be slightly more efficient (default: True) + rank0_only (bool, Optional): if ``True``, full parameters are + materialized on only global rank 0. This means that within the + context, only rank 0 will have full parameters and the other + ranks will have sharded parameters. Note that setting + ``rank0_only=True`` with ``writeback=True`` is not supported, + as model parameter shapes will be different across ranks + within the context, and writing to them can lead to + inconsistency across ranks when the context is exited. + offload_to_cpu (bool, Optional): If ``True``, full parameters are + offloaded to CPU. Note that this offloading currently only + occurs if the parameter is sharded (which is only not the case + for world_size = 1 or ``NO_SHARD`` config). It is recommended + to use ``offload_to_cpu`` with ``rank0_only=True`` to avoid + redundant copies of model parameters being offloaded to the same CPU memory. + with_grads (bool, Optional): If ``True``, gradients are also + unsharded with the parameters. Currently, this is only + supported when passing ``use_orig_params=True`` to the FSDP + constructor and ``offload_to_cpu=False`` to this method. + (Default: ``False``) + """ + with _unshard_params( + module, recurse, writeback, rank0_only, offload_to_cpu, with_grads + ): + yield + + @contextlib.contextmanager + def _deregister_orig_params_ctx(self): + """Deregister the original parameters and expose the :class:`FlatParameter`. + + If a :class:`FlatParameter` is sharded, then + this refreshes the sharded views before exiting. This method should + only be called when using the original parameters. + """ + _p_assert( + self._use_orig_params, + "`_deregister_orig_params_ctx()` should only be called when " + "`_use_orig_params=True`", + ) + for fsdp_module in traversal_utils._get_fsdp_states(self): + _deregister_orig_params(fsdp_module, fsdp_module) + try: + yield + finally: + for fsdp_module in traversal_utils._get_fsdp_states(self): + _register_orig_params(fsdp_module, fsdp_module) + + def _apply(self, *args, **kwargs): + """Deregister the original parameters and expose the :class:`FlatParameter` s before calling ``_apply()``.""" + # When using the original parameters: Since (1) the `FlatParameter`s + # own the storage and (2) `_apply()` is the subroutine underlying the + # most common storage-changing ops like `to()` and `cuda()`, we + # override `_apply()` to have the storage change directly performed on + # the `FlatParameter`s instead of applying to the original parameters + # and then writing back to the `FlatParameter`s. + context = ( + self._deregister_orig_params_ctx() + if self._use_orig_params + else contextlib.nullcontext() + ) + with context: + return super()._apply(*args, **kwargs) + + def named_buffers( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Return an iterator over module buffers, yielding both the name of the buffer and the buffer itself. + + Intercepts buffer names and removes all occurrences of the FSDP-specific flattened buffer prefix + when inside the :meth:`summon_full_params` context manager. + """ + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + for buffer_name, buffer in super().named_buffers(*args, **kwargs): + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + buffer_name = buffer_name.replace(FSDP_PREFIX, "") + yield (buffer_name, buffer) + + def named_parameters( + self, + *args, + **kwargs, + ) -> Iterator[tuple[str, torch.nn.Parameter]]: + """Return an iterator over module parameters, yielding both the name of the parameter and the parameter itself. + + Intercepts parameter names and removes all occurrences of the FSDP-specific flattened parameter prefix + when inside the :meth:`summon_full_params` context manager. + """ + should_clean_name = self.training_state == TrainingState.SUMMON_FULL_PARAMS + for param_name, param in super().named_parameters(*args, **kwargs): + if should_clean_name: + # Remove any instances of the FSDP-specific prefix; there can + # be multiple in the case of nested FSDP modules + param_name = param_name.replace(FSDP_PREFIX, "") + yield (param_name, param) + + def _assert_state(self, state: Union[TrainingState, list[TrainingState]]) -> None: + """Assert we are in the given state.""" + # Since assert can be turned off and this error checking + # is really important, we use explicit error checking + # and raise a ValueError if needed. + if isinstance(state, TrainingState): + state = [state] + if self.training_state not in state: + msg = ( + f"expected to be in states {state} but current state " + f"is {self.training_state}" + ) + # In case we are failing in the context of autograd hook, asserting + # may not generate useful msg. So, let's print it to be sure. + if self.rank == 0: + print(f"Asserting FSDP instance is: {self}") + print(f"ERROR: {msg}") + traceback.print_stack() + raise ValueError(msg) + + @contextmanager + def no_sync(self) -> Generator: + """Disable gradient synchronizations across FSDP instances. + + Within this context, gradients will be accumulated in module + variables, which will later be synchronized in the first + forward-backward pass after exiting the context. This should only be + used on the root FSDP instance and will recursively apply to all + children FSDP instances. + + .. note:: This likely results in higher memory usage because FSDP will + accumulate the full model gradients (instead of gradient shards) + until the eventual sync. + + .. note:: When used with CPU offloading, the gradients will not be + offloaded to CPU when inside the context manager. Instead, they + will only be offloaded right after the eventual sync. + """ + _lazy_init(self, self) + if not self._is_root: + raise RuntimeError( + "`no_sync()` on inner FSDP instances is not supported. Please call `no_sync()` on root FSDP module." + ) + self._assert_state(TrainingState.IDLE) + old_flags = [] + for m in self.modules(): + if isinstance(m, FullyShardedDataParallel): + old_flags.append((m, m._sync_gradients)) + m._sync_gradients = False + try: + yield + finally: + for m, old_flag in old_flags: + assert not m._sync_gradients, ( + "`_sync_gradients` was incorrectly set to " + "`True` while in the `no_sync()` context manager" + ) + m._sync_gradients = old_flag + + @torch.no_grad() + def clip_grad_norm_( + self, max_norm: Union[float, int], norm_type: Union[float, int] = 2.0 + ) -> torch.Tensor: + """Clip the gradient norm of all parameters. + + The norm is computed over all parameters' gradients as viewed as a single vector, and the + gradients are modified in-place. + + Args: + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` + for infinity norm. + + Returns: + Total norm of the parameters (viewed as a single vector). + + If every FSDP instance uses ``NO_SHARD``, meaning that no + gradients are sharded across ranks, then you may directly use + :func:`torch.nn.utils.clip_grad_norm_`. + + If at least some FSDP instance uses a sharded strategy (i.e. + one other than ``NO_SHARD``), then you should use this method + instead of :func:`torch.nn.utils.clip_grad_norm_` since this method + handles the fact that gradients are sharded across ranks. + + The total norm returned will have the "largest" dtype across + all parameters/gradients as defined by PyTorch's type promotion + semantics. For example, if *all* parameters/gradients use a low + precision dtype, then the returned norm's dtype will be that low + precision dtype, but if there exists at least one parameter/ + gradient using FP32, then the returned norm's dtype will be FP32. + + .. warning:: This needs to be called on all ranks since it uses + collective communications. + """ + _lazy_init(self, self) + if not self._is_root: + raise RuntimeError( + "`clip_grad_norm_()` should only be called on the root FSDP instance" + ) + if self._zero_scalar is None: + self._zero_scalar = torch.tensor(0.0, device=self.compute_device) + self._assert_state(TrainingState.IDLE) + # If every FSDP instance uses `NO_SHARD`, then we can directly use + # the normal `nn.utils` one targeting local gradients + all_no_shard = all( + not handle.uses_sharded_strategy for handle in self._all_handles + ) + if all_no_shard: + return torch.nn.utils.clip_grad_norm_( + self.parameters(), max_norm, norm_type + ) + # Otherwise, there exists some FSDP instance using a sharded strategy, + # where sharded and non-sharded parameters must be handled separately + max_norm = float(max_norm) + norm_type = float(norm_type) + sharded_params_set = set() + nonsharded_params_set = set() # `NO_SHARD` or not FSDP-managed + # Make sure to compute the local norm using lists for deterministic + # iteration order and hence deterministic total norm computation + sharded_params = [] + nonsharded_params = [] + grads: list[torch.Tensor] = [] + for handle in self._all_handles: + if handle.uses_sharded_strategy: + target_set = sharded_params_set + target_list = sharded_params + else: + target_set = nonsharded_params_set + target_list = nonsharded_params + if handle._use_orig_params: + for param in handle.flat_param._params: + if param not in target_set: + target_set.add(param) + target_list.append(param) + if param.grad is not None: + grads.append(param.grad) + else: + if handle.flat_param not in target_set: + target_set.add(handle.flat_param) + target_list.append(handle.flat_param) + if handle.flat_param.grad is not None: + grads.append(handle.flat_param.grad) + for param in self.parameters(): + not_fsdp_managed = ( + param not in sharded_params_set and param not in nonsharded_params_set + ) + if not_fsdp_managed: + nonsharded_params_set.add(param) + nonsharded_params.append(param) + if param.grad is not None: + grads.append(param.grad) + # Compute local norms (forced to be in FP32) + local_sharded_norm = _get_grad_norm( + sharded_params, norm_type, self._zero_scalar, self.compute_device + ) + local_nonsharded_norm = ( + _get_grad_norm( + nonsharded_params, norm_type, self._zero_scalar, self.compute_device + ) + if nonsharded_params + else None + ) + # Reconstruct the total gradient norm depending on the norm type + if norm_type == math.inf: + total_norm = ( + torch.maximum(local_sharded_norm, local_nonsharded_norm) + if local_nonsharded_norm is not None + else local_sharded_norm + ) + dist.all_reduce( + total_norm, op=torch.distributed.ReduceOp.MAX, group=self.process_group + ) + else: + total_norm = local_sharded_norm**norm_type + dist.all_reduce(total_norm, group=self.process_group) + # All-reducing the local non-sharded norm would count it an extra + # world-size-many times + if local_nonsharded_norm is not None: + total_norm += local_nonsharded_norm**norm_type + total_norm = total_norm ** (1.0 / norm_type) + if self.cpu_offload.offload_params: + total_norm = total_norm.cpu() + + clip_coef = max_norm / (total_norm + 1e-6) + # Multiplying by the clamped coefficient is meaningless when it is + # equal to 1, but it avoids the host-device sync that would result from + # `if clip_coef < 1` + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for grad in grads: + grad.mul_(clip_coef_clamped.to(grad.device, grad.dtype)) + # Use the "largest" dtype by type promotion semantics to use the same + # dtype as if we did not force local norm computation to be in FP32 + if len(grads) == 0: + # If this rank has no gradients, then we must default to FP32 + # unless we use additional communication, which we prefer to avoid + # since `clip_grad_norm_()` is called in the training loop + warnings.warn( + f"Called FSDP.clip_grad_norm_() on rank {self.rank} with no " + "gradients -- returning the total norm in the default dtype " + f"{total_norm.dtype}" + ) # warn since this is generally unexpected + return total_norm + total_norm_dtype = functools.reduce( + torch.promote_types, + [grad.dtype for grad in grads], + ) + return total_norm.to(total_norm_dtype) + + @staticmethod + def _warn_optim_input(optim_input, *, stacklevel: int = 1): + if optim_input is not None: + warnings.warn( + "The `optim_input` argument is deprecated and will be removed after PyTorch 1.13. " + "You may remove it from your code without changing its functionality.", + FutureWarning, + stacklevel=stacklevel + 1, + ) + + @staticmethod + def _is_using_optim_input(optim_input, optim) -> bool: + if optim_input is None and optim is None: + # Use the default behavior of `optim_input`` + return True + if optim_input is not None: + # Use the `optim_input` code path + return True + # Use the `optim` code path + return False + + @staticmethod + def _warn_legacy_optim_state_dict(curr: str, new: str, *, stacklevel: int = 1): + warnings.warn( + f"``FullyShardedDataParallel.{curr}``is being deprecated and is " + f"replaced by ``FullyShardedDataParallel.{new}``. " + f"``FullyShardedDataParallel.{curr}`` may be removed after PyTorch 2.2.", + FutureWarning, + stacklevel=stacklevel + 1, + ) + + @staticmethod + def _optim_state_dict_impl( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + rank0_only: bool = True, + full_state_dict: bool = True, + group: Optional[dist.ProcessGroup] = None, + cpu_offload: bool = True, + *, + _stacklevel: int = 1, + ) -> dict[str, Any]: + """Transform the state-dict of an optimizer corresponding to a sharded model. + + This is the internal API that is used by all the optim_state_dict implementations. + Given model, optim, the original optim_state_dict, this API removes the + FSDP internal information and internal sharding from the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input( + optim_input, stacklevel=_stacklevel + 1 + ) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + assert optim_input is None and not rank0_only + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ + 0 + ]._use_orig_params + assert all( + use_orig_params == m._use_orig_params + for m in FullyShardedDataParallel.fsdp_modules(model) + ), "Not all FSDP modules have the same _use_orig_params value" + + return _optim_state_dict( + model=model, + optim=optim, + optim_state_dict=optim_state_dict, + optim_input=optim_input, + rank0_only=rank0_only, + shard_state=not full_state_dict, + group=group, + using_optim_input=using_optim_input, + use_orig_params=use_orig_params, + cpu_offload=cpu_offload, + ) + + @staticmethod + def _optim_state_dict_to_load_impl( + optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + full_state_dict: bool = True, + rank0_only: bool = False, + is_named_optimizer: bool = False, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. + + This is the internal API that is used by all the load optim_state_dict implementations. + Given model, optim, and the saved optim_state_dict, this API adds the FSDP + internal information and internal sharding to the optim_state_dict. + """ + if full_state_dict: + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + else: + using_optim_input = False + assert optim_input is None and not rank0_only + + use_orig_params = FullyShardedDataParallel.fsdp_modules(model)[ + 0 + ]._use_orig_params + assert all( + use_orig_params == m._use_orig_params + for m in FullyShardedDataParallel.fsdp_modules(model) + ), "Not all FSDP modules have the same _use_orig_params value" + + if rank0_only and dist.get_rank(group) > 0: + optim_state_dict = {} + sharded_osd = _flatten_optim_state_dict( + optim_state_dict, + model=model, + use_orig_params=use_orig_params, + optim=(optim if is_named_optimizer else None), + rank0_only=rank0_only, + group=group, + ) + return _rekey_sharded_optim_state_dict( + sharded_osd, + model=model, + optim=optim, + optim_input=optim_input, + using_optim_input=using_optim_input, + is_named_optimizer=is_named_optimizer, + ) + + @staticmethod + def full_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + rank0_only: bool = True, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """Return the full optimizer state-dict. + + Consolidates the full optimizer state on rank 0 and returns it + as a :class:`dict` following the convention of + :meth:`torch.optim.Optimizer.state_dict`, i.e. with keys ``"state"`` + and ``"param_groups"``. The flattened parameters in ``FSDP`` modules + contained in ``model`` are mapped back to their unflattened parameters. + + This needs to be called on all ranks since it uses + collective communications. However, if ``rank0_only=True``, then + the state dict is only populated on rank 0, and all other ranks + return an empty :class:`dict`. + + Unlike ``torch.optim.Optimizer.state_dict()``, this method + uses full parameter names as keys instead of parameter IDs. + + Like in :meth:`torch.optim.Optimizer.state_dict`, the tensors + contained in the optimizer state dict are not cloned, so there may + be aliasing surprises. For best practices, consider saving the + returned optimizer state dict immediately, e.g. using + ``torch.save()``. + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer ``optim`` representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + rank0_only (bool): If ``True``, saves the populated :class:`dict` + only on rank 0; if ``False``, saves it on all ranks. (Default: + ``True``) + group (dist.ProcessGroup): Model's process group or ``None`` if using + the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model`` 's original unflattened parameters and including keys + "state" and "param_groups" following the convention of + :meth:`torch.optim.Optimizer.state_dict`. If ``rank0_only=True``, + then nonzero ranks return an empty :class:`dict`. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "full_optim_state_dict", + "optim_state_dict", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=optim_input, + rank0_only=rank0_only, + group=group, + full_state_dict=True, + _stacklevel=2, + ) + + @staticmethod + def sharded_optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """Return the optimizer state-dict in its sharded form. + + The API is similar to :meth:`full_optim_state_dict` but this API chunks + all non-zero-dimension states to :class:`ShardedTensor` to save memory. + This API should only be used when the model ``state_dict`` is derived + with the context manager ``with state_dict_type(SHARDED_STATE_DICT):``. + + For the detailed usage, refer to :meth:`full_optim_state_dict`. + + .. warning:: The returned state dict contains ``ShardedTensor`` and + cannot be directly used by the regular ``optim.load_state_dict``. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "sharded_optim_state_dict", + "optim_state_dict", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim.state_dict(), + optim_input=None, + rank0_only=False, + full_state_dict=False, + group=group, + _stacklevel=2, + ) + + @staticmethod + def shard_full_optim_state_dict( + full_optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> dict[str, Any]: + """Shard a full optimizer state-dict. + + Remaps the state in ``full_optim_state_dict`` to flattened parameters instead of unflattened + parameters and restricts to only this rank's part of the optimizer state. + The first argument should be the return value of :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) + >>> torch.save(full_osd, PATH) + >>> # Define new model with possibly different world size + >>> new_model, new_optim = ... + >>> full_osd = torch.load(PATH) + >>> sharded_osd = FSDP.shard_full_optim_state_dict(full_osd, new_model) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + full non-sharded optimizer state. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "shard_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + is_named_optimizer=False, + ) + + @staticmethod + def flatten_sharded_optim_state_dict( + sharded_optim_state_dict: dict[str, Any], + model: torch.nn.Module, + optim: torch.optim.Optimizer, + ) -> dict[str, Any]: + """Flatten a sharded optimizer state-dict. + + The API is similar to :meth:`shard_full_optim_state_dict`. The only + difference is that the input ``sharded_optim_state_dict`` should be + returned from :meth:`sharded_optim_state_dict`. Therefore, there will + be all-gather calls on each rank to gather ``ShardedTensor`` s. + + Args: + sharded_optim_state_dict (Dict[str, Any]): Optimizer state dict + corresponding to the unflattened parameters and holding the + sharded optimizer state. + model (torch.nn.Module): + Refer to :meth:`shard_full_optim_state_dict`. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + + Returns: + Refer to :meth:`shard_full_optim_state_dict`. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "flatten_sharded_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=sharded_optim_state_dict, + model=model, + optim_input=None, + optim=optim, + full_state_dict=False, + is_named_optimizer=False, + ) + + @staticmethod + def scatter_full_optim_state_dict( + full_optim_state_dict: Optional[dict[str, Any]], + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + group: Optional[Any] = None, + ) -> dict[str, Any]: + """Scatter the full optimizer state dict from rank 0 to all other ranks. + + Returns the sharded optimizer state dict on each rank. + The return value is the same as :meth:`shard_full_optim_state_dict`, and on rank + 0, the first argument should be the return value of + :meth:`full_optim_state_dict`. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> model, optim = ... + >>> full_osd = FSDP.full_optim_state_dict(model, optim) # only non-empty on rank 0 + >>> # Define new model with possibly different world size + >>> new_model, new_optim, new_group = ... + >>> sharded_osd = FSDP.scatter_full_optim_state_dict(full_osd, new_model, group=new_group) + >>> new_optim.load_state_dict(sharded_osd) + + .. note:: Both :meth:`shard_full_optim_state_dict` and + :meth:`scatter_full_optim_state_dict` may be used to get the + sharded optimizer state dict to load. Assuming that the full + optimizer state dict resides in CPU memory, the former requires + each rank to have the full dict in CPU memory, where each rank + individually shards the dict without any communication, while the + latter requires only rank 0 to have the full dict in CPU memory, + where rank 0 moves each shard to GPU memory (for NCCL) and + communicates it to ranks appropriately. Hence, the former has + higher aggregate CPU memory cost, while the latter has higher + communication cost. + + Args: + full_optim_state_dict (Optional[Dict[str, Any]]): Optimizer state + dict corresponding to the unflattened parameters and holding + the full non-sharded optimizer state if on rank 0; the argument + is ignored on nonzero ranks. + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + correspond to the optimizer state in ``full_optim_state_dict``. + optim_input (Optional[Union[List[Dict[str, Any]], Iterable[torch.nn.Parameter]]]): + Input passed into the optimizer representing either a + :class:`list` of parameter groups or an iterable of parameters; + if ``None``, then this method assumes the input was + ``model.parameters()``. This argument is deprecated, and there + is no need to pass it in anymore. (Default: ``None``) + optim (Optional[torch.optim.Optimizer]): Optimizer that will load + the state dict returned by this method. This is the preferred + argument to use over ``optim_input``. (Default: ``None``) + group (dist.ProcessGroup): Model's process group or ``None`` if + using the default process group. (Default: ``None``) + + Returns: + Dict[str, Any]: The full optimizer state dict now remapped to + flattened parameters instead of unflattened parameters and + restricted to only include this rank's part of the optimizer state. + """ + FullyShardedDataParallel._warn_legacy_optim_state_dict( + "scatter_full_optim_state_dict", + "optim_state_dict_to_load", + stacklevel=2, + ) + return FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=full_optim_state_dict, + model=model, + optim_input=optim_input, + optim=optim, + full_state_dict=True, + rank0_only=True, + is_named_optimizer=False, + group=group, + ) + + @staticmethod + def rekey_optim_state_dict( + optim_state_dict: dict[str, Any], + optim_state_key_type: OptimStateKeyType, + model: torch.nn.Module, + optim_input: Optional[ + Union[ + list[dict[str, Any]], + Iterable[torch.nn.Parameter], + ] + ] = None, + optim: Optional[torch.optim.Optimizer] = None, + ) -> dict[str, Any]: + """Re-keys the optimizer state dict ``optim_state_dict`` to use the key type ``optim_state_key_type``. + + This can be used to achieve compatibility between optimizer state dicts from models with FSDP + instances and ones without. + + To re-key an FSDP full optimizer state dict (i.e. from + :meth:`full_optim_state_dict`) to use parameter IDs and be loadable to + a non-wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> wrapped_model, wrapped_optim = ... + >>> full_osd = FSDP.full_optim_state_dict(wrapped_model, wrapped_optim) + >>> nonwrapped_model, nonwrapped_optim = ... + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(full_osd, OptimStateKeyType.PARAM_ID, nonwrapped_model) + >>> nonwrapped_optim.load_state_dict(rekeyed_osd) + + To re-key a normal optimizer state dict from a non-wrapped model to be + loadable to a wrapped model:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> nonwrapped_model, nonwrapped_optim = ... + >>> osd = nonwrapped_optim.state_dict() + >>> rekeyed_osd = FSDP.rekey_optim_state_dict(osd, OptimStateKeyType.PARAM_NAME, nonwrapped_model) + >>> wrapped_model, wrapped_optim = ... + >>> sharded_osd = FSDP.shard_full_optim_state_dict(rekeyed_osd, wrapped_model) + >>> wrapped_optim.load_state_dict(sharded_osd) + + Returns: + Dict[str, Any]: The optimizer state dict re-keyed using the + parameter keys specified by ``optim_state_key_type``. + """ + FullyShardedDataParallel._warn_optim_input(optim_input) + using_optim_input = FullyShardedDataParallel._is_using_optim_input( + optim_input, + optim, + ) + assert optim_state_key_type in ( + OptimStateKeyType.PARAM_NAME, + OptimStateKeyType.PARAM_ID, + ) + osd = optim_state_dict # alias + # Validate that the existing parameter keys are uniformly typed + uses_param_name_mask = [type(param_key) is str for param_key in osd["state"]] + uses_param_id_mask = [type(param_key) is int for param_key in osd["state"]] + if (any(uses_param_name_mask) and not all(uses_param_name_mask)) or ( + any(uses_param_id_mask) and not all(uses_param_id_mask) + ): + error_msg = f"Invalid parameter keys: {osd['state'].keys()}" + raise ValueError(error_msg) + # Return directly if the existing key type matches the target key type + if ( + optim_state_key_type == OptimStateKeyType.PARAM_NAME + and all(uses_param_name_mask) + ) or ( + optim_state_key_type == OptimStateKeyType.PARAM_ID + and all(uses_param_id_mask) + ): + return osd + # Otherwise, actually perform the re-keying + new_osd = {} + if optim_state_key_type == OptimStateKeyType.PARAM_NAME: # ID -> name + param_id_to_param = ( + _get_param_id_to_param_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_key_to_param(optim) + ) + param_to_param_name = _get_param_to_fqn(model) + param_id_to_param_name: list[str] = [ + param_to_param_name[param] for param in param_id_to_param.values() + ] + new_osd["state"] = { + param_id_to_param_name[param_id]: param_state + for param_id, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted( + [ + param_id_to_param_name[param_id] + for param_id in param_group["params"] + ] + ) + return new_osd + elif optim_state_key_type == OptimStateKeyType.PARAM_ID: # name -> ID + param_name_to_param = _get_fqn_to_param(model) + param_to_param_id = ( + _get_param_to_param_id_from_optim_input(model, optim_input) + if using_optim_input + else _get_param_to_param_key(optim) + ) + # Because not all model parameters may be passed as the optimizer + # input, we may need to drop some parameters from this mapping + param_name_to_param_id = { + param_name: param_to_param_id[param] + for param_name, param in param_name_to_param.items() + if param in param_to_param_id + } + new_osd["state"] = { + param_name_to_param_id[param_name]: param_state + for param_name, param_state in osd["state"].items() + } + new_osd["param_groups"] = copy.deepcopy(osd["param_groups"]) + for param_group in new_osd["param_groups"]: + param_group["params"] = sorted( + [ + param_name_to_param_id[param_name] + for param_name in param_group["params"] + ] + ) + return new_osd + return new_osd # should never reach here + + @staticmethod + def optim_state_dict( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: Optional[dict[str, Any]] = None, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Transform the state-dict of an optimizer corresponding to a sharded model. + + The given state-dict can be transformed to one of three types: + 1) full optimizer state_dict, 2) sharded optimizer state_dict, 3) local optimizer state_dict. + + For full optimizer state_dict, all states are unflattened and not sharded. + Rank0 only and CPU only can be specified via :meth:`state_dict_type` to + avoid OOM. + + For sharded optimizer state_dict, all states are unflattened but sharded. + CPU only can be specified via :meth:`state_dict_type` to further save + memory. + + For local state_dict, no transformation will be performed. But a state + will be converted from nn.Tensor to ShardedTensor to represent its sharding + nature (this is not supported yet). + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> from torch.distributed.fsdp import FullStateDictConfig + >>> from torch.distributed.fsdp import FullOptimStateDictConfig + >>> # Save a checkpoint + >>> model, optim = ... + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> state_dict = model.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict(model, optim) + >>> save_a_checkpoint(state_dict, optim_state_dict) + >>> # Load a checkpoint + >>> model, optim = ... + >>> state_dict, optim_state_dict = load_a_checkpoint() + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> model.load_state_dict(state_dict) + >>> optim_state_dict = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state_dict + >>> ) + >>> optim.load_state_dict(optim_state_dict) + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_state_dict (Dict[str, Any]): the target optimizer state_dict to + transform. If the value is None, optim.state_dict() will be used. ( + Default: ``None``) + group (dist.ProcessGroup): Model's process group across which parameters + are sharded or ``None`` if using the default process group. ( + Default: ``None``) + + Returns: + Dict[str, Any]: A :class:`dict` containing the optimizer state for + ``model``. The sharding of the optimizer state is based on + ``state_dict_type``. + """ + state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) + if optim_state_dict is None: + optim_state_dict = optim.state_dict() + return FullyShardedDataParallel._optim_state_dict_impl( + model=model, + optim=optim, + optim_state_dict=optim_state_dict, + optim_input=None, + rank0_only=getattr( + state_dict_settings.optim_state_dict_config, "rank0_only", False + ), + full_state_dict=state_dict_settings.state_dict_type + == StateDictType.FULL_STATE_DICT, + group=group, + cpu_offload=getattr( + state_dict_settings.optim_state_dict_config, "offload_to_cpu", True + ), + _stacklevel=2, + ) + + @staticmethod + def optim_state_dict_to_load( + model: torch.nn.Module, + optim: torch.optim.Optimizer, + optim_state_dict: dict[str, Any], + is_named_optimizer: bool = False, + load_directly: bool = False, + group: Optional[dist.ProcessGroup] = None, + ) -> dict[str, Any]: + """ + Convert an optimizer state-dict so that it can be loaded into the optimizer associated with the FSDP model. + + Given a ``optim_state_dict`` that is transformed through + :meth:`optim_state_dict`, it gets converted to the flattened optimizer + state_dict that can be loaded to ``optim`` which is the optimizer for + ``model``. ``model`` must be sharded by FullyShardedDataParallel. + + >>> # xdoctest: +SKIP("undefined variables") + >>> from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + >>> from torch.distributed.fsdp import StateDictType + >>> from torch.distributed.fsdp import FullStateDictConfig + >>> from torch.distributed.fsdp import FullOptimStateDictConfig + >>> # Save a checkpoint + >>> model, optim = ... + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> state_dict = model.state_dict() + >>> original_osd = optim.state_dict() + >>> optim_state_dict = FSDP.optim_state_dict( + >>> model, + >>> optim, + >>> optim_state_dict=original_osd + >>> ) + >>> save_a_checkpoint(state_dict, optim_state_dict) + >>> # Load a checkpoint + >>> model, optim = ... + >>> state_dict, optim_state_dict = load_a_checkpoint() + >>> FSDP.set_state_dict_type( + >>> model, + >>> StateDictType.FULL_STATE_DICT, + >>> FullStateDictConfig(rank0_only=False), + >>> FullOptimStateDictConfig(rank0_only=False), + >>> ) + >>> model.load_state_dict(state_dict) + >>> optim_state_dict = FSDP.optim_state_dict_to_load( + >>> model, optim, optim_state_dict + >>> ) + >>> optim.load_state_dict(optim_state_dict) + + Args: + model (torch.nn.Module): Root module (which may or may not be a + :class:`FullyShardedDataParallel` instance) whose parameters + were passed into the optimizer ``optim``. + optim (torch.optim.Optimizer): Optimizer for ``model`` 's + parameters. + optim_state_dict (Dict[str, Any]): The optimizer states to be loaded. + is_named_optimizer (bool): Is this optimizer a NamedOptimizer or + KeyedOptimizer. Only set to True if ``optim`` is TorchRec's + KeyedOptimizer or torch.distributed's NamedOptimizer. + load_directly (bool): If this is set to True, this API will also + call optim.load_state_dict(result) before returning the result. + Otherwise, users are responsible to call ``optim.load_state_dict()`` + (Default: ``False``) + group (dist.ProcessGroup): Model's process group across which parameters + are sharded or ``None`` if using the default process group. ( + Default: ``None``) + """ + state_dict_settings = FullyShardedDataParallel.get_state_dict_type(model) + result = FullyShardedDataParallel._optim_state_dict_to_load_impl( + optim_state_dict=optim_state_dict, + model=model, + optim_input=None, + optim=optim, + full_state_dict=( + state_dict_settings.state_dict_type == StateDictType.FULL_STATE_DICT + ), + rank0_only=getattr( + state_dict_settings.optim_state_dict_config, "rank0_only", False + ), + is_named_optimizer=is_named_optimizer, + group=group, + ) + if load_directly: + optim.load_state_dict(result) + return result + + def register_comm_hook(self, state: object, hook: callable): + """Register a communication hook. + + This is an enhancement that provides a flexible hook to users where they can specify how FSDP aggregates + gradients across multiple workers. + This hook can be used to implement several algorithms like + `GossipGrad `_ and gradient compression + which involve different communication strategies for + parameter syncs while training with :class:`FullyShardedDataParallel`. + + .. warning :: + FSDP communication hook should be registered before running an initial forward pass + and only once. + + Args: + state (object): Passed to the hook to maintain any state information during the training process. + Examples include error feedback in gradient compression, + peers to communicate with next in `GossipGrad `_, etc. + It is locally stored by each worker + and shared by all the gradient tensors on the worker. + hook (Callable): Callable, which has one of the following signatures: + 1) ``hook: Callable[torch.Tensor] -> None``: + This function takes in a Python tensor, which represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). + It then performs all necessary processing and returns ``None``; + 2) ``hook: Callable[torch.Tensor, torch.Tensor] -> None``: + This function takes in two Python tensors, the first one represents + the full, flattened, unsharded gradient with respect to all variables + corresponding to the model this FSDP unit is wrapping + (that are not wrapped by other FSDP sub-units). The latter + represents a pre-sized tensor to store a chunk of a sharded gradient after + reduction. + In both cases, callable performs all necessary processing and returns ``None``. + Callables with signature 1 are expected to handle gradient communication for a `NO_SHARD` case. + Callables with signature 2 are expected to handle gradient communication for sharded cases. + + """ + if not self.check_is_root(): + raise AssertionError( + "register_comm_hook can only be called on a root instance." + ) + for fsdp_state in traversal_utils._get_fsdp_states(self): + if fsdp_state.sharding_strategy in HYBRID_SHARDING_STRATEGIES: + raise AssertionError( + f"Communication hook is not supported for hybrid strategies: {fsdp_state.sharding_strategy}" + ) + if fsdp_state._comm_hook is not None: + raise AssertionError("A communication hook is already registered") + if not callable(hook): + raise ValueError( + f"The communication hook must be callable but got {hook}" + ) + fsdp_state._comm_hook = hook + fsdp_state._comm_hook_state = state + + def _unshard(self, async_op: bool = False): + class UnshardHandle: + def __init__( + self, + flat_param_handle: Optional[FlatParamHandle], + unshard_event: torch.Event, + ): + self._flat_param_handle = flat_param_handle + self._unshard_event = unshard_event + + def wait(self): + if self._flat_param_handle is not None: + current_stream = ( + self._flat_param_handle._device_handle.current_stream() + ) + current_stream.wait_event(self._unshard_event) + self._flat_param_handle = None + + if self._handle: + with self._use_training_state( + TrainingState.FORWARD_BACKWARD, HandleTrainingState.FORWARD + ): + _unshard( + self, self._handle, self._unshard_stream, self._pre_unshard_stream + ) + self._unshard_event = self._unshard_stream.record_event() + self._handle._prefetched = True + unshard_handle = UnshardHandle(self._handle, self._unshard_stream) + if async_op: + return unshard_handle + unshard_handle.wait() + return None + + def _wait_unshard_streams_on_current_stream(self): + _wait_for_computation_stream( + self._device_handle.current_stream(), + self._unshard_stream, + self._pre_unshard_stream, + ) + + @contextlib.contextmanager + def _use_training_state( + self, training_state: TrainingState, handle_training_state: HandleTrainingState + ): + prev_training_state = self.training_state + self.training_state = training_state + if self._handle: + prev_handle_training_state = self._handle._training_state + self._handle._training_state = handle_training_state + try: + yield + finally: + self.training_state = prev_training_state + if self._handle: + self._handle._training_state = prev_handle_training_state + + +def _get_grad_norm( + params: Iterable[nn.Parameter], + norm_type: float, + zero: torch.Tensor, + device: torch.device, +) -> torch.Tensor: + """ + Return the gradient norm of parameters ``param`` s, where the gradients are viewed as a single vector. + + The returned norm is in FP32 even if parameters/gradients are in a low precision. This is because the downstream + use of this return value is a reduction across ranks. + """ + params_with_grad = [param for param in params if param.grad is not None] + if len(params_with_grad) == 0: + # Reuse a tensor for zero to avoid a GPU sync + return zero + grads = [param.grad for param in params_with_grad] + grad_dtypes = {grad.dtype for grad in grads} + if len(grad_dtypes) != 1: + raise ValueError( + f"Requires uniform dtype across all gradients but got {grad_dtypes}" + ) + # Compute the gradient norm in FP32, where we treat the gradients as a + # single vector + grad_norm = torch.linalg.vector_norm( + torch.stack( + [ + torch.linalg.vector_norm(grad.detach(), norm_type, dtype=torch.float32) + for grad in grads + ], + ), + norm_type, + dtype=torch.float32, + ) + return grad_norm.to(device=device) + + +def _get_param_to_fqn( + model: torch.nn.Module, +) -> dict[torch.nn.Parameter, str]: + """ + Construct a mapping from parameters to their parameter names. + + The ``model`` should not contain any :class:`FullyShardedDataParallel` instances, which + means that none of the parameters should be ``FlatParameter`` s. As a + result, compared to :meth:`_get_param_to_fqns`, the mapped + values may be flattened from singleton :class:`list` s to the contained + names themselves. + + Args: + model (torch.nn.Module): Root module, which should not contain any + :class:`FullyShardedDataParallel` instances. + """ + param_to_param_names = _get_param_to_fqns(model) + for param_names in param_to_param_names.values(): + assert len(param_names) > 0, ( + "`_get_param_to_fqns()` should not construct empty lists" + ) + if len(param_names) > 1: + raise RuntimeError( + "Each parameter should only map to one parameter name but got " + f"{len(param_names)}: {param_names}" + ) + param_to_param_name = { + param: param_names[0] for param, param_names in param_to_param_names.items() + } + return param_to_param_name + + +def _get_fqn_to_param( + model: torch.nn.Module, +) -> dict[str, torch.nn.Parameter]: + """Construct the inverse mapping of :meth:`_get_param_to_fqn`.""" + param_to_param_name = _get_param_to_fqn(model) + return dict(zip(param_to_param_name.values(), param_to_param_name.keys())) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..964be6efc600ef2ae4c37ef3b485e3e81340ed66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/sharded_grad_scaler.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +import logging +from collections import abc, defaultdict +from collections.abc import Iterable +from typing import Any, Optional, overload, Union + +import torch +import torch.distributed as dist +from torch.amp.grad_scaler import _MultiDeviceReplicator, GradScaler, OptState +from torch.distributed.distributed_c10d import ProcessGroup + + +logger = logging.getLogger(__name__) + + +def _refresh_per_optimizer_state() -> dict[str, Any]: + return {"stage": OptState.READY, "found_inf_per_device": {}} + + +def _is_supported_device(tensor: torch.Tensor) -> bool: + return tensor.is_cuda or tensor.device.type in ( + "xla", + "cpu", + "hpu", + "mtia", + "xpu", + torch._C._get_privateuse1_backend_name(), + ) + + +class _GeneralMultiDeviceReplicator(_MultiDeviceReplicator): + """ + Lazily serves tensor to request device. This class extends + _MultiDeviceReplicator to allow support for "cpu" as a device. + """ + + def __init__(self, master_tensor: torch.Tensor) -> None: + assert _is_supported_device(master_tensor) + self.master = master_tensor + self._per_device_tensors: dict[torch.device, torch.Tensor] = {} + + +class ShardedGradScaler(GradScaler): + """ + ShardedGradScaler helps perform gradient scaling in a shard aware manner. It extends + functionality from GradScaler: + * Supports Pytorch DDP and FSDP implementations + * Support CPU offloaded tensors (as used in fully sharded data parallel[FSDP]) + * Supports the custom Mixed Precision loss dtype (fp16, bf16) that FSDP returns + * Sync inf/nan for scaled gradient tensors on any torch.device (where tensors are placed) across + nodes + + Example:: + + # Creates a ShardedGradScaler once at the beginning of training. + scaler = ShardedGradScaler() + + for epoch in epochs: + for input, target in data: + optimizer.zero_grad() + output = model(input) + loss = loss_fn(output, target) + + # Scales loss. Calls backward() on scaled loss to create scaled gradients. + scaler.scale(loss).backward() + + # scaler.step() first unscales gradients of the optimizer's params. + # If gradients don't contain infs/NaNs, optimizer.step() is then called, + # otherwise, optimizer.step() is skipped. + scaler.step(optimizer) + + # Updates the scale for next iteration. + scaler.update() + + See :class:`GradScaler` for explanation of scaling/unscaling and more use cases. + + Args: + init_scale (float, optional, default=2.**16): Initial scale factor. + growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during + :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations. + backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during + :meth:`update` if inf/NaN gradients occur in an iteration. + growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients + that must occur for the scale to be multiplied by ``growth_factor``. + enabled (bool, optional): If ``False``, disables gradient scaling. :meth:`step` simply + invokes the underlying ``optimizer.step()``, and other methods become no-ops. + Default: ``True`` + process_group (ProcessGroup, optional, default=torch.distributed.group.WORLD): + process group for sharding + """ + + def __init__( + self, + device: str = "cuda", + init_scale: float = 2.0**16, + backoff_factor: float = 0.5, + growth_factor: float = 2.0, + growth_interval: int = 2000, + enabled: bool = True, + process_group: Optional[ProcessGroup] = dist.group.WORLD, + ) -> None: + super().__init__( + device, + init_scale=init_scale, + backoff_factor=backoff_factor, + growth_factor=growth_factor, + growth_interval=growth_interval, + enabled=enabled, + ) + if self._enabled: + self.process_group = process_group + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) + + @overload + def scale(self, outputs: torch.Tensor) -> torch.Tensor: ... + + @overload + def scale(self, outputs: list[torch.Tensor]) -> list[torch.Tensor]: ... + + @overload + def scale(self, outputs: tuple[torch.Tensor, ...]) -> tuple[torch.Tensor, ...]: ... + + @overload + def scale(self, outputs: Iterable[torch.Tensor]) -> Iterable[torch.Tensor]: ... + + def scale( + self, outputs: Union[torch.Tensor, Iterable[torch.Tensor]] + ) -> Union[torch.Tensor, Iterable[torch.Tensor]]: + if not self._enabled: + return outputs + + if isinstance(outputs, torch.Tensor): + assert _is_supported_device(outputs) + if self._scale is None: + self._lazy_init_scale_growth_tracker(outputs.device) + assert self._scale is not None + scaled_output = outputs * self._scale.to( + device=outputs.device, non_blocking=True + ) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_output.type(outputs.dtype) + + stash: list[_GeneralMultiDeviceReplicator] = [] + + def apply_scale(val: Union[torch.Tensor, Iterable[torch.Tensor]]): + if isinstance(val, torch.Tensor): + assert _is_supported_device(val) + if len(stash) == 0: + if self._scale is None: + self._lazy_init_scale_growth_tracker(val.device) + assert self._scale is not None + stash.append(_GeneralMultiDeviceReplicator(self._scale)) + scaled_val = val * stash[0].get(val.device) + # Here we ensure the return dtype is the same as the outputs dtype. + # For the FSDP + Mixed Precision use case, the loss output is in the Mixed Precision + # format (fp16, bf16) and so the scaled loss should be of the same dtype. + return scaled_val.type(val.dtype) + if isinstance(val, abc.Iterable): + iterator = map(apply_scale, val) + if isinstance(val, (list, tuple)): + return type(val)(iterator) + return iterator + raise ValueError("outputs must be a Tensor or an iterable of Tensors") + + return apply_scale(outputs) + + def _unscale_grads_( + self, + optimizer: torch.optim.Optimizer, + inv_scale: torch.Tensor, + found_inf: torch.Tensor, + allow_fp16: bool = True, + ) -> dict[torch.device, torch.Tensor]: + per_device_inv_scale = _GeneralMultiDeviceReplicator(inv_scale) + per_device_found_inf = _GeneralMultiDeviceReplicator(found_inf) + + # To set up _amp_foreach_non_finite_check_and_unscale_, split grads by device and dtype. + # There could be thousands of grads, so we'd like to iterate through them just once. + # However, we don't know their devices or dtypes in advance. + + # https://stackoverflow.com/questions/5029934/defaultdict-of-defaultdict + # Google says mypy struggles with defaultdicts type annotations. + per_device_and_dtype_grads = defaultdict(lambda: defaultdict(list)) # type: ignore[var-annotated] + with torch.no_grad(): + for group in optimizer.param_groups: + for param in group["params"]: + if param.grad is None: + continue + if (not allow_fp16) and param.grad.dtype == torch.float16: + raise ValueError("Attempting to unscale FP16 gradients.") + if param.grad.is_sparse: + # is_coalesced() == False means the sparse grad has values with duplicate indices. + # coalesce() deduplicates indices and adds all values that have the same index. + # For scaled fp16 values, there's a good chance coalescing will cause overflow, + # so we should check the coalesced _values(). + if param.grad.dtype is torch.float16: + # coalesce is not supported in torch.float16 + param_grad_fp32 = param.grad.type(torch.float32).coalesce() + param.grad = param_grad_fp32.type(torch.float16) + to_unscale = param.grad._values() + else: + to_unscale = param.grad + + per_device_and_dtype_grads[to_unscale.device][ + to_unscale.dtype + ].append(to_unscale) + + for device, per_dtype_grads in per_device_and_dtype_grads.items(): + for grads in per_dtype_grads.values(): + torch._amp_foreach_non_finite_check_and_unscale_( + grads, + per_device_found_inf.get(device), + per_device_inv_scale.get(device), + ) + # There exist contexts (e.g. w/ `use_orig_params=True`) wherein some + # ranks may have no (non-zero sized) parameter shards, necessitating the + # initialization of `per_device_found_inf._per_device_tensors` here + if not per_device_found_inf._per_device_tensors: + assert self._scale is not None + per_device_found_inf.get(self._scale.device) + return per_device_found_inf._per_device_tensors + + def unscale_(self, optimizer: torch.optim.Optimizer) -> None: + if not self._enabled: + return + + self._check_scale_growth_tracker("unscale_") + + optimizer_state = self._per_optimizer_states[id(optimizer)] + + if optimizer_state["stage"] is OptState.UNSCALED: + raise RuntimeError( + "unscale_() has already been called on this optimizer since the last update()." + ) + elif optimizer_state["stage"] is OptState.STEPPED: + raise RuntimeError("unscale_() is being called after step().") + + # FP32 division can be imprecise for certain compile options, so we carry out the reciprocal in FP64. + assert self._scale is not None + inv_scale = self._scale.double().reciprocal().float() + found_inf = torch.full( + (1,), 0.0, dtype=torch.float32, device=self._scale.device + ) + + optimizer_state["found_inf_per_device"] = self._unscale_grads_( + optimizer, inv_scale, found_inf, True + ) + optimizer_state["stage"] = OptState.UNSCALED + + # Synchronize the detected inf across the ranks + optimizer_state = self._per_optimizer_states[id(optimizer)] + works = [] + found_inf_on_cpus = [] + found_inf_on_devices = [] + + for found_inf in optimizer_state["found_inf_per_device"].values(): + if self._device != "cpu" and found_inf.device.type == "cpu": + found_inf_on_cpus.append(found_inf) + found_inf_on_device = found_inf.to(self._device) + found_inf_on_devices.append(found_inf_on_device) + works.append( + dist.all_reduce( + found_inf_on_device, async_op=True, group=self.process_group + ) + ) + else: + works.append( + dist.all_reduce(found_inf, async_op=True, group=self.process_group) + ) + for work in works: + work.wait() + if found_inf_on_cpus: + torch._foreach_copy_(found_inf_on_cpus, found_inf_on_devices) + + def _amp_update_scale_cpu_(self, found_inf: torch.Tensor) -> None: + """ + If found_inf is 1.0 (True), then scale is multiplied by backoff_factor and growth_tracker is set to zero. + Otherwise, scale is multiplied by the growth factor when the growth interval is reached. + """ + assert self._scale is not None and self._growth_tracker is not None + + if found_inf.item() >= 1.0: + self._scale *= self._backoff_factor + self._growth_tracker.fill_(0) + else: + successful = self._growth_tracker + 1 + if successful == self._growth_interval: + self._scale *= self._growth_factor + self._growth_tracker.fill_(0) + else: + self._growth_tracker = successful + + def update(self, new_scale: Optional[Union[float, torch.Tensor]] = None) -> None: + """ + Updates the scale factor. + If any optimizer steps were skipped the scale is multiplied by ``backoff_factor`` + to reduce it. If ``growth_interval`` unskipped iterations occurred consecutively, + the scale is multiplied by ``growth_factor`` to increase it. + Passing ``new_scale`` sets the new scale value manually. (``new_scale`` is not + used directly, it's used to fill GradScaler's internal scale tensor. So if + ``new_scale`` was a tensor, later in-place changes to that tensor will not further + affect the scale GradScaler uses internally.) + Args: + new_scale (float or :class:`torch.Tensor`, optional, default=None): New scale factor. + .. warning:: + :meth:`update` should only be called at the end of the iteration, after ``scaler.step(optimizer)`` has + been invoked for all optimizers used this iteration. + """ + + if not self._enabled: + return + + _scale, _growth_tracker = self._check_scale_growth_tracker("update") # type: ignore[var-annotated] + + if new_scale is not None: + # Accept a new user-defined scale. + if isinstance(new_scale, float): + self._scale.fill_(new_scale) # type: ignore[union-attr] + else: + reason = ( + "new_scale should be a float or a 1-element torch.cuda.FloatTensor or \ + torch.FloatTensor with requires_grad=False." + ) + assert new_scale.device.type == self._device, reason + assert new_scale.numel() == 1, reason + assert new_scale.requires_grad is False, reason + self._scale.copy_(new_scale) # type: ignore[union-attr] + else: + # Consume shared inf/nan data collected from optimizers to update the scale. + # If all found_inf tensors are on the same device as self._scale, this operation is asynchronous. + found_infs = [ + found_inf.to(device=_scale.device, non_blocking=True) + for state in self._per_optimizer_states.values() + for found_inf in state["found_inf_per_device"].values() + ] + + assert len(found_infs) > 0, "No inf checks were recorded prior to update." + + found_inf_combined = found_infs[0] + if len(found_infs) > 1: + for i in range(1, len(found_infs)): + found_inf_combined += found_infs[i] + + if _scale.device.type == "cpu": + self._amp_update_scale_cpu_(found_inf_combined) + else: + torch._amp_update_scale_( + self._scale, # type: ignore[arg-type] + self._growth_tracker, # type: ignore[arg-type] + found_inf_combined, + self._growth_factor, # type: ignore[arg-type] + self._backoff_factor, # type: ignore[arg-type] + self._growth_interval, # type: ignore[arg-type] + ) + + # To prepare for next iteration, clear the data collected from optimizers this iteration. + self._per_optimizer_states = defaultdict(_refresh_per_optimizer_state) diff --git a/phivenv/Lib/site-packages/torch/distributed/fsdp/wrap.py b/phivenv/Lib/site-packages/torch/distributed/fsdp/wrap.py new file mode 100644 index 0000000000000000000000000000000000000000..ec9027686d4862c00a2e0a2944b9f6d4a4725b1a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/fsdp/wrap.py @@ -0,0 +1,596 @@ +# mypy: allow-untyped-defs +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import contextlib +import copy +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable, Sequence +from typing import Any, Callable, cast, Optional, Union + +import torch.nn as nn + + +__all__ = [ + "always_wrap_policy", + "lambda_auto_wrap_policy", + "transformer_auto_wrap_policy", + "size_based_auto_wrap_policy", + "enable_wrap", + "wrap", + "CustomPolicy", + "ModuleWrapPolicy", +] + + +# NOTE: We intentionally keep this function simple and isolate the complexity +# to `fn` to enable using this function generically. We may move this to a +# non-FSDP-specific folder and/or make it public in the future. +def _post_order_apply( + root_module: nn.Module, + fn: Callable[[nn.Module], Optional[nn.Module]], +): + """ + This applies ``fn`` to every module in the module tree of ``root_module`` + following a post-order traversal. If ``fn`` returns an :class:`nn.Module`, + then this replaces the original module with the newly returned one in the + tree. Otherwise, ``fn`` should return ``None``, in which case the module is + not changed. + """ + # Track visited modules to avoid visiting shared modules multiple times + visited_modules: set[nn.Module] = {root_module} + + def _post_order_apply_inner( + module: nn.Module, + module_name: str, + parent_module: Optional[nn.Module], + ): + for child_module_name, child_module in module.named_children(): + if child_module not in visited_modules: + visited_modules.add(child_module) + _post_order_apply_inner(child_module, child_module_name, module) + optional_module = fn(module) + if optional_module is not None: + assert isinstance(parent_module, nn.Module), ( + "Non-root modules should have their parent module set but got " + f"{parent_module} for {module}" + ) + assert module_name, ( + "Non-root modules should have their module name set but got " + f"an empty module name for {module}" + ) + assert isinstance(optional_module, nn.Module), ( + f"fn should return None or an nn.Module but got {optional_module}" + ) + setattr(parent_module, module_name, optional_module) + + _post_order_apply_inner(root_module, "", None) + + +def _construct_wrap_fn( + root_module: nn.Module, + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], + fsdp_fn: Callable, +) -> Callable[[nn.Module], Optional[nn.Module]]: + """ + This constructs the "wrap" function to pass to :func:`_post_order_apply` + based on ``target_module_to_kwargs``, which should be constructed from the + wrapping policy. + """ + + def fn(module: nn.Module) -> Optional[nn.Module]: + # Explicitly avoid wrapping the root module since for FSDP, it is + # handled by the caller + if module in target_module_to_kwargs and module is not root_module: + kwargs = target_module_to_kwargs[module] + return fsdp_fn(module, **kwargs) + return None + + return fn + + +def _run_mixed_precision_override_policy( + root_module: nn.Module, + module_classes: Iterable[type[nn.Module]], + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + target_module_to_kwargs: dict[nn.Module, dict[str, Any]], +): + module_classes_tuple = tuple(set(module_classes)) + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes_tuple): + # This policy overrides any existing policy + if module not in target_module_to_kwargs: + # Only inherit from the root kwargs if not already specified + target_module_to_kwargs[module] = root_kwargs + target_module_to_kwargs[module]["mixed_precision"] = None + return target_module_to_kwargs + + +def always_wrap_policy(*args, **kwargs) -> bool: + """ + A simple recursive wrap policy that always returns ``True``. This means + that every submodule is wrapped by the wrapper class in + :func:`_recursive_wrap`. + """ + return True + + +class _Policy(ABC): + """ + This defines an abstract base class that represents a policy for applying + a module-level API. + """ + + @abstractmethod + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + """ + This should return a dict ``target_module_to_kwargs`` that maps from + each target module to wrap to its kwargs. + """ + ... + + +def _module_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + module_classes: set[type[nn.Module]], +) -> bool: + """ + This auto wrap policy wraps every module that is an instance of any type in + ``module_classes`` as its own FSDP instance. The root module given by + ``module`` is always wrapped as an FSDP instance regardless. Since the + wrapping proceeds bottom up, each FSDP instance manages the parameters in + its subtree excluding any already managed by a child FSDP instance. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + module_classes (Set[Type[nn.Module]]): Set of module classes that are + wrapped as FSDP instances. + + Returns: + ``True`` if ``recurse=True``, and whether ``module`` should be wrapped + if ``recurse=False``. + """ + if recurse: + return True # always recurse + return isinstance(module, tuple(module_classes)) + + +class ModuleWrapPolicy(_Policy): + """ + This policy applies to every module of the specified module classes, + passing in the kwargs given to the root. + """ + + def __init__(self, module_classes: Iterable[type[nn.Module]]): + module_classes_set = set(module_classes) + self._module_classes = module_classes_set + self._module_classes_str = str(module_classes_set) + + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + module_classes = tuple(self._module_classes) + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + elif isinstance(module, module_classes): + # Shallow copy to avoid coupling changes across modules + target_module_to_kwargs[module] = copy.copy(root_kwargs) + return target_module_to_kwargs + + def __call__(self, module, recurse, *args, **kwargs): + # nonwrapped_numel is not used. + return _module_wrap_policy( + module, recurse, nonwrapped_numel=-1, module_classes=self._module_classes + ) + + def __repr__(self) -> str: + return super().__repr__() + f"({self._module_classes_str})" + + +class CustomPolicy(_Policy): + """ + This policy takes in a lambda function that maps a given ``nn.Module`` to + either ``False``, ``True``, or a kwarg dictionary. + - If the function returns ``False`` or an empty dictionary, then the module + does not have the API applied. + - If the function returns ``True``, then the module has the API applied + with the root's kwargs. + - If the function returns a non-empty dictionary, then the module has the + API applied, and the dictionary overrides the root's kwargs. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> model = init_transformer_model(...) + >>> def lambda_fn(module: nn.Module): + >>> if module is model.lm_head: + >>> return {"sharding_strategy": ShardingStrategy.SHARD_GRAD_OP} + >>> elif isinstance(module, TransformerBlock): + >>> return True + >>> return False + >>> policy = CustomPolicy(lambda_fn) + >>> fsdp_model = FSDP(model, auto_wrap_policy=policy) + """ + + def __init__(self, lambda_fn: Callable[[nn.Module], Union[bool, dict[str, Any]]]): + self._lambda_fn = lambda_fn + + def _run_policy( + self, + root_module: nn.Module, + ignored_modules: set[nn.Module], + root_kwargs: dict[str, Any], + ) -> dict[nn.Module, dict[str, Any]]: + target_module_to_kwargs: dict[nn.Module, dict[str, Any]] = {} + for module in root_module.modules(): + if module in ignored_modules: + continue + res = self._lambda_fn(module) + if not isinstance(res, (dict, bool)): + raise ValueError( + "The lambda_fn passed to CustomPolicy should return " + f"False/True or a kwarg dict, but it returned {res}" + ) + if not res: + continue + kwargs = copy.copy(root_kwargs) + if isinstance(res, dict): + # Override the root kwargs with the ones specified by the + # lambda function + kwargs.update(res) + target_module_to_kwargs[module] = kwargs + return target_module_to_kwargs + + +def lambda_auto_wrap_policy( + module: nn.Module, recurse: bool, nonwrapped_numel: int, lambda_fn: Callable +) -> bool: + """ + A convenient auto wrap policy to wrap submodules based on an arbitrary user + function. If `lambda_fn(submodule) == True``, the submodule will be wrapped as + a `wrapper_cls` unit. + + Return if a module should be wrapped during auto wrapping. + + The first three parameters are required by :func:`_recursive_wrap`. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + lambda_fn (Callable[[nn.Module], bool]): If this returns ``True``, then + this module will be wrapped. + """ + if recurse: + return True # always recurse + return lambda_fn(module) + + +def transformer_auto_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + transformer_layer_cls: set[type[nn.Module]], +) -> bool: + """ + See :func:`_module_wrap_policy`, where ``transformer_layer_cls`` is the + same as ``module_classes``. Note that shared parameters must be wrapped in + the same FSDP instance, so this auto wrap policy can help wrap shared + embeddings into the same FSDP instance for transformer models. + """ + return _module_wrap_policy(module, recurse, nonwrapped_numel, transformer_layer_cls) + + +def _wrap_module_cls_individually( + module: nn.Module, module_classes: Sequence[type], recurse: bool, *args, **kwargs +): + if recurse: + # always recurse + return True + else: + # if not recursing, decide whether we should wrap based on whether the type of module + # is in `module_classes`. + return isinstance(module, tuple(module_classes)) + + +def _or_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + policies, +) -> bool: + """ + A policy that wraps ``module`` if any policy in the passed in iterable of + ``policies`` returns ``True``. + """ + return any( + policy(module=module, recurse=recurse, nonwrapped_numel=nonwrapped_numel) + for policy in policies + ) + + +def size_based_auto_wrap_policy( + module: nn.Module, + recurse: bool, + nonwrapped_numel: int, + # Additional custom arguments + min_num_params: int = int(1e8), + force_leaf_modules: Optional[set[type[nn.Module]]] = None, + exclude_wrap_modules: Optional[set[type[nn.Module]]] = None, +) -> bool: + """ + A size-based auto wrap policy. + + Args: + module (nn.Module): Current module being considered. + recurse (bool): If ``False``, then this function must decide whether + ``module`` should be wrapped as an FSDP instance or not. If + ``True``, then the function is still recursing down the module + tree as a part of the DFS. + nonwrapped_numel (int): Parameter numel not yet wrapped. + + min_num_params (int): Customizable policy input that controls the size + threshold over which a module is ready to be wrapped. This is in + units of numel. + force_leaf_modules (Optional[set[type[nn.Module]]]): Set of module types to keep + as leaves, i.e. their children will never be wrapped. + exclude_wrap_modules (Optional[set[type[nn.Module]]]): Set of module types to be + excluded in wrapping. + + Returns: + Whether ``module`` should be wrapped. + """ + force_leaf_modules = ( + size_based_auto_wrap_policy.FORCE_LEAF_MODULES # type: ignore[attr-defined] + if force_leaf_modules is None + else force_leaf_modules + ) + exclude_wrap_modules = ( + size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES # type: ignore[attr-defined] + if exclude_wrap_modules is None + else exclude_wrap_modules + ) + + # Keep the argument `min_num_params` for BC for now, but it represents the + # minimum non-wrapped *numel* before triggering a wrapping + min_nonwrapped_numel = min_num_params + is_large = nonwrapped_numel >= min_nonwrapped_numel + if recurse: + # We should recurse if the module is big enough but not in force_leaf_modules list. + return is_large and not isinstance(module, tuple(force_leaf_modules)) + else: + # If we are not recursing, determine if we should wrap. + return is_large and not isinstance(module, tuple(exclude_wrap_modules)) + + +# Set those defaults to the size_based_auto_wrap_policy function. Make them easy to be imported. +size_based_auto_wrap_policy.EXCLUDE_WRAP_MODULES = {nn.ModuleList, nn.ModuleDict} # type: ignore[attr-defined] +size_based_auto_wrap_policy.FORCE_LEAF_MODULES = {nn.MultiheadAttention} # type: ignore[attr-defined] + + +@contextlib.contextmanager +def enable_wrap( + *, wrapper_cls: Any, **wrapper_kwargs: Any +) -> Generator[None, None, None]: + """ + Context manager to wrap modules using a wrapper. + + Useful for when you'd like to apply the same configuration arguments to all + child modules that you wrap. A particularly important use case is wrapping + large layers so that they get sharded (in-place) during initialization, to + avoid running out of system memory. Large layers can indicate that they + should be sharded via the ``wrap`` annotation and this context manager can + provide the exact configuration for these nested instances. + + Usage:: + + with enable_wrap(wrapper_cls, **params): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + wrapper_cls: + Class that `wrap` annotation will `wrap` modules with, such as + `FullyShardedDataParallel`. + **wrapper_kwargs: + Configuration settings that will be passed to all ``wrap`` + instances inside the context + """ + kwargs = { + "wrapper_cls": wrapper_cls, + **wrapper_kwargs, + } + with _ConfigAutoWrap(**kwargs): + yield + + +def wrap(module: nn.Module, **wrap_overrides: Any) -> nn.Module: + """ + Annotate that a module should be wrapped. Annotated modules will only be + wrapped if inside of an :func:`enable_wrap` context manager. This allows + a module to be initialized both with and without a wrapper without code + change. + + The class that this function wraps the passed in ``nn.Module`` with is the + passed in ``wrapper_cls`` argument into ``enable_wrap``. Both + ``enable_wrap`` and ``wrap`` can take in kwargs specifying how to construct + the ``wrapper_cls`` instance. In the case of duplicate kwargs in + ``enable_wrap`` and ``wrap``, the argument passed into ``wrap`` will be + respected. + + Usage:: + + with enable_wrap(wrapper_cls=FSDP, **fsdp_config): + # Wraps layer in FSDP by default if within context + self.l1 = wrap(torch.nn.Linear(5, 5)) + + Args: + module (nn.Module): module to wrap (if in :func:`enable_wrap` context) + **wrap_overrides: configuration overrides that will take priority over + the values provided by the :func:`enable_wrap` context + """ + if _ConfigAutoWrap.in_autowrap_context: + assert _ConfigAutoWrap.wrapper_cls is not None + + wrap_overrides = {**_ConfigAutoWrap.kwargs, **wrap_overrides} + return _wrap( + module, + _ConfigAutoWrap.wrapper_cls, + **wrap_overrides, + ) + return module + + +def _wrap(module: nn.Module, wrapper_cls: Callable, **kwargs) -> nn.Module: + assert wrapper_cls is not None + if hasattr(module, "_wrap_overrides"): + # If module has a _wrap_overrides attribute, we force overriding the + # FSDP config with these attributes for this module. Currently this + # is only used to disable mixed precision for BatchNorm when + # auto_wrapping. + overrides = {**kwargs, **module._wrap_overrides} # type: ignore[arg-type, dict-item] + return wrapper_cls(module, **overrides) + + return wrapper_cls(module, **kwargs) + + +def _recursive_wrap( + module: nn.Module, + auto_wrap_policy: Callable, + wrapper_cls: Callable, + ignored_modules: set[nn.Module], + ignored_params: set[nn.Parameter], + only_wrap_children: bool = False, + **kwargs: Any, +) -> tuple[nn.Module, int]: + """ + Wraps submodules of ``module`` for which ``auto_wrap_policy`` returns + ``True`` with ``wrapper_cls``. + + Args: + module (nn.Module): Module to recursively wrap. + auto_wrap_policy (Callable): A callable representing a policy that + determines which modules to recursively wrap with ``wrapper_cls``. + ignored_modules (set[torch.nn.Module]): Modules to ignore when + wrapping. + ignored_params (set[torch.nn.Parameter]): Parameters to ignore when + wrapping; these should be the parameters contained in the modules + in ``ignored_modules``. + Returns: + (nn.Module, int): + ``module`` after wrapping and the numel recursively wrapped. + """ + assert auto_wrap_policy is not None, "Must specify auto_wrap_policy." + assert wrapper_cls is not None, "Must specify wrapper_cls" + # Make sure no child is already wrapped. + for _, child in module.named_modules(): + if child in ignored_modules: + continue + try: + assert not isinstance(child, cast(type, wrapper_cls)) + except TypeError: + # wrapper_cls is a function as opposed to a class type, just bypass above check. + pass + + # We count all params, assuming none of them are already wrapped. + nonwrapped_numel = sum( + p.numel() for p in module.parameters() if p not in ignored_params + ) + + assert auto_wrap_policy is not None + if auto_wrap_policy(module=module, recurse=True, nonwrapped_numel=nonwrapped_numel): + total_wrapped_numel = 0 + # Iterate through the children, recursively wrap if necessary + for name, child in module.named_children(): + if child in ignored_modules: + continue + wrapped_child, num_wrapped_params = _recursive_wrap( + module=child, + auto_wrap_policy=auto_wrap_policy, + wrapper_cls=wrapper_cls, + ignored_modules=ignored_modules, + ignored_params=ignored_params, + **kwargs, + ) + setattr(module, name, wrapped_child) + # Keep track of how many parameters have been wrapped + total_wrapped_numel += num_wrapped_params + # decide if we need to wrap the current module, + # since the left over parameters exceed the number of params to wrap + remainder = nonwrapped_numel - total_wrapped_numel + if not only_wrap_children and auto_wrap_policy( + module=module, recurse=False, nonwrapped_numel=remainder + ): + # Leaf node or final wrapping of the remainder both happen here. + return _wrap(module, wrapper_cls, **kwargs), nonwrapped_numel + else: + return module, total_wrapped_numel + return module, 0 + + +class _ConfigAutoWrap: + """ + Helper class to wrap modules based on default config args via a context manager. + See :func:`enable_wrap` for more information. + """ + + in_autowrap_context: bool = False # Context flag + wrapper_cls: Optional[Callable] = None # The wrapper class + kwargs: dict[str, Any] = {} # Wrapper's args + + def __init__(self, **kwargs: dict[str, Any]): + self.kwargs = kwargs + + @staticmethod + def enable_autowrap_context(kwargs: Any) -> None: + if _ConfigAutoWrap.in_autowrap_context: + raise NotImplementedError( + "You are already within an autowrap context and we currently do not supported nested autowrap." + ) + _ConfigAutoWrap.in_autowrap_context = True + # Get and save the wrapper cls for the context. + assert "wrapper_cls" in kwargs.keys(), ( + "Expected to pass in wrapper_cls arg into _ConfigAutoWrap." + ) + _ConfigAutoWrap.wrapper_cls = cast(Callable, kwargs["wrapper_cls"]) + del kwargs["wrapper_cls"] + # Save the rest. + _ConfigAutoWrap.kwargs = kwargs + + @staticmethod + def disable_autowrap_context() -> None: + _ConfigAutoWrap.in_autowrap_context = False + _ConfigAutoWrap.wrapper_cls = None + _ConfigAutoWrap.kwargs = {} + + def __enter__(self) -> None: + self.enable_autowrap_context(self.kwargs) + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + self.disable_autowrap_context() diff --git a/phivenv/Lib/site-packages/torch/distributed/launcher/__init__.py b/phivenv/Lib/site-packages/torch/distributed/launcher/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1f1bb694f17a678be6f4902ba44f1b1f23449f5d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/launcher/__init__.py @@ -0,0 +1,14 @@ +#!/usr/bin/env/python3 + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from torch.distributed.launcher.api import ( # noqa: F401 + elastic_launch, + launch_agent, + LaunchConfig, +) diff --git a/phivenv/Lib/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a6854075f57592680e02331d59a0943a78d93120 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/launcher/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/launcher/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/launcher/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a04049defa1c0fff7bbb0f1ae20a10cbf1f0e475 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/launcher/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/launcher/api.py b/phivenv/Lib/site-packages/torch/distributed/launcher/api.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc70f7c166146832a398433fc55f7f651726dad --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/launcher/api.py @@ -0,0 +1,297 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import sys +import uuid +from dataclasses import dataclass, field +from typing import Any, Callable, Optional, Union + +import torch.distributed.elastic.rendezvous.registry as rdzv_registry +from torch.distributed.elastic import events, metrics +from torch.distributed.elastic.agent.server.api import WorkerSpec +from torch.distributed.elastic.agent.server.local_elastic_agent import LocalElasticAgent +from torch.distributed.elastic.multiprocessing import ( + DefaultLogsSpecs, + LogsSpecs, + SignalException, +) +from torch.distributed.elastic.multiprocessing.errors import ChildFailedError +from torch.distributed.elastic.rendezvous import RendezvousParameters +from torch.distributed.elastic.rendezvous.utils import parse_rendezvous_endpoint +from torch.distributed.elastic.utils.logging import get_logger + + +__all__ = ["LaunchConfig", "elastic_launch", "launch_agent"] + +logger = get_logger(__name__) + + +@dataclass +class LaunchConfig: + """ + Creates a rendezvous config. + + Args: + min_nodes: Minimum amount of nodes that the user function will + be launched on. Elastic agent ensures that the user + function start only when the min_nodes amount enters + the rendezvous. + max_nodes: Maximum amount of nodes that the user function + will be launched on. + nproc_per_node: On each node the elastic agent will launch + this amount of workers that will execute user + defined function. + rdzv_backend: rdzv_backend to use in the rendezvous (zeus-adapter, etcd). + rdzv_endpoint: The endpoint of the rdzv sync. storage. + rdzv_configs: Key, value pair that specifies rendezvous specific configuration. + rdzv_timeout: Legacy argument that specifies timeout for the rendezvous. It is going + to be removed in future versions, see the note below. The default timeout is 900 seconds. + run_id: The unique run id of the job (if not passed a unique one will be + deduced from run environment - flow workflow id in flow - or auto generated). + role: User defined role of the worker (defaults to "trainer"). + max_restarts: The maximum amount of restarts that elastic agent will conduct + on workers before failure. + monitor_interval: The interval in seconds that is used by the elastic_agent + as a period of monitoring workers. + start_method: The method is used by the elastic agent to start the + workers (spawn, fork, forkserver). + metrics_cfg: configuration to initialize metrics. + local_addr: address of the local node if any. If not set, a lookup on the local + machine's FQDN will be performed. + local_ranks_filter: ranks for which to show logs in console. If not set, show from all. + event_log_handler: name of the event logging handler as registered in + `elastic/events/handlers.py `_. + + + .. note:: + `rdzv_timeout` is a legacy argument that will be removed in future. + Set the timeout via `rdzv_configs['timeout']` + + """ + + min_nodes: int + max_nodes: int + nproc_per_node: int + logs_specs: Optional[LogsSpecs] = None + run_id: str = "" + role: str = "default_role" + rdzv_endpoint: str = "" + rdzv_backend: str = "etcd" + rdzv_configs: dict[str, Any] = field(default_factory=dict) + rdzv_timeout: int = -1 + max_restarts: int = 3 + monitor_interval: float = 0.1 + start_method: str = "spawn" + log_line_prefix_template: Optional[str] = None + metrics_cfg: dict[str, str] = field(default_factory=dict) + local_addr: Optional[str] = None + event_log_handler: str = "null" + + def __post_init__(self): + default_timeout = 900 + if self.rdzv_timeout != -1: + self.rdzv_configs["timeout"] = self.rdzv_timeout + elif "timeout" not in self.rdzv_configs: + self.rdzv_configs["timeout"] = default_timeout + + # Post-processing to enable refactoring to introduce logs_specs due to non-torchrun API usage + if self.logs_specs is None: + self.logs_specs = DefaultLogsSpecs() + + +class elastic_launch: + """ + Launches an torchelastic agent on the container that invoked the entrypoint. + + 1. Pass the ``entrypoint`` arguments as non ``kwargs`` (e.g. no named parameters)/ + ``entrypoint`` can be a function or a command. + 2. The return value is a map of each worker's output mapped + by their respective global rank. + + Usage + + :: + + def worker_fn(foo): + # ... + + def main(): + # entrypoint is a function. + outputs = elastic_launch(LaunchConfig, worker_fn)(foo) + # return rank 0's output + return outputs[0] + + # entrypoint is a command and ``script.py`` is the python module. + outputs = elastic_launch(LaunchConfig, "script.py")(args) + outputs = elastic_launch(LaunchConfig, "python")("script.py") + """ + + def __init__( + self, + config: LaunchConfig, + entrypoint: Union[Callable, str, None], + ): + self._config = config + self._entrypoint = entrypoint + + def __call__(self, *args): + return launch_agent(self._config, self._entrypoint, list(args)) + + +def _get_entrypoint_name( + entrypoint: Union[Callable, str, None], args: list[Any] +) -> str: + """Retrieve entrypoint name with the rule: + 1. If entrypoint is a function, use ``entrypoint.__qualname__``. + 2. If entrypoint is a string, check its value: + 2.1 if entrypoint equals to ``sys.executable`` (like "python"), use the first element from ``args`` + which does not start with hifen letter (for example, "-u" will be skipped). + 2.2 otherwise, use ``entrypoint`` value. + 3. Otherwise, return empty string. + """ + if isinstance(entrypoint, Callable): # type: ignore[arg-type] + return entrypoint.__name__ # type: ignore[union-attr] + elif isinstance(entrypoint, str): + if entrypoint == sys.executable: + return next((arg for arg in args if arg[0] != "-"), "") + else: + return entrypoint + else: + return "" + + +def _get_addr_and_port( + rdzv_parameters: RendezvousParameters, +) -> tuple[Optional[str], Optional[int]]: + if rdzv_parameters.backend != "static": + return (None, None) + endpoint = rdzv_parameters.endpoint + endpoint = endpoint.strip() + if not endpoint: + raise ValueError( + "Endpoint is missing in endpoint. Try to add --master-addr and --master-port" + ) + master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1) + if master_port == -1: + raise ValueError( + f"port is missing in endpoint: {endpoint}. Try to specify --master-port" + ) + return (master_addr, master_port) + + +def launch_agent( + config: LaunchConfig, + entrypoint: Union[Callable, str, None], + args: list[Any], +) -> dict[int, Any]: + if not config.run_id: + run_id = str(uuid.uuid4().int) + logger.warning("config has no run_id, generated a random run_id: %s", run_id) + config.run_id = run_id + + entrypoint_name = _get_entrypoint_name(entrypoint, args) + + logger.info( + "Starting elastic_operator with launch configs:\n" + " entrypoint : %(entrypoint)s\n" + " min_nodes : %(min_nodes)s\n" + " max_nodes : %(max_nodes)s\n" + " nproc_per_node : %(nproc_per_node)s\n" + " run_id : %(run_id)s\n" + " rdzv_backend : %(rdzv_backend)s\n" + " rdzv_endpoint : %(rdzv_endpoint)s\n" + " rdzv_configs : %(rdzv_configs)s\n" + " max_restarts : %(max_restarts)s\n" + " monitor_interval : %(monitor_interval)s\n" + " log_dir : %(log_dir)s\n" + " metrics_cfg : %(metrics_cfg)s\n" + " event_log_handler : %(event_log_handler)s\n", + { + "entrypoint": entrypoint_name, + "min_nodes": config.min_nodes, + "max_nodes": config.max_nodes, + "nproc_per_node": config.nproc_per_node, + "run_id": config.run_id, + "rdzv_backend": config.rdzv_backend, + "rdzv_endpoint": config.rdzv_endpoint, + "rdzv_configs": config.rdzv_configs, + "max_restarts": config.max_restarts, + "monitor_interval": config.monitor_interval, + "log_dir": config.logs_specs.root_log_dir, # type: ignore[union-attr] + "metrics_cfg": config.metrics_cfg, + "event_log_handler": config.event_log_handler, + }, + ) + + rdzv_parameters = RendezvousParameters( + backend=config.rdzv_backend, + endpoint=config.rdzv_endpoint, + run_id=config.run_id, + min_nodes=config.min_nodes, + max_nodes=config.max_nodes, + local_addr=config.local_addr, + **config.rdzv_configs, + ) + + master_addr, master_port = _get_addr_and_port(rdzv_parameters) + + spec = WorkerSpec( + role=config.role, + local_world_size=config.nproc_per_node, + entrypoint=entrypoint, + args=tuple(args), + rdzv_handler=rdzv_registry.get_rendezvous_handler(rdzv_parameters), + max_restarts=config.max_restarts, + monitor_interval=config.monitor_interval, + master_addr=master_addr, + master_port=master_port, + local_addr=config.local_addr, + event_log_handler=config.event_log_handler, + ) + + agent = LocalElasticAgent( + spec=spec, + logs_specs=config.logs_specs, # type: ignore[arg-type] + start_method=config.start_method, + log_line_prefix_template=config.log_line_prefix_template, + ) + + shutdown_rdzv = True + try: + metrics.initialize_metrics(metrics.MetricsConfig(config.metrics_cfg)) + + result = agent.run() + # records that agent.run() has succeeded NOT that workers have succeeded + events.record(agent.get_event_succeeded(), config.event_log_handler) + + if result.is_failed(): + # ChildFailedError is treated specially by @record + # if the error files for the failed children exist + # @record will copy the first error (root cause) + # to the error file of the launcher process. + raise ChildFailedError( + name=entrypoint_name, + failures=result.failures, + ) + + return result.return_values + except ChildFailedError: + raise + except SignalException: + # when the agent dies with a signal do NOT shutdown the rdzv_handler + # since this closes the rendezvous on this rdzv_id permanently and + # prevents any additional scaling events + shutdown_rdzv = False + events.record(agent.get_event_failed(), config.event_log_handler) + raise + except Exception: + events.record(agent.get_event_failed(), config.event_log_handler) + raise + finally: + if shutdown_rdzv: + spec.rdzv_handler.shutdown() diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/__init__.py b/phivenv/Lib/site-packages/torch/distributed/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c2a8b7da0dc1c1127eff964a85dd8e31fcab952d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/nn/__init__.py @@ -0,0 +1,7 @@ +import torch + +from .functional import * # noqa: F403 + + +if torch.distributed.rpc.is_available(): + from .api.remote_module import RemoteModule diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17f70280deb33e720be9603643ec6cfeaf167815 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/__pycache__/functional.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/__pycache__/functional.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66e0357c2866a95c49de95c8b9844884dfb53f4a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/__pycache__/functional.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/api/__init__.py b/phivenv/Lib/site-packages/torch/distributed/nn/api/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ac1d5eabb60f35bee0c1fe54ef12f28c34a3422 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/api/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfeb2547508503550391208fbb688b224ae9f81a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/api/__pycache__/remote_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/api/remote_module.py b/phivenv/Lib/site-packages/torch/distributed/nn/api/remote_module.py new file mode 100644 index 0000000000000000000000000000000000000000..dada346a471425679ab5a2e68fa699e5fd87bd79 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/nn/api/remote_module.py @@ -0,0 +1,754 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs +import collections +import io +import sys +import types +from collections.abc import Iterator, Mapping +from typing import Any, Callable, Optional, TypeVar, Union +from typing_extensions import Self + +import torch +import torch.distributed.rpc as rpc +from torch import device, dtype, nn, Tensor +from torch.distributed import _remote_device +from torch.distributed.nn.jit import instantiator +from torch.distributed.rpc.internal import _internal_rpc_pickler +from torch.nn import Module +from torch.nn.parameter import Parameter +from torch.utils.hooks import RemovableHandle + + +__all__ = ["RemoteModule"] + +_grad_t = Union[tuple[Tensor, ...], Tensor] +# See https://mypy.readthedocs.io/en/latest/generics.html#generic-methods-and-generic-self for the use +# of `T` to annotate `self`. Many methods of `Module` return `self` and we want those return values to be +# the type of the subclass, not the looser type of `Module`. +T = TypeVar("T", bound="Module") + +_NON_SCRIPTABLE_REMOTE_MODULE_MODULE = ( + instantiator.instantiate_non_scriptable_remote_module_template() +) + +_REMOTE_MODULE_PICKLED_ATTRIBUTES = ( + "on", + "device", + "is_device_map_set", + "is_scriptable", + "generated_methods", + "module_rref", +) + +_SerializedRemoteModule = collections.namedtuple( # type: ignore[misc] + "_SerializedRemoteModule", + _REMOTE_MODULE_PICKLED_ATTRIBUTES, +) + +# These attributes are mostly from RemoteModule's parent class and are intentionally not pickled. +# A new attribute of RemoteModule should be either in _REMOTE_MODULE_PICKLED_ATTRIBUTES +# or _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. +# Otherwise, it will not be pickled. +_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING = ( + "training", + "_parameters", + "_buffers", + "_non_persistent_buffers_set", + "_backward_hooks", + "_backward_pre_hooks", + "_is_full_backward_hook", + "_forward_hooks", + "_forward_hooks_with_kwargs", + "_forward_hooks_always_called", + "_forward_pre_hooks", + "_forward_pre_hooks_with_kwargs", + "_state_dict_hooks", + "_state_dict_pre_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_state_dict_pre_hooks", + "_modules", + # The two attributes below are generated methods, not available at pickling time. + "forward_async", + "forward", +) + + +# RPC handler. +def _instantiate_template(module_interface_cls, enable_moving_cpu_tensors_to_cuda): + instantiator.instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + + +def _create_module(module_cls, args, kwargs, device): + module = module_cls(*args, **kwargs) + if not isinstance(module, nn.Module): + raise ValueError( + "Expect `module_cls(*args, **kwargs)` returns an instance of , " + f"but it returns an instance of {type(module)}." + ) + module.to(device) + return module + + +def _create_module_with_interface( + module_cls, args, kwargs, device, module_interface_cls +): + module = _create_module(module_cls, args, kwargs, device) + if module_interface_cls is not None: + module = torch.jit.script(module) + return rpc.RRef(module, module_interface_cls) + + +def _param_rrefs(module_rref, recurse) -> list[rpc.RRef[Parameter]]: + ret: list[rpc.RRef[Parameter]] = [ + rpc.RRef(param) for param in module_rref.local_value().parameters(recurse) + ] + return ret + + +def _raise_not_supported(name: str) -> None: + raise ValueError(f"Method ``{name}`` not supported for RemoteModule") + + +class _RemoteModule(nn.Module): + def __new__(cls, *args, **kwargs): + # Use __new__ for logging purposes. + torch._C._log_api_usage_once("torch.distributed.nn.api.remote_module") + return super().__new__(cls) + + def __init__( + self, + remote_device: str, + module_cls: type[nn.Module], + args: Optional[tuple] = None, + kwargs: Optional[dict[str, Any]] = None, + _module_interface_cls: Any = None, + ): + """ + RemoteModule instance can only be created after RPC initialization. + + It creates a user-specified module on a specified remote node. + It behaves like a regular ``nn.Module`` except that the ``forward`` method is + executed on the remote node. + It takes care of autograd recording to ensure the backward pass propagates + gradients back to the corresponding remote module. + It can be shared across processors using `RPC framework `__, + without incurring any overheads of copying the actual module, + which is equivalent to an :class:`~torch.distributed.rpc.RRef` + pointing to the remote module. + + The arguments of ``forward_async`` and ``forward`` are the same as + the ``forward`` method of the module returned by the ``module_cls``. + + Apart from ``forward_async`` and ``forward``, no other methods are supported from nn.Module for now. + + Particularly, to create a hybrid model, typically the local modules should be + created outside of remote modules, rather than as submodules of any remote module (by calling ``add_module``). + Hybrid Example: + >>> class HybridModel(nn.Module): + >>> def __init__(self) -> None: + >>> nn.Module.__init__(self) + >>> self.remote_embedding = RemoteModule(...) + >>> self.local_linear = nn.Linear(...) + + For example, if ``module_cls`` returns an instance of ``nn.Linear``, + that has ``forward`` method signature, ``def forward(input: Tensor) -> Tensor:``, + the generated ``RemoteModule`` will have 2 methods in signature of + ``def forward(input: Tensor) -> Tensor:`` and + ``def forward_async(input: Tensor) -> Future[Tensor]:``. + + .. note:: + If the remote module is placed on a cuda device, + any input CPU tensors will be automatically moved to the same cuda device, + and GPU tensors are returned over the wire according to the device map of the remote worker on TensorPipe RPC backend. + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + In addition, the device field can be optional and the default value is "cpu". + module_cls (nn.Module): For example, + >>> class MyModule(nn.Module): + >>> def forward(input): + >>> return input + 1 + >>> + >>> module_cls = MyModule + args (Sequence, optional): args to be passed to ``module_cls``. + kwargs (Dict, optional): kwargs to be passed to ``module_cls``. + _module_interface_cls (type, optional): The TorchScript interface type for the module + to be created. The type object should be decorated by @torch.jit.interface. + If not provided, the generated RemoteModule is not torchscript-able. + Warning, this is an experimental API and susceptible to frequent changes. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_cls``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> from torch import nn, Tensor + >>> from torch.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_linear_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> input = torch.randn(128, 20) + >>> ret_fut = remote_linear_module.forward_async(input) + >>> ret = ret_fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + super().__init__() + + enable_moving_cpu_tensors_to_cuda = self._prepare_init(remote_device) + + # Default arguments preparation. + args = args if args is not None else () + kwargs = kwargs if kwargs is not None else {} + + if _module_interface_cls is not None: + # Users reply on this field to know if this generated RemoteModule is TorchScript-able. + self.is_scriptable = True + + # Instantiate template on remote side. + fut = rpc.rpc_async( + self.on, + _instantiate_template, + (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), + ) + + self._init_template( + _module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + + # Instantiate template on remote side. + fut = rpc.rpc_async( + self.on, + _instantiate_template, + (_module_interface_cls, enable_moving_cpu_tensors_to_cuda), + ) + + # Create the module on the remote side. + fut.wait() # Ensure remote_module_cls is available on remote side. + + # TODO: We need to change this to rpc.remote, and make it async (see the else branch below). + # For that we need to be able to apply _module_interface_cls to the RRef returned by rpc.remote + # See https://github.com/pytorch/pytorch/issues/58098 for more context. + self.module_rref = rpc.rpc_sync( + self.on, + _create_module_with_interface, + (module_cls, args, kwargs, self.device, _module_interface_cls), + ) + else: + self.is_scriptable = False + self.generated_methods = ( + _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods + ) + # Create the module on the remote side. + self.module_rref = rpc.remote( + self.on, + _create_module, + (module_cls, args, kwargs, self.device), + ) + + self._install_generated_methods() + self._check_attribute_picklability() + + def remote_parameters(self, recurse: bool = True) -> list[rpc.RRef[Parameter]]: + """ + Return a list of :class:`~torch.distributed.rpc.RRef` pointing to the remote module's parameters. + + This can typically be used in conjunction + with :class:`~torch.distributed.optim.DistributedOptimizer`. + + Args: + recurse (bool): if True, then returns parameters of the remote + module and all submodules of the remote module. Otherwise, + returns only parameters that are direct members of the + remote module. + + Returns: + A list of :class:`~torch.distributed.rpc.RRef` (``List[RRef[nn.Parameter]]``) + to remote module's parameters. + """ + return rpc.rpc_sync(self.on, _param_rrefs, args=(self.module_rref, recurse)) + + def get_module_rref(self) -> rpc.RRef[nn.Module]: + """Return an :class:`~torch.distributed.rpc.RRef` (``RRef[nn.Module]``) pointing to the remote module.""" + return self.module_rref + + @torch.jit.export + def __getstate__(self): + raise RuntimeError( + "Cannot pickle RemoteModule in python pickler. RemoteModule can only be pickled when using RPC" + ) + + @torch.jit.export + def __setstate__(self, state): + raise RuntimeError( + "Cannot unpickle RemoteModule in python pickler. RemoteModule can only be unpickled when using RPC" + ) + + def register_buffer( + self, name: str, tensor: Optional[Tensor], persistent: bool = True + ) -> None: + _raise_not_supported(self.register_buffer.__name__) + + def register_parameter(self, name: str, param: Optional[Parameter]) -> None: + _raise_not_supported(self.register_parameter.__name__) + + def add_module(self, name: str, module: Optional[Module]) -> None: + _raise_not_supported(self.add_module.__name__) + + def apply(self, fn: Callable[[Module], None]) -> Self: # type: ignore[return] + _raise_not_supported(self.apply.__name__) + + def cuda(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return] + _raise_not_supported(self.cuda.__name__) + + def ipu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return] + _raise_not_supported(self.ipu.__name__) + + def xpu(self, device: Optional[Union[int, device]] = None) -> Self: # type: ignore[return] + _raise_not_supported(self.xpu.__name__) + + def cpu(self) -> Self: # type: ignore[return] + _raise_not_supported(self.cpu.__name__) + + def type(self, dst_type: Union[dtype, str]) -> Self: # type: ignore[return] + _raise_not_supported(self.type.__name__) + + def float(self) -> Self: # type: ignore[return] + _raise_not_supported(self.float.__name__) + + def double(self) -> Self: # type: ignore[return] + _raise_not_supported(self.double.__name__) + + def half(self) -> Self: # type: ignore[return] + _raise_not_supported(self.half.__name__) + + def bfloat16(self) -> Self: # type: ignore[return] + _raise_not_supported(self.bfloat16.__name__) + + def to(self, *args, **kwargs) -> T: # type: ignore[misc, return, type-var] + _raise_not_supported(self.to.__name__) + + def register_backward_hook( # type: ignore[return] + self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]] + ) -> RemovableHandle: + _raise_not_supported(self.register_backward_hook.__name__) + + def register_forward_pre_hook( # type: ignore[return] + self, + hook: Union[ + Callable[[T, tuple[Any, ...]], Optional[Any]], + Callable[ + [T, tuple[Any, ...], dict[str, Any]], + Optional[tuple[Any, dict[str, Any]]], + ], + ], + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + _raise_not_supported(self.register_forward_pre_hook.__name__) + + def register_forward_hook( # type: ignore[return, override] + self, + hook: Union[ + Callable[[T, tuple[Any, ...], Any], Optional[Any]], + Callable[[T, tuple[Any, ...], dict[str, Any], Any], Optional[Any]], + ], + prepend: bool = False, + with_kwargs: bool = False, + ) -> RemovableHandle: + _raise_not_supported(self.register_forward_hook.__name__) + + def state_dict(self, *args, **kwargs): + _raise_not_supported(self.state_dict.__name__) + + def load_state_dict( + self, + state_dict: Mapping[str, Any], + strict: bool = True, + assign: bool = False, + ): + _raise_not_supported(self.load_state_dict.__name__) + + def parameters(self, recurse: bool = True) -> Iterator[Parameter]: + raise ValueError( + "Method ``parameters`` not supported for RemoteModule. Please use ``remote_parameters`` instead." + ) + + def named_parameters( # type: ignore[return] + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Parameter]]: + _raise_not_supported(self.named_parameters.__name__) + + def buffers(self, recurse: bool = True) -> Iterator[Tensor]: # type: ignore[return] + _raise_not_supported(self.buffers.__name__) + + def named_buffers( # type: ignore[return] + self, prefix: str = "", recurse: bool = True, remove_duplicate: bool = True + ) -> Iterator[tuple[str, Tensor]]: + _raise_not_supported(self.named_buffers.__name__) + + def children(self) -> Iterator[Module]: # type: ignore[return] + _raise_not_supported(self.children.__name__) + + def named_children(self) -> Iterator[tuple[str, Module]]: # type: ignore[return] + _raise_not_supported(self.named_children.__name__) + + def modules(self) -> Iterator[Module]: # type: ignore[return] + _raise_not_supported(self.modules.__name__) + + def named_modules( + self, + memo: Optional[set[Module]] = None, + prefix: str = "", + remove_duplicate: bool = True, + ): + _raise_not_supported(self.named_modules.__name__) + + def train(self, mode: bool = True) -> Self: + return self.module_rref.rpc_sync().train() # type: ignore[operator, union-attr] + + def eval(self) -> Self: + return self.module_rref.rpc_sync().eval() # type: ignore[operator, union-attr] + + def requires_grad_(self, requires_grad: bool = True) -> Self: # type: ignore[return] + _raise_not_supported(self.requires_grad_.__name__) + + def zero_grad(self, set_to_none: bool = True) -> None: + _raise_not_supported(self.zero_grad.__name__) + + def share_memory(self) -> Self: # type: ignore[return] + _raise_not_supported(self.share_memory.__name__) + + def extra_repr(self) -> str: # type: ignore[return] + _raise_not_supported(self.extra_repr.__name__) + + def _prepare_init(self, remote_device_str: str) -> bool: + """Prepare the initialization and returns whether to enable automatically moving CPU tensors to CUDA devices.""" + # Sanity check. + assert rpc._is_current_rpc_agent_set(), "RemoteModule only works in RPC." + + remote_device = _remote_device(remote_device_str) + self.on = ( + remote_device.worker_name() + if remote_device.worker_name() is not None + else remote_device.rank() + ) + self.device = str(remote_device.device()) + agent = rpc._get_current_rpc_agent() + # If the device map of the remote worker is set, + # then enable moving any input CPU tensors to the same cuda device. + self.is_device_map_set = bool( + agent._get_device_map(agent.get_worker_info(self.on)) # type: ignore[arg-type] + ) + # ``enable_moving_cpu_tensors_to_cuda`` is less strict than ``is_device_map_set``: + # If ``enable_moving_cpu_tensors_to_cuda`` is true, but the device map is not set, + # then any CPU tensors can still be moved to a cuda device to run forward, + # but the output must be moved back to CPU before being sent over the wire. + enable_moving_cpu_tensors_to_cuda = torch.device(self.device).type == "cuda" + return enable_moving_cpu_tensors_to_cuda + + def _init_template(self, module_interface_cls, enable_moving_cpu_tensors_to_cuda): + """Instantiate template on local side.""" + generated_module = instantiator.instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + self.generated_methods = generated_module._generated_methods + + def _check_attribute_picklability(self): + """Check if all the attribute has explicitly defined whether to be pickled (i.e., picklability).""" + for k in self.__dict__.keys(): + if ( + k not in _REMOTE_MODULE_PICKLED_ATTRIBUTES + and k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING + ): + raise AttributeError( + f"Attribute {k} must be either in ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` or " + "``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``." + ) + + def _install_generated_methods(self): + for method in self.generated_methods: + method_name = method.__name__ + method = torch.jit.export(method) + setattr(self, method_name, types.MethodType(method, self)) + + @staticmethod + def init_from_module_rref( + remote_device: str, + module_rref: rpc.RRef[nn.Module], + _module_interface_cls: Any = None, + ): + """ + Besides the constructor, a RemoteModule instance can also be initialized given a module RRef. + + This alternate initialization method can be particularly useful if we want to create multiple + RemoteModule instances that share the same underlying module and reduce memory consumption. + + Moreover, this also provides a workaround for passing script RemoteModule over RPC, + which is not supported. The recommended way is as follows: + + 1. the sender creates a RemoteModule; + 2. the sender sends its ``module_rref`` over RPC; + 3. the receiver calls this method to initialize another RemoteModule using the same ``module_rref``. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> from torch import nn, Tensor + >>> from torch.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> + >>> remote_module1 = rpc.rpc_sync( + >>> "worker1/cpu", + >>> RemoteModule.init_from_module_rref, + >>> ("worker1/cpu", remote_module1.get_module_rref()), + >>> ) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The device can be a local device or a remote device specified by one of the following remote + formats: + + 1. "rank:/" (ex: "rank:0/cuda:0"). + 2. "/" (ex: "trainer0/cuda:0"). + + In addition, the device field can be optional and the default value is "cpu". + module_rref (RRef[nn.Module]): The module reference shared by both the caller and + the created remote module. + _module_interface_cls (type, optional): The TorchScript interface type for the module + to be created. The type object should be decorated by @torch.jit.interface. + If not provided, the generated RemoteModule is not torchscript-able. + Warning, this is an experimental API and susceptible to frequent changes. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_rref``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + """ + # NOTE: if a new attribute is added to this class, also need to add it + # to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES`` for pickling/unpickling. + + remote_module = object.__new__(RemoteModule) + + enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device) + + if _module_interface_cls is not None: + # Users reply on this field to know if this generated RemoteModule is TorchScript-able. + remote_module.is_scriptable = True + + remote_module._init_template( + _module_interface_cls, enable_moving_cpu_tensors_to_cuda + ) + else: + remote_module.is_scriptable = False + remote_module.generated_methods = ( + _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods + ) + remote_module.module_rref = module_rref + + remote_module._install_generated_methods() + remote_module._check_attribute_picklability() + + return remote_module + + +class RemoteModule(_RemoteModule): + """ + A RemoteModule instance can only be created after RPC initialization. + + It creates a user-specified module on a specified remote node. + It behaves like a regular ``nn.Module`` except that the ``forward`` method is + executed on the remote node. + It takes care of autograd recording to ensure the backward pass propagates + gradients back to the corresponding remote module. + + It generates two methods ``forward_async`` and ``forward`` based on the + signature of the ``forward`` method of ``module_cls``. ``forward_async`` + runs asynchronously and returns a Future. The arguments of ``forward_async`` + and ``forward`` are the same as the ``forward`` method of the module + returned by the ``module_cls``. + + For example, if ``module_cls`` returns an instance of ``nn.Linear``, + that has ``forward`` method signature: ``def forward(input: Tensor) -> Tensor:``, + the generated ``RemoteModule`` will have 2 methods with the signatures: + + | ``def forward(input: Tensor) -> Tensor:`` + | ``def forward_async(input: Tensor) -> Future[Tensor]:`` + + Args: + remote_device (str): Device on the destination worker where we'd like to place this module. + The format should be "/", where the device field can be parsed as torch.device type. + E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". + In addition, the device field can be optional and the default value is "cpu". + module_cls (nn.Module): Class for the module to be created remotely. For example, + + >>> class MyModule(nn.Module): + >>> def forward(input): + >>> return input + 1 + >>> + >>> module_cls = MyModule + + args (Sequence, optional): args to be passed to ``module_cls``. + kwargs (Dict, optional): kwargs to be passed to ``module_cls``. + + Returns: + A remote module instance which wraps the :class:`~nn.Module` created by the + user-provided ``module_cls``, it has a blocking ``forward`` method and an + asynchronous ``forward_async`` method that returns a future of the ``forward`` call + on the user-provided module on the remote side. + + Example:: + Run the following code in two different processes: + + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> from torch import nn, Tensor + >>> from torch.distributed.nn.api.remote_module import RemoteModule + >>> + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> remote_linear_module = RemoteModule( + >>> "worker1/cpu", nn.Linear, args=(20, 30), + >>> ) + >>> input = torch.randn(128, 20) + >>> ret_fut = remote_linear_module.forward_async(input) + >>> ret = ret_fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Furthermore, a more practical example that is combined with + `DistributedDataParallel `__ (DDP) + can be found in this `tutorial `__. + """ + + def __init__( + self, + remote_device: str, + module_cls: type[nn.Module], + args: Optional[tuple] = None, + kwargs: Optional[dict[str, Any]] = None, + ): + super().__init__(remote_device, module_cls, args, kwargs) + + +def _remote_module_receiver( + *remote_module_pickled_attrs, +): + """Deserializes a RemoteModule.""" + serialized_remote_module = _SerializedRemoteModule._make( + remote_module_pickled_attrs + ) + m = object.__new__(RemoteModule) + m.__dict__.update(serialized_remote_module._asdict()) + + # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method. + m.module_rref = rpc.PyRRef._deserialize(m.module_rref) + + # Install generated methods when unpickled. + for method in m.generated_methods: + method_name = method.__name__ + method = torch.jit.export(method) + setattr(m, method_name, types.MethodType(method, m)) + + return m + + +def _remote_module_reducer(remote_module): + """Serialize a RemoteModule.""" + pickled_attrs = {} + for k, v in remote_module.__dict__.items(): + # Pickling the attribute `module_rref` must invoke RRef's `_serialize()` method. + if k == "module_rref": + pickled_attrs[k] = v._serialize() + elif k in _REMOTE_MODULE_PICKLED_ATTRIBUTES: + pickled_attrs[k] = v + # Check if unpickled attributes are all in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING. + elif k not in _REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING: + print( + f"The new attribute ``{k}`` of RemoteModule is ignored during RPC pickling. " + "To pickle this attribute, please add it to ``_REMOTE_MODULE_PICKLED_ATTRIBUTES``. " + "Otherwise, please explicitly add it to ``_REMOTE_MODULE_ATTRIBUTES_IGNORE_FOR_PICKLING``.", + file=sys.stderr, + ) + + return ( + _remote_module_receiver, + tuple(pickled_attrs.values()), + ) + + +def _recursive_script_module_receiver( + recursive_script_module_serialized, +): + """Deserializes a RecursiveScriptModule that does not contain a script RemoteModule.""" + f = io.BytesIO(recursive_script_module_serialized) + m = torch.jit.load(f) + return m + + +def _recursive_script_module_reducer(recursive_script_module): + """Serialize a RecursiveScriptModule that does not contain a script RemoteModule, and raises an error otherwise.""" + if hasattr(recursive_script_module._c, "module_rref"): + raise RuntimeError( + "Passing a script RemoteModule over RPC is not supported. Please create a RemoteModule in the sender, " + "send the `module_rref` to the receiver, and create a new instance on the receiver end by passing this `module_rref`." + ) + + f = io.BytesIO() + torch.jit.save(recursive_script_module, f) + return (_recursive_script_module_receiver, (f.getvalue(),)) + + +_internal_rpc_pickler._register_reducer(RemoteModule, _remote_module_reducer) +_internal_rpc_pickler._register_reducer( + torch.jit.RecursiveScriptModule, _recursive_script_module_reducer +) diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/functional.py b/phivenv/Lib/site-packages/torch/distributed/nn/functional.py new file mode 100644 index 0000000000000000000000000000000000000000..aa50b80c09b9611aa3d11687ffbb05bdf085af48 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/nn/functional.py @@ -0,0 +1,452 @@ +# mypy: allow-untyped-defs +import torch +import torch.distributed as dist +from torch.autograd import Function + +# The two imports below are not always available depending on the +# USE_DISTRIBUTED compile flag. Make sure they raise import error +# if we're trying to use them. +from torch.distributed import group, ReduceOp + + +def broadcast(tensor, src, group=group.WORLD): + """ + Broadcasts the tensor to the whole group. + + ``tensor`` must have the same number of elements in all processes + participating in the collective. + + Arguments: + tensor (Tensor): Data to be sent if ``src`` is the rank of current + process. + src (int): Source rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Received tensor from the broadcast op. + + """ + return _Broadcast.apply(src, group, tensor) + + +def gather(tensor, dst=0, group=group.WORLD): + """ + Gathers a list of tensors in a single process. + + Arguments: + tensor (Tensor): Input tensor. + dst (int, optional): Destination rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple[Tensor]: List of appropriately-sized tensors with the gathered data. + """ + return _Gather.apply(dst, group, tensor) + + +def scatter(tensors, src=0, group=group.WORLD): + """ + Scatters a list of tensors to all processes in a group. + + Each process will receive exactly one tensor and store its data in the + ``tensor`` argument. + + Arguments: + tensors (list[Tensor]): List of tensors to scatter on the source rank. + Receivers must pass ``None`. + src (int, optional): Source rank (default is 0). + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output tensor from the scatter operation. + + """ + return _Scatter.apply(src, group, *tensors) + + +def reduce(tensor, dst, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines. + + Only the process with rank ``dst`` is going to receive the final result. + + Arguments: + tensor (Tensor): Input of the collective. + dst (int): Destination rank. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce.apply(dst, op, group, tensor) + + +def reduce_scatter(output, input_list, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces, then scatters a list of tensors to all processes in a group. + + Arguments: + output (Tensor): Output tensor. + input_list (list[Tensor]): List of tensors to reduce and scatter. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective. + + """ + return _Reduce_Scatter.apply(op, group, output, *input_list) + + +def all_gather(tensor, group=group.WORLD): + """ + Gathers tensors from the whole group in a list. + + Arguments: + tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AllGather.apply(group, tensor) + + +def _all_gather_base(output_tensor, input_tensor, group=group.WORLD): + """ + Single tensor all gather. Gathers a single tensor from all ranks, and puts them in a single output tensor. + + Args: + output_tensor (Tensor): Output tensor. It should contain + correctly-sized tensors to be used for output of the collective. + input_tensor (Tensor): Tensor to be broadcast from current process. + group (ProcessGroup, optional): The process group to work on. If None, + the default process group will be used. + + Examples: + >>> # All tensors below are of torch.int64 dtype. + >>> # We have 2 process groups, 2 ranks. + >>> # xdoctest: +SKIP("incorrect want text") + >>> output_tensor = torch.zeros(2, dtype=torch.int64) + >>> output_tensor + [tensor([0, 0])] # Rank 0 and 1 + >>> tensor = torch.arange(1, dtype=torch.int64) + 1 + rank + >>> tensor + tensor([1]) # Rank 0 + tensor([2]) # Rank 1 + >>> dist.all_gather_base(output_tensor, tensor) + >>> output_tensor + tensor([1,2]) # Rank 0 + tensor([1,2]) # Rank 1 + + .. warning:: + `_all_gather_base` is experimental and subject to change. + It is the caller's responsibility to ensure the output_tensor + is correctly sized. + + """ + return _AllGatherBase.apply(output_tensor, input_tensor, group) + + +def all_to_all(output_tensor_list, input_tensor_list, group=group.WORLD): + """ + Each process scatters list of input tensors to all processes in a group and return gathered list of tensors in output list. + + Arguments: + output_tensor_list (list[Tensor]): list of tensors to gather one per rank. + input_tensor_list (list[Tensor]): List of tensors to scatter one per rank. + group (ProcessGroup, optional): The process group to work on. + + Returns: + tuple([Tensor]): Output of the collective. + + """ + return _AlltoAll.apply(group, output_tensor_list, *input_tensor_list) + + +def all_to_all_single( + output, + input, + output_split_sizes=None, + input_split_sizes=None, + group=group.WORLD, +): + """ + Each process splits input tensor and then scatters the split list to all processes in a group. + + Then concatenate the received tensors from all the processes in the group and return single output tensor. + + Arguments: + output (Tensor): Gathered concatenated output tensor. + input (Tensor): Input tensor to scatter. + output_split_sizes: (list[Int], optional): Output split sizes for dim 0 + if specified None or empty, dim 0 of ``output`` tensor must divide + equally by ``world_size``. + input_split_sizes: (list[Int], optional): Input split sizes for dim 0 + if specified None or empty, dim 0 of ``input`` tensor must divide + equally by ``world_size``. + + Returns: + Tensor: Output of the collective. + + """ + return _AlltoAllSingle.apply( + group, output, output_split_sizes, input_split_sizes, input + ) + + +def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): + """ + Reduces the tensor data across all machines in such a way that all get the final result. + + After the call the returned tensor is going to be bitwise + identical in all processes. + + Arguments: + tensor (Tensor): Input of the collective. + op (optional): One of the values from + ``torch.distributed.ReduceOp`` + enum. Specifies an operation used for element-wise reductions. + group (ProcessGroup, optional): The process group to work on. + + Returns: + Tensor: Output of the collective + + """ + return _AllReduce.apply(op, group, tensor) + + +class _Broadcast(Function): + @staticmethod + def forward(ctx, src, group, tensor): + ctx.src = src + ctx.group = group + ctx.rank = dist.get_rank(group=group) + # torch.distributed makes all the calls in place + # we allocate new tensors to avoid this + tensor = tensor.clone() + dist.broadcast(tensor, src, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) + if ctx.src != ctx.rank: + gx.zero_() + return (None, None, gx) + + +class _Gather(Function): + @staticmethod + def forward(ctx, dst, group, tensor): + ctx.dst = dst + ctx.group = group + # Need to create a list of tensors here to do the + # aggregation, get it from the group size + # tensor should be correctly sized for the method + # gathering + tensor_list = [ + torch.zeros_like(tensor) for i in range(dist.get_world_size(group=group)) + ] + + tensor = tensor.contiguous() + if dist.get_rank(group=group) == dst: + dist.gather(tensor, tensor_list, dst, group=group) + else: + dist.gather(tensor, None, dst, group=group) + return tuple(tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + return (None, None) + (_Scatter.apply(ctx.dst, ctx.group, *grad_outputs),) + + +class _Scatter(Function): + @staticmethod + def forward(ctx, src, group, *tensors): + ctx.src = src + ctx.group = group + assert all(t.size() == tensors[0].size() for t in tensors) + output = torch.zeros_like(tensors[0]) + if dist.get_rank(group=group) == src: + dist.scatter(output, list(tensors), src, group=group) + else: + dist.scatter(output, None, src, group=group) + return output + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) + + +class _Reduce(Function): + @staticmethod + def forward(ctx, src, op, group, tensor): + ctx.src = src + ctx.group = group + tensor = tensor.clone() + dist.reduce(tensor, src, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) + + +class _Reduce_Scatter(Function): + @staticmethod + def forward(ctx, op, group, tensor, *input_tensor_list): + ctx.group = group + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + input_tensor_list = tuple(t.contiguous() for t in input_tensor_list) + dist.reduce_scatter(tensor, list(input_tensor_list), op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None, None) + _AllGather.apply(ctx.group, grad_output) + + +class _AllGather(Function): + @staticmethod + def forward(ctx, group, tensor): + # Need contiguous tensors for collectives. + tensor = tensor.contiguous() + + ctx.group = group + out_tensor_list = [ + torch.empty_like(tensor) for _ in range(dist.get_world_size(group=group)) + ] + + dist.all_gather(out_tensor_list, tensor, group=group) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + rank = dist.get_rank(group=ctx.group) + gx = torch.empty_like(grad_outputs[rank]) + gx = _Reduce_Scatter.apply(ReduceOp.SUM, ctx.group, gx, *grad_outputs) + else: + # As many backends doesn't support ReduceScatter, we use AlltoAll with .sum() + # to emulate the ReduceScatter behavior + tensor_list = [torch.empty_like(tensor) for tensor in grad_outputs] + gxs = _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + gx = torch.sum(torch.stack(gxs), dim=0) + return (None, gx) + + +class _AllGatherBase(Function): + @staticmethod + def forward(ctx, output_tensor, input_tensor, group): + ctx.group = group + dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) + return output_tensor + + @staticmethod + def backward(ctx, grad_output): + if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: + world_size = dist.get_world_size(group=ctx.group) + out_size = list(grad_output.size()) + if out_size[0] % world_size != 0: + raise RuntimeError( + f"Tensor with dimensions: {out_size} does " + f"not have first dimension divisible by world_size: {world_size}" + ) + out_size[0] = out_size[0] // dist.get_world_size(group=ctx.group) + gx = torch.empty( + out_size, device=grad_output.device, dtype=grad_output.dtype + ) + dist._reduce_scatter_base(gx, grad_output, ReduceOp.SUM, ctx.group) + else: + raise RuntimeError("Backend not supported!") + return (None, gx, None) + + +class _AlltoAll(Function): + @staticmethod + def forward(ctx, group, out_tensor_list, *tensors): + ctx.group = group + ctx.input_tensor_size_list = [ + tensors[i].size() for i in range(dist.get_world_size(group=group)) + ] + my_rank = dist.get_rank(group=group) + tensors = tuple(t.contiguous() for t in tensors) + # Implement it on means of scatter/gather, send/recv async operations have issues + if dist.get_backend(group=group) is dist.Backend.GLOO: + for i in range(dist.get_world_size(group=group)): + to_send = None + if i == my_rank: + to_send = list(tensors) + dist.scatter(out_tensor_list[i], to_send, i, group=group) + else: + dist.all_to_all( + out_tensor_list, + list(tensors), + group=group, + ) + return tuple(out_tensor_list) + + @staticmethod + def backward(ctx, *grad_outputs): + tensor_list = [ + torch.empty( + size, device=grad_outputs[0].device, dtype=grad_outputs[0].dtype + ) + for size in ctx.input_tensor_size_list + ] + return (None, None) + _AlltoAll.apply(ctx.group, tensor_list, *grad_outputs) + + +class _AlltoAllSingle(Function): + @staticmethod + def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): + ctx.group = group + ctx.input_size = input.size() + ctx.output_split_sizes = input_split_sizes + ctx.input_split_sizes = output_split_sizes + dist.all_to_all_single( + output, + input, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=group, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + tensor = torch.empty( + ctx.input_size, device=grad_output.device, dtype=grad_output.dtype + ) + return (None, None, None, None) + ( + _AlltoAllSingle.apply( + ctx.group, + tensor, + ctx.output_split_sizes, + ctx.input_split_sizes, + grad_output.contiguous(), + ), + ) + + +class _AllReduce(Function): + @staticmethod + def forward(ctx, op, group, tensor): + ctx.group = group + ctx.op = op + tensor = tensor.clone(memory_format=torch.contiguous_format) + dist.all_reduce(tensor, op=op, group=group) + return tensor + + @staticmethod + def backward(ctx, grad_output): + return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/__init__.py b/phivenv/Lib/site-packages/torch/distributed/nn/jit/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de89a0f70b2b8bb772206085fd86d915d70d6cfb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/jit/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..744d38af1d4a74e107056e31dd1ccb7c3354850a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/jit/__pycache__/instantiator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/instantiator.py b/phivenv/Lib/site-packages/torch/distributed/nn/jit/instantiator.py new file mode 100644 index 0000000000000000000000000000000000000000..e4746fa2ac22ddfca5ca9f1a535c359dd67aa771 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/nn/jit/instantiator.py @@ -0,0 +1,156 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs +import atexit +import importlib +import logging +import os +import sys +import tempfile +from typing import Optional + +import torch +from torch.distributed.nn.jit.templates.remote_module_template import ( + get_remote_module_template, +) + + +logger = logging.getLogger(__name__) + + +_FILE_PREFIX = "_remote_module_" +_TEMP_DIR = tempfile.TemporaryDirectory() +INSTANTIATED_TEMPLATE_DIR_PATH = _TEMP_DIR.name +atexit.register(_TEMP_DIR.cleanup) +logger.info("Created a temporary directory at %s", INSTANTIATED_TEMPLATE_DIR_PATH) +sys.path.append(INSTANTIATED_TEMPLATE_DIR_PATH) + + +def get_arg_return_types_from_interface(module_interface): + assert getattr(module_interface, "__torch_script_interface__", False), ( + "Expect a TorchScript class interface decorated by @torch.jit.interface." + ) + qualified_name = torch._jit_internal._qualified_name(module_interface) + cu = torch.jit._state._python_cu + module_interface_c = cu.get_interface(qualified_name) + assert "forward" in module_interface_c.getMethodNames(), ( + f"Expect forward in interface methods, while it has {module_interface_c.getMethodNames()}" + ) + method_schema = module_interface_c.getMethod("forward") + + arg_str_list = [] + arg_type_str_list = [] + assert method_schema is not None + for argument in method_schema.arguments: + arg_str_list.append(argument.name) + + if argument.has_default_value(): + default_value_str = f" = {argument.default_value}" + else: + default_value_str = "" + arg_type_str = f"{argument.name}: {argument.type}{default_value_str}" + arg_type_str_list.append(arg_type_str) + + arg_str_list = arg_str_list[1:] # Remove "self". + args_str = ", ".join(arg_str_list) + + arg_type_str_list = arg_type_str_list[1:] # Remove "self". + arg_types_str = ", ".join(arg_type_str_list) + + assert len(method_schema.returns) == 1 + argument = method_schema.returns[0] + return_type_str = str(argument.type) + + return args_str, arg_types_str, return_type_str + + +def _write(out_path, text): + old_text: Optional[str] + try: + with open(out_path) as f: + old_text = f.read() + except OSError: + old_text = None + if old_text != text: + with open(out_path, "w") as f: + logger.info("Writing %s", out_path) + f.write(text) + else: + logger.info("Skipped writing %s", out_path) + + +def _do_instantiate_remote_module_template( + generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda +): + generated_code_text = get_remote_module_template( + enable_moving_cpu_tensors_to_cuda + ).format(**str_dict) + out_path = os.path.join( + INSTANTIATED_TEMPLATE_DIR_PATH, f"{generated_module_name}.py" + ) + _write(out_path, generated_code_text) + + # From importlib doc, + # > If you are dynamically importing a module that was created since + # the interpreter began execution (e.g., created a Python source file), + # you may need to call invalidate_caches() in order for the new module + # to be noticed by the import system. + importlib.invalidate_caches() + generated_module = importlib.import_module(f"{generated_module_name}") + return generated_module + + +def instantiate_scriptable_remote_module_template( + module_interface_cls, enable_moving_cpu_tensors_to_cuda=True +): + if not getattr(module_interface_cls, "__torch_script_interface__", False): + raise ValueError( + f"module_interface_cls {module_interface_cls} must be a type object decorated by " + "@torch.jit.interface" + ) + + # Generate the template instance name. + module_interface_cls_name = torch._jit_internal._qualified_name( + module_interface_cls + ).replace(".", "_") + generated_module_name = f"{_FILE_PREFIX}{module_interface_cls_name}" + + # Generate type annotation strs. + assign_module_interface_cls_str = ( + f"from {module_interface_cls.__module__} import " + f"{module_interface_cls.__name__} as module_interface_cls" + ) + args_str, arg_types_str, return_type_str = get_arg_return_types_from_interface( + module_interface_cls + ) + kwargs_str = "" + arrow_and_return_type_str = f" -> {return_type_str}" + arrow_and_future_return_type_str = f" -> Future[{return_type_str}]" + + str_dict = dict( + assign_module_interface_cls=assign_module_interface_cls_str, + arg_types=arg_types_str, + arrow_and_return_type=arrow_and_return_type_str, + arrow_and_future_return_type=arrow_and_future_return_type_str, + args=args_str, + kwargs=kwargs_str, + jit_script_decorator="@torch.jit.script", + ) + return _do_instantiate_remote_module_template( + generated_module_name, str_dict, enable_moving_cpu_tensors_to_cuda + ) + + +def instantiate_non_scriptable_remote_module_template(): + generated_module_name = f"{_FILE_PREFIX}non_scriptable" + str_dict = dict( + assign_module_interface_cls="module_interface_cls = None", + args="*args", + kwargs="**kwargs", + arg_types="*args, **kwargs", + arrow_and_return_type="", + arrow_and_future_return_type="", + jit_script_decorator="", + ) + # For a non-scriptable template, always enable moving CPU tensors to a cuda device, + # because there is no syntax limitation on the extra handling caused by the script. + return _do_instantiate_remote_module_template(generated_module_name, str_dict, True) diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__init__.py b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bdd0538b93bb66877452ce0428607a8e3ca14484 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f5c4b4fe4554a68c3e8fc938a465de98ce27464 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/__pycache__/remote_module_template.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py new file mode 100644 index 0000000000000000000000000000000000000000..42e4f1279769bedd46c33337f8d91b4ec2ecfa9f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/nn/jit/templates/remote_module_template.py @@ -0,0 +1,108 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + + +def get_remote_module_template(enable_moving_cpu_tensors_to_cuda: bool): + return _TEMPLATE_PREFIX + ( + _REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA + if enable_moving_cpu_tensors_to_cuda + else _REMOTE_FORWARD_TEMPLATE + ) + + +_TEMPLATE_PREFIX = """from typing import * + +import torch +import torch.distributed.rpc as rpc +from torch import Tensor +from torch._jit_internal import Future +from torch.distributed.rpc import RRef +from typing import Tuple # pyre-ignore: unused import + + +{assign_module_interface_cls} + + +def forward_async(self, {arg_types}){arrow_and_future_return_type}: + args = (self.module_rref, self.device, self.is_device_map_set, {args}) + kwargs = {{{kwargs}}} + return rpc.rpc_async( + self.module_rref.owner(), + _remote_forward, + args, + kwargs, + ) + + +def forward(self, {arg_types}){arrow_and_return_type}: + args = (self.module_rref, self.device, self.is_device_map_set, {args}) + kwargs = {{{kwargs}}} + ret_fut = rpc.rpc_async( + self.module_rref.owner(), + _remote_forward, + args, + kwargs, + ) + return ret_fut.wait() + + +_generated_methods = [ + forward_async, + forward, +] + + +{jit_script_decorator} +""" + +# This template may cause typing error (the mismatch between ``Tuple[()]`` and ``Tuple[Any]``) +# even if the code is only used for instantiation but not execution. +# Therefore, only include handling moving CPU tensors to a cuda device if necessary. +# TODO: Merge these two templates together in the future once TorchScript syntax is improved. +_REMOTE_FORWARD_TEMPLATE_ENABLE_MOVING_CPU_TENSORS_TO_CUDA = """ +def _remote_forward( + module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: + module = module_rref.local_value() + device = torch.device(device) + + if device.type != "cuda": + return module.forward({args}, {kwargs}) + + # If the module is on a cuda device, + # move any CPU tensor in args or kwargs to the same cuda device. + # Since torch script does not support generator expression, + # have to use concatenation instead of + # ``tuple(i.to(device) if isinstance(i, Tensor) else i for i in *args)``. + args = ({args},) + out_args: Tuple[()] = () + for arg in args: + arg = (arg.to(device),) if isinstance(arg, Tensor) else (arg,) + out_args = out_args + arg + + kwargs = {{{kwargs}}} + for k, v in kwargs.items(): + if isinstance(v, Tensor): + kwargs[k] = kwargs[k].to(device) + + if is_device_map_set: + return module.forward(*out_args, {kwargs}) + + # If the device map is empty, then only CPU tensors are allowed to send over wire, + # so have to move any GPU tensor to CPU in the output. + # Since torch script does not support generator expression, + # have to use concatenation instead of + # ``tuple(i.cpu() if isinstance(i, Tensor) else i for i in module.forward(*out_args, {kwargs}))``. + ret: Tuple[()] = () + for i in module.forward(*out_args, {kwargs}): + i = (i.cpu(),) if isinstance(i, Tensor) else (i,) + ret = ret + i + return ret +""" + +_REMOTE_FORWARD_TEMPLATE = """ +def _remote_forward( + module_rref: RRef[module_interface_cls], device: str, is_device_map_set: bool, {arg_types}){arrow_and_return_type}: + module = module_rref.local_value() + + return module.forward({args}, {kwargs}) +""" diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__init__.py b/phivenv/Lib/site-packages/torch/distributed/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dd9c263a2dcf50f418d3d1fe517b0256c296e8b2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/__init__.py @@ -0,0 +1,44 @@ +""" +:mod:`torch.distributed.optim` exposes DistributedOptimizer, which takes a list +of remote parameters (:class:`~torch.distributed.rpc.RRef`) and runs the +optimizer locally on the workers where the parameters live. The distributed +optimizer can use any of the local optimizer :ref:`optimizer-algorithms` to +apply the gradients on each worker. +""" + +import warnings + +import torch +from torch import optim + +from .apply_optimizer_in_backward import ( + _apply_optimizer_in_backward, + _get_in_backward_optimizers, +) +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD +from .named_optimizer import _NamedOptimizer +from .utils import as_functional_optim + + +# DistributedOptimizer imports torch.distributed.rpc names, so gate availability +# based on RPC being available. +if hasattr(torch._C, "_rpc_init"): + from .optimizer import DistributedOptimizer + +from .post_localSGD_optimizer import PostLocalSGDOptimizer +from .zero_redundancy_optimizer import ZeroRedundancyOptimizer + + +__all__ = [ + "as_functional_optim", + "DistributedOptimizer", + "PostLocalSGDOptimizer", + "ZeroRedundancyOptimizer", +] diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eccb65928491e3984458a74708f005805493c907 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/_deprecation_warning.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/_deprecation_warning.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6da06c3394e5201b96d13cf23acf5691d8ade2cb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/_deprecation_warning.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1393ca24ce9c2274e271302701e8c1e495aaa515 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/apply_optimizer_in_backward.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f0962a12aa363adab88dc4b5b70df0f566c5bc5e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adadelta.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc9ce673426f107e68f35bdf9e9ee52ad80ba9dd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adagrad.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a93a871ef70a0bfd2bece267b0b05417601134df Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adam.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95338f62fa8c4777a2154cba8f0b23bb1c684901 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamax.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12ae9bb810e92d3586dc86e379d532b05f0bfd61 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_adamw.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ab19cab596602958dc9c521aa058cf2fdda7814 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rmsprop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcbfb3bc68c3435be662662d82ac321d4427a859 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_rprop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe3f6b56e515ca8d8f3d9f6553a5ad0c8fcba33 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/functional_sgd.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..843e472e4ca4c6e2c52c0519608bb17f44937c2d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/named_optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf5c7dbd0ae3cb2751136b655560bf34a8ad8066 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..553446f2c2ba68303c2c86a0cdf4132b00aabb97 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/post_localSGD_optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..946c3469ec567ed676caf456dbd1d17e18ed9c0b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6b4be175691fa709f621134495e9878034c6aeb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/optim/__pycache__/zero_redundancy_optimizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/_deprecation_warning.py b/phivenv/Lib/site-packages/torch/distributed/optim/_deprecation_warning.py new file mode 100644 index 0000000000000000000000000000000000000000..d8424059c1e4068a1606eaf5bcbdfe7a820a8756 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/_deprecation_warning.py @@ -0,0 +1,16 @@ +import warnings + +import torch + + +@torch.jit.ignore # type: ignore[misc] +def _scripted_functional_optimizer_deprecation_warning(stacklevel: int = 0) -> None: + with warnings.catch_warnings(): + warnings.simplefilter("always") + warnings.warn( + "`TorchScript` support for functional optimizers is deprecated " + "and will be removed in a future PyTorch release. " + "Consider using the `torch.compile` optimizer instead.", + DeprecationWarning, + stacklevel=stacklevel + 2, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py b/phivenv/Lib/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..6d09a1448ad600bc5238dd47853ad993f2eb9e87 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/apply_optimizer_in_backward.py @@ -0,0 +1,121 @@ +from collections.abc import Iterable +from typing import Any, no_type_check + +import torch + + +__all__: list[str] = [] + +# WeakTensorKeyDictionary to store relevant meta-data for the Tensor/Parameter +# without changing it's life-time. +# NOTE: Alternative is to add the meta-data as an attribute to the tensor, +# but that will serialize the meta-data if Tensor is serialized. +param_to_optim_hook_handle_map = torch.utils.weak.WeakTensorKeyDictionary() +param_to_acc_grad_map = torch.utils.weak.WeakTensorKeyDictionary() + + +@no_type_check +def _apply_optimizer_in_backward( + optimizer_class: type[torch.optim.Optimizer], + params: Iterable[torch.nn.Parameter], + optimizer_kwargs: dict[str, Any], + register_hook: bool = True, +) -> None: + """ + Upon ``backward()``, the optimizer specified for each parameter will fire after + the gradient has been accumulated into the parameter. + + Note - gradients for these parameters will be set to None after ``backward()``. + This means that any other optimizer not specified via `_apply_optimizer_in_backward` + over this parameter will be a no-op. + + Args: + optimizer_class: (Type[torch.optim.Optimizer]): Optimizer to apply to parameter + params: (Iterator[nn.Parameter]): parameters to apply optimizer state to + optimizer_kwargs: (Dict[str, Any]): kwargs to pass to optimizer constructor + register_hook: (bool): whether to register a hook that runs the optimizer + after gradient for this parameter is accumulated. This is the default + way that optimizer in backward is implemented, but specific use cases + (such as DDP) may wish to override this to implement custom behavior. + (Default = True) + + Example:: + params_generator = model.parameters() + param_1 = next(params_generator) + remainder_params = list(params_generator) + + apply_optimizer_in_backward(torch.optim.SGD, [param_1], {"lr": 0.02}) + apply_optimizer_in_backward(torch.optim.Adam, remainder_params, {"lr": 0.04}) + + model(...).sum().backward() # after backward, parameters will already + # have their registered optimizer(s) applied. + + """ + torch._C._log_api_usage_once("torch.distributed.optim.apply_optimizer_in_backward") + + @no_type_check + def _apply_optimizer_in_backward_to_param(param: torch.nn.Parameter) -> None: + # view_as creates a node in autograd graph that allows us access to the + # parameter's AccumulateGrad autograd function object. We register a + # hook on this object to fire the optimizer when the gradient for + # this parameter is ready (has been accumulated into .grad field) + + # Don't create a new acc_grad if we already have one + # i.e. for shared parameters or attaching multiple optimizers to a param. + if param not in param_to_acc_grad_map: + param_to_acc_grad_map[param] = param.view_as(param).grad_fn.next_functions[ + 0 + ][0] + + optimizer = optimizer_class([param], **optimizer_kwargs) + + if not hasattr(param, "_in_backward_optimizers"): + param._in_backward_optimizers = [] # type: ignore[attr-defined] + # TODO: Remove these attributes once we have a better way of accessing + # optimizer classes and kwargs for a parameter. + param._optimizer_classes = [] # type: ignore[attr-defined] + param._optimizer_kwargs = [] # type: ignore[attr-defined] + + param._in_backward_optimizers.append(optimizer) # type: ignore[attr-defined] + param._optimizer_classes.append(optimizer_class) # type: ignore[attr-defined] + param._optimizer_kwargs.append(optimizer_kwargs) # type: ignore[attr-defined] + + if not register_hook: + return + + def optimizer_hook(*_unused) -> None: + for opt in param._in_backward_optimizers: # type: ignore[attr-defined] + opt.step() + + param.grad = None + + handle = param_to_acc_grad_map[param].register_hook(optimizer_hook) # type: ignore[attr-defined] + if param not in param_to_optim_hook_handle_map: + param_to_optim_hook_handle_map[param] = [] + param_to_optim_hook_handle_map[param].append(handle) + + for param in params: + _apply_optimizer_in_backward_to_param(param) + + +def _get_in_backward_optimizers(module: torch.nn.Module) -> list[torch.optim.Optimizer]: + """ + Return a list of in-backward optimizers applied to ``module``'s parameters. Note that these + optimizers are not intended to directly have their ``step`` or ``zero_grad`` methods called + by the user and are intended to be used for things like checkpointing. + + Args: + module: (torch.nn.Module): model to retrieve in-backward optimizers for + + Returns: + List[torch.optim.Optimizer]: the in-backward optimizers. + + Example:: + _apply_optimizer_in_backward(torch.optim.SGD, model.parameters(), {"lr": 0.01}) + optims = _get_optimizers_in_backward(model) + """ + optims: list[torch.optim.Optimizer] = [] + for param in module.parameters(): + optims.extend(getattr(param, "_in_backward_optimizers", [])) + + return optims diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_adadelta.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..c8fba5e0b7e7c03f5881de1a046d250344d14e84 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adadelta.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adadelta Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdadelta: + def __init__( + self, + params: list[Tensor], + lr: float = 1.0, + rho: float = 0.9, + eps: float = 1e-6, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "rho": rho, + "eps": eps, + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + acc_deltas = [] + state_steps = [] + lr = self.defaults["lr"] + rho = self.defaults["rho"] + eps = self.defaults["eps"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["square_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["acc_delta"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + acc_deltas.append(state["acc_delta"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adadelta( + params_with_grad, + grads, + square_avgs, + acc_deltas, + state_steps, + lr=lr, + rho=rho, + eps=eps, + weight_decay=weight_decay, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_adagrad.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..28bd9e374468e91c303ec913197095064a208240 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adagrad.py @@ -0,0 +1,115 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adagrad Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly let the user pass gradients to the `step` function +# this is so that we could separate the gradients and parameters +# and allow multithreaded trainer to update the parameters +# without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdagrad: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + lr_decay: float = 0.0, + weight_decay: float = 0.0, + initial_accumulator_value: float = 0.0, + warmup_lr_multiplier: float = 1.0, + warmup_num_iters: float = 0.0, + eps: float = 1e-10, + coalesce_grad: bool = True, + foreach: bool = False, + fused: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "lr_decay": lr_decay, + "eps": eps, + "weight_decay": weight_decay, + "initial_accumulator_value": initial_accumulator_value, + "warmup_lr_multiplier": warmup_lr_multiplier, + "warmup_num_iters": warmup_num_iters, + } + self.coalesce_grad = coalesce_grad + self.foreach = foreach + self.fused = fused + self.maximize = maximize + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + # TODO: no union or any types in TorchScript, make step a scalar tensor instead + # This is also needed by if we want to share_memory on the step across processes + for p in self.param_group["params"]: + self.state[p] = { + "sum": torch.full_like(p.data, initial_accumulator_value), + "step": torch.tensor(0.0), + } + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + state_sums = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad, has_complex = False, False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_sparse_grad |= gradient.is_sparse + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + state = self.state[param] + state_sums.append(state["sum"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adagrad( + params, + grads, + state_sums, + state_steps, + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + lr_decay=self.defaults["lr_decay"], + eps=self.defaults["eps"], + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_adam.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..0c2e9cb40c2c29363a6794a0b4287230f78d4092 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adam.py @@ -0,0 +1,202 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adam Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdam: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + """ + Similar to step, but operates on a single parameter and optionally a + gradient tensor. + """ + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = torch.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with torch.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = False + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adam( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + has_complex=has_complex, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_adamax.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adamax.py new file mode 100644 index 0000000000000000000000000000000000000000..eb423973925bde76b253c1baba87295d943364f5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adamax.py @@ -0,0 +1,123 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Adamax Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamax: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 0.0, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.foreach = foreach + self.maximize = maximize + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_infs = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_inf"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_infs.append(state["exp_inf"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adamax( + params_with_grad, + grads, + exp_avgs, + exp_infs, + state_steps, + eps=self.defaults["eps"], + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_adamw.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adamw.py new file mode 100644 index 0000000000000000000000000000000000000000..7368fb3053c846c505e76ab5e0e82728256ca3a5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_adamw.py @@ -0,0 +1,203 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional AdamW Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalAdamW: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-3, + betas: tuple[float, float] = (0.9, 0.999), + eps: float = 1e-8, + weight_decay: float = 1e-2, + amsgrad: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + if not 0.0 <= lr: + raise ValueError(f"Invalid learning rate: {lr}") + if not 0.0 <= eps: + raise ValueError(f"Invalid epsilon value: {eps}") + if not 0.0 <= betas[0] < 1.0: + raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}") + if not 0.0 <= betas[1] < 1.0: + raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") + if not 0.0 <= weight_decay: + raise ValueError(f"Invalid weight_decay value: {weight_decay}") + + self.defaults = { + "lr": lr, + "eps": eps, + "beta1": betas[0], + "beta2": betas[1], + "weight_decay": weight_decay, + } + self.amsgrad = amsgrad + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + has_complex = torch.is_complex(param) + if grad is not None: + params_with_grad.append(param) + grads.append(grad) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + with torch.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + exp_avgs = [] + exp_avg_sqs = [] + max_exp_avg_sqs = [] + state_steps: list[Tensor] = [] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(self.param_group["params"], gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + # Exponential moving average of gradient values + state["exp_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + # Exponential moving average of squared gradient values + state["exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.amsgrad: + # Maintains max of all exp. moving avg. of sq. grad. values + state["max_exp_avg_sq"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + + exp_avgs.append(state["exp_avg"]) + exp_avg_sqs.append(state["exp_avg_sq"]) + + if self.amsgrad: + max_exp_avg_sqs.append(state["max_exp_avg_sq"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.adamw( + params_with_grad, + grads, + exp_avgs, + exp_avg_sqs, + max_exp_avg_sqs, + state_steps, + amsgrad=self.amsgrad, + maximize=self.maximize, + beta1=self.defaults["beta1"], + beta2=self.defaults["beta2"], + lr=self.defaults["lr"], + weight_decay=self.defaults["weight_decay"], + eps=self.defaults["eps"], + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_rmsprop.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_rmsprop.py new file mode 100644 index 0000000000000000000000000000000000000000..f8a7489bcf04126fcea31faef1b1a4fb4900f433 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_rmsprop.py @@ -0,0 +1,130 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional RMSprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRMSprop: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + alpha: float = 0.99, + eps: float = 1e-8, + weight_decay: float = 0.0, + momentum: float = 0.0, + centered: bool = False, + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "alpha": alpha, + "eps": eps, + "weight_decay": weight_decay, + "momentum": momentum, + } + self.centered = centered + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + square_avgs = [] + grad_avgs = [] + momentum_buffer_list = [] + state_steps = [] + lr = self.defaults["lr"] + alpha = self.defaults["alpha"] + eps = self.defaults["eps"] + momentum = self.defaults["momentum"] + weight_decay = self.defaults["weight_decay"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["square_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if momentum > 0: + state["momentum_buffer"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + if self.centered: + state["grad_avg"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + + state = self.state[param] + square_avgs.append(state["square_avg"]) + if momentum > 0: + momentum_buffer_list.append(state["momentum_buffer"]) + if self.centered: + grad_avgs.append(state["grad_avg"]) + + state_steps.append(state["step"]) + + with torch.no_grad(): + F.rmsprop( + params_with_grad, + grads, + square_avgs, + grad_avgs, + momentum_buffer_list, + state_steps, + lr=lr, + alpha=alpha, + eps=eps, + weight_decay=weight_decay, + momentum=momentum, + centered=self.centered, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_rprop.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_rprop.py new file mode 100644 index 0000000000000000000000000000000000000000..5839efeb7c019a2ee8db475c65a70e68617520aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_rprop.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional Rprop Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalRprop: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + etas: tuple[float, float] = (0.5, 1.2), + step_sizes: tuple[float, float] = (1e-6, 50), + foreach: bool = False, + maximize: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + } + self.etas = etas + self.step_sizes = step_sizes + self.foreach = foreach + self.maximize = maximize + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + prevs = [] + step_sizes = [] + state_steps = [] + lr = self.defaults["lr"] + etaminus, etaplus = self.etas + step_size_min, step_size_max = self.step_sizes + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_complex = False + for param, gradient in zip(params, gradients): + if gradient is not None: + has_complex |= torch.is_complex(param) + params_with_grad.append(param) + grads.append(gradient) + # Lazy state initialization + if param not in self.state: + self.state[param] = {} + state = self.state[param] + state["step"] = torch.tensor(0.0) + state["prev"] = torch.zeros_like( + param, memory_format=torch.preserve_format + ) + state["step_size"] = torch.full_like(gradient, lr) + + state = self.state[param] + prevs.append(state["prev"]) + step_sizes.append(state["step_size"]) + state_steps.append(state["step"]) + + with torch.no_grad(): + F.rprop( + params_with_grad, + grads, + prevs, + step_sizes, + state_steps, + step_size_min=step_size_min, + step_size_max=step_size_max, + etaminus=etaminus, + etaplus=etaplus, + foreach=self.foreach, + maximize=self.maximize, + has_complex=has_complex, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/functional_sgd.py b/phivenv/Lib/site-packages/torch/distributed/optim/functional_sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..5aa98bcaa9c0db40c036fbac27da323ec2679d32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/functional_sgd.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +import torch.optim._functional as F +from torch import Tensor +from torch.distributed.optim._deprecation_warning import ( + _scripted_functional_optimizer_deprecation_warning, +) + + +__all__: list[str] = [] + + +# Define a TorchScript compatible Functional SGD Optimizer +# where we use these optimizer in a functional way. +# Instead of using the `param.grad` when updating parameters, +# we explicitly allow the distributed optimizer pass gradients to +# the `step` function. In this way, we could separate the gradients +# and parameters and allow multithreaded trainer to update the +# parameters without data traces on accumulating to the same .grad. +# NOTE: This should be only used by distributed optimizer internals +# and not meant to expose to the user. +@torch.jit.script +class _FunctionalSGD: + def __init__( + self, + params: list[Tensor], + lr: float = 1e-2, + momentum: float = 0.0, + dampening: float = 0.0, + weight_decay: float = 0.0, + nesterov: bool = False, + maximize: bool = False, + foreach: bool = False, + fused: bool = False, + _allow_empty_param_list: bool = False, + ): + _scripted_functional_optimizer_deprecation_warning(stacklevel=2) + self.defaults = { + "lr": lr, + "momentum": momentum, + "dampening": dampening, + "weight_decay": weight_decay, + } + self.nesterov = nesterov + self.maximize = maximize + self.foreach = foreach + self.fused = fused + self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {}) + + if len(params) == 0 and not _allow_empty_param_list: + raise ValueError("optimizer got an empty parameter list") + + # NOTE: we only have one param_group and don't allow user to add additional + # param group as it's not a common use case. + self.param_group = {"params": params} + + def step_param(self, param: Tensor, grad: Optional[Tensor]): + """Similar to self.step, but operates on a single parameter and + its gradient. + """ + # TODO: Once step_param interface is robust, refactor step to call + # step param on each param. + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + lr = self.defaults["lr"] + params = [param] + momentum_buffer_list: list[Optional[Tensor]] = [] + grads = [] + + has_sparse_grad = False + if grad is not None: + grads.append(grad) + if grad.is_sparse: + has_sparse_grad = True + if param not in self.state: + self.state[param] = {} + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with torch.no_grad(): + F.sgd( + params, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + # update momentum_buffer in state + state = self.state[param] + momentum_buffer = momentum_buffer_list[0] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer + + def step(self, gradients: list[Optional[Tensor]]): + params = self.param_group["params"] + params_with_grad = [] + grads = [] + momentum_buffer_list: list[Optional[Tensor]] = [] + lr = self.defaults["lr"] + weight_decay = self.defaults["weight_decay"] + momentum = self.defaults["momentum"] + dampening = self.defaults["dampening"] + + if len(params) != len(gradients): + raise ValueError( + "the gradients passed in does not equal to the size of the parameters!" + + f"Params length: {len(params)}. " + + f"Gradients length: {len(gradients)}" + ) + + has_sparse_grad = False + for param, gradient in zip(params, gradients): + if gradient is not None: + params_with_grad.append(param) + grads.append(gradient) + if gradient.is_sparse: + has_sparse_grad = True + + if param not in self.state: + self.state[param] = {} + + state = self.state[param] + if "momentum_buffer" not in state: + momentum_buffer_list.append(None) + else: + momentum_buffer_list.append(state["momentum_buffer"]) + + with torch.no_grad(): + F.sgd( + params_with_grad, + grads, + momentum_buffer_list, + weight_decay=weight_decay, + momentum=momentum, + lr=lr, + dampening=dampening, + nesterov=self.nesterov, + maximize=self.maximize, + has_sparse_grad=has_sparse_grad, + foreach=self.foreach, + fused=self.fused, + grad_scale=None, + found_inf=None, + ) + + # update momentum_buffers in state + for i, p in enumerate(params_with_grad): + state = self.state[p] + momentum_buffer = momentum_buffer_list[i] + if momentum_buffer is not None: + state["momentum_buffer"] = momentum_buffer diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/named_optimizer.py b/phivenv/Lib/site-packages/torch/distributed/optim/named_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..baee65d61cbe7247b01e471896ade2818ecd36a7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/named_optimizer.py @@ -0,0 +1,327 @@ +import logging +import warnings +from collections.abc import Collection, Mapping +from copy import deepcopy +from typing import Any, Callable, Optional, overload, Union + +import torch +import torch.nn as nn +from torch import optim +from torch.distributed._shard.sharded_tensor import ShardedTensor +from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + + +__all__: list[str] = [] + +logger = logging.getLogger(__name__) + + +class _NamedOptimizer(optim.Optimizer): + """ + ``_NamedOptimizer`` takes a dict of parameters and exposes ``state_dict`` by parameter key. + + We replace the original key (number) in an optim to the + fully qualified name (FQN) string. User can initialize the optim as they + initialize a PyTorch optim, the only difference is that they also need to + pass in the FQN of each parameters. + + Args: + named_parameters (Mapping[str, Union[torch.Tensor, ShardedTensor]]): + Mapping from FQN to parameter. + optimizer_class (optim.Optimizer): + The class of optimizer to instantiate. + param_groups (Collection[Mapping[str, Any]]): + `param_groups` to pass to optimizer if specified. + The key of the inner map needs to be FQNs. + Default: None + module (nn.Module): the module whose parameters to updated + by the optimizer. + args: arguments to pass to the optimizer constructor. + kwargs: arguments to pass to the optimizer constructor. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch import optim + >>> from torch.distributed.optim import _NamedOptimizer + >>> + >>> # Define the named optimizer. + >>> m = Model(...) + >>> named_optim = _NamedOptimizer(m.named_parameters(), optim.SGD) + >>> # Forward pass + backward pass. + >>> named_optim.step() + >>> ... + >>> # Call state_dict for the named optimizer returns a FQN state_dict. + >>> named_optim.state_dict() + + Warning: This API is still in development and subject to change. + + TODO: Add tutorial for _NamedOptimizer. + TODO: Add documentation in the docstring for the public attributes + like self.param_groups and self.named_parameters. + """ + + def __init__( + self, + named_parameters: Mapping[str, Union[torch.Tensor, ShardedTensor]], + optimizer_class: optim.Optimizer, + param_groups: Optional[Collection[Mapping[str, Any]]] = None, + module: Optional[nn.Module] = None, + *args: tuple[Any, ...], + **kwargs: dict[str, Any], + ) -> None: + torch._C._log_api_usage_once("torch.distributed.optim._NamedOptimizer") + self.param_groups: Collection[Mapping[str, Any]] = param_groups # type: ignore[assignment] + self._param_groups_check() + self.named_parameters = dict(named_parameters) + params_for_optimizer = ( + self.named_parameters.values() if param_groups is None else param_groups + ) + self._optimizer = optimizer_class( # type: ignore[operator] + params_for_optimizer, + *args, + **kwargs, + ) + self.module = module + if param_groups is None: + self.ordered_param_keys = list(self.named_parameters.keys()) + else: + warnings.warn( + "Since we pass in param_groups, we will use param_groups to " + "initialize the optimizer, not all parameters of the module." + ) + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + ordered_param_keys = [] + for group in param_groups: + for param in group["params"]: + if param not in param_to_key: + raise ValueError( + f"Expect param name {param} found in param group but is missing." + ) + ordered_param_keys.append(param_to_key[param]) + self.ordered_param_keys = ordered_param_keys + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def _param_groups_check(self) -> None: + if self.param_groups is not None: + for param_group in self.param_groups: + assert isinstance(param_group, dict), "param group must be a dict" + assert "params" in param_group, "param group must contain key params" + params = param_group["params"] + if isinstance(params, torch.Tensor): + params = [params] + params = list(params) + for param in params: + if not isinstance(param, torch.Tensor): + raise TypeError( + "optimizer can only optimize Tensors, " + "but one of the params is " + torch.typename(param) + ) + param_group["params"] = params + + def state_dict(self) -> dict[str, Any]: + """ + Return the ``state_dict`` of the optimizer. + + Instead of using number to index + parameters, we will use module fully qualified name (FQN) as the key. + """ + state_dict = self._optimizer.state_dict() + param_groups = state_dict["param_groups"] + + ret_state = { + self.ordered_param_keys[st_key]: state_val + for st_key, state_val in state_dict["state"].items() + } + + ret_groups = [] + for group in param_groups: + param_keys = [self.ordered_param_keys[param] for param in group["params"]] + ret_group = {"params": sorted(param_keys)} + for k, v in group.items(): + if k != "params": + ret_group[k] = deepcopy(v) + ret_groups.append(ret_group) + + return self._post_state_dict({"state": ret_state, "param_groups": ret_groups}) + + @overload + def step(self, closure: None = None) -> None: ... + + @overload + def step(self, closure: Callable[[], float]) -> float: ... + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + """ + Perform a single optimization step. + + This will call :meth:`torch.optim.Optimizer.step` on the wrapped + optimizer. + """ + return self._optimizer.step(closure=closure) + + @property + def state(self) -> Mapping[torch.Tensor, Any]: # type: ignore[override] + return self._optimizer.state + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + """ + Define the default behavior to load a state_dict for ``_NamedOptimizer``. + + Sample Code + ``` + my_model = MyModule() + optimizer = _NamedOptimizer(my_model.named_parameters(), Adagrad) + ... + + optim_state_dict = optimizer.state_dict() + ... + ... + + optimizer.load_state_dict(optim_state_dict) + ... + ``` + Args: + state_dict (dict[str, Any]) : A ``state_dict`` to load into the optimizer. + Note that this state dict update is performed in place. + + .. note:: PyTorch is using lazy init to initialize the optim states. + So it is possible that there is no optim state when user call + ``load_state_dict`` and for ``_NamedOptimizer`` we make it stricter + that users can only call ``load_state_dict`` after the state is initialized. + By doing this, we can validate the optim ``state_dict`` to be loaded. + """ + new_state_dict = self._optimizer.state_dict() + state_dict = self._pre_load_state_dict(state_dict) + state = state_dict["state"] + new_state = new_state_dict["state"] + if len(new_state) == 0: + raise ValueError( + "Expects the optim to be initialized before load but found not initialized." + ) + + for idx, param_key in enumerate(self.ordered_param_keys): + # When the conditional training is performed, not all parameters are updated in the optim. + if param_key not in state.keys(): + continue + if len(state[param_key]) != len(new_state[idx]): + raise ValueError( + f"Expects equal length as {len(new_state[idx])} for parameter {param_key} but found: {len(state[param_key])}" + ) + # Iterate through all optimizer states. + for state_key, state_val in new_state[idx].items(): + if state_key not in state[param_key]: + raise ValueError( + f"Expects state {state_key} for parameter {param_key} but not found." + ) + + src_state_val = state[param_key][state_key] + if isinstance(state_val, ShardedTensor): + assert isinstance(src_state_val, ShardedTensor) + num_shards = len(state_val.local_shards()) + num_new_shards = len(src_state_val.local_shards()) + if num_shards != num_new_shards: + raise ValueError( + f"Expects equal number of shards as {num_new_shards} but found {num_shards} for {param_key}/{state_key}" + ) + for shard, src_shard in zip( + state_val.local_shards(), src_state_val.local_shards() + ): + shard.tensor.detach().copy_(src_shard.tensor) + elif isinstance(state_val, torch.Tensor): + assert isinstance(src_state_val, torch.Tensor) + state_val.detach().copy_(src_state_val) + else: + new_state[idx][state_key] = deepcopy(src_state_val) + + # Load param_groups of state_dict + src_param_groups = state_dict["param_groups"] + new_param_groups = new_state_dict["param_groups"] + + src_group_map = {} + for group in src_param_groups: + param_keys = list(group["params"]) + src_group_map[_gen_param_group_key(param_keys)] = group + new_group_map = {} + for new_group in new_param_groups: + param_keys = [] + for param_key in new_group["params"]: + param_keys.append(self.ordered_param_keys[param_key]) # type: ignore[call-overload] + new_group_map[_gen_param_group_key(param_keys)] = new_group + for group_key, new_group in new_group_map.items(): + # When not all parameters are used in training or receive gradient, aka., not all parameters + # would be in the param_group. Thus we skip the group_key here. + if group_key not in src_group_map: + continue + src_group = src_group_map[group_key] + if len(src_group) != len(new_group): + raise ValueError( + f"Expects equal param_group size as {len(new_group)} for group {group_key} but found {len(src_group)}." + ) + for k in src_group: + if k not in new_group: + raise ValueError( + f"Expects group key {k} to be in group {group_key} in `state_dict` but is missing." + ) + if k != "params": + new_group[k] = deepcopy(src_group[k]) + + self._optimizer.load_state_dict(new_state_dict) + + def add_param_group(self, param_group: Mapping[str, Any]) -> None: + """ + Add a param group to the :class:`_NamedOptimizer` s `param_groups`. + + Warning: This API is still in development and subject to change. + """ + assert isinstance(param_group, dict), "param group must be a dict" + + params = param_group["params"] + if isinstance(params, torch.Tensor): + param_group["params"] = [params] + else: + param_group["params"] = list(params) + + param_to_key = {param: key for key, param in self.named_parameters.items()} # type: ignore[misc, has-type] + for param in param_group["params"]: + if param not in param_to_key: + raise ValueError("some parameters are not in the module") + self.ordered_param_keys.append(param_to_key[param]) + + self._optimizer.add_param_group(param_group) + # Update param_groups from optimizer. + self.param_groups = self._optimizer.param_groups + + def init_state(self) -> None: + """ + Run a dummy optimizer step, which allows to initialize optimizer state because we do lazy init for most optimizers. + + This allows doing in-place loading of optimizer state from a checkpoint. + """ + for param in self.named_parameters.values(): + if param.requires_grad: + t = torch.zeros_like(param) + param.grad = torch.autograd.Variable(t) + # Calling ``step`` will load the initial state for optimizer states. + self.step(closure=None) + + def _pre_load_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + return FSDP.optim_state_dict_to_load( + self.module, self._optimizer, state_dict, is_named_optimizer=True + ) + return state_dict + + def _post_state_dict(self, state_dict: dict[str, Any]) -> dict[str, Any]: + # TODO(chienchin): This API should be FSDP agnostic and should support + # general user hooks. + if isinstance(self.module, FSDP): + FSDP.optim_state_dict(self.module, self._optimizer, state_dict) + return state_dict + + +def _gen_param_group_key(param_keys: list[str]) -> str: + """Concatenate all param keys as a unique identifier for one param group.""" + return "/".join(sorted(param_keys)) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/optimizer.py b/phivenv/Lib/site-packages/torch/distributed/optim/optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c05df90da796eebb207fadd144c756eccdb658a7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/optimizer.py @@ -0,0 +1,255 @@ +# mypy: allow-untyped-defs +import logging +from collections import defaultdict +from threading import Lock +from typing import Optional + +import torch +import torch.distributed.autograd as dist_autograd +import torch.distributed.rpc as rpc +import torch.jit as jit +import torch.nn as nn +from torch import Tensor +from torch.distributed.rpc import RRef + +from .utils import functional_optim_map + + +__all__ = ["DistributedOptimizer"] + +logger = logging.getLogger(__name__) + + +# XXX: we define a _ScriptModuleOptimizer here to explicitly +# compile the FunctionalOptimizer class into TorchScript +# This is because ScriptClass instance still lives in +# python unless you explicitly compile it as an attribute +# in ScriptModule or pass it to a ScriptFunction +# _ScriptLocalOptimizerInterface serves as a common +# interface type for Optimizer ScriptModules. +# +# TODO (wanchaol): remove this once we added TorchScript +# class reference semantics +@jit.interface +class _ScriptLocalOptimizerInterface: + def step(self, autograd_ctx_id: int) -> None: + pass + + +class _ScriptLocalOptimizer(nn.Module): + # TorchScript does not support multithread concurrent compiling. + # request_callback might invoke concurrent compiling, so we + # serialize the compiling with a lock + compile_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + super().__init__() + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + @jit.export + def step(self, autograd_ctx_id: int): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + # apply functional optimizer step with a list of gradients + grads: list[Optional[Tensor]] = [ + all_local_grads[p] if p in all_local_grads else None + for p in self._local_params + ] + + self.optim.step(grads) + + +# TODO (wanchaol): remove/merge this with ScriptLocalOptimizer once +# we have converted all to functional optimizer in distributed.optim +class _LocalOptimizer: + # Ideally we would only need to share a lock for instances of + # _LocalOptimizer that deal with the same parameters. We are + # making a simplifying assumption here that if there is more + # than one instance of _LocalOptimizer per worker, they will + # be optimizing the same parameters (e.g. each data parallel + # trainer will create its own instance of _LocalOptimizer but + # they will all optimize the same parameters on each worker) + global_lock = Lock() + + def __init__(self, optim_cls, local_params_rref, *args, **kwargs): + self._local_params = [rref.local_value() for rref in local_params_rref] + self.optim = optim_cls(self._local_params, *args, **kwargs) + + def step(self, autograd_ctx_id): + all_local_grads = dist_autograd.get_gradients(autograd_ctx_id) + + with _LocalOptimizer.global_lock: + for param, grad in all_local_grads.items(): + param.grad = grad + self.optim.step() + + +def _new_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + return rpc.RRef(_LocalOptimizer(optim_cls, local_params_rref, *args, **kwargs)) + + +def _local_optimizer_step(local_optim_rref, autograd_ctx_id): + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +# new/step functions combined with _ScriptLocalOptimizer to provide GIL-free optimizer +def _new_script_local_optimizer(optim_cls, local_params_rref, *args, **kwargs): + optim = _ScriptLocalOptimizer(optim_cls, local_params_rref, *args, **kwargs) + + with _ScriptLocalOptimizer.compile_lock: + script_optim = jit.script(optim) + return rpc.RRef(script_optim, _ScriptLocalOptimizerInterface) + + +@jit.script +def _script_local_optimizer_step( + local_optim_rref: RRef[_ScriptLocalOptimizerInterface], autograd_ctx_id: int +) -> None: + local_optim = local_optim_rref.local_value() + local_optim.step(autograd_ctx_id) + + +def _wait_for_all(rpc_futs): + # TODO: improve error propagation + exception = None + results = [] + for fut in rpc_futs: + try: + results.append(fut.wait()) + except Exception as e: + results.append(e) + exception = e + if exception is not None: + raise exception + return results + + +class DistributedOptimizer: + """ + DistributedOptimizer takes remote references to parameters scattered + across workers and applies the given optimizer locally for each parameter. + + This class uses :meth:`~torch.distributed.autograd.get_gradients` in order + to retrieve the gradients for specific parameters. + + Concurrent calls to + :meth:`~torch.distributed.optim.DistributedOptimizer.step`, + either from the same or different clients, will + be serialized on each worker -- as each worker's optimizer can only work + on one set of gradients at a time. However, there is no guarantee that + the full forward-backward-optimizer sequence will execute for one client + at a time. This means that the gradients being applied may not correspond + to the latest forward pass executed on a given worker. Also, there is no + guaranteed ordering across workers. + + `DistributedOptimizer` creates the local optimizer with TorchScript enabled + by default, so that optimizer updates are not blocked by the Python Global + Interpreter Lock (GIL) in the case of multithreaded training (e.g. Distributed + Model Parallel). This feature is currently enabled for most optimizers. You + can also follow `the recipe`__ in PyTorch tutorials to enable TorchScript support + for your own custom optimizers. + + Args: + optimizer_class (optim.Optimizer): the class of optimizer to + instantiate on each worker. + params_rref (list[RRef]): list of RRefs to local or remote parameters + to optimize. + args: arguments to pass to the optimizer constructor on each worker. + kwargs: arguments to pass to the optimizer constructor on each worker. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> import torch.distributed.autograd as dist_autograd + >>> import torch.distributed.rpc as rpc + >>> from torch import optim + >>> from torch.distributed.optim import DistributedOptimizer + >>> + >>> with dist_autograd.context() as context_id: + >>> # Forward pass. + >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + >>> loss = rref1.to_here() + rref2.to_here() + >>> + >>> # Backward pass. + >>> dist_autograd.backward(context_id, [loss.sum()]) + >>> + >>> # Optimizer. + >>> dist_optim = DistributedOptimizer( + >>> optim.SGD, + >>> [rref1, rref2], + >>> lr=0.05, + >>> ) + >>> dist_optim.step(context_id) + + __ https://github.com/pytorch/tutorials/pull/1465 + """ + + def __init__(self, optimizer_class, params_rref, *args, **kwargs): + torch._C._log_api_usage_once("torch.distributed.optim.DistributedOptimizer") + per_worker_params_rref = defaultdict(list) + for param in params_rref: + per_worker_params_rref[param.owner()].append(param) + + if optimizer_class in functional_optim_map and jit._state._enabled: + optim_ctor = functional_optim_map.get(optimizer_class) + else: + optim_ctor = optimizer_class + self.is_functional_optim = optim_ctor != optimizer_class + + if self.is_functional_optim: + optimizer_new_func = _new_script_local_optimizer + else: + logger.warning( + "Creating the optimizer %s without TorchScript support, " + "this might result in slow computation time in multithreading environment" + "(i.e. Distributed Model Parallel training on CPU) due to the Python's " + "Global Interpreter Lock (GIL). Please file an issue if you need this " + "optimizer in TorchScript. ", + optimizer_class, + ) + optimizer_new_func = _new_local_optimizer + + remote_optim_futs = [] + for worker, param_rrefs in per_worker_params_rref.items(): + remote_optim_rref_fut = rpc.rpc_async( + worker, + optimizer_new_func, + args=(optim_ctor, param_rrefs) + args, + kwargs=kwargs, + ) + remote_optim_futs.append(remote_optim_rref_fut) + + self.remote_optimizers = _wait_for_all(remote_optim_futs) + + def step(self, context_id): + """ + Performs a single optimization step. + + This will call :meth:`torch.optim.Optimizer.step` on each worker + containing parameters to be optimized, and will block until all workers + return. The provided ``context_id`` will be used to retrieve the + corresponding :class:`~torch.distributed.autograd.context` that + contains the gradients that should be applied to the parameters. + + Args: + context_id: the autograd context id for which we should run the + optimizer step. + """ + dist_autograd._is_valid_context(context_id) + + optimizer_step_func = ( + _script_local_optimizer_step + if self.is_functional_optim + else _local_optimizer_step + ) + + rpc_futs = [ + rpc.rpc_async( + optimizer.owner(), + optimizer_step_func, + args=(optimizer, context_id), + ) + for optimizer in self.remote_optimizers + ] + _wait_for_all(rpc_futs) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/post_localSGD_optimizer.py b/phivenv/Lib/site-packages/torch/distributed/optim/post_localSGD_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..c023282dfc71722746e82d840472dec0b405661f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/post_localSGD_optimizer.py @@ -0,0 +1,110 @@ +# mypy: allow-untyped-defs +import warnings + +import torch +import torch.distributed.algorithms.model_averaging.averagers as averagers + + +class PostLocalSGDOptimizer(torch.optim.Optimizer): + r""" + Wraps an arbitrary :class:`torch.optim.Optimizer` and runs `post-local SGD `_, + This optimizer runs local optimizer at every step. + After the warm-up stage, it averages parameters periodically after the local optimizer is applied. + + Args: + optim: The local optimizer. + averager: A model averager instance to run post-localSGD algorithm. + + Example:: + + >>> # xdoctest: +SKIP("undefined variables") + >>> import torch + >>> import torch.distributed as dist + >>> import torch.distributed.algorithms.model_averaging.averagers as averagers + >>> import torch.nn as nn + >>> from torch.distributed.optim import PostLocalSGDOptimizer + >>> from torch.distributed.algorithms.ddp_comm_hooks.post_localSGD_hook import ( + >>> PostLocalSGDState, + >>> post_localSGD_hook, + >>> ) + >>> + >>> model = nn.parallel.DistributedDataParallel( + >>> module, device_ids=[rank], output_device=rank + >>> ) + >>> + >>> # Register a post-localSGD communication hook. + >>> state = PostLocalSGDState(process_group=None, subgroup=None, start_localSGD_iter=100) + >>> model.register_comm_hook(state, post_localSGD_hook) + >>> + >>> # Create a post-localSGD optimizer that wraps a local optimizer. + >>> # Note that ``warmup_steps`` used in ``PostLocalSGDOptimizer`` must be the same as + >>> # ``start_localSGD_iter`` used in ``PostLocalSGDState``. + >>> local_optim = torch.optim.SGD(params=model.parameters(), lr=0.01) + >>> opt = PostLocalSGDOptimizer( + >>> optim=local_optim, + >>> averager=averagers.PeriodicModelAverager(period=4, warmup_steps=100) + >>> ) + >>> + >>> # In the first 100 steps, DDP runs global gradient averaging at every step. + >>> # After 100 steps, DDP runs gradient averaging within each subgroup (intra-node by default), + >>> # and post-localSGD optimizer runs global model averaging every 4 steps after applying the local optimizer. + >>> for step in range(0, 200): + >>> opt.zero_grad() + >>> loss = loss_fn(output, labels) + >>> loss.backward() + >>> opt.step() + """ + + def __init__(self, optim: torch.optim.Optimizer, averager: averagers.ModelAverager): + self.optim = optim + self.param_groups = self.optim.param_groups + self.averager = averager + + @property + def state(self): # type: ignore[override] + return self.optim.state + + def __repr__(self): + return self.optim.__repr__() + + def state_dict(self): + r""" + This is the same as :class:`torch.optim.Optimizer` :meth:`state_dict`, + but adds an extra entry to record model averager's step to the checkpoint + to ensure reload does not cause unnecessary warm up again. + """ + optim_state_dict = self.optim.state_dict() + optim_state_dict["step"] = self.averager.step + return optim_state_dict + + def load_state_dict(self, state_dict): + r""" + This is the same as :class:`torch.optim.Optimizer` :meth:`load_state_dict`, + but also restores model averager's step value to the one + saved in the provided ``state_dict``. + + If there is no ``"step"`` entry in ``state_dict``, + it will raise a warning and initialize the model averager's step to 0. + """ + self.optim.load_state_dict(state_dict) + if "step" in state_dict: + self.averager.step = state_dict["step"] + else: + warnings.warn( + "Loaded state dict does not contain a step counter for an averager. " + "Setting step counter to 0." + ) + self.averager.step = 0 + + def step(self): # type: ignore[override] + r""" + Performs a single optimization step (parameter update). + """ + self.optim.step() + self.averager.average_parameters(params=self.param_groups) + + def zero_grad(self, set_to_none: bool = True): # type: ignore[override] + self.optim.zero_grad(set_to_none=set_to_none) + + def add_param_group(self, param_group): + self.optim.add_param_group(param_group) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/utils.py b/phivenv/Lib/site-packages/torch/distributed/optim/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..36e201618c17480f597ff2567ed425b9d2baaf8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/utils.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs + +from torch import optim + +from .functional_adadelta import _FunctionalAdadelta +from .functional_adagrad import _FunctionalAdagrad +from .functional_adam import _FunctionalAdam +from .functional_adamax import _FunctionalAdamax +from .functional_adamw import _FunctionalAdamW +from .functional_rmsprop import _FunctionalRMSprop +from .functional_rprop import _FunctionalRprop +from .functional_sgd import _FunctionalSGD + + +# dict to map a user passed in optimizer_class to a functional +# optimizer class if we have already defined inside the +# distributed.optim package, this is so that we hide the +# functional optimizer to user and still provide the same API. +functional_optim_map = { + optim.Adagrad: _FunctionalAdagrad, + optim.Adam: _FunctionalAdam, + optim.AdamW: _FunctionalAdamW, + optim.SGD: _FunctionalSGD, + optim.Adadelta: _FunctionalAdadelta, + optim.RMSprop: _FunctionalRMSprop, + optim.Rprop: _FunctionalRprop, + optim.Adamax: _FunctionalAdamax, +} + + +def register_functional_optim(key, optim): + """ + Interface to insert a new functional optimizer to functional_optim_map + ``fn_optim_key`` and ``fn_optimizer`` are user defined. The optimizer and key + need not be of :class:`torch.optim.Optimizer` (e.g. for custom optimizers) + Example:: + >>> # import the new functional optimizer + >>> # xdoctest: +SKIP + >>> from xyz import fn_optimizer + >>> from torch.distributed.optim.utils import register_functional_optim + >>> fn_optim_key = "XYZ_optim" + >>> register_functional_optim(fn_optim_key, fn_optimizer) + """ + if key not in functional_optim_map: + functional_optim_map[key] = optim + + +def as_functional_optim(optim_cls: type, *args, **kwargs): + try: + functional_cls = functional_optim_map[optim_cls] + except KeyError as e: + raise ValueError( + f"Optimizer {optim_cls} does not have a functional counterpart!" + ) from e + + return _create_functional_optim(functional_cls, *args, **kwargs) + + +def _create_functional_optim(functional_optim_cls: type, *args, **kwargs): + return functional_optim_cls( + [], + *args, + **kwargs, + _allow_empty_param_list=True, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py b/phivenv/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..0cd38c9f658a88ec9da6e0e716176be8126ba12a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.py @@ -0,0 +1,1657 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +r"""Zero Redundancy Optimizer.""" + +import collections +import copy +import enum +import inspect +import io +import logging +from itertools import chain +from typing import Any, Callable, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed.algorithms.join import Join, Joinable, JoinHook +from torch.distributed.optim.utils import functional_optim_map +from torch.optim import Optimizer + + +__all__ = ["ZeroRedundancyOptimizer"] + + +logger = logging.getLogger(__name__) + + +# Credits: classy_vision/generic/distributed_util.py +def _recursive_copy_to_device( + value: Any, + non_blocking: bool, + device: torch.device, +) -> Any: + r""" + Recursively searches lists, tuples, dicts and copies tensors to device if possible. + + Non-tensor values are passed as-is in the result. + + .. note:: + These are all copies, so if there are two objects that reference + the same object, then after this call, there will be two different objects + referenced on the device. + """ + if isinstance(value, torch.Tensor): + return value.to(device, non_blocking=non_blocking) + + if isinstance(value, (list, tuple)): + values = [ + _recursive_copy_to_device(val, non_blocking=non_blocking, device=device) + for val in value + ] + return values if isinstance(value, list) else tuple(values) + + if isinstance(value, collections.abc.Mapping): + return { + key: _recursive_copy_to_device( + val, non_blocking=non_blocking, device=device + ) + for key, val in value.items() + } + + return value + + +def _is_trainable(param: torch.Tensor) -> bool: + r"""Return if a parameter is trainable, where trainability is equivalent to requiring a gradient.""" + return param.requires_grad + + +def _broadcast_object( + obj: Any, + src_rank: int, + group: object = dist.group.WORLD, + device: torch.device = torch.device("cpu"), +) -> Any: + r""" + Broadcasts an object to the given group. + + It will be sending the object if called from the source rank and receiving + the object otherwise. + + Arguments: + obj: object to broadcast; only used if called on the source rank. + src_rank (int): source rank. + group (``ProcessGroup``, optional): group used for the broadcast + (default: ``dist.group.WORLD``). + device (``torch.device``, optional): device to send from or receive + to (default: ``torch.device("cpu")``). + + Returns: + The broadcasted object. + """ + if dist.get_rank() == src_rank: + # Send the object + buffer = io.BytesIO() + torch.save(obj, buffer) + data = bytearray(buffer.getbuffer()) + length_tensor = torch.LongTensor([len(data)]).to(device) + data_send_tensor = torch.ByteTensor(data).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) + else: + # Receive the object + length_tensor = torch.LongTensor([0]).to(device) + dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) + data_recv_tensor = torch.empty( + [int(length_tensor.item())], dtype=torch.uint8, device=device + ) + dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) + buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) + obj = torch.load(buffer, map_location=device, weights_only=False) + return obj + + +class _ZeROJoinHook(JoinHook): + def __init__(self, zero): + assert isinstance(zero, ZeroRedundancyOptimizer), ( + "ZeRO join hook requires passing in a ZeroRedundancyOptimizer " + "instance as the state" + ) + self.zero = zero + super().__init__() + + def main_hook(self): + """ + Perform an optimizer step. + + This step updates the joined process's shard of + the parameters and broadcasts those parameters. + """ + self.zero.step() + + +class _DDPBucketAssignment: + r""" + Represent a :class:`DistributedDataParallel` bucket assignment. + + This means that a (possibly non-strict) subset of the parameters corresponding to + a DDP bucket assigned to a rank to update. + + Attributes: + bucket_index (int): index of the bucket determined by the DDP gradient + bucket all-reduce order. + parameters (List[torch.Tensor]): model parameters in the bucket + assigned to this rank. + offset (int): offset into the :class:`GradBucket` 's :meth:`parameters` + giving the index of the first element in the passed-in + ``parameters``; this equivalently indexes into the + :class:`GradBucket` 's :meth:`gradients`. + device (torch.device): device on which the parameters are stored. + tensor (torch.Tensor): flattened tensor giving the data of the + parameter subset assigned to the rank. + """ + + def __init__( + self, + bucket_index: int, + parameters: list[torch.Tensor], + offset: int, + ): + self.bucket_index = bucket_index + self.parameters = parameters + self.offset = offset + if len(self.parameters) == 0: + raise ValueError("Empty bucket assignment") + # DDP guarantees all parameters in the bucket have the same device + self.device: torch.device = self.parameters[0].device + self.tensor: Optional[torch.Tensor] = None + + +class _OverlapStatus(enum.IntEnum): + r""" + Define possible statuses that :class:`ZeroRedundancyOptimizer` can be in when overlapping with :class:`DistributedDataParallel`. + + Attributes: + ``UNINITIALIZED``: The ZeRO instance is effectively uninitialized and + is waiting for DDP to finalize its bucketing. + ``DDP_HAS_REBUILT_BUCKETS``: DDP has rebuilt its buckets, meaning that + its bucketing is finalized. The ZeRO instance can now collect the + necessary information about the DDP bucketing. + ``INITIALIZED``: The ZeRO instance is fully initialized and can now + optimize parameters. + """ + + UNINITIALIZED = 0 + DDP_HAS_REBUILT_BUCKETS = 1 + INITIALIZED = 2 + + +class _OverlapInfo: + r""" + Information needed by :class:`ZeroRedundancyOptimizer` to overlap with :class:`DistributedDataParallel`. + + Arguments: + world_size (int): world size of the process group being used. + + Attributes: + shard_buckets (bool): if ``True``, then the assignment of each + :class:`DistributedDataParallel` bucket is partitioned across + possibly multiple :class:`ZeroRedundancyOptimizer` instances (i.e. + across possibly multiple ranks) to approximate uniformity following + a threshold given by the total parameter size divided by the world + size; if ``False``, then each bucket is wholly assigned to a single + :class:`ZeroRedundancyOptimizer` instance (i.e. to a single rank); + this should be set to the value passed into the hook constructor. + status (_OverlapStatus): current status; see :class:`_OverlapStatus` + for more information. + params_per_bucket (List[List[torch.Tensor]]): ``params_per_bucket[i]`` + gives the model parameters in the ``i``th bucket. + params_per_rank (List[List[torch.Tensor]]): ``params_per_rank[i]`` + gives the model parameters assigned to the ``i``th rank, where the + parameters are grouped by increasing bucket indices. + offsets (Dict[int, int]): maps from bucket index to the offset in + ``self.params_per_rank[rank]`` giving the index of the first + parameter in that bucket, where ``rank`` is this process's own + rank; the keys of this :class:`dict` are the bucket indices + assigned to this rank. + num_bucket_assignments (int): total number of bucket assignments across + all ranks; this is equal to the number of + :class:`DistributedDataParallel` gradient buckets if + ``shard_buckets=False`` and possibly greater otherwise. + total_size (int, optional): total size of all buckets (i.e. sum of + ``param.numel()`` for all ``param`` across all buckets) if + ``shard_buckets=True``; otherwise, ``None``. + broadcast_handles (List[Work]): :class:`list` of async work handles for + the parameter broadcasts. + bucket_index_to_future (Dict[int, torch.futures.Future]): + :class:`dict` mapping bucket index to the corresponding all-reduce + future. + bucket_index_to_bucket (Dict[int, dist.GradBucket]): :class:`dict` + mapping bucket index to the corresponding bucket. + bucket_indices_seen (List[int]): :class:`list` of the bucket indices + seen on this iteration. + """ + + def __init__(self, world_size) -> None: + self.status: _OverlapStatus = _OverlapStatus.UNINITIALIZED + self.shard_buckets: bool = False + + # Modified per bucket reconstruction + self.params_per_bucket: list[list[torch.Tensor]] = [] + self.params_per_rank: list[list[torch.Tensor]] = [[] for _ in range(world_size)] + self.offsets: dict[int, int] = {} + # Group Ranks + self.assigned_ranks_per_bucket: list[set[int]] = [] + self.num_bucket_assignments: int = 0 + self.total_size: Optional[int] = None + + # Modified per iteration + self.broadcast_handles: list[Any] = [] + self.bucket_indices_seen: list[int] = [] + # Used by `hook_with_zero_step()` + self.bucket_index_to_future: dict[int, torch.futures.Future] = {} + self.bucket_index_to_bucket: dict[int, dist.GradBucket] = {} + + def wait_for_broadcasts(self) -> None: + r""" + Wait for all parameter broadcasts. + + This function should be called once all broadcasts have been scheduled, + meaning ``self.broadcast_handles`` is filled. This clears ``self.broadcast_handles`` + in preparation for the next iteration. + """ + assert len(self.broadcast_handles) == self.num_bucket_assignments, ( + f"Missing at least one broadcast handle on rank {dist.get_rank()}" + ) + _ = [x.wait() for x in self.broadcast_handles] + self.broadcast_handles.clear() + + def clear_per_iter_info(self) -> None: + r""" + Clear the data structures that are modified per-iteration. + + This function should be called at the end of an iteration. + """ + self.bucket_indices_seen.clear() + self.bucket_index_to_future.clear() + self.bucket_index_to_bucket.clear() + + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + r""" + Wrap an arbitrary :class:`optim.Optimizer ` and shards its states across ranks in the group. + + The sharing is done as described by `ZeRO `_. + + The local optimizer instance in each rank is only + responsible for updating approximately ``1 / world_size`` parameters and + hence only needs to keep ``1 / world_size`` optimizer states. After + parameters are updated locally, each rank will broadcast its parameters to + all other peers to keep all model replicas in the same state. + ``ZeroRedundancyOptimizer`` can be used in conjunction with + :class:`torch.nn.parallel.DistributedDataParallel` to reduce per-rank peak + memory consumption. + + ``ZeroRedundancyOptimizer`` uses a sorted-greedy algorithm to pack a number + of parameters at each rank. Each parameter belongs to a single rank and is + not divided among ranks. The partition is arbitrary and might not match the + the parameter registration or usage order. + + Arguments: + params (``Iterable``): an ``Iterable`` of :class:`torch.Tensor` s + or :class:`dict` s giving all parameters, which will be sharded + across ranks. + + Keyword Args: + optimizer_class (:class:`torch.nn.Optimizer`): the class of the local + optimizer. + process_group (``ProcessGroup``, optional): ``torch.distributed`` + ``ProcessGroup`` (default: ``dist.group.WORLD`` initialized by + :meth:`torch.distributed.init_process_group`). + parameters_as_bucket_view (bool, optional): if ``True``, parameters are + packed into buckets to speed up communication, and ``param.data`` + fields point to bucket views at different offsets; if ``False``, + each individual parameter is communicated separately, and each + ``params.data`` stays intact (default: ``False``). + overlap_with_ddp (bool, optional): if ``True``, :meth:`step` is + overlapped with :class:`DistributedDataParallel` 's gradient + synchronization; this requires (1) either a functional optimizer + for the ``optimizer_class`` argument or one with a functional + equivalent and (2) registering a DDP communication hook + constructed from one of the functions in ``ddp_zero_hook.py``; + parameters are packed into buckets matching those in + :class:`DistributedDataParallel`, meaning that the + ``parameters_as_bucket_view`` argument is ignored. + If ``False``, :meth:`step` runs disjointly after the backward pass + (per normal). + (default: ``False``) + **defaults: any trailing arguments, which are forwarded to the local + optimizer. + + Example:: + + >>> # xdoctest: +SKIP + >>> import torch.nn as nn + >>> from torch.distributed.optim import ZeroRedundancyOptimizer + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> model = nn.Sequential(*[nn.Linear(2000, 2000).to(rank) for _ in range(20)]) + >>> ddp = DDP(model, device_ids=[rank]) + >>> opt = ZeroRedundancyOptimizer( + >>> ddp.parameters(), + >>> optimizer_class=torch.optim.Adam, + >>> lr=0.01 + >>> ) + >>> ddp(inputs).sum().backward() + >>> opt.step() + + .. warning:: + Currently, ``ZeroRedundancyOptimizer`` requires that all of the + passed-in parameters are the same dense type. + + .. warning:: + If you pass ``overlap_with_ddp=True``, be wary of the following: Given + the way that overlapping :class:`DistributedDataParallel` with + :class:`ZeroRedundancyOptimizer` is currently implemented, the first + two or three training iterations do not perform parameter updates in + the optimizer step, depending on if ``static_graph=False`` or + ``static_graph=True``, respectively. This is because it needs + information about the gradient bucketing strategy used by + :class:`DistributedDataParallel`, which is not finalized until the + second forward pass if ``static_graph=False`` or until the third + forward pass if ``static_graph=True``. To adjust for this, one option + is to prepend dummy inputs. + + .. warning:: ZeroRedundancyOptimizer is experimental and subject to change. + """ + + def __init__( + self, + params, + optimizer_class: type[Optimizer], + process_group: Optional[Any] = None, + parameters_as_bucket_view: bool = False, + overlap_with_ddp: bool = False, + **defaults: Any, + ): + r"""Init.""" + # Perform type and assumption checks on the input parameters + params = self._verify_and_init_params(params) + self._verify_same_dense_param_type() + + # NOTE: The parent constructor uses `add_param_group()` which is + # partially overloaded in ZeroRedundancyOptimizer, so we use the + # `initialized` flag to dissociate the behaviour of `add_param_group()` + # between the parent and child. + self.initialized = False + + Optimizer.__init__(self, params, defaults) + Joinable.__init__(self) + # Now, all parameters are held in both `self._all_params` and + # `self.param_groups` + + # Internal data structures (`_cache` indicates lazily evaluated) + self._param_to_rank_cache: dict[torch.Tensor, int] = {} + self._param_to_index_cache: dict[torch.Tensor, int] = {} + self._partition_parameters_cache: list[list[dict]] = [] + self._index_to_param_cache: list[torch.Tensor] = [] + self._device_to_params_per_rank_cache: dict[ + torch.device, list[list[torch.Tensor]] + ] = {} + self._bucket_assignments_per_rank_cache: list[ + dict[int, _DDPBucketAssignment] + ] = [] + self._is_trainable_mask = self._get_is_trainable_mask() + + # Default device for collective communication and buckets + self._default_device = self._all_params[0].device + + self.process_group = ( + process_group if process_group is not None else dist.group.WORLD + ) + self.world_size: int = dist.get_world_size(self.process_group) + self.rank: int = dist.get_rank(self.process_group) + self.global_rank: int = dist.distributed_c10d.get_global_rank( + self.process_group, self.rank + ) + + self._overlap_with_ddp: bool = overlap_with_ddp + self._optim_defaults = defaults + self._optim_constructor = self._get_optimizer_constructor(optimizer_class) + + # If `overlap_with_ddp=True`, local optimizer initialization is delayed + # to run time after the necessary information has been collected + if not overlap_with_ddp: + self._init_local_optimizer() + else: + self._overlap_info: _OverlapInfo = _OverlapInfo(self.world_size) + if parameters_as_bucket_view: + logger.warning( + "`parameters_as_bucket_view=True` will be ignored since " + "`overlap_with_ddp=True`; instead, a different bucketing " + "strategy will be used" + ) + + # `self._buckets` is used if `parameters_as_bucket_view=True`, in + # which case parameter data is flattened into contiguous bucket tensors + self.parameters_as_bucket_view = parameters_as_bucket_view + self._buckets: list[list[torch.Tensor]] = [] + self._build_param_buckets() + + # Optional consolidated optimizer state, only populated if this rank + # is the target in `consolidate_state_dict()` + self._all_state_dicts: list[dict[str, Any]] = [] + + self.initialized = True + + def _clear_cache(self) -> None: + r"""Clear the cached data structures giving partition information.""" + self._partition_parameters_cache.clear() + self._param_to_rank_cache.clear() + self._index_to_param_cache.clear() + self._param_to_index_cache.clear() + self._device_to_params_per_rank_cache.clear() + self._bucket_assignments_per_rank_cache.clear() + + def add_param_group(self, param_group: dict[str, Any]) -> None: + r""" + Add a parameter group to the :class:`Optimizer` 's ``param_groups``. + + This can be useful when fine tuning a pre-trained network, as frozen + layers can be made trainable and added to the :class:`Optimizer` as + training progresses. + + Arguments: + param_group (dict): specifies the parameters to be optimized and + group-specific optimization options. + + .. warning:: This method handles updating the shards on all partitions + but needs to be called on all ranks. Calling this on a subset of + the ranks will cause the training to hang because communication + primitives are called depending on the managed parameters and + expect all the ranks to participate on the same set of parameters. + """ + if self.initialized and self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` only " + "supports a single parameter group" + ) + + super().add_param_group(param_group) + # NOTE: The rest of the method assumes that the call to the parent's + # `add_param_group()` appends the new parameter group and preserves + # the previous parameter-group ordering + + if self.initialized: + # Force a re-partitioning of the parameters + self._clear_cache() + param_groups = self._partition_parameters()[self.rank] + # NOTE: All parameters in the old parameter groups should be + # assigned to the same ranks so that the local optimizers do not + # need to be reinitialized + + # Add the parameters assigned to this rank from the new parameter + # group to the local optimizer, if any + if len(param_groups) == len(self.optim.param_groups) + 1: + self.optim.add_param_group(param_groups[-1]) + + # Update the bucketing strategy accordingly + if self.parameters_as_bucket_view: + self._build_param_buckets() + + def consolidate_state_dict(self, to: int = 0) -> None: + r""" + Consolidate a list of ``state_dict`` s (one per rank) on the target rank. + + Arguments: + to (int): the rank that receives the optimizer states (default: 0). + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + + .. warning:: This needs to be called on all ranks. + """ + self._check_overlap_initialized() + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Pull the sharded state from all ranks and store them in rank order + empty_messenger = torch.tensor( + [0], dtype=torch.uint8, device=self._default_device + ) + + # NOTE: We wastefully use `broadcast()` (e.g. instead of `gather()`) + # due to compatibility issues with NCCL backend; a possible follow-up + # is to move all sharded state management to RPC RRef + self._all_state_dicts = [] + for rank in range(self.world_size): + global_rank = dist.distributed_c10d.get_global_rank( + self.process_group, rank + ) + if self.rank == to: + # Consolidate all local `state_dict`s on this rank, storing on + # CPU to save GPU memory + if rank == self.rank: + # Directly append own optimizer state + self._all_state_dicts.append( + _recursive_copy_to_device( + self.optim.state_dict(), + non_blocking=True, + device=torch.device("cpu"), + ) + ) + else: + # Receive the optimizer state from the source rank + local_state_dict = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + self._all_state_dicts.append( + _recursive_copy_to_device( + local_state_dict, + non_blocking=True, + device=torch.device("cpu"), + ) + ) + else: + if rank == self.rank: + # Send the optimizer state to the target rank + _ = _broadcast_object( + self.optim.state_dict(), + src_rank=self.global_rank, + group=self.process_group, + device=self._default_device, + ) + elif rank != to: + # Discard the received object; `broadcast()` is used for + # compatibility reasons + _ = _broadcast_object( + empty_messenger, + src_rank=global_rank, + group=self.process_group, + device=self._default_device, + ) + + def _verify_params_per_rank( + self, + params_per_rank: list[list[torch.Tensor]], + ) -> None: + r""" + Verify ``params_per_rank`` for :meth:`_partition_parameters`. + + The verification is done by checking that ``params_per_rank`` has length equal + to the world size and that it does not contain any parameters not passed into the + :class:`ZeroRedundancyOptimizer` constructor. + + The parameters in ``params_per_rank`` being a strict subset of those + passed into the constructor is valid since some parameters may be + frozen. + + Raises: + ValueError: if ``params_per_rank`` does not have length equal to + the world size or if it contains a parameter that was not + passed into the :class:`ZeroRedundancyOptimizer` constructor. + """ + if len(params_per_rank) != self.world_size: + raise ValueError( + "`params_per_rank` must have length equal to the world size" + ) + all_params_set = set(self._all_params) + for params in params_per_rank: + for param in params: + if param not in all_params_set: + raise ValueError( + "Passing a new parameter in `params_per_rank` that " + "was not passed into the ZeroRedundancyOptimizer " + "constructor" + ) + + def _partition_param_group( + self, param_group: dict[str, Any], params_per_rank: list[list[torch.Tensor]] + ) -> None: + r""" + Partition the parameter group ``param_group`` according to ``params_per_rank``. + + The partition will modify the ``self._partition_parameters_cache``. This method should + only be used as a subroutine for :meth:`_partition_parameters`. + + Arguments: + param_group (dict[str, Any]): a parameter group as normally defined + in an optimizer state. + params_per_rank (list[list[torch.Tensor]]): a :class:`list` of + length world size containing :class:`list` s of parameters to + assign to each rank. + """ + for rank, params in enumerate(params_per_rank): + rank_param_group = copy.copy(param_group) + rank_param_group["params"] = params + self._partition_parameters_cache[rank].append(rank_param_group) + + def _partition_parameters( + self, + params_per_rank: Optional[list[list[torch.Tensor]]] = None, + ) -> list[list[dict]]: + r""" + Partitions parameters across distributed data parallel ranks. + + Arguments: + params_per_rank (list[list[torch.Tensor]], optional): a + :class:`list` of length world size containing :class:`list` s + of parameters to assign to each rank; this provides a way to + specify a partition manually. + If ``None``, the parameters are partitioned according to an + internal algorithm. + (default: ``None``) + + Returns: + A :class:`list` where each element of the list contains the + ``param_groups`` for a rank (which itself is a :class:`list` of + :class:`dict`); element 0 corresponds to rank 0, etc.; each rank + stores the ``param_groups`` for all ranks for the collective + communication in :meth:`step`. + + Raises: + ValueError: see :meth:`_validate_params_per_rank`. + RuntimeError: if ``params_per_rank`` is not ``None`` and this + :class:`ZeroRedundancyOptimizer` instance is using more than + one parameter group. + """ + if params_per_rank is None: + # Partition the parameters optimizing for uniformity + if len(self._partition_parameters_cache) == 0: + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + sizes = [0] * self.world_size + for param_group in self.param_groups: + param_group_params_per_rank: list[list] = [ + [] for _ in range(self.world_size) + ] + # Sort the parameters by size (largest first) + params_sorted = sorted( + param_group["params"], key=lambda t: t.numel(), reverse=True + ) + for param in params_sorted: + # Greedily add the parameter to rank with smallest size so far + rank = self._get_min_index(sizes) + param_group_params_per_rank[rank].append(param) + sizes[rank] += param.numel() + # Apply the constructed partition of the parameter group + self._partition_param_group( + param_group, param_group_params_per_rank + ) + + return self._partition_parameters_cache + + # Partition the parameters according to `params_per_rank` + assert len(self._partition_parameters_cache) == 0, ( + "Specifying `params_per_rank` should only be done when the " + "parameters have not been partitioned yet" + ) + if len(self.param_groups) != 1: + raise RuntimeError( + "Specifying `params_per_rank` only supports a single parameter group" + ) + self._verify_params_per_rank(params_per_rank) + self._partition_parameters_cache = [[] for _ in range(self.world_size)] + + # Apply the passed-in partition of the parameter group + param_group = self.param_groups[0] + self._partition_param_group(param_group, params_per_rank) + + return self._partition_parameters_cache + + @property + def _param_to_rank(self) -> dict[torch.Tensor, int]: + r""":class:`dict` mapping parameters to their assigned data parallel rank in the partition.""" + if len(self._param_to_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + self._param_to_rank_cache[param] = rank + return self._param_to_rank_cache + + @property + def _param_to_index(self) -> dict[torch.Tensor, int]: + r""" + :class:`dict` mapping parameters to their indices in the global optimizer state. + + NOTE: This assumes that the global optimizer state's indexing (in + ``state_dict``) follows a linear ordering over the parameter groups. + """ + if len(self._param_to_index_cache) == 0: + self._param_to_index_cache = { + p: i + for i, p in enumerate( + chain.from_iterable(g["params"] for g in self.param_groups) + ) + } + return self._param_to_index_cache + + @property + def _index_to_param(self) -> list[torch.Tensor]: + r"""List mapping parameter indices in the global optimizer scheme to the actual params.""" + if len(self._index_to_param_cache) == 0: + self._index_to_param_cache = list( + chain.from_iterable(g["params"] for g in self.param_groups) + ) + return self._index_to_param_cache + + def _broadcast_params_from_rank(self, rank: int): + r""" + Broadcast the shard of parameters from a given rank to all other ranks asynchronously. + + Arguments: + rank (int): the source rank. + + Returns: + A :class:`list` of async work handles for the ``broadcast()`` s + performed to synchronize the parameters. + """ + assert not self._overlap_with_ddp, ( + "`_broadcast_params_from_rank()` should not be used if " + "`overlap_with_ddp=True`; instead, the broadcasting should " + "happen in the DDP communication hook" + ) + handles = [] + if self.parameters_as_bucket_view: + for dev_i_buckets in self._buckets: + bucket = dev_i_buckets[rank] + global_rank = dist.distributed_c10d.get_global_rank( + self.process_group, rank + ) + handles.append( + dist.broadcast( + tensor=bucket, + src=global_rank, + group=self.process_group, + async_op=True, + ) + ) + else: + param_groups = self._partition_parameters()[rank] + global_rank = dist.distributed_c10d.get_global_rank( + self.process_group, rank + ) + for param_group in param_groups: + handles.extend( + dist.broadcast( + tensor=param.data, + src=global_rank, + group=self.process_group, + async_op=True, + ) + for param in param_group["params"] + ) + return handles + + def _sync_params(self): + r""" + Sync all parameter shards across the ranks. + + This rank sends its shard of the parameters to all other ranks and + receives a shard from each other rank. This is done using + ``broadcast()``. Parameters are sent bucket-by-bucket if + ``parameters_as_bucket_view=True``and sent parameter-by-parameter + otherwise. + """ + handles = [] + for rank in range(self.world_size): + handles.extend(self._broadcast_params_from_rank(rank)) + _ = [x.wait() for x in handles] + + @property + def _device_to_params_per_rank( + self, + ) -> dict[torch.device, list[list[torch.Tensor]]]: + r""" + Return device parameters assigned per rank. + + :class:`dict` mapping each device to a :class:`list` of the per-rank parameter + lists filtered to only include the parameters stored on that device. + Each per-rank parameter list gives the parameters assigned to that rank + to update. + + This is used for constructing the parameter buckets if + ``parameters_as_bucket_view=True``. + + Let ``dev_i`` denote the ``i``th device for this rank. Then: + ``dev_0`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_0``, + rank 1's assigned parameters stored on ``dev_0``, + ... + ``dev_1`` maps to a list containing: + rank 0's assigned parameters stored on ``dev_1``, + rank 1's assigned parameters stored on ``dev_1``, + ... + ... + """ + assert self.parameters_as_bucket_view, ( + "`_device_to_params_per_rank` should only be used if " + "`parameters_as_bucket_view=True`" + ) + if len(self._device_to_params_per_rank_cache) == 0: + for rank, param_groups in enumerate(self._partition_parameters()): + for param_group in param_groups: + for param in param_group["params"]: + device = param.device + if device not in self._device_to_params_per_rank_cache: + self._device_to_params_per_rank_cache[device] = [ + [] for _ in range(self.world_size) + ] + self._device_to_params_per_rank_cache[device][rank].append( + param + ) + return self._device_to_params_per_rank_cache + + def _get_min_index( + self, + values: list[int], + disallowed_indices: Optional[set[int]] = None, + ) -> int: + r""" + Return ``values.index(min(values))``, except only uses one pass. + + It also excludes any indices in ``disallowed_indices`` if provided. + + Arguments: + values: (List[int]): :class:`list` of values. + disallowed_indices (Optional[set[int]]): indices that are + disallowed from being the returned min index. + """ + min_index = -1 + min_value = float("inf") + for i, value in enumerate(values): + if disallowed_indices and i in disallowed_indices: + continue + if value < min_value: + min_value = value + min_index = i + assert min_index >= 0, "All indices are disallowed" + return min_index + + def _assign_bucket_subset_to_rank( + self, + bucket_index: int, + bucket_params: list[torch.Tensor], + bucket_offset: int, + assigned_rank: int, + assigned_ranks_per_bucket: list[set[int]], + ) -> None: + r""" + Assign ``bucket_params`` to the rank with the least size assigned so far and collects relevant information. + + The model parameters given by ``bucket_params`` represents a (possibly non-strict) + subset of the parameters corresponding to a :class:`DistributedDataParallel` bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + gradient bucket. + bucket_params (List[torch.Tensor]): subset of the parameters + corresponding to the bucket to assign. + bucket_offset (int): offset giving the index of the first element + in ``bucket_params`` in the bucket's full parameter list. + assigned_rank (int): group rank to assign to. + assigned_ranks_per_bucket (list[set[int]]): :class:`set` of group ranks + assigned to each bucket. + """ + overlap_info = self._overlap_info + if len(bucket_params) == 0: + raise ValueError("Empty bucket assignment") + params_per_rank = overlap_info.params_per_rank + offsets = overlap_info.offsets + + self._bucket_assignments_per_rank_cache[assigned_rank][bucket_index] = ( + _DDPBucketAssignment(bucket_index, bucket_params, bucket_offset) + ) + if self.global_rank == assigned_rank: + offsets[bucket_index] = len(params_per_rank[assigned_rank]) + params_per_rank[assigned_rank].extend(bucket_params) + assigned_ranks_per_bucket[bucket_index].add(assigned_rank) + self._overlap_info.num_bucket_assignments += 1 + + @property + def _bucket_assignments_per_rank(self) -> list[dict[int, _DDPBucketAssignment]]: + r""" + Return DDP bucket parameters assigned per rank. + + :class:`list` of length world size consisting of :class:`dict` s + mapping bucket indices to :class:`_DDPBucketAssignment` s for each + rank. + """ + assert self._overlap_with_ddp, ( + "`_bucket_assignments_per_rank` only be used if `overlap_with_ddp=True`" + ) + if len(self._bucket_assignments_per_rank_cache) > 0: + return self._bucket_assignments_per_rank_cache + + overlap_info = self._overlap_info + assert overlap_info.status == _OverlapStatus.INITIALIZED + + self._bucket_assignments_per_rank_cache = [{} for _ in range(self.world_size)] + params_per_bucket = overlap_info.params_per_bucket + + if overlap_info.shard_buckets: + # Define the assignment threshold to approximate uniformity + assert overlap_info.total_size is not None, "`total_size` was not computed" + threshold = overlap_info.total_size / self.world_size # type: ignore[operator] + size_per_rank = [0 for _ in range(self.world_size)] + + num_buckets = len(params_per_bucket) + overlap_info.assigned_ranks_per_bucket = [set() for _ in range(num_buckets)] + assigned_ranks_per_bucket = overlap_info.assigned_ranks_per_bucket + if not overlap_info.shard_buckets: + # Assign each DDP bucket entirely to a single rank + for bucket_index, bucket_params in enumerate(params_per_bucket): + assert len(bucket_params) > 0, "Empty bucket" + assigned_rank = self._get_assigned_rank(bucket_index) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params, + 0, + assigned_rank, + assigned_ranks_per_bucket, + ) + else: + # Assign each DDP bucket to possibly multiple ranks + # Specifically, sort the DDP buckets by increasing size, and for + # each bucket, iteratively assign the maximal unassigned subset + # with size less than `threshold` to the rank with the least total + # size so far -- each such assignment is represented by a + # `_DDPBucketAssignment` instance and only contains parameters from + # a single DDP bucket + params_per_bucket_enum = sorted( + enumerate(params_per_bucket), key=lambda x: sum(p.numel() for p in x[1]) + ) + for bucket_index, bucket_params in params_per_bucket_enum: + assert len(bucket_params) > 0, "Empty bucket" + bucket_offset = 0 + assignment_size = 0 + for param_index, param in enumerate(bucket_params): + param_numel = param.numel() + if ( + assignment_size + param_numel >= threshold + and param_index > bucket_offset + ): + assigned_rank = self._get_min_index( + size_per_rank, assigned_ranks_per_bucket[bucket_index] + ) + # Include up to but not including the parameter that + # exceeded the threshold + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:param_index], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + size_per_rank[assigned_rank] += assignment_size + bucket_offset = param_index + assignment_size = 0 + assignment_size += param_numel + # Assign the remainder of the bucket so that no assignment + # spans across two buckets + assigned_rank = self._get_min_index( + size_per_rank, assigned_ranks_per_bucket[bucket_index] + ) + self._assign_bucket_subset_to_rank( + bucket_index, + bucket_params[bucket_offset:], + bucket_offset, + assigned_rank, + assigned_ranks_per_bucket, + ) + size_per_rank[assigned_rank] += assignment_size + + return self._bucket_assignments_per_rank_cache + + def _local_step( + self, + gradients: Optional[list[Optional[torch.Tensor]]] = None, + closure: Optional[Callable[[], float]] = None, + **kwargs: Any, + ) -> Optional[float]: + r""" + Perform a single optimizer step without syncing parameters across ranks. + + Arguments: + gradients (list[Optional[torch.Tensor]], optional): a :class:`list` + of length equal to the number of parameters assigned to this + rank containing gradient tensors or ``None`` as its elements; + a ``None`` in the :class:`list` indicates that the + corresponding parameter should not be updated. + If the argument itself is ``None``, then all parameters are + updated, and the gradients are assumed to be already populated. + (default: ``None``) + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers and should be + ``None`` if ``gradients`` is not ``None``; (default: ``None``) + Returns: + Optional loss depending on the underlying local optimizer. + + .. warning:: + The argument ``gradients`` should only be specified (i.e. not + ``None``) if ``overlap_with_ddp=True``, in which case + :class:`ZeroRedundancyOptimizer` wraps a functional optimizer. + """ + Join.notify_join_context(self) + # Check if the model trainability has changed + is_trainable_mask = self._get_is_trainable_mask() + if is_trainable_mask != self._is_trainable_mask: + if self._overlap_with_ddp: + raise RuntimeError( + "ZeroRedundancyOptimizer with `overlap_with_ddp=True` " + "does not support changing parameter trainability at run " + "time" + ) + logger.warning( + "ZeroRedundancyOptimizer detected that the trainable " + "parameters changed; rebuilding the parameter buckets if " + "enabled" + ) + self._build_param_buckets() + self._is_trainable_mask = is_trainable_mask + + # Sync the exposed `param_groups` attributes to the local optimizer in + # case they have been updated + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + # Run the optimizer step on this shard only + if gradients is None: + loss = ( + self.optim.step(**kwargs) + if closure is None + else self.optim.step(closure=closure, **kwargs) + ) + else: + assert self._overlap_with_ddp, ( + "Specifying `gradients` should not " + "be used when `overlap_with_ddp=False`" + ) + assert closure is None, ( + "`closure` is not supported when using a local functional optimizer" + ) + loss = self.optim.step(gradients=gradients) + + # Sync any updated attributes in the local optimizer to the exposed + # `param_groups` + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + return loss + + def step( + self, + closure: Optional[Callable[[], float]] = None, + **kwargs: Any, + ) -> Optional[float]: + r""" + Perform a single optimizer step and syncs parameters across all ranks. + + Arguments: + closure (Callable): a closure that re-evaluates the model and + returns the loss; optional for most optimizers. + Returns: + Optional loss depending on the underlying local optimizer. + + .. note:: Any extra parameters are passed to the base optimizer as-is. + """ + if self._overlap_with_ddp: + logger.warning( + "`step()` should not be included in the training loop when " + "`overlap_with_ddp=True`" + ) + return None + + # Perform the local optimizer step + loss = self._local_step(closure=closure, **kwargs) + + # Sync all of the updated parameter shards across the ranks + self._sync_params() + + return loss + + def join_hook(self, **kwargs): + r""" + Return the ZeRO join hook. + + It enables training on uneven inputs by + shadowing the collective communications in the optimizer step. + + Gradients must be properly set before this hook is called. + + Arguments: + kwargs (dict): a :class:`dict` containing any keyword arguments + to modify the behavior of the join hook at run time; all + :class:`Joinable` instances sharing the same join context + manager are forwarded the same value for ``kwargs``. + + This hook does not support any keyword arguments; i.e. ``kwargs`` is + unused. + """ + return _ZeROJoinHook(self) + + @property + def join_device(self) -> torch.device: + r"""Return default device.""" + return self._default_device + + @property + def join_process_group(self) -> Any: + r"""Return process group.""" + return self.process_group + + def load_state_dict(self, state_dict: dict[str, Any]) -> None: + r""" + Load the state pertaining to the given rank from the input ``state_dict``, updating the local optimizer as needed. + + Arguments: + state_dict (dict): optimizer state; should be an object returned + from a call to :meth:`state_dict`. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt. + """ + self._check_overlap_initialized() + + for index, value in state_dict["state"].items(): + param = self._index_to_param[index] + if self._param_to_rank[param] != self.rank: + # Clear any state irrelevant to this rank + state_dict["state"][index] = None + else: + # Load the parameter state to the local optimizer + self.optim.state[param] = _recursive_copy_to_device( + value, non_blocking=True, device=param.device + ) + # Force zero-dimensional tensors (like Adam "step") on CPU + for state_name, state_value in self.optim.state[param].items(): + if torch.is_tensor(state_value) and state_value.dim() == 0: + self.optim.state[param][state_name] = state_value.cpu() + + super().load_state_dict(state_dict) + + # Sync the input state with the exposed and local optimizer states + self._sync_param_groups(state_dict["param_groups"], self.param_groups) + self._sync_param_groups(self.param_groups, self.optim.param_groups) + + def state_dict(self) -> dict[str, Any]: + r""" + Return the last global optimizer state known to this rank. + + .. warning: + If the state has not been consolidated to this rank, this raises a + runtime error, and even if it has, the state may not be up-to-date, + depending on when :meth:`consolidate_state_dict` was last called. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and this method is + called before this :class:`ZeroRedundancyOptimizer` instance + has been fully initialized, which happens once + :class:`DistributedDataParallel` gradient buckets have been + rebuilt; or if this method is called without a preceding call + to :meth:`consolidate_state_dict`. + """ + self._check_overlap_initialized() + + if len(self._all_state_dicts) == 0: + raise RuntimeError( + "Optimizer state has not been consolidated on this rank. " + f"Please call `consolidate_state_dict(to={self.rank})` on " + "all ranks beforehand if you meant to save the global state." + ) + + # Get the possibly-stale global optimizer state that uses global + # parameter indexing + state_dict = super().state_dict() + + # Update the global optimizer state with local state information, + # factoring in the translation from local to global indexing + for rank, local_state_dict in enumerate(self._all_state_dicts): + local_param_groups = local_state_dict["param_groups"] + global_param_groups = self._partition_parameters()[rank] + assert len(local_param_groups) == len(global_param_groups), ( + "Mismatch between number of local and global parameter groups" + ) + + for local_param_group, global_param_group in zip( + local_param_groups, global_param_groups + ): + # `local_param_group` stores local indices, while + # `global_param_group` stores the tensors directly + local_param_indices = local_param_group["params"] + global_params = global_param_group["params"] + + assert len(local_param_indices) == len(global_params), ( + "Mismatch between number of local and global parameters in parameter group" + ) + for local_param_index, global_param in zip( + local_param_indices, global_params + ): + # Update the global parameter state, if any + if local_param_index in local_state_dict["state"]: + global_param_index = self._param_to_index[global_param] + state_dict["state"][global_param_index] = local_state_dict[ + "state" + ][local_param_index] + + # Sort the parameters in the state + state_dict["state"] = dict(sorted(state_dict["state"].items())) + return state_dict + + @staticmethod + def _sync_param_groups( + src_param_groups: list[dict[Any, Any]], + dst_param_groups: list[dict[Any, Any]], + ) -> None: + r""" + Sync the attributes from the source parameter groups to the destination parameter groups. + + Example attributes include learning rate or scheduler attributes. The + two parameter groups should have the same length (i.e. same number of + parameter groups). + + Arguments: + src_param_groups (list[dict]): parameter groups giving the + attribute settings to copy. + dst_param_groups (list[dict]): parameter groups giving the + attribute settings to set. + """ + assert len(src_param_groups) == len(dst_param_groups), ( + "Mismatch between number of source and destination parameter groups" + ) + for src_param_group, dst_param_group in zip(src_param_groups, dst_param_groups): + # Sync all attributes except the parameters + for attr in filter(lambda x: x != "params", src_param_group.keys()): + dst_param_group[attr] = src_param_group[attr] + + def _build_param_buckets(self) -> None: + r""" + Build parameter buckets if ``parameters_as_bucket_view=True``. + + For each device that stores this rank's parameters, there is a + bucket (represented as a tensor) containing all of the parameters on + that device that are assigned to a given rank in the parameter update + partition. + + This method is called in the constructor and any time parameter + trainability is changed. + + .. warning:: + The current implementation assumes that all of the parameters in a + bucket are of the same dense type when allocating the bucket's + tensor. + + .. warning:: + If the model parameters are stored across more than one device, + then the storage partitioning must be the same across all + processes in order for parameter synchronization to work. + """ + if not self.parameters_as_bucket_view or self._overlap_with_ddp: + return + + # `self._buckets[i][j]` are the parameters stored on device i and + # assigned to rank j + num_devices = len(self._device_to_params_per_rank) + self._buckets = [[] for _ in range(num_devices)] # type: ignore[assignment] + + for dev_i, (device, params_per_rank) in enumerate( + self._device_to_params_per_rank.items() + ): + for params in params_per_rank: + bucket_size = 0 + dtype = None + trainable_params = [] + for param in params: + if not _is_trainable(param): + # Clone in case the parameter was previously part of + # a bucket to avoid the data from being destroyed + param.data = param.data.detach().clone() + else: + bucket_size += param.numel() + trainable_params.append(param) + dtype = param.dtype # assumes all same dtype + + if bucket_size == 0: + # Create a dummy bucket if there are no parameters + bucket = torch.zeros(1, device=device) + else: + # Construct the bucket (assuming all dense and same dtype) + bucket = torch.empty(bucket_size, dtype=dtype, device=device) + offset = 0 + for param in trainable_params: + offset_next = offset + param.numel() + bucket[offset:offset_next].copy_(param.data.flatten()) + param.data = bucket[offset:offset_next].view_as(param.data) + offset = offset_next + self._buckets[dev_i].append(bucket) # type: ignore[arg-type] + + def _build_ddp_param_buckets(self) -> None: + r""" + Build the DDP bucket with parameters assigned to this rank. + + For each DDP bucket with parameters assigned to this rank, flattens the + data of those parameters into a single tensor and saves the tensor to + the ``tensor`` attribute in the corresponding + :class:`_DDPBucketAssignment` instance stored in + ``self._bucket_assignments_per_rank``. + + :class:`DistributedDataParallel` guarantees that the parameters + corresponding to a gradient bucket have the same device and the same + dtype. + """ + for bucket_assignments in self._bucket_assignments_per_rank: + for bucket_assignment in bucket_assignments.values(): + params = bucket_assignment.parameters + bucket_size = 0 + dtype = None + for param in params: + assert _is_trainable(param), ( + "Model parameter " + "corresponding to a gradient in a DDP bucket should " + "require a gradient" + ) + bucket_size += param.numel() + dtype = param.dtype # assumes all same dtype + assert bucket_size > 0, "Empty bucket" + + # Construct the bucket tensor (assuming all dense and same dtype) + tensor = torch.empty( + bucket_size, dtype=dtype, device=bucket_assignment.device + ) + offset = 0 + for param in params: + offset_next = offset + param.numel() + tensor[offset:offset_next].copy_(param.data.flatten()) + param.data = tensor[offset:offset_next].view_as(param.data) + offset = offset_next + bucket_assignment.tensor = tensor + + def _verify_and_init_params( + self, + params: Any, + ) -> Union[list[torch.Tensor], list[dict]]: + r""" + Verify the type of ``params`` and initializes ``self._all_params`` as a :class:`list` of all parameters. + + The initializagtion will first make sure that provided ``params`` is valid. + + Arguments: + params (Any): Candidate parameter list or parameter groups to verify. + + Raises: + TypeError: ``params`` has an invalid type. + ValueError: ``params`` is empty. + + Returns: + The persistent form of ``params`` to be passed into the parent + :class:`Optimizer` constructor -- i.e. returns ``params`` as a + :class:`list` to ensure that it can be iterated over again. + """ + if isinstance(params, torch.Tensor): + raise TypeError( + "`params` argument should be an iterable of " + f"Tensors, but got {torch.typename(params)}" + ) + try: + all_params = list(params) + except TypeError as e: + raise TypeError( + "`params` argument should be an iterable of Tensors" + f" or dicts, but got {torch.typename(params)}" + ) from e + if len(all_params) == 0: + raise ValueError("ZeroRedundancyOptimizer got an empty parameter list") + all_tensors = True + all_dicts = True + for param in all_params: + all_tensors &= isinstance(param, torch.Tensor) + all_dicts &= isinstance(param, dict) + if not all_tensors and not all_dicts: + raise TypeError( + "`params` argument should be an iterable of Tensors or dicts" + ) + # Ensure that `self._all_params` contains a list of all parameters + if all_tensors: + self._all_params = all_params + elif all_dicts: + self._all_params = [] + # `all_params` contains parameter groups (not parameters) + for param_group in all_params: + if "params" not in param_group: + raise ValueError( + "Each parameter group passed-in via `params` must " + "have a 'params' key mapping to the parameters in " + "the group" + ) + self._all_params.extend(param_group["params"]) + return all_params + + def _verify_same_dense_param_type(self) -> None: + r""" + Verify that all parameters are of the same dense type. + + The method assumes that ``self._all_params`` has been initialized + and is non-empty. + + Raises: + ValueError: ``params`` contains sparse parameters or parameters + of varying dense types. + + NOTE: This method can be removed once support for sparse parameters + and varying parameter types is added. + """ + typename = torch.typename(self._all_params[0]) + if self._all_params[0].is_sparse: + raise ValueError( + "ZeroRedundancyOptimizer only supports using " + "the same dense type for all parameters but got " + f"{typename}" + ) + for param in self._all_params[1:]: + other_typename = torch.typename(param) + if other_typename != typename: + raise ValueError( + "ZeroRedundancyOptimizer only supports " + "using the same dense type for all " + f"parameters but got both {typename} and " + f"{other_typename}" + ) + + def _get_is_trainable_mask(self) -> list[bool]: + r"""Return a boolean mask indicating if each parameter is trainable (``requires_grad``) or not.""" + return list(map(_is_trainable, self._all_params)) + + def _init_local_optimizer(self) -> None: + r""" + Initialize this rank's local optimizer, responsible for its subset of the parameters. + + The local optimizer is saved in ``self.optim``. + """ + assert self._optim_constructor is not None, ( + "The local optimizer class has not been set" + ) + + param_groups = self._partition_parameters()[self.rank] + # `overlap_with_ddp=True` requires a local functional optimizer + if self._overlap_with_ddp: + # Functional optimizers only support a single parameter group and + # require passing in the parameters as a list + assert len(param_groups) == 1, ( + "Initializing the local " + "functional optimizer with more than one parameter group" + ) + params = param_groups[0]["params"] + # Try to pass `_allow_empty_param_list=True` to avoid erroring + if ( + "_allow_empty_param_list" + in inspect.signature(self._optim_constructor).parameters + ): + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults, _allow_empty_param_list=True + ) + else: + logger.warning( + "%s does not support the argument " + "`_allow_empty_param_list`; ZeroRedundancyOptimizer may " + "error due to an empty parameter list", + self._optim_constructor, + ) + self.optim: Any = self._optim_constructor( + params, **self._optim_defaults + ) # type: ignore[no-redef] + + # Log information about the DDP and ZeRO bucketing + if dist.get_debug_level() != dist.DebugLevel.OFF: + local_numel = sum(p.numel() for p in params) + num_assigned_buckets = len( + self._bucket_assignments_per_rank[self.global_rank] + ) + logger.info( + "rank %s with %s parameters across %s buckets", + self.global_rank, + local_numel, + num_assigned_buckets, + ) + if self.global_rank == 0: + logger.info( + "%s DDP buckets and %s bucket assignments", + len(self._overlap_info.params_per_bucket), + self._overlap_info.num_bucket_assignments, + ) + else: + # NOTE: Passing `param_groups` into the local optimizer constructor + # bypasses the empty parameter list check + self.optim: Optimizer = self._optim_constructor( + param_groups, **self._optim_defaults + ) # type: ignore[no-redef] + + # TODO: Manually add `self.param_groups` if using a functional + # optimizer; remove this if/when the functional optimizers support + # multiple parameter groups + if self._overlap_with_ddp and not hasattr(self.optim, "param_groups"): + assert hasattr(self.optim, "param_group"), ( + "The functional optimizer should set at least one of the " + "attributes `param_group` or `param_groups`" + ) + self.optim.param_groups = [self.optim.param_group] # type: ignore[attr-defined] + + self._sync_param_groups(self.optim.param_groups, self.param_groups) + + def _init_zero_for_overlap(self) -> None: + r"""Perform a delayed initialization of the local optimizer and the supporting data structures.""" + assert self._overlap_with_ddp, ( + "`_init_zero_for_overlap()` should only be called when " + "`overlap_with_ddp=True`" + ) + self._overlap_info.status = _OverlapStatus.INITIALIZED + self._clear_cache() + self._partition_parameters(self._overlap_info.params_per_rank) + self._build_ddp_param_buckets() + self._init_local_optimizer() + + def _get_assigned_rank(self, bucket_index: int) -> int: + r""" + Return the single rank assigned to a :class:`DistributedDataParallel` gradient bucket. + + Arguments: + bucket_index (int): index of the :class:`DistributedDataParallel` + bucket for which to get the assigned rank. + """ + assert not self._overlap_info.shard_buckets, ( + "The bucket assignment requires global bucket information and " + "will be computed later; there should be no need to use this " + "method" + ) + return bucket_index % self.world_size + + def _check_overlap_initialized(self): + r""" + Check the delayed initialization depending on the value of ``overlap_with_ddp``. + + The delayed initialization has occurred (see + :meth:`_init_zero_for_overlap`) if ``overlap_with_ddp=True``, and + raises a ``RuntimeError`` if not. This should preface methods that + should not be run before that delayed initialization. + + Raises: + RuntimeError: if ``overlap_with_ddp=True`` and + :meth:`_init_zero_for_overlap` has not been called. + """ + if ( + self._overlap_with_ddp + and self._overlap_info.status != _OverlapStatus.INITIALIZED + ): + raise RuntimeError( + "This method should not be called until this " + "ZeroRedundancyOptimizer instance has been fully " + "initialized" + ) + + def _get_optimizer_constructor(self, optimizer_class: Any) -> Any: + r""" + Return the optimizer constructor using validation and transformation depending on ``overlap_with_ddp``. + + Returns: + - ``optimizer_class`` if ``overlap_with_ddp=False`` and + ``optimizer_class`` is not a functional optimizer. + - ``optimizer_class`` if ``overlap_with_ddp=True`` and + ``optimizer_class`` is already a functional optimizer. + - The functional equivalent of ``optimizer_class`` if + ``overlap_with_ddp=True`` and ``optimizer_class`` is not + already a functional optimizer (assuming the equivalent + exists). + + Raises: + ValueError: + + - if ``overlap_with_ddp=True`` but ``optimizer_class`` is + neither a functional optimizer nor translatable to a + functional optimizer. + - if ``overlap_with_ddp=False`` and ``optimizer_class`` is a + functional optimizer. + """ + functional_optims = functional_optim_map.values() + if not self._overlap_with_ddp: + if optimizer_class in functional_optims: + # Using a functional optimizer is only supported when + # `overlap_with_ddp=True` + raise ValueError( + f"Passing in a functional optimizer {optimizer_class} " + "when `overlap_with_ddp=False`" + ) + else: + return optimizer_class + else: + if optimizer_class in functional_optims: + # Already a functional optimizer + return optimizer_class + elif optimizer_class in functional_optim_map: + # Translate the passed-in optimizer class to its functional + # equivalent if `overlap_with_ddp=True` + optim_constructor = functional_optim_map[optimizer_class] + logger.info( + "Using the functional optimizer %s " + "instead of %s since " + "`overlap_with_ddp=True`", + optim_constructor, + optimizer_class, + ) + return optim_constructor + else: + raise ValueError( + "Using `ddp_with_overlap=True` requires using a " + "functional optimizer, but there is no supported functional " + f"optimizer equivalent for {optimizer_class}" + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi b/phivenv/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi new file mode 100644 index 0000000000000000000000000000000000000000..735a67306744a953ef1e5795b64eb8171c5d5b3b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/optim/zero_redundancy_optimizer.pyi @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import enum +from typing import Any, Callable, overload + +import torch +from torch.distributed.algorithms.join import Joinable, JoinHook +from torch.optim import Optimizer + +class _ZeROJoinHook(JoinHook): + zero: Any = ... + def __init__(self, zero: Any) -> None: ... + def main_hook(self) -> None: ... + +class _DDPBucketAssignment: + bucket_index: int + parameters: list[torch.Tensor] + offset: int + device: torch.device + tensor: torch.Tensor | None + +class _OverlapStatus(enum.IntEnum): + UNINITIALIZED = ... + DDP_HAS_REBUILT_BUCKETS = ... + INITIALIZED = ... + +class _OverlapInfo: + status: Any = ... + params_per_bucket: Any = ... + params_per_rank: Any = ... + offsets: Any = ... + broadcast_handles: Any = ... + bucket_index_to_future: Any = ... + bucket_index_to_bucket: Any = ... + bucket_indices_seen: Any = ... + assigned_ranks_per_bucket: list[set[int]] = ... + total_size: int = ... + shard_buckets: bool = ... + def __init__(self) -> None: ... + def wait_for_broadcasts(self) -> None: ... + def clear_per_iter_info(self) -> None: ... + +class ZeroRedundancyOptimizer(Optimizer, Joinable): + functional_optim_map: Any = ... + initialized: bool = ... + process_group: Any = ... + world_size: int = ... + rank: int = ... + global_rank: int = ... + parameters_as_bucket_view: bool = ... + optim: Any = ... + _device_to_device_index: dict[torch.device, int] = ... + _overlap_with_ddp: bool = ... + _overlap_info: _OverlapInfo = ... + _buckets: list[list[torch.Tensor]] = ... + _bucket_assignments_per_rank: list[dict[int, _DDPBucketAssignment]] = ... + def __init__( + self, + params: Any, + optimizer_class: type[Optimizer], + process_group: Any | None = ..., + parameters_as_bucket_view: bool = ..., + overlap_with_ddp: bool = ..., + **defaults: Any, + ) -> None: ... + def add_param_group(self, param_group: dict[str, Any]) -> None: ... + def consolidate_state_dict(self, to: int = ...) -> None: ... + @overload + def step(self, closure: None = None, **kwargs: Any) -> None: ... + @overload + def step(self, closure: Callable[[], float], **kwargs: Any) -> float: ... + def load_state_dict(self, state_dict: dict[str, Any]) -> None: ... + def state_dict(self) -> dict[str, Any]: ... + def _local_step( + self, + gradients: list[torch.Tensor | None] | None = None, + closure: Callable[[], float] | None = None, + **kwargs: Any, + ) -> float | None: ... + def _get_assigned_rank(self, bucket_index: int) -> int: ... + def _init_zero_for_overlap(self) -> None: ... + def join_hook(self, **kwargs): ... + @property + def join_device(self) -> torch.device: ... + def join_process_group(self) -> Any: ... diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/_IR.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/_IR.py new file mode 100644 index 0000000000000000000000000000000000000000..01b9825f196e8f463fc8b1b155a017628a0aec64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/_IR.py @@ -0,0 +1,1246 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import copy +import logging +import operator +from collections import defaultdict +from enum import Enum +from inspect import Parameter, Signature, signature +from types import MethodType +from typing import Any, Callable, Optional, Union + +import torch +import torch.fx as fx +from torch.distributed import ProcessGroup +from torch.export import ExportedProgram +from torch.export.unflatten import ( + _assign_attr, + _AttrKind, + _sink_params, + InterpreterModule, +) +from torch.fx.node import map_aggregate +from torch.fx.passes.split_module import split_module + +from ._backward import _null_coalesce_accumulate, stage_backward +from ._unflatten import _outline_submodules +from ._utils import PipeInfo +from .stage import _PipelineStage + + +logger = logging.getLogger(__name__) + +# TODO: +# 1. investigate gradient sync for shared parameters. how does DDP do it? +# 2. Add parameter movement to split_module + + +def _find_loss_from_output_and_spec(output_val, spec_val): + if spec_val is False: + return None + if spec_val is True: + if not isinstance(output_val, fx.Node): + raise RuntimeError( + f"Loss spec must specify a dynamic value but got {output_val}" + ) + return output_val + + if isinstance(spec_val, (tuple, list)): + if not isinstance(output_val, (tuple, list)): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if len(output_val) != len(spec_val): + raise RuntimeError( + f"Output value {output_val} must match length of loss specification " + f"{spec_val}" + ) + for out, spec in zip(output_val, spec_val): + loss_val = _find_loss_from_output_and_spec(out, spec) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + if isinstance(spec_val, dict): + if not isinstance(output_val, dict): + raise RuntimeError( + f"Output value {output_val} must match type of loss specification " + f"{spec_val}" + ) + if set(output_val.keys()) != set(spec_val.keys()): + raise RuntimeError( + f"Output value {output_val} must match keys of loss specification " + f"{spec_val}" + ) + for k in spec_val: + loss_val = _find_loss_from_output_and_spec(output_val[k], spec_val[k]) + if loss_val is not None: + return loss_val + raise RuntimeError(f"Did not find loss value in specification {spec_val}") + + raise RuntimeError(f"Unsupported type {type(spec_val)} in loss specification") + + +def _find_loss_output(mod: torch.nn.Module, g: fx.Graph, output_loss_value_spec): + output_nodes = [n for n in g.nodes if n.op == "output"] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + output_val = output_node.args[0] + generated_spec: Any = None + + if isinstance(mod, TrivialLossWrapper): + # TrivialLossWrapper is pre-defined by PiPPy. + # It has loss as the only output so we can safely assume the first output arg is the loss. + assert len(output_node.args) == 1 + loss_node = output_val + generated_spec = TrivialLossWrapper.loss_spec + elif output_loss_value_spec is None: + # Use default spec, i.e. search for "loss" in output values + if isinstance(output_val, dict) and "loss" in output_val.keys(): + loss_node = output_val["loss"] + generated_spec = {k: k == "loss" for k in output_val} + else: + loss_node = None + generated_spec = None + else: + loss_node = _find_loss_from_output_and_spec(output_val, output_loss_value_spec) + generated_spec = output_loss_value_spec + + return loss_node, output_node, generated_spec + + +def _insert_stage_symbolic_backward( + g: fx.Graph, + loss_node: fx.Node, + output_node: fx.Node, +): + # Collect metadata about tuple output values. TODO: move this to split_module or FX IR + tuples: dict[fx.Node, tuple] = {} + for node in reversed(g.nodes): + if node.op == "call_function": + # In the forward pass, only emit placeholder, module calls, and + # getitem calls. If we have a target other than getitem in this + # (forward-only) code, there is a bug. + assert node.target == operator.getitem, ( + "Found non-getitem call in forward pass. Please report a bug to PiPPy" + ) + assert len(node.args) == 2, ( + "Found malformed getitem call. Please report a bug to PiPPy" + ) + indexed_value, node_idx = tuple(node.args) + + # indexed_value is a collection that we are indexing into. It could + # exist in the tuples map if we've processed another `getitem` + # already. + existing_list_size = ( + len(tuples[indexed_value]) if indexed_value in tuples else -1 + ) + new_list_size = max(node_idx + 1, existing_list_size) + + reconstructed_list = [None for _ in range(new_list_size)] + + # Copy over existing elements if present + if indexed_value in tuples: + for i, val in enumerate(tuples[indexed_value]): + reconstructed_list[i] = val + + # Populate value represented by this node + reconstructed_list[node_idx] = node + + tuples[indexed_value] = tuple(reconstructed_list) + + # Keep track of nodes that dominate the loss node. + # We will only emit backward operations for nodes that can contribute + # to the specified loss value. + live_nodes = {loss_node: None} + val_to_grad: dict[fx.Node, Optional[fx.Node]] = {loss_node: None} + + def assign_or_accumulate_grad(forward_node, grad_value): + if forward_node in val_to_grad and forward_node.op != "placeholder": + grad_value = g.call_function( + _null_coalesce_accumulate, + (val_to_grad[forward_node], grad_value), + ) + val_to_grad[forward_node] = grad_value + + with g.inserting_before(output_node): + for node in reversed(g.nodes): + if node not in live_nodes: + continue + + def add_to_live_nodes(n): + live_nodes.setdefault(n, None) + + fx.node.map_arg(node.args, add_to_live_nodes) + fx.node.map_arg(node.kwargs, add_to_live_nodes) + if node.op == "call_module": + output_grads: Union[tuple[Optional[fx.Node], ...], Optional[fx.Node]] + if node in tuples: + stage_output = tuples[node] + output_grads = tuple(val_to_grad.get(n, None) for n in tuples[node]) + outputs_with_grads_idxs = [ + i for i, n in enumerate(tuples[node]) if n in live_nodes + ] + else: + stage_output = (node,) + output_grads = val_to_grad[node] + outputs_with_grads_idxs = [0] + + output_grads = ( + (output_grads,) + if not isinstance(output_grads, tuple) + else output_grads + ) + + grad_call = g.call_function( + stage_backward, + kwargs={ + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": list(node.all_input_nodes), + "outputs_with_grads_idxs": outputs_with_grads_idxs, + }, + ) + # Insert backward stage debug info + kwargs_copy = dict(grad_call.kwargs) + grad_call.kwargs = kwargs_copy + + grad_call_proxy = fx.Proxy(grad_call) + grads = grad_call_proxy.node + + input_nodes = list(node.all_input_nodes) + grads_proxy = fx.Proxy(grads) + for i, input_node in enumerate(input_nodes): + assign_or_accumulate_grad(input_node, grads_proxy[i].node) # type: ignore[index] + + return g + + +class PipeSequential(torch.nn.Sequential): + @staticmethod + def from_sequential(sequential_instance: torch.nn.Sequential): + return PipeSequential(*[copy.copy(m) for m in sequential_instance]) + + def forward(self, input): + for i, module in enumerate(self): + input = module(input) + if i != len(self) - 1: + pipe_split() + return input + + +class LossWrapper(torch.nn.Module): + """ + LossWrapper is a convenient abstract class that allows you to wrap up both + your model as well as its loss function and specify the connectivity between + the inputs, model, loss function, and output value. Example:: + + class MyModelWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + loss_value = self.loss_fn(model_out, targets) + return loss_value + + The above example defines a connectivity where we expect the forward/loss/backward + training procedure to take two arguments (x and targets), pass x into the module + to get the output of the feedforward computation, pass the model output and the + targets value into the loss function, and get and return the loss value, which will + be backpropagated by PiPPy. The above class would then be instantiated like:: + + model = ... # instantiate the model + loss_fn = torch.nn.MSELoss() # for the sake of demonstration + + wrapper = MyModelWrapper(model, loss_fn) + pipe = Pipe.from_tracing(wrapper, ...) + + """ + + def __init__(self, module, loss_fn): + super().__init__() + self.module = module + self.loss_fn = loss_fn + + def forward(self, *args, **kwargs): + raise NotImplementedError( + "This instance of LossWrapper does not have an overridden" + "forward(). Please implement forward() to specify the arguments, " + "connection between the module and loss, and loss output " + "value." + ) + + +class TrivialLossWrapper(LossWrapper): + def forward(self, x, targets): + model_out = self.module(x) + return self.loss_fn(model_out, targets) + + loss_spec = True + + +# Pipe model representation +# +# Pipe can be thought of as an `nn.Sequential++`. That is to say: it specifies +# a single topological ordering of pipeline "stages" that, when run in series, +# constitutes all of the operations of the program. However, unlike `nn.Sequential`, +# Pipe allows non-local usages of values, so long as those uses still respect +# topological ordering. In particular: +# +# 1. Non-local activations. This type of usage can appear in, for example, skip +# connections. These values will be directly transmitted from the "def" stage +# to all stages that use them skipping intermediate stages. During autograd, +# gradients will be propagated back through this skip connection reverse +# to how activations propagated in the forward pass. +# 2. Non-local parameter/module invocations. This occurs when a parameter is used +# in a stage downstream of where it is resident. These values can be carried +# forward similarly to (1), but in addition one might want to replicate the +# value on multiple stages. Gradients for these shared parameters will be +# accumulated separately on each stage, but there will be an additional +# gradient accumulation before the optimizer step. + + +# Register `_pipe_split()` as an ATen operator. This is required for Export to +# preserve this marker in the graph. +torch.library.define("pippy::_pipe_split", "() -> ()") + + +@torch.library.impl("pippy::_pipe_split", "BackendSelect") +def _pipe_split(): + return None + + +@torch.library.register_fake("pippy::_pipe_split") # type: ignore[no-redef] +def _pipe_split(): # noqa: F811 + return None + + +# Add an alias for convenience +aten_pipe_split_alias = torch.ops.pippy._pipe_split.default + +# Ask Export to preserve the `_pipe_split` op. +# See examples in pytorch/torch/fx/node.py +fx.node._side_effectful_functions.add(aten_pipe_split_alias) + + +# User facing API +def pipe_split(): + """ + pipe_split is a special operator that is used to mark the boundary between + stages in a module. It is used to split the module into stages. It is a + no-op if your annotated module is run eagerly. + + Example: + >>> # xdoctest: +SKIP + >>> def forward(self, x): + >>> x = torch.mm(x, self.mm_param) + >>> x = torch.relu(x) + >>> pipe_split() + >>> x = self.lin(x) + >>> return x + + The above example will be split into two stages. + """ + return torch.ops.pippy._pipe_split() + + +class MultiUseParameterConfig(Enum): + TRANSMIT = 1 + REPLICATE = 2 + + +MultiUseParamSpec = Union[MultiUseParameterConfig, dict[str, MultiUseParameterConfig]] + + +class DetachExecutor(fx.Interpreter): + """ + Special interpreter to run the split_gm in testing that detaches all inputs to + a module invocation. This is needed so that the values at the boundary are + leaf modules in autograd execution. + """ + + def __init__(self, module, garbage_collect_values=True): + garbage_collect_values = False + super().__init__(module, garbage_collect_values) + self.value_remap = {} + + def run(self, *args, initial_env=None): # type: ignore[override] + self.value_remap = {} + return super().run(*args, initial_env=initial_env) + + def call_module(self, target, args, kwargs): + def detach_tensors(a): + if isinstance(a, torch.Tensor) and a.requires_grad: + if a not in self.value_remap: + new_val = a.detach().requires_grad_(True) + self.value_remap[a] = new_val + return self.value_remap[a] + else: + return a + + """ + def dont_traverse_size(a): + return type(a) != torch.Size + """ + + args = map_aggregate( + args, + detach_tensors, # dont_traverse_size + ) + kwargs = map_aggregate( + kwargs, + detach_tensors, # dont_traverse_size + ) + + return super().call_module(target, args, kwargs) + + def call_function(self, target, args, kwargs): + # HACK to reroute saved input tensors to point to the detach()ed version + if target == stage_backward: + kwargs = dict(kwargs) + kwargs["input_values"] = [ + self.value_remap.get(v, v) for v in kwargs["input_values"] + ] + return super().call_function(target, args, kwargs) + + +class _NodeReference: + def __init__(self, name): + self.name = name + + name: str + + +class _LinearNodeList: + def __init__(self, node_list): + self.serialize_node_list = [] + for node in node_list: + node_args = fx.node.map_arg(node.args, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + node_kwargs = fx.node.map_arg(node.kwargs, lambda n: _NodeReference(n.name)) # type: ignore[arg-type,return-value] + serialize_node = fx.Node( + graph=None, # type: ignore[arg-type] + name=node.name, + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + return_type=node.type, + ) + serialize_node.meta = copy.copy(node.meta) + self.serialize_node_list.append(serialize_node) + + def to_graph(self): + graph = fx.Graph() + + ref_str_to_node: dict[str, fx.Node] = {} + + def ref_to_node(arg): + if isinstance(arg, _NodeReference): + return ref_str_to_node[arg.name] + else: + return arg + + for node in self.serialize_node_list: + node_args = map_aggregate(node.args, ref_to_node) + node_kwargs = map_aggregate(node.kwargs, ref_to_node) + deser_node = graph.create_node( + op=node.op, + target=node.target, + args=node_args, # type: ignore[arg-type] + kwargs=node_kwargs, # type: ignore[arg-type] + name=node.name, + type_expr=node.type, + ) + ref_str_to_node[node.name] = deser_node + + return graph + + +def _direct_serialization_deserialize(body, nodes): + """ + Custom `__reduce__` method for serialization. + DO AS I SAY -- NOT AS I DO. This violates the principle that + GraphModules serialize via code export & re-tracing. We allow + for this here because **PIPE STAGES SHOULD NOT BE PERSISTED + TO DISK -- THIS IS ONLY FOR TRANSMISSION VIA RPC**. Persisting + these instances to disk will expose internal implementation + details of `fx.Graph` and related data structures and is + NOT advised. + """ + + class DummyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__.update(body) + + dummy = DummyModule(body) + + return fx.GraphModule(dummy, nodes.to_graph()) + + +def _direct_serialization_reduce(self): + serialization_dict = dict(self.__dict__) + serialization_dict.pop("_graph") + return ( + _direct_serialization_deserialize, + (serialization_dict, _LinearNodeList(self.graph.nodes)), + ) + + +def _modify_graph_op_device( + gm: torch.fx.GraphModule, + new_device: torch.device, +): + """ + Modify the device argument of all "call_function" nodes in the graph. This + is useful for moving the graph to a different device. In particular for + generator ops, like torch.ones. + """ + modified = False + for node in gm.graph.nodes: + if node.op == "call_function": + if "device" in node.kwargs and node.kwargs["device"] != new_device: + logger.debug( + f"Changing device of Node {node.name} from {node.kwargs['device']} to {new_device}" # noqa: G004 + ) + node.update_kwarg("device", new_device) + modified = True + elif node.op == "call_module": + # Recursively modify "device" in submodules + submod = gm.get_submodule(node.target) + if isinstance(submod, torch.fx.GraphModule): + _modify_graph_op_device(submod, new_device) + elif isinstance(submod, InterpreterModule): + # If unflattening has been performed, we need to access its graph module by `.graph_module` + _modify_graph_op_device(submod.graph_module, new_device) # type: ignore[arg-type] + else: + logger.warning( + f"Skipping device modification for submodule {node.target} because it is a {type(submod)}" # noqa: G004 + ) + + if modified: + gm.recompile() + + +class Pipe(torch.nn.Module): + def __init__( + self, + split_gm: fx.GraphModule, + num_stages: int, + has_loss_and_backward: bool, + loss_spec, + ): + # TODO: is there a way not to hard wire init? + torch.nn.Module.__init__(self) + self.split_gm: fx.GraphModule = split_gm + self.executor: DetachExecutor = DetachExecutor(self.split_gm) + self.num_stages: int = num_stages + self.has_loss_and_backward = has_loss_and_backward + self.loss_spec = loss_spec + + for node in split_gm.graph.nodes: + assert ( + node.op in {"call_module", "placeholder", "output"} + or (node.op, node.target) == ("call_function", operator.getitem) + or (node.op, node.target) == ("call_method", "backward") + or (node.op, node.target) == ("call_function", stage_backward) + or (node.op, node.target) + == ("call_function", _null_coalesce_accumulate) + ), node + + # Detect replicated parameters so we know that we have to do an additional allreduce + # before applying the optimizer + # + # Note that this also handles the case where there were multiple calls to a single + # module from different stages, regardless of whether that module invocation + # was handled by the logic above. + + # Map parameter value to a dictionary that maps the user pipeline module + # to the local qualname within that module + params_to_users: dict[torch.nn.Parameter, dict[str, str]] = {} + + for m_qualname, mod in self.split_gm.named_children(): + for p_qualname, param in mod.named_parameters(): + params_to_users.setdefault(param, {}) + params_to_users[param][m_qualname] = p_qualname + + self.replicated_params: list[dict[str, str]] = [ + use_mapping + for _, use_mapping in params_to_users.items() + if len(use_mapping) > 1 + ] + + # We must break the aliasing relationship between the replicated parameters for correct + # numerics in reference runs. If we do not do this, the autograd tape in separate stages + # will have a reference to the same tensor value and will erroneously apply gradient + # updates multiple times. Therefore, for each replicated parameter set, we deepcopy the + # values so that we have separate instances. + for param_mapping in self.replicated_params: + for submod_name, param_qualname in param_mapping.items(): + submod = getattr(self.split_gm, submod_name) + atoms = param_qualname.split(".") + for atom in atoms[:-1]: + submod = getattr(submod, atom) + setattr(submod, atoms[-1], copy.deepcopy(getattr(submod, atoms[-1]))) + + def throw(self, *args, **kwargs): + raise RuntimeError( + "To run pipeline locally, invoke the Pipe object directly, not `split_gm`" + ) + + self.split_gm.forward = throw + + # Make submodules use custom direct-serialized GraphModule + i = 0 + while True: + try: + name = f"submod_{i}" + submod = getattr(self.split_gm, name) + submod.__class__.__reduce__ = _direct_serialization_reduce + i += 1 + except AttributeError: + break + + def forward(self, *args, **kwargs): + executor_args = args + if len(kwargs) > 0: + parameters = [] + for node in self.split_gm.graph.nodes: + if node.op == "placeholder": + if node.args and len(node.args) > 0: + parameters.append( + Parameter( + node.target, + Parameter.POSITIONAL_OR_KEYWORD, + default=node.args[0], + ) + ) + else: + parameter_kind = Parameter.POSITIONAL_OR_KEYWORD + param_name = node.target + if node.target.startswith("**"): + parameter_kind = Parameter.VAR_KEYWORD # type: ignore[assignment] + param_name = param_name[2:] + elif node.target.startswith("*"): + parameter_kind = Parameter.VAR_POSITIONAL # type: ignore[assignment] + param_name = param_name[1:] + parameters.append(Parameter(param_name, parameter_kind)) + signature = Signature(parameters) + ba = signature.bind(*args, **kwargs) + ba.apply_defaults() + executor_args = ba.arguments.values() # type: ignore[assignment] + + res = self.executor.run(*executor_args) + + return res + + def get_stage_module(self, stage_idx: int) -> torch.nn.Module: + """ + Return a stage module corresponding to `stage_idx` of the `pipe`. + """ + if stage_idx < 0 or stage_idx >= self.num_stages: + raise ValueError(f"Invalid stage index {stage_idx}!") + return getattr(self.split_gm, f"submod_{stage_idx}") + + @staticmethod + def _number_and_count_forward_stages(gm: fx.GraphModule): + num_stages = 0 + found_idxs: dict[int, None] = {} + for node in gm.graph.nodes: + if node.op == "call_module" and node.target.startswith("submod_"): + node.meta["stage_idx"] = int(node.target[len("submod_") :]) + found_idxs.setdefault(node.meta["stage_idx"]) + num_stages += 1 + + # this assert will fail if a split point is inserted before the first layer, which creates empty first submodule + # Update: the following assert may fail against some torch versions >= + # 2.2.0, as: + # submod_0, submod_1, submod_2, ... + # may be named as + # submod_0, submod_2, submod_4, ... + # TODO: investigate + # assert all(i in found_idxs for i in range(num_stages)) + + return num_stages + + @staticmethod + def _from_traced( + mod: torch.nn.Module, + exported_program: ExportedProgram, + multi_use_param_spec: Optional[MultiUseParamSpec] = None, + output_loss_value_spec=None, + split_policy: Optional[ + Callable[[torch.fx.GraphModule], torch.fx.GraphModule] + ] = None, + ): + """ + Additionally, the ``output_loss_value_spec`` value can be specified to disambiguate + which value in the output of `forward` is the loss value on which PiPPy should apply + backpropagation. For example, if your ``forward`` returns a tuple ``(loss, model_out)``, + you can specify ``output_loss_value_spec=(True, False)``. Or, if your ``forward`` returns + a dict ``{'loss': loss_value, 'model_out': model_out}``, you can specify + ``output_loss_value_spec={'loss': True, 'model_out': False}`` + """ + + traced = exported_program.module() + + if split_policy is not None: + logger.info("Auto-splitting model") + traced = split_policy(traced) # type: ignore[arg-type] + + logger.debug(traced.print_readable(print_output=False)) # type: ignore[operator] + + # Deduplicate `get_attr` nodes that refer to the same parameter . Downstream code for moving + # parameters relies on the invariant that parameter accesses happen once. This is not necessarily + # the case (especially with custom tracers), so fix that up here. + get_attr_nodes: dict[str, fx.Node] = {} + for node in traced.graph.nodes: # type: ignore[union-attr] + if node.op == "get_attr": + get_attr_nodes.setdefault(node.target, node) + + if get_attr_nodes[node.target] != node: + node.replace_all_uses_with(get_attr_nodes[node.target]) + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + # avoid looking at next node by keeping track of previous pipe_split + prev_pipe_split_idx = -1 + pipe_split_nodes_to_erase = set() + for i, node in enumerate(traced.graph.nodes): # type: ignore[arg-type, union-attr] + if (node.op, node.target) == ("call_function", pipe_split): + if prev_pipe_split_idx == i - 1: + pipe_split_nodes_to_erase.add(node) + prev_pipe_split_idx = i + + for node in pipe_split_nodes_to_erase: + traced.graph.erase_node(node) # type: ignore[operator, union-attr] + + traced.recompile() # type: ignore[operator] + + part_idx = 0 + + def split_callback(n: fx.Node): + nonlocal part_idx + if (n.op, n.target) == ( + "call_function", + aten_pipe_split_alias, + ): + logger.debug(f"Found pipe_split {part_idx}") # noqa: G004 + part_idx += 1 + return part_idx + + # TODO: what does split do with module invocations? does it move the modules + # into the submodules? + split = split_module(traced, mod, split_callback) # type: ignore[arg-type] + # a (custom) tracer can produce dead code like orphan get_attr nodes + split.graph.eliminate_dead_code() + + # peephole to remove pipe_split + for submodule in split.modules(): + if isinstance(submodule, fx.GraphModule): + for node in submodule.graph.nodes: + if (node.op, node.target) == ( + "call_function", + aten_pipe_split_alias, + ): + submodule.graph.erase_node(node) + submodule.recompile() + + for name, submodule in split.named_children(): + if isinstance(submodule, fx.GraphModule): + new_submod = _outline_submodules(submodule.graph) + # Replace old submod + split.register_module(name, new_submod) + + # TODO: backport this into split_module + def delete_user_reference(node, user): + """ + Delete reference of `node` from `user`'s arg list. + Args: + - node: a `get_attr` node at root. + - user: a submodule node that uses `node`. + """ + assert len(user.kwargs) == 0 + use_idxs = [i for i, arg in enumerate(user.args) if arg == node] + assert len(use_idxs) == 1 + args_copy = list(user.args) + args_copy.pop(use_idxs[0]) + user.args = tuple(args_copy) + logger.debug( + f"Deleted {node} from user {user}, arg index = {use_idxs[0]}" # noqa: G004 + ) + + # A list of param referrals for deferred deletion. + # To be accumulated in `move_param_to_callee`. + to_delete = [] + + def _recursive_getattr_with_parent(mod, fqn): + # Returns getattr call given a nested FQN, and the last parent + atoms = fqn.split(".") + for atom in atoms[:-1]: + if not hasattr(mod, atom): + return None, None + mod = getattr(mod, atom) + if not hasattr(mod, atoms[-1]): + return mod, None + attr = getattr(mod, atoms[-1]) + return mod, attr + + def move_param_to_callee( + root, + callee_name, + param_fqn, + ): + """ + Move a parameter from the root module to a submodule. + Args: + root: The root module. + callee_name: The name of the submodule to move the parameter to. + param_fqn: The fully qualified name of the parameter to move. + """ + # `atoms` is a list of strings representing the path to the + # parameter in the original model + atoms = param_fqn.split(".") + mod_itr, param_val = _recursive_getattr_with_parent(split, param_fqn) + # Check whether the parameter is a buffer or a parameter + is_buffer = atoms[-1] in mod_itr._buffers + + # Check whether the parameter is a tensor + assert isinstance(param_val, torch.Tensor), ( + f"Expected '{param_fqn}' to be {torch.Tensor} but got {type(param_val)}." + + ( + f" It might happen if module '{param_fqn}' was passed to some 'leaf function'" + f"(see https://pytorch.org/docs/stable/fx.html#fx.wrap). Please inspect " + f"usages of '{param_fqn}' in the traced graph." + if isinstance(param_val, torch.nn.Module) + else "" + ) + ) + + # Get submodule + callee = root.get_submodule(callee_name) + assert not hasattr(callee, param_fqn), ( + f"Module {callee_name} already has a parameter named {param_fqn}" + ) + + # Assign the parameter to the submodule + if is_buffer: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.BUFFER, + persistent=True, # TODO: handle non-persistent buffer + ) + else: + _assign_attr( + param_val, + callee, + param_fqn, + attr_kind=_AttrKind.PARAMETER, + ) + logger.debug(f"Moved parameter {param_fqn} to {callee_name}") # noqa: G004 + + # Next step is to replace placeholder of submodule with a get_attr. + # Those placeholders are created by `split_module` inside each + # submodule. + # Update: this step is now moved to `_sink_params` because + # `_sink_params` can do it recursively (i.e. for modules inside + # submodule) + + to_delete.append((mod_itr, atoms[-1])) + + # Get the list of all parameters in the root module + attr_nodes = list(filter(lambda n: n.op == "get_attr", split.graph.nodes)) + for node in attr_nodes: + # Check whether the parameter is used in only one submodule + if len(node.users) > 1: + logger.info( + f"Parameter {node.target} used in multiple stages: {node.users}." # noqa: G004 + ) + for user in node.users: + assert user.op == "call_module" + # Move parameter into submodule + move_param_to_callee( + split, + user.target, + node.target, + ) + + # [aliasing] store tensor id -> list of FQNs, built from state dict + # Also assign non-persistent buffers + id_to_fqns: dict[int, set[str]] = defaultdict(set) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + id_to_fqns[id(tensor)].add(fqn) + for fqn, tensor in mod.named_buffers(): + id_to_fqns[id(tensor)].add(fqn) + + # After moving the params to their corresponding hierarchies, we also + # need to move the `get_attr` nodes from the root of the graph to those + # hierarchies. + # [aliasing] use id -> fqn mapping to list out all valid FQNs + inputs_to_state: dict[str, list[str]] = {} + for attr in attr_nodes: + _, tensor = _recursive_getattr_with_parent(mod, attr.target) + fqns = list(id_to_fqns[id(tensor)]) + if fqns: + inputs_to_state[attr.name] = fqns + elif attr.target in exported_program.constants: # lifted constants + inputs_to_state[attr.name] = [attr.target] + + # [aliasing] for each submodule split, assign attributes on FQNs that may be used. + # We determine this based on whether or not the FQN attribute parent exists. + # i.e. if the last submodule exists, assign the attribute. + added_attributes: dict[str, list[str]] = defaultdict(list) + for fqn, tensor in mod.state_dict(keep_vars=True).items(): + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + parent, child = _recursive_getattr_with_parent(submod, fqn) + if ( + parent and child is None + ): # parent exists, attribute doesn't -> assign + added_attributes[name].append(fqn) + setattr(parent, fqn.split(".")[-1], tensor) + + # Deferral deletion: Remove the original attributes (to params) from the + # root GraphModule + for mod_itr, last_atom in to_delete: + try: + delattr(mod_itr, last_atom) + except AttributeError: + # This is expected if the parameter is used in multiple stages + pass + + # This is done by (1) `_sink_params` at each submodule; + for name, submod in split.named_children(): + if isinstance(submod, fx.GraphModule): + _sink_params(submod, inputs_to_state, []) + submod.graph.lint() + submod.recompile() + + # [aliasing] This step is not super necessary, but helps reduce parameter usage/memory. + # After _sink_params() routine has run, clean up unused attributes that we previously added. + # Determine this based on the get_attr nodes - if not used, remove it. + for name, attributes in added_attributes.items(): + submod = getattr(split, name) + unused_attributes = set(attributes) + # track used attributes in the submodule, running DFS on subgraph hierarchy + stack = [("", submod)] # (scope, submodule) + while stack: + scope, _mod = stack.pop() + if isinstance(_mod, (fx.GraphModule, InterpreterModule)): + for node in _mod.graph.nodes: + if node.op == "get_attr": + # get_attr might get access deeper level attribute + fqn = scope + "." + node.target if scope else node.target + unused_attributes.discard(fqn) + for _name, _submod in _mod.named_children(): + stack.append((scope + "." + _name if scope else _name, _submod)) + # delete unused attributes + for attr in unused_attributes: + mod_itr, atoms = submod, attr.split(".") + for atom in atoms[:-1]: + mod_itr = getattr(mod_itr, atom) + delattr(mod_itr, atoms[-1]) + + for node in attr_nodes: + # And (2): remove `get_attr` node from submod's arg list + for user in copy.copy(node.users): + assert user.op == "call_module" + delete_user_reference(node, user) + # And (3): remove the `get_attr` node from the root graph. + split.graph.erase_node(node) + + split.delete_all_unused_submodules() + split.graph.lint() + split.recompile() + + num_stages = Pipe._number_and_count_forward_stages(split) + + has_loss_and_backward = False + generated_loss_spec = output_loss_value_spec + + if output_loss_value_spec is not None: + loss_node, output_node, generated_loss_spec = _find_loss_output( + mod, split.graph, output_loss_value_spec + ) + if loss_node is not None: + _insert_stage_symbolic_backward( + split.graph, + loss_node, + output_node, + ) + split.recompile() + has_loss_and_backward = True + logger.debug("Pipeline is in training mode, backward pass generated") + else: + raise RuntimeError( + f"Did not find any loss value according to {output_loss_value_spec=}" + ) + else: + logger.debug("Pipeline is in inference mode, backward pass not generated") + + logger.debug(f"Full pipe model:\n{split}") # noqa: G004 + + return Pipe( + split, + num_stages, + has_loss_and_backward, + generated_loss_spec, + ) + + def print_readable(self): + """ + Print the pipe in a human-readable format. + This will print both the root pipe and each stage module. + """ + self.split_gm.print_readable() + + @staticmethod + def _trace_with_export( + mod: torch.nn.Module, + example_args: tuple[Any, ...], + example_kwargs: Optional[dict[str, Any]] = None, + ) -> ExportedProgram: + logger.info("Tracing model ...") + try: + ep = torch.export.export_for_training( + mod, example_args, example_kwargs, strict=True + ) + except Exception as e: + raise RuntimeError( + "It seems that we cannot capture your model as a full graph. " + "Typical reasons include graph breaks, data/shape-dependent " + "control flow, or missing meta kernels for custom operators. " + "You can use our manual pipeline interfaces, or try to fix the " + "graph breaks, see https://pytorch.org/docs/stable/export.html" + ) from e + + return ep + + @staticmethod + def from_tracing( + mod: torch.nn.Module, + example_args: tuple[Any, ...], + example_kwargs: Optional[dict[str, Any]] = None, + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, + ): + # If a param will be used in multiple pipeline stages, we default the strategy to REPLICATE'ing the param across + # stages instead of TRANSMIT'ting it + multi_use_param_spec = MultiUseParameterConfig.REPLICATE + + # Figure out which output is loss from output_chunk_spec + output_loss_value_spec: Any = None + # Deprecated + """ + if output_chunk_spec is not None: + output_loss_value_spec = map_aggregate( + output_chunk_spec, lambda v: isinstance(v, _LossReducer) + ) + """ + + # Trace with export + exported_program = Pipe._trace_with_export( + mod, + example_args, + example_kwargs, + ) + + pipe = Pipe._from_traced( + mod, + exported_program, + multi_use_param_spec, + output_loss_value_spec=output_loss_value_spec, + split_policy=split_policy, + ) + + # Users want the first pipeline stage to accept kwargs if the original + # program does. This is controlled by the `_codegen` field of the graph, + # so we make a copy here. Note: we only want the input spec and not the + # output spec, because the output spec is for the last stage. Maybe a + # TODO? Not sure yet. + split = pipe.split_gm + traced = exported_program.module() + submod0 = next(iter(split.children())) + submod0_sign = signature(submod0.forward) + model_sign = signature(traced.forward) + if len(model_sign.parameters) != len(submod0_sign.parameters): + # We don't change the signature of the first stage if it takes + # different number of args than original model + logger.info( + f"Original model takes {len(model_sign.parameters)} args but the " # noqa: G004 + f"first pipeline stage takes {len(submod0_sign.parameters)}. " + "Please provide args to respective pipeline stages." + ) + else: + # Support kwargs for the first stage + submod0.graph._codegen = copy.deepcopy(traced.graph._codegen) # type: ignore[union-attr] + # `_replace` is actually not "private" or internal. based on this doc: + # To prevent conflicts with field names, the method and attribute names + # start with an underscore + submod0.graph._codegen.pytree_info = ( # type: ignore[union-attr] + submod0.graph._codegen.pytree_info._replace(out_spec=None) # type: ignore[operator, union-attr] + ) + submod0.recompile() + + return pipe + + def __str__(self): + return self.split_gm.__str__() + + def __repr__(self): + return self.split_gm.__repr__() + + def info(self) -> PipeInfo: + """ + Get information about the pipe. + + Returns + ------- + PipeInfo + A dataclass containing information about the pipe. + """ + return PipeInfo( + graph=self.split_gm.graph, + num_stages=self.num_stages, + has_loss_and_backward=self.has_loss_and_backward, + ) + + def build_stage( + self, + stage_index: int, + device: torch.device, + group: Optional[ProcessGroup] = None, + ) -> _PipelineStage: + """ + Create a `PipelineStage` given a stage index and distributed group. + The `PipelineStage` can run with `PipelineSchedule`s. + """ + # Find stage module + stage_module = self.get_stage_module(stage_index) + + # Move ops argument to device + # Today PT2 tracer does not treat `x.device` as a symbolic device; + # instead, the device of tracing time got burned into the generated + # code. Here we provide a workaround for users to manually modify the + # "device" kwarg of operations. Such operation may include: + # `torch.ones`, `torch.zeros`, `torch.rand`, etc. + if isinstance(stage_module, torch.fx.GraphModule): + _modify_graph_op_device(stage_module, device) + else: + logger.warning( + f"Expected a `torch.fx.GraphModule` but got {type(stage_module)}" # noqa: G004 + ) + + # Detach pipe info + # Note: be careful what's included in `pipe_info`. We don't want to keep + # a reference to `Pipe` or `Pipe.split_gm` which stops python from + # recycling them. When python recycles them, other stage modules (which + # are irrelevant to current rank) can be automatically freed. + pipe_info = self.info() + return _PipelineStage(stage_module, stage_index, pipe_info, device, group) + + +class SplitPoint(Enum): + """ + Enum representing the points at which a split can occur in the execution of a submodule. + Attributes: + BEGINNING: Represents adding a split point *before* the execution of a certain submodule in the `forward` function. + END: Represents adding a split point *after* the execution of a certain submodule in the `forward` function. + """ + + BEGINNING = 1 + END = 2 + + +# For backward compatibility, we kept the PipeSplitWrapper class because `class +# SplitPoint` used to be defined in this class. +class PipeSplitWrapper: + # Create a class alias for BC + SplitPoint = SplitPoint + + +def _split_before_forward(self, *args, **kwargs): + pipe_split() + return self._orig_forward(*args, **kwargs) + + +def _split_after_forward(self, *args, **kwargs): + try: + return self._orig_forward(*args, **kwargs) + finally: + pipe_split() + + +def annotate_split_points(mod: torch.nn.Module, spec: dict[str, SplitPoint]): + # TODO: make this implementation out-of-place? + for qualname, split_type in spec.items(): + atoms = qualname.split(".") + predecessor_module = mod + for i, atom in enumerate(atoms[:-1]): + try: + predecessor_module = getattr(predecessor_module, atom) + except AttributeError as e: + raise AttributeError( + f"Specified target {qualname} referenced " + f"nonexistent module {'.'.join(atoms[: i + 1])}" + ) from e + + mod_to_wrap = getattr(predecessor_module, atoms[-1]) + mod_to_wrap._orig_forward = mod_to_wrap.forward + if split_type == SplitPoint.BEGINNING: + mod_to_wrap.forward = MethodType(_split_before_forward, mod_to_wrap) + elif split_type == SplitPoint.END: + mod_to_wrap.forward = MethodType(_split_after_forward, mod_to_wrap) + else: + raise ValueError("Unknown split point type.") + + +def pipeline( + module: torch.nn.Module, + mb_args: tuple[Any, ...], + mb_kwargs: Optional[dict[str, Any]] = None, + split_spec: Optional[dict[str, SplitPoint]] = None, + split_policy: Optional[Callable[[fx.GraphModule], fx.GraphModule]] = None, +) -> Pipe: + """ + Split a module based on a specification. + + See `Pipe` for more details. + + Arguments + --------- + module: + The module to be split. + mb_args: + Example positional inputs, in micro-batch form. + mb_kwargs: + Example keyword inputs, in micro-batch form. (default: `None`) + split_spec: + A dictionary using submodule names as split marker. (default: `None`) + split_policy: + The policy to use for splitting the module. (default: `None`) + + Returns + ------- + A pipeline representation of class `Pipe`. + """ + if split_spec is not None and split_policy is not None: + raise ValueError( + "Cannot specify both `split_spec` and `split_policy`. Please use only one of them." + ) + + if split_spec is not None: + # Annotate split points in the module based on user spec + annotate_split_points(module, split_spec) + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + ) + else: + # Use split policy + return Pipe.from_tracing( + mod=module, + example_args=mb_args, + example_kwargs=mb_kwargs, + split_policy=split_policy, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__init__.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7e618354fc38ee9aa187aeee82b11354c151732f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._IR import Pipe, pipe_split, pipeline, SplitPoint +from .schedules import ( + _ScheduleForwardOnly, + Schedule1F1B, + ScheduleGPipe, + ScheduleInterleaved1F1B, + ScheduleInterleavedZeroBubble, + ScheduleLoopedBFS, + ScheduleZBVZeroBubble, +) +from .stage import build_stage, PipelineStage + + +__all__ = [ + "Pipe", + "pipe_split", + "SplitPoint", + "pipeline", + "PipelineStage", + "build_stage", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", + "ScheduleZBVZeroBubble", +] diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30df8a26659e7b4d3aa0f331f699c07e2b26e2f1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_IR.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32f716fd5093310f5860927ff6c5c8c292c1d089 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a740740456647f9d2400d3f975360d8f076eb4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_backward.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11cf4e2cbdaac5f67f7ecc1f10f037aac03d071e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_debug.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0702f88acb37a329b18d25280b64f71a0b01f6f1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_schedule_visualizer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ca817d6743e4951279d34607128712c3fab6e74 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_unflatten.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..688820d9d6725bc9d00bda27714763e92c5c1649 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cad99d6f638a19ffe9dc3835e6ca0043ffad1595 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/microbatch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..824f12d9482e7a456b93708796b9e246ed2802c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/schedules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11cf8158e78e46a356a2318197b64798f7122bd8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/pipelining/__pycache__/stage.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/_backward.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/_backward.py new file mode 100644 index 0000000000000000000000000000000000000000..ab1706d6db1fa53fd07b12cdc3841579ac3debc0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/_backward.py @@ -0,0 +1,404 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import collections +import logging +from collections.abc import Iterator +from typing import Any, Optional, Union + +import torch +from torch.autograd.graph import GradientEdge, Node +from torch.nn import Parameter + +from ._debug import map_debug_info + + +logger = logging.getLogger(__name__) + + +def _get_grad_fn_or_grad_acc(t: torch.Tensor) -> Union[Node, None]: + """ + Get the grad function or grad accumulator for a tensor. + + Accumulate grad nodes are lazily created, so we need to a + dummy view in order to trigger its creation. + """ + if t.requires_grad and t.grad_fn is None: + # if no grad function (leaf tensors) we use view + viewed_t = t.view_as(t) + grad_fn = viewed_t.grad_fn + if grad_fn is not None: + return grad_fn.next_functions[0][0] + else: + raise RuntimeError( + "Attempted to get grad_fn, but got None." + "Is this being created in a no-grad context?" + ) + else: + return t.grad_fn + + +def reverse_closure( + roots: list[Node], target_nodes: set[Node], reverse_edges_dict +) -> tuple[set[Node], set[Node]]: + """ + This function returns the reverse closure of the given roots, + i.e. the set of nodes that can be reached from the roots by following the + reverse edges of the graph. The target_nodes are the nodes that we want to + include in the closure. + """ + # Recurse until we reach a target node + closure: set[Node] = set() + visited_target_nodes = set() + q: collections.deque[Node] = collections.deque() + for node in roots: + if node is not None and node not in closure: + closure.add(node) + q.append(node) + while q: + node = q.popleft() + reverse_edges = reverse_edges_dict[node] + for fn in reverse_edges: + if fn in closure or fn is None: + continue + if fn in target_nodes: + visited_target_nodes.add(fn) + continue + closure.add(fn) + q.append(fn) + return closure, visited_target_nodes + + +def construct_reverse_graph(roots: list[Node]) -> dict[Node, list[Node]]: + q: collections.deque[Node] = collections.deque() + root_seen: set[Node] = set() + reverse_edges_dict: dict[Node, list[Node]] = collections.defaultdict(list) + for node in roots: + if node is not None and node not in root_seen: + q.append(node) + root_seen.add(node) + while q: + node = q.popleft() + for fn, _ in node.next_functions: + if fn is not None: + if len(reverse_edges_dict[fn]) == 0: + q.append(fn) + reverse_edges_dict[fn].append(node) + return reverse_edges_dict + + +def get_param_groups( + inputs: list[Node], params: list[Node], reverse_edges_dict +) -> list[dict[str, Any]]: + """ + Given a list of inputs and a list of parameters, return a list of parameter + groups, where each group contains the parameters and the intermediates that + are connected to the parameters. + + The returned list of parameter groups is a list of dictionaries, where each + dictionary contains the following keys: + - "params": a set of parameters + - "intermediates": a set of intermediates + + The returned list of parameter groups is a list of dictionaries, + """ + # reverse graph that starts with inputs, and goes up to the dOutput or the loss, + # but omits weights and any subgraphs connecting weights to this closure + inputs_closure, _ = reverse_closure(inputs, set(), reverse_edges_dict) + param_groups: dict[Node, dict[str, set]] = dict() # keyed on intermediates + for param in params: + closure, intersected = reverse_closure( + [param], inputs_closure, reverse_edges_dict + ) + param_group: dict[str, set] = { + "params": {param}, + "intermediates": intersected, + } + for input_node in intersected: + existing = param_groups.get(input_node, None) + if existing is not None: + existing["params"] = existing["params"].union(param_group["params"]) + existing["intermediates"] = existing["intermediates"].union( + param_group["intermediates"] + ) + param_group = existing + else: + param_groups[input_node] = param_group + + # Sanity check: union of all param_groups params should be equal to all params + union_params: set[Node] = set() + seen_ids: set[int] = set() + unique_param_groups = [] + for param_group in param_groups.values(): + if id(param_group) not in seen_ids: + seen_ids.add(id(param_group)) + unique_param_groups.append(param_group) + union_params = union_params.union(param_group["params"]) + + # The assert will only be true if the input tensor requires gradients, + # otherwise the autograd graph will miss the first layer of inputs + # assert union_params == set(params) + return unique_param_groups + + +def stage_backward_input( + stage_outputs_or_loss: list[torch.Tensor], + output_grads: Optional[list[torch.Tensor]], + input_values: list[torch.Tensor], + weights: Iterator[Parameter], +) -> tuple[tuple[Optional[torch.Tensor], ...], list[dict[str, Any]]]: + """ + Compute the gradients for only the stage inputs with + respect to the stage outputs (if non-last stage) or loss (if last stage) + + After computing input gradients, we save the intermediate nodes in `param_groups` + for later use in stage_backward_weight. We don't need to save any other intermediate nodes + that aren't needed for dW because when we do dW calculation, we start from saved intermediates. + Detaching the stage_outputs_or_loss at the end of this function is important as + it frees up the memory that the autograd graph is anticipating to be used later (but doesn't actually need). + """ + stage_output_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, stage_outputs_or_loss)) + ) + stage_input_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, input_values)) + ) + weight_grad_fns: list[Node] = list( + filter(None, map(_get_grad_fn_or_grad_acc, weights)) + ) + + reverse_edges_dict = construct_reverse_graph(stage_output_grad_fns) + param_groups = get_param_groups( + stage_input_grad_fns, weight_grad_fns, reverse_edges_dict + ) + + handles = [] + for param_group in param_groups: + for i, intermediate in enumerate(param_group["intermediates"]): + + def get_hook(param_group, i): + def hook(grad_inputs): + if param_group.get("grads", None) is None: + param_group["grads"] = [None] * len( + param_group["intermediates"] + ) + param_group["grads"][i] = grad_inputs + + return hook + + # These are always "split" nodes that we need to recompute, so + # save their inputs. + handle = intermediate.register_prehook(get_hook(param_group, i)) + handles.append(handle) + + if output_grads is None: + # In case this is the loss and there are no output_grads, then we just use 1s + output_grads = [ + torch.ones_like(stage_output) for stage_output in stage_outputs_or_loss + ] + + # Some inputs may not be used or may not require gradients, so we filter them out + input_values = [inp for inp in input_values if inp.requires_grad] + dinputs = torch.autograd.grad( + stage_outputs_or_loss, + inputs=input_values, + grad_outputs=output_grads, + retain_graph=True, + ) + # Update the gradients for inputs + for inp, dinput in zip(input_values, dinputs): + if inp.grad is None: + inp.grad = dinput + else: + inp.grad += dinput + + # stage_outputs_or_loss are not used in backwards after this point, so we can safely remove it from the autograd graph + # this allows autograd to clear up the graph dedicated for this tensor and free up significant memory + for t in stage_outputs_or_loss: + t.detach_() + + # hooks are no longer necessary, clean up for consistency + for handle in handles: + handle.remove() + + return dinputs, param_groups + + +def stage_backward_weight( + weights: Iterator[Parameter], param_groups: list[dict[str, Any]], retain_graph=False +) -> tuple[Optional[torch.Tensor], ...]: + # map weights to param_group_weights + grad_acc_to_weight = {} + weight_grads: list[Optional[torch.Tensor]] = [] + for index, weight in enumerate(weights): + grad_acc = _get_grad_fn_or_grad_acc(weight) + grad_acc_to_weight[grad_acc] = weight, index + weight_grads.append(weight.grad) + + for param_group in param_groups: + # TODO: Handle case where intermediate can have multiple outputs + intermediate_edges = tuple( + GradientEdge(i, 0) for i in param_group["intermediates"] + ) + weights_edges = tuple(GradientEdge(w, 0) for w in param_group["params"]) + + # Break a reference cycle caused inside stage_backward_input->get_hook->hook + # The summarized cycle is: + # `hook` -> cell -> param_group -> intermediates -> `hook` + # because we install the hook function onto each of the intermediate autograd nodes. + # We need to keep intermediates alive up until backward_weight, but we can free it now. + del param_group["intermediates"] + + assert all(len(g) == 1 for g in param_group["grads"]) + # [NEW!] Able to pass a GradientEdge to autograd.grad as output + # We do not need to retain_graph because... guarantee no overlap? + # print("trying to execute: ", intermediate_edges, weights_edges) + dweights = torch.autograd.grad( + intermediate_edges, + weights_edges, + grad_outputs=sum(param_group["grads"], tuple()), + retain_graph=retain_graph, + ) + # release grad memory early after use + del param_group["grads"] + + for grad_acc, dw in zip(param_group["params"], dweights): + weight, index = grad_acc_to_weight[grad_acc] + if weight.grad is None: + weight.grad = dw + else: + weight.grad += dw + # return grads in the original order weights were provided in + return tuple(weight_grads) + + +def stage_backward( + stage_output, + output_grads, + input_values, + outputs_with_grads_idxs: Optional[list[int]] = None, # deprecated, not used +) -> tuple[Optional[torch.Tensor], ...]: + """ + This is a helper function to: + 1. compute the gradients for the stage inputs, and + 2. accumulate gradients for the stage module's parameters. + + Given the input value(s) and the corresponding gradient for the output + value(s), compute and accumulate gradients for all parameter values (leaves + in the autograd trace) as well as return a list of the gradients for the + input values + """ + if outputs_with_grads_idxs is not None: + # Deprecated, not used in runtime calls, only exists in compiler + stage_output = [stage_output[i] for i in outputs_with_grads_idxs] + output_grads = [output_grads[i] for i in outputs_with_grads_idxs] + + try: + # stage_output may be a composite datatype like dict. Extract all individual + # tensor values here + stage_output_tensors: list[torch.Tensor] = [] + output_grad_tensors: list[Optional[torch.Tensor]] = [] + + def extract_tensors_with_grads( + output_val, + grad_val, + # Don't delete me- see [Note: ref cycle] + extract_tensors_with_grads, + ): + if isinstance(output_val, torch.Tensor): + if not output_val.requires_grad and output_val.grad_fn is None: + return + assert isinstance(grad_val, (torch.Tensor, type(None))), ( + f"Expected Tensor or None gradient but got {type(grad_val)}" + ) + stage_output_tensors.append(output_val) + output_grad_tensors.append(grad_val) + elif isinstance(output_val, (tuple, list)): + if grad_val is None: + return + assert isinstance(grad_val, (tuple, list)), ( + f"grad_value expected to have type {type(output_val)} but got {type(grad_val)}" + ) + assert len(output_val) == len(grad_val) + for ov, gv in zip(output_val, grad_val): + extract_tensors_with_grads( + ov, + gv, + extract_tensors_with_grads, + ) + elif isinstance(output_val, dict): + if grad_val is None: + return + assert isinstance(grad_val, dict) + assert set(output_val.keys()) == set(grad_val.keys()) + for k in output_val.keys(): + extract_tensors_with_grads( + output_val[k], grad_val[k], extract_tensors_with_grads + ) + else: + # Output is a non-tensor type; just ignore it + pass + + # Note: ref cycle + # break a ref cycle that would keep tensors alive until GC runs + # 1. extract_tensors_with_grads refers to a cell that holds refs to any vars defined in stage_backward + # and used in extract_tensors_with_grads + # 2. extract_tensors_with_grads referred to both stage_output_tensors, output_grad_tensors, + # and to itself (extract_tensors_with_grads) since it makes a recursive call + # 3. stage_output_tensors was kept alive by the above refcycle, and it holds activation tensors, which is bad + # fix -> explicitly pass in the ref to the fn, so there is no gc cycle anymore + extract_tensors_with_grads( + stage_output, output_grads, extract_tensors_with_grads + ) + + torch.autograd.backward( + stage_output_tensors, + grad_tensors=output_grad_tensors, # type: ignore[arg-type] + ) + + # Extract gradients wrt the input values + grad_inputs: list[Optional[torch.Tensor]] = [] + for val in input_values: + if isinstance(val, torch.Tensor): + grad_inputs.append(val.grad) + else: + grad_inputs.append(None) + + # Alternative impl: `torch.autograd.grad`. + # Note that `torch.autograd.grad` will not accumulate gradients into the + # model's parameters. + """ + inputs_with_grad = [] + for val in input_values: + if isinstance(val, torch.Tensor) and val.requires_grad: + inputs_with_grad.append(val) + + grad_inputs = torch.autograd.grad( + stage_output_tensors, inputs_with_grad, output_grad_tensors, # type: ignore[arg-type] + ) + """ + + except Exception as e: + exc_msg = f""" + Failed to run stage backward: + Stage output: {map_debug_info(stage_output)} + Output gradient: {map_debug_info(output_grads)} + Input: {map_debug_info(input_values)} + """ + raise RuntimeError(exc_msg) from e + + return tuple(grad_inputs) + + +# TODO: handling requires_grad=False dynamically. Can we analyze this during initial +# IR emission? +def _null_coalesce_accumulate(lhs, rhs): + """ + Coalesce two values, even if one of them is null, returning the non-null + value. + """ + if lhs is None: + return rhs + elif rhs is None: + return lhs + else: + return torch.add(lhs, rhs) diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/_debug.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/_debug.py new file mode 100644 index 0000000000000000000000000000000000000000..a2a973e656afeb29325d63cf62af42d91bd06542 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/_debug.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +from torch.fx.node import Argument + + +def friendly_debug_info(v: object) -> Argument: + """ + Helper function to print out debug info in a friendly way. + """ + if isinstance(v, torch.Tensor): + return f"Tensor({v.shape}, grad={v.requires_grad}, dtype={v.dtype})" + else: + return str(v) + + +def map_debug_info(a: Argument) -> Argument: + """ + Helper function to apply `friendly_debug_info` to items in `a`. + `a` may be a list, tuple, or dict. + """ + return torch.fx.node.map_aggregate(a, friendly_debug_info) diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/_schedule_visualizer.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/_schedule_visualizer.py new file mode 100644 index 0000000000000000000000000000000000000000..5a5ae36999d81f7a5de754bb5aa6c1a1603e934b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/_schedule_visualizer.py @@ -0,0 +1,189 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +""" +This visualizer requires matplotlib to be installed. + +Example usage: + +ops = get_schedule_ops("InterleavedZeroBubble", 4, 8) +visualize_schedule(ops, "test.png") +""" + +from typing import Optional, Union +from unittest import mock + +from torch.distributed.pipelining.schedules import ( + _Action, + _ComputationType, + _PipelineSchedule, + get_schedule_class, + PipelineScheduleMulti, + PipelineScheduleSingle, +) +from torch.distributed.pipelining.stage import PipelineStage + + +def get_schedule_ops( + schedule: Union[str, _PipelineSchedule], + pp_degree: int, + num_microbatches: int, + num_stages_per_rank: Optional[int] = None, +) -> list[list[Optional[_Action]]]: + """ + Get all actions for a given schedule, pp_degree, and num_microbatches. The actions are returned in a list of lists + where each inner list represents a rank and each element in the inner list represents an action. + + The schedule can be specified as a string which is passed into get_schedule_class() or a _PipelineSchedule instance. + """ + + if isinstance(schedule, str): + schedule_class = get_schedule_class(schedule) + elif type(schedule) == _PipelineSchedule: + schedule_class = schedule + else: + raise ValueError(f"Invalid schedule: {schedule}") + + # Create a mock of the PipelineStage class + mock_pipeline_stage = mock.create_autospec(PipelineStage, instance=True) + # Set the return values for group_rank and group_size methods + mock_pipeline_stage.group_rank = 0 + mock_pipeline_stage.group_size = pp_degree + mock_pipeline_stage.submod = None + + # Check num_stages_per_rank is valid + if issubclass(schedule_class, PipelineScheduleSingle): + if num_stages_per_rank is None: + num_stages_per_rank = 1 + assert num_stages_per_rank == 1 + stages = mock_pipeline_stage + stages.num_stages = num_stages_per_rank * pp_degree + elif issubclass(schedule_class, PipelineScheduleMulti): + if num_stages_per_rank is None: + num_stages_per_rank = 2 + assert num_stages_per_rank >= 2 + stages = [mock_pipeline_stage for _ in range(num_stages_per_rank)] + for stage in stages: + stage.num_stages = num_stages_per_rank * pp_degree + + else: + raise ValueError(f"Invalid schedule: {schedule_class}") + + # Instantiate the schedule class + schedule_instance = schedule_class(stages, num_microbatches) + + # Convert to List[List[_Action]] + all_actions = [] + for rank in range(pp_degree): + all_actions.append(schedule_instance.pipeline_order[rank]) + + # Return the pipeline order + return all_actions + + +class _ComputationTypeColor: + def __init__( + self, + color: str, + text: str = "", + width: int = 1, + ): + self.color = color + self.width = width + self.text = text + + +# Update the mapping to use _ComputationTypeColor instances +action_type_to_color_mapping = { + _ComputationType.FORWARD: _ComputationTypeColor("blue", "Forward"), + _ComputationType.BACKWARD_INPUT: _ComputationTypeColor("teal", "Backward Input"), + _ComputationType.BACKWARD_WEIGHT: _ComputationTypeColor("green", "Backward Weight"), + _ComputationType.FULL_BACKWARD: _ComputationTypeColor("orange", "Full Backward", 2), +} + + +def visualize_schedule( + schedule: list[list[Optional[_Action]]], filename: Optional[str] = None +) -> None: + """ + Visualize the schedule using matplotlib. + The schedule is a list of lists where each inner list represents a rank and each element in the inner list represents an action. + The actions are represented as rectangles with different colors based on their computation type. + The filename is optional and if provided, the plot will be saved to that file. + """ + + import matplotlib.pyplot as plt + from matplotlib.patches import Rectangle + + plt.rcParams["font.family"] = ( + "DejaVu Sans" # or any other font available on your system + ) + num_ranks = len(schedule) + max_actions = max(len(rank) for rank in schedule) + + # Increase the figure size to provide more space for the legend + fig, ax = plt.subplots(figsize=(max_actions + 2, num_ranks + 2)) + max_draw_position = -1 + # Calculate dynamic font size based on figure size + font_size = min(max_actions, num_ranks) + 4 + used_computation = set() + for rank_idx, actions in enumerate(schedule): + draw_position = 0 # Initialize drawing position for each rank + for action in actions: + if action is not None: + comp_type_color = action_type_to_color_mapping.get( + action.computation_type, _ComputationTypeColor("black") + ) + used_computation.add(action.computation_type) + color = comp_type_color.color + width = comp_type_color.width + # Draw the rectangle to represent the action duration + rect = Rectangle( + (draw_position, num_ranks - rank_idx - 1), + width, + 1, + facecolor=color, + edgecolor="black", + ) + ax.add_patch(rect) + # Draw the text centered within the rectangle + ax.text( + draw_position + width / 2, + num_ranks - rank_idx - 1 + 0.5, + str(action), + ha="center", + va="center", + fontsize=font_size, + color="white", + ) + # Increment the drawing position by the width of the current action + draw_position += width + else: + draw_position += 1 # Move to the next + max_draw_position = max(max_draw_position, draw_position) + ax.set_xlim(-0.5, max_draw_position + 1) + ax.set_ylim(-0.5, num_ranks + 0.5) # Add extra space at the top + # Set y-ticks to be in the middle of each rank's row + ax.set_yticks([num_ranks - rank_idx - 0.5 for rank_idx in range(num_ranks)]) + ax.set_yticklabels([f"Rank {i}" for i in range(num_ranks)], fontsize=font_size) + ax.set_xticklabels([]) + + # Remove grid lines and ticks + ax.grid(False) + # Add legend with larger font size + legend_elements = [ + Rectangle( + (0, 0), + 1, + 1, + facecolor=action_type_to_color_mapping[comp_type].color, + edgecolor="black", + label=action_type_to_color_mapping[comp_type].text, + ) + for comp_type in used_computation + ] + ax.legend(handles=legend_elements, loc="upper right", fontsize=font_size) + # Save to file if filename is provided, otherwise display the plot + if filename: + plt.savefig(filename, bbox_inches="tight") + else: + plt.show() diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/_unflatten.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/_unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..df11c8a36f5f6ac4865e3f18c64395b1f96b5640 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/_unflatten.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections import defaultdict + +import torch +from torch.export.unflatten import _ModuleFrame, _SubmoduleEntry + + +def _outline_submodules(orig_graph: torch.fx.Graph) -> torch.fx.GraphModule: + # Create an empty GraphModule to hold the outlined modules + new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph()) + seen_nodes: dict[str, torch.fx.Node] = {} + seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: dict[str, set[str]] = defaultdict(set) + created_modules: dict[str, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + None, + [("", None, 0)], + "", + {}, + module=new_module, + ).run_outer() + new_module.graph.lint() + new_module.recompile() + return new_module diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/_utils.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0cfaa683837397f5864c53e74f6b5ba66bad53a6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/_utils.py @@ -0,0 +1,133 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from dataclasses import dataclass +from typing import Union + +import torch +from torch import fx + + +logger = logging.getLogger(__name__) + + +def flatten_args_detach(args): + """ + Flatten the args into a list form and detach the tensors from computational graph. + """ + flat_detached_args = [] + + def extract_tensor_args(a): + nonlocal flat_detached_args + if isinstance(a, torch.Tensor): + val = a.detach().requires_grad_(a.requires_grad) + flat_detached_args.append(val) + return val + else: + flat_detached_args.append(a) + return a + + new_args = fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return new_args, flat_detached_args + + +def flatten_args(args): + """ + Flatten the args into a list form. + """ + flat_args = [] + + def extract_tensor_args(a): + nonlocal flat_args + flat_args.append(a) + return a + + fx.node.map_aggregate( + args, + extract_tensor_args, + ) + + return flat_args + + +class PipeliningShapeError(RuntimeError): + """Shape mismatch between configured and runtime values.""" + + +def validate_tensor_metadata(desc, expected, given): + if not expected.shape == given.shape: + raise PipeliningShapeError( + f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}" + ) + if not expected.dtype == given.dtype: + raise PipeliningShapeError( + f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}" + ) + if not expected.stride() == given.stride(): + raise PipeliningShapeError( + f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}" + ) + + +def validate_tensors_metadata( + desc, + expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], + actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]], +): + if len(expected_tensors) != len(actual_tensors): + raise PipeliningShapeError( + f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})" + ) + for i in range(len(expected_tensors)): + validate_tensor_metadata( + f"{desc}: value {i}", expected_tensors[i], actual_tensors[i] + ) + + +def generate_stage_to_rank_mapping( + pp_size: int, num_stages: int, style: str = "loop" +) -> dict[int, int]: + """ + Compute the stage id to rank mapping for either a looped or V-style schedule. + + Most commonly num_stages == pp_size * 2, but this function can be used to + compute the mapping for any number of stages per rank. + """ + mapping = {} + if style == "loop": + for stage_index in range(num_stages): + mapping[stage_index] = stage_index % pp_size + elif style == "v": + if num_stages % pp_size != 0: + raise ValueError( + f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules" + ) + + rank_index = 0 + for stage_index in range(num_stages): + mapping[stage_index] = rank_index + # dont change rank if we are on the border (to keep v shape) + if (stage_index + 1) % pp_size == 0: + continue + if (stage_index // pp_size) % 2 == 0: + rank_index += 1 + else: + rank_index -= 1 + else: + raise ValueError(f"Style {style} is not supported.") + return mapping + + +@dataclass +class PipeInfo: + """ + Captures information for a pipeline (`Pipe` object). + """ + + graph: fx.Graph + num_stages: int + has_loss_and_backward: bool diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/microbatch.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/microbatch.py new file mode 100644 index 0000000000000000000000000000000000000000..e841a00a43e008ae01fe3c2b4ccb1b13cd6bea04 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/microbatch.py @@ -0,0 +1,469 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from typing import Any, Optional + +import torch +from torch.fx.node import map_aggregate +from torch.utils._pytree import tree_flatten, tree_unflatten + + +__all__ = [ + "TensorChunkSpec", + "split_args_kwargs_into_chunks", + "merge_chunks", +] + +logger = logging.getLogger(__name__) + +""" +_debug_mask_minibatches specifies to send masked versions of the mini-batch +through instead of micro-batch slices--this can be used for more stable +numerical testing (see [A Note About Correctness Testing]) +""" +_debug_mask_minibatches = False + + +class _CustomReducer: + """ + Custom reducer class that can be used to specify a custom operation that + reduces losses of multiple microbatches into one value. + + Example: + >>> # xdoctest: +SKIP + >>> sum_reducer = _CustomReducer( + >>> torch.tensor(0.0), + >>> lambda a, b: a + b + >>> ) + """ + + def __init__(self, init_value, reduce_fn): + self.init_value = init_value + self.reduce_fn = reduce_fn + + +class _LossReducer(_CustomReducer): + pass + + +sum_reducer = _LossReducer(torch.tensor(0.0), operator.add) + +# Default chunking dimension is 0. This is used for the case where the user did +# not specify a chunking dimension. +DEFAULT_CHUNK_DIM = 0 + + +class TensorChunkSpec: + """ + Class used to specify chunking of inputs + """ + + def __init__(self, split_dim): + self.split_dim = split_dim + + split_dim: int + + def __repr__(self): + return ( + f"{self.__class__.__module__}.{self.__class__.__name__}({self.split_dim})" + ) + + def __str__(self): + return f"TensorChunkSpec({self.split_dim})" + + @staticmethod + def from_tuple( + chunk_dims: tuple[int, ...], + ): + """ + A helper for creating a tuple of `TensorChunkSpec` from a tuple of chunk + dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # There are three positional arguments to the model, and + >>> # we are chunking them along dimension 0, 0 and 1, respectively + >>> args_chunk_spec = TensorChunkSpec.from_tuple((0, 0, 1)) + """ + args_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return args_chunk_spec + + @staticmethod + def from_dict( + chunk_dims: dict[str, int], + ): + """ + A helper for creating a dictionary of `TensorChunkSpec` from a + dictionary of chunk dimensions (int's). + Example: + >>> # xdoctest: +SKIP + >>> # Chunk dimension 0 for the "id" argument, 1 for the "mask" argument + >>> kwargs_chunk_spec = TensorChunkSpec.from_dict({"id": 0, "mask": 1}) + """ + kwargs_chunk_spec = map_aggregate( + chunk_dims, + lambda dim: TensorChunkSpec(dim), # type: ignore[arg-type,return-value] + ) + return kwargs_chunk_spec + + +# Class used to specify replication of inputs +class _Replicate: + pass + + +def _shard_dict_of_args( + args_dict, + args_chunk_spec, + num_chunks, +): + """ + Given a dictionary of args, and a dictionary of chunking specs, shard the + args according to the chunking specs. + + Args: + args_dict: Dictionary of args + args_chunk_spec: Dictionary of chunking specs + num_chunks: Number of chunks to shard the args into + + Returns: + args_split: List of sharded args + """ + # Stage 1+2: flatten and shard/replicate + + # args_sharded_replicated : [num args, num flat values, num chunks] + args_sharded_replicated = {} + arg_specs = [] + + real_num_chunks = num_chunks + first_tensor = True + + assert len(args_dict) == len(args_chunk_spec), ( + f"args_dict.keys() = {list(args_dict.keys())} args_chunk_spec.keys() = {list(args_chunk_spec.keys())}" + ) + + for arg_key, arg in args_dict.items(): + flat, spec = tree_flatten(arg) + arg_specs.append(spec) + + chunk_spec = args_chunk_spec[arg_key] + assert chunk_spec is not None # Should have been set by caller + chunk_spec_flat, _ = tree_flatten(chunk_spec) + if len(flat) != len(chunk_spec_flat): + raise ValueError( + f"Argument value {arg} did not have the same number of " + f"values as as chunk spec {chunk_spec}" + ) + + sharded_arg_flat = [] + + for v, chunk_v in zip(flat, chunk_spec_flat): + if chunk_v is _Replicate or not isinstance(v, torch.Tensor): + sharded_arg_flat.append([v] * real_num_chunks) + elif isinstance(chunk_v, TensorChunkSpec): + # TODO: check type of v. If it's a tensor, use chunk (or debug mask). + # If it's a collection type, split it as you would expect. Otherwise, + # Throw an error + assert isinstance(v, torch.Tensor), f"{v} is not a tensor" + + v_split_dim_size = v.size(chunk_v.split_dim) + if v_split_dim_size < real_num_chunks: + if first_tensor: + # We can only adjust number of chunks when we hit this + # issue at the first tensor encountered + logger.warning( + f"Tensor size on chunking dimension is {v_split_dim_size}, " # noqa: G004 + f"downsizing the number of chunks from {num_chunks} to {v_split_dim_size}." + ) + real_num_chunks = v_split_dim_size + else: + raise RuntimeError( + f"Arg {arg_key} on chunking dimension has a size of {v_split_dim_size}, " + f"smaller than the number of chunks {num_chunks}. " + "PiPPy cannot reduce the number of chunks because " + "other arguments have bigger chunk-dimension sizes. " + "Please adjust your num_chunks setting." + ) + + chunk_tensors = torch.tensor_split( + v, real_num_chunks, chunk_v.split_dim + ) + + if _debug_mask_minibatches: + expanded_chunks = [] + + split_dim_idx = 0 + for chunk_tensor in chunk_tensors: + new_val = torch.zeros_like(v) + upper_idx = split_dim_idx + chunk_tensor.size(chunk_v.split_dim) + + slice_indices = [slice(None, None, None)] * new_val.ndim + slice_indices[chunk_v.split_dim] = slice( + split_dim_idx, upper_idx + ) + new_val[slice_indices] = chunk_tensor + + expanded_chunks.append(new_val) + + split_dim_idx += chunk_tensor.size(chunk_v.split_dim) + + sharded_arg_flat.append(expanded_chunks) + else: + sharded_arg_flat.append(chunk_tensors) # type: ignore[arg-type] + + first_tensor = False + else: + raise TypeError(f"Unrecognized chunk spec: {chunk_v}") + + args_sharded_replicated[arg_key] = sharded_arg_flat + + # chunks_flat : [num chunks, num args, num flat values] + chunks_flat = [] + for chunk_idx in range(real_num_chunks): + chunk_args = {} + for key, arg in args_sharded_replicated.items(): + arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg] + chunk_args[key] = arg_single_chunk + chunks_flat.append(chunk_args) + + # args_split : [num chunks, num args] + args_split = [] + + for chunk in chunks_flat: + per_chunk_args = {} + assert len(arg_specs) == len(chunk) + for (key, arg), arg_spec in zip(chunk.items(), arg_specs): + per_chunk_args[key] = tree_unflatten(arg, arg_spec) + args_split.append(per_chunk_args) + + return args_split + + +def split_args_kwargs_into_chunks( + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]], + chunks: int, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, +) -> tuple[list[tuple], list[dict]]: + """ + Given a sequence of args and kwargs, split them into a number of chunks + according to their respective chunking specs. + + Args: + args: Tuple of args + kwargs: Dict of kwargs + chunks: Number of chunks to split the args and kwargs into + args_chunk_spec: chunking specs for args, in same shape as args + kwargs_chunk_spec: chunking specs for kwargs, in same shape as kwargs + + Returns: + args_split: List of sharded args + kwargs_split: List of sharded kwargs + """ + # Given `args` and `kwargs`, we want to yield a set of `chunks` args and kwargs such that + # the constituent Tensor values have been sharded/replicated according to the `args_chunk_spec` + # and `kwargs_chunk_spec` specifications. The steps are as follows: + # + # 1. Use pytree.tree_flatten to flatten each arg and its spec into nto a 1d array of values. + # To use a running example: suppose our inputs look like + # + # args = ([A, [B, C]], D) args_spec = ([None, [None, TensorChunkSpec]], None) + # (kwargs not shown but it's a similar process) + # + # Then for this step we would end up with + # + # args = ([A, B, C], D) args_spec = ([None, None, TensorChunkSpec], None) + # + # 2. Shard or replicate the arguments subject to the policy in the spec. Suppose chunks = 2 + # + # args = ([[A, A], [B, B], [C_1, C_2]], [D, D]) + # + # 3. Rotate the nesting order such that chunks are the outer dimension + # + # args_chunks = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 4. Unflatten each chunk according to the spec + # + # args_chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + + # TODO: _debug_mask_minibatches + # Handle the case where kwargs is None + if kwargs is None: + kwargs = {} + + # If user did not provide args_chunk_spec or kwargs_chunk_spec, we extend + # their format and use default chunking along dim 0 + if args_chunk_spec is None: + args_chunk_spec = (TensorChunkSpec(DEFAULT_CHUNK_DIM),) * len(args) + + if kwargs_chunk_spec is None: + kwargs_chunk_spec = dict.fromkeys(kwargs, TensorChunkSpec(DEFAULT_CHUNK_DIM)) + + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + chunks, + ) + real_num_chunks = len(args_split_dict) + + kwargs_split = _shard_dict_of_args( + kwargs, + kwargs_chunk_spec, + real_num_chunks, + ) + + if len(kwargs_split) < real_num_chunks: + # In case kwargs are sharded into less chunks + # e.g. when `args` has no tensor, just values + real_num_chunks = len(kwargs_split) + # Re-shard args + args_split_dict = _shard_dict_of_args( + dict(enumerate(args)), + dict(enumerate(args_chunk_spec)), + real_num_chunks, + ) + + if len(args_split_dict) != len(kwargs_split): + raise RuntimeError( + "args and kwargs are split into different number of chunks: " + f"{len(args_split_dict)}, {len(kwargs_split)}" + ) + + args_split = [ + tuple(chunk_args[i] for i in range(len(chunk_args))) + for chunk_args in args_split_dict + ] + + return args_split, kwargs_split + + +def merge_chunks( + chunks: list[Any], + chunk_spec, +): + """ + Given a list of chunks, merge them into a single value according to + the chunk spec. + + Args: + chunks: list of chunks + chunk_spec: Chunking spec for the chunks + + Returns: + value: Merged value + """ + # This is essentially the inverse of `split_args_kwargs_into_chunks`, so the + # steps are similar to the steps in that function but in reverse. Given the + # input values: + # + # chunks = [ + # ([A, [B, C_1]], D), + # ([A, [B, C_2]], D), + # ] + # args_spec = ([None, [None, TensorChunkSpec]], None) + # + # 1. Flatten the chunks according to the chunk_spec + # + # chunks_flat = [ + # ([A, B, C_1], D), + # ([A, B, C_2], D), + # ] + # + # 2. Rotate the nesting order such that chunks are the inner dimension + # + # value_inner = ([A, B, [C_1, C_2]], D) + # + # 3. Concatenate sharded arguments + # + # value_combined = ([A, B, C], D) + # + # 4. Unflatten the combined args given the spec + # + # value = ([A, [B, C]], D) + + # Preliminary: flatten the chunk spec + if chunk_spec is not None: + spec_flattened, flatten_spec = tree_flatten(chunk_spec) + else: + # If chunk_spec is not provided, we will merge chunks along the default dimension (0), for all output fields + # We obtain the output structure by flattening chunk 0 and generate the chunk_spec + chunk0_flat, flatten_spec = tree_flatten(chunks[0]) + spec_flattened = [TensorChunkSpec(DEFAULT_CHUNK_DIM)] * len(chunk0_flat) + + # Stage 1: flatten chunks + # chunks_flattened : [num chunks, num args] + chunks_flattened = [] + + for chunk in chunks: + chunk_flattened, _ = tree_flatten(chunk) + if len(chunk_flattened) != len(spec_flattened): + raise ValueError(f"Chunk {chunk} did not match chunk spec {chunk_spec}") + + chunks_flattened.append(chunk_flattened) + + # Stage 2 and 3: Rotate nesting order s.t. chunks are inner dimension and + # concatenate sharded operands + # args_flattened : [num args] + args_flattened = [] + for arg_idx, arg in enumerate(spec_flattened): + if isinstance(arg, TensorChunkSpec): + partial_values = [ + chunks_flattened[chunk_idx][arg_idx] + for chunk_idx in range(len(chunks_flattened)) + ] + + if _debug_mask_minibatches: + # Infer size of individual chunks by running `tensor_split` again + overall_shape = partial_values[0].shape + for val in partial_values[1:]: + assert val.shape == overall_shape + meta_chunks = torch.tensor_split( + torch.empty(*overall_shape, device="meta"), + sections=len(partial_values), + dim=arg.split_dim, + ) + + values_to_cat = [] + chunk_start_idx = 0 + assert len(partial_values) == len(meta_chunks) + for partial_value, meta_chunk in zip(partial_values, meta_chunks): + chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim) + + slice_indices = [slice(None, None, None)] * partial_value.ndim + slice_indices[arg.split_dim] = slice(chunk_start_idx, chunk_end_idx) + sliced = partial_value[slice_indices] + values_to_cat.append(sliced) + + chunk_start_idx = chunk_end_idx + + else: + values_to_cat = partial_values + + args_flattened.append(torch.cat(values_to_cat, dim=arg.split_dim)) + elif isinstance(arg, _CustomReducer): + reduced_val = arg.init_value + + for chunk_idx in range(len(chunks_flattened)): + reduced_val = arg.reduce_fn( + reduced_val, chunks_flattened[chunk_idx][arg_idx] + ) + + args_flattened.append(reduced_val) + else: + value = chunks_flattened[0][arg_idx] + for chunk_idx in range(1, len(chunks_flattened)): + assert chunks_flattened[chunk_idx][arg_idx] == value + args_flattened.append(value) + + # Stage 4: Unflatten combined args + return tree_unflatten(args_flattened, flatten_spec) diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/schedules.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/schedules.py new file mode 100644 index 0000000000000000000000000000000000000000..0b227aa21162c272cdf1d995f63d5711d2f6475f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/schedules.py @@ -0,0 +1,2773 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +import copy +import csv +import itertools +import logging +import re +from abc import ABC, abstractmethod +from collections import Counter, defaultdict +from enum import Enum +from typing import Any, Callable, NamedTuple, Optional, Union + +import torch +import torch.distributed as dist +from torch._dynamo import OptimizedModule +from torch.distributed.fsdp import FSDPModule, UnshardHandle +from torch.nn.modules.loss import _Loss +from torch.profiler import record_function + +from ._utils import generate_stage_to_rank_mapping +from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec +from .stage import _PipelineStageBase + + +__all__ = [ + "get_schedule_class", + "PipelineScheduleSingle", + "PipelineScheduleMulti", + "Schedule1F1B", + "ScheduleGPipe", + "ScheduleInterleaved1F1B", + "ScheduleLoopedBFS", + "ScheduleInterleavedZeroBubble", + "ScheduleZBVZeroBubble", +] + +logger = logging.getLogger(__name__) + + +class _ComputationType(Enum): + # TODO(whc) rename to _ActType? + FORWARD = 1 + BACKWARD_INPUT = 2 + BACKWARD_WEIGHT = 3 + UNSHARD = 4 + RESHARD = 5 + SEND_F = 6 + RECV_F = 7 + SEND_B = 8 + RECV_B = 9 + FULL_BACKWARD = 10 + + def __str__(self): + str_map = { + _ComputationType.FORWARD: "F", + _ComputationType.BACKWARD_INPUT: "I", + _ComputationType.BACKWARD_WEIGHT: "W", + _ComputationType.UNSHARD: "UNSHARD", + _ComputationType.RESHARD: "RESHARD", + _ComputationType.SEND_F: "SEND_F", + _ComputationType.RECV_F: "RECV_F", + _ComputationType.SEND_B: "SEND_B", + _ComputationType.RECV_B: "RECV_B", + _ComputationType.FULL_BACKWARD: "B", + } + return str_map[self] + + @staticmethod + def from_str(action): + if action == "F": + return _ComputationType.FORWARD + elif action == "I": + return _ComputationType.BACKWARD_INPUT + elif action == "W": + return _ComputationType.BACKWARD_WEIGHT + elif action == "UNSHARD": + return _ComputationType.UNSHARD + elif action == "RESHARD": + return _ComputationType.RESHARD + elif action == "SEND_F": + return _ComputationType.SEND_F + elif action == "RECV_F": + return _ComputationType.RECV_F + elif action == "SEND_B": + return _ComputationType.SEND_B + elif action == "RECV_B": + return _ComputationType.RECV_B + elif action == "B": + return _ComputationType.FULL_BACKWARD + else: + raise RuntimeError(f"Invalid computation type {action}") + + +FORWARD = _ComputationType.FORWARD +BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT +BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT +UNSHARD = _ComputationType.UNSHARD +RESHARD = _ComputationType.RESHARD +SEND_F = _ComputationType.SEND_F +RECV_F = _ComputationType.RECV_F +SEND_B = _ComputationType.SEND_B +RECV_B = _ComputationType.RECV_B +FULL_BACKWARD = _ComputationType.FULL_BACKWARD + +# Convenience shorthand for compute actions only since they are used in 'simple schedule format' +F = FORWARD +I = BACKWARD_INPUT +W = BACKWARD_WEIGHT +B = FULL_BACKWARD + +# Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index) +_action_regex = re.compile( + r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)" +) + + +class _Action(NamedTuple): + stage_index: int + computation_type: _ComputationType + microbatch_index: Optional[int] = None + + def __repr__(self): + repr = str(self.stage_index) + repr += str(self.computation_type) + if self.microbatch_index is not None: + repr += str(self.microbatch_index) + return repr + + @staticmethod + def from_str(action_string: str): + """ + Reverse of __repr__ + + String should be formatted as [stage][action type][(microbatch)] + e.g. `2F0`, `1UNSHARD`, `3SEND_F1` + """ + action_string = action_string.strip() + if match := _action_regex.match(action_string): + stage_index, computation_type, microbatch_index = match.groups() + return _Action( + int(stage_index), + _ComputationType.from_str(computation_type), + int(microbatch_index) if len(microbatch_index) else None, + ) + elif action_string == "": + return None + raise RuntimeError( + f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0" + ) + + +def _format_pipeline_order( + pipeline_order: dict[int, list[Optional[_Action]]], + error_step_number: Optional[int] = None, +) -> str: + """ + Formats the pipeline order in a timestep (row) x rank (column) grid of actions + and returns the formatted string. + + If `error_step_number` is passed in, an additional label will be added to signify which step + that it is erroring on. + """ + + # don't mutate the original + pipeline_order = copy.deepcopy(pipeline_order) + + # Replace None with "" + for rank in pipeline_order: + for i in range(len(pipeline_order[rank])): + if pipeline_order[rank][i] is None: + # TODO make a real 'None action' that prints as empty string and make mypy happy + pipeline_order[rank][i] = "" # type: ignore[call-overload] + + # Calculate the maximum number of steps across all ranks + num_steps = max(len(actions) for actions in pipeline_order.values()) + step_labels = [ + "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps) + ] + # Sorting the dictionary by keys and retrieving values in that order + rank_actions = [ + pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) + ] + # Transpose the list of lists (rows to columns) + transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) + # Generate column labels for ranks + num_ranks = len(pipeline_order) + rank_labels = ["Rank " + str(i) for i in range(num_ranks)] + # Calculate the maximum length of each column, considering labels + max_lengths = [ + max(len(str(item)) if item is not None else 0 for item in col) + for col in zip(step_labels, *transposed_actions) + ] + # Format the header row with rank labels + header_row = " " * (len(step_labels[0]) + 2) + " ".join( + f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels) + ) + # Format each row with its corresponding label + formatted_rows = [ + f"{label}: " + + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row)) + + ( + " <-- ERROR HERE" + if error_step_number is not None + and int(label.split()[1]) == error_step_number + else "" + ) + for label, row in zip(step_labels, transposed_actions) + ] + # Join the rows into a single string + formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n" + return formatted_table + + +class _PipelineSchedule(ABC): + def __init__( + self, + n_microbatches: int, + loss_fn: Optional[Callable[..., torch.Tensor]] = None, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + scale_grads: bool = True, + ): + # From arguments + self._n_microbatches = n_microbatches + self._loss_fn = loss_fn + + # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti` + self.scale_grads = scale_grads + + # Chunking specification for positional inputs. (default: `None`) + self._args_chunk_spec = args_chunk_spec + # Chunking specification for keyword inputs. (default: `None`) + self._kwargs_chunk_spec = kwargs_chunk_spec + self._output_merge_spec = output_merge_spec + """ + # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs. + # They are used to convert batch to microbatches in `step(x)`. See + # `TensorChunkSpec` for helper methods for creating them. + """ + + # Derived + self._has_backward = self._loss_fn is not None + + # Holds the losses for each microbatch. + self._internal_losses: list[torch.Tensor] = [] + logger.info("Using %s", self.__class__.__name__) + + def _maybe_compute_loss(self, stage, output, target_mbs, mb_index): + if stage.is_last and self._has_backward: + loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index] + self._internal_losses.append(loss) + + def _maybe_get_loss(self, stage, mb_index): + valid_index = 0 <= mb_index < len(self._internal_losses) + if stage.is_last and self._has_backward and valid_index: + return self._internal_losses[mb_index] + elif len(self._internal_losses) != 0 and not valid_index: + raise RuntimeError( + f"Loss for microbatch {mb_index} is not available. " + f"Available losses for microbatches: {self._internal_losses}" + ) + else: + return None + + def _update_losses(self, stages, losses): + """ + Update the losses to those in the internal state + """ + # if stages not a list turn into a list + if not isinstance(stages, list): + stages = [stages] + contains_last_stage = any(stage.is_last for stage in stages) + + # Return losses if there is a container passed in + if contains_last_stage and losses is not None: + if len(self._internal_losses) != self._n_microbatches: + raise RuntimeError( + f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}" + ) + + # Clean external container first + losses.clear() + # Copy internal losses to external container + losses.extend(self._internal_losses) + + self._internal_losses.clear() + + @abstractmethod + def _step_microbatches( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the schedule + implementation. + + Args: + microbatches: list of microbatch args. + """ + raise NotImplementedError + + @abstractmethod + def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + raise NotImplementedError + + def _check_inputs( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Pre-process/check inputs + """ + + def check_type_and_len(mbs, name: str): + if not isinstance(mbs, list): + raise TypeError(f"{name} must be a list but got a {type(mbs)}") + if len(mbs) != self._n_microbatches: + raise ValueError( + f"Expecting {self._n_microbatches} {name} but got {len(mbs)}" + ) + + if arg_mbs is not None: + check_type_and_len(arg_mbs, "arg_mbs") + else: + arg_mbs = [()] * self._n_microbatches + + if kwarg_mbs is not None: + check_type_and_len(kwarg_mbs, "kwarg_mbs") + else: + kwarg_mbs = [{}] * self._n_microbatches + + if target_mbs is not None: + check_type_and_len(target_mbs, "target_mbs") + + if losses is not None: + if not isinstance(losses, list): + raise TypeError(f"losses must be a list but got a {type(losses)}") + + return arg_mbs, kwarg_mbs + + def _compute_loss(self, output, target): + return self._loss_fn(output, target) # type: ignore[misc] + + def _split_inputs( + self, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + ): + """ + Splits a full-batch input into chunks (i.e. microbatches) and returns + the chunks + """ + if args or kwargs: + args_split, kwargs_split = split_args_kwargs_into_chunks( + args, + kwargs, + self._n_microbatches, + self._args_chunk_spec, + self._kwargs_chunk_spec, + ) + return args_split, kwargs_split + else: + # Empty inputs (e.g. when called on middle stages) + # Return a list of empty tuples/dicts with matching length as chunks + return [()] * self._n_microbatches, [{}] * self._n_microbatches + + def _merge_outputs(self, output_chunks: list[Any]) -> Any: + """ + Merge output chunks back to a batch state. + If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim). + """ + return merge_chunks( + output_chunks, + self._output_merge_spec, + ) + + +def _batch_p2p( + p2p_ops: list[dist.P2POp], desc: Optional[str] = None +) -> list[dist.Work]: + """ + Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top. + """ + if len(p2p_ops) == 0: + return [] + desc_str = f"{desc}, " if desc else "" + logger.debug("batch_p2p %s%s", desc_str, p2p_ops) + return dist.batch_isend_irecv(p2p_ops) + + +def _sorted_batch_p2p( + p2p_ops: list[dist.P2POp], desc: Optional[str] = None +) -> dict[int, list[dist.Work]]: + """ + Sorts the list of P2P ops by the peer rank, and then calls + batch_isend_irecv. Return a dictionary of works by peer rank. This function + helps us avoid hangs in case of skip connections. + """ + # Arrange p2p_ops by peer rank: + # int is the peer rank; + # List is the list of ops towards the peer + ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list) + work_by_peer: dict[int, list[dist.Work]] = {} + if len(p2p_ops) == 0: + return work_by_peer + + # Classify the ops by peer rank + for op in p2p_ops: + ops_by_peer[op.peer].append(op) + + # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs) + for peer, ops in sorted(ops_by_peer.items()): + work_by_peer[peer] = _batch_p2p(ops, desc=desc) + + return work_by_peer + + +def _wait_batch_p2p(work: list[dist.Work]): + """ + Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p). + """ + for w in work: + w.wait() + + +class PipelineScheduleSingle(_PipelineSchedule): + """ + Base class for single-stage schedules. + Implements the `step` method. + Derived classes should implement `_step_microbatches`. + + Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting + should match the configuration of your loss_fn, which may either average losses (scale_grads=True) + or sum losses (scale_grads=False). + """ + + def __init__( + self, + stage: _PipelineStageBase, + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + scale_grads: bool = True, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + # Self attributes + self._stage = stage + self._num_stages = stage.num_stages + # Set the same has_backward flag for stage object + self._stage.has_backward = self._has_backward + self._stage_initialized = False + + if n_microbatches < self._num_stages: + raise ValueError( + f"Number of microbatches ({n_microbatches}) must be greater than \ +or equal to the number of stages ({self._num_stages})." + ) + + self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = ( + self._get_pipeline_order() + ) + + def _initialize_stage(self, args, kwargs): + self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs) + if self._has_backward: + self._stage._prepare_backward_infra(self._n_microbatches) + self._stage_initialized = True + + def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + + # Clean per iteration + self._stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + if self._stage.is_last: + return self._merge_outputs(self._stage.output_chunks) + else: + return None + + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline execution order as a schedule IR. + + The returned IR is a dictionary mapping rank IDs to lists of actions. + Each action is either an _Action object representing computation to perform, + or None representing a deliberate idle step. + + The None values are used to represent pipeline bubbles where a rank + must wait for dependencies from other ranks before proceeding. However + during execution, with the _PipelineScheduleRuntime, these Nones are + skipped since the relevant communication (send/recv) will be scheduled and waited on. + + Returns: + A dictionary mapping rank -> list of actions + """ + return None + + +class _ScheduleForwardOnly(PipelineScheduleSingle): + """ + The forward-only schedule. + Will go through all the microbatches and perform only the forward pass + """ + + def _step_microbatches( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Run one iteration of the pipeline schedule + """ + if target_mbs is not None or losses is not None: + raise RuntimeError( + "Forward-only schedule does not support loss computation" + ) + + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[list[dist.Work]] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + _wait_batch_p2p(work) + + +class ScheduleGPipe(PipelineScheduleSingle): + """ + The GPipe schedule. + Will go through all the microbatches in a fill-drain manner. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the GPipe schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Delay send waits + fwd_sends_to_wait: list[list[dist.Work]] = [] + + # Run microbatches + for i in range(self._n_microbatches): + with record_function(f"Forward {i}"): + ops = self._stage.get_fwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index] + + ops = self._stage.get_fwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="fwd_send") + fwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i) + + self._maybe_compute_loss(self._stage, output, target_mbs, i) + + # Wait for all forward sends to finish + # This should not have performance impact because by the time the first + # backward arrives all the forward sends should have been finished. + for work in fwd_sends_to_wait: + _wait_batch_p2p(work) + + # No loss function, no need to run backward + if not self._has_backward: + return + + # Run backward + # Delay send waits + bwd_sends_to_wait: list[list[dist.Work]] = [] + for i in range(self._n_microbatches): + with record_function(f"Backward {i}"): + ops = self._stage.get_bwd_recv_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_recv") + for work in works.values(): + _wait_batch_p2p(work) + + loss = self._maybe_get_loss(self._stage, i) + self._stage.backward_one_chunk( + i, + loss=loss, + last_backward=i == self._n_microbatches - 1, + ) + + ops = self._stage.get_bwd_send_ops(i) + works = _sorted_batch_p2p(ops, desc="bwd_send") + bwd_sends_to_wait.extend(works.values()) + + logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i) + + self._stage.scale_grads( + grad_scale_factor=self._n_microbatches if self.scale_grads else 1 + ) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + # Wait for all backward sends to finish + for work in bwd_sends_to_wait: + _wait_batch_p2p(work) + + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline order for GPipe schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[Optional[_Action]] = [] + + # 1. Initial delay based on rank position + warmup_delay = rank + actions.extend([None] * warmup_delay) + + # 2. Forward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx)) + + # 3. Wait period before backward passes can begin + backward_delay = 3 * (pp_group_size - 1 - rank) + actions.extend([None] * backward_delay) + + # 4. Backward passes for all microbatches + for mb_idx in range(self._n_microbatches): + actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx)) + + pipeline_order[rank] = actions + + return pipeline_order + + +class Schedule1F1B(PipelineScheduleSingle): + """ + The 1F1B schedule. + Will perform one forward and one backward on the microbatches in steady state. + """ + + def _step_microbatches( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Run one iteration of the pipeline schedule with list of microbatches. + Will go through all the microbatches according to the 1F1B schedule. + + Args: + microbatches: list of microbatch args. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + if not self._stage_initialized: + self._initialize_stage(arg_mbs[0], kwarg_mbs[0]) + + # Last stage has 1 warmup, second-to-last 2 warmups, ... + # first stage `num_stages` warmups + warmup_chunks = min( + self._n_microbatches, + self._num_stages - self._stage.stage_index, + ) + + # Chunk counters + fwd_mb_index = 0 + bwd_mb_index = 0 + + # Warmup phase + send_work: list[dist.Work] = [] + fwd_sends = [] + for _ in range(warmup_chunks): + # Receive activations + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv")) + + # Compute + output = self._stage.forward_one_chunk( + fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index] + ) # type: ignore[index] + + # Clear previous chunk's forward sends (hopefully they have well + # finished, otherwise, we are heavily communication bound, in which + # case it doesn't create a lot of benefit to compute next chunk + # eagerly either) + _wait_batch_p2p(send_work) + + # Send activations + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + if fwd_mb_index != warmup_chunks - 1: + # Safe to fire + send_work = _batch_p2p(fwd_sends, desc="fwd_send") + # otherwise: + # The last forward send is left for fuse with first 1B in 1B1F below + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + fwd_mb_index += 1 + + # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below. + + # 1B1F phase + while True: # Don't worry, we have a break inside + # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + + # Now, we need to fire the fwd_sends and bwd_recvs together + _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv")) + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Get the bwd send ops, but don't fire, to be fused with the 1F below + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + bwd_mb_index += 1 + + if fwd_mb_index == self._n_microbatches: + # We are done with 1B1F, so break with some left-over bwd_sends + break + + # We prepare 1F of the `1B1F` + fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index) + + # Fuse it with bwd_sends above + _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv")) + + # Now do the fwd + output = self._stage.forward_one_chunk( + fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index] + ) # type: ignore[index] + + # Compute loss + self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index) + + # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around) + fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index) + fwd_mb_index += 1 + + # Remember we still have some bwd_sends left over after the break? Now it is time to fire it + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + + # Cooldown + while bwd_mb_index < self._n_microbatches: + # prepare bwd recv ops + bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index) + _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv")) + + # Backward one chunk + loss = self._maybe_get_loss(self._stage, bwd_mb_index) + self._stage.backward_one_chunk( + bwd_mb_index, + loss=loss, + last_backward=bwd_mb_index == self._n_microbatches - 1, + ) + + # Clear previous chunk's backward sends (hopefully they have well finished) + _wait_batch_p2p(send_work) + + # Get the bwd send ops, fire it + bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index) + send_work = _batch_p2p(bwd_sends, desc="bwd_send") + bwd_mb_index += 1 + + self._stage.scale_grads( + grad_scale_factor=self._n_microbatches if self.scale_grads else 1 + ) + + # Wait for the last backward send to finish + _wait_batch_p2p(send_work) + + # Return losses if there is a container passed in + self._update_losses(self._stage, losses) + + def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]: + """ + Returns the pipeline order for 1F1B schedule. + + See base method in PipelineScheduleSingle for details on the schedule IR format. + """ + pipeline_order = {} + pp_group_size = self._num_stages + + for rank in range(pp_group_size): + actions: list[Optional[_Action]] = [] + + # 1. Warmup phase: initial delay based on rank + actions.extend([None] * rank) + + # 2. Initial forward passes before 1F1B phase + num_forward = (pp_group_size - 1) - rank + forward_mb = 0 + for i in range(num_forward): + actions.append(_Action(rank, _ComputationType.FORWARD, i)) + forward_mb = i + + # 3. Wait for backward to be ready + wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank)) + actions.extend([None] * wait_for_1f1b) + + # 4. 1F1B steady state phase + backward_mb = 0 + remaining_forward = self._n_microbatches - num_forward + + while remaining_forward > 0: + # One forward + forward_mb += 1 + actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb)) + remaining_forward -= 1 + + # One backward + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + + # 5. Cooldown phase: remaining backward passes + remaining_backward = self._n_microbatches - backward_mb + + while remaining_backward > 0: + # Add None and backward actions in alternating pattern + # based on distance from the last stage + if (pp_group_size - rank) > 0: + actions.append(None) + # Decrement the wait counter only if we still have backward passes to do + if remaining_backward > 0: + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + else: + # If we're at the last stage, just add backward actions without None + actions.append( + _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb) + ) + backward_mb += 1 + remaining_backward -= 1 + + pipeline_order[rank] = actions + return pipeline_order + + +def _add_unshard_reshard( + compute_actions: list[Optional[_Action]], + max_active_stages: int = 3, +) -> list[_Action]: + """Given a basic schedule involving only compute actions (F,B,W), add UNSHARD/RESHARD actions for FSDP. + + UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation. + RESHARD does the opposite, releasing memory (but doing no communication) + + We abandon the "timestep lock" during lowering + + max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice + 3 stages is probably the thing we want? + (to account for having one f and one b active, and something else prefetching?) + """ + + def next_stage_indices( + count: int, next_actions: list[Optional[_Action]] + ) -> list[int]: + """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute.""" + seen: set[int] = set() + ret: list[int] = [] + + for a in next_actions: + if a is not None and a.stage_index not in seen: + seen.add(a.stage_index) + ret.append(a.stage_index) + if len(ret) == count: + break + return ret + + active_stages: set[int] = set() + fsdp_aware_actions: list[_Action] = [] + + def _unshard(stage_index: int): + active_stages.add(stage_index) + fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None)) + + def _reshard(stage_index: int): + active_stages.remove(stage_index) + fsdp_aware_actions.append(_Action(stage_index, RESHARD, None)) + + for i, action in enumerate(compute_actions): + if action is None: + continue + + # We prefetch the next N stages we'll see, dropping existing stages to make room + next_n = next_stage_indices(max_active_stages, compute_actions[i:]) + # Fetch needs to be ordered correctly, so don't use a set + fetch = list(filter(lambda s: s not in active_stages, next_n)) + # Unclear what the best policy is for eviction, but we can maintain order so we do + evict = list(filter(lambda s: s not in next_n, active_stages)) + + # logger.debug( + # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s", + # i, + # active_stages, + # fetch, + # evict, + # ) + + for stage in evict: + _reshard(stage) + for stage in fetch: + _unshard(stage) + fsdp_aware_actions.append(action) + + return fsdp_aware_actions + + +def _merge_bw( + compute_actions: list[Optional[_Action]], +) -> list[_Action]: + """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops. + (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD) + + B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient + in some cases. + """ + merged_actions = [] + while compute_actions: + action = compute_actions.pop(0) + if action is None: + continue + + while len(compute_actions) and (next_action := compute_actions[0]) is None: + # remove any None actions between 'action' and 'next_action' + compute_actions.pop(0) + + if ( + action.computation_type == BACKWARD_INPUT + and next_action is not None + and next_action.computation_type == BACKWARD_WEIGHT + and action.stage_index == next_action.stage_index + and action.microbatch_index == next_action.microbatch_index + ): + merged_actions.append( + _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index) + ) + compute_actions.pop(0) + else: + merged_actions.append(action) + return merged_actions + + +def _add_send_recv( + compute_actions: dict[int, list[_Action]], + stage_to_rank: Callable[[int], int], + num_stages: int, +) -> dict[int, list[_Action]]: + comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions} + prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions} + + def _has_comms(action: _Action) -> bool: + if action.computation_type == F: + return action.stage_index != num_stages - 1 and stage_to_rank( + action.stage_index + 1 + ) != stage_to_rank(action.stage_index) + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + return action.stage_index != 0 and stage_to_rank( + action.stage_index - 1 + ) != stage_to_rank(action.stage_index) + return False + + def _get_comms(action: _Action) -> tuple[_Action, _Action]: + assert _has_comms(action), f"{action} is not a valid comm action" + stage_idx = action.stage_index + ctype = action.computation_type + mb_idx = action.microbatch_index + send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx) + recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1 + recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx) + return send, recv + + def _ready_to_schedule( + action: Optional[_Action], prev_actions: set[_Action] + ) -> bool: + """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place. + This helps ensure a sane (non-hanging) ordering of sends and recvs. + But it also means we might not be able to schedule our next compute action yet. + """ + if action is None: + return True + elif action.computation_type == F and not action.stage_index == 0: + if ( + _Action(action.stage_index, RECV_F, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) + in prev_actions + ): + return True + return False + elif ( + action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD) + and not action.stage_index == num_stages - 1 + ): + if ( + _Action(action.stage_index, RECV_B, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_actions + ): + return True + elif ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_actions + ): + return True + return False + else: + return True + + while compute_actions: + progress = False + # go in order of ranks even if dict keys aren't ordered + for rank in sorted(compute_actions): + assert len(compute_actions[rank]) > 0, ( + f"{rank=}, {len(compute_actions[rank])=}" + ) + action = compute_actions[rank][0] + + if not _ready_to_schedule(action, prev_actions[rank]): + continue + + if action is not None: + comm_actions[rank].append(action) + prev_actions[rank].add(action) + if _has_comms(action): + send, recv = _get_comms(action) + # TODO we can avoid send/recv if the 2 stages are on the same rank. + # should we avoid that in the runtime or here? + comm_actions[rank].append(send) + prev_actions[rank].add(send) + comm_actions[stage_to_rank(recv.stage_index)].append(recv) + prev_actions[stage_to_rank(recv.stage_index)].add(recv) + + compute_actions[rank].pop(0) + if len(compute_actions[rank]) == 0: + del compute_actions[rank] + progress = True + assert progress, "Malformed compute schedule, can't schedule sends/recvs" + return comm_actions + + +def _validate_schedule( + actions: dict[int, list[Optional[_Action]]], + pp_group_size: int, + num_stages: int, + num_microbatches: int, +) -> dict[int, int]: + assert len(actions) == pp_group_size, ( + f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}" + ) + for rank in range(pp_group_size): + assert rank in actions, f"Schedule is missing actions for rank {rank}" + + # We will count all the actions per stage and ensure they happen in a valid order + # (e.g. F before (B, I) before W for a given microbatch) + stage_actions: dict[int, dict[_ComputationType, set]] = { + stage_id: { + F: set(), + B: set(), + I: set(), + W: set(), + } + for stage_id in range(num_stages) + } + stage_index_to_rank_mapping = {} + for rank in actions: + for action in actions[rank]: + if action is None: + continue + assert isinstance(action, _Action), ( + f"Got an invalid action: {action}, expected instance of _Action" + ) + s_id = action.stage_index + ctype = action.computation_type + mb_id = action.microbatch_index + if ctype == F: + stage_actions[s_id][F].add(mb_id) + elif ctype == B: + assert mb_id in stage_actions[s_id][F], ( + f"Running Full Backward for stage {s_id}, microbatch {mb_id} without first running Forward" + ) + stage_actions[s_id][B].add(mb_id) + elif ctype == I: + assert mb_id in stage_actions[s_id][F], ( + f"Running Backward Input for stage {s_id}, microbatch {mb_id} without first running Forward" + ) + stage_actions[s_id][I].add(mb_id) + elif ctype == W: + assert mb_id in stage_actions[s_id][I], ( + f"Running Backward Weight for stage {s_id}, microbatch {mb_id} without first running Backward Input" + ) + stage_actions[s_id][W].add(mb_id) + if s_id not in stage_index_to_rank_mapping: + stage_index_to_rank_mapping[s_id] = rank + else: + existing_rank = stage_index_to_rank_mapping[s_id] + assert rank == existing_rank, ( + f"Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}" + ) + + for s_id in stage_actions: + f_mb = len(stage_actions[s_id][F]) + b_mb = len(stage_actions[s_id][B]) + i_mb = len(stage_actions[s_id][I]) + w_mb = len(stage_actions[s_id][W]) + + assert f_mb == num_microbatches, ( + f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}" + ) + + assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, ( + f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \ + but got B={b_mb}, I={i_mb}, W={w_mb}" + ) + return stage_index_to_rank_mapping + + +class PipelineScheduleMulti(_PipelineSchedule): + """ + Base class for multi-stage schedules. + Implements the `step` method. + + Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting + should match the configuration of your loss_fn, which may either average losses (scale_grads=True) + or sum losses (scale_grads=False). + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + use_full_backward: Optional[bool] = None, + scale_grads: bool = True, + ): + # Init parent + super().__init__( + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + # Self attributes + self._stages = stages + self._num_stages = stages[0].num_stages + self.pp_group_size = stages[0].group_size + self.rank = stages[0].group_rank + # Set the pipeline stage states + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + # Set the same has_backward flag for stage object + for stage in self._stages: + stage.has_backward = self._has_backward + self._stages_initialized = False + + # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle + has_loss: bool = self._loss_fn is not None + self._should_compute_loss = lambda stage: stage.is_last and has_loss + + # This will be set during init of derived schedules + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + + if use_full_backward is not None: + logger.warning( + "Deprecation warning: 'use_full_backward' is no longer supported. " + "Simply stop passing it, and everything should still work fine." + ) + + def _initialize_stages(self, args: tuple[Any, ...], kwargs): + # may be 'none' value (if this stage sends its output shapes to the next stage via P2P) + # or real value (if this stage and next stage are on the same device) + next_stage_args: tuple[Any, ...] = tuple() + for stage in self._stages: + if stage.is_first: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, args, kwargs + ) + else: + next_stage_args = stage._prepare_forward_infra( + self._n_microbatches, next_stage_args, kwargs + ) + + if self._has_backward: + stage._prepare_backward_infra(self._n_microbatches) + self._stages_initialized = True + + def _validate_and_set_stage_mapping( + self, actions: dict[int, list[Optional[_Action]]] + ) -> None: + """ + Allocates the stage index to rank mapping which is needed for communication + """ + self.stage_index_to_group_rank = _validate_schedule( + actions, + self.pp_group_size, + self._num_stages, + self._n_microbatches, + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + def _dump_csv(self, filename): + """Dump a CSV representation of the schedule into a file with the provided filename.""" + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order: + writer.writerow(self.pipeline_order[rank]) + + def _load_csv(self, filename, format="compute_only"): + """Load a CSV representation of the schedule from a file with the provided filename. + This API will most likely get renamed/refactored so is marked as internal for now. + + format must be "compute_only" for PipelineScheduleMulti. + """ + assert format == "compute_only" + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + self.pipeline_order[rank] = [_Action.from_str(s) for s in row] + + # Validates the order of the pipeline actions and infers the stage_to_rank_mapping. + # This will overwrite the default stage_to_rank_mapping created in the constructor + self._validate_and_set_stage_mapping(self.pipeline_order) + + def step(self, *args, target=None, losses: Optional[list] = None, **kwargs): + """ + Run one iteration of the pipeline schedule with *whole-batch* input. + Will chunk the input into microbatches automatically, and go through the + microbatches according to the schedule implementation. + + args: positional arguments to the model (as in non-pipeline case). + kwargs: keyword arguments to the model (as in non-pipeline case). + target: target for the loss function. + losses: a list to store the losses for each microbatch. + """ + # Clean per iteration + for stage in self._stages: + stage.clear_runtime_states() + + # Split inputs into microbatches + args_split, kwargs_split = self._split_inputs(args, kwargs) + + # Split target into microbatches + if target is not None: + targets_split = list(torch.tensor_split(target, self._n_microbatches)) + else: + targets_split = None + + # Run microbatches + self._step_microbatches(args_split, kwargs_split, targets_split, losses) + + # Return merged results per original format + for stage in self._stages: + if stage.is_last: + return self._merge_outputs(stage.output_chunks) + # Does not contain the last stage + return None + + def _step_microbatches( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + + if not self._stages_initialized: + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + # determine prev_rank and next_rank based on which ranks are next to + # the stages in the pipeline_order + all_prev_ranks: set[int] = set() + all_next_ranks: set[int] = set() + for stage_index in stage_index_to_stage.keys(): + # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections) + if stage_index > 0: + all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1]) + if stage_index < self._num_stages - 1: + all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1]) + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order[self.rank]): + try: + ops: list[dist.P2POp] = [] + if action is not None: + computation_type = action.computation_type + mb_index = action.microbatch_index + stage_index = action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + if computation_type == _ComputationType.FORWARD: + # perform forward computation + stage = stage_index_to_stage[stage_index] + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + ops.extend(stage.get_fwd_send_ops(mb_index)) + elif computation_type == _ComputationType.FULL_BACKWARD: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_index] += 1 + last_backward = ( + backward_counter[stage_index] == self._n_microbatches + ) + grad_scale_factor = ( + self._n_microbatches if self.scale_grads else 1 + ) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_INPUT: + # perform backward computation + stage = stage_index_to_stage[stage_index] + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + ops.extend(stage.get_bwd_send_ops(mb_index)) + elif computation_type == _ComputationType.BACKWARD_WEIGHT: + # perform weight update + stage = stage_index_to_stage[stage_index] + backward_counter[stage_index] += 1 + last_backward = ( + backward_counter[stage_index] == self._n_microbatches + ) + grad_scale_factor = ( + self._n_microbatches if self.scale_grads else 1 + ) + stage.backward_weight_one_chunk( + mb_index, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + else: + raise ValueError(f"Unknown computation type {computation_type}") + + # Look at the neighboring ranks for this current timestep and determine whether + # this current rank needs to do any recv communication + for prev_rank in all_prev_ranks: + prev_rank_ops = self.pipeline_order[prev_rank] + prev_rank_action = None + if time_step < len(prev_rank_ops): + prev_rank_action = prev_rank_ops[time_step] + if prev_rank_action is not None: + computation_type = prev_rank_action.computation_type + mb_index = prev_rank_action.microbatch_index + stage_index = prev_rank_action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + # Only handle sends for the forward from a previous rank + if computation_type == _ComputationType.FORWARD: + # If not the last stage, then receive fwd activations + if stage_index + 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage-1 + # however that is not necessarily true of get_fwd_recv_ops + stage = stage_index_to_stage[stage_index + 1] + ops.extend(stage.get_fwd_recv_ops(mb_index)) + elif computation_type in ( + FULL_BACKWARD, + BACKWARD_INPUT, + BACKWARD_WEIGHT, + ): + # Previous rank doing backward has no influence for the current rank forward recv + pass + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + for next_rank in all_next_ranks: + next_rank_ops = self.pipeline_order[next_rank] + next_rank_action = None + if time_step < len(next_rank_ops): + next_rank_action = next_rank_ops[time_step] + if next_rank_action is not None: + computation_type = next_rank_action.computation_type + mb_index = next_rank_action.microbatch_index + stage_index = next_rank_action.stage_index + assert mb_index is not None, ( + "All currently supported action types require valid microbatch_index" + ) + # Only handle receives for the backwards from a next rank + if computation_type in (FORWARD, BACKWARD_WEIGHT): + # Next rank doing forward or weight update has no influence for the current rank backward recv + pass + elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + # If not the first stage, then receive bwd gradients + if stage_index - 1 in stage_index_to_stage: + # TODO: We are assuming that stage will always receive from stage+1 + # however that is not necessarily true of get_bwd_recv_ops + stage = stage_index_to_stage[stage_index - 1] + ops.extend(stage.get_bwd_recv_ops(mb_index)) + else: + raise ValueError( + f"Unknown computation type {computation_type}" + ) + + # do the communication + _wait_batch_p2p(_batch_p2p(ops)) + except Exception as e: + logger.error( + "[Rank %s] pipeline schedule %s caught the following exception \ + at time_step %s when running action %s", + self.rank, + self.__class__.__name__, + time_step, + action, + ) + logger.error( + "%s", + _format_pipeline_order( + self.pipeline_order, error_step_number=time_step + ), + ) + raise e + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class _PipelineScheduleRuntime(PipelineScheduleMulti): + """ + Provides a simple runtime that requires a 'schedule IR' including specified communication operations. + + Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be + subclassed and the subclass can be responsible for creating a schedule IR. + """ + + def _load_actions( + self, + actions: dict[int, list[Optional[_Action]]], + format: str = "compute_only", + ): + """ + Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including + communication actions. Stores the schedule in self, and must be called before running step_mo() + """ + # validate the provided actions are valid and overrides the default stage_index_to_group_rank + super()._validate_and_set_stage_mapping(actions) + + self.pipeline_order_with_comms: dict[int, list[_Action]] = {} + if format == "compute_comms": + for rank in actions: + self.pipeline_order_with_comms[rank] = [] + for action in actions[rank]: + assert action is not None + self.pipeline_order_with_comms[rank].append(action) + # TODO what level of validation should we offer for compute+comms schedule? + elif format == "compute_only": + # Perform schedule lowering + for rank in actions: + self.pipeline_order_with_comms[rank] = _add_unshard_reshard( + actions[rank] + ) + + self.pipeline_order_with_comms = _add_send_recv( + self.pipeline_order_with_comms, + stage_to_rank=lambda s: self.stage_index_to_group_rank[s], + num_stages=self._num_stages, + ) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _load_csv(self, filename: str, format: str = "compute_only"): + """Loads a csv in simple format and then lowers it to include communication actions + + format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes + will automatically be run to generate a compute_comms schedule. + """ + if format == "compute_only": + # this will populate self.pipeline_order + super()._load_csv(filename) + # this will populate self.pipeline_order_with_comms + self._load_actions(self.pipeline_order) + elif format == "compute_comms": + actions = {} + with open(filename, newline="") as csvfile: + reader = csv.reader(csvfile) + for rank, row in enumerate(reader): + actions[rank] = [_Action.from_str(s) for s in row] + self._load_actions(actions, format=format) + else: + raise NotImplementedError(f"{format=} is not implemented") + + def _dump_csv(self, filename: str): + """Dump a CSV representation of the compute + comms schedule into a file with the provided filename.""" + # TODO should there be an option to dump the compute_only schedule from PipelineScheduleRuntime? It's possible + # that it does not exist if it was created from a compute_comms schedule. + assert self.pipeline_order_with_comms is not None, ( + "Must initialize compute_comms schedule before dump_csv" + ) + with open(filename, "w", newline="") as csvfile: + writer = csv.writer(csvfile) + for rank in self.pipeline_order_with_comms: + writer.writerow(self.pipeline_order_with_comms[rank]) + + def _simulate(self): + return _simulate_comms_compute( + self.pipeline_order_with_comms, + lambda s: self.stage_index_to_group_rank[s], + self._num_stages, + ) + + def _step_microbatches( + self, + arg_mbs: Optional[list] = None, + kwarg_mbs: Optional[list] = None, + target_mbs: Optional[list] = None, + losses: Optional[list] = None, + ): + """ + Operate on the microbatches for looped schedules (multiple stages on each rank). + + TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does + not support models with skip connections. + """ + arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses) + if not self._stages_initialized: + self._initialize_stages(arg_mbs[0], kwarg_mbs[0]) + + # Based on the plan in Step 1 created in __init__: + # 2. Perform communication based on the pipeline_order + stage_index_to_stage: dict[int, _PipelineStageBase] = { + stage.stage_index: stage for stage in self._stages + } + + assert self.pipeline_order_with_comms is not None, ( + "Must call _load_actions() before calling _step_microbatches()" + ) + + # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use + bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} + fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {} + + # send ops should be waited on before step() exists, mainly for hygiene + send_ops: list[list[dist.Work]] = [] + + # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages + unshard_ops: dict[int, UnshardHandle] = {} + unsharded_stages = set() + + def _assert_unsharded(stage_idx: int): + """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared.""" + if stage_idx in unshard_ops: + unshard_ops[stage_idx].wait() + del unshard_ops[stage_idx] + unsharded_stages.add(stage_idx) + assert stage_idx in unsharded_stages, ( + f"Attempted to compute on sharded {stage_idx=}" + ) + + # count either full_backward or backward_weight together, to determine when to sync DP grads + backward_counter: Counter[int] = Counter() + for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]): + try: + comp_type = action.computation_type + mb_index: int = ( + action.microbatch_index + if action.microbatch_index is not None + else -1 + ) + assert mb_index >= 0 or comp_type in ( + UNSHARD, + RESHARD, + ), f"{action=} missing mb_index" + stage_idx = action.stage_index + stage = stage_index_to_stage[stage_idx] + stage_uses_fsdp = isinstance(stage.submod, FSDPModule) + # see [Note: V-schedule special case] + is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage + is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage + + logger.debug( + "_PipelineScheduleRuntime running time_step %d, action %s", + time_step, + action, + ) + + # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections, + # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be + # safe to use instead. + # However, I was wondering if I should avoid calling batched operators at all in the case that there is + # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them. + if comp_type == SEND_F: + send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index))) + elif comp_type == SEND_B: + send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index))) + elif comp_type == RECV_F: + assert ( + stage_idx, + mb_index, + ) not in fwd_recv_ops, ( + "Recv twice for {stage_idx=} {mb_index=} without executing forward" + ) + fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_fwd_recv_ops(mb_index) + ) + elif comp_type == RECV_B: + assert ( + stage_idx, + mb_index, + ) not in bwd_recv_ops, ( + "Recv twice for {stage_idx=} {mb_index=} without executing backward" + ) + bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p( + stage.get_bwd_recv_ops(mb_index) + ) + elif comp_type == UNSHARD: + if stage_uses_fsdp: + assert ( + stage_idx not in unsharded_stages + and stage_idx not in unshard_ops + ), f"Unsharding the same {stage_idx=} twice" + unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator] + elif comp_type == RESHARD: + if stage_uses_fsdp: + assert stage_idx in unsharded_stages, ( + f"Resharding {stage_idx=} without unsharding" + ) + assert stage_idx not in unshard_ops, ( + f"Resharding {stage_idx=} before finishing unshard" + ) + stage.submod.reshard() # type: ignore[operator] + elif comp_type == FORWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if ( + not stage.is_first + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_prev_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in fwd_recv_ops, f"Computing {action=} before receiving input" + _wait_batch_p2p(fwd_recv_ops.pop((stage_idx, mb_index))) + + output = stage.forward_one_chunk( + mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index] + ) + self._maybe_compute_loss(stage, output, target_mbs, mb_index) + + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_next_stage_on_this_rank: + stage_index_to_stage[stage_idx + 1].set_local_fwd_input( + output, mb_index + ) + + elif comp_type == FULL_BACKWARD: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if ( + not stage.is_last + # no recv op expected for V-schedule special case (see [Note: V-schedule special case]) + and not is_next_stage_on_this_rank + ): + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + backward_counter[stage_idx] += 1 + last_backward = backward_counter[stage_idx] == self._n_microbatches + grad_scale_factor = self._n_microbatches if self.scale_grads else 1 + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=True, + last_backward=last_backward, + ) + if last_backward: + stage.scale_grads(grad_scale_factor) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_INPUT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + + if not stage.is_last and not is_next_stage_on_this_rank: + assert ( + stage_idx, + mb_index, + ) in bwd_recv_ops, ( + f"Attempted to run compute {action=} before receiving input" + ) + _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index))) + loss = self._maybe_get_loss(stage, mb_index) + stage.backward_one_chunk( + mb_index, + loss=loss, + full_backward=False, + last_backward=False, + ) + # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank + # see [Note: V-schedule special case] + if is_prev_stage_on_this_rank: + stage_index_to_stage[stage_idx - 1].set_local_bwd_input( + stage.get_local_bwd_output(mb_index), mb_index + ) + elif comp_type == BACKWARD_WEIGHT: + if stage_uses_fsdp: + _assert_unsharded(stage_idx) + backward_counter[stage_idx] += 1 + stage.backward_weight_one_chunk( + mb_index, + last_backward=backward_counter[stage_idx] + == self._n_microbatches, + ) + else: + raise ValueError(f"{action=} is unknown or unsupported") + except Exception as e: + logger.error( + "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:", + time_step, + action, + ) + # TODO(whc) what is the best practice for printing a multiline log? + # logger will split it into multiple log lines, but this makes it hard to read (too wide) + print( + _format_pipeline_order( + self.pipeline_order_with_comms, # type: ignore[arg-type] + error_step_number=time_step, + ) + ) + raise e + + # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them + while len(send_ops): + _wait_batch_p2p(send_ops.pop()) + + assert len(unshard_ops) == 0, "Unused unshard operations" + + # Return losses if there is a container passed in + self._update_losses(self._stages, losses) + + +class ScheduleLoopedBFS(PipelineScheduleMulti): + """ + Breadth-First Pipeline Parallelism. + See https://arxiv.org/abs/2211.05953 for details. + Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank. + What is different is that when microbatches are ready for multiple local + stages, Loops BFS will prioritizes the earlier stage, running all available + microbatches at once. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Union[Callable, _Loss]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + scale_grads: bool = True, + ): + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + # ======================================================================== + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank): + n_local_stages = len(self._stages) + stage_indices = range( + rank, self.pp_group_size * n_local_stages, self.pp_group_size + ) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] + + for stage_index in stage_indices: + rank_ops.extend( + _Action(stage_index, _ComputationType.FORWARD, mb_index) + for mb_index in range(self._n_microbatches) + ) + + # wait for the first backward to trickle up + # which is 2 for every hop away + post_warmup_ops = 2 * (self.pp_group_size - 1 - rank) + rank_ops.extend([None] * post_warmup_ops) + + for stage_index in reversed(stage_indices): + rank_ops.extend( + _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index) + for mb_index in reversed(range(self._n_microbatches)) + ) + return rank_ops + + +def _get_1f1b_rank_ops( + n_local_stages, + pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches=0, + enable_zero_bubble=False, +): + # All stages start with handling microbatch 0 + fwd_stage_mb_index: dict[int, int] = defaultdict(int) + bwd_stage_mb_index: dict[int, int] = defaultdict(int) + weight_stage_mb_index: dict[int, int] = defaultdict(int) + + # Store the list of operations used for that rank + # Pre-padding, rank starts with no-ops based on the warmup. + rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] + # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup + # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks. + # Formula: + # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward + # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding) + # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)] + # warmup_ops = calculated above + post_warmup_ops = ( + n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank) + ) - (warmup_ops + rank) + + if enable_zero_bubble: + post_warmup_ops = pp_group_size - rank - 1 + + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + + backward_op_ids = [] + weight_op_count = 0 + + FULL_BACKWARD_OR_BACKWARD_INPUT = ( + BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD + ) + + for op in range(total_ops): + # Warmup phase + if op < warmup_ops: + fwd_stage_index = forward_stage_index(op) + # This will assign the current microbatch index and update it as well + fwd_stage_mb_index[fwd_stage_index] = ( + mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index) + ) + if op == warmup_ops - 1: + # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up + rank_ops.extend([None] * post_warmup_ops) + # 1F1B Phase (forward and backward) + elif warmup_ops <= op < warmup_ops + fwd_bwd_ops: + fwd_stage_index = forward_stage_index(op) + fwd_stage_mb_index[fwd_stage_index] = ( + fwd_mb_index := fwd_stage_mb_index[fwd_stage_index] + ) + 1 + rank_ops.append( + _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index) + ) + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + # Cooldown phase + else: + # During cooldown phase, we need steps to align with 1f1b happening in other ranks + # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None + if not enable_zero_bubble: + rank_ops.append(None) + + bwd_stage_index = backward_stage_index(op) + bwd_stage_mb_index[bwd_stage_index] = ( + bwd_mb_index := bwd_stage_mb_index[bwd_stage_index] + ) + 1 + rank_ops.append( + _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index) + ) + backward_op_ids.append(op) + + if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches: + weight_stage_index = backward_stage_index( + backward_op_ids[weight_op_count] + ) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, + _ComputationType.BACKWARD_WEIGHT, + weight_mb_index, + ) + ) + weight_op_count += 1 + + while enable_zero_bubble and weight_op_count < len(backward_op_ids): + weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count]) + weight_stage_mb_index[weight_stage_index] = ( + weight_mb_index := weight_stage_mb_index[weight_stage_index] + ) + 1 + rank_ops.append( + _Action( + weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index + ) + ) + weight_op_count += 1 + + return rank_ops + + +class ScheduleInterleaved1F1B(PipelineScheduleMulti): + """ + The Interleaved 1F1B schedule. + See https://arxiv.org/pdf/2104.04473 for details. + Will perform one forward and one backward on the microbatches in steady + state and supports multiple stages per rank. When microbatches are ready for + multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch + (also called "depth first"). + + This schedule is mostly similar to the original paper. + It differs by being relaxing the requirement of num_microbatch % pp_size == 0. + Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and + it works as long as n_microbatches % num_rounds is 0. As a few examples, support + + 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0. + 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + scale_grads: bool = True, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Interleaved 1F1B requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 2 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + ) + + +class ScheduleInterleavedZeroBubble(PipelineScheduleMulti): + """ + The Interleaved Zero Bubble schedule. + See https://arxiv.org/pdf/2401.10241 for details. + Will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses the backward for weights to fill in + the pipeline bubble. + + In particular this is implementing the ZB1P schedule in the paper. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + scale_grads: bool = True, + ): + # TODO: we don't support Zero Bubble with torch.compile so we + # should disable it for now + for stage in stages: + if isinstance(stage.submod, OptimizedModule): + raise RuntimeError( + "The Zero Bubble schedule is not supported with \ +stage modules that have used torch.compile" + ) + + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + self.n_local_stages = len(stages) + self.rank = stages[0].group_rank + self.number_of_rounds = max(1, n_microbatches // self.pp_group_size) + self.microbatches_per_round = n_microbatches // self.number_of_rounds + if n_microbatches % self.number_of_rounds != 0: + raise ValueError( + "Zero bubble requires the number of microbatches to be a " + f"multiple of the number of rounds ({self.number_of_rounds}), " + f"but got {n_microbatches}." + ) + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + # This function add bubbles to the generated schedule based on dependencies of actions + # Note that the ZB1P schedule will not require bubbles to be manually added and it is + # only useful when n_microbatches <= microbatches_per_round + self.pipeline_order = self._add_bubbles_to_actions( + self.n_local_stages * self.pp_group_size, + ) + + def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: + def get_rank_warmup_ops(rank): + # Warms up operations for last stage + warmups_ops_last_stage = ( + self.n_local_stages - 1 + ) * self.microbatches_per_round + # Increment warmup operations by 2 for each hop away from the last stage + multiply_factor = 1 + warmup_ops = warmups_ops_last_stage + multiply_factor * ( + (self.pp_group_size - 1) - rank + ) + + # We cannot have more warmup operations than there are number of microbatches, so cap it there + return min(warmup_ops, self._n_microbatches * self.n_local_stages) + + warmup_ops = get_rank_warmup_ops(rank) + microbatch_ops = self.n_local_stages * self._n_microbatches + # fwd_bwd_ops should encompass the remaining forwards + fwd_bwd_ops = microbatch_ops - warmup_ops + # cooldown_ops should encompass the remaining backwards + cooldown_ops = microbatch_ops - fwd_bwd_ops + # total ops encompass both forward and backward ops + total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops + # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2 + logger.debug( + "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s", + rank, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + total_ops, + ) + + # Calculates the stage index based on step and pp_group_size + + def forward_stage_index(step): + # Get the local index from 0 to n_local_stages-1 + local_index = (step // self.microbatches_per_round) % self.n_local_stages + return (local_index * self.pp_group_size) + rank + + def backward_stage_index(step): + local_index = ( + self.n_local_stages + - 1 + - ((step - warmup_ops) // self.microbatches_per_round) + % self.n_local_stages + ) + return (local_index * self.pp_group_size) + rank + + num_1f1b_microbatches = rank + + return _get_1f1b_rank_ops( + self.n_local_stages, + self.pp_group_size, + warmup_ops, + fwd_bwd_ops, + cooldown_ops, + rank, + forward_stage_index, + backward_stage_index, + num_1f1b_microbatches, + enable_zero_bubble=True, + ) + + def _add_bubbles_to_actions(self, num_stages_global): + actions = self.pipeline_order + + def need_bubble(stage, op, microbatch, num_stages_global, seen_ops): + if op == _ComputationType.FORWARD: + if stage != 0 and (stage - 1, op, microbatch) not in seen_ops: + return True + elif op == _ComputationType.FULL_BACKWARD: + if stage == num_stages_global - 1: + return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops + return (stage + 1, op, microbatch) not in seen_ops + return False + + seen_ops: set[tuple[int, _ComputationType, int]] = set() + result: dict[int, list[Optional[_Action]]] = {} + next_pointer: dict[int, int] = {} + bubbles_added: dict[int, int] = {} + total_bubbles_added = 0 + + for rank in range(self.pp_group_size): + result[rank] = [] + next_pointer[rank] = 0 + bubbles_added[rank] = 0 + + while True: + should_stop = True + + temp_seen_ops: set[tuple[int, _ComputationType, int]] = set() + + for rank in range(self.pp_group_size): + timestamp = next_pointer[rank] + if timestamp >= len(actions[rank]): + continue + + should_stop = False + + if actions[rank][timestamp] is not None: + temp_action = actions[rank][timestamp] + assert temp_action is not None + stage_index, op, microbatch = temp_action + if not need_bubble( + stage_index, op, microbatch, num_stages_global, seen_ops + ): + result[rank].append(actions[rank][timestamp]) + if microbatch is not None: + temp_seen_ops.add((stage_index, op, microbatch)) + next_pointer[rank] += 1 + else: + result[rank].append(None) + bubbles_added[rank] += 1 + else: + next_pointer[rank] += 1 + result[rank].append(None) + + seen_ops.update(temp_seen_ops) + if should_stop: + break + + if total_bubbles_added > 0: + logger.warning( + "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s", + total_bubbles_added, + bubbles_added, + ) + return result + + +class ScheduleZBVZeroBubble(PipelineScheduleMulti): + """ + The Zero Bubble schedule (ZBV variant). + See https://arxiv.org/pdf/2401.10241 Section 6 for details. + + This schedules requires exactly two stages per rank. + + This schedule will perform one forward and one backward on inputs for the microbatches in steady + state and supports multiple stages per rank. Uses backward with respect to weights to fill in + the pipeline bubble. + + This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights. + In practice, this is not likely true for real models so alternatively + a greedy scheduler could be implemented for unequal/unbalanced time. + """ + + def __init__( + self, + stages: list[_PipelineStageBase], + n_microbatches: int, + loss_fn: Optional[Callable] = None, + args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None, + kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None, + output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None, + scale_grads: bool = True, + ): + self.pp_group_size = stages[0].group_size + super().__init__( + stages=stages, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + args_chunk_spec=args_chunk_spec, + kwargs_chunk_spec=kwargs_chunk_spec, + output_merge_spec=output_merge_spec, + scale_grads=scale_grads, + ) + self.stage_index_to_group_rank = generate_stage_to_rank_mapping( + self.pp_group_size, self._num_stages, style="v" + ) + for stage in self._stages: + stage.stage_index_to_group_rank = self.stage_index_to_group_rank + + self.n_local_stages = len(stages) + if self.n_local_stages != 2: + raise ValueError( + "ZBV requires exactly 2 stages per rank, but got " + f"{self.n_local_stages}." + ) + + self.rank = stages[0].group_rank + self.num_stages = stages[0].num_stages + + # 1. Create the pipeline_order (all ranks do this calculation) + # This will be used to keep track of the current state of the entire pipeline + # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...] + self.pipeline_order: dict[int, list[Optional[_Action]]] = {} + for rank in range(self.pp_group_size): + rank_ops = self._calculate_single_rank_operations(rank) + self.pipeline_order[rank] = rank_ops + + def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]: + # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least + # as large of the number of microbatches needed to fully utilize the pipeline + n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches) + rank_ops: list[Optional[_Action]] = [None for _ in range(rank)] + + # Forward and backward action counts for stage chunk 0 and chunk 1 + f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0 + # warm-up phase + warmup_n1 = 2 * (self.pp_group_size - rank) - 1 + stage_id_chunk0 = rank + stage_id_chunk1 = self.num_stages - 1 - rank + + for _ in range(warmup_n1): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt) + ) + f0_cnt += 1 + warmup_n2 = rank + for _ in range(warmup_n2): + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt) + ) + f0_cnt += 1 + warmup_n3 = self.pp_group_size - rank + for _ in range(warmup_n3): + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + # stable phase + while f1_cnt < f0_cnt or f0_cnt < n_micro: + if f0_cnt < n_micro: + rank_ops.append( + _Action( + stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt + ) + ) + f0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + + rank_ops.append( + _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt) + ) + f1_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + # cool-down phase + w0_cnt, w1_cnt = b0_cnt, b1_cnt + cooldown_n1 = rank + for _ in range(cooldown_n1): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt) + ) + b1_cnt += 1 + cooldown_n2 = self.pp_group_size - rank + for _ in range(cooldown_n2): + rank_ops.append( + _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt) + ) + b0_cnt += 1 + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt) + ) + w0_cnt += 1 + while w1_cnt < b1_cnt: + rank_ops.append( + _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt) + ) + w1_cnt += 1 + while w0_cnt < b0_cnt: + rank_ops.append( + _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt) + ) + w0_cnt += 1 + + assert w0_cnt == b0_cnt and b0_cnt == f0_cnt + assert w1_cnt == b1_cnt and b1_cnt == f1_cnt + # We use max() in the n_micro computation above, so we may need to + # remove redundant microbatches + rank_ops = [ + ( + action + if action is not None + and action.microbatch_index is not None + and action.microbatch_index < self._n_microbatches + else None + ) + for action in rank_ops + ] + return rank_ops + + +def get_schedule_class(schedule_name: str): + """ + Maps a schedule name (case insensitive) to its corresponding class object. + + Args: + schedule_name (str): The name of the schedule. + """ + schedule_map = { + "1F1B": Schedule1F1B, + "Interleaved1F1B": ScheduleInterleaved1F1B, + "GPipe": ScheduleGPipe, + "LoopedBFS": ScheduleLoopedBFS, + "InterleavedZeroBubble": ScheduleInterleavedZeroBubble, + "PipelineScheduleSingle": PipelineScheduleSingle, + "PipelineScheduleMulti": PipelineScheduleMulti, + "ZBVZeroBubble": ScheduleZBVZeroBubble, + } + lowercase_keys = {k.lower(): k for k in schedule_map.keys()} + lowercase_schedule_name = schedule_name.lower() + if lowercase_schedule_name not in lowercase_keys: + raise ValueError( + f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}" + ) + return schedule_map[lowercase_keys[lowercase_schedule_name]] + + +def _simulate_comms_compute( + pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int +): + """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags + any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank + can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used + as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number + of simulated steps. + + The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams. + Future work may be to enhance this and model the compute time, comms overlap, and even memory. + """ + pipeline_order = { + rank: [a for a in pipeline_order[rank] if a is not None] + for rank in sorted(pipeline_order) + } + _schedule: dict[int, list[_Action | None]] = { + rank: [] for rank in sorted(pipeline_order) + } + + _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule} + + def add_to_schedule(rank: int, action: Optional[_Action]): + _schedule[rank].append(action) + if action is not None: + _prev_ops_rank[rank].add(action) + + def _ready_to_schedule(action: Optional[_Action]) -> bool: + if action is None: + return True + + stage_idx = action.stage_index + prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)] + if action.computation_type == F: + if action.stage_index == 0: + return True + elif ( + _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops + ): + return True + elif ( + _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops + ): + return True + return False + elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD): + if action.stage_index == num_stages - 1: + return True + if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops: + return True + if ( + _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index) + in prev_ops + ): + return True + if ( + _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index) + in prev_ops + ): + return True + return False + elif action.computation_type == BACKWARD_WEIGHT: + return True + elif action.computation_type == SEND_F: + expected_f = _Action(action.stage_index, F, action.microbatch_index) + return expected_f in prev_ops + elif action.computation_type == RECV_F: + peer_stage_idx = stage_idx - 1 + expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + elif action.computation_type == SEND_B: + expected_b = _Action( + action.stage_index, BACKWARD_INPUT, action.microbatch_index + ) + expected_bw = _Action( + action.stage_index, FULL_BACKWARD, action.microbatch_index + ) + return expected_b in prev_ops or expected_bw in prev_ops + elif action.computation_type == RECV_B: + peer_stage_idx = stage_idx + 1 + expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index) + return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)] + else: + raise ValueError(f"Unsupported action type {action}") + + while pipeline_order: + progress = False + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + add_to_schedule(rank, action) + pipeline_order[rank].pop(0) + progress = True + else: + add_to_schedule(rank, None) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked + # by one of the later ranks + for rank in sorted(pipeline_order): + if len(pipeline_order[rank]) == 0: + continue + + if _schedule[rank][-1] is not None: + continue + + action = pipeline_order[rank][0] + if _ready_to_schedule(action): + if action is not None: + _schedule[rank][-1] = action + _prev_ops_rank[rank].add(action) + pipeline_order[rank].pop(0) + + for i in sorted(pipeline_order, reverse=True): + if len(pipeline_order[i]) == 0: + del pipeline_order[i] + + if not progress: + print("WIP comms schedule:\n", _format_pipeline_order(_schedule)) + for rank in pipeline_order: + print(f"{rank=} next action= {pipeline_order[rank][0]}") + raise ValueError("Schedule is not progressing") + + return _schedule + + +def _dump_chrometrace(schedule, filename): + """ + This function dumps a schedule IR into a chrometrace format so it can be visualized. + + It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text. + + As future work we may extend this to include more accurate heuristics for durations, or let users input durations, + add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute + as separate streams on the chrometrace view. + """ + events = [] + for rank in sorted(schedule): + for timestep, action in enumerate(schedule[rank]): + if action is None: + continue + events.append( + { + "name": str(action), + "cat": ( + "computation" + if action.computation_type in (F, B, W) + else "communication" + ), + "ph": "X", + "pid": rank, + "tid": rank, + "ts": timestep, + "dur": 1, + } + ) + import json + + with open(filename, "w") as f: + json.dump({"traceEvents": events}, f) diff --git a/phivenv/Lib/site-packages/torch/distributed/pipelining/stage.py b/phivenv/Lib/site-packages/torch/distributed/pipelining/stage.py new file mode 100644 index 0000000000000000000000000000000000000000..4e8684dc2c23a67ec76641c3bdae23eb7a8bade6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/pipelining/stage.py @@ -0,0 +1,1509 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +import operator +from abc import ABC, abstractmethod +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch.distributed as dist +import torch.fx as fx +import torch.nn as nn +from torch._subclasses.fake_tensor import FakeTensor +from torch.distributed.fsdp import FSDPModule, fully_shard +from torch.fx.node import Argument, map_aggregate +from torch.nn.parallel import DistributedDataParallel +from torch.utils._pytree import tree_map_only + +from ._backward import stage_backward, stage_backward_input, stage_backward_weight +from ._debug import map_debug_info +from ._utils import flatten_args, PipeInfo, validate_tensors_metadata + + +__all__ = [ + "PipelineStage", + "build_stage", +] + +logger = logging.getLogger(__name__) + + +def _normalize_model_output_as_tuple(output: Any) -> tuple[Any]: + """[Note: pipeline model output type] + + The output of the model passed to pipelining can be any type, controlled by the user. + + However, there are 2 API surfaces that complicate this. + (1) the outputs of intermediate stages are passed via Send/Recv ops to subsequent stages. The implicit assumption + is that each element of the outputs is a tensor. Otherwise, Send/Recv would not be supported. The exception + is the last layer of the model, which can output anything any which won't be communicated via Send/Recv. + (2) the outputs of the last layer of the model are returned to the user, or, passed to the loss function. + The loss function can be written in any way, such that its inputs match the outputs of the model. + + It would be convenient if we could strictly type the output signature of the pipeline stage wrapping the model, + but we do not want to impose an unnecessary constraint on user provided models. + + Currently, we let user provided models return either a Tensor or a tuple of Tensors from each stage. Due to + torch.export tracing, compiled models may also return a list instead of a Tuple, which we will normalize back to a + tuple for consistency. + + TODO: should we be stricter about asserting that stage modules (intermediate and output) all return only Tensor + values? + """ + if type(output) is list: + # HACK: this is a hacky workaround for the fact that export creates + # output in list format + output = tuple(output) + + # Unify output form to tuple for easy correspondence with + # `act_send_info` + output_tuple = output if type(output) is tuple else (output,) + return output_tuple + + +class _RootArgPlaceholder: + """ + Placeholder for model-level inputs. + """ + + def __init__(self, tensor): + self.meta = tensor.to("meta") + + +class _RecvInfo: + """ + Represents a stage input. + """ + + def __init__( + self, + input_name: str, + source: int, + buffer: torch.Tensor, + ): + # Name of this input + self.input_name = input_name + # Stage index of the source of this input + self.source = source + # Buffer to receive the input into. + self.buffer = buffer + + def __repr__(self): + return f"_RecvInfo(input={self.input_name}, source={self.source}, shape={self.buffer.size()})" + + +# An input can be either a received activation or a model input +InputInfo = Union[_RecvInfo, _RootArgPlaceholder] + + +def _make_tensor_from_meta( + example: Union[torch.Tensor, FakeTensor], + device: torch.device, +) -> torch.Tensor: + """ + Create a real tensor from a tensor. + """ + return torch.empty( + example.size(), + dtype=example.dtype, + layout=example.layout, + device=device, + ) + + +class _PipelineStageBase(ABC): + """ + Base class for pipeline stages. + Defines or implements common methods used by the `_PipelineStage` used by + the tracing frontend and `PipelineStage` used by manual frontend. + """ + + def __init__( + self, + submodule: torch.nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + """ + Args: + submodule (torch.nn.Module): The module to be executed in this stage. + stage_index (int): The index of this stage. + num_stages (int): The total number of stages in this pipeline. + device (torch.device): The device to run this stage on. + group (Optional[dist.ProcessGroup]): The process group to use for communication. + If `None`, the default process group will be used. + Default: `None`. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder is a builder function + that will build a new dw_runner function that will run parts of module backward that were intentionally + skipped during the module's actual backward pass. The builder must be invoked by stage after stage runs + model backwards, and stage should save the latest dw_runner to run during weight pas (W). + If not provided, a dw_runner will be generated automatically by traversing the autograd graph. + When used with schedules that only have F and B steps, the fresh dw_runner function will be called as + part of I (input backwards). When used with F,I,W schedules, the dw_runner function implements 'W'. + """ + super().__init__() + if stage_index >= num_stages: + raise ValueError( + f"Stage index {stage_index} is out of range of {num_stages}" + ) + + self.submod = submodule + self.stage_index = stage_index + self.num_stages = num_stages + self.device = device + self.group = group + + self.dw_builder = dw_builder + + # backward state + self.backward_state: dict[int, tuple[Any, ...]] = {} + + # store dw_runner per microbatch_id + self.dw_runner: dict[int, Callable[..., None]] = {} + + # `group_rank` is rank in process group `group`. + self.group_rank = dist.get_rank(self.group) + self.group_size = dist.get_world_size(self.group) + if self.group_size > self.num_stages: + raise RuntimeError( + f"Pipeline group size {self.group_size} cannot be larger than number of stages {self.num_stages}" + ) + + # Run time states + self._outputs_meta: Optional[tuple[torch.Tensor, ...]] = None + # map microbatch ID to list of forward tensor args + self.fwd_cache: dict[int, tuple[Any, list[torch.Tensor]]] = {} + # map microbatch ID to list of backward grad tensor args + self.bwd_cache: dict[int, tuple[Optional[torch.Tensor], ...]] = {} + # Caching chunk outputs for final output merge or reduction + self.output_chunks: list[Any] = [] + + # Initialize has_backward to false; this will be set to true if loss + # function is passed to pipeline schedule + self.has_backward = False + # Log prefix + self.log_prefix = f"[Stage {self.stage_index}]" + + # Forward infra + self.args_recv_info: dict[int, tuple[InputInfo, ...]] = {} + self.act_send_info: dict[int, list] = {} + + # Backward infra will created lazily + self.grad_recv_info: dict = {} + self.grad_send_info: Optional[list] = None + + # To be populated later by the Schedule + self.chunks: Optional[int] = None + self.stage_index_to_group_rank: dict[int, int] = { + i: i % self.group_size for i in range(self.num_stages) + } + + @property + def has_backward(self) -> bool: + """ + Returns true if this stage has a backward pass. + """ + return self._has_backward + + @has_backward.setter + def has_backward(self, has_backward: bool): + self._has_backward = has_backward + + @property + def is_first(self): + """ + Returns true if this stage is the first stage in the pipeline. + """ + return self.stage_index == 0 + + @property + def is_last(self): + """ + Returns true if this stage is the last stage in the pipeline. + """ + return self.stage_index == self.num_stages - 1 + + def _check_chunk_id(self, chunk_id: int): + if self.chunks is None: + raise RuntimeError( + "Attempted to access chunk_id before chunks have been configured." + ) + if chunk_id >= self.chunks: + raise RuntimeError( + f"Chunk id {chunk_id} is out of range [0, {self.chunks})" + ) + + def _configure_outputs_meta(self, outputs_meta: tuple[torch.Tensor, ...]): + """ + Track the output shapes/dtype of this stage since they determine the send operation(s) which must match + recv operations of the next stage. The next stage _will_ be freezing its recv buffers based on its initial + configuration, so it's important to also freeze/validate the output side to avoid any send/recv mismatches + which could show up as hangs, silent corruption, or other errors. + """ + assert self._outputs_meta is None, ( + "Attempting to reconfigure output_meta, which is not supported" + ) + self._outputs_meta = tuple(outputs_meta) # type: ignore[assignment] + + def get_outputs_meta(self) -> tuple[torch.Tensor, ...]: + """Get the output metadata (meta tensors) reprensenting the outputs of this stage""" + assert self._outputs_meta is not None, ( + "Attempted to get_outputs_meta() without configuring output meta" + ) + return self._outputs_meta + + def _create_grad_send_info( + self, + args_recv_info: tuple, + ) -> list[Optional[int]]: + """ + Create a list of stage indices to send gradients to. + """ + grad_send_info: list[Optional[int]] = [] + + def map_recv_to_send(a): + # Note: we send gradients back to previous stage as long as in + # forward it is a received input, regardless of whether it requires + # grad. It is up to the previous stage to discard this gradient. + if isinstance(a, _RecvInfo): + grad_send_info.append(a.source) + return a.source + else: + grad_send_info.append(None) + return None + + map_aggregate(args_recv_info, map_recv_to_send) + + logger.debug("%s Grad send info: %s", self.log_prefix, grad_send_info) + return grad_send_info + + @abstractmethod + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[Any, ...]: + raise NotImplementedError + + def _prepare_backward_infra(self, num_microbatches: int): + # TODO: this is needed for backward_maybe_with_nosync + self.chunks = num_microbatches + + for mb_index in range(num_microbatches): + # `grad_recv_info` is a mirror of `act_send_info` + self.grad_recv_info[mb_index] = self._create_grad_recv_info( + self.act_send_info + ) + + @abstractmethod + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + raise NotImplementedError + + def _get_recv_ops( + self, + recv_infos: tuple[InputInfo, ...], + ) -> list[dist.P2POp]: + """ + Helper function shared by `get_fwd_recv_ops` and `get_bwd_recv_ops`. + Returns a list of ops that correspond to the recv infos. + """ + ops: list[dist.P2POp] = [] + for info in recv_infos: + if not isinstance(info, _RecvInfo): + continue + + peer_rank = self.stage_index_to_group_rank[info.source] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append( + dist.P2POp(dist.irecv, info.buffer, peer_global_rank, self.group) + ) + + return ops + + """[Note: V-schedule special case] + + V-Schedules have a special case where 2 stages with adjacent stage_id are on the same rank. + + ex: 2 ranks, 4 stages forms a simple V: + rank0: stage 0 stage 3 + rank1: stage 1 stage 2 + + stage 0,1 and 2,3 communicate activations using send/recv as usual, but stage 1,2 do not need to + use communication ops. Instead, they should pass tensor data directly via function call. + + set_local_fwd_input and (get_local_bwd_output + set_local_bwd_input) facilitate this optimization, and + should be called at the appropriate time during the pipeline schedule (after forward or backward execution). + """ + + def set_local_fwd_input(self, prev_stage_outputs: Any, mb_index: int) -> None: + """ + Moves 'prev_stage_outputs' from another stage on the same rank into place as inputs for this stage. Avoids + copying tensor data or using send/recv op. Detaches original tensor and sets requires_grad so the + tensor can serve as a leaf for autograd and gradients can be collected from it during backward. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[mb_index] + + # See [Note: pipeline model output type] + prev_stage_outputs = _normalize_model_output_as_tuple(prev_stage_outputs) + + for info, tensor in zip(recv_infos, prev_stage_outputs): + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + "set_local_Fwd_input should only be called on non-first stage, which should always have RecvInfo" + ) + + # We don't need to do a data copy here, since we can directly pass the activation tensor reference from + # one stage to the next. However, we do need to mark the activation as a leaf tensor since it will serve + # as the input tensor for a fresh autograd graph, not part of the previous stage's autograd graph. + # TODO: confirm, do we use this activation as the root of the backward call for the previous stage? does + # detach have any affect on that? + info.buffer = tensor.detach().requires_grad_(True) + + def get_local_bwd_output(self, mb_index): + """ + Returns the input grad tensors for this stage, which correspond to the stage inputs during forward. + """ + assert self.has_backward, ( + "can't steal_bwd_input if this stage doesn't have backward" + ) + assert not self.is_first, "can't get bwd output if this stage is first" + + self._check_chunk_id(mb_index) + return self.bwd_cache.pop(mb_index) + + def set_local_bwd_input( + self, next_stage_bwd_outputs: tuple[Optional[torch.Tensor], ...], mb_index: int + ) -> None: + """ + Moves 'grad input' tensors from the next stage to 'grad_output' on this stage, avoiding a copy or send/recv. + Does not detach or set '_requires_grad'. + """ + assert isinstance(next_stage_bwd_outputs, tuple), ( + f"Expected tuple, got {type(next_stage_bwd_outputs)}" + ) + + assert self.has_backward, ( + "can't set bwd input if this stage doesn't have backward" + ) + assert not self.is_last, "can't set bwd input if this stage is last" + recv_infos = self.grad_recv_info[mb_index] + for info, tensor in zip(recv_infos, next_stage_bwd_outputs): + assert isinstance(tensor, torch.Tensor), ( + f"expected tensor values as outputs from prev stage, got {type(tensor)}" + ) + assert isinstance(info, _RecvInfo), ( + f"Expected a recv info, got {type(info)}" + ) + info.buffer = tensor + + def get_fwd_recv_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the input arguments + for this stage. + """ + recv_infos: tuple[InputInfo, ...] = self.args_recv_info[fwd_chunk_id] + + return self._get_recv_ops(recv_infos) + + def get_bwd_recv_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Returns a list of ops that are needed to receive the gradients + for this stage. + """ + if not self.has_backward or self.is_last: + return [] + + recv_infos = self.grad_recv_info[bwd_chunk_id] + return self._get_recv_ops(recv_infos) + + def get_fwd_send_ops(self, fwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the activation send ops for current stage's forward. + """ + output_tuple, _ = self.fwd_cache[fwd_chunk_id] + + ops: list[dist.P2POp] = [] + + for idx, out in enumerate(output_tuple): + dst_stages = self.act_send_info[idx] + for dst in dst_stages: + if dst is None: + continue + logger.debug( + "%s Sending tensor to Stage %s: %s", + self.log_prefix, + dst, + out.size(), + ) + peer_rank = self.stage_index_to_group_rank[dst] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append(dist.P2POp(dist.isend, out, peer_global_rank, self.group)) + + return ops + + def get_bwd_send_ops(self, bwd_chunk_id: int) -> list[dist.P2POp]: + """ + Get the gradient send ops for current stage's backward. + """ + self._check_chunk_id(bwd_chunk_id) + + if not self.has_backward or self.is_first: + return [] + + # Create bwd send infra lazily + if self.grad_send_info is None: + # Send info for input grads during backward: + # List of destinations corresponding to input grads + # Can be None if an input has no grad + # `grad_send_info` is a mirror of `args_recv_info` + self.grad_send_info = self._create_grad_send_info(self.args_recv_info[0]) + + ops: list[dist.P2POp] = [] + grads_input = self.bwd_cache.pop(bwd_chunk_id) + for grad, grad_recv_stage in zip(grads_input, self.grad_send_info): + if isinstance(grad, torch.Tensor) and grad_recv_stage is not None: + logger.debug( + "%s Sending gradient to Stage %s: %s", + self.log_prefix, + grad_recv_stage, + grad.size(), + ) + peer_rank = self.stage_index_to_group_rank[grad_recv_stage] + peer_global_rank = ( + peer_rank + if self.group is None + else dist.get_global_rank(self.group, peer_rank) + ) + ops.append(dist.P2POp(dist.isend, grad, peer_global_rank, self.group)) + else: + if not (grad is None and grad_recv_stage is None): + raise RuntimeError( + f"[{self.stage_index}] for chunk {bwd_chunk_id} has gradients {grad} " + f"and is expecting to send gradients to stage {grad_recv_stage}" + ) + return ops + + def clear_runtime_states(self) -> None: + """ + Clear runtime states of the stage. + """ + # map microbatch ID to list of forward tensor args + self.fwd_cache.clear() + # Caching chunk outputs for final output merge or reduction + self.output_chunks.clear() + + # Clear grad of input buffers in between schedule steps. This is because + # `torch.autograd.backward()` will accumulate gradients into leaf + # tensors by default. For gradients to pass back to previous stages, we + # don't want such accumulation. + for recv_tuple in self.args_recv_info.values(): # iterate over all chunks + for a in recv_tuple: # iterate over all input args + if isinstance(a, _RecvInfo): + # Set to None is the newer and recommended way to clear grads, compared to `zero_()`. + # See https://github.com/pytorch/pytorch/pull/92731 + a.buffer.grad = None + + def _map_tensor_from_recv_info( + self, + recv_infos: tuple[InputInfo, ...], + ): + """ + Map tensors from recv infos to a list. + """ + + def get_recv_tensor(info): + if isinstance(info, _RecvInfo): + return info.buffer + else: + raise AssertionError(f"Expected _RecvInfo but got {type(info)}") + + return map_aggregate(cast(Argument, recv_infos), get_recv_tensor) + + def _retrieve_recv_activations(self, fwd_chunk_id: int): + """ + Retrieve the activations received for the current stage during forward. + """ + recv_infos = self.args_recv_info[fwd_chunk_id] + activations = self._map_tensor_from_recv_info(recv_infos) + return activations + + def _retrieve_recv_grads( + self, + bwd_chunk_id: int, + ): + """ + Retrieve the gradients received for the current stage during backward. + """ + recv_infos = self.grad_recv_info[bwd_chunk_id] + grads = self._map_tensor_from_recv_info(recv_infos) + return grads + + def forward_maybe_with_nosync(self, *args, **kwargs): + # If submod is wrapped with DDP, we use the `no_sync` context manager to + # avoid gradient all-reduce per microbatch + if isinstance(self.submod, DistributedDataParallel): + with self.submod.no_sync(): # type: ignore[operator] + out_val = self.submod(*args, **kwargs) + else: + out_val = self.submod(*args, **kwargs) + return out_val + + def scale_grads(self, grad_scale_factor: int) -> None: + """Scale gradients model gradients by `grad_scale_factor`, which should be specified in coordination with the + loss function used with pipelining. For loss functions which perform 'mean' loss reduction, `grad_scale_factor` + should be set to num_microbatches. For loss functions that use `sum` reduction, `grad_scale_factor` should + be set to 1. + + Should only be called once per pipeline schedule step, after all backwards passes have completed. + """ + + # PP scales only for its own contribution (microbatches), but relies on DP to scale further + # for DP degree. + if grad_scale_factor != 1: + for p in self.submod.parameters(): + if p.grad is not None: + p.grad.div_(grad_scale_factor) + + def backward_maybe_with_nosync( + self, + backward_type, + bwd_kwargs: dict, + last_backward: bool = False, + ) -> tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]]: + """ + Whether using PP with FSDP or DDP, there are some runtime differences between the last backward step and the + other steps. Namely, we need to accumulate gradients on previous steps and reduce them on the last step, but + there are additional state-variables and performance considerations depending on the data parallelism used. + This helper should adapt any pipeline parallel schedule to work with common/supported data parallel libraries. + """ + + def perform_backward( + backward_type, + ) -> Callable[ + [], + tuple[tuple[Optional[torch.Tensor], ...], Optional[list[dict[str, Any]]]], + ]: + if backward_type == "full": + return lambda: ( + stage_backward( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + ), + None, + ) + elif backward_type == "input": + return lambda: stage_backward_input( + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + bwd_kwargs["input_values"], + self.submod.parameters(), + ) + elif backward_type == "weight": + return lambda: ( + stage_backward_weight( + self.submod.parameters(), bwd_kwargs["param_groups"] + ), + None, + ) + else: + raise RuntimeError(f"Unknown backward type: {backward_type}") + + # If submod is wrapped by DDP + if isinstance(self.submod, DistributedDataParallel): + if last_backward: + # Last chunk, prepare for gradient reduction + # HACK: reaching into DDP implementation details here. Is there a better way? + self.submod.reducer.prepare_for_backward( # type: ignore[union-attr, operator] + list( + torch.nn.parallel.distributed._find_tensors( # type: ignore[attr-defined] + bwd_kwargs["stage_output"] + ) + ) + ) + result = perform_backward(backward_type)() + else: + with self.submod.no_sync(): # type: ignore[operator] + result = perform_backward(backward_type)() + # If submod is a FSDP module + elif isinstance(self.submod, FSDPModule): + self.submod.set_is_last_backward(False) + self.submod.set_reshard_after_backward(False) + self.submod.set_requires_gradient_sync(False) + result = perform_backward(backward_type)() + if last_backward: + # Manually call post backward for FSDP + def run_post_backward(fsdp_module: FSDPModule) -> None: + fsdp_module.set_is_last_backward(True) + fsdp_module.set_reshard_after_backward(True) + fsdp_module.set_requires_gradient_sync(True) + fsdp_state = fully_shard.state(fsdp_module) # type: ignore[attr-defined] + for state in fsdp_state._state_ctx.all_states: + if state._fsdp_param_group: + state._fsdp_param_group.post_backward() + + # it would be much better if pipelining backward invoked .backward so autograd hooks + # worked and modules like DDP/FSDP behaved as expected. Working around this for the time being, + # we need to call this too to ensure FSDP syncs its grad reduction ops back to the default stream. + fsdp_state._root_post_backward_final_callback() + + run_post_backward(self.submod) + + else: + # Non-DP submodule, regular backward + result = perform_backward(backward_type)() + + grads, param_groups = result + return grads, param_groups + + def forward_one_chunk( + self, + fwd_chunk_id: int, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + ): + """ + Perform forward pass on the stage with one microbatch. + `args` and `kwargs` are the inputs from *external* to this stage. + As of Sept 2024: + - `args` applies to the first stage only, other stages receives args + through activation transmission. + - `kwargs` can be passed to all stages via respective `step` calls. + """ + + if self.is_first: + # First stage doesn't need to receive anything + composite_args = args + else: + # Receive activations for this chunk + # Activations only come in args form + composite_args = self._retrieve_recv_activations(fwd_chunk_id) + + composite_kwargs = kwargs or {} + + self._validate_fwd_input(args, kwargs) + + # Compute forward + try: + output = self.forward_maybe_with_nosync(*composite_args, **composite_kwargs) + + except Exception as e: + exc_msg = f""" + {self.log_prefix} failed to run forward: + args: {map_debug_info(composite_args)} + kwargs: {map_debug_info(composite_kwargs)} + """ + raise RuntimeError(exc_msg) from e + + # See [Note: pipeline model output type] + output_tuple = _normalize_model_output_as_tuple(output) + + # Prepare for final output merge or reduction + # Output chunks is only used for the last stage since we only merge the output of the last stage + if self.is_last: + self.output_chunks.append(output) + + # Save activations and inputs for backward + flat_args = flatten_args(composite_args) + flat_kwargs = flatten_args(composite_kwargs) + flatten_input_tensors = flat_args + flat_kwargs + self.fwd_cache[fwd_chunk_id] = ( + output_tuple, # stage_output + flatten_input_tensors, # input_values + ) + + logger.debug( + "%s Forwarded chunk %s, outputs: %s", + self.log_prefix, + fwd_chunk_id, + map_debug_info(output), + ) + self._validate_fwd_outputs(output_tuple) + + # We return the original user-provied output, not normalized to tuple. + # See [Note: pipeline model output type] + return output + + def backward_one_chunk( + self, + bwd_chunk_id: int, + loss=None, + full_backward: bool = True, + last_backward=False, + ): + """ + Perform backward pass on the module. + This should only be called once per microbatch. + + If full_backward is True (the default), the full backward pass including weight and input gradients will be run, + and it is an error to call `backward_weight_one_chunk` for this bwd_chunk_id. + + If full_backward is False, it is optional that `dw_runner` was provided to the PipelineStage at __init__ time, + and a subsequent call to `backward_weight_one_chunk` is required to invoke dw_runner and complete the backward. + + last_backward is controlled by the schedule and signals synchronization of gradients across DP groups + after the last backward. + """ + self._check_chunk_id(bwd_chunk_id) + + ( + stage_output, + input_values, + ) = self.fwd_cache.pop(bwd_chunk_id) + + # Compute backward + if self.is_last: + # Last stage computes gradients from loss and has no gradients from + # next stage + bwd_kwargs = { + "stage_output": loss, + "output_grads": None, + "input_values": input_values, + } + else: + # Otherwise, receive gradients from next stage + grads_output = self._retrieve_recv_grads(bwd_chunk_id) + # If an input to the pipeline requires gradient, + # `torch.autograd.backward` will accumulate the gradient into the + # `.grad` field of such input + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": grads_output, + "input_values": input_values, + } + + grads_input: tuple[Optional[torch.Tensor], ...] = () + + # Custom backward function + if self.dw_builder: + # TODO: We may want to change our semantics so we are allowed to ignore + # the 'dw_builder' and call full_backward directly when it is a full_backward op. + grads_input, _ = self.backward_maybe_with_nosync( + "full", + bwd_kwargs, + last_backward=last_backward, + ) + if full_backward: + self.dw_builder()() + else: + self.dw_runner[bwd_chunk_id] = self.dw_builder() + else: + if full_backward: + grads_input, _ = self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + else: + param_groups: list[dict[str, Any]] | None = None + # Skip the backward for the first stage since we will perform the weight update with + # autograd.backward in backward_weight_one_chunk + if not self.is_first: + if isinstance(bwd_kwargs["stage_output"], torch.Tensor): + bwd_kwargs["stage_output"] = (bwd_kwargs["stage_output"],) + + # perform the partial backwards for the inputs with a custom backward function + # when the "stage_ouput" is a loss, then it is a tensor, otherwise it is a tuple of tensors + grads_input, param_groups = self.backward_maybe_with_nosync( + "input", bwd_kwargs, last_backward=last_backward + ) + + # TODO: we dont need to save this, add to dw_runner? + self.backward_state[bwd_chunk_id] = ( + bwd_kwargs["input_values"], + param_groups, + bwd_kwargs["stage_output"], + bwd_kwargs["output_grads"], + ) + # Save a placeholder for the dw_runner + self.dw_runner[bwd_chunk_id] = lambda: None + + self.bwd_cache[bwd_chunk_id] = grads_input + + if self.is_last and not self.is_first: + # Autograd dependencies: + # rest_of_autograd_graph -> stage_output -> loss + # stage_output is no longer used in the last stage for backward and only needed + # to return to the user in merge_output_chunks, therefore + # this should be detached to release autograd graph context and free memory earlier + for t in stage_output: + if not t._is_view(): # views are not detachable in-place + t.detach_() + + logger.debug("%s Backwarded chunk %s", self.log_prefix, bwd_chunk_id) + + def backward_weight_one_chunk(self, bwd_chunk_id: int, last_backward=False): + assert bwd_chunk_id in self.dw_runner, ( + f"{self.log_prefix} Attempted to run backward_weight_one_chunk for chunk {bwd_chunk_id}" + " without first calling `backward_one_chunk(full_backward=False)`" + ) + + if self.dw_builder is not None: + self.dw_runner.pop(bwd_chunk_id)() + else: + ( + input_values, + param_groups, + stage_output, + output_grads, + ) = self.backward_state.pop(bwd_chunk_id) + + if self.stage_index != 0: + bwd_kwargs = { + "stage_output": stage_output, + "param_groups": param_groups, + } + self.backward_maybe_with_nosync( + "weight", bwd_kwargs, last_backward=last_backward + ) + else: + # TODO: figure out a better way to do this: + # if inputs does not require gradient, + # then the parameter group will not be fully captured during stage_backward_input + # in this case, we need call grad directly on the parameters + # To solve: make input fn do the intersect compute and then finish it off during W + bwd_kwargs = { + "stage_output": stage_output, + "output_grads": output_grads, + "input_values": input_values, + } + self.backward_maybe_with_nosync( + "full", bwd_kwargs, last_backward=last_backward + ) + + def _validate_fwd_input(self, args, kwargs): + """Raises a RuntimeError if shapes of input args/kwargs do not match the shapes configured for this stage.""" + + if self.is_first: + # TODO why is there a separate recv_info for each pipeline chunk? + # kwen2501: to avoid passing a `fwd_chunk_id` to this function, we + # check all chunks against args_recv_info[0] + expected_args = self.args_recv_info[0] + else: + # We don't check inputs for non-0 stages assuming they don't accept + # user inputs in canonical pipeline scenarios + return + + if len(kwargs): + # TODO- need a mapping of kwarg to position in self.args_recv_info + # Without it, we are not 100% sure how to match the args and + # expected_args. + return + + # TODO- need a mapping of kwarg to position in self.args_recv_info + # maybe it's impossible to tell whether the len mismatches because + # (a) the user passed an extra arg or missed an arg + # (b) the user did not pass a kwarg, which has a default value baked into expected_args + expected_tensors_meta = [ + e.meta if isinstance(e, _RootArgPlaceholder) else e.buffer + for e in expected_args + ] + validate_tensors_metadata( + f"Stage {self.stage_index} forward inputs", expected_tensors_meta, args + ) + + def _validate_fwd_outputs(self, outputs: tuple[torch.Tensor, ...]): + """Raises a RuntimeError if this stage produces an output of unexpected shape/dtype. + Most likely, this could be cause either by incorrect user specification of output shapes, or because + shape inference was done on the original model but then at runtime the model is wrapped with something like + mixed precision which changes output dtype. + """ + expected_tensors_meta = self.get_outputs_meta() + validate_tensors_metadata( + f"Stage {self.stage_index} forward outputs", expected_tensors_meta, outputs + ) + + +class _PipelineStage(_PipelineStageBase): + def __init__( + self, + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, + ): + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and a `pipe_info` describing the stage relationship of the pipeline. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + """ + _PipelineStageBase.__init__( + self, + stage_module, + stage_index, + pipe_info.num_stages, + device, + group, + ) + self.pipe_info = pipe_info + + # Find stage nodes in graph + submod_nodes = [ + node for node in pipe_info.graph.nodes if node.op == "call_module" + ] + if len(submod_nodes) != self.num_stages: + raise AssertionError( + f"Number of submodules in pipe graph {len(submod_nodes)} does not match number of stages {self.num_stages}" + ) + + # Find my stage node in graph + self.node = submod_nodes[self.stage_index] + self.name = self.node.name + logger.info( + "[%s] Creating PipelineStage %s for %s", + self.group_rank, + stage_index, + self.name, + ) + + # Create mapping from stage name to stage index + self.submod_to_stage_index: dict[str, int] = {} + for i, node in enumerate(submod_nodes): + self.submod_to_stage_index.setdefault(node.name, i) + + # Cast submodule to device + self._move_submod_to_device() + + def _move_submod_to_device(self): + # Move submodule to indicated device if possible + # Note: we cannot move meta module to real devices because meta tensors + # do not support to() method. One needs to do an in-place tensor swap in + # that case. + has_meta_param = any( + isinstance(p, FakeTensor) or p.is_meta for p in self.submod.parameters() + ) + if has_meta_param: + logger.debug("%s Found meta parameters!", self.log_prefix) + else: + self.submod.to(self.device) + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[Any, ...]: + """ + Create send/recv infrastructures for activations (during forward) + """ + # TODO(whc) + # this method should be deleted once lazy buffer allocation is implemented + # for now, it ignores args/kwargs because it should not need to do shape inference + for chunk in range(num_microbatches): + self.args_recv_info[chunk] = self._create_act_recv_info() + + # Send info during forward for each activation + self.act_send_info = self._create_act_send_info() + return tuple() + + def get_stage_index_of_submod( + self, + submod_name: str, + ): + """ + Given a submodule name, return the stage index of the submodule. + """ + if submod_name not in self.submod_to_stage_index: + raise AssertionError(f"Stage id of {submod_name} not found") + + return self.submod_to_stage_index[submod_name] + + def _create_act_recv_info( + self, + ): + """ + Create a tuple of `_RecvInfo` for inputs to the stage. + """ + + def create_recv_tensor(placeholder, arg_node): + """ + Create a receive buffer for a placeholder. + """ + example_value = placeholder.meta["val"] + if arg_node.op == "placeholder": + # This is a root level placeholder, thus an input argument to the entire model. + # We are likely at stage 0, hence no need to create a receive buffer. + return _RootArgPlaceholder(example_value) + + # Figure out the source stage of this input + while arg_node.target is operator.getitem: + # If the input is a getitem, we need to go deeper + arg_node = arg_node.args[0] + + assert arg_node.op == "call_module", ( + f"Expecting call_module, got {arg_node.op}" + ) + src_stage = self.get_stage_index_of_submod(arg_node.name) + + # Create a receive buffer for this placeholder + logger.debug( + "%s Creating recv buffer for input '%s' : %s, %s", + self.log_prefix, + placeholder.name, + example_value.shape, + example_value.dtype, + ) + buffer = _make_tensor_from_meta(example_value, self.device) + # In case there is backward pass, set requires_grad for receive buffers + # before first forward + if self.has_backward: + buffer.requires_grad_(True) + + return _RecvInfo( + arg_node.name, + src_stage, + buffer, + ) + + args_recv_info: list[InputInfo] = [] + # Filter out placeholder nodes from `self.submod` (a GraphModule) + placeholders = filter( # type: ignore[var-annotated] + lambda node: node.op == "placeholder", # type: ignore[arg-type] + self.submod.graph.nodes, # type: ignore[arg-type,union-attr] + ) + # `placeholders` are nodes internal to submod. + # `self.node.args` are dependency nodes in the outer graph. + # The two are 1:1. + for placeholder, arg_node in zip(placeholders, self.node.args): + # Create a receive buffer for this placeholder + recv_info = create_recv_tensor(placeholder, arg_node) + args_recv_info.append(recv_info) + + logger.debug( + "%s Activation recv / args info: %s", self.log_prefix, args_recv_info + ) + # `args` is a Tuple, hence we will return a Tuple[InputInfo] + return tuple(args_recv_info) + + def find_dst_rank( + self, + user: fx.Node, + ) -> Optional[int]: + """ + Find the destination rank of a `user` node. + If the `user` is not a submod, `None` may be returned. + """ + if user.op == "call_module": + # User is a stage (`call_module`) + return self.get_stage_index_of_submod(user.name) + else: + # - If user.op == "output": + # No need to send back to rank 0 + # - If user.target is stage_backward: + # No need to send assuming submod output is stored locally or + # should be re-calucated in case of activation checkpointing + return None + + def _create_act_send_info(self): + """ + Create a dict of send info for activations. + The dict is of the form: + { + output_index: [dst_rank_0, dst_rank_1, ...], + ... + } + where the list of `dst_rank`s covers the case where an output value may + be consumed by multiple stages. + """ + # Output index: List of receiver ranks + act_send_info: dict[int, list] = {} + out_idx = 0 + + for user in self.node.users: + if user.target is operator.getitem: + # Recursively find the real destination + gi_dsts = act_send_info.setdefault(out_idx, []) + for gi_user in user.users: + dst_rank = self.find_dst_rank(gi_user) + if dst_rank is not None: + gi_dsts.append(dst_rank) + # Next `getitem` will point to the next output index + out_idx += 1 + else: + # In case of single output value, `out_idx` will not increase + dsts = act_send_info.setdefault(out_idx, []) + dst_rank = self.find_dst_rank(user) + if dst_rank is not None: + dsts.append(dst_rank) + + output_node = self._get_output_node() + output_vals: tuple[torch.Tensor] = tuple( + v.meta["val"] for v in flatten_args(output_node.args) + ) + self._configure_outputs_meta(output_vals) + + logger.debug("%s Send info: %s", self.log_prefix, act_send_info) + return act_send_info + + def _get_output_node(self): + output_nodes = [node for node in self.submod.graph.nodes if node.op == "output"] # type: ignore[union-attr] + assert len(output_nodes) == 1 + output_node = output_nodes[0] + return output_node + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + """ + Create a tuple of `_RecvInfo` for gradients. + """ + # Dict[output_index, _RecvInfo] + grad_recv_info: dict[int, _RecvInfo] = {} + output_node = self._get_output_node() + + # The output node may take multiple args, meaning the submod having multiple output values. + output_vals = flatten_args(output_node.args) + + for out_idx, dst_list in act_send_info.items(): + if not dst_list: + # No actual receiver for activation so no grad coming back + continue + + output = output_vals[out_idx] + example_value = output.meta["val"] + logger.debug( + f"{self.log_prefix} Creating grad recv buffer for output {output.name} " # noqa: G004 + f": {example_value.shape}, {example_value.dtype}" + ) + + # TODO: otherwise needs grad accumulation + assert len(dst_list) == 1, "Backward of skip connections not supported yet" + grad_src = dst_list[0] + grad_recv_info[out_idx] = _RecvInfo( + f"{grad_src}", # noqa: G004 + grad_src, + _make_tensor_from_meta(example_value, self.device), + ) + + # Convert to tuple for convenience in get_ops and retrieve tensor + grad_recv_info_tuple = tuple(grad_recv_info.values()) + logger.debug("%s Grad recv info: %s", self.log_prefix, grad_recv_info_tuple) + return grad_recv_info_tuple + + +# A helper function to create a pipeline stage based on traced pipeline information +def build_stage( + stage_module: torch.nn.Module, + stage_index: int, + pipe_info: PipeInfo, + device: torch.device, + group: Optional[dist.ProcessGroup] = None, +) -> _PipelineStage: + """ + Create a pipeline stage given a stage_module to be wrapped by this stage + and pipeline information. + + Args: + stage_module (torch.nn.Module): the module to be wrapped by this stage + stage_index (int): the index of this stage in the pipeline + pipe_info (PipeInfo): information about the pipeline, can be retrieved by `pipe.info()` + device (torch.device): the device to be used by this stage + group (Optional[dist.ProcessGroup]): the process group to be used by this stage + + Returns: + _PipelineStage: a pipeline stage that can run with `PipelineSchedules`. + """ + return _PipelineStage( + stage_module, + stage_index, + pipe_info, + device, + group, + ) + + +class PipelineStage(_PipelineStageBase): + """ + A class representing a pipeline stage in a pipeline parallelism setup. + + PipelineStage assumes sequential partitioning of the model, i.e. the model is split into chunks where outputs from + one chunk feed into inputs of the next chunk, with no skip connections. + + PipelineStage performs runtime shape/dtype inference automatically by propagating the outputs from stage0 to + stage1 and so forth, in linear order. To bypass shape inference, pass the `input_args` and `output_args` to each + PipelineStage instance. + + Args: + submodule (nn.Module): The PyTorch module wrapped by this stage. + stage_index (int): The ID of this stage. + num_stages (int): The total number of stages. + device (torch.device): The device where this stage is located. + input_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The input arguments for the submodule. + output_args (Union[torch.Tensor, Tuple[torch.tensor]], optional): The output arguments for the submodule. + group (dist.ProcessGroup, optional): The process group for distributed training. If None, default group. + dw_builder (Optional[Callable[[], Callable[..., None]]): If provided, dw_builder will build a new dw_runner function + that will the W action (input weights) for F, I, W (Fwd, Input, Weight) zero bubble schedules. + """ + + def __init__( + self, + submodule: nn.Module, + stage_index: int, + num_stages: int, + device: torch.device, + input_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, + output_args: Optional[Union[torch.Tensor, tuple[torch.Tensor, ...]]] = None, + group: Optional[dist.ProcessGroup] = None, + dw_builder: Optional[Callable[[], Callable[..., None]]] = None, + ): + super().__init__(submodule, stage_index, num_stages, device, group, dw_builder) + self.inputs: Optional[list[torch.Tensor]] = None + self.inputs_meta: Optional[tuple[torch.Tensor, ...]] = None + # Note: inputs and submod should ideally be on meta device. We decided not to assert this (yet) because it + # might be breaking for existing users. + if input_args is None: + assert output_args is None, ( + "If specifying output_args, input_args must also be specified. " + "Otherwise, shape inference will be performed at runtime" + ) + else: + self.inputs_meta = ( + (input_args,) if isinstance(input_args, torch.Tensor) else input_args + ) + if output_args is None: + logger.warning( + "Deprecation warning: passing input_args and performing init-time shape inference is deprecated. " + "PipelineStage now supports runtime shape inference using the real inputs provided to schedule step(). " + "Either delete `input_args` arg to `PipelineStage` to opt-into runtime shape inference, " + "or additionally pass `output_args` to `PipelineStage` to fully override shape inference. " + ) + try: + with torch.no_grad(): + output_args = submodule(*self.inputs_meta) + output_args = tree_map_only( + torch.Tensor, lambda x: x.to("meta"), output_args + ) + except Exception as e: + raise RuntimeError( + "Failed to perform pipeline shape inference- are your inputs on the same device as your module?" + ) from e + assert output_args is not None, ( + "If passing input_args, also pass output_args to override shape inference" + ) + self._configure_outputs_meta( + (output_args,) if isinstance(output_args, torch.Tensor) else output_args + ) + + # these are the buffers used in backwards send/recv, they are allocated later + self.outputs_grad: list[torch.Tensor] = [] + + dbg_str = ( + f"Finished pipeline stage init, {self.stage_index=}, {self.is_first=}, " # noqa: G004 + f"{self.is_last=}, {self.num_stages=}, " + ) + if self.inputs_meta is not None: + dbg_str += ( + f"inputs: {[inp.shape for inp in self.inputs_meta]}, " + f"output: {[output.shape for output in self.get_outputs_meta()]}" + ) + else: + dbg_str += " running shape-inference at runtime" + + logger.debug(dbg_str) + + def _shape_inference( + self, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + ): + if kwargs is None: + kwargs = {} + assert args is not None, "Args may be an empty tuple but not None" + + # We skip recv communication if we're the first stage, but also if the previous stage is on the same rank + # and can pass its output shapes in as args instead of using send/recv. + if ( + self.is_first + # if not first stage, then check if prev stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index - 1] == self.group_rank + ): + logger.debug( + "Shape inference: stage %s skipping recv, because shape info passed in via `args`", + self.stage_index, + ) + args = tree_map_only(torch.Tensor, lambda x: x.to("meta"), args) + else: + assert len(args) == 0, ( + "Can't supply input args for shape inference on non-first stage" + ) + objects = [None] + logger.debug( + "Shape inference: stage %s receiving from stage %s", + self.stage_index, + self.stage_index - 1, + ) + dist.recv_object_list( + objects, + src=dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + self.stage_index_to_group_rank[self.stage_index - 1], + ), + group=self.group, + device=self.device, + ) + recv_args = objects[0] + assert isinstance(recv_args, tuple), type(recv_args) + args = recv_args + + # cache input shapes for use during recv buffer allocation + self.inputs_meta = args + args = tree_map_only( + torch.Tensor, lambda x: torch.zeros_like(x, device=self.device), args + ) + + # set attributes needed for forward + with torch.no_grad(): + outputs = self.submod(*args, **kwargs) + + # if single tensor, convert so it is always a list + if isinstance(outputs, torch.Tensor): + outputs = [outputs] + + # communicate meta outputs not real outputs for two reasons + # 1 - its faster (esp. since obj coll pickles tensor data!) + # 2 - avoid activating a cuda context for the src rank when unpickling on the recv end! + outputs_meta = tuple( + tree_map_only(torch.Tensor, lambda x: x.to("meta"), outputs) + ) + logger.debug( + "Shape inference: stage %s inputs %s, outputs %s", + self.stage_index, + self.inputs_meta, + outputs_meta, + ) + self._configure_outputs_meta(outputs_meta) + + # Passing outputs to the next stage: + # two cases- + # 1. Usually: use send/recv communication to pass the output + # 2. Special case: for V-schedules, 2 'adjacent' stages (e.g. stage 3, 4 in an 8-stage 4-rank V) + # pass their shape info via return value and function args rather than send/recv. + if ( + self.is_last + # if not last stage, then check if next stage is on the same rank + or self.stage_index_to_group_rank[self.stage_index + 1] == self.group_rank + ): + # Case (2) above: pass shape info via return value and caller passes it as args to next stage's + # _shape_inference call + logger.debug( + "Shape inference: stage %s skipping send to next stage", + self.stage_index, + ) + + else: + # Case (1): send shapes via send operation, and ensure not to return it to the caller + logger.debug( + "Shape inference: stage %s sending to stage %s", + self.stage_index, + self.stage_index + 1, + ) + dist.send_object_list( + [outputs_meta], + dst=dist.get_global_rank( + self.group or dist.distributed_c10d._get_default_group(), + self.stage_index_to_group_rank[self.stage_index + 1], + ), + group=self.group, + device=self.device, + ) + outputs_meta = tuple() + + return outputs_meta + + def _prepare_forward_infra( + self, + num_microbatches: int, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + ) -> tuple[Any, ...]: + # TODO move self.device to an argument from step API (from its input tensors)? + assert num_microbatches is not None, "TODO fix num_microbatches" + + outputs: tuple[Any, ...] = tuple() + if self.inputs_meta is None: + outputs = self._shape_inference(args, kwargs) + + assert self.inputs_meta is not None + # Receive info during forward + # TODO: create args_recv_info lazily? (same needed for PipelineStage) + for chunk_id in range(num_microbatches): + if not self.is_first: + # We assume that we always receive from stage - 1 + recv_infos = tuple( + [ + _RecvInfo( + f"recv_for_{self.stage_index}_from_{self.stage_index - 1}", + self.stage_index - 1, + _make_tensor_from_meta(inp, self.device), + ) + for inp in self.inputs_meta + ] + ) + # In case there is backward pass, set requires_grad for receive buffers + if self.has_backward: + for r in recv_infos: + r.buffer.requires_grad_(True) + + self.args_recv_info[chunk_id] = recv_infos + else: + self.args_recv_info[chunk_id] = tuple( + [_RootArgPlaceholder(i) for i in self.inputs_meta] + ) + + # Send info during forward for each activation + # only need the rank that is being sent to + self.act_send_info: dict[int, list] = {} + + for idx in range(len(self.get_outputs_meta())): + # We assume we always send to stage + 1 + if not self.is_last: + self.act_send_info[idx] = [self.stage_index + 1] + else: + self.act_send_info[idx] = [] + + return outputs + + def _create_grad_recv_info( + self, + act_send_info: dict, + ) -> tuple[_RecvInfo, ...]: + grad_recv_info: tuple[_RecvInfo, ...] = () + if not self.is_last: + # Receiving gradients from multiple sources is not supported + # hence we only take the first destination + grad_recv_info = tuple( + [ + _RecvInfo( + f"recv_grad_for_{self.stage_index}_from_{dst_list[0]}", + dst_list[0], + _make_tensor_from_meta( + self.get_outputs_meta()[idx], self.device + ), + ) + for idx, dst_list in act_send_info.items() + ] + ) + return grad_recv_info diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__init__.py b/phivenv/Lib/site-packages/torch/distributed/rpc/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2cc103aeb58265939a16c9940ec59be7694703a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/__init__.py @@ -0,0 +1,257 @@ +# mypy: allow-untyped-defs +import logging +import os +import threading +import warnings +from collections.abc import Generator +from datetime import timedelta +from urllib.parse import urlparse + +import torch +import torch.distributed as dist + + +__all__ = ["is_available"] + + +logger = logging.getLogger(__name__) + + +_init_counter = 0 +_init_counter_lock = threading.Lock() + + +def is_available() -> bool: + return hasattr(torch._C, "_rpc_init") + + +if is_available() and not torch._C._rpc_init(): + raise RuntimeError("Failed to initialize torch.distributed.rpc") + + +if is_available(): + _is_tensorpipe_available = hasattr( + torch._C._distributed_rpc, "_TensorPipeRpcBackendOptionsBase" + ) + + import numbers + + import torch.distributed.autograd as dist_autograd + from torch._C._distributed_c10d import Store + from torch._C._distributed_rpc import ( # noqa: F401 + _cleanup_python_rpc_handler, + _DEFAULT_INIT_METHOD, + _DEFAULT_RPC_TIMEOUT_SEC, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _disable_jit_rref_pickle, + _disable_server_process_global_profiler, + _enable_jit_rref_pickle, + _enable_server_process_global_profiler, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _rref_context_get_debug_info, + _set_and_start_rpc_agent, + _set_profiler_node_id, + _set_rpc_timeout, + _UNSET_RPC_TIMEOUT, + enable_gil_profiling, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + RpcAgent, + RpcBackendOptions, + WorkerInfo, + ) + + if _is_tensorpipe_available: + from torch._C._distributed_rpc import ( # noqa: F401 + _DEFAULT_NUM_WORKER_THREADS, + _TensorPipeRpcBackendOptionsBase, + TensorPipeAgent, + ) + + from . import api, backend_registry, functions + from .api import * # noqa: F401,F403 + from .backend_registry import BackendType + from .options import TensorPipeRpcBackendOptions # noqa: F401 + from .server_process_global_profiler import _server_process_global_profile + + rendezvous_iterator: Generator[tuple[Store, int, int], None, None] + + __all__ += ["init_rpc", "BackendType", "TensorPipeRpcBackendOptions"] + __all__ = __all__ + api.__all__ + backend_registry.__all__ # noqa: PLE0605 + + def init_rpc( + name, + backend=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + r""" + Initializes RPC primitives such as the local RPC agent + and distributed autograd, which immediately makes the current + process ready to send and receive RPCs. + + Args: + name (str): a globally unique name of this node. (e.g., + ``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``) + Name can only contain number, alphabet, underscore, colon, + and/or dash, and must be shorter than 128 characters. + backend (BackendType, optional): The type of RPC backend + implementation. Supported values is + ``BackendType.TENSORPIPE`` (the default). + See :ref:`rpc-backends` for more information. + rank (int): a globally unique id/rank of this node. + world_size (int): The number of workers in the group. + rpc_backend_options (RpcBackendOptions, optional): The options + passed to the RpcAgent constructor. It must be an agent-specific + subclass of :class:`~torch.distributed.rpc.RpcBackendOptions` + and contains agent-specific initialization configurations. By + default, for all agents, it sets the default timeout to 60 + seconds and performs the rendezvous with an underlying process + group initialized using ``init_method = "env://"``, + meaning that environment variables ``MASTER_ADDR`` and + ``MASTER_PORT`` need to be set properly. See + :ref:`rpc-backends` for more information and find which options + are available. + """ + torch._C._log_api_usage_once("torch.distributed.init_rpc") + if backend is not None and not isinstance( + backend, backend_registry.BackendType + ): + raise TypeError("Argument backend must be a member of BackendType") + + if rpc_backend_options is not None and not isinstance( + rpc_backend_options, RpcBackendOptions + ): + raise TypeError( + "Argument rpc_backend_options must be an instance of RpcBackendOptions" + ) + + # Try to detect the backend from the options + if backend is None and rpc_backend_options is not None: + for candidate_backend in BackendType: + if isinstance( + rpc_backend_options, + type( + backend_registry.construct_rpc_backend_options( + candidate_backend + ) + ), + ): + backend = candidate_backend + break + else: + raise TypeError( + f"Could not infer backend for options {rpc_backend_options}" + ) + # Ignore type error because mypy doesn't handle dynamically generated type objects (#4865) + if backend != BackendType.TENSORPIPE: # type: ignore[attr-defined] + logger.warning( + "RPC was initialized with no explicit backend but with options " # type: ignore[attr-defined] + "corresponding to %(backend)s, hence that backend will be used " + "instead of the default BackendType.TENSORPIPE. To silence this " + "warning pass `backend=%(backend)s` explicitly.", + {"backend": backend}, + ) + + if backend is None: + backend = BackendType.TENSORPIPE # type: ignore[attr-defined] + + if rpc_backend_options is None: + # default construct a set of RPC backend options. + rpc_backend_options = backend_registry.construct_rpc_backend_options( + backend + ) + + # Create store, performs rendezvous for static RPC group. + if not world_size: + # If world_size is not set in construction and also not set in environment variables + # The store will be created for the dynamic group setting + store = dist._create_store_from_options(rpc_backend_options, rank) + else: + # This rendezvous state sometimes is destroyed before all processes + # finishing handshaking. To avoid that issue, we make it global to + # keep it alive. + global rendezvous_iterator + rendezvous_iterator = dist.rendezvous( + rpc_backend_options.init_method, rank=rank, world_size=world_size + ) + store, _, _ = next(rendezvous_iterator) + # Use same timeout as RPC. + store.set_timeout(timedelta(seconds=rpc_backend_options.rpc_timeout)) + + # Use a PrefixStore to distinguish multiple invocations. + with _init_counter_lock: + global _init_counter + store = dist.PrefixStore(str(f"rpc_prefix_{_init_counter}"), store) + _init_counter += 1 + + # Initialize autograd before RPC since _init_rpc_backend guarantees all + # processes sync via the store. If we initialize autograd after RPC, + # there could be a race where some nodes might have initialized autograd + # and others might not have. As a result, a node calling + # torch.distributed.autograd.backward() would run into errors since + # other nodes might not have been initialized. + dist_autograd._init(rank) + + _set_profiler_node_id(rank) + # Initialize RPC. + _init_rpc_backend(backend, store, name, rank, world_size, rpc_backend_options) + + def _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options): + type_mapping = { + backend: backend_registry.BackendType, + store: dist.Store, + name: str, + rank: numbers.Integral, + # world_size can be None for a dynamic group + world_size: (numbers.Integral, type(None)), + rpc_backend_options: RpcBackendOptions, + } + for arg, arg_type in type_mapping.items(): + if not isinstance(arg, arg_type): # type: ignore[arg-type] + raise RuntimeError( + f"Argument {arg} must be of type {arg_type} but got type {type(arg)}" + ) + + def _init_rpc_backend( + backend=BackendType.TENSORPIPE, # type: ignore[attr-defined] + store=None, + name=None, + rank=-1, + world_size=None, + rpc_backend_options=None, + ): + _validate_rpc_args(backend, store, name, rank, world_size, rpc_backend_options) + + if _is_current_rpc_agent_set(): + raise RuntimeError("RPC is already initialized") + + # Initialize RPC. + rpc_agent = backend_registry.init_backend( + backend, + store=store, + name=name, + rank=rank, + world_size=world_size, + rpc_backend_options=rpc_backend_options, + ) + + api._init_rpc_states(rpc_agent) + + @api._require_initialized + def _get_debug_info(): + info = _rref_context_get_debug_info() + info.update(api._get_current_rpc_agent().get_debug_info()) + info.update(dist_autograd._get_debug_info()) + return info diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5dface2d82c5412299a8eee712c150f77aab2638 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20d900f4a4bceb4ad4813173972ec238962fa93f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65c62becf796880b2339c07b84893a1c8c39e1b2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74e5045dd358d54cfe45f43d09f23eb326c1f9ca Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/backend_registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..060888e412c62b9b2fbbcc9b66e867b3a2231e6c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..085a5da3ea2581761cdb9820abd958e919f83aaa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/functions.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30a34da45865d7a7d831ea19c5343f82aabd010c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/internal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/options.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/options.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2f106434d2a35a7e966b78a778165c88711ef5c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/options.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..066543aca00c48dd25f75d3274bd7dde543c3069 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/rref_proxy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca2e1065f1e72ea6f504278065672294b23e7091 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/__pycache__/server_process_global_profiler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__init__.py b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4141946004ff7a5c4e7baefc6c57be96a86e4b36 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__init__.py @@ -0,0 +1,18 @@ +import torch + + +def is_available() -> bool: + return hasattr(torch._C, "_faulty_agent_init") + + +if is_available() and not torch._C._faulty_agent_init(): + raise RuntimeError("Failed to initialize torch.distributed.rpc._testing") + +if is_available(): + # Registers FAULTY_TENSORPIPE RPC backend. + from torch._C._distributed_rpc_testing import ( + FaultyTensorPipeAgent, + FaultyTensorPipeRpcBackendOptions, + ) + + from . import faulty_agent_backend_registry diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8497ca2492edc3251215d18c02f5780b1f15c2e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ba29ccdd97040c9dc9f9ba2b5f50b43d9dab54c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/__pycache__/faulty_agent_backend_registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..bc6ac4d844b22d9d758e9e58a257764ce97e47e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/_testing/faulty_agent_backend_registry.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python3 +# mypy: allow-untyped-defs + +import torch.distributed as dist +import torch.distributed.rpc as rpc + + +def _faulty_tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads, + messages_to_fail, + messages_to_delay, + num_fail_sends, + **kwargs, +): + from . import FaultyTensorPipeRpcBackendOptions + + return FaultyTensorPipeRpcBackendOptions( + num_worker_threads=num_worker_threads, + rpc_timeout=rpc_timeout, + init_method=init_method, + messages_to_fail=messages_to_fail, + messages_to_delay=messages_to_delay, + num_fail_sends=num_fail_sends, + ) + + +def _faulty_tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from torch.distributed.rpc import api + + from . import FaultyTensorPipeAgent, FaultyTensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, FaultyTensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `FaultyTensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + agent = FaultyTensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, # reverse_device_map + [], # devices + ) + api._init_rpc_states(agent) + + return agent + + +rpc.backend_registry.register_backend( + "FAULTY_TENSORPIPE", + _faulty_tensorpipe_construct_rpc_backend_options_handler, + _faulty_tensorpipe_init_backend_handler, +) diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/_utils.py b/phivenv/Lib/site-packages/torch/distributed/rpc/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..809340fa17b59071ccfa37618d1bb748ac9d33ec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/_utils.py @@ -0,0 +1,47 @@ +# mypy: allow-untyped-defs +import logging +from contextlib import contextmanager +from typing import cast + + +logger = logging.getLogger(__name__) + + +@contextmanager +def _group_membership_management(store, name, is_join): + token_key = "RpcGroupManagementToken" + join_or_leave = "join" if is_join else "leave" + my_token = f"Token_for_{name}_{join_or_leave}" + while True: + # Retrieve token from store to signal start of rank join/leave critical section + returned = store.compare_set(token_key, "", my_token).decode() + if returned == my_token: + # Yield to the function this context manager wraps + yield + # Finished, now exit and release token + # Update from store to signal end of rank join/leave critical section + store.set(token_key, "") + # Other will wait for this token to be set before they execute + store.set(my_token, "Done") + break + else: + # Store will wait for the token to be released + try: + store.wait([returned]) + except RuntimeError: + logger.error( + "Group membership token %s timed out waiting for %s to be released.", + my_token, + returned, + ) + raise + + +def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join): + from . import api, TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + ret = agent._update_group_membership( + worker_info, my_devices, reverse_device_map, is_join + ) + return ret diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/api.py b/phivenv/Lib/site-packages/torch/distributed/rpc/api.py new file mode 100644 index 0000000000000000000000000000000000000000..92a3b587350a5decc6b16a1118b13e9809e1dfa6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/api.py @@ -0,0 +1,965 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs + +import collections +import contextlib +import functools +import inspect +import logging +import threading +from typing import Any, Generic, TYPE_CHECKING, TypeVar + +import torch +from torch._C._distributed_rpc import ( + _cleanup_python_rpc_handler, + _delete_all_user_and_unforked_owner_rrefs, + _destroy_rref_context, + _get_current_rpc_agent, + _invoke_remote_builtin, + _invoke_remote_python_udf, + _invoke_remote_torchscript, + _invoke_rpc_builtin, + _invoke_rpc_python_udf, + _invoke_rpc_torchscript, + _is_current_rpc_agent_set, + _reset_current_rpc_agent, + _set_and_start_rpc_agent, + get_rpc_timeout, + PyRRef, + RemoteProfilerManager, + WorkerInfo, +) +from torch.futures import Future + +from ._utils import _group_membership_management, _update_group_membership +from .constants import DEFAULT_SHUTDOWN_TIMEOUT, UNSET_RPC_TIMEOUT +from .internal import ( + _build_rpc_profiling_key, + _internal_rpc_pickler, + PythonUDF, + RPCExecMode, +) + + +__all__ = [ + "shutdown", + "get_worker_info", + "remote", + "rpc_sync", + "rpc_async", + "RRef", + "AllGatherStates", + "method_factory", + "new_method", +] + + +logger = logging.getLogger(__name__) + +# NB: Ignoring RRef leaks during shutdown. Without this, applications have to +# make sure there is no references to any RRef in the application code and +# Python GC has done its job to delete those RRefs. This is could result in bad +# debugging experiences especially when for large applications. Therefore, by +# default, we are going to ignore RRef leaks during shutdown. This is usually +# fine as shutdown means applications have done training and no longer care +# about states. +# +# To enable RRef leak checking, set this _ignore_rref_leak to False +_ignore_rref_leak = True +_default_pickler = _internal_rpc_pickler + + +@contextlib.contextmanager +def _use_rpc_pickler(rpc_pickler): + r""" + rpc_pickler: (.internal._InternalRPCPickler) Overrides the default RPC pickler + """ + global _default_pickler + _default_pickler = rpc_pickler + try: + yield + finally: + _default_pickler = _internal_rpc_pickler + + +def _require_initialized(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + if not _is_current_rpc_agent_set(): + raise RuntimeError( + "RPC has not been initialized. Call " + "torch.distributed.rpc.init_rpc first." + ) + return func(*args, **kwargs) + + return wrapper + + +class AllGatherStates: + def __init__(self): + # Each `gathered_objects` is an empty dict at beginning. + # The leader worker is elected as the first worker in a sorted worker + # name list. Whenever there is a worker entering `_all_gather()`, it + # runs `_gather_to_leader()` on the leader to add its own name and + # data obj to this dict. The leader also adds itself's name to the dict + # on calling `_all_gather()`. + # Once `set(gathered_objects.keys()) == _ALL_WORKER_NAMES`, the leader + # will broadcast the gathered dict to all follower workers and set their + # `gathered_objects` field and the `proceed_signal` field. + self.gathered_objects = {} + # All workers wait on this signal until it receives all gathered + # objects. + self.proceed_signal = threading.Event() + + +# States used by `def _all_gather()`. +# `_ALL_WORKER_NAMES` is initialized on initializing RPC layer. +_ALL_WORKER_NAMES: set[Any] = set() +_all_gather_dict_lock = threading.RLock() +_all_gather_sequence_id: dict[str, int] = {} +_all_gather_sequence_id_to_states: collections.defaultdict = collections.defaultdict( + AllGatherStates +) + + +def _init_rpc_states(agent): + worker_infos = agent.get_worker_infos() + global _ALL_WORKER_NAMES + _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} + + # NB: backend implementation might have already set the rpc_agent. + if not _is_current_rpc_agent_set(): + _set_and_start_rpc_agent(agent) + + +def _gather_to_leader(sequence_id, worker_name, obj, worker_names=None): + with _all_gather_dict_lock: + if not worker_names: + worker_names = _ALL_WORKER_NAMES + assert worker_name in worker_names, ( + f"{worker_name} is not expected by leader." + ) + states = _all_gather_sequence_id_to_states[sequence_id] + assert worker_name not in states.gathered_objects, ( + f"{worker_name} reported intent sequence id {sequence_id} twice. " + ) + states.gathered_objects[worker_name] = obj + if worker_names == set(states.gathered_objects.keys()): + states.proceed_signal.set() + + +def _broadcast_to_followers(sequence_id, objects_map): + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + assert not states.proceed_signal.is_set(), ( + f"Termination signal sequence id {sequence_id} got set twice." + ) + states.gathered_objects = objects_map + states.proceed_signal.set() + + +_thread_local_var = threading.local() + + +@contextlib.contextmanager +def _wait_all(): + r""" + A context manager that collects all futures returned by ``rpc_async`` and + waits them on the context manager's exit; relieving the user of needing + to explicitly call wait. + + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> with rpc._wait_all(): + >>> fut_1 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> fut_2 = rpc.rpc_async(dst, torch.add, (torch.ones(2, 2), 1)) + >>> #fut_1 and fut_2 are waited on + """ + _thread_local_var.future_list = [] + try: + yield + finally: + try: + torch.futures.wait_all(_thread_local_var.future_list) + finally: + del _thread_local_var.future_list + + +@_require_initialized +def _all_gather(obj, worker_names=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + This is similar to torch.distributed.all_gather(), but is using RPC. It + picks the worker with the smallest name (alphabetic order) as the leader. + Then all followers send their data ``obj`` to the leader. After the leader + has received all, it will broadcast the results back to all followers. This + function blocks until all workers have received the gathered results. + """ + if not worker_names: + assert _ALL_WORKER_NAMES is not None, ( + "`_ALL_WORKER_NAMES` is not initialized for `def _all_gather`." + ) + worker_names = _ALL_WORKER_NAMES + leader_name = min(worker_names) + + self_name = _get_current_rpc_agent().get_worker_info().name + + with _all_gather_dict_lock: + concat_names = "".join(sorted(worker_names)) + sequence_num = _all_gather_sequence_id.get(concat_names, 0) + _all_gather_sequence_id[concat_names] = sequence_num + 1 + sequence_id = concat_names + str(sequence_num) + + is_leader = leader_name == self_name + + if timeout == UNSET_RPC_TIMEOUT: + # Timeout is specified by agent for RPC calls + rpc_timeout = get_rpc_timeout() + # No timeout for signal + signal_timeout = None + elif timeout == DEFAULT_SHUTDOWN_TIMEOUT: + # No timeout for RPC + rpc_timeout = timeout + # No timeout for signal + signal_timeout = None + else: + # Signal and RPC timeout use the same timeout + signal_timeout = rpc_timeout = timeout + + # Phase 1: Followers send it's object to the leader + if is_leader: + _gather_to_leader(sequence_id, self_name, obj, worker_names) + else: + rpc_sync( + leader_name, + _gather_to_leader, + args=(sequence_id, self_name, obj, worker_names), + timeout=rpc_timeout, + ) + + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states[sequence_id] + + # Timeout is either set by function parameter or None (which is indefinite) + states.proceed_signal.wait(timeout=signal_timeout) + + # Phase 2: Leader broadcast gathered results to all followers + # Leader's signal is the first to be unblocked, after receiving all + # followers' data objects. + if is_leader: + worker_name_to_response_future_dict = {} + for follower_name in worker_names - {leader_name}: + fut = rpc_async( + follower_name, + _broadcast_to_followers, + args=(sequence_id, states.gathered_objects), + timeout=rpc_timeout, + ) + worker_name_to_response_future_dict[follower_name] = fut + + errors = [] + for follower_name, fut in worker_name_to_response_future_dict.items(): + try: + fut.wait() + except RuntimeError as ex: + errors.append((follower_name, ex)) + + if errors: + raise RuntimeError( + f"Followers {[e[0] for e in errors]} timed out in _all_gather " + f"after {rpc_timeout:.2f} seconds. The first exception is {errors[0][1]}" + ) + + # Clean up for the states using the sequence_id + with _all_gather_dict_lock: + states = _all_gather_sequence_id_to_states.pop(sequence_id) + return states.gathered_objects + + +@_require_initialized +def _barrier(worker_names): + r""" + Synchronizes local and remote RPC processes. + + This will block until all local and remote RPC processes specified under worker_names + reach this method to wait for all outstanding work to complete. + + Args: + worker_names (List[str]): The set of workers to synchronize. + + """ + try: + _all_gather(None, set(worker_names)) + except RuntimeError as ex: + logger.error("Failed to complete barrier, got error %s", ex) + + +@_require_initialized +def _wait_all_workers(timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Block until all local and remote RPC processes reach this method and wait + for all outstanding work to complete. Every RPC process must call this + method before exit to perform a graceful shutdown. This should be used to + terminate the RPC framework, and there is no guarantee that the RPC + framework will work after this method returns. + """ + try: + _all_gather(None, timeout=timeout) + except RuntimeError as ex: + logger.error( + "Failed to respond to 'Shutdown Proceed' in time, got error %s", ex + ) + raise ex + + +@_require_initialized +def shutdown(graceful=True, timeout=DEFAULT_SHUTDOWN_TIMEOUT): + r""" + Perform a shutdown of the RPC agent, and then destroy the RPC agent. This + stops the local agent from accepting outstanding requests, and shuts + down the RPC framework by terminating all RPC threads. If ``graceful=True``, + this will block until all local and remote RPC processes reach this method + and wait for all outstanding work to complete. Otherwise, if + ``graceful=False``, this is a local shutdown, and it does not wait for other + RPC processes to reach this method. + + .. warning:: + For :class:`~torch.futures.Future` objects returned by + :meth:`~torch.distributed.rpc.rpc_async`, ``future.wait()`` should not + be called after ``shutdown()``. + + Args: + graceful (bool): Whether to do a graceful shutdown or not. If True, + this will 1) wait until there is no pending system + messages for ``UserRRefs`` and delete them; 2) block + until all local and remote RPC processes have reached + this method and wait for all outstanding work to + complete. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> # do some work + >>> result = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(1), 1)) + >>> # ready to shutdown + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + if graceful: + try: + agent = _get_current_rpc_agent() + from torch._C._distributed_rpc import TensorPipeAgent + + if not isinstance(agent, TensorPipeAgent) or agent.is_static_group: + _wait_all_workers(timeout) + _delete_all_user_and_unforked_owner_rrefs() + agent.join(shutdown=True, timeout=timeout) + else: + # This is a dynamic group so we need to grab the token for the operation + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + with _group_membership_management(agent.store, my_name, False): + all_worker_infos = agent.get_worker_infos() + for worker in all_worker_infos: + if worker.name != my_name: + rpc_sync( + worker.name, + _update_group_membership, + args=(my_worker_info, [], {}, False), + ) + agent.join(shutdown=True, timeout=timeout) + finally: + # In case of errors, continue to complete the local shutdown. + _finalize_shutdown() + else: + _finalize_shutdown() + + +def _finalize_shutdown(): + try: + # This raises a `TORCH_CHECK()` exception on RRef leak detected. + _destroy_rref_context(_ignore_rref_leak) + finally: + _get_current_rpc_agent().shutdown() + # clean up python rpc handler in shutdown(), see comments in + # PythonRpcHandler::cleanup(), call it in python API because the + # cleanup() function has python dependency, it assumes python + # interpreter exists. + # No matter if RRef leak exception is raised, this clean-up code + # must run to avoid destruction segfault in Python 3.5. + # + # future.wait() should not be called after shutdown(). + # pythonRpcHandler is cleaned up in shutdown(), after + # shutdown(), python objects returned from rpc python call can not be + # resolved. + _cleanup_python_rpc_handler() + _reset_current_rpc_agent() + + +@_require_initialized +def get_worker_info(worker_name=None): + r""" + Get :class:`~torch.distributed.rpc.WorkerInfo` of a given worker name. + Use this :class:`~torch.distributed.rpc.WorkerInfo` to avoid passing an + expensive string on every invocation. + + Args: + worker_name (str): the string name of a worker. If ``None``, return the + the id of the current worker. (default ``None``) + + Returns: + :class:`~torch.distributed.rpc.WorkerInfo` instance for the given + ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the + current worker if ``worker_name`` is ``None``. + """ + if worker_name is not None: + return _get_current_rpc_agent().get_worker_info(worker_name) + else: + return _get_current_rpc_agent().get_worker_info() + + +def _to_worker_info(to): + if isinstance(to, WorkerInfo): + return to + elif isinstance(to, (str, int)): + return get_worker_info(to) + else: + raise ValueError(f"Cannot get WorkerInfo from name {to}") + + +def _rref_typeof_on_owner(rref, blocking: bool = True): + rref_type = type(rref.local_value()) + if blocking: + return rref_type + else: + # Wrap result into a completed Future. This is so that if blocking=`False` + # is specified, we return a future regardless of if this call is on user + # or owner. + future = Future[type]() + future.set_result(rref_type) + return future + + +def _rref_typeof_on_user( + rref, timeout: float = UNSET_RPC_TIMEOUT, blocking: bool = True +): + fut = rpc_async(rref.owner(), _rref_typeof_on_owner, args=(rref,), timeout=timeout) + if blocking: + return fut.wait() + else: + return fut + + +T = TypeVar("T") +GenericWithOneTypeVar = Generic[T] + + +if TYPE_CHECKING: + + class RRef(PyRRef[T], Generic[T]): + pass + +else: + try: + # Combine the implementation class and the type class. + class RRef(PyRRef, Generic[T]): + pass + + except TypeError: + # TypeError: metaclass conflict: the metaclass of a derived class + # must be a (non-strict) subclass of the metaclasses of all its bases + # Mypy doesn't understand __class__ (mypy bug #4177) + class RRefMeta(PyRRef.__class__, GenericWithOneTypeVar.__class__): # type: ignore[name-defined, misc, valid-type] + pass + + # Combine the implementation class and the type class. + # Types for classes expecting a certain generic parameter (mypy bug #7791) + class RRef(PyRRef, GenericWithOneTypeVar, metaclass=RRefMeta): # type: ignore[misc, no-redef, valid-type] + pass + + +# Install docstrings from `PyRRef` to `RRef`. +# +# This is for the fact that pybind11 generates the parameter +# `self` as type `rpc.PyRRef`, so a `:inherited-members:` +# under `.. autoclass:: RRef` does not work. +# we have to do the following process to replace `rpc.PyRRef` with `rpc.RRef`. +# +def method_factory(method_name, docstring): + def method(self, *args, **kwargs): + return getattr(super(RRef, self), method_name)(*args, **kwargs) + + if method.__doc__: + method.__doc__ = docstring + return method + + +for method_name, method in inspect.getmembers(PyRRef): + # Ignore magic methods, except "__str__". + if method_name.startswith("_") and method_name != "__str__": + continue + + # Get pybind11 generated docstring. + # It's like, + """ + to_here(self: torch.distributed.rpc.PyRRef, timeout: float=-1.0) -> object + + Blocking call that copies the value of the RRef from the owner + to the local node and returns it. If the current node is the + owner, returns a reference to the local value. + """ + docstring = getattr(method, "__doc__", None) + assert docstring is not None, "RRef user-facing methods should all have docstrings." + + # Do surgery on pybind11 generated docstrings. + docstring = docstring.replace( + "torch.distributed.rpc.PyRRef", "torch.distributed.rpc.RRef" + ) + + # Attach user-facing RRef method with modified docstring. + new_method = method_factory(method_name, docstring) + setattr(RRef, method_name, new_method) + + +@_require_initialized +def remote(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a remote call to run ``func`` on worker ``to`` and return an + :class:`~torch.distributed.rpc.RRef` to the result value immediately. + Worker ``to`` will be the owner of the returned + :class:`~torch.distributed.rpc.RRef`, and the worker calling ``remote`` is + a user. The owner manages the global reference count of its + :class:`~torch.distributed.rpc.RRef`, and the owner + :class:`~torch.distributed.rpc.RRef` is only destructed when globally there + are no living references to it. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + + timeout (float, optional): timeout in seconds for this remote call. If the + creation of this + :class:`~torch.distributed.rpc.RRef` on worker + ``to`` is not successfully processed on this + worker within this timeout, then the next time + there is an attempt to use the RRef (such as + ``to_here()``), a timeout will be raised + indicating this failure. A value of 0 indicates + an infinite timeout, i.e. a timeout error will + never be raised. If not provided, the default + value set during initialization or with + ``_set_rpc_timeout`` is used. + + Returns: + A user :class:`~torch.distributed.rpc.RRef` instance to the result + value. Use the blocking API :meth:`torch.distributed.rpc.RRef.to_here` + to retrieve the result value locally. + + .. warning :: + The ``remote`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned RRef is + confirmed by the owner, which can be checked using the + :meth:`torch.distributed.rpc.RRef.confirmed_by_owner` API. + + .. warning :: + Errors such as timeouts for the ``remote`` API are handled on a + best-effort basis. This means that when remote calls initiated by + ``remote`` fail, such as with a timeout error, we take a best-effort + approach to error handling. This means that errors are handled and set + on the resulting RRef on an asynchronous basis. If the RRef has not been + used by the application before this handling (such as ``to_here`` or + fork call), then future uses of the ``RRef`` will appropriately raise + errors. However, it is possible that the user application will use the + ``RRef`` before the errors are handled. In this case, errors may not be + raised as they have not yet been handled. + + Example:: + + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref1 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rref2 = rpc.remote("worker1", torch.add, args=(torch.ones(2), 1)) + >>> x = rref1.to_here() + rref2.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> rref = rpc.remote("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> rref.to_here() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + torch._C._log_api_usage_once("torch.distributed.rpc_remote") + qualified_name = torch.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, RPCExecMode.REMOTE, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, torch.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + rref = _invoke_remote_builtin( + dst_worker_info, qualified_name, timeout, *args, **kwargs + ) + elif isinstance(func, torch.jit.ScriptFunction): + rref = _invoke_remote_torchscript( + dst_worker_info.name, + torch._jit_internal._qualified_name(func), + timeout, + is_async_exec, + *args, + **kwargs, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + rref = _invoke_remote_python_udf( + dst_worker_info, pickled_python_udf, tensors, timeout, is_async_exec + ) + # attach profiling information + if should_profile: + assert torch.autograd._profiler_enabled() + assert rf is not None + fut = rf._call_end_callbacks_on_future(rref._get_future()) + rref._set_profiling_future(fut) + + return rref + + +def _invoke_rpc( + to, func, rpc_type, args=None, kwargs=None, rpc_timeout: float = UNSET_RPC_TIMEOUT +): + if not callable(func): + raise TypeError("function should be callable.") + + qualified_name = torch.jit._builtins._find_builtin(func) + dst_worker_info = _to_worker_info(to) + + should_profile = _get_should_profile() + + ctx_manager = _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info + ) + + with ctx_manager as rf: + args = args if args else () + kwargs = kwargs if kwargs else {} + + is_async_exec = hasattr(func, "_wrapped_async_rpc_function") + + if is_async_exec: + wrapped = func._wrapped_async_rpc_function + if isinstance(wrapped, torch.jit.ScriptFunction): + func = wrapped + + if qualified_name is not None: + fut = _invoke_rpc_builtin( + dst_worker_info, qualified_name, rpc_timeout, *args, **kwargs + ) + elif isinstance(func, torch.jit.ScriptFunction): + fut = _invoke_rpc_torchscript( + dst_worker_info.name, + torch._jit_internal._qualified_name(func), + args, + kwargs, + rpc_timeout, + is_async_exec, + ) + else: + (pickled_python_udf, tensors) = _default_pickler.serialize( + PythonUDF(func, args, kwargs) + ) + fut = _invoke_rpc_python_udf( + dst_worker_info, pickled_python_udf, tensors, rpc_timeout, is_async_exec + ) + if should_profile: + assert torch.autograd._profiler_enabled() + assert rf is not None + # Schedule profiling callbacks to run when the future completes. + # This returns a future that is completed when the original future + # completes and the profiling callbacks have been completed as well, + # to guarantee that fut.wait() completes the profiling. This new + # future will contain the same value as the original future. + fut = rf._call_end_callbacks_on_future(fut) + return fut + + +@_require_initialized +def rpc_sync(to, func, args=None, kwargs=None, timeout: float = UNSET_RPC_TIMEOUT): + r""" + Make a blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + Returns: + Returns the result of running ``func`` with ``args`` and ``kwargs``. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", torch.add, args=(torch.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + """ + torch._C._log_api_usage_once("torch.distributed.rpc_sync") + fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs, timeout) + return fut.wait() + + +@_require_initialized +def rpc_async(to, func, args=None, kwargs=None, timeout=UNSET_RPC_TIMEOUT): + r""" + Make a non-blocking RPC call to run function ``func`` on worker ``to``. RPC + messages are sent and received in parallel to execution of Python code. This + method is thread-safe. This method will immediately return a + :class:`~torch.futures.Future` that can be awaited on. + + Args: + to (str or WorkerInfo or int): name/rank/``WorkerInfo`` of the destination worker. + func (Callable): a callable function, such as Python callables, builtin + operators (e.g. :meth:`~torch.add`) and annotated + TorchScript functions. + args (tuple): the argument tuple for the ``func`` invocation. + kwargs (dict): is a dictionary of keyword arguments for the ``func`` + invocation. + timeout (float, optional): timeout in seconds to use for this RPC. If + the RPC does not complete in this amount of + time, an exception indicating it has + timed out will be raised. A value of 0 + indicates an infinite timeout, i.e. a timeout + error will never be raised. If not provided, + the default value set during initialization + or with ``_set_rpc_timeout`` is used. + + + Returns: + Returns a :class:`~torch.futures.Future` object that can be waited + on. When completed, the return value of ``func`` on ``args`` and + ``kwargs`` can be retrieved from the :class:`~torch.futures.Future` + object. + + .. warning :: + Using GPU tensors as arguments or return values of ``func`` is not + supported since we don't support sending GPU tensors over the wire. You + need to explicitly copy GPU tensors to CPU before using them as + arguments or return values of ``func``. + + .. warning :: + The ``rpc_async`` API does not copy storages of argument tensors until + sending them over the wire, which could be done by a different thread + depending on the RPC backend type. The caller should make sure that the + contents of those tensors stay intact until the returned + :class:`~torch.futures.Future` completes. + + Example:: + Make sure that ``MASTER_ADDR`` and ``MASTER_PORT`` are set properly + on both workers. Refer to :meth:`~torch.distributed.init_process_group` + API for more details. For example, + + export MASTER_ADDR=localhost + export MASTER_PORT=5678 + + Then run the following code in two different processes: + + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut1 = rpc.rpc_async("worker1", torch.add, args=(torch.ones(2), 3)) + >>> fut2 = rpc.rpc_async("worker1", min, args=(1, 2)) + >>> result = fut1.wait() + fut2.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + + Below is an example of running a TorchScript function using RPC. + + >>> # On both workers: + >>> @torch.jit.script + >>> def my_script_add(tensor: torch.Tensor, scalar: int): + >>> return torch.add(tensor, scalar) + + >>> # On worker 0: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3)) + >>> ret = fut.wait() + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> rpc.shutdown() + """ + torch._C._log_api_usage_once("torch.distributed.rpc_async") + fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs, timeout) + if hasattr(_thread_local_var, "future_list"): + _thread_local_var.future_list.append(fut) + return fut + + +def _get_should_profile(): + # Legacy profiler should be enabled. RPC profiling is not supported with + # Kineto profiler. + ActiveProfilerType = torch._C._profiler.ActiveProfilerType + return ( + torch.autograd._profiler_enabled() + and torch._C._autograd._profiler_type() == ActiveProfilerType.LEGACY # type: ignore[attr-defined] + ) + + +def _enable_rpc_profiler( + should_profile, qualified_name, func, rpc_type, dst_worker_info +): + ctx_manager = contextlib.nullcontext() + + if should_profile: + # Create appropriate string representation based on type of func + # (builtin, script, python) + if qualified_name is None: + func_name = ( + torch._jit_internal._qualified_name(func) + if isinstance(func, torch.jit.ScriptFunction) + else func.__qualname__ + ) + else: + func_name = qualified_name + # Build RPC profiling key. + rpc_profiling_key = _build_rpc_profiling_key( + rpc_type, + func_name, + get_worker_info().name, + dst_worker_info.name, + ) + RemoteProfilerManager.set_current_profiling_key(rpc_profiling_key) + # Mypy doesn't support re-def of a variable not in the same block (#1174) + ctx_manager = torch.autograd.profiler.record_function(rpc_profiling_key) # type: ignore[assignment] + + return ctx_manager diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/backend_registry.py b/phivenv/Lib/site-packages/torch/distributed/rpc/backend_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..0e420e55cfd2c4cd96e4540d93c85929c055ac1d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/backend_registry.py @@ -0,0 +1,430 @@ +# mypy: allow-untyped-defs + + +import collections +import enum +from typing import cast + +import torch +import torch.distributed as dist + +from . import api, constants as rpc_constants +from ._utils import _group_membership_management, _update_group_membership + + +__all__ = [ + "backend_registered", + "register_backend", + "construct_rpc_backend_options", + "init_backend", + "BackendValue", + "BackendType", +] + +BackendValue = collections.namedtuple( + "BackendValue", ["construct_rpc_backend_options_handler", "init_backend_handler"] +) + + +def _backend_type_repr(self): + return "BackendType." + self.name + + +_backend_type_doc = """ + An enum class of available backends. + + PyTorch ships with a builtin ``BackendType.TENSORPIPE`` backend. + Additional ones can be registered using the + :func:`~torch.distributed.rpc.backend_registry.register_backend` function. +""" + +# Create an enum type, `BackendType`, with empty members. +# Can't handle Function Enum API (mypy bug #9079) +BackendType = enum.Enum(value="BackendType", names={}) # type: ignore[misc] +# Unable to assign a function a method (mypy bug #2427) +BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + +if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + + +def backend_registered(backend_name): + """ + Checks if backend_name is registered as an RPC backend. + + Args: + backend_name (str): string to identify the RPC backend. + Returns: + True if the backend has been registered with ``register_backend``, else + False. + """ + return backend_name in BackendType.__members__.keys() + + +def register_backend( + backend_name, construct_rpc_backend_options_handler, init_backend_handler +): + """Registers a new RPC backend. + + Args: + backend_name (str): backend string to identify the handler. + construct_rpc_backend_options_handler (function): + Handler that is invoked when + rpc_backend.construct_rpc_backend_options(**dict) is called. + init_backend_handler (function): Handler that is invoked when the + `_init_rpc_backend()` function is called with a backend. + This returns the agent. + """ + global BackendType + if backend_registered(backend_name): + raise RuntimeError(f"RPC backend {backend_name}: already registered") + # Create a new enum type, `BackendType`, with extended members. + existing_enum_dict = {member.name: member.value for member in BackendType} + extended_enum_dict = dict( + { + backend_name: BackendValue( + construct_rpc_backend_options_handler=construct_rpc_backend_options_handler, + init_backend_handler=init_backend_handler, + ) + }, + **existing_enum_dict, + ) + # Can't handle Function Enum API (mypy bug #9079) + BackendType = enum.Enum(value="BackendType", names=extended_enum_dict) # type: ignore[misc] + # Unable to assign a function a method (mypy bug #2427) + BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] + if BackendType.__doc__: + BackendType.__doc__ = _backend_type_doc + return BackendType[backend_name] + + +def construct_rpc_backend_options( + backend, + rpc_timeout=rpc_constants.DEFAULT_RPC_TIMEOUT_SEC, + init_method=rpc_constants.DEFAULT_INIT_METHOD, + **kwargs, +): + return backend.value.construct_rpc_backend_options_handler( + rpc_timeout, init_method, **kwargs + ) + + +def init_backend(backend, *args, **kwargs): + return backend.value.init_backend_handler(*args, **kwargs) + + +def _init_process_group(store, rank, world_size): + # Initialize ProcessGroup. + process_group_timeout = rpc_constants.DEFAULT_PROCESS_GROUP_TIMEOUT + + # We're using a bunch of private APIs here since `new_group` requires the + # default group to be initialized. + group = dist.ProcessGroupGloo(store, rank, world_size, process_group_timeout) + + assert group is not None, "Failed to initialize default ProcessGroup." + + if (rank != -1) and (rank != group.rank()): + raise RuntimeError(f"rank argument {rank} doesn't match pg rank {group.rank()}") + if (world_size != -1) and (world_size != group.size()): + raise RuntimeError( + f"world_size argument {world_size} doesn't match pg size {group.size()}" + ) + return group + + +def _tensorpipe_construct_rpc_backend_options_handler( + rpc_timeout, + init_method, + num_worker_threads=rpc_constants.DEFAULT_NUM_WORKER_THREADS, + _transports=None, + _channels=None, + **kwargs, +): + from . import TensorPipeRpcBackendOptions + + return TensorPipeRpcBackendOptions( + rpc_timeout=rpc_timeout, + init_method=init_method, + num_worker_threads=num_worker_threads, + _transports=_transports, + _channels=_channels, + ) + + +def _tensorpipe_validate_devices(devices, device_count): + return all( + d.type == "cpu" or (d.type == "cuda" and 0 <= d.index < device_count) + for d in devices + ) + + +# detect if any worker has invalid device_map configurations, and return +# reverse device maps +def _tensorpipe_exchange_and_check_all_device_maps( + my_name, my_device_count, my_device_maps, my_devices, group +): + gathered: list[ + tuple[str, int, dict[str, dict[torch.device, torch.device]], list[torch.device]] + ] = [("", 0, {}, []) for _ in range(group.size())] + dist.all_gather_object( + gathered, (my_name, my_device_count, my_device_maps, my_devices), group + ) + all_names = [name for name, _, _, _ in gathered] + all_device_counts = {name: count for name, count, _, _ in gathered} + all_device_maps = {name: map_ for name, _, map_, _ in gathered} + all_devices = {name: devices for name, _, _, devices in gathered} + + _validate_device_maps(all_names, all_device_counts, all_device_maps, all_devices) + + # passed all checked, construct reverse mapping and get list of devices handled by this agent + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + my_devices = _create_device_list(my_devices, my_device_maps, reverse_device_maps) + return reverse_device_maps, my_devices + + +def _validate_device_maps( + all_names, all_device_counts, all_device_maps, all_devices, is_static_group=True +): + for node in all_names: + devices = all_devices[node] + if len(set(devices)) != len(devices): + raise ValueError(f"Node {node} has duplicated devices\ndevices = {devices}") + if not _tensorpipe_validate_devices(devices, all_device_counts[node]): + raise ValueError( + f"Node {node} has devices with invalid indices\n" + f"devices = {devices}\n" + f"device count = {all_device_counts[node]}" + ) + + for source_node in all_names: + # For dynamic group (non-static) do not check the target node name since it may not have joined yet + if is_static_group and not set(all_device_maps[source_node].keys()).issubset( + all_names + ): + raise ValueError( + f"Node {source_node} has invalid target node names in its device maps\n" + f"device maps = {all_device_maps[source_node].keys()}\n" + f"node names = {all_names}" + ) + for target_node, map_ in all_device_maps[source_node].items(): + if len(set(map_.values())) != len(map_): + raise ValueError( + f"Node {source_node} has duplicated target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}" + ) + if all_devices[source_node]: + if not set(map_.keys()).issubset(all_devices[source_node]): + raise ValueError( + f"Node {source_node} has unexpected source devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[source_node]}" + ) + elif not _tensorpipe_validate_devices( + map_.keys(), all_device_counts[source_node] + ): + raise ValueError( + f"Node {source_node} has source devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[source_node]}" + ) + if all_devices.get(target_node, []): + if not set(map_.values()).issubset(all_devices[target_node]): + raise ValueError( + f"Node {source_node} has unexpected target devices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"devices = {all_devices[target_node]}" + ) + elif target_node in all_device_counts and not _tensorpipe_validate_devices( + map_.values(), all_device_counts[target_node] + ): + raise ValueError( + f"Node {source_node} has target devices with invalid indices " + f"in its device map for {target_node}\n" + f"device map = {map_}\n" + f"device count = {all_device_counts[target_node]}" + ) + + +def _create_device_list(my_devices, my_device_maps, reverse_device_maps): + if not my_devices: + devices_set: set[torch.device] = set() + for map_ in my_device_maps.values(): + devices_set.update(map_.keys()) + for map_ in reverse_device_maps.values(): + devices_set.update(map_.keys()) + devices_set.discard(torch.device("cpu")) + my_devices = list(devices_set) + my_devices = sorted(my_devices, key=lambda d: d.index) + return my_devices + + +def _create_reverse_mapping(my_name, all_names, all_device_maps): + reverse_device_maps: dict[str, dict[torch.device, torch.device]] = {} + for node in all_names: + if my_name in all_device_maps[node]: + reverse_device_maps[node] = { + v: k for k, v in all_device_maps[node][my_name].items() + } + return reverse_device_maps + + +def _get_device_infos(): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, api._get_current_rpc_agent()) + opts = agent._get_backend_options() + device_count = torch.cuda.device_count() + if torch.cuda.is_available() and opts.devices: + torch.cuda.init() + return device_count, opts.device_maps, opts.devices + + +def _set_devices_and_reverse_device_map(agent): + from . import TensorPipeAgent + + agent = cast(TensorPipeAgent, agent) + # Group state is retrieved from local agent + # On initialization, tensorpipe agent retrieves information from all existing workers, so group state is valid + my_worker_info = agent.get_worker_info() + my_name = my_worker_info.name + all_worker_infos = agent.get_worker_infos() + # One round to get device_maps of all workers and construct reverse device maps + all_device_counts, all_device_maps, all_devices, all_names = {}, {}, {}, [] + for worker_info in all_worker_infos: + worker_name = worker_info.name + if worker_name != my_name: + # TODO: make async? + device_count, device_map, devices = api.rpc_sync( + worker_name, _get_device_infos + ) + else: + opts = agent._get_backend_options() + device_count, device_map, devices = ( + torch.cuda.device_count(), + opts.device_maps, + opts.devices, + ) + all_device_counts[worker_name] = device_count + all_device_maps[worker_name] = device_map + all_devices[worker_name] = devices + all_names.append(worker_name) + + _validate_device_maps( + all_names, + all_device_counts, + all_device_maps, + all_devices, + is_static_group=False, + ) + reverse_device_maps = _create_reverse_mapping(my_name, all_names, all_device_maps) + + # Perform RPC call to all workers, including itself, to include newly joined worker information and device maps + for worker_name in all_names: + # Set device list for each worker + all_devices[worker_name] = _create_device_list( + all_devices[worker_name], all_device_maps[worker_name], reverse_device_maps + ) + api.rpc_sync( + worker_name, + _update_group_membership, + args=(my_worker_info, all_devices[worker_name], reverse_device_maps, True), + ) + + +def _tensorpipe_init_backend_handler( + store, name, rank, world_size, rpc_backend_options +): + from . import TensorPipeAgent, TensorPipeRpcBackendOptions + + if not isinstance(store, dist.Store): + raise TypeError(f"`store` must be a c10d::Store. {store}") + + if not isinstance(rpc_backend_options, TensorPipeRpcBackendOptions): + raise TypeError( + f"`rpc_backend_options` must be a `TensorPipeRpcBackendOptions`. {rpc_backend_options}" + ) + + device_count = torch.cuda.device_count() + + is_static_group = True if world_size else False + # world_size is specified so this is a static group (ranks cannot join and leave) + if is_static_group: + # The agent's join method is required to behave like a barrier and perform + # collective operations, for which it relies on a process group, instead of + # re-implementing this on top of RPCs. + group = _init_process_group(store, rank, world_size) + + reverse_device_maps, devices = _tensorpipe_exchange_and_check_all_device_maps( + name, + device_count, + rpc_backend_options.device_maps, + rpc_backend_options.devices, + group, + ) + + if torch.cuda.is_available() and devices: + # It's necessary to initialize PyTorch CUDA states here (e.g., + # CUDACachingAllocator). If this is missing, we could hit errors like + # "allocator not initialized", because other processes might send + # CUDA-related RPC request to this process before user code in this + # process initializes its PyTorch CUDA states. + torch.cuda.init() + + # TODO: add try-except and destroy _agent in all processes if any fails. + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + reverse_device_maps, + devices, + ) + + api._init_rpc_states(agent) + + # Run one dummy round of RPC to initialize channels/transports. Without + # this, it's easy to hit timeout in rpc.shutdown() if there is no other RPC + # on that process before rpc.shutdown(), as the agent initialization can + # take longer than 5s. + api._all_gather(None, timeout=rpc_backend_options.rpc_timeout) + # Need a barrier here to make sure no peers leave before the rank0 finishes + # _all_gather + group.barrier().wait() + + return agent + # initialization for dynamic rpc (ranks can join and leave) + else: + with _group_membership_management(store, name, True): + # Construct TPAgent with empty reverse_device_map and devices + # these properties will be updated after initialization + agent = TensorPipeAgent( + store, + name, + rank, + world_size, + rpc_backend_options, + {}, + [], + ) + api._init_rpc_states(agent) + + try: + # Notify all workers in group this rank has joined and set devices and reverse_device_map + # This is a synchronous operation that completes once all existing ranks are updated + _set_devices_and_reverse_device_map(agent) + except Exception: + api.shutdown() + raise + return agent + + +register_backend( + "TENSORPIPE", + _tensorpipe_construct_rpc_backend_options_handler, + _tensorpipe_init_backend_handler, +) diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/constants.py b/phivenv/Lib/site-packages/torch/distributed/rpc/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..1af6a9fff092a196f021d0894a3bb522f18b1633 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/constants.py @@ -0,0 +1,24 @@ +from datetime import timedelta + +from torch._C._distributed_rpc import ( + _DEFAULT_INIT_METHOD, + _DEFAULT_NUM_WORKER_THREADS, + _DEFAULT_RPC_TIMEOUT_SEC, + _UNSET_RPC_TIMEOUT, +) + + +# For any RpcAgent. +DEFAULT_RPC_TIMEOUT_SEC: float = _DEFAULT_RPC_TIMEOUT_SEC +DEFAULT_INIT_METHOD: str = _DEFAULT_INIT_METHOD +DEFAULT_SHUTDOWN_TIMEOUT: float = 0 + +# For TensorPipeAgent. +DEFAULT_NUM_WORKER_THREADS: int = _DEFAULT_NUM_WORKER_THREADS +# Ensure that we don't time out when there are long periods of time without +# any operations against the underlying ProcessGroup. +DEFAULT_PROCESS_GROUP_TIMEOUT: timedelta = timedelta(milliseconds=2**31 - 1) +# Value indicating that timeout is not set for RPC call, and the default should be used. +UNSET_RPC_TIMEOUT: float = _UNSET_RPC_TIMEOUT + +__all__: list[str] = [] diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/functions.py b/phivenv/Lib/site-packages/torch/distributed/rpc/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..c5074176fdd5787017fba324329a94f12947b61c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/functions.py @@ -0,0 +1,169 @@ +# mypy: allow-untyped-defs +import functools + + +def async_execution(fn): + r""" + A decorator for a function indicating that the return value of the function + is guaranteed to be a :class:`~torch.futures.Future` object and this + function can run asynchronously on the RPC callee. More specifically, the + callee extracts the :class:`~torch.futures.Future` returned by the wrapped + function and installs subsequent processing steps as a callback to that + :class:`~torch.futures.Future`. The installed callback will read the value + from the :class:`~torch.futures.Future` when completed and send the + value back as the RPC response. That also means the returned + :class:`~torch.futures.Future` only exists on the callee side and is never + sent through RPC. This decorator is useful when the wrapped function's + (``fn``) execution needs to pause and resume due to, e.g., containing + :meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals. + + .. note:: To enable asynchronous execution, applications must pass the + function object returned by this decorator to RPC APIs. If RPC detected + attributes installed by this decorator, it knows that this function + returns a ``Future`` object and will handle that accordingly. + However, this does not mean this decorator has to be outmost one when + defining a function. For example, when combined with ``@staticmethod`` + or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the + inner decorator to allow the target function be recognized as a static + or class function. This target function can still execute asynchronously + because, when accessed, the static or class method preserves attributes + installed by ``@rpc.functions.async_execution``. + + + Example:: + The returned :class:`~torch.futures.Future` object can come from + :meth:`~torch.distributed.rpc.rpc_async`, + :meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future` + constructor. The example below shows directly using the + :class:`~torch.futures.Future` returned by + :meth:`~torch.futures.Future.then`. + + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @rpc.functions.async_execution + >>> def async_add_chained(to, x, y, z): + >>> # This function runs on "worker1" and returns immediately when + >>> # the callback is installed through the `then(cb)` API. In the + >>> # mean time, the `rpc_async` to "worker2" can run concurrently. + >>> # When the return value of that `rpc_async` arrives at + >>> # "worker1", "worker1" will run the lambda function accordingly + >>> # and set the value for the previously returned `Future`, which + >>> # will then trigger RPC to send the result back to "worker0". + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> # xdoctest: +SKIP + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add_chained, + >>> args=("worker2", torch.ones(2), 1, 1) + >>> ) + >>> print(ret) # prints tensor([3., 3.]) + + When combined with TorchScript decorators, this decorator must be the + outmost one. + + >>> from torch import Tensor + >>> from torch.futures import Future + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> @torch.jit.script + >>> def script_add(x: Tensor, y: Tensor) -> Tensor: + >>> return x + y + >>> + >>> @rpc.functions.async_execution + >>> @torch.jit.script + >>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: + >>> return rpc.rpc_async(to, script_add, (x, y)) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> async_add, + >>> args=("worker2", torch.ones(2), 1) + >>> ) + >>> print(ret) # prints tensor([2., 2.]) + + When combined with static or class method, this decorator must be the + inner one. + + >>> from torch.distributed import rpc + >>> + >>> # omitting setup and shutdown RPC + >>> + >>> # On all workers + >>> class AsyncExecutionClass: + >>> + >>> @staticmethod + >>> @rpc.functions.async_execution + >>> def static_async_add(to, x, y, z): + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> @classmethod + >>> @rpc.functions.async_execution + >>> def class_async_add(cls, to, x, y, z): + >>> ret_fut = torch.futures.Future() + >>> rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: ret_fut.set_result(fut.wait() + z) + >>> ) + >>> return ret_fut + >>> + >>> @rpc.functions.async_execution + >>> def bound_async_add(self, to, x, y, z): + >>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( + >>> lambda fut: fut.wait() + z + >>> ) + >>> + >>> # On worker0 + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.static_async_add, + >>> args=("worker2", torch.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> ret = rpc.rpc_sync( + >>> "worker1", + >>> AsyncExecutionClass.class_async_add, + >>> args=("worker2", torch.ones(2), 1, 2) + >>> ) + >>> print(ret) # prints tensor([4., 4.]) + + This decorator also works with RRef helpers, i.e., . + :meth:`torch.distributed.rpc.RRef.rpc_sync`, + :meth:`torch.distributed.rpc.RRef.rpc_async`, and + :meth:`torch.distributed.rpc.RRef.remote`. + + >>> from torch.distributed import rpc + >>> + >>> # reuse the AsyncExecutionClass class above + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2) + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait() + >>> print(ret) # prints tensor([4., 4.]) + >>> + >>> rref = rpc.remote("worker1", AsyncExecutionClass) + >>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() + >>> print(ret) # prints tensor([4., 4.]) + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + # Can't declare and use attributes of function objects (mypy#2087) + wrapper._wrapped_async_rpc_function = fn # type: ignore[attr-defined] + return wrapper diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/internal.py b/phivenv/Lib/site-packages/torch/distributed/rpc/internal.py new file mode 100644 index 0000000000000000000000000000000000000000..90aa3ffc8c293c68091166d62c40a2f814dd6a1c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/internal.py @@ -0,0 +1,285 @@ +# mypy: allow-untyped-defs +import collections +import copyreg +import io +import pickle +import sys +import threading +import traceback +from enum import Enum + +import torch +import torch.distributed as dist +from torch._C._distributed_rpc import _get_current_rpc_agent + + +__all__ = ["RPCExecMode", "serialize", "deserialize", "PythonUDF", "RemoteException"] + +# Thread local tensor tables to store tensors while pickling torch.Tensor +# objects +_thread_local_tensor_tables = threading.local() +_pickler = pickle.Pickler +_unpickler = pickle.Unpickler + + +class RPCExecMode(Enum): + SYNC = "sync" + ASYNC = "async" + ASYNC_JIT = "async_jit" + REMOTE = "remote" + + +class _InternalRPCPickler: + r""" + This class provides serialize() and deserialize() interfaces to serialize + data to be "binary string + tensor table" format + So for RPC python UDF function and args, non tensor data will be serialized + into regular binary string, tensor data will be put into thread local tensor + tables, this serialization format is consistent with builtin operator and args + using JIT pickler. This format will make tensor handling in C++ much easier, + e.g. attach tensor to distributed autograd graph in C++ + """ + + def __init__(self): + # Ignore type error because dispatch_table is defined in third-party package + self._dispatch_table = copyreg.dispatch_table.copy() # type: ignore[attr-defined] + self._dispatch_table[torch.Tensor] = self._tensor_reducer + # Used for registering customized picklers. + self._class_reducer_dict = {} + + def _register_reducer(self, obj_class, reducer): + # For the same class, only register the reducer once. + if obj_class not in self._class_reducer_dict: + self._class_reducer_dict[obj_class] = reducer + + @classmethod + def _tensor_receiver(cls, tensor_index): + global _thread_local_tensor_tables + return _thread_local_tensor_tables.recv_tables[tensor_index] + + def _tensor_reducer(self, tensor): + global _thread_local_tensor_tables + _thread_local_tensor_tables.send_tables.append(tensor) + tensor_index = len(_thread_local_tensor_tables.send_tables) - 1 + return (_InternalRPCPickler._tensor_receiver, (tensor_index,)) + + @classmethod + def _py_rref_receiver(cls, rref_fork_data): + return dist.rpc.PyRRef._deserialize(rref_fork_data) + + def _py_rref_reducer(self, py_rref): + rref_fork_data = py_rref._serialize() + return (_InternalRPCPickler._py_rref_receiver, (rref_fork_data,)) + + def _rref_reducer(self, rref): + return self._py_rref_reducer(rref) + + @classmethod + def _script_module_receiver(cls, script_module_serialized): + """ + Given a serialized representation of a ScriptModule created with torch.jit.save, + loads and returns the ScriptModule. + """ + f = io.BytesIO(script_module_serialized) + m = torch.jit.load(f) + return m + + def _script_module_reducer(self, script_module): + """ + Serializes a ScriptModule. + """ + f = io.BytesIO() + torch.jit.save(script_module, f) + return (_InternalRPCPickler._script_module_receiver, (f.getvalue(),)) + + def serialize(self, obj): + r""" + Serialize non tensor data into binary string, tensor data into + tensor table + """ + f = io.BytesIO() + p = _pickler(f) + p.dispatch_table = self._dispatch_table + + # rpc api could accept user picklers inheriting from _InternalRPCPickler to serialize rref, + # user picklers could have different initialization function from _InternalRPCPickler, + # but all the user picklers should call serialize() and use _rref_reducer to pickle rref + # in python. also, when _internal_rpc_pickler is imported to rpc/api.py, rpc.RRef is not + # compiled yet, it is not good place to access rpc.RRef inside _InternalRPCPickler constructor, + # so putting rref's dispatch table here + # + # The return value of a `rpc.remote(..)` call is type of `rpc.PyRRef`. + # The deserialized RRef object on an RPC receiver side is type of `rpc.PyRRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.PyRRef] = self._py_rref_reducer # type: ignore[index] + # An RRef created locally by RRef Python constructor is type of `rpc.RRef`. + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[dist.rpc.RRef] = self._rref_reducer # type: ignore[index] + + # Add dispatch pickling for ScriptModule or its subclass. + if isinstance(obj, torch.jit.ScriptModule): + # Ignore type error because dispatch_table is defined in third-party package + p.dispatch_table[obj.__class__] = self._script_module_reducer # type: ignore[index] + + # Install customized picklers. + for class_name in self._class_reducer_dict.keys(): + p.dispatch_table[class_name] = self._class_reducer_dict[class_name] # type: ignore[index] + + # save _thread_local_tensor_tables.send_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "send_tables"): + old_send_tables = _thread_local_tensor_tables.send_tables + else: + old_send_tables = None + _thread_local_tensor_tables.send_tables = [] + + p.dump(obj) + + # restore _thread_local_tensor_tables.send_tables if return + # from nested call, otherwise clean up the table + tensors = _thread_local_tensor_tables.send_tables + if old_send_tables is not None: + _thread_local_tensor_tables.send_tables = old_send_tables + else: + del _thread_local_tensor_tables.send_tables + + return (f.getvalue(), tensors) + + def deserialize(self, binary_data, tensor_table): + r""" + Deserialize binary string + tensor table to original obj + """ + # save _thread_local_tensor_tables.recv_tables if it is in nested call + global _thread_local_tensor_tables + if hasattr(_thread_local_tensor_tables, "recv_tables"): + old_recv_tables = _thread_local_tensor_tables.recv_tables + else: + old_recv_tables = None + _thread_local_tensor_tables.recv_tables = tensor_table + + try: + unpickler = _unpickler(io.BytesIO(binary_data)) + ret = unpickler.load() + except AttributeError as e: + # Occurs when function is not found on module/class during + # unpickling. + except_str = ( + str(e) + + """ Default RPC pickler does not serialize + function code. Ensure that UDFs are defined on both caller and + callee modules.""" + ) + ret = AttributeError(except_str) + # Ensure the stack trace gets preserved + ret.__cause__ = e + + # restore _thread_local_tensor_tables.recv_tables if return + # from nested call, otherwise clean up the table + if old_recv_tables is not None: + _thread_local_tensor_tables.recv_tables = old_recv_tables + else: + del _thread_local_tensor_tables.recv_tables + + return ret + + +# Create _internal_rpc_pickler only once to initialize _dispatch_table only once +_internal_rpc_pickler = _InternalRPCPickler() + + +def serialize(obj): + return _internal_rpc_pickler.serialize(obj) + + +def deserialize(binary_data, tensor_table): + return _internal_rpc_pickler.deserialize(binary_data, tensor_table) + + +def _run_function(python_udf): + r""" + This function is exclusively called from C++. + See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``. + + Runs a Python UDF and returns its return value. + Wraps any exception in ``RemoteException`` if the function raises. + """ + try: + if isinstance(python_udf, AttributeError): + raise python_udf + result = python_udf.func(*python_udf.args, **python_udf.kwargs) + except Exception as e: + # except str = exception info + traceback string + except_str = ( + f"On {_get_current_rpc_agent().get_worker_info()}:\n" + f"{repr(e)}\n{traceback.format_exc()}" + ) + print(except_str, file=sys.stderr) + result = RemoteException(except_str, type(e)) + return result + + +def _handle_exception(result): + if isinstance(result, RemoteException): + exception_msg = result.msg.encode("utf-8").decode("unicode_escape") + # We wrap exception re-creation here in case some exception classes + # cannot be constructed directly from a string. + exc = None + try: + exc = result.exception_type(exception_msg) + except BaseException as e: + raise RuntimeError( # noqa: B904 + f"Failed to create original exception type. Error msg was {str(e)}" + f" Original exception on remote side was {exception_msg}" + ) from e + + if exc is not None: + raise exc + + +def _build_rpc_profiling_key( + exec_type, func_name, current_worker_name, dst_worker_name +): + """ + Builds the key that RPC calls are profiled with using the autograd profiler. + This will be the name of the corresponding Event recorded in the profiler. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dst_worker_name (str): Name of the destination worker. + + Returns: + String representing profiling key + """ + profile_key = ( + f"rpc_{exec_type.value}#{func_name}({current_worker_name} -> {dst_worker_name})" + ) + return profile_key + + +def _start_record_function(exec_type, func_name, current_worker_name, dest_worker_name): + """ + This function should be called from RPC/RRef functions to create a + RecordFunction object for profiling. This function also runs the before + callbacks that start the profiling, though the user is responsible for + running the appropriate callbacks when the function to be profiled finishes. + + Args: + exec_type (RPCExecMode): Type of RPC/RRef call + func_name (str): Name of function being profiled. + current_worker_name (str): Name of current worker. + dest_worker_name (str): Name of the destination worker. + + Returns: + An instance of `torch.autograd._RecordFunction`. + """ + assert torch.autograd._profiler_enabled(), "Autograd profiler should be enabled." + profile_key = f"rpc_{exec_type.value}#{str(func_name)}({current_worker_name} -> {dest_worker_name})" + rf = torch.autograd._RecordFunction() # type: ignore[attr-defined] + torch.autograd._run_before_callbacks(rf, profile_key) # type: ignore[attr-defined] + return rf + + +PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"]) +RemoteException = collections.namedtuple("RemoteException", ["msg", "exception_type"]) diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/options.py b/phivenv/Lib/site-packages/torch/distributed/rpc/options.py new file mode 100644 index 0000000000000000000000000000000000000000..872657fbcbad74c04b8d6228421ee82f2b4468f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/options.py @@ -0,0 +1,180 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch + +from . import _is_tensorpipe_available, constants as rpc_contants + + +DeviceType = Union[int, str, torch.device] + +__all__ = ["TensorPipeRpcBackendOptions"] + + +def _to_device(device: DeviceType) -> torch.device: + device = torch.device(device) + if device.type != "cuda": + raise ValueError( + "`set_devices` expect a list of CUDA devices, but got " + f"device type {device.type}." + ) + return device + + +def _to_device_map( + device_map: dict[DeviceType, DeviceType], +) -> dict[torch.device, torch.device]: + full_device_map: dict[torch.device, torch.device] = {} + reverse_map: dict[torch.device, torch.device] = {} + for k, v in device_map.items(): + k, v = torch.device(k), torch.device(v) + if v in reverse_map: + raise ValueError( + "`device_map` only supports 1-to-1 mapping, " + f"trying to map {k} and {reverse_map[v]} to {v}" + ) + full_device_map[k] = v + reverse_map[v] = k + return full_device_map + + +def _to_device_list(devices: list[DeviceType]) -> list[torch.device]: + return list(map(_to_device, devices)) + + +if _is_tensorpipe_available: # type: ignore[has-type] + from torch._C._distributed_rpc import _TensorPipeRpcBackendOptionsBase +else: + _TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc] + + +class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): + r""" + The backend options for + :class:`~torch.distributed.rpc.TensorPipeAgent`, derived from + :class:`~torch.distributed.rpc.RpcBackendOptions`. + + Args: + num_worker_threads (int, optional): The number of threads in the + thread-pool used by + :class:`~torch.distributed.rpc.TensorPipeAgent` to execute + requests (default: 16). + rpc_timeout (float, optional): The default timeout, in seconds, + for RPC requests (default: 60 seconds). If the RPC has not + completed in this timeframe, an exception indicating so will + be raised. Callers can override this timeout for individual + RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and + :meth:`~torch.distributed.rpc.rpc_async` if necessary. + init_method (str, optional): The URL to initialize the distributed + store used for rendezvous. It takes any value accepted for the + same argument of :meth:`~torch.distributed.init_process_group` + (default: ``env://``). + device_maps (Dict[str, Dict], optional): Device placement mappings from + this worker to the callee. Key is the callee worker name and value + the dictionary (``Dict`` of ``int``, ``str``, or ``torch.device``) + that maps this worker's devices to the callee worker's devices. + (default: ``None``) + devices (List[int, str, or ``torch.device``], optional): all local + CUDA devices used by RPC agent. By Default, it will be initialized + to all local devices from its own ``device_maps`` and corresponding + devices from its peers' ``device_maps``. When processing CUDA RPC + requests, the agent will properly synchronize CUDA streams for + all devices in this ``List``. + """ + + def __init__( + self, + *, + num_worker_threads: int = rpc_contants.DEFAULT_NUM_WORKER_THREADS, + rpc_timeout: float = rpc_contants.DEFAULT_RPC_TIMEOUT_SEC, + init_method: str = rpc_contants.DEFAULT_INIT_METHOD, + device_maps: Optional[dict[str, dict[DeviceType, DeviceType]]] = None, + devices: Optional[list[DeviceType]] = None, + _transports: Optional[list] = None, + _channels: Optional[list] = None, + ): + full_device_maps = ( + {} + if device_maps is None + else {k: _to_device_map(v) for k, v in device_maps.items()} + ) + full_device_list = [] if devices is None else _to_device_list(devices) + super().__init__( + num_worker_threads, + _transports, + _channels, + rpc_timeout, + init_method, + full_device_maps, + full_device_list, + ) + + def set_device_map(self, to: str, device_map: dict[DeviceType, DeviceType]): + r""" + Set device mapping between each RPC caller and callee pair. This + function can be called multiple times to incrementally add + device placement configurations. + + Args: + to (str): Callee name. + device_map (Dict of int, str, or torch.device): Device placement + mappings from this worker to the callee. This map must be + invertible. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> # both workers + >>> def add(x, y): + >>> print(x) # tensor([1., 1.], device='cuda:1') + >>> return x + y, (x + y).to(2) + >>> + >>> # on worker 0 + >>> options = TensorPipeRpcBackendOptions( + >>> num_worker_threads=8, + >>> device_maps={"worker1": {0: 1}} + >>> # maps worker0's cuda:0 to worker1's cuda:1 + >>> ) + >>> options.set_device_map("worker1", {1: 2}) + >>> # maps worker0's cuda:1 to worker1's cuda:2 + >>> + >>> rpc.init_rpc( + >>> "worker0", + >>> rank=0, + >>> world_size=2, + >>> backend=rpc.BackendType.TENSORPIPE, + >>> rpc_backend_options=options + >>> ) + >>> + >>> x = torch.ones(2) + >>> rets = rpc.rpc_sync("worker1", add, args=(x.to(0), 1)) + >>> # The first argument will be moved to cuda:1 on worker1. When + >>> # sending the return value back, it will follow the invert of + >>> # the device map, and hence will be moved back to cuda:0 and + >>> # cuda:1 on worker0 + >>> print(rets[0]) # tensor([2., 2.], device='cuda:0') + >>> print(rets[1]) # tensor([2., 2.], device='cuda:1') + """ + full_device_map = _to_device_map(device_map) + curr_device_maps = super().device_maps + + if to in curr_device_maps: + for k, v in full_device_map.items(): + if k in curr_device_maps[to] and v != curr_device_maps[to][k]: + raise ValueError( + "`set_device_map` only supports 1-to-1 mapping, trying" + f" to map {k} to {v} and {curr_device_maps[to][k]}" + ) + + super()._set_device_map(to, full_device_map) + + def set_devices(self, devices: list[DeviceType]): + r""" + Set local devices used by the TensorPipe RPC agent. When processing + CUDA RPC requests, the TensorPipe RPC agent will properly synchronize + CUDA streams for all devices in this ``List``. + + Args: + devices (List of int, str, or torch.device): local devices used by + the TensorPipe RPC agent. + """ + self.devices = _to_device_list(devices) diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/rref_proxy.py b/phivenv/Lib/site-packages/torch/distributed/rpc/rref_proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..d0e1c362188ff0143ff95b9c19ded9c961f564f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/rref_proxy.py @@ -0,0 +1,80 @@ +# mypy: allow-untyped-defs +from functools import partial + +import torch +from torch.futures import Future + +from . import functions, rpc_async +from .constants import UNSET_RPC_TIMEOUT + + +def _local_invoke(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +@functions.async_execution +def _local_invoke_async_execution(rref, func_name, args, kwargs): + return getattr(rref.local_value(), func_name)(*args, **kwargs) + + +def _invoke_rpc(rref, rpc_api, func_name, timeout, *args, **kwargs): + def _rref_type_cont(rref_fut): + rref_type = rref_fut.value() + + _invoke_func = _local_invoke + # Bypass ScriptModules when checking for async function attribute. + bypass_type = issubclass(rref_type, torch.jit.ScriptModule) or issubclass( + rref_type, torch._C.ScriptModule + ) + if not bypass_type: + func = getattr(rref_type, func_name) + if hasattr(func, "_wrapped_async_rpc_function"): + _invoke_func = _local_invoke_async_execution + + return rpc_api( + rref.owner(), + _invoke_func, + args=(rref, func_name, args, kwargs), + timeout=timeout, + ) + + rref_fut = rref._get_type(timeout=timeout, blocking=False) + + if rpc_api != rpc_async: + rref_fut.wait() + return _rref_type_cont(rref_fut) + else: + # A little explanation on this. + # rpc_async returns a Future pointing to the return value of `func_name`, it returns a `Future[T]` + # Calling _rref_type_cont from the `then` lambda causes Future wrapping. IOW, `then` returns a `Future[Future[T]]` + # To address that, we return a Future that is completed with the result of the async call. + result: Future = Future() + + def _wrap_rref_type_cont(fut): + try: + _rref_type_cont(fut).then(_complete_op) + except BaseException as ex: + result.set_exception(ex) + + def _complete_op(fut): + try: + result.set_result(fut.value()) + except BaseException as ex: + result.set_exception(ex) + + rref_fut.then(_wrap_rref_type_cont) + return result + + +# This class manages proxied RPC API calls for RRefs. It is entirely used from +# C++ (see python_rpc_handler.cpp). +class RRefProxy: + def __init__(self, rref, rpc_api, timeout=UNSET_RPC_TIMEOUT): + self.rref = rref + self.rpc_api = rpc_api + self.rpc_timeout = timeout + + def __getattr__(self, func_name): + return partial( + _invoke_rpc, self.rref, self.rpc_api, func_name, self.rpc_timeout + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/rpc/server_process_global_profiler.py b/phivenv/Lib/site-packages/torch/distributed/rpc/server_process_global_profiler.py new file mode 100644 index 0000000000000000000000000000000000000000..00c1d2ff647889e8ddc275ba49436d9f0dd3ff20 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/rpc/server_process_global_profiler.py @@ -0,0 +1,186 @@ +#!/usr/bin/python3 +# mypy: allow-untyped-defs + +import itertools + +import torch +from torch.autograd.profiler_legacy import profile + +from . import ( + _disable_server_process_global_profiler, + _enable_server_process_global_profiler, +) + + +__all__: list[str] = [] + + +class _server_process_global_profile(profile): + """ + It has the same API as ``torch.autograd.profiler.profile`` class, + except that it enables profiling on all threads running RPC server request callbacks. + + Context manager that manages autograd profiler state and holds a summary of results. + Under the hood it just records events of functions being executed in C++ and + exposes those events to Python. You can wrap any code into it and it will + only report runtime of PyTorch functions. + Note: profiler is thread local and is automatically propagated into the async tasks + + Args: + enabled (bool, optional): Setting this to False makes this context manager a no-op. + Default: ``True``. + + use_cuda (bool, optional): Enables timing of CUDA events as well using the cudaEvent API. + Adds approximately 4us of overhead to each tensor operation. + Default: ``False`` + + record_shapes (bool, optional): If shapes recording is set, information + about input dimensions will be collected. This allows one to see which + dimensions have been used under the hood and further group by them + using prof.key_averages(group_by_input_shape=True). Please note that + shape recording might skew your profiling data. It is recommended to + use separate runs with and without shape recording to validate the timing. + Most likely the skew will be negligible for bottom most events (in a case + of nested function calls). But for higher level functions the total + self cpu time might be artificially increased because of the shape + collection. + + profile_memory (bool, optional): Whether to report memory usage, default: ``False`` + + .. warning:: + Enabling memory profiling incurs additional profiler overhead + + .. warning:: + Due to some CUDA multiprocessing limitations (see :ref:`multiprocessing-cuda-note`), + one cannot use the profiler with ``use_cuda = True`` to benchmark + DataLoaders with ``num_workers > 0``. If you wish to benchmark data loading, + please use ``use_cuda = False`` or ``num_workers = 0``. + + Example: + >>> # xdoctest: +SKIP + >>> # On worker 0: + >>> import torch + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker0", rank=0, world_size=2) + >>> x, y = torch.tensor(1), torch.tensor(2) + >>> outer_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) + >>> outer_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, torch.add, (x, y)) + >>> inner_profile_rref = rpc.remote( + ... dst_worker_name, rpc._server_process_global_profile + ... ) + >>> inner_profile_rref.rpc_sync().__enter__() + >>> rpc.rpc_sync(dst_worker_name, torch.sub, (x, y)) + >>> inner_profile_rref.rpc_sync().__exit__(None, None, None) + >>> outer_profile_rref.rpc_sync().__exit__(None, None, None) + >>> print(inner_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 85.06% 76.275us 100.00% 89.667us 89.667us 1 + empty 14.94% 13.392us 14.94% 13.392us 13.392us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 89.667us + >>> print(outer_profile_rref.rpc_sync().key_averages()) + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Name Self CPU total % Self CPU total CPU total % CPU total CPU time avg Number of Calls + --------- --------------- --------------- --------------- --------------- --------------- --------------- + sub 35.65% 76.275us 41.91% 89.667us 89.667us 1 + empty 12.67% 27.101us 12.67% 27.101us 13.551us 2 + add 51.68% 110.550us 58.09% 124.259us 124.259us 1 + --------- --------------- --------------- --------------- --------------- --------------- --------------- + Self CPU time total: 213.926us + >>> rpc.shutdown() + + >>> # On worker 1: + >>> import torch.distributed.rpc as rpc + >>> rpc.init_rpc("worker1", rank=1, world_size=2) + >>> # wait for worker 0 to finish work, and then shutdown. + >>> rpc.shutdown() + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def __enter__(self): + """ + Turn on server-side process-global profiling. + This enables thread-local profiler on all RPC threads running server-side request callbacks. + """ + if not self.enabled: + return + + if self.entered: # type: ignore[has-type] + raise RuntimeError("autograd profiler traces are not reentrant") + self.entered = True + + profiler_kind = ( + torch.autograd.ProfilerState.CUDA + if self.use_cuda + else torch.autograd.ProfilerState.CPU + ) + profiler_config = torch.autograd.ProfilerConfig( + profiler_kind, + self.record_shapes, + self.profile_memory, + False, + False, + False, + torch.profiler._ExperimentalConfig(), + ) + _enable_server_process_global_profiler(profiler_config) + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Turn off server-side process-global profiling. + Aggregate all profiling events recorded by RPC threads. + + These attributes are assigned on exiting context. + + Attributes: + function_events (torch.autograd.profiler.EventList). It's a list that has helper + methods, like 1) show record items in a pretty-print table. + 2) do averaging by grouping on keys. 3) and more. + + process_global_function_events (List[torch.autograd.profiler.FunctionEvent]). + It's a list of ``FunctionEvent`` elements. Every element is a profiling result + of an RPC request handling within the profiling range. + """ + if not self.enabled: + return + + process_global_events = _disable_server_process_global_profiler() + + # Every element in this list is a thread profiling result from an RPC request handling. + process_global_function_events = [] + for thread_local_events in process_global_events: + # Parse from ``Event``s to ``FunctionEvent``s. + thread_local_function_events = ( + torch.autograd.profiler_legacy._parse_legacy_records( + thread_local_events + ) + ) + thread_local_function_events.sort( + key=lambda function_event: [ + function_event.time_range.start, + -(function_event.time_range.end), + ] + ) + process_global_function_events.append(thread_local_function_events) + + flattened_function_events = list( + itertools.chain.from_iterable(process_global_function_events) + ) + self.function_events = torch.autograd.profiler_util.EventList( + flattened_function_events, + use_device="cuda" if self.use_cuda else None, + profile_memory=self.profile_memory, + ) + self.function_events._build_tree() + + self.process_global_function_events = process_global_function_events + + return False diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__init__.py b/phivenv/Lib/site-packages/torch/distributed/tensor/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c1975c28a0e1f1589bc2f0038a91d77dbb498ee0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/__init__.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import torch +import torch.distributed.tensor._ops # force import all built-in dtensor ops +from torch.distributed.device_mesh import DeviceMesh, init_device_mesh # noqa: F401 +from torch.distributed.tensor._api import ( + distribute_module, + distribute_tensor, + DTensor, + empty, + full, + ones, + rand, + randn, + zeros, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) +from torch.optim.optimizer import ( + _foreach_supported_types as _optim_foreach_supported_types, +) +from torch.utils._foreach_utils import ( + _foreach_supported_types as _util_foreach_supported_types, +) + + +# All public APIs from dtensor package +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "Shard", + "Replicate", + "Partial", + "Placement", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +# For weights_only torch.load +from ._dtensor_spec import DTensorSpec as _DTensorSpec, TensorMeta as _TensorMeta + + +torch.serialization.add_safe_globals( + [ + DeviceMesh, + _DTensorSpec, + _TensorMeta, + DTensor, + Partial, + Replicate, + Shard, + ] +) + + +# Append DTensor to the list of supported types for foreach implementation for optimizer +# and clip_grad_norm_ so that we will try to use foreach over the for-loop implementation on CUDA. +if DTensor not in _optim_foreach_supported_types: + _optim_foreach_supported_types.append(DTensor) + +if DTensor not in _util_foreach_supported_types: + _util_foreach_supported_types.append(DTensor) # type: ignore[arg-type] + + +# Set namespace for exposed private names +DTensor.__module__ = "torch.distributed.tensor" +distribute_tensor.__module__ = "torch.distributed.tensor" +distribute_module.__module__ = "torch.distributed.tensor" +ones.__module__ = "torch.distributed.tensor" +empty.__module__ = "torch.distributed.tensor" +full.__module__ = "torch.distributed.tensor" +rand.__module__ = "torch.distributed.tensor" +randn.__module__ = "torch.distributed.tensor" +zeros.__module__ = "torch.distributed.tensor" diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64268625aa48327850b51a898cf2e22d4ae157a5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2b51a6eb6c8cb49714e9aad2d0cbc0b2a701e829 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_collective_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_collective_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f07ee598fcda36ffae88838d54af981974a985fb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_collective_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_dispatch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_dispatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..265c062ec0fd2f0bd8360ffea0cff5554fa33bb3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_dispatch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c7a93e0dd727a2773d6b69bd8b00d5336606338 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_dtensor_spec.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_op_schema.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_op_schema.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..761a4dedb1dbcaa1d33ef9cbe53cd44fd8c02c64 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_op_schema.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_random.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_random.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81fa74ebe8a69258a10fd075fe8cec6c83d0246a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_random.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_redistribute.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_redistribute.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1b254715ac10ad4b58f4fd2a63613207ab63527 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_redistribute.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_sharding_prop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_sharding_prop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b6302ae0bda564a929d33d3948a72f1d021a1b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_sharding_prop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81efe8c2fee0a0fc951c60e7d5d0294362799a7c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_shards_wrapper.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_tp_conv.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_tp_conv.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd2612a0863ad4095d58ab2872acb959116bbaa5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_tp_conv.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5cdc5450ab0cc8e69464b0d1d542c47065d0cced Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/device_mesh.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/device_mesh.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cdfe7d94cdcbef428740ee998ef6af71912a0f3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/device_mesh.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/placement_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/placement_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca3b188c348d9b52c8175b3bb0a557d7380c7fc0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/__pycache__/placement_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_api.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_api.py new file mode 100644 index 0000000000000000000000000000000000000000..7355a1cb26c8b645551b510d046d263c1e853270 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_api.py @@ -0,0 +1,1315 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import inspect +import warnings +from collections.abc import Sequence +from typing import Any, Callable, cast, Optional +from typing_extensions import deprecated + +import torch +import torch.distributed.tensor._dispatch as op_dispatch +import torch.distributed.tensor._random as random +import torch.nn as nn +from torch._export.wrappers import mark_subclass_constructor_exportable_experimental +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor._collective_utils import check_tensor_meta, mesh_broadcast +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._redistribute import ( + Redistribute, + redistribute_local_tensor, +) +from torch.distributed.tensor._utils import ( + compute_global_tensor_info, + compute_local_shape_and_global_offset, + normalize_to_torch_size, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +__all__ = [ + "DTensor", + "distribute_tensor", + "distribute_module", + "ones", + "empty", + "full", + "rand", + "randn", + "zeros", +] + +aten = torch.ops.aten + + +# NOTE [Autograd interaction between torch.Tensor] +# +# The autograd functions defined below are being used by the public +# facing APIs (i.e. from_local, to_local) to ensure DTensor to work +# together with torch.Tensor within the autograd engine. This +# allows DTensor to only exist on part of the module hierarchy. +# +# As an example, we have the a module that consists of submodules +# A, B, and C, the execution flow would be like: +# input(torch.Tensor) -> Module A -> Module B -> Module C -> output (torch.Tensor) +# +# Suppose I only want to make Module B be a sharded module with +# DTensor params, the following forward/backward should work: +# +# input(torch.Tensor) -> Module A +# -> DTensor input (from_local) -> Sharded Module B -> DTensor output +# -> torch.Tensor output (to_local) -> Module C +# +# So from_local/to_local must be Autograd functions. +# +class _ToTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, + input: "DTensor", + grad_placements: Optional[Sequence[Placement]], + ): + ctx.dtensor_spec = input._spec + ctx.grad_placements = grad_placements + local_tensor = input._local_tensor + + # We need to return a fresh Tensor object there as autograd metadata + # will be inplaced into it. So we don't want to pollute the Tensor + # object stored in the _local_tensor of this DTensor. + return local_tensor.view_as(local_tensor) + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): # type: ignore[override] + dtensor_spec = ctx.dtensor_spec + mesh = dtensor_spec.mesh + grad_placements = ctx.grad_placements + dtensor_meta = dtensor_spec.tensor_meta + + _, tensor_stride = compute_global_tensor_info( + grad_output, mesh, dtensor_spec.placements + ) + tensor_stride = tuple(tensor_stride) + grad_placements = grad_placements or dtensor_spec.placements + grad_spec = DTensorSpec( + mesh, + grad_placements, + tensor_meta=TensorMeta( + shape=dtensor_meta.shape, + stride=tensor_stride, + dtype=dtensor_meta.dtype, + ), + ) + + return ( + DTensor( + grad_output, + grad_spec, + requires_grad=grad_output.requires_grad, + ), + None, + ) + + +class _FromTorchTensor(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + ctx, # pyre-ignore[2]: Parameter must be annotated. + input: torch.Tensor, + device_mesh: DeviceMesh, + placements: tuple[Placement, ...], + run_check: bool, + shape: Optional[torch.Size] = None, + stride: Optional[tuple[int, ...]] = None, + ) -> "DTensor": + ctx.previous_placement = placements + ctx.previous_device_mesh = device_mesh + + if shape and stride: + tensor_shape, tensor_stride = shape, stride + elif not shape and not stride: + # if it's not by default run_check, we assume user is certain that each + # rank has the same tensor shape, and we just use that to calculate the + # global shape + global_shape, global_stride = compute_global_tensor_info( + input, device_mesh, placements + ) + tensor_shape, tensor_stride = torch.Size(global_shape), tuple(global_stride) + else: + raise RuntimeError( + f"Found shape:{shape}, stride:{stride}.", + "Please pass both shape and stride at the same time.", + ) + + if device_mesh.get_coordinate() is None: + # if the global rank is not participating in the device mesh, we + # simply set the local tensor to an empty tensor + input = input.new_empty(0, requires_grad=input.requires_grad) + elif run_check: + # TODO: support uneven sharding when global shape/stride not passed, by + # building the global TensorMeta during check_tensor_meta + check_shape_stride = not shape and not stride + check_tensor_meta(input, check_shape_stride=check_shape_stride) + # TODO: See if we need to make this run_check logic + # have a corresponding backward. + for idx, placement in enumerate(placements): + if placement.is_replicate(): + # broadcast rank 0 tensor to all ranks + # only broadcast if run_check is True + input = input.contiguous() + mesh_broadcast(input, device_mesh, mesh_dim=idx) + + dist_spec = DTensorSpec( + device_mesh, + placements, + tensor_meta=TensorMeta( + tensor_shape, + tensor_stride, + input.dtype, + ), + ) + + # We want a fresh Tensor object that shares memory with the input tensor + dist_tensor = DTensor( + input.view_as(input), + dist_spec, + # requires_grad of the dist tensor depends on if input + # requires_grad or not + requires_grad=input.requires_grad, + ) + return dist_tensor + + @staticmethod + def backward(ctx, grad_output: "DTensor"): # type: ignore[override] + previous_placement = ctx.previous_placement + previous_device_mesh = ctx.previous_device_mesh + + # reshard to the placement when creating DistributedTensor + # so that the gradient layout matches, and we could return + # local gradients directly + if grad_output.placements != previous_placement: + current_spec = grad_output._spec + target_spec = DTensorSpec( + previous_device_mesh, + previous_placement, + tensor_meta=grad_output._spec.tensor_meta, + ) + local_tensor = grad_output._local_tensor + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, is_backward=True + ) + # TODO: return the redistributed local tensor directly without + # differentiable backward. see if this make sense for all cases. + return output, None, None, None, None, None + + # TODO: backward is also differentiable now, add a test + # to test higher level gradients. + return grad_output.to_local(), None, None, None, None, None + + +class DTensor(torch.Tensor): + """ + ``DTensor`` (Distributed Tensor) is a subclass of ``torch.Tensor`` that provides single-device like + abstraction to program with multi-device ``torch.Tensor``. It describes the distributed tensor sharding + layout (DTensor Layout) through the :class:`DeviceMesh` and following types of :class:`Placement`: + + * :class:`Shard`: Tensor sharded on the tensor dimension ``dim`` on the devices of the ``DeviceMesh`` dimension + * :class:`Replicate`: Tensor replicated on the devices of the ``DeviceMesh`` dimension + * :class:`Partial`: Tensor is pending reduction on the devices of the ``DeviceMesh`` dimension + + When calling PyTorch operators, ``DTensor`` overrides the PyTorch operators to perform sharded computation and issue + communications whenever necessary. Along with the operator computation, ``DTensor`` will transform or propagate the + placements (DTensor Layout) properly (based on the operator semantic itself) and generate new ``DTensor`` outputs. + + To ensure numerical correctness of the ``DTensor`` sharded computation when calling PyTorch operators, ``DTensor`` + requires every Tensor argument of the operator be DTensor. + + .. note:: Directly using the Tensor subclass constructor here is not the recommended way to create a ``DTensor`` + (i.e. it does not handle autograd correctly hence is not the public API). Please refer to the `create_dtensor`_ + section to see how to create a ``DTensor``. + """ + + _local_tensor: torch.Tensor + _spec: DTensorSpec + __slots__ = ["_local_tensor", "_spec"] + + # _op_dispatcher instance as a class attribute to handle runtime dispatching logic + _op_dispatcher: op_dispatch.OpDispatcher = op_dispatch.OpDispatcher() + + @staticmethod + @torch._disable_dynamo + def __new__( + cls, + local_tensor: torch.Tensor, + spec: DTensorSpec, + *, + requires_grad: bool, + ) -> "DTensor": + """ + Construct a DTensor from a local tensor, device mesh, and placement and + other tensor properties (i.e. shape, requires_grad, strides, etc). + + .. note:: This is not a public API and it's only supposed to be used by the + operator implementations and internals. If you want to construct a + DTensor from a local tensor, consider using ``DTensor.from_local``, if + you want to construct a DTensor from a "global" tensor (where you + already have tensor initialized and want to shard this tensor), + consider using ``distribute_tensor``. + """ + if local_tensor.requires_grad and not requires_grad: + warnings.warn( + "To construct DTensor from torch.Tensor, it's recommended to " + "use local_tensor.detach() and make requires_grad consistent." + ) + + # new method instruct wrapper tensor from local_tensor and add + # placement spec, it does not do actual distribution + assert spec.tensor_meta is not None, "TensorMeta should not be None!" + r = torch.Tensor._make_wrapper_subclass( + cls, + spec.tensor_meta.shape, + strides=spec.tensor_meta.stride, + dtype=local_tensor.dtype, + device=local_tensor.device, + layout=local_tensor.layout, + requires_grad=requires_grad, + ) + + r._spec = spec + r._local_tensor = local_tensor + return r + + @torch._disable_dynamo + @mark_subclass_constructor_exportable_experimental + def __init__(self, *args, **kwargs): + super().__init__() + + # pyre-fixme[14]: `__repr__` overrides method defined in `DTensor` inconsistently. + # pyre-fixme[3]: Return type must be annotated. + def __repr__(self): # type: ignore[override] + # TODO: consider all_gather the local tensors for better debugging + return f"DTensor(local_tensor={self._local_tensor}, device_mesh={self._spec.mesh}, placements={self._spec.placements})" + + def __tensor_flatten__(self): + """ + protocol to inform how to flatten a DTensor to local tensor + for PT2 tracing + """ + return ["_local_tensor"], (self._spec, self.requires_grad) + + @staticmethod + def __tensor_unflatten__(inner_tensors, flatten_spec, outer_size, outer_stride): + assert flatten_spec is not None, ( + "Expecting spec to be not None from `__tensor_flatten__` return value!" + ) + local_tensor = inner_tensors["_local_tensor"] + spec, requires_grad = flatten_spec + unflatten_tensor_meta = TensorMeta( + shape=outer_size, + stride=outer_stride, + dtype=spec.tensor_meta.dtype, + ) + unflatten_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=unflatten_tensor_meta, + ) + return DTensor( + local_tensor, + unflatten_spec, + requires_grad=requires_grad, + ) + + def __coerce_tangent_metadata__(self): + if not any(isinstance(p, Partial) for p in self.placements): + return self + placements = [ + Replicate() if isinstance(p, Partial) else p for p in self.placements + ] + return self.redistribute(device_mesh=self.device_mesh, placements=placements) + + def __coerce_same_metadata_as_tangent__(self, flatten_spec, expected_type=None): + if expected_type is not None: + return None + + (spec, _) = flatten_spec # Result of tensor_flatten() + return self.redistribute( + device_mesh=self.device_mesh, + placements=spec.placements, + ) + + @classmethod + @torch._disable_dynamo + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + return DTensor._op_dispatcher.dispatch( + func, + args, + kwargs or {}, + ) + + @staticmethod + def from_local( + local_tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + run_check: bool = False, + shape: Optional[torch.Size] = None, + stride: Optional[tuple[int, ...]] = None, + ) -> "DTensor": + """ + Create a :class:`DTensor` from a local torch.Tensor on each rank + according to the ``device_mesh`` and ``placements`` specified. + + Args: + local_tensor (torch.Tensor): local torch.Tensor on each rank. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + tensor, if not specified, must be called under a DeviceMesh + context manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the local torch.Tensor on DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + + Keyword args: + run_check (bool, optional): at a cost of extra communications, perform + sanity check across ranks to check each local tensor's meta information + to ensure correctness. If have :class:`Replicate` in ``placements``, the + data on first rank of the device mesh dimension will be broadcasted + to other ranks. default: False + shape (torch.Size, optional): A List of int which specifies the size of + DTensor which build on top of `local_tensor`. Note this needs to be + provided if the shape of ``local_tensor`` are different across the ranks. + If not provided, ``shape`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + stride (tuple, optional): A List of int which specifies the stride of DTensor. + If not provided, ``stride`` will be computed assuming the given distributed + tensor is evenly sharded across ranks. default: None + + Returns: + A :class:`DTensor` object + + .. note:: When ``run_check=False``, it is the user's responsibility to ensure the + local tensor passed in is correct across ranks (i.e. the tensor is sharded for + the ``Shard(dim)`` placement or replicated for the ``Replicate()`` placement). + If not, the behavior of the created DTensor is undefined. + + .. note:: ``from_local`` is differentiable, the `requires_grad` of the created + `DTensor` object will depend on if `local_tensor` requires_grad or not. + """ + # if same shape/dtype, no need to run_check, if not, must allgather + # the metadatas to check the size/dtype across ranks + # There should be no data communication unless there's replication + # strategy, where we broadcast the replication from the first rank + # in the mesh dimension + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + + # convert the local tensor to desired device base on device mesh's device_type + if device_type != local_tensor.device.type and not local_tensor.is_meta: + local_tensor = local_tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + else: + placements = list(placements) + for idx, placement in enumerate(placements): + # normalize shard dim to be positive + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + placements[idx] = Shard(placement.dim + local_tensor.ndim) + + # `from_local` is differentiable, and the gradient of the dist tensor this function + # created should flow back the gradients to the local_tensor, so we call an autograd + # function to construct the dist tensor instead. + return _FromTorchTensor.apply( # pyre-ignore[16]: autograd func + local_tensor, + device_mesh, + tuple(placements), + run_check, + shape, + stride, + ) + + def to_local( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Get the local tensor of this DTensor on its current rank. For sharding it returns + a local shard of the logical tensor view, for replication it returns the replica on + its current rank. + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the Tensor returned from this + function. + `to_local` converts DTensor to local tensor and the returned local tensor + might not be used as the original DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original DTensor layout. + If not specified, we will assume the gradient layout remains the same + as the original DTensor and use that for gradient computation. + + Returns: + A :class:`torch.Tensor` or ``AsyncCollectiveTensor`` object. it represents the + local tensor on its current rank. When an ``AsyncCollectiveTensor`` object is returned, + it means the local tensor is not ready yet (i.e. communication is not finished). In this + case, user needs to call ``wait`` to wait the local tensor to be ready. + + .. note:: ``to_local`` is differentiable, the ``requires_grad`` of the local tensor returned + will depend on if the `DTensor` requires_grad or not. + """ + if not torch.is_grad_enabled(): + return self._local_tensor + + if grad_placements is not None and not isinstance(grad_placements, tuple): + grad_placements = tuple(grad_placements) + return _ToTorchTensor.apply( + self, grad_placements + ) # pyre-ignore[16]: autograd func + + def redistribute( + self, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + async_op: bool = False, + forward_dtype: Optional[torch.dtype] = None, + backward_dtype: Optional[torch.dtype] = None, + ) -> "DTensor": + """ + ``redistribute`` performs necessary collective operations that redistribute the current + DTensor from its current placements to a new placements, or from its current DeviceMesh + to a new DeviceMesh. i.e. we can turn a Sharded DTensor to a Replicated DTensor by + specifying a Replicate placement for each dimension of the DeviceMesh. + + When redistributing from current to the new placements on one device mesh dimension, we + will perform the following operations including communication collective or local operation: + + 1. ``Shard(dim)`` -> ``Replicate()``: ``all_gather`` + 2. ``Shard(src_dim)`` -> ``Shard(dst_dim)``: ``all_to_all`` + 3. ``Replicate()`` -> ``Shard(dim)``: local chunking (i.e. ``torch.chunk``) + 4. ``Partial()`` -> ``Replicate()``: ``all_reduce`` + 5. ``Partial()`` -> ``Shard(dim)``: ``reduce_scatter`` + + + ``redistribute`` would correctly figure out the necessary redistribute steps for DTensors + that are created either on 1-D or N-D DeviceMesh. + + Args: + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. If not specified, it would use the current DTensor's DeviceMesh. + default: None + placements (List[:class:`Placement`], optional): the new placements that + describes how to place the DTensor into the DeviceMesh, must + have the same number of elements as ``device_mesh.ndim``. + default: replicate on all mesh dimensions + + Keyword args: + async_op (bool, optional): whether to perform the DTensor redistribute operation + asynchronously or not. Default: False + forward_dtype (torch.dtype, optional): the local tensor datatype can be converted to + ``forward_dtype`` before redistributing the local tensor in its forward. + The result DTensor will be in ``forward_dtype`` Default: None. + backward_dtype (torch.dtype, optional): the local tensor datatype can be converted to + ``backward_dtype`` before redistributing the local tensor in its backward. + The result DTensor gradient would be converted back to the current DTensor dtype. Default: None + + Returns: + A :class:`DTensor` object + + .. note:: ``redistribute`` is differentiable, which means user do not need to worry about + the backward formula of the redistribute operation. + + .. note:: ``redistribute`` currently only supports redistributing DTensor on the same DeviceMesh, + Please file an issue if you need to redistribute DTensor to different DeviceMesh. + """ + # NOTE: This redistribute API currently only supports out + # of place redistribution, i.e. it always create a new + # DTensor object and leave the original one unchanged. + + # if device_mesh is not specified, use the current device_mesh + device_mesh = device_mesh or self.device_mesh + # raise error if new placements not specified + if placements is None: + raise RuntimeError("placements is needed for redistribute!") + + placements = list(placements) + for i, placement in enumerate(placements): + if placement.is_partial(): + raise RuntimeError( + "Can not redistribute to Partial, redistributing to Partial is for internal use only!" + ) + elif isinstance(placement, Shard) and placement.dim < 0: + # normalize shard dim to be positive + placements[i] = Shard(placement.dim + self.ndim) + placements = tuple(placements) + + # pyre-fixme[16]: `Redistribute` has no attribute `apply`. + return Redistribute.apply( + self, device_mesh, placements, async_op, forward_dtype, backward_dtype + ) + + def full_tensor( + self, *, grad_placements: Optional[Sequence[Placement]] = None + ) -> torch.Tensor: + """ + Return the full tensor of this DTensor. It will perform necessary collectives + to gather the local tensors from other ranks in its DeviceMesh and concatenate + them together. It's a syntatic sugar of the following code: + + ``dtensor.redistribute(placements=[Replicate()] * mesh.ndim).to_local()`` + + Keyword args: + grad_placements (List[:class:`Placement`], optional): the placements describes + the future layout of any gradient layout of the full Tensor returned from this + function. + `full_tensor` converts DTensor to a full torch.Tensor and the returned torch.tensor + might not be used as the original replicated DTensor layout later in the code. This + argument is the hint that user can give to autograd in case the gradient + layout of the returned tensor does not match the original replicated DTensor layout. + If not specified, we will assume the gradient layout of the full tensor be replicated. + + Returns: + A :class:`torch.Tensor` object that represents the full tensor of this DTensor. + + .. note:: ``full_tensor`` is differentiable. + """ + + redist_res = self.redistribute( + placements=[Replicate()] * self.device_mesh.ndim, async_op=False + ) + return _ToTorchTensor.apply(redist_res, grad_placements) + + @property + def device_mesh(self) -> DeviceMesh: + """ + The :class:`DeviceMesh` attribute that associates with this DTensor object. + + .. note:: ``device_mesh`` is a read-only property, it can not be set. + """ + return self._spec.mesh + + @property + def placements(self) -> tuple[Placement, ...]: + """ + The placements attribute of this DTensor that describes the layout of this + DTensor on the its DeviceMesh. + + .. note:: ``placements`` is a read-only property, it can not be set. + """ + return self._spec.placements + + def __create_write_items__(self, fqn: str, object: Any): + from torch.distributed.checkpoint.planner_helpers import ( + _create_write_items_for_dtensor, + ) + + if hasattr(self._local_tensor, "__create_write_items__"): + return self._local_tensor.__create_write_items__(fqn, object) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_write_items_for_dtensor(fqn, object)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __create_chunk_list__(self): + """ + Return a list of ChunkStorageMetadata, which is a dataclass that describes the size/offset of the local shard/replica + on current rank. For DTensor, each rank will have a single local shard/replica, so the returned list usually only + has one element. + + This dunder method is primariy used for distributed checkpoint purpose. + + Returns: + A List[:class:`ChunkStorageMetadata`] object that represents the shard size/offset on the current rank. + """ + from torch.distributed.checkpoint.planner_helpers import ( + _create_chunk_from_dtensor, + ) + + if hasattr(self._local_tensor, "__create_chunk_list__"): + return self._local_tensor.__create_chunk_list__() # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return [_create_chunk_from_dtensor(self)] + else: + raise RuntimeError("Unsupported tensor type!") + + def __get_tensor_shard__(self, index): + if hasattr(self._local_tensor, "__get_tensor_shard__"): + return self._local_tensor.__get_tensor_shard__(index) # type: ignore[attr-defined] + elif isinstance(self._local_tensor, torch.Tensor): + return self.to_local() + else: + raise RuntimeError("Unsupported tensor type!") + + +def distribute_tensor( + tensor: torch.Tensor, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + *, + src_data_rank: Optional[int] = 0, +) -> DTensor: + """ + Distribute a leaf ``torch.Tensor`` (i.e. nn.Parameter/buffers) to the ``device_mesh`` according + to the ``placements`` specified. The rank of ``device_mesh`` and ``placements`` must be the + same. The ``tensor`` to distribute is the logical or "global" tensor, and the API would use + the ``tensor`` from first rank of the DeviceMesh dimension as the source of truth to preserve + the single-device semantic. If you want to construct a DTensor in the middle of the Autograd + computation, please use :meth:`DTensor.from_local` instead. + + Args: + tensor (torch.Tensor): torch.Tensor to be distributed. Note that if you + want to shard a tensor on a dimension that is not evenly divisible by + the number of devices in that mesh dimension, we use ``torch.chunk`` + semantic to shard the tensor and scatter the shards. The uneven sharding + behavior is experimental and subject to change. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to distribute the + tensor, if not specified, must be called under a DeviceMesh context + manager, default: None + placements (List[:class:`Placement`], optional): the placements that + describes how to place the tensor on DeviceMesh, must have the same + number of elements as ``device_mesh.ndim``. If not specified, we will + by default replicate the tensor across the ``device_mesh`` from the + first rank of each dimension of the `device_mesh`. + + Keyword args: + src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is + used by :meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. + By default, we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve + the single-device semantic. If passing ``None`` explicitly, :meth:`distribute_tensor` simply uses + its local data instead of trying to preserve the single-device semantic via scatter/broadcast. + Default: 0 + + Returns: + A :class:`DTensor` or ``XLAShardedTensor`` object. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_tensor`` + return `XLAShardedTensor` instead. see `this issue `__ + for more details. The XLA integration is experimental and subject to change. + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_tensor") + + # get default device mesh if there's nothing specified + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # call PyTorch/XLA SPMD for `xla` backend type device mesh. + # This returns XLAShardedTensor + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_tensor, + ) + + return xla_distribute_tensor(tensor, device_mesh, placements) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + if not tensor.is_leaf: + raise RuntimeError( + "`distribute_tensor` should be used to distribute leaf tensors! but found non-leaf tensor!" + ) + + # convert tensor to the corresponding device type if it's not in that device type + if device_type != tensor.device.type and not tensor.is_meta: + tensor = tensor.to(device_type) + + # set default placements to replicated if not specified + if placements is None: + placements = [Replicate() for _ in range(device_mesh.ndim)] + + if len(placements) != device_mesh.ndim: + raise ValueError( + f"`placements` must have the same length as `device_mesh.ndim`! " + f"Found placements length: {len(placements)}, and device_mesh.ndim: {device_mesh.ndim}." + ) + if isinstance(tensor, DTensor): + # if the tensor is already a DTensor, we need to check: + # 1. if the we can further shard this DTensor if the two device mesh belong to + # the same parenet mesh and further sharding is possible. + # 2. check if device mesh and placements are the same + if tensor.device_mesh != device_mesh: + raise ValueError( + f"Cannot distribute a DTensor with device mesh {tensor.device_mesh} " + f"to a different device mesh {device_mesh}." + ) + if tensor.placements != tuple(placements): + raise ValueError( + f"Cannot distribute a DTensor with placements {tensor.placements} " + f"to a different placements {placements}. do you want to call " + f"`redistribute` instead?" + ) + return tensor + + local_tensor = tensor.detach() + + # TODO(xilun): address sharding order + # distribute the tensor according to the placements. + placements = list(placements) + for idx, placement in enumerate(placements): + if placement.is_shard(): + placement = cast(Shard, placement) + if placement.dim < 0: + # normalize shard placement dim + placement = Shard(placement.dim + tensor.ndim) + placements[idx] = placement + local_tensor = placement._shard_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + elif placement.is_replicate(): + placement = cast(Replicate, placement) + local_tensor = placement._replicate_tensor( + local_tensor, device_mesh, idx, src_data_rank + ) + else: + raise RuntimeError( + f"Trying to distribute tensor with unsupported placements {placement} on device mesh dimension {idx}!" + ) + placements = tuple(placements) + + assert local_tensor is not None, "distributing a tensor should not be None" + # detach the local tensor passed to DTensor since after the construction + # of DTensor, autograd would work on top of DTensor instead of local tensor + spec = DTensorSpec( + mesh=device_mesh, + placements=placements, + tensor_meta=TensorMeta( + shape=tensor.size(), + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + return DTensor( + local_tensor.requires_grad_(tensor.requires_grad), + spec, + requires_grad=tensor.requires_grad, + ) + + +@deprecated("Please use `distribute_tensor` with `src_data_rank=None` instead.") +def _shard_tensor( + full_tensor: torch.Tensor, + placements: Sequence[Shard], + device_mesh: Optional[DeviceMesh] = None, +) -> "DTensor": + """ + Locally shards a full tensor based on indicated sharding arrangement, and + returns a DTensor containing the local shard. + + .. warning:: This is a private API that is subject to change. It skips the + communication otherwise required by `distribute_tensor`. It is only + applicable to cases where all ranks have the same `full_tensor`. For + example, in distributed inference all ranks load from the same + checkpoint. This API will not check for data equality between ranks, it + is thus user's responsibility to ensure the `full_tensor` is the same + across ranks. + + Args: + full_tensor (torch.Tensor): the full tensor to be sharded. + placements (Sequence[:class:`Shard`]): the placements that + describes how to place the local tensor on DeviceMesh. + device_mesh (:class:`DeviceMesh`, optional): DeviceMesh to place the + DTensor. Must have same dimension as the number of placements. + If not specified, would be retrieve from current context. + + Returns: + A :class:`DTensor` object with the shard as its local tensor. + + Examples: + >>> # xdoctest: +SKIP("need world_size and rank") + >>> device_mesh = dist.init_device_mesh("cuda", (world_size,)) + >>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}") + >>> dtensor = _shard_tensor(full_tensor, [Shard(1)], device_mesh) + """ + return distribute_tensor(full_tensor, device_mesh, placements, src_data_rank=None) + + +def distribute_module( + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + partition_fn: Optional[Callable[[str, nn.Module, DeviceMesh], None]] = None, + input_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, + output_fn: Optional[Callable[[nn.Module, Any, DeviceMesh], None]] = None, +) -> nn.Module: + """ + This function expose three functions to control the parameters/inputs/outputs of the module: + + 1. To perform sharding on the module before runtime execution by specifying the + ``partition_fn`` (i.e. allow user to convert Module parameters to :class:`DTensor` + parameters according to the `partition_fn` specified). + 2. To control the inputs or outputs of the module during runtime execution by + specifying the ``input_fn`` and ``output_fn``. (i.e. convert the input to + :class:`DTensor`, convert the output back to ``torch.Tensor``) + + Args: + module (:class:`nn.Module`): user module to be partitioned. + device_mesh (:class:`DeviceMesh`): the device mesh to place the module. + partition_fn (Callable): the function to partition parameters (i.e. shard certain + parameters across the ``device_mesh``). If ``partition_fn`` is not specified, + by default we replicate all module parameters of ``module`` across the mesh. + input_fn (Callable): specify the input distribution, i.e. could control how the + input of the module is sharded. ``input_fn`` will be installed as a module + ``forward_pre_hook`` (pre forward hook). + output_fn (Callable): specify the output distribution, i.e. could control how the + output is sharded, or convert it back to torch.Tensor. ``output_fn`` will be + installed as a module ``forward_hook`` (post forward hook). + + Returns: + A module that contains parameters/buffers that are all ``DTensor`` s. + + .. note:: + When initialize the DeviceMesh with the ``xla`` device_type, ``distribute_module`` + return nn.Module with PyTorch/XLA SPMD annotated parameters. See + `this issue `__ + for more details. The XLA integration is experimental and subject to change. + + """ + + torch._C._log_api_usage_once("torch.dtensor.distribute_module") + + already_distributed = getattr(module, "_distribute_module_applied", False) + if already_distributed: + raise RuntimeError( + "distribute_module should only be called once on a module, " + "but it has already been called on this module!" + ) + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + device_type = device_mesh.device_type + if device_type == "xla": + try: + # This function annotates all module parameters for auto-partitioning with + # PyTorch/XLA SPMD or explicitly partition to :class:`XLAShardedTensor` parameters + # according to the `partition_fn` specified. + from torch_xla.distributed.spmd import ( # type:ignore[import] + xla_distribute_module, + ) + + return xla_distribute_module( + module, device_mesh, partition_fn, input_fn, output_fn + ) # type:ignore[return-value] + except ImportError as e: + msg = "To use DTensor API with xla, you must install the torch_xla package!" + raise ImportError(msg) from e + + def replicate_module_params_buffers(m: nn.Module, mesh: DeviceMesh) -> None: + # This function loop over the immediate module parameters and + # buffers, replicate all non DTensor params/buffers to DTensor + # parameters/buffers, if they have not been partitioned in the + # partition_fn, we can't easily use `module._apply` here + # because we don't know what happened inside partition_fn as + # user could do anything, i.e. install hooks, and we want to + # preserve those. + full_replicate = [Replicate()] * mesh.ndim + for key, param in m._parameters.items(): + if param is not None and not isinstance(param, DTensor): + m.register_parameter( + key, + nn.Parameter(distribute_tensor(param.data, mesh, full_replicate)), + ) + for key, buffer in m._buffers.items(): + if buffer is not None and not isinstance(buffer, DTensor): + m._buffers[key] = distribute_tensor(buffer, mesh, full_replicate) + + if partition_fn is None: + # if partition_fn not specified, we by default replicate + # all module params/buffers + for name, submod in module.named_modules(): + replicate_module_params_buffers(submod, device_mesh) + else: + # apply partition_fun to submodules + for name, submod in module.named_modules(): + partition_fn(name, submod, device_mesh) + replicate_module_params_buffers(submod, device_mesh) + + # register input_fn as module forward pre hook + if input_fn is not None: + # check the input_fn signature + num_args = len(inspect.signature(input_fn).parameters) + if num_args == 2: + # input_fn only takes in inputs and device mesh + warnings.warn( + "Deprecating input_fn that takes two arguments (inputs, device_mesh), " + "please use input_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_pre_hook( + lambda _, inputs: input_fn(inputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + # input_fn takes in module, inputs, device mesh + module.register_forward_pre_hook( + lambda mod, inputs: input_fn(mod, inputs, device_mesh) + ) + else: + raise ValueError( + f"input_fn should take in 3 arguments, but got {num_args} arguments!" + ) + # register output_fn as module forward hook + if output_fn is not None: + num_args = len(inspect.signature(output_fn).parameters) + if num_args == 2: + # output_fn only takes in outputs and device mesh + warnings.warn( + "Deprecating output_fn that takes two arguments (inputs, device_mesh), " + "please use output_fn that takes in (module, inputs, device_mesh) instead!", + FutureWarning, + stacklevel=2, + ) + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(outputs, device_mesh) # type: ignore[call-arg] + ) + elif num_args == 3: + module.register_forward_hook( + lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) + ) + else: + raise ValueError( + f"output_fn should take in 3 arguments, but got {num_args} arguments!" + ) + + module._distribute_module_applied = True # type: ignore[assignment] + return module + + +# Below are tensor factory function APIs, which are used to create a DTensor directly. We need +# to make separate factory function APIs because tensor subclass could not override the tensor +# factory methods, and we need user to call the factory functions with user intended device_mesh +# and placements to create a proper DTensor. + + +def _dtensor_init_helper( # type: ignore[no-untyped-def] + init_op, + size: torch.Size, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, + **kwargs, +) -> DTensor: + # if device_mesh is None, use the one from mesh resources + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + kwargs["device"] = device_mesh.device_type + + # set default placements to replicated if not specified + placements = placements or tuple(Replicate() for _ in range(device_mesh.ndim)) + + # check device_mesh againts placements + assert device_mesh.ndim == len(placements), ( + "mesh dimension does not match the length of placements" + ) + + assert kwargs["layout"] == torch.strided, "layout value not supported!" + torch_stride = torch._prims_common.make_contiguous_strides_for(size) + + # get local tensor shape + local_shape, _ = compute_local_shape_and_global_offset( + size, device_mesh, placements + ) + + # initialize the local tensor + if init_op == torch.full: + fill_value = kwargs.pop("fill_value", 0) + local_tensor = init_op(local_shape, fill_value, **kwargs) + elif init_op == torch.rand or init_op == torch.randn: + # this tensor meta is not used except `shape` + dtype = kwargs.get("dtype", torch.get_default_dtype()) + + tensor_meta = TensorMeta(size, (0,), dtype) + spec = DTensorSpec(device_mesh, tuple(placements), tensor_meta=tensor_meta) + + if random.is_rng_supported_mesh(device_mesh) and not random._rng_tracker: + random._rng_tracker = random.OffsetBasedRNGTracker(device_mesh) + + assert random._rng_tracker is not None + with random._rng_tracker._distribute_region(spec): + local_tensor = init_op(local_shape, **kwargs) + else: + local_tensor = init_op(local_shape, **kwargs) + + spec = DTensorSpec( + device_mesh, + tuple(placements), + tensor_meta=TensorMeta( + size, + torch_stride, + local_tensor.dtype, + ), + ) + + return DTensor( + local_tensor, + spec, + requires_grad=kwargs["requires_grad"], + ) + + +def ones( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 1, with the shape defined + by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.ones, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def empty( # type: ignore[no-untyped-def] + *size, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with uninitialized data. The shape of the :class:`DTensor` + is defined by the variable argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: empty(1,2,3..) or empty([1,2,3..]) or empty((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\ + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.empty, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def full( # type: ignore[no-untyped-def] + size, + fill_value, + *, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + requires_grad: bool = False, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with ``fill_value`` according to ``device_mesh`` and + ``placements``, with the shape defined by the argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + fill_value(Scalar): the value to fill the output tensor with. + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.full, + torch_size, + fill_value=fill_value, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def rand( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a uniform distribution + on the interval ``[0, 1)``. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.rand, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def randn( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with random numbers from a normal distribution + with mean 0 and variance 1. The shape of the tensor is defined by the variable + argument ``size``. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: ones(1,2,3..) or ones([1,2,3..]) or ones((1,2,3..)) + + Keyword args: + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned DTensor. + Default: ``torch.strided``. + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks. + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.randn, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) + + +def zeros( # type: ignore[no-untyped-def] + *size, + requires_grad: bool = False, + dtype: Optional[torch.dtype] = None, + layout: torch.layout = torch.strided, + device_mesh: Optional[DeviceMesh] = None, + placements: Optional[Sequence[Placement]] = None, +) -> DTensor: + """ + Returns a :class:`DTensor` filled with the scalar value 0. + + Args: + size (int...): a sequence of integers defining the shape of the output :class:`DTensor`. + Can be a variable number of arguments or a collection like a list or tuple. + E.g.: zeros(1,2,3..) or zeros([1,2,3..]) or zeros((1,2,3..)) + Keyword args: + requires_grad (bool, optional): If autograd should record operations on the + returned :class:`DTensor`. Default: ``False``. + dtype (:class:`torch.dtype`, optional): the desired data type of returned :class:`DTensor`. + Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`). + layout (:class:`torch.layout`, optional): the desired layout of returned :class:`DTensor`. + Default: ``torch.strided``. + device_mesh: :class:`DeviceMesh` type, contains the mesh info of ranks + placements: a sequence of :class:`Placement` type: ``Shard``, ``Replicate`` + + Returns: + A :class:`DTensor` object on each rank + """ + torch_size = normalize_to_torch_size(size) + + return _dtensor_init_helper( + torch.zeros, + torch_size, + dtype=dtype, + layout=layout, + requires_grad=requires_grad, + device_mesh=device_mesh, + placements=placements, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_collective_utils.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_collective_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d2d22cf1cb927143c75f0ae128e12c5cf6ce00 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_collective_utils.py @@ -0,0 +1,379 @@ +# mypy: allow-untyped-defs +import logging +import math +from dataclasses import dataclass +from functools import lru_cache +from typing import Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._dtensor_spec as dtensor_spec +from torch._C._distributed_c10d import _resolve_process_group +from torch._logging import warning_once +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.distributed_c10d import ( + _get_group_size_by_name, + broadcast, + get_group_rank, + get_rank, + ProcessGroup, + scatter, + Work, +) + + +logger = logging.getLogger(__name__) + + +if not torch._running_with_deploy(): + + @torch.library.register_fake("_dtensor::shard_dim_alltoall") + def _shard_dim_alltoall_meta(input, gather_dim, shard_dim, group_name): + group_size = _get_group_size_by_name(group_name) + stacked_list = [torch.empty_like(input) for _ in range(group_size)] + group = _resolve_process_group(group_name) + group_rank = get_group_rank(group, get_rank()) + + return ( + torch.cat(stacked_list, dim=gather_dim) + .chunk(group_size, dim=shard_dim)[group_rank] + .contiguous() + ) + +else: + import warnings + + warnings.warn( + "PyTorch Distributed functional collectives do not work with torch::deploy." + ) + + +def shard_dim_alltoall(input, gather_dim, shard_dim, mesh, mesh_dim): + if mesh.device_type == "cpu": + # Gloo does not support alltoall, so falling back to allgather + chunk + warning_once( + logger, + "CPU process group does not support alltoall yet, falling back with allgather + chunk!", + ) + out = funcol.all_gather_tensor(input, gather_dim, (mesh, mesh_dim)) + if isinstance(out, funcol.AsyncCollectiveTensor): + # stick to the same behavior for the alltoall case, remove this once we enable alltoall async + out = out.wait() + out = torch.chunk(out, mesh.size(mesh_dim), dim=shard_dim)[ + mesh.get_local_rank(mesh_dim) + ] + return out.contiguous() + + group_name = funcol._resolve_group_name((mesh, mesh_dim)) + # TODO: enable async op for shard_dim_alltoall + return torch.ops._dtensor.shard_dim_alltoall( + input, gather_dim, shard_dim, group_name + ) + + +def mesh_scatter( + output: torch.Tensor, + scatter_list: list[torch.Tensor], + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, + *, + group_src: int = 0, +) -> Optional[Work]: + """ + scatter a list of tensors to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we scatter on mesh_dim = 1, we will + scatter the tensor list on rank 0 to rank 0/1, and tensor list on rank + 2 to rank 2/3. + + Args: + output (torch.Tensor): the tensor to receive the scattered list. + scatter_list (List[torch.Tensor]): the tensor list to be scattered. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Keyword args: + group_src (int, optional): the group rank of the source data for the + logical/global tensor, on the specific mesh dimension. By default, we + use ``group_rank=0`` on each DeviceMesh dimension as the source data + to preserve the single-device semantic. If passing ``None`` explicitly, + this method simply uses its local data with no communication. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if output.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + if group_src == get_rank(dim_group): + fut = scatter( + output, + scatter_list=scatter_list, + group=dim_group, + async_op=async_op, + group_src=group_src, + ) + else: + fut = scatter( + output, + scatter_list=None, + group=dim_group, + async_op=async_op, + group_src=group_src, + ) + + return fut + + +def mesh_broadcast( + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int = 0, + async_op: bool = False, + *, + group_src: int = 0, +) -> Optional[Work]: + """ + broadcast the tensor to a device mesh dimension. We by default + use the first rank of the mesh dimension as the source of truth, i.e + for a 2d mesh [[0, 1], [2, 3]], if we broadcast on mesh_dim = 1, we will + broadcast the tensor on rank 0 to rank 0/1, and tensor on rank 2 + to rank 2/3. + + Args: + tensor (torch.Tensor): tensor to broadcast. + mesh_dim (int, optional): indicate which mesh dimension we want + to scatter on, we by default choose the first rank on the + mesh dimension as source of truth. + + Keyword args: + group_src (int, optional): the group rank of the source data for the + logical/global tensor, on the specific mesh dimension. By default, we + use ``group_rank=0`` on each DeviceMesh dimension as the source data + to preserve the single-device semantic. If passing ``None`` explicitly, + this method simply uses its local data with no communication. + + Returns: + A :class:`Work` object + """ + # TODO: Ideally we should use the meta tensor way + # (to register a meta kernel for the collective op) + # so that it would avoid the communication. Need to + # remove the check below once that is done. + if tensor.is_meta: + return None + dim_group = mesh.get_group(mesh_dim) + assert isinstance(dim_group, ProcessGroup) + + return broadcast(tensor, group=dim_group, async_op=async_op, group_src=group_src) + + +def pad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + pad = [0, 0] * (tensor.ndim - pad_dim) + pad[-1] = pad_size + return torch.nn.functional.pad(tensor, pad) + + +def unpad_tensor(tensor: torch.Tensor, pad_dim: int, pad_size: int) -> torch.Tensor: + if pad_size == 0: + return tensor + return tensor.narrow( + pad_dim, + start=0, + length=tensor.size(pad_dim) - pad_size, + ) + + +def fill_empty_tensor_to_shards( + shards: list[torch.Tensor], shard_dim: int, num_empty_tensors: int +) -> list[torch.Tensor]: + if num_empty_tensors == 0: + return shards + tensor_size = list(shards[0].size()) + tensor_size[shard_dim] = 0 + tensor = shards[0].new_zeros(tensor_size) + shards.extend(tensor for _ in range(num_empty_tensors)) + return shards + + +def check_tensor_meta( + local_tensor, check_shape_stride=False +) -> Optional["dtensor_spec.TensorMeta"]: + local_metadata = { + "dtype": local_tensor.dtype, + "requires_grad": local_tensor.requires_grad, + } + + if check_shape_stride: + local_metadata.update( + {"shape": local_tensor.shape, "stride": local_tensor.stride()} + ) + + gathered_metadata = [None for _ in range(torch.distributed.get_world_size())] + torch.distributed.all_gather_object(gathered_metadata, local_metadata) + + # Check if metadata is consistent across ranks + if not all(meta == local_metadata for meta in gathered_metadata): + raise ValueError( + "Inconsistent tensor metadata (including shape and stride) across ranks." + ) + return None + + +def spec_to_bytes(spec: "dtensor_spec.DTensorSpec") -> int: + assert spec.tensor_meta is not None, "spec should have tensor meta defined!" + return spec.tensor_meta.dtype.itemsize * math.prod(spec.shape) + + +@dataclass +class MeshTopoInfo: + """ + Mesh information for collective cost estimation + """ + + mesh: DeviceMesh + mesh_dim_devices: list[int] + mesh_dim_bandwidth: list[float] + mesh_dim_latency: list[float] + + @staticmethod + @lru_cache(None) + def build_from_mesh(mesh: DeviceMesh) -> "MeshTopoInfo": + # Generate mesh topology info for intra-host/inter-host communication pattern + # Note that we made bunch of assumptions for simplicity: + # 1. we assume the mesh is homogeneous, and it's gpu/nccl model + # 2. we assume gpu arch is Ampere or Hopper + # 3. we assume collectives are all ring base algo for now + num_devices_per_host = _mesh_resources.num_devices_per_host(mesh.device_type) + # the base bw number (intra-node), GB/s + base_bw = 87.7 + mesh_dim_bandwidth = [base_bw] * mesh.ndim + # the latency in terms of us (intra-node, nv-link) + mesh_dim_latency = [0.6] * mesh.ndim + mesh_dim_devices = [1] * mesh.ndim + + total_num_devices = 1 + for mesh_dim in reversed(range(mesh.ndim)): + num_devices = mesh.size(mesh_dim) + mesh_dim_devices[mesh_dim] = num_devices + total_num_devices *= num_devices + if total_num_devices > num_devices_per_host: + # magic number for inter-host communication bandwidth/latency factor + # This number assumes latest GPU arch, i.e. Ampere or Hopper + # TODO: see if we need to tweak this or offer a way for user + # to specify the bandwidths/latency + mesh_dim_bandwidth[mesh_dim] *= 0.22 + # set to ethernet latency for inter-host + mesh_dim_latency[mesh_dim] = 2.7 + + return MeshTopoInfo( + mesh, mesh_dim_devices, mesh_dim_bandwidth, mesh_dim_latency + ) + + +def allgather_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] # us + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth # s + return latency + bw * 1e6 # rescale to us + + +def allreduce_cost(bytes_gb: float, mesh_topo: MeshTopoInfo, mesh_dim: int) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + # allreduce have almost 2x comm bytes compare to allgather/reduce_scatter + num_hops = 2 * (num_devices_on_mesh_dim - 1) + + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def reduce_scatter_cost( + bytes_gb: float, + mesh_topo: MeshTopoInfo, + mesh_dim: int, +) -> float: + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[mesh_dim] + mesh_dim_bandwidth = mesh_topo.mesh_dim_bandwidth[mesh_dim] + num_hops = num_devices_on_mesh_dim - 1 + # base latency + comm latency + latency = 6.6 + num_hops * mesh_topo.mesh_dim_latency[mesh_dim] + bw = (bytes_gb * num_hops / num_devices_on_mesh_dim) / mesh_dim_bandwidth + return latency + bw * 1e6 + + +def redistribute_cost( + current_spec: "dtensor_spec.DTensorSpec", + target_spec: "dtensor_spec.DTensorSpec", +) -> float: + """ + This function returns the cost of redistribute from current to target DTensorSpec. + + NOTE: + 1. Only consider communication cost here, since computation costs for redistribute + are quite trival (i.e. we only need to narrow or simple division) + 2. Only consider redistribute cost on same mesh, cross mesh communication cost is + not quite needed for operator strategy estimation/selection. + """ + if current_spec.mesh != target_spec.mesh: + # make infinite cost if meshes are not same + # TODO: see if we want to support this once there's cross mesh communication + return float("inf") + + if current_spec.is_replicated(): + # short-cut: + # comm cost is 0 if current spec is already full replication + return 0.0 + + mesh_topo = MeshTopoInfo.build_from_mesh(current_spec.mesh) + cost = 0.0 + comm_bytes_gb = ( + spec_to_bytes(current_spec) / current_spec.num_shards / 1024 / 1024 / 1024 + ) + # Transformation that considered for redistribute cost: + # 1. allgather 2. alltoall + # 3. allreduce 4. reduce_scatter + for i, (current, target) in enumerate( + zip(current_spec.placements, target_spec.placements) + ): + if current == target: + continue + + num_devices_on_mesh_dim = mesh_topo.mesh_dim_devices[i] + if current.is_shard() and target.is_replicate(): + # allgather gives larger comm bytes + comm_bytes_gb *= num_devices_on_mesh_dim + # add up allgather comm cost + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + elif current.is_shard() and target.is_shard(): + # should be alltoall comm, since we haven't implement it yet, add penalty + # to favor allgather instead + cost += allgather_cost(comm_bytes_gb, mesh_topo, i) + 1.0 + elif current.is_partial() and target.is_replicate(): + # add up allreduce comm cost + cost += allreduce_cost(comm_bytes_gb, mesh_topo, i) + elif current.is_partial() and target.is_shard(): + # add up reduce_scatter comm cost + cost += reduce_scatter_cost(comm_bytes_gb, mesh_topo, i) + # after reduce_scatter the comm bytes for further collectives halved. + comm_bytes_gb /= num_devices_on_mesh_dim + elif current.is_shard() and target.is_partial(): + # ban shard -> partial as it does not make sense to perform + # this redistribute + return float("inf") + + return cost diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_dispatch.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..be106958e73c6c60b3122044c3d7d67256f1dfe6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_dispatch.py @@ -0,0 +1,456 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import functools +import logging +import operator +import warnings +from collections.abc import Sequence +from typing import cast, Optional + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor +import torch.distributed.tensor._random as random +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpInfo, OpSchema, OutputSpecType +from torch.distributed.tensor._random import is_rng_supported_mesh +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor._sharding_prop import ShardingPropagator +from torch.distributed.tensor._tp_conv import ( + convolution_backward_handler, + convolution_handler, +) +from torch.distributed.tensor._utils import try_find_mesh_from_args +from torch.distributed.tensor.placement_types import Partial, Placement, Replicate + + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +def is_same_size_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> bool: + lhs = cast(torch.Tensor, args[0]) + rhs = cast(torch.Tensor, args[1]) + return lhs.shape == rhs.shape + + +def found_inf_reduce_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> None: + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + local_tensor_args = pytree.tree_unflatten( + cast(list[object], op_info.local_args), + op_info.args_tree_spec, # type: ignore[arg-type] + ) + local_tensor_args = cast(tuple[object, ...], local_tensor_args) + op_call(*local_tensor_args, **op_info.local_kwargs) + + grad_dtensor = cast(list[dtensor.DTensor], args[0])[0] + grad_placements = grad_dtensor.placements + mesh = grad_dtensor.device_mesh + + found_inf_placements: list[Placement] = [] + for placement in grad_placements: + if isinstance(placement, Replicate): + found_inf_placements.append(placement) + else: + found_inf_placements.append(Partial("max")) + + target_tensor = cast(torch.Tensor, args[1]) + spec = DTensorSpec( + mesh=mesh, + placements=tuple(found_inf_placements), + tensor_meta=TensorMeta( + shape=target_tensor.size(), + stride=target_tensor.stride(), + dtype=target_tensor.dtype, + ), + ) + found_inf_dtensor = dtensor.DTensor( + local_tensor=target_tensor, spec=spec, requires_grad=False + ) + found_inf = found_inf_dtensor.full_tensor() + target_tensor.copy_(found_inf) + + +class OpDispatcher: + """ + Op dispatching class instance to handle args/kwargs pre-processing (un-wrapping), sharding + propagation, redistribute local args, local compute, and post-processing (re-wrapping). It + also handles any op specific logic if necessary. + + NOTE: Given the runtime overhead of Tensor subclass (__torch_dispatch__), the OpDispatcher + is designed to minimize the CPU overhead by using the tricks of proper unflattening, faster + pytree if needed, and leveraging various caching mechanisms implemented in the sharding + propagation and redistribute modules. The CPU overhead is critical to eager mode performance, + one need to carefully measure the CPU overhead when making significant changes to the + OpDispatcher and ShardingPropagator. + """ + + def __init__(self) -> None: + self.sharding_propagator = ShardingPropagator() + self._random_ops = { + aten.native_dropout.default, + aten.normal_.default, + aten.rand_like.default, + aten.randn_like.default, + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + aten.uniform_.default, + aten.bernoulli.default, + aten.bernoulli_.float, + } + self._custom_op_handlers = { + aten.is_same_size.default: is_same_size_handler, + aten.convolution.default: convolution_handler, + aten.convolution_backward.default: convolution_backward_handler, + aten._amp_foreach_non_finite_check_and_unscale_.default: found_inf_reduce_handler, + } + + # This flag is used internally to control whether we treat the torch.Tensor(non-DTensor) + # as implicitly replicated or we throw error to user. + # NOTE: It is EXTREMELY UNSAFE to turn this flag on by default so we intentionally leave + # it as False by default. + self._allow_implicit_replication = False + + def dispatch( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + """ + Main dispatching logic + """ + # operators that does not need to go through sharding propagation + if torch._C._dispatch_has_kernel_for_dispatch_key( + op_call.name(), torch._C.DispatchKey.CompositeImplicitAutograd + ): + # When running under inference mode, CompositeImplicitAutograd ops show up in __torch_dispatch__, + # so we manually decompose them, here + out = op_call.decompose(*args, **kwargs) + assert out is not NotImplemented + return out + if op_call in self._custom_op_handlers: + return self._custom_op_handlers[op_call](op_call, args, kwargs) # type: ignore[operator] + + # extract local tensor and sharding infos to a OpInfo + op_info = self.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + self.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + logger.debug("output_sharding for %s: %s", op_call, output_sharding) + assert output_sharding is not None, "output sharding should not be None" + + mesh = op_info.compute_mesh + if mesh.get_coordinate() is not None: + # computation that happens in the current rank of the mesh, normal case + if output_sharding.needs_redistribute: + # If sharding propagation decision needs redistribute, perform redistribute + # on args first, which could potentially modify args (i.e. allgather certain arg) + assert output_sharding.redistribute_schema is not None + self.redistribute_local_args( + op_info, output_sharding.redistribute_schema + ) + + local_tensor_args = ( + pytree.tree_unflatten( + cast(list[object], op_info.local_args), op_info.args_tree_spec + ) + if op_info.args_tree_spec + else op_info.local_args + ) + + # run local op computation with potentially modified args/kwargs + local_tensor_args = cast(tuple[object, ...], local_tensor_args) + if op_call in self._random_ops: + if not random._rng_tracker and is_rng_supported_mesh(mesh): + # Default to `OffsetBasedRNGTracker` if the parallelism API + # did not already construct one + random._rng_tracker = random.OffsetBasedRNGTracker(mesh) + + first_arg, first_local_arg = ( + cast(dtensor.DTensor, args[0]), + cast(torch.Tensor, local_tensor_args[0]), + ) + rng_context = ( + random._rng_tracker._distribute_region(first_arg._spec) + if random._rng_tracker and not first_local_arg.is_meta + else contextlib.nullcontext() + ) + # For DTensor random operator, run it within a RNGTracker context to + # ensure the random number generator is properly distributed. + with rng_context: + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + else: + # normal case, run local sharded op computation + local_results = op_call(*local_tensor_args, **op_info.local_kwargs) + + else: + # For a non-participating device (happens on rank that does not belong to + # the device mesh), we do: + # 1. if the return type is scalar, set the local result to None. + # 2. if the return type is Tensor or List[Tensor], return empty + # tensor(s) with correct dtype. + spec = output_sharding.output_spec + ret_list = op_info.schema.op._schema.returns + + if spec is None: + # For a scalar return type, the non-participating device has None + # as its local result + local_results = None + else: + + def default_tensor(spec: DTensorSpec) -> torch.Tensor: + if spec.tensor_meta is not None: + shape = spec.tensor_meta.shape + dtype = spec.tensor_meta.dtype + if len(shape) == 0: + # scalar tensor + return torch.zeros((), dtype=dtype) + else: + # non-scalar tensor + return torch.tensor([], dtype=dtype) + else: + raise RuntimeError(f"{spec} has no tensor metadata.") + + if isinstance(spec, DTensorSpec): + # return a Tensor value + local_results = default_tensor(spec) + elif isinstance(spec, Sequence): + # return a List[Tensor] value + local_results = [ + default_tensor(s) if s is not None else None for s in spec + ] + assert isinstance(local_results, list) + if None in local_results: + ret_type = str(ret_list[0].type) + raise NotImplementedError( + f"return type {ret_type} in DTensor op is not supported" + ) + + if output_sharding.output_spec is None: + if op_call == aten.equal.default: + # For equal operator, The local results from all devices should be all-gathered + # and a reduce op (AND) will be performed on the list of results to ensure SPMD + # execution. We can extend this for more ops if necessary. + obj_list = [None for _ in range(dist.get_world_size())] + dist.all_gather_object(obj_list, local_results) # type: ignore[possibly-undefined] + obj_list = list(filter(lambda x: x is not None, obj_list)) + # perform reduce on the collection with AND op + local_results = functools.reduce(operator.and_, obj_list, True) + + if op_info.schema.is_inplace_op(): + # inplace op should return self instead of re-wrapping + if output_sharding.output_spec is not None: + return args[0] + else: + return None + elif op_info.schema.is_out_variant_op(): + # out variant could possibly have multiple out args (i.e. lu_unpack.out) + output_specs = ( + (output_sharding.output_spec,) + if not isinstance(output_sharding.output_spec, tuple) + else output_sharding.output_spec + ) + out_dts = [] + spec_idx = 0 + for argument in op_call._schema.arguments: + if argument.is_out: + out_dt = cast(dtensor.DTensor, kwargs[argument.name]) + out_dt._spec = cast(DTensorSpec, output_specs[spec_idx]) + out_dts.append(out_dt) + spec_idx += 1 + + assert len(out_dts) >= 1, "out variant should have at least one out arg" + return tuple(out_dts) if len(out_dts) > 1 else out_dts[0] + else: + return self.wrap(local_results, output_sharding.output_spec) # type: ignore[possibly-undefined] + + @staticmethod + def redistribute_local_args( + op_info: OpInfo, + suggested_input_schema: OpSchema, + ) -> None: + # NOTE: it's very rare that we need to reshard kwargs so we intentionally skip it + if op_info.args_tree_spec is not None: + flatten_args_schema_to_reshard = tuple( + pytree.tree_leaves(suggested_input_schema.args_schema) + ) + else: + flatten_args_schema_to_reshard = suggested_input_schema.args_schema + + new_local_args: list[object] = [] + for i, arg_spec in enumerate(op_info.flat_args_schema): + reshard_arg_spec = flatten_args_schema_to_reshard[i] + if isinstance(arg_spec, DTensorSpec): + local_tensor = cast(torch.Tensor, op_info.local_args[i]) + if arg_spec != reshard_arg_spec: + resharded_local_tensor = redistribute_local_tensor( + local_tensor, arg_spec, reshard_arg_spec + ) + new_local_args.append(resharded_local_tensor) + else: + new_local_args.append(local_tensor) + else: + new_local_args.append(reshard_arg_spec) + + op_info.local_args = tuple(new_local_args) + + def unwrap_to_op_info( + self, + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> OpInfo: + # get runtime schema info to determine whether to use pytree to flatten inputs + runtime_schema_info = self.sharding_propagator.op_to_schema_info.get( + op_call, None + ) + + if runtime_schema_info is not None and runtime_schema_info.needs_pytree: + # flatten args/kwargs when op says necessary + tree_args, args_spec = pytree.tree_flatten(args) + args_list: Sequence[object] = tree_args + else: + args_list, args_spec = args, None + + args_schema: list[object] = [] + kwargs_schema: dict[str, object] = {} + local_args: list[object] = [] + local_kwargs: dict[str, object] = {} + compute_mesh: Optional[DeviceMesh] = None + + for arg in args_list: + if isinstance(arg, dtensor.DTensor): + local_args.append(arg._local_tensor) + args_schema.append(arg._spec) + if compute_mesh is None: + # record the first compute device mesh from args + compute_mesh = arg.device_mesh + elif isinstance(arg, torch.Tensor): + compute_mesh = compute_mesh or try_find_mesh_from_args( + op_call, args_list + ) + args_schema.append( + self._try_replicate_spec_for_scalar_tensor( + op_call, arg, compute_mesh + ) + ) + local_args.append(arg) + else: + # non DTensor/Tensor args (i.e. int/float/bool), just add to args_schema/local_args + args_schema.append(arg) + local_args.append(arg) + + for k, v in kwargs.items(): + if isinstance(v, dtensor.DTensor): + local_kwargs[k] = v._local_tensor + kwargs_schema[k] = v._spec + elif isinstance(v, torch.Tensor): + compute_mesh = compute_mesh or try_find_mesh_from_args( + op_call, args_list + ) + kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( + op_call, v, compute_mesh + ) + local_kwargs[k] = v + else: + # non DTensor/Tensor args (i.e. int/float/bool), just add to args_schema/local_args + kwargs_schema[k] = v + local_kwargs[k] = v + + assert compute_mesh is not None, ( + f"found no DeviceMesh from dtensor args for {op_call}!" + ) + op_info = OpInfo( + compute_mesh, + OpSchema( + op_call, + ( + pytree.tree_unflatten(args_schema, args_spec) + if args_spec + else tuple(args_schema) + ), + kwargs_schema, + schema_info=runtime_schema_info, + ), + args_schema, + tuple(local_args), + local_kwargs, + args_spec, + ) + return op_info + + @staticmethod + def wrap(res: object, spec: OutputSpecType) -> object: + if isinstance(res, torch.Tensor): + if spec is not None: + assert isinstance(spec, DTensorSpec), ( + f"output spec does not match with output! Expected DTensorSpec, got {spec}." + ) + return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) + else: + # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor + assert res.ndim == 0, "output tensor should be scalar!" + return res + elif isinstance(res, (list, tuple)): + assert spec is not None and isinstance(spec, (list, tuple)), ( + f"output spec does not match with output! Expected list/tuple, got {spec}." + ) + res_list = [] + for e, s in zip(res, spec): + res_list.append(OpDispatcher.wrap(e, s)) + + return tuple(res_list) if isinstance(res, tuple) else res_list + else: + # if the res contains only non tensor values (i.e. int/float/none), we simply return it + # without rewrapping to DTensor. + return res + + def _try_replicate_spec_for_scalar_tensor( + self, + op_call: torch._ops.OpOverload, + tensor_arg: torch.Tensor, + compute_mesh: DeviceMesh, + ) -> DTensorSpec: + # util function to produce a replicate spec for a scalar tensor arg/kwarg + if tensor_arg.numel() == 1 and tensor_arg.ndim == 1: + warnings.warn( + "Found a non-scalar tensor with numel=1 and ndim!=0, " + "we are implicitly creating a replicated DTensor for it. " + "However, please consider changing it to a scalar tensor " + "or explicitly create a DTensor under distributed enviroment." + ) + + if tensor_arg.numel() == 1 or self._allow_implicit_replication: + # scalar tensor can be safely treated as replicated + replication_spec = DTensorSpec( + compute_mesh, + (Replicate(),) * compute_mesh.ndim, + tensor_meta=TensorMeta( + shape=tensor_arg.shape, + stride=tensor_arg.stride(), + dtype=tensor_arg.dtype, + ), + ) + else: + raise RuntimeError( + f"{op_call}: got mixed torch.Tensor and DTensor, need to convert all" + " torch.Tensor to DTensor before calling distributed operators!" + ) + return replication_spec diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_dtensor_spec.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_dtensor_spec.py new file mode 100644 index 0000000000000000000000000000000000000000..65691b71e269689427886dc92c62513c45a84dbb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_dtensor_spec.py @@ -0,0 +1,276 @@ +from dataclasses import dataclass +from typing import Any, cast, NamedTuple, Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +class TensorMeta(NamedTuple): + # simple named tuple to represent tensor metadata + # intentionally to stay simple only for sharding + # propagation purposes. + shape: torch.Size + stride: tuple[int, ...] + dtype: torch.dtype + + +# used internally to propagate the placements +@dataclass +class DTensorSpec: + mesh: DeviceMesh + placements: tuple[Placement, ...] + + # tensor meta will only be set during sharding propagation + tensor_meta: Optional[TensorMeta] = None + + def __post_init__(self) -> None: + if not isinstance(self.placements, tuple): + self.placements = tuple(self.placements) + self._hash: Optional[int] = None + + def __setattr__(self, attr: str, value: Any) -> None: + super().__setattr__(attr, value) + # Make sure to recompute the hash in case any of the hashed attributes + # change (though we do not expect `mesh` or `placements` to change) + if hasattr(self, "_hash") and attr in ("mesh", "placements", "tensor_meta"): + self._hash = None + + def _hash_impl(self) -> int: + # hashing and equality check for DTensorSpec are used to cache the sharding + # propagation results. We only need to consider the mesh, placements, shape + # dtype and stride. + # Caveat: we need to keep this in mind and sync hash and eq if we add more + # fields to them. + if self.tensor_meta is not None: + return hash( + ( + self.mesh, + self.placements, + self.tensor_meta.shape, + self.tensor_meta.stride, + self.tensor_meta.dtype, + ) + ) + return hash((self.mesh, self.placements)) + + def __hash__(self) -> int: + # We lazily cache the spec to avoid recomputing the hash upon each + # use, where we make sure to update the hash when the `tensor_meta` + # changes by overriding `__setattr__`. This must be lazy so that Dynamo + # does not try to hash non-singleton `SymInt`s for the stride. + if self._hash is None: + self._hash = self._hash_impl() + return self._hash + + def __eq__(self, other: object, /) -> bool: + if not ( + isinstance(other, DTensorSpec) + and self.mesh == other.mesh + and self.placements == other.placements + ): + return False + if self.tensor_meta is None or other.tensor_meta is None: + return self.tensor_meta == other.tensor_meta + + return ( + self.tensor_meta.shape == other.tensor_meta.shape # type: ignore[union-attr] + and self.tensor_meta.stride == other.tensor_meta.stride # type: ignore[union-attr] + and self.tensor_meta.dtype == other.tensor_meta.dtype # type: ignore[union-attr] + ) + + def __str__(self) -> str: + """ + human readable representation of the DTensorSpec + """ + if len(self.placements) == 1: + placement_str = str(self.placements[0]) + else: + placement_str = str(self.placements) + + if self.tensor_meta is not None: + tensor_shape = str(tuple(self.tensor_meta.shape)) + else: + tensor_shape = "unknown shape" + + return f"Spec({placement_str} on {tensor_shape})" + + @property + def shape(self) -> torch.Size: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.shape + + @property + def stride(self) -> tuple[int, ...]: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return self.tensor_meta.stride + + @property + def ndim(self) -> int: + if self.tensor_meta is None: + raise ValueError("tensor_meta is not set") + return len(self.tensor_meta.shape) + + @property + def num_shards(self) -> int: + num_shards = 1 + for i, placement in enumerate(self.placements): + if placement.is_shard(): + num_shards *= self.mesh.size(i) + return num_shards + + @property + def device_mesh(self) -> DeviceMesh: + # simple aliasing for the mesh field, make some + # checks that mixes DTensor/DTensorSpec easier + return self.mesh + + @property + def dim_map(self) -> list[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. It simply return a list of ints + where dim_map[i] denotes the sharding mapping to the mesh + dimension, and len(dim_map) == dist_tensor.ndim + dim_map[i] = -1: means tensor dim i replicate on mesh + dim_map[i] = j: means tensor dim i shard on mesh dim j + + For example, we have a dist tensor that have the shape of + [18, 20, 30], and device_mesh([0, 1, 2, 3]), placements: + [Shard(1)], the dim_map of this placement would be: + [-1, 0, -1]. This representation is pretty helpful during + sharding propagation where we could know exactly each + tensor dimension is sharded or not. + + Note that if placements contains `_Partial`, we have to + explicitly deal with it, so that when we create a DTensorSpec + with dim_map, we could properly record the pending sums. + """ + # dims mapping of dist tensor sharding + # return size of tensor ndim, -1 represent replicate + # and int >=0 represent shard on that device mesh dim + r = [-1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if r[shard_dim] > -1: + raise ValueError( + f"Tensor dim {shard_dim} is already sharded on mesh dim {r[shard_dim]}," + " DTensor operator implementation does not support things like hybrid" + " sharding strategies yet (i.e. [Shard(0), Shard(0)])" + ) + r[shard_dim] = i + return r + + @property + def num_shards_map(self) -> list[int]: + """ + dim_map is a property we derive from `placements` of + the distributed tensor. Unlike `dim_map`, `num_shards_map` + denotes how many shards each tensor dim has. Like `dim_map`: + len(num_shards_map) == dist_tensor.ndim + num_shards_map[i] = 1: means tensor dim i is not sharded + num_shards_map[i] = j: means tensor dim i has j shards in total + + For example, we have a dist tensor of shape [18, 20, 30], + a device_mesh ([[0, 1, 2, 3], [4, 5, 6, 7]]), and placements + ([Shard(1), Shard(0)]), the num_shards_map of this distributed tensor + would be: [4, 2, 1]. + """ + r = [1] * self.ndim + for i, placement in enumerate(self.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + r[shard_dim] *= self.mesh.size(i) + + return r + + @property + def sums(self) -> list[int]: + """ + sums is a property we derive from `placements` of the + distributed tensor. It simply return a list of ints where + sums[i] denotes the pending sum (partial) on mesh dim i + """ + return [ + idx + for idx, placement in enumerate(self.placements) + if placement.is_partial() + ] + + @classmethod + def from_dim_map( + cls, + mesh: DeviceMesh, + dim_map: list[int], + sums: list[int], + tensor_meta: Optional[TensorMeta] = None, + ) -> "DTensorSpec": + """ + Construct a DTensorSpec from dim_map list and pending sum. + + Args: + mesh (class:`DeviceMesh`): device mesh to be used in the DTensorSpec + dim_map (List[int]): a list of integer that represents sharding on each + tensor dimension, see `dim_map` property doc for details + sums (List[int]): a list of integer that represents the dist tensor have + pending sum on which device mesh dimension. + tensor meta (TensorMeta): DTensor metadata + + Return: + a class:`DTensorSpec` object + """ + # by default replicate on device mesh dims + placements: list[Placement] = [Replicate() for _ in range(mesh.ndim)] + + # find all mesh dims that need pending reductions + for s in sums: + placements[s] = Partial() + + for i, m in enumerate(dim_map): + if m >= 0: + placement = placements[m] + if placement.is_shard(): + placement = cast(Shard, placement) + raise RuntimeError( + f"DeviceMesh dimension cann't be mapped to two dimension of the same tensor: {i} and {placement.dim}" + ) + elif placement.is_partial(): + raise RuntimeError( + f"DeviceMesh dimension {m} cannot be both shard and partial!" + ) + placements[m] = Shard(i) + + return cls(mesh, tuple(placements), tensor_meta=tensor_meta) + + def is_replicated(self) -> bool: + """ + return True if the current DTensorSpec replicates on all mesh dims (devices) + """ + return all(placement.is_replicate() for placement in self.placements) + + def is_sharded(self) -> bool: + """ + return True if the current DTensorSpec is sharded on any mesh dims (devices) + """ + return any(placement.is_shard() for placement in self.placements) + + def shallow_copy_with_tensor_meta( + self, tensor_meta: Optional[TensorMeta] + ) -> "DTensorSpec": + """ + Shallow copy the DTensorSpec with a new tensor_meta. + """ + assert tensor_meta is not None, "shallow copy with no tensor_meta!" + return DTensorSpec( + self.mesh, + self.placements, + tensor_meta=tensor_meta, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_op_schema.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_op_schema.py new file mode 100644 index 0000000000000000000000000000000000000000..322b857753ca511f14fb45241d7f2fba8e13eec8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_op_schema.py @@ -0,0 +1,532 @@ +# mypy: allow-untyped-defs +from collections.abc import Sequence +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Optional, Union + +import torch +from torch._ops import OpOverload +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch.utils._cxx_pytree import tree_leaves, tree_map_only, TreeSpec +except ImportError: + from torch.utils._pytree import ( # type: ignore[no-redef, assignment] + tree_leaves, + tree_map_only, + TreeSpec, + ) + + +# Common type aliases +ArgsType = tuple[object, ...] +KwargsType = dict[str, object] + +PlacementList = list[Optional[Placement]] + +# ATen op schemas could have Tensor, Tuple[Tensor] and List[Tensor], so output type sould +# be the same set of possibilities. +OutputSpecType = Optional[Union[DTensorSpec, Sequence[Optional[DTensorSpec]]]] + + +def _rebuild_tensor_from_dtensor_meta(arg) -> object: + """ + This is used to propagate tensor metadata, must be under fake mode + """ + assert arg.tensor_meta is not None, "DTensorSpec does not contain tensor_meta." + return torch.empty_strided( + arg.tensor_meta.shape, + arg.tensor_meta.stride, + dtype=arg.tensor_meta.dtype, + ) + + +def _pretty_print_spec(spec: object) -> str: + if spec is None: + return "None" + elif isinstance(spec, DTensorSpec): + return "".join([str(p) for p in spec.placements]) + elif isinstance(spec, Sequence): + return "(" + ", ".join([_pretty_print_spec(s) for s in spec]) + ")" + else: + raise RuntimeError(f"Unknown spec type to print: spec={spec}") + + +@dataclass +class OpSpec: + """ + An OpSpec describes an acceptable sharding placements of an operation, with the + specified DTensorSpecs for both the output and the inputs. + + note: when the op return value is a single DTensor object, output_specs is + DTensorSpec; when the return value is a tuple of Optional[DTensor], + output_specs is a tuple of Optional[DTensorSpec]. + """ + + output_specs: Union[DTensorSpec, tuple[Optional[DTensorSpec], ...]] + input_specs: Optional[Sequence[DTensorSpec]] = None + + # redistribute costs to redistribute the operator input shardings to this OpSpec. + # Note that We need a nested list to record the cost for each operand of this + # operator, and for each operand of this operator it might have multiple OpSpecs. + redistribute_cost: Optional[list[list[float]]] = None + + @cached_property + def output_spec(self) -> DTensorSpec: + """ + This function requires that the strategy have exactly one DTensorSpec as the + output spec. If the output_specs is a tuple, we throw an exception. + """ + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec but got: {self.output_specs}" + ) + + @cached_property + def mesh(self): + if isinstance(self.output_specs, DTensorSpec): + return self.output_specs.mesh + elif isinstance(self.output_specs, tuple): + out_spec = self.output_specs[0] + assert isinstance(out_spec, DTensorSpec) + return out_spec.mesh + else: + raise ValueError( + f"function output_spec expects a single DTensorSpec or a tuple of DTensorSpec but got: {self.output_specs}" + ) + + def input_spec(self, index: int = 0) -> DTensorSpec: + assert self.input_specs is not None, "input_specs of OpSpec is None!" + assert len(self.input_specs) > index, ( + f"Invalid index {index} for input_specs of length " + f"{len(self.input_specs)}: {self.input_specs}" + ) + return self.input_specs[index] + + def __str__(self) -> str: + if self.input_specs is not None: + input_specs_str = f"{_pretty_print_spec(self.input_specs)} -> " + else: + input_specs_str = "" + output_spec_str = _pretty_print_spec(self.output_specs) + return f"{input_specs_str}{output_spec_str}" + + +class StrategyType: + """ + Base class type for op strategy, We have two StrategyType: + OpStrategy and TupleStrategy + """ + + +class OpStrategy(StrategyType): + """ + OpStrategy that consists of a list of sharding strategies associated with the op, + where each strategy is an OpSpec that describes the acceptable input/output sharding. + """ + + def __init__(self, strategies: list[OpSpec]) -> None: + super().__init__() + self.strategies: list[OpSpec] = strategies + + def __str__(self) -> str: + strategy_list_str = ", ".join([str(strategy) for strategy in self.strategies]) + mesh_shape = self.mesh_shape + return f"[{strategy_list_str}] @ mesh: {mesh_shape}" + + def max_num_shards(self) -> int: + """ + Returns the max number of shards across all OpSpecs + """ + return max(strategy.output_spec.num_shards for strategy in self.strategies) + + @property + def mesh(self): + return self.strategies[0].mesh + + @property + def mesh_shape(self): + return self.strategies[0].mesh.shape + + @property + def ndim(self): + return self.strategies[0].output_spec.ndim + + @property + def shape(self): + return self.strategies[0].output_spec.shape + + +class TupleStrategy(StrategyType): + """ + TupleStrategy represents the output strategy of this op is a tuple of OpStrategies, + i.e. If the output of this op is a tuple of tensors or list of tensors with possibly + different OpStrategies, we should return a TupleStrategy that contains a tuple of + OpStrategy, where each child represents the sharding strategy of "each element" of + the tuple/list of tensors the op returns. + + NOTE: if the output of the op is a List[Tensor] and they share the same OpStrategy, + then we should return a single OpStrategy instead of a TupleStrategy + """ + + def __init__(self, childs: Sequence[StrategyType]) -> None: + super().__init__() + self.childs: Sequence[StrategyType] = childs + + def child_mesh(self, index: int) -> DeviceMesh: + op_strategy = self.childs[index] + assert isinstance(op_strategy, OpStrategy) + return op_strategy.mesh + + def __str__(self) -> str: + child_strategies_str = ", ".join( + [f"{str(strat)}" for idx, strat in enumerate(self.childs)] + ) + return f"TupleStrategy({child_strategies_str})" + + +@dataclass +class RuntimeSchemaInfo: + """ + RuntimeSchemaInfo stores the operator schema related information for runtime (eager) + execution. This is mainly used for two ways: 1. to generate hash for args to determine + whether to re-run sharding prop or not 2. to determine if we need pytree + """ + + # This static_argnum records static arg "starting index" for ops that have non-tensor + # args/kwargs which would affect sharding propagation results. All args starting from + # this index would be hashed to our sharding cache. + # Note that only a few ops need this information, e.g. view, transpose, var.dim, etc. + static_argnum: int = 100 + # This static_kwargkey records static kwarg names which would affect sharding prop + static_kwargkey: Optional[list[str]] = None + # each op can decide if it wants to use pytree flatten/unflatten during operator + # eager execution, by default we don't need to do flatten/unflatten, only if the + # op indicate it needs to, this is to accelerate eager performance. + needs_pytree: bool = False + + +@dataclass +class OpSchema: + """ + OpSchema is a data class that describes an operator input schemas, it includes + DTensorSpecs/OpStrategies (instead of DTensor) and non-tensor args/kwargs (positional + order preserved). It is mainly used by the DTensor's dispatching logic to perform various + actions (i.e. sharding propagation, caching sharding decisions, redistribute, etc.) + + NOTE: this should be used as a read only data class + TODO: make this a frozen dataclass + + Args: + op: the operator overload we are intercepting + args_schema: contains args except that the DTensor args have been replaced + with its DTensorSpec or OpStrategy + kwargs_schema: contains kwargs except that the DTensor kwargs have been replaced + with its DTensorSpec or OpStrategy + """ + + op: OpOverload + args_schema: ArgsType + kwargs_schema: KwargsType + + schema_info: Optional[RuntimeSchemaInfo] = None + + @property + def args_spec(self) -> tuple[DTensorSpec, ...]: + """ + args_spec: Tuple[DTensorSpec, ...]: contains a clean list of args spec list + with NO non-DTensor positional arguments (i.e. int/float/tuple, etc) + mainly used by sharding propagation to propagate the output spec + """ + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, DTensorSpec)) + + @property + def args_strategy(self) -> tuple[OpStrategy, ...]: + # filter out non-relevant values from args schema to get a clean OpStrategy list + # separate with args_spec for the ease of type annotation + # TODO: see if we should merge this with args_spec + args = ( + tree_leaves(self.args_schema) + if self.schema_info is not None and self.schema_info.needs_pytree + else self.args_schema + ) + return tuple(item for item in args if isinstance(item, OpStrategy)) + + def __repr__(self) -> str: + args_schema = ", ".join([str(arg_schema) for arg_schema in self.args_schema]) + return ( + f"OpSchema(op={self.op}," + f" args_schema=({args_schema})," + f" kwargs_schema={self.kwargs_schema})" + ) + + def __str__(self) -> str: + args_schema: list[str] = [] + mesh_shape = None + for arg in self.args_schema: + if isinstance(arg, DTensorSpec): + args_schema.append(str(arg)) + mesh_shape = arg.mesh.shape + elif isinstance(arg, OpStrategy): + assert len(arg.strategies) == 1 + args_schema.append(_pretty_print_spec(arg.strategies[0].output_specs)) + mesh_shape = arg.mesh_shape + elif isinstance(arg, TupleStrategy): + first_op_strategy = arg.childs[0] + assert isinstance(first_op_strategy, OpStrategy) + mesh_shape = first_op_strategy.mesh_shape + args_schema.append(str(arg)) + else: + args_schema.append(str(arg)) + return f"Op(op={self.op}, args_schema={', '.join(args_schema)} @ mesh: {mesh_shape})" + + def __post_init__(self) -> None: + has_symints = False + for a in self.args_schema: + if isinstance(a, DTensorSpec) and a.tensor_meta is not None: + if any(isinstance(s, torch.SymInt) for s in a.tensor_meta.shape): + has_symints = True + break + self.has_symints = has_symints + + def arg_type_tensor_or_tensor_list_like(self, arg_idx: int) -> bool: + arg = self.args_schema[arg_idx] + is_tensor = isinstance(arg, DTensorSpec) + if is_tensor: + return True + + if not isinstance(arg, list): + return False + + return all(isinstance(e, DTensorSpec) or e is None for e in arg) + + def return_type_tuple_tensor_like(self) -> bool: + # all dispatch ops could only return Tuple[Tensor] or have None/ints/floats + # in the tuple, but the first element must be a Tensor, so this check is enough + return_types = self.op._schema.returns + return len(return_types) > 1 and isinstance( + return_types[0].type, torch.TensorType + ) + + def return_type_tensor(self) -> bool: + return_types = self.op._schema.returns + # all dispatch ops only return Tensor or Tuple[Tensor] for tensor like + # return types, so this check is enough for tensor like types + return isinstance(return_types[0].type, torch.TensorType) + + def get_mesh_from_args(self, validate: bool = True) -> DeviceMesh: + """ + This util can be used to get a mesh from the OpSchema that contains multiple + DTensors as arguments. When `validate` is True, it will try to validate that all the + arguments have the same mesh to avoid unexpected cross mesh errors. + + NOTE: this util currently does not handle TupleStrategy when `validate=True`, + this is because for TupleStrategy there could be different types of checks, i.e.: + - for stack and cat like op, we need to check within a TupleStrategy is every + input is on the same mesh + - for foreach like ops we need to check "zipped" inputs are on the same mesh + for each index. + """ + first_arg = self.args_schema[0] + if isinstance(first_arg, (DTensorSpec, OpStrategy)): + mesh = first_arg.mesh + elif isinstance(first_arg, (list, tuple, TupleStrategy)): + first_elem = ( + first_arg.childs[0] + if isinstance(first_arg, TupleStrategy) + else first_arg[0] + ) + assert isinstance(first_elem, (DTensorSpec, OpStrategy)) + mesh = first_elem.mesh + else: + raise ValueError(f"Cannot find device mesh from args for op : {self.op}.") + + if validate: + for arg in self.args_schema[1:]: + if isinstance(arg, (DTensorSpec, OpStrategy)) and arg.mesh != mesh: + raise RuntimeError( + f"DTensor does not support cross-mesh operation on {self.op}! " + f"Got meshes: {mesh} {arg.mesh}. " + f"Please make sure all the arguments have the same DeviceMesh." + ) + + return mesh + + def is_inplace_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an inplace variant, it might not + # be entirely correct, but it's good enough for now. + return self.op._schema.name[-1] == "_" + + def is_out_variant_op(self) -> bool: + # simple analysis of function schema to determine + # if this is an out variant, it might not + # be entirely correct, but it's good enough for now. + return "out" in self.op._schema.overload_name + + def __hash__(self) -> int: + # Only hash args and kwargs that op indicates to hash + if not self.schema_info: + static_argnum = len(self.args_schema) + static_kwargkey = None + else: + static_argnum = self.schema_info.static_argnum + static_kwargkey = self.schema_info.static_kwargkey + + args_to_hash = tuple( + tuple(e) if isinstance(e, list) else e + for i, e in enumerate(self.args_schema) + if self.arg_type_tensor_or_tensor_list_like(i) or i >= static_argnum + ) + if static_kwargkey is not None: + kwargs_to_hash = tuple( + self.kwargs_schema.get(k, None) for k in static_kwargkey + ) + return hash((self.op, args_to_hash, kwargs_to_hash)) + else: + return hash((self.op, args_to_hash)) + + def __eq__(self, other: object) -> bool: + # early return checks + if not isinstance(other, OpSchema): + return False + + if self.op != other.op: + return False + + if len(self.args_schema) != len(other.args_schema): + return False + + # compare each element and early return if any of them is different + if not self.schema_info: + static_argnum = len(self.args_schema) + static_kwargkey = None + else: + static_argnum = self.schema_info.static_argnum + static_kwargkey = self.schema_info.static_kwargkey + + for i, (self_arg, other_arg) in enumerate( + zip(self.args_schema, other.args_schema) + ): + if isinstance(self_arg, DTensorSpec) and self_arg != other_arg: + return False + elif i >= static_argnum and self_arg != other_arg: + return False + + # check kwarg equality when there's a static kwarg key + if static_kwargkey: + for key in static_kwargkey: + if self.kwargs_schema.get(key, None) != other.kwargs_schema.get( + key, None + ): + return False + + return True + + def gen_fake_args(self) -> ArgsType: + """ + gen_fake_args: generate fake args for the operator, this is mainly used + by sharding propagation rules to generate fake args for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, + _rebuild_tensor_from_dtensor_meta, + self.args_schema, + is_leaf=lambda x: isinstance(x, DTensorSpec), + ) + + def gen_fake_kwargs(self) -> KwargsType: + """ + gen_fake_kwargs: generate fake kwargs for the operator, this is mainly used + by sharding propagation rules to generate fake kwargs for the operator + to run the local tensor operator and get the output spec. + """ + return tree_map_only( + DTensorSpec, + _rebuild_tensor_from_dtensor_meta, + self.kwargs_schema, + is_leaf=lambda x: isinstance(x, DTensorSpec), + ) + + def _inplace_rewrap_schema_suggestion(self, origin_schema: "OpSchema") -> None: + suggestion_args_spec = self.args_spec + new_arg_schema: list[object] = [] + idx_of_args_spec = 0 + if ( + origin_schema.schema_info is not None + and origin_schema.schema_info.needs_pytree + ): + args_schema: Sequence[Any] = tree_leaves(origin_schema.args_schema) + else: + args_schema = origin_schema.args_schema + for arg in args_schema: + if isinstance(arg, DTensorSpec): + new_arg_schema.append(suggestion_args_spec[idx_of_args_spec]) + idx_of_args_spec += 1 + else: + new_arg_schema.append(arg) + self.args_schema = tuple(new_arg_schema) + self.kwargs_schema = origin_schema.kwargs_schema + + +@dataclass +class OutputSharding: + """ + OutputSharding is a data class that is used by the sharding propagation, + it could set the output_spec upon successful propagation. If needs_redistribute + is set to True, a redistribute_schema would be returned together to indicate + the input arguments needs to be redistributed before the op execution. + + NOTE: the redistribute_schema generated by sharding propagation should be + exactly the same as the operator OpSchema, except the DTensorSpecs + """ + + output_spec: OutputSpecType + redistribute_schema: Optional[OpSchema] = None + needs_redistribute: bool = False + + @cached_property + def mesh(self): + if isinstance(self.output_spec, DTensorSpec): + return self.output_spec.mesh + elif isinstance(self.output_spec, tuple): + out_spec = self.output_spec[0] + if isinstance(out_spec, DTensorSpec): + return out_spec.mesh + else: + raise ValueError(f"Unknown output spec type: {type(out_spec)}") + else: + raise ValueError(f"Unknown output spec type: {type(self.output_spec)}") + + +@dataclass +class OpInfo: + """ + All Runtime Op execution info are packed here + """ + + # The first compute device mesh recorded from args + # NOTE: one op could have multiple meshes from its args. We just record the first + # mesh here to check if current rank should participate in computation or not. + compute_mesh: DeviceMesh + + # compete runtime operator infos + schema: OpSchema + flat_args_schema: list[object] + local_args: Sequence[object] + local_kwargs: dict[str, object] + args_tree_spec: Optional[TreeSpec] = None + + # the output sharding info + output_sharding: Optional[OutputSharding] = None diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__init__.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bef1ee9b36e7f77cdde60282fa4e3f5f8d956223 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from ._conv_ops import * # noqa: F403 +from ._embedding_ops import * # noqa: F403 +from ._math_ops import * # noqa: F403 +from ._matrix_ops import * # noqa: F403 +from ._pointwise_ops import * # noqa: F403 +from ._random_ops import * # noqa: F403 +from ._tensor_ops import * # noqa: F403 +from ._view_ops import * # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b221d544e5ffd50c1e3ad5779ddfc2dfe238bd7f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a9a9c7f825f73dcae3c4fdd548b5bc4922aefba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_common_rules.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..04231d6156aef120b798e3d3339647703e94a50a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_conv_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b74bc2b71d8607a6164fbb703cb5077bb44f8184 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_einsum_strategy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce3e025ccf51349a6f9028b3c5f585d86e2e6f48 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_embedding_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ef523358604720577a26be7a7ad1f0d7244e0b2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_math_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17cafa02c0f64f2fefdb7c816e910875af3f1e01 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_matrix_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0f536e01981648a58a0ca7c735dd8f46568f4f09 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_pointwise_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6064639d655868cf171080d829ee547ae0995de9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_random_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e87ccc14eb95ba577c97cda23041a02d758fdbee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_tensor_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c9e68235c0868df63ad9f7d876b6f207986eb94 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/_view_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b546ef99c7d7122e836a386da887f17e9232f52 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_common_rules.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_common_rules.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b2ca2ac9902b09e2c0ca258afdcb7ea51e5fe0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_common_rules.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import string +from typing import cast, Optional + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._ops.utils import prod +from torch.distributed.tensor._utils import compute_local_shape_and_global_offset + + +def _replace_char_in_str(string: str, new_char: str, idx: int) -> str: + return string[:idx] + new_char + string[idx + 1 :] + + +def _gen_reshard_suggestions( + op_schema: OpSchema, + input_dims: list[str], + input_specs: tuple[DTensorSpec, ...], + dim_to_sharding: dict[str, int], + pending_sum: list[int], +) -> OutputSharding: + suggested_arg_specs: list[DTensorSpec] = [] + for input_dim, input_spec in zip(input_dims, input_specs): + dim_map = [dim_to_sharding[dim] for dim in input_dim] + suggested_arg_specs.append( + DTensorSpec.from_dim_map( + mesh=input_spec.mesh, + dim_map=dim_map, + sums=pending_sum, + tensor_meta=input_spec.tensor_meta, + ) + ) + suggested_schema = OpSchema(op_schema.op, tuple(suggested_arg_specs), {}) + suggested_schema._inplace_rewrap_schema_suggestion(op_schema) + return OutputSharding( + None, + redistribute_schema=suggested_schema, + ) + + +def einop_rule( + equation: str, + op_schema: OpSchema, + *, + linearity: bool = False, + enforce_sharding: Optional[dict[str, int]] = None, +) -> OutputSharding: + """ + Propagate the sharding of inputs to output for ops whose data moves according to einsum notation. + + This is mostly borrowed from @zdevito's sharding simulator. Examples: + mk,kn->mn - einsum + ij,ij->ij - addition + ij,j->ij - broadcasted addition + ij->i - reduction + Other ops could use this propagation algorithm when applied, note + that einsum propagation only deal with list of specs (DTensor specs) + as it only works on list of tensors! + + linearity in einop_rule means that the calling op `f` follows this rule: + f(a + b) = f(a) + f(b) + + In this case we can propagate the partial sum, note that linearity in einop + only applies to partial sum, not other operations like min/max (which are + associative but not linear). + """ + # parse einop equation and extract arg specs + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + input_specs = op_schema.args_spec + # NOTE: only support single output unless needed in future + output_dim = output_dims[0] + + dim_to_sharding: dict[str, int] = {} + dim_to_size: dict[str, int] = {} + # record pending sum, key is mesh dimension, value is pending sum + # counter across input specs + pending_sums_counter: dict[int, int] = {} + seen_shardings: dict[int, str] = {} + needs_reshard = False + + def merge_sharding(dim: str, a: int, b: int) -> int: + # merge the sharding of inputs if it's able to merge, i.e. we can merge + # replicate and shard to shard, but this will trigger an reshard operation + if a != b: + if a == -1 or b == -1: + # reshard the replicate to match the sharded one + nonlocal needs_reshard + needs_reshard = True + return a if a != -1 else b + else: + # TODO: further merge the sharding properly (i.e. reshard one input to replicate) + raise RuntimeError( + f"{equation}: dim {dim} sharded two different ways: {a} and {b}" + ) + else: + return a + + for input_dim, input_spec in zip(input_dims, input_specs): + # deal with partial sums + input_sums = input_spec.sums + for sum_dim in input_sums: + if sum_dim not in pending_sums_counter: + seen_shardings[sum_dim] = "+" + # update pending sum counter for pending sum mesh + # dimension with the occurrence from each input + pending_sums_counter[sum_dim] = pending_sums_counter.get(sum_dim, 0) + 1 + + for idx, (dim, mesh_dim) in enumerate(zip(input_dim, input_spec.dim_map)): + if enforce_sharding and dim in enforce_sharding: + if enforce_sharding[dim] != mesh_dim: + needs_reshard = True + dim_to_sharding[dim] = enforce_sharding[dim] + dim_to_size[dim] = input_spec.shape[idx] + elif dim not in dim_to_sharding: + dim_to_sharding[dim] = mesh_dim + dim_to_size[dim] = input_spec.shape[idx] + else: + dim_to_sharding[dim] = merge_sharding( + dim, dim_to_sharding[dim], mesh_dim + ) + assert dim_to_size[dim] == input_spec.shape[idx] + + # after merging sharding, we check if there're multiple + # sharding on the same mesh dim. + merged_sharding_for_dim = dim_to_sharding[dim] + if merged_sharding_for_dim != -1: + if ( + merged_sharding_for_dim in seen_shardings + and dim != seen_shardings[merged_sharding_for_dim] + ): + needs_reshard = True + seen_shardings[merged_sharding_for_dim] += dim + else: + seen_shardings[merged_sharding_for_dim] = dim + + if pending_sums_counter and not linearity: + # return reshard suggestion with no pending sum, because we already properly + # merge the sharding, this reshard suggestion is legit to use + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, [] + ) + else: + # It's a op that support linearity, but not all input arguments are partial + # we fail the sharding propagation with suggestion to make all inputs be + # partial on the corresponding mesh dim (all inputs should be partial for + # the mesh dims in order to execute locally and delay the sum reduction) + for value in pending_sums_counter.values(): + if value != len(input_specs): + needs_reshard = True + + for mesh_dim, dims in seen_shardings.items(): + if len(dims) > 1: + # we found different input dims are being sharded on the same mesh dim + # in order to perform local op computation, we need to reshard inputs + # base on some simple heuristics, now we simply pick the one with least comm + # volume. (i.e. the input with least size) + # TODO: consider a more advanced heuristic to pick the best sharding + costs = [] + for d in dims: + cost = 0 + for input_dim, input_spec in zip(input_dims, input_specs): + if ( + d in input_dim + and input_spec.dim_map[input_dim.index(d)] == mesh_dim + ): + assert input_spec.tensor_meta is not None + global_shape = input_spec.tensor_meta.shape + local_shape, _ = compute_local_shape_and_global_offset( + global_shape, input_spec.mesh, input_spec.placements + ) + cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) + costs.append(cost) + d_to_keep_sharding = dims[costs.index(max(costs))] + for d in dims: + # update dim_to_sharding to keep the sharding of the dim with + # highest comm and make the rest of the dims to replicate + if d != d_to_keep_sharding: + dim_to_sharding[d] = -1 + + pending_sums = list(pending_sums_counter.keys()) + if needs_reshard: + return _gen_reshard_suggestions( + op_schema, input_dims, input_specs, dim_to_sharding, pending_sums + ) + + # generate output pending sum if a dim is sharded, and it appears in input + # but not output + for dim, shard_on_mesh in dim_to_sharding.items(): + if dim not in output_dims[0] and shard_on_mesh != -1: + pending_sums.append(shard_on_mesh) + + # if no need to reshard, we directly generate the output sharding + output_dim_map = [] + output_shape = [] + for dim in output_dim: + if dim == "1": + # find output dim that is a singleton dimension, mark sharding and shape + output_dim_map.append(-1) + output_shape.append(1) + else: + output_dim_map.append(dim_to_sharding[dim]) + output_shape.append(dim_to_size[dim]) + + # XXX: since we still need to have intermediate shape calculation, we need + # to pass in the shape here. We should remove this once sharding decomp works + # for ops like addmm + assert input_specs[0].tensor_meta is not None + tensor_meta = TensorMeta( + torch.Size(output_shape), + input_specs[0].tensor_meta.stride, + input_specs[0].tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_specs[0].mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +def pointwise_rule(op_schema: OpSchema, linearity: bool = False) -> OutputSharding: + """ + Propagate the sharding for pointwise operations. + + Examples: + ij,ij->ij - addition/mul + ij,j->ij - broadcasted addition + """ + alphabet = string.ascii_lowercase + # find the max_dim first in case we need to broadcasting + input_specs = op_schema.args_spec + max_dim = max(input.ndim for input in input_specs) + dimchars = [] + singleton_counter: list[int] = [0] * max_dim + for input in input_specs: + start_dim = max_dim - input.ndim + p = alphabet[start_dim:max_dim] + # handle the "broadcasting to a common shape case" + # see https://pytorch.org/docs/stable/notes/broadcasting.html + # If any of the dimensions is singleton dimension (i.e. 1). + # we mark the dim char as a special "1" to distinguish with + # the non-singleton dimension, so that sharding propagation + # should just ignore the singleton dimension. + if len(input_specs) > 1: + for i in range(max_dim): + if i < start_dim: + # treat the leading miss dim chars as singleton + singleton_counter[i] += 1 + elif input.shape[i - start_dim] == 1: + # mark singleton dim char as a special "1" in einop rule + singleton_counter[i] += 1 + p = _replace_char_in_str(p, "1", (i - start_dim)) + + dimchars.append(p) + out_dimchars = alphabet[:max_dim] + # check if we replace the all inputs dim char with singleton dimension, + # if we replace all inputs, we also need to replace the output dimension. + for output_dim_idx in range(len(out_dimchars)): + if singleton_counter[output_dim_idx] == len(input_specs): + out_dimchars = _replace_char_in_str(out_dimchars, "1", output_dim_idx) + + fmt = f"{','.join(p for p in dimchars)}->{out_dimchars}" + + enforce_sharding: dict[str, int] = {} + if op_schema.is_inplace_op(): + follow_spec = op_schema.args_spec[0] + enforce_sharding.update(zip(out_dimchars, follow_spec.dim_map)) + elif op_schema.is_out_variant_op(): + follow_spec = cast(DTensorSpec, op_schema.kwargs_schema["out"]) + enforce_sharding.update(zip(out_dimchars, follow_spec.dim_map)) + + return einop_rule( + fmt, + op_schema, + linearity=linearity, + enforce_sharding=enforce_sharding, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_conv_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_conv_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..ea1b5a250e2b4993a9cedc1fa80616621619b0e9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_conv_ops.py @@ -0,0 +1,112 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import OpSchema, OutputSharding +from torch.distributed.tensor._ops.utils import register_prop_rule + + +aten = torch.ops.aten + + +@register_prop_rule(aten.convolution.default) +def convolution_rules(op_schema: OpSchema) -> OutputSharding: + ( + input_spec, + weight_spec, + bias_spec, + stride, + padding, + dilation, + _transposed, + _output_padding, + _groups, + ) = op_schema.args_schema + + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + assert isinstance(bias_spec, DTensorSpec) + assert input_spec.tensor_meta is not None + assert weight_spec.tensor_meta is not None + in_shape = input_spec.tensor_meta.shape + weight_shape = weight_spec.tensor_meta.shape + assert isinstance(stride, list) + assert isinstance(padding, list) + assert isinstance(dilation, list) + assert isinstance(weight_shape, torch.Size) + N, H_in, W_in = in_shape[0], in_shape[2], in_shape[3] + C_out = weight_shape[0] + H_out = (H_in + 2 * padding[0] - dilation[0] * (weight_shape[2] - 1) - 1) // stride[ + 0 + ] + 1 + W_out = (W_in + 2 * padding[1] - dilation[1] * (weight_shape[3] - 1) - 1) // stride[ + 1 + ] + 1 + output_shape = [N, C_out, H_out, W_out] + output_stride = (C_out * H_out * W_out, H_out * W_out, W_out, 1) + output_dim_map = input_spec.dim_map + pending_sums = input_spec.sums + + tensor_meta = TensorMeta( + torch.Size(output_shape), + output_stride, + input_spec.tensor_meta.dtype, + ) + return OutputSharding( + DTensorSpec.from_dim_map( + input_spec.mesh, + output_dim_map, + pending_sums, + tensor_meta=tensor_meta, + ) + ) + + +@register_prop_rule(aten.convolution_backward.default) +def convolution_backward_rules(op_schema: OpSchema) -> OutputSharding: + input_spec = op_schema.args_schema[0] + ( + grad_output_spec, + input_spec, + weight_spec, + bias_shape_opt, + _stride, + _padding, + _dilation, + _transposed, + _output_padding, + _groups, + _output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_output_spec, DTensorSpec) + assert isinstance(input_spec, DTensorSpec) + assert isinstance(weight_spec, DTensorSpec) + assert isinstance(bias_shape_opt, list) + assert input_spec.tensor_meta is not None + weight_tensor_meta = weight_spec.tensor_meta + bias_tensor_meta = TensorMeta( + torch.Size(bias_shape_opt), + (1,), + input_spec.tensor_meta.dtype, + ) + + grad_input_spec = input_spec + grad_weight_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1, -1, -1, -1], + [0], + tensor_meta=weight_tensor_meta, + ) + grad_bias_spec = DTensorSpec.from_dim_map( + input_spec.mesh, + [-1], + [0], + tensor_meta=bias_tensor_meta, + ) + # TODO: actually the output_mask is not respected here, we should + # set the corresponding spec to `None` if the output_mask is not `False` + # for a certain output Tensor. This also applies to the conv handler + # in torch/distributed/tensor/_tp_conv.py + return OutputSharding([grad_input_spec, grad_weight_spec, grad_bias_spec]) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..c5e29172950798d93f4c16d73421f71cdf7b21c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_einsum_strategy.py @@ -0,0 +1,173 @@ +import itertools +from dataclasses import dataclass + +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import OpSpec, OpStrategy +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +@dataclass +class EinsumDims: + contracting_dims: list[str] + batch_dims: list[str] + lhs_out_only_dims: list[str] + rhs_out_only_dims: list[str] + + @classmethod + def parse_equation(cls, equation: str) -> tuple[list[str], str]: + # parse einop equation and extract arg specs + """ + Parse the einsum equation str to input dim chars and output dim char + """ + inputs, outputs = equation.split("->") + input_dims, output_dims = inputs.split(","), outputs.split(",") + + # NOTE: only support at most two inputs, and single output + # extend to support more inputs if needed in future + assert len(input_dims) <= 2, "Only support at most two inputs" + assert len(output_dims) == 1, "Only support single output" + output_dim = output_dims[0] + return input_dims, output_dim + + @classmethod + def parse_dims(cls, input_dims: list[str], output_dim: str) -> "EinsumDims": + """ + Parse the dims and extract the contracting, batch, and free dimensions + for the left and right hand sides. + """ + dim_char_set: set[str] = set() + for input_dim in input_dims: + dim_char_set.update(input_dim) + + # get a determinisitc order of all dim chars + all_dim_chars = sorted(dim_char_set) + + # parse input and output dimensions + lhs_out_only_dims, rhs_out_only_dims = [], [] + batch_dims, contracting_dims = [], [] + + for dim_char in all_dim_chars: + if dim_char not in output_dim: + contracting_dims.append(dim_char) + else: + is_batch_dim = True + for input_dim in input_dims: + is_batch_dim = is_batch_dim and dim_char in input_dim + + if is_batch_dim: + batch_dims.append(dim_char) + else: + assert len(input_dims) == 2, ( + "free dimension only supported for two inputs!" + ) + lhs, rhs = input_dims + if dim_char in lhs: + lhs_out_only_dims.append(dim_char) + elif dim_char in rhs: + rhs_out_only_dims.append(dim_char) + else: + raise RuntimeError("Invalid dimension character") + + return cls( + contracting_dims=contracting_dims, + batch_dims=batch_dims, + lhs_out_only_dims=lhs_out_only_dims, + rhs_out_only_dims=rhs_out_only_dims, + ) + + +def gen_einsum_strategies( + equation: str, + mesh: DeviceMesh, + *, + linearity: bool = False, +) -> OpStrategy: + """ + Generate a strategy list for the ops that follow einsum style notation. + """ + # parse einop equation and extract dims + input_dims, output_dim = EinsumDims.parse_equation(equation) + edims = EinsumDims.parse_dims(input_dims, output_dim) + + all_mesh_dim_strategies = [] + + # generate strategies for each mesh dim + for mesh_dim in range(mesh.ndim): + mesh_dim_strategies = [] + + # placement list stores placements of [output, input1, input2, ...] + # first we always have replicate all for inputs and output + placement_list: list[Placement] = [Replicate()] * (len(input_dims) + 1) + mesh_dim_strategies.append(placement_list) + + # split batch dim + for batch_dim in edims.batch_dims: + output_batch_dim = output_dim.index(batch_dim) + placement_list = [Shard(output_batch_dim)] + for input_dim in input_dims: + input_batch_dim = input_dim.index(batch_dim) + placement_list.append(Shard(input_batch_dim)) + + mesh_dim_strategies.append(placement_list) + + # split contracting dim + for contracting_dim in edims.contracting_dims: + placement_list = [Partial()] + for input_dim in input_dims: + input_contracting_dim = input_dim.index(contracting_dim) + placement_list.append(Shard(input_contracting_dim)) + + mesh_dim_strategies.append(placement_list) + + # split lhs free dim + for lhs_dim in edims.lhs_out_only_dims: + lhs_free_dim = output_dim.index(lhs_dim) + # this means split the lhs input and output + # i.e. S(0), R -> S(0) + lhs_placement_list: list[Placement] = [ + Shard(lhs_free_dim), + Shard(lhs_free_dim), + Replicate(), + ] + mesh_dim_strategies.append(lhs_placement_list) + + # split rhs free dim + for rhs_dim in edims.rhs_out_only_dims: + rhs_free_dim = output_dim.index(rhs_dim) + rhs_placement_list: list[Placement] = [ + Shard(rhs_free_dim), + Replicate(), + Shard(rhs_free_dim), + ] + mesh_dim_strategies.append(rhs_placement_list) + + # linearity strategy + if linearity: + linearity_placement_list: list[Placement] = [Partial()] + for input_dim in input_dims: + linearity_placement_list.append(Partial()) + mesh_dim_strategies.append(linearity_placement_list) + + all_mesh_dim_strategies.append(mesh_dim_strategies) + + # generate strategies for entire mesh + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + # TODO: filter out invalid strategies, at this point we generate + # all possible strategies without considering the whether the tensor + # dim could be sharded or not, we would need to filter out invalid + # strategies base on the actual tensor shape + # (i.e. for Shard, tensor dim size must > mesh size) + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)] + strat = OpSpec(output_specs=spec_list[0], input_specs=spec_list[1:]) + all_strategies.append(strat) + + return OpStrategy(all_strategies) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c48c96670261b28d5e950782cd4e3d75027283ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_embedding_ops.py @@ -0,0 +1,272 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from dataclasses import dataclass, field +from typing import cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + StrategyType, +) +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +@dataclass +class MaskBuffer: + data: Optional[torch.Tensor] = None + # refcount allows shared usage of the MaskBuffer, as long as all users have the same data + refcount: int = 0 + + def materialize_mask(self, mask): + if self.refcount == 0: + self.data = mask + else: + assert self.data is not None + if not torch.equal(self.data, mask): + raise RuntimeError( + "MaskBuffer has been materialized with conflicting data" + ) + self.refcount += 1 + + def release_mask(self): + if self.refcount == 0 or self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + self.refcount -= 1 + if self.refcount == 0: + self.data = None + + def apply_mask(self, tensor): + if self.refcount == 0 or self.data is None: + raise RuntimeError("MaskBuffer has not been materialized") + + # NOTE: _MaskPartial is being used by the embedding op and the gather op. + # For gather, the mask has the same dimension as the output tensor, whereas + # the output of the embedding op has an additional dimension compare to the input, + # hence the output masking logic below having two different cases. + if tensor.ndim == self.data.ndim: + tensor[self.data] = 0.0 + else: + tensor[self.data, :] = 0.0 + + +@dataclass(frozen=True) +class _MaskPartial(Partial): + """ + A partial mask placement devised for rowwise sharded embedding op, where we need + to mask and adjust the indices to the local embedding shard, embedding masking + is a special type of the Partial placement + + NOTE: the lifecycle of this MaskPartial placement follows the corresponding DTensor + lifecycle, i.e. the indices_mask would only be alive during the lifetime of the DTensor. + """ + + mask_buffer: MaskBuffer = field(default_factory=MaskBuffer) + + # required fields for computing the local offset and deriving the mask + offset_shape: Optional[torch.Size] = None + offset_dim: int = 0 + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # override parent logic to perform partial mask for embedding + num_chunks = mesh.size(mesh_dim) + # get local shard size and offset on the embedding_dim + assert self.offset_shape is not None, ( + "offset_shape needs to be set for _MaskPartial" + ) + local_shard_size, local_offset_on_dim = Shard._local_shard_size_and_offset( + self.offset_shape[self.offset_dim], + num_chunks, + mesh.get_local_rank(mesh_dim), + ) + # Build the input mask and save it for the current partial placement + # this is so that the output of embedding op can reuse the same partial + # placement saved mask to perform mask + reduction + mask = (tensor < local_offset_on_dim) | ( + tensor >= local_offset_on_dim + local_shard_size + ) + # mask the input tensor + masked_tensor = tensor.clone() - local_offset_on_dim + masked_tensor[mask] = 0 + # materialize the mask buffer to be used for reduction + self.mask_buffer.materialize_mask(mask) + return masked_tensor + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # perform sum reduction + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # by the time we ned reduction, we should have already saved the mask + assert self.mask_buffer.data is not None + + # apply the mask to the tensor that pending reduction + self.mask_buffer.apply_mask(tensor) + + # clear the mask buffer + self.mask_buffer.release_mask() + + # call reduce_shard_tensor of the shard_spec. + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _MaskPartial): + return False + + # if either data is not None, we invalidate the sharding cache, as this indicates + # the current MaskPartial placement is still in use and should not be used for cache hit. + if self.mask_buffer.data is not None or other.mask_buffer.data is not None: + return False + + return ( + self.reduce_op == other.reduce_op + and self.offset_shape == other.offset_shape + and self.offset_dim == other.offset_dim + ) + + def __hash__(self) -> int: + return 1 + hash( + ( + self.reduce_op, + self.offset_shape, + self.offset_dim, + ) + ) + + def __repr__(self) -> str: + """ + machine readable representation of the MaskPartial placement + """ + return f"_MaskPartial(offset_shape={self.offset_shape}, offset_dim={self.offset_dim})" + + def __str__(self) -> str: + """ + human readable representation of the MaskPartial placement + """ + return "MaskP" + + +@register_op_strategy(aten.embedding.default) +def embedding_strategy(op_schema: OpSchema) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + weight_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + mesh = op_schema.get_mesh_from_args() + + weight_shape = weight_strategy.shape + indices_shape = indices_strategy.shape + output_emd_dim = len(indices_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding, output shard on last dim, weight shard on dim 1, input replicate + colwise_sharding: PlacementList = [Shard(output_emd_dim), Shard(1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # rowwise sharding, output is embedding partial, weight shard on dim 0, input accepts embedding partial + embedding_partial_placement = _MaskPartial(offset_shape=weight_shape, offset_dim=0) + + # NOTE we want to reuse the same mask partial placement so that we can reuse the same mask that generates + # from the input indices and use it for output reduction + rowwise_sharding: PlacementList = [ + embedding_partial_placement, + Shard(0), + embedding_partial_placement, + ] + single_mesh_dim_strategies.append(rowwise_sharding) + + # batch dim sharding, weight replicated, input can shard on any dim, output follows input + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [ + Shard(input_dim), + Replicate(), + Shard(input_dim), + ] + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) + + +@register_op_strategy(aten.embedding_dense_backward.default) +def embedding_dense_backward_strategy(op_schema: OpSchema) -> StrategyType: + """ + This strategy handles embedding op. We have two possible embedding shardings: + rowwise and colwise + """ + grad_out_strategy = cast(OpStrategy, op_schema.args_schema[0]) + indices_strategy = cast(OpStrategy, op_schema.args_schema[1]) + mesh = op_schema.get_mesh_from_args() + + grad_out_shape = grad_out_strategy.shape + indices_shape = indices_strategy.shape + grad_out_ndim = len(grad_out_shape) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, weight, input_indices] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # colwise sharding backward, grad_out shard on last dim, input replicate, + # weight grad shard colwise + colwise_sharding: PlacementList = [Shard(1), Shard(grad_out_ndim - 1), Replicate()] + single_mesh_dim_strategies.append(colwise_sharding) + + # batch dim sharding, weight replicated, grad_out/input have same sharding + # that can shard on any dim, weight grad partial + for input_dim in range(len(indices_shape)): + batch_sharding: PlacementList = [Partial(), Shard(input_dim), Shard(input_dim)] + single_mesh_dim_strategies.append(batch_sharding) + + # grad_out partial, input replicate, weight grad keep partial + partial_sharding: PlacementList = [Partial(), Partial(), Replicate()] + single_mesh_dim_strategies.append(partial_sharding) + + return expand_to_full_mesh_op_strategy(mesh, op_schema, single_mesh_dim_strategies) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_math_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_math_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..369dcfb47ba0c484e2c9ae11b0bf0666276a9ec1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_math_ops.py @@ -0,0 +1,1092 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import math +from collections.abc import Sequence +from dataclasses import dataclass +from enum import Enum +from typing import cast, Optional, Union + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, + TupleStrategy, +) +from torch.distributed.tensor._ops.utils import ( + as_list, + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + is_tensor_evenly_shardable, + normalize_dim, + normalize_dims, + register_op_strategy, +) +from torch.distributed.tensor._utils import normalize_to_torch_size +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +class Reduction(Enum): + NONE = 0 + MEAN = 1 + SUM = 2 + + +@dataclass(frozen=True) +class NormReduction: + norm_type: Union[int, float, str] + + +ReductionOpType = Union[NormReduction, str] + + +@dataclass(frozen=True) +class _NormPartial(Partial): + """ + This placement is used for partial vector norm. + + For p-norms (where p not inf or -inf), the p-norm over n elements computes + (sum_i x_i^p)^(1/p) + where the sum is from i=1 to n. The reduction op is the p-norm itself. + For example, consider 2 ranks, a (4,) tensor sharded on dim-0, and 2-norm: + Rank 0: [t1, t2] | Rank 1: [t3, t4] + After computing 2-norm per gradient (partial placement): + Rank 0: [sqrt(t1^2 + t2^2)] | Rank 1: [sqrt(t3^2 + t4^2)] + Converting from partial to replicate wants to ultimately get: + Rank 0/1: [sqrt(t1^2 + t2^2 + t3^2 + t4^2)] + This can be achieved by computing 2-norm on each rank's result. This holds + similarly for inf and -inf norm. For 0-norm, the reduction op is sum. + """ + + norm_type: Union[int, float, str] = 2 + + def __post_init__(self): + """Set the appropriate reduce op based on the norm type.""" + # Use `object.__setattr__` to bypass frozen checks + if self.norm_type in (float("inf"), "inf"): + object.__setattr__(self, "reduce_op", "max") + elif self.norm_type in (float("-inf"), "-inf"): + object.__setattr__(self, "reduce_op", "min") + elif isinstance(self.norm_type, (int, float)): + object.__setattr__(self, "reduce_op", "sum") + else: + raise NotImplementedError(f"Unsupported norm type: {self.norm_type}") + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + """ + For example, consider 4 ranks, a (3,) replicated tensor, and 2-norm: + Ranks 0 and 1: sqrt(t1^2 + t2^2 + t3^3) + To convert from replicated to partial, we want f(x) such that + sqrt(t1^2 + t2^2 + t3^3) = sqrt(4f(t1)^2 + 4f(t2)^2 + 4f(t3)^2) + = sqrt(4) sqrt(f(t1)^2 + f(t2)^2 + f(t3)^2). + One such f(x) is f(x) = x / sqrt(4). This generalizes to d ranks and + p-norm as f(x) = x / d^(1/p). + """ + if self.reduce_op in ("max", "min"): + return tensor + elif self.reduce_op == "sum": + if self.norm_type == 0: + raise NotImplementedError(f"Unsupported norm type:: {self.norm_type}") + elif self.norm_type == 1: + return tensor / mesh.size(mesh_dim) + assert isinstance(self.norm_type, (int, float)) + return tensor / math.pow(mesh.size(mesh_dim), 1 / self.norm_type) + raise NotImplementedError(self.reduce_op) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + assert isinstance(shard_spec, Shard), f"{shard_spec}" + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_shard_value(tensor, mesh, mesh_dim, shard_spec) + return self._post_reduce_transform(reduced_tensor) + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + tensor = self._pre_reduce_transform(tensor) + reduced_tensor = super()._reduce_value(tensor, mesh, mesh_dim) + return self._post_reduce_transform(reduced_tensor) + + def _pre_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" + if self.norm_type != 0 and self.norm_type != 1: + return tensor**self.norm_type + return tensor + + def _post_reduce_transform(self, tensor: torch.Tensor) -> torch.Tensor: + if self.reduce_op == "sum": + assert isinstance(self.norm_type, (int, float)), f"{self.norm_type}" + if self.norm_type != 0 and self.norm_type != 1: + return tensor ** (1.0 / self.norm_type) + return tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, _NormPartial): + return False + return self.norm_type == other.norm_type + + def __hash__(self) -> int: + return 1 + hash(self.norm_type) + + +def _infer_reduction_dims(dims_arg: object, ndim: int) -> Optional[list[int]]: + if dims_arg is None: + return None + dims = cast(list[int], as_list(dims_arg)) + dims = cast(list[int], normalize_dims(dims, ndim)) + empty_dims = [[0], [-1], []] + if ndim == 0 and dims_arg in empty_dims: + return None + return dims + + +def _infer_reduce_dims_map( + reduction_dims: list[int], input_ndim: int, keep_dim=False +) -> list[int]: + reduction_dims_map = [] + new_dim_count = 0 + for input_dim in range(input_ndim): + if input_dim in reduction_dims and not keep_dim: + # if input dim in reduction dims, mark it as -1 + reduction_dims_map.append(-1) + else: + # otherwise mark it as the new dim + reduction_dims_map.append(new_dim_count) + new_dim_count += 1 + + return reduction_dims_map + + +def _replicate_dims_start_at( + placements: Sequence[Placement], start_dim: int = 0 +) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if p.is_partial() or (isinstance(p, Shard) and p.dim >= start_dim): + new_placements.append(Replicate()) # make it replicate + else: + new_placements.append(p) # keep the placement + return tuple(new_placements) + + +# return new_placements which align with placements but skip the skipped_dim +def _skip_dim( + placements: tuple[Placement, ...], skipped_dim: int +) -> tuple[Placement, ...]: + new_placements: list[Placement] = [] + for p in placements: + if isinstance(p, Shard) and p.dim >= skipped_dim: + new_placements.append(Shard(p.dim - 1)) + else: + new_placements.append(p) + return tuple(new_placements) + + +def replicate_reduction_dims( + placements: tuple[Placement, ...], reduction_dims: list[int] +) -> tuple[Placement, ...]: + # replicate the reduction dims if not reduction_linear + new_placements: list[Placement] = [] + + for p in placements: + if p.is_partial(): + new_placements.append(Replicate()) + elif isinstance(p, Shard) and p.dim in reduction_dims: + new_placements.append(Replicate()) + else: + new_placements.append(p) + + return tuple(new_placements) + + +def map_placements_after_reduction( + placements: tuple[Placement, ...], + reduction_dims: list[int], + reduction_dims_map: list[int], + reduction_op: ReductionOpType, +) -> tuple[Placement, ...]: + """ + Map each placement based on the output shape after reduction. + """ + new_placements: list[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + assert isinstance(placement, Shard) + shard_dim = placement.dim + new_shard_dim = reduction_dims_map[shard_dim] + if new_shard_dim == -1 or shard_dim in reduction_dims: + # if new_shard_dim collapsed or its in the reduction dims + # (i.e. for the case where keepdims=True), we generate partial + new_placements.append(get_placement_from_reduction_op(reduction_op)) + else: + new_placements.append(Shard(new_shard_dim)) + return tuple(new_placements) + + +def get_placement_from_reduction_op(reduction_op: ReductionOpType) -> Placement: + if isinstance(reduction_op, NormReduction): + return _NormPartial(norm_type=reduction_op.norm_type) + return Partial(reduction_op) + + +def common_reduction_strategy( + input_strategy: OpStrategy, + reduce_dims: list[int], + keep_dim: bool = False, + reduction_linear: bool = True, + reduction_op: ReductionOpType = "sum", +) -> OpStrategy: + """ + reduction_linear means that the reduction `f` follows this rule: + f([f(a), f(b)]) = f([a, b]) + + reduction linear should be super set of linearity. + """ + # by default follow reduction input strategy + reduction_strategy = OpStrategy([]) + + for op_spec in input_strategy.strategies: + if not reduction_linear: + # input placements for this strategy should clear out pending sum and sharding + # on the reduction dimension + input_placements = replicate_reduction_dims( + op_spec.output_spec.placements, reduce_dims + ) + else: + input_placements = op_spec.output_spec.placements + + input_spec = DTensorSpec( + mesh=input_strategy.mesh, + placements=input_placements, + tensor_meta=op_spec.output_spec.tensor_meta, + ) + + reduce_dims_map = _infer_reduce_dims_map(reduce_dims, input_spec.ndim, keep_dim) + out_placements = map_placements_after_reduction( + input_spec.placements, reduce_dims, reduce_dims_map, reduction_op + ) + redistribute_cost = [generate_redistribute_costs(input_strategy, input_spec)] + reduction_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec( + mesh=input_strategy.mesh, + placements=out_placements, + ), + input_specs=(input_spec,), + redistribute_cost=redistribute_cost, + ) + ) + + return reduction_strategy + + +LINEAR_REDUCTION_OP_MAP = { + aten.all.default: "sum", + aten.all.dim: "sum", + aten.sum.default: "sum", + aten.sum.dim_IntList: "sum", + aten.prod.default: "product", + aten.prod.dim_int: "product", + aten.prod.int_out: "product", + aten.mean.default: "avg", + aten.mean.dim: "avg", + aten.mean.out: "avg", + aten.max.default: "max", + aten.max.dim: "max", + aten.max.out: "max", + aten.min.default: "min", + aten.min.dim: "min", + aten.min.out: "min", + aten.any.default: "sum", + aten.any.dim: "sum", + aten.any.out: "sum", + aten.amax.default: "max", + aten.amax.out: "max", + aten.amin.default: "min", + aten.amin.out: "min", +} + + +@register_op_strategy( + list(LINEAR_REDUCTION_OP_MAP.keys()), schema_info=RuntimeSchemaInfo(1) +) +def linear_reduction_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = len(op_schema.args_schema) > 2 and bool(op_schema.args_schema[2]) + reduction_op = LINEAR_REDUCTION_OP_MAP[op_schema.op] + return common_reduction_strategy( + input_strategy, + reduce_dims, + keep_dim=keep_dim, + reduction_linear=True, + reduction_op=reduction_op, + ) + + +@register_op_strategy(aten.cumsum.default, schema_info=RuntimeSchemaInfo(1)) +def cumsum_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + dim = args_schema[1] + assert isinstance(dim, int), f"{dim}" + + return common_reduction_strategy( + input_strategy, [dim], keep_dim=True, reduction_linear=False + ) + + +@register_op_strategy( + [aten.var.correction, aten.var.correction_out], + schema_info=RuntimeSchemaInfo(1, ["keepdim"]), +) +def var_reduction_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + dims = None + if len(op_schema.args_schema) > 1: + dims = _infer_reduction_dims(args_schema[1], input_strategy.ndim) + + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + + keep_dim = cast(bool, op_schema.kwargs_schema.get("keepdim", False)) + return common_reduction_strategy( + input_strategy, reduce_dims, keep_dim=keep_dim, reduction_linear=False + ) + + +@register_op_strategy( + [aten.linalg_vector_norm.default], schema_info=RuntimeSchemaInfo(1) +) +def vector_norm_strategy(op_schema: OpSchema) -> OpStrategy: + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy) + + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + assert isinstance(norm_type, (int, float, str)), f"{norm_type}" + dim = args_schema[2] if len(args_schema) > 2 else None + keepdim = args_schema[3] if len(args_schema) > 3 else False + dims = _infer_reduction_dims(dim, input_strategy.ndim) + reduce_dims = list(range(input_strategy.ndim)) if dims is None else dims + return common_reduction_strategy( + input_strategy, + reduce_dims, + keep_dim=cast(bool, keepdim), + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + + +@register_op_strategy( + [aten._foreach_norm.Scalar], schema_info=RuntimeSchemaInfo(1, needs_pytree=True) +) +def foreach_norm_strategy(op_schema: OpSchema) -> TupleStrategy: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy) + norm_type = args_schema[1] if len(args_schema) > 1 else 2 + assert isinstance(norm_type, (int, float, str)), f"{norm_type}" + output_tuple_strategy_childs: list[OpStrategy] = [] + for op_strategy in input_tuple_strategy.childs: + assert isinstance(op_strategy, OpStrategy), f"{op_strategy}" + reduce_dims = list(range(op_strategy.ndim)) + output_strategy = common_reduction_strategy( + op_strategy, + reduce_dims, + reduction_linear=True, + reduction_op=NormReduction(norm_type), + ) + output_tuple_strategy_childs.append(output_strategy) + return TupleStrategy(output_tuple_strategy_childs) + + +@register_op_strategy( + [ + aten._linalg_svd.default, + aten.linalg_qr.default, + # TODO: The diagonal ops can have an improved sharding strategy for + # shard placements that does not require redistributing to replicate. + aten.diagonal_copy.default, + aten.diag_embed.default, + aten.diag.default, + aten.diagonal.default, + aten.tril.default, + aten.triu.default, + aten._linalg_eigh.default, + aten.upsample_bicubic2d.default, + aten.upsample_bilinear2d.default, + aten.upsample_linear1d.default, + aten.upsample_nearest2d.default, + aten.upsample_trilinear3d.default, + # TODO: support the full F.interpolate set of options. + ], + schema_info=RuntimeSchemaInfo(1), +) +def linalg_replicate_strategy(op_schema: OpSchema) -> OpStrategy: + """ + Since we do not have a simple way to compute some linear algebra operations + like SVD or QR decomposition, always fall back to replicate. + """ + args_schema = op_schema.args_schema + input_strategy = args_schema[0] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + mesh = input_strategy.mesh + + output_strategies: list[OpSpec] = [] + for placement_strategy in input_strategy.strategies: + replicate_placements = tuple(Replicate() for _ in range(mesh.ndim)) + replicate_spec = DTensorSpec( + mesh=mesh, + placements=replicate_placements, + tensor_meta=placement_strategy.output_spec.tensor_meta, + ) + redistribute_cost = [ + generate_redistribute_costs(input_strategy, replicate_spec) + ] + replicate_strategy = OpSpec( + output_specs=replicate_spec, + input_specs=(replicate_spec,), + redistribute_cost=redistribute_cost, + ) + output_strategies.append(replicate_strategy) + return OpStrategy(output_strategies) + + +@register_op_strategy( + [aten._log_softmax.default, aten._softmax.default, aten._safe_softmax.default], + schema_info=RuntimeSchemaInfo(1), +) +def softmax_strategy(op_schema: OpSchema) -> OpStrategy: + input_strategy, softmax_dim, *_ = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, input_strategy.ndim) + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # make sure input is replicated along the softmax dim + input_target_spec = DTensorSpec( + mesh=input_strategy.mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [softmax_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + output_target_spec = input_target_spec + output_strategy.strategies.append( + OpSpec( + output_specs=output_target_spec, + input_specs=[input_target_spec], + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [ + aten._log_softmax_backward_data.default, + aten._softmax_backward_data.default, + ], + schema_info=RuntimeSchemaInfo(2), +) +def softmax_backward_strategy(op_schema: OpSchema) -> OpStrategy: + grad_out_strategy, out_strategy, softmax_dim, _ = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + out_strategy = cast(OpStrategy, out_strategy) + softmax_dim = cast(int, softmax_dim) + softmax_dim = normalize_dim(softmax_dim, grad_out_strategy.ndim) + + grad_in_strategy = OpStrategy([]) + for grad_out_placement_strat, out_placement_strat in zip( + grad_out_strategy.strategies, out_strategy.strategies + ): + # follow the sharding of the grad_out or out depending on which has more shards + grad_out_src_spec = grad_out_placement_strat.output_spec + out_src_spec = out_placement_strat.output_spec + src_spec = ( + grad_out_src_spec + if grad_out_src_spec.num_shards >= out_src_spec.num_shards + else out_src_spec + ) + + # make sure inputs are replicated along the softmax dim + tgt_spec = DTensorSpec( + mesh=grad_out_strategy.mesh, + placements=replicate_reduction_dims(src_spec.placements, [softmax_dim]), + ) + redist_grad_out_cost = generate_redistribute_costs(grad_out_strategy, tgt_spec) + redist_out_cost = generate_redistribute_costs(out_strategy, tgt_spec) + grad_in_strategy.strategies.append( + OpSpec( + output_specs=tgt_spec, + redistribute_cost=[redist_grad_out_cost, redist_out_cost], + ) + ) + + return grad_in_strategy + + +@register_op_strategy( + [aten.nll_loss_forward.default, aten.nll_loss2d_forward.default], + schema_info=RuntimeSchemaInfo(3), +) +def nll_loss_forward_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + assert len(op_schema.args_schema) == 5 + + ( + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + ) = op_schema.args_schema + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + if reduction == Reduction.NONE.value: + output_expected_spec = target_expected_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, placements=tuple([Replicate()] * mesh.ndim) + ) + else: + if reduction == Reduction.MEAN.value: + reduction_op = "avg" + if not is_tensor_evenly_shardable( + target_expected_spec.shape, target_expected_spec + ): + raise ValueError( + "The intermediate results of nll_loss cannot be evenly sharded, \ + resulting in biased mean result." + ) + else: # reduction == Reduction.SUM.value: + reduction_op = "sum" + reduce_dims = list(range(target_expected_spec.ndim)) + reduce_dims_map = _infer_reduce_dims_map( + reduce_dims, target_expected_spec.ndim, keep_dim=False + ) + out_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + reduction_op, + ) + output_expected_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + ) + + # whether reduction is sum or mean, the total weight has to be summed up if not replicated + total_weight_placements = map_placements_after_reduction( + target_expected_spec.placements, + reduce_dims, + reduce_dims_map, + "sum", + ) + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=total_weight_placements, + ) + + output_strategy.strategies.append( + OpSpec( + output_specs=(output_expected_spec, total_weight_expected_spec), + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.nll_loss_backward.default, aten.nll_loss2d_backward.default], + schema_info=RuntimeSchemaInfo(4), +) +def nll_loss_backward_strategy(op_schema: OpSchema) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + assert len(op_schema.args_schema) == 7 + ( + grad_out_strategy, + input_strategy, + target_strategy, + weight_strategy, + reduction, + _, + total_weight_strategy, + ) = op_schema.args_schema + grad_out_strategy = cast(OpStrategy, grad_out_strategy) + input_strategy = cast(OpStrategy, input_strategy) + target_strategy = cast(OpStrategy, target_strategy) + reduction = cast(int, reduction) + total_weight_strategy = cast(OpStrategy, total_weight_strategy) + + input_shape = input_strategy.shape + channel_dim = 1 if len(input_shape) >= 2 else 0 + + grad_in_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + + # make sure input is replicated along the channel dim + input_src_spec = input_placement_strategy.output_spec + input_expected_spec = DTensorSpec( + mesh=mesh, + placements=replicate_reduction_dims( + input_src_spec.placements, [channel_dim] + ), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_expected_spec) + ) + + # target doesn't have channel dim, and it follows input on other dims + target_src_spec = target_strategy.strategies[idx].output_spec + target_expected_spec = DTensorSpec( + mesh=mesh, + placements=_skip_dim(input_expected_spec.placements, channel_dim), + tensor_meta=target_src_spec.tensor_meta, + ) + op_args_target_specs.append(target_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(target_strategy, target_expected_spec) + ) + + # grad_out follows target if there is no reduction; + # otherwise, it should be a replicated scalar. + grad_out_src_spec = grad_out_strategy.strategies[idx].output_spec + if reduction == Reduction.NONE.value: + grad_out_expected_spec = target_expected_spec + else: + grad_out_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(grad_out_src_spec.placements), + tensor_meta=grad_out_src_spec.tensor_meta, + ) + op_args_target_specs.insert(0, grad_out_expected_spec) + redistribute_costs.insert( + 0, generate_redistribute_costs(grad_out_strategy, grad_out_expected_spec) + ) + + # weight tensor, if given, has to be a Tensor of size input_shape[channel_dim] + # make sure it is replicated + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_expected_spec) + ) + + # total_weight should always be replicated + total_weight_src_spec = total_weight_strategy.strategies[idx].output_spec + total_weight_expected_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(total_weight_src_spec.placements), + tensor_meta=total_weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(total_weight_expected_spec) + redistribute_costs.append( + generate_redistribute_costs( + total_weight_strategy, total_weight_expected_spec + ) + ) + + grad_in_expected_spec = input_expected_spec + grad_in_strategy.strategies.append( + OpSpec( + output_specs=grad_in_expected_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return grad_in_strategy + + +@register_op_strategy( + [aten.native_layer_norm.default], + schema_info=RuntimeSchemaInfo(1), +) +def layer_norm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + # args must be: input, normalized_shape, weight, bias, eps + # for None weight and bias, their corresponding objects will + # be None as well. layer_norm_strategy returns one OpStrategy + # for the triple return values (out, mean, rstd). + assert len(op_schema.args_schema) == 5 + ( + input_strategy, + normalized_shape, + weight_strategy, + bias_strategy, + _, + ) = op_schema.args_schema + + # the current layer norm implementation requires that all + # input DTensor's sharding must be in form of OpStrategy + assert isinstance(input_strategy, OpStrategy) + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + + # we use OpStrategy because the output (out, mean, rstd) + # should have the same placements + output_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + op_args_target_specs = [] + redistribute_costs = [] + input_src_spec = input_placement_strategy.output_spec + + # for the input tensor, we replicate it on the inner dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + op_args_target_specs.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + if weight_strategy is not None: + assert isinstance(weight_strategy, OpStrategy) + weight_src_spec = weight_strategy.strategies[idx].output_spec + + # for the weight tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + weight_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(weight_src_spec.placements), + tensor_meta=weight_src_spec.tensor_meta, + ) + op_args_target_specs.append(weight_target_spec) + redistribute_costs.append( + generate_redistribute_costs(weight_strategy, weight_target_spec) + ) + + if bias_strategy is not None: + assert isinstance(bias_strategy, OpStrategy) + bias_src_spec = bias_strategy.strategies[idx].output_spec + + # for the bias tensor, we replicate it on all dims if necessary + # TODO: we can avoid forcing the redistribution once we figure out + # how to decompose layer norm + bias_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(bias_src_spec.placements), + tensor_meta=bias_src_spec.tensor_meta, + ) + op_args_target_specs.append(bias_target_spec) + redistribute_costs.append( + generate_redistribute_costs(bias_strategy, bias_target_spec) + ) + + # the output spec is the same as input spec + output_target_spec = input_target_spec + output_strategy.strategies.append( + OpSpec( + output_specs=output_target_spec, + input_specs=op_args_target_specs, + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +@register_op_strategy( + [aten.native_layer_norm_backward.default], + schema_info=RuntimeSchemaInfo(2), +) +def layer_norm_bwd_strategy(op_schema: OpSchema) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + # args must be: grad_out, input, normalized_shape, mean, rstd, + # weight, bias, output_mask. For None weight and bias, their + # corresponding objects will be None as well. + + assert len(op_schema.args_schema) == 8 + ( + grad_out_strategy, + input_strategy, + normalized_shape, + mean_strategy, + rstd_strategy, + weight_strategy, + bias_strategy, + output_mask, + ) = op_schema.args_schema + + assert isinstance(grad_out_strategy, OpStrategy) + assert isinstance(input_strategy, OpStrategy) + assert isinstance(mean_strategy, OpStrategy) + assert isinstance(rstd_strategy, OpStrategy) + + assert isinstance(normalized_shape, (int, Sequence, torch.Size)) + normalized_size = normalize_to_torch_size(normalized_shape) + input_ndim = input_strategy.ndim + axis = input_ndim - len(normalized_size) + outer_dims = list(range(axis)) + + assert isinstance(output_mask, list) and len(output_mask) == 3 + + # output triple: (d_input, d_weight, d_bias) + out_tuple_strategy = OpStrategy([]) + for idx, input_placement_strategy in enumerate(input_strategy.strategies): + # args for OpSpec + output_specs_list: list[Optional[DTensorSpec]] = [] + input_specs_list: list[DTensorSpec] = [] + redistribute_costs = [] + + input_src_spec = input_placement_strategy.output_spec + # arg: grad_out + # TODO: change the strategy to the following rule. + # d_input is basically a product of element-wise mul of + # grad_out, rstd, and normalized input, among which rstd + # and normalized input (x_hat) should have the same sharding + # placements, and grad_out's sharding is determined by the + # pointwise result of x_hat and weight/bias. + # TODO: now grad_out spec follows input spec. we may need + # to change it to apply a pointwise rule over grad_out, + # input, and weight. + grad_out_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + input_specs_list.append(grad_out_target_spec) + redistribute_costs.append( + generate_redistribute_costs(grad_out_strategy, grad_out_target_spec) + ) + output_specs_list.append(grad_out_target_spec if output_mask[0] else None) + + # arg: input + input_target_spec = DTensorSpec( + mesh=mesh, + placements=_replicate_dims_start_at(input_src_spec.placements, axis), + tensor_meta=input_src_spec.tensor_meta, + ) + input_specs_list.append(input_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_strategy, input_target_spec) + ) + + # arg: mean, rstd + mean_src_spec = mean_strategy.strategies[idx].output_spec + input_specs_list.append(mean_src_spec) + redistribute_costs.append([0.0 for _ in mean_strategy.strategies]) + rstd_src_spec = rstd_strategy.strategies[idx].output_spec + input_specs_list.append(rstd_src_spec) + redistribute_costs.append([0.0 for _ in rstd_strategy.strategies]) + + def _add_target_input_spec(strategy) -> DTensorSpec: + # shared logic for setting the weight and bias target input specs + assert isinstance(strategy, OpStrategy) + src_spec = strategy.strategies[idx].output_spec + # no need to redistribute since they should be replicated in forward pass + input_specs_list.append(src_spec) + redistribute_costs.append([0.0 for _ in strategy.strategies]) + return src_spec + + # arg: weight + # d_weight = sum(grad_out * (input - mean) / rstd, outer_dim, keepdim=False) + if weight_strategy is not None: + weight_src_spec = _add_target_input_spec(weight_strategy) + # TODO: now d_weight spec follows input spec w/ a reduction. + # we may need to change to a pointwise rule over grad_out and + # input, then apply a reduction. + inp_placements = _replicate_dims_start_at(input_src_spec.placements, axis) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, input_src_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + weight_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=weight_src_spec.tensor_meta, + ) + output_specs_list.append(weight_out_spec if output_mask[1] else None) + else: + assert output_mask[1] is False, ( + "output_mask[1] should not be `True` while weight argument is `None` in native_layer_norm_backward." + ) + output_specs_list.append(None) + + # arg: bias + # d_bias = sum(grad_out, outer_dim, keepdim=False) + if bias_strategy is not None: + bias_src_spec = _add_target_input_spec(bias_strategy) + # d_bias spec follows a reduction over grad_out + inp_placements = _replicate_dims_start_at( + grad_out_target_spec.placements, axis + ) + reduce_dims_map = _infer_reduce_dims_map( + outer_dims, grad_out_target_spec.ndim, False + ) + out_placements = map_placements_after_reduction( + inp_placements, outer_dims, reduce_dims_map, "sum" + ) + bias_out_spec = DTensorSpec( + mesh=mesh, + placements=out_placements, + tensor_meta=bias_src_spec.tensor_meta, + ) + output_specs_list.append(bias_out_spec if output_mask[2] else None) + else: + assert output_mask[2] is False, ( + "output_mask[2] should not be `True` while bias argument is `None` in native_layer_norm_backward." + ) + output_specs_list.append(None) + + out_tuple_strategy.strategies.append( + OpSpec( + output_specs=tuple(output_specs_list), + input_specs=input_specs_list, + redistribute_cost=redistribute_costs, + ) + ) + + return out_tuple_strategy + + +@register_op_strategy( + [aten.topk.default], + schema_info=RuntimeSchemaInfo(2), +) +def topk_strategy(op_schema: OpSchema) -> OpStrategy: + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + topk_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else -1 + ) + topk_dim = normalize_dim(topk_dim, input_strategy.ndim) + + single_mesh_dim_strategies = [] + + # two outputs (values, indices), 1 input + # replicate always works + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # every dim except topk dim should work + for dim in range(input_strategy.ndim): + if dim != topk_dim: + dim_shardings: PlacementList = [Shard(dim)] * 3 + single_mesh_dim_strategies.append(dim_shardings) + # TODO: topk on sharded dim requries non-trival reduction, address it later + + return expand_to_full_mesh_op_strategy( + input_strategy.mesh, op_schema, single_mesh_dim_strategies, input_index=2 + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b1dc4109189107d6c7a9586f2a79be2cb2c7dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_matrix_ops.py @@ -0,0 +1,1040 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor + + +from typing import Optional + +import torch +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, +) +from torch.distributed.tensor._ops._einsum_strategy import gen_einsum_strategies +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + infer_broadcast_dims_map, + is_tensor_shardable, + map_placements_after_broadcast, + prod, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +@register_op_strategy(aten.t.default) +def transpose_strategy(op_schema: OpSchema) -> OpStrategy: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + transpose_strategies = [] + for input_strategy in self_strategy.strategies: + input_spec = input_strategy.output_spec + # follow the input spec but transpose the Shard placements + output_placements = [ + Shard(1 - p.dim) if isinstance(p, Shard) else p + for p in input_spec.placements + ] + transpose_strategy = OpSpec( + output_specs=DTensorSpec( + mesh=input_strategy.mesh, + placements=tuple(output_placements), + ), + input_specs=(input_strategy.output_spec,), + ) + transpose_strategies.append(transpose_strategy) + + return OpStrategy(strategies=transpose_strategies) + + +def _mm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + if is_tensor_shardable(self_strategy.shape, self_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _addmm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + self_strategy, mat1_strategy, mat2_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat1_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + self_shape = self_strategy.shape + mm_out_shape = torch.Size( + [ + mat2_strategy.shape[-1] if i == len(mat1_strategy.shape) - 1 else dim_size + for i, dim_size in enumerate(mat1_strategy.shape) + ] + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + # construct new strategy by consider the self arg + assert strtg.input_specs is not None + mat1_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + out_spec = strtg.output_spec + + # self arg's spec should follow the output of mm, but need + # to consider broadcast for the self arg + broadcast_dims_map = infer_broadcast_dims_map(mm_out_shape, self_shape) + self_placements = map_placements_after_broadcast( + out_spec.placements, mm_out_shape, broadcast_dims_map + ) + self_spec = DTensorSpec(mesh=mesh, placements=self_placements) + + if is_tensor_shardable(mat1_strategy.shape, mat1_spec) and is_tensor_shardable( + mat2_strategy.shape, mat2_spec + ): + # update input specs with new self spec + strtg.input_specs = (self_spec, mat1_spec, mat2_spec) + + # associate costs + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat1_strategy, mat1_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +def _scaled_mm_like_strategy( + mm_equation: str, mesh: DeviceMesh, op_schema: OpSchema +) -> OpStrategy: + ( + self_strategy, + mat2_strategy, + scale_self_strategy, + scale_mat2_strategy, + bias_strategy, + scale_result_strategy, + *_, + ) = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(mat2_strategy, OpStrategy) + assert isinstance(scale_self_strategy, OpStrategy) + assert isinstance(scale_mat2_strategy, OpStrategy) + # TODO: add support for these later + assert bias_strategy is None, "_scaled_mm on DTensors doesn't support bias" + assert scale_result_strategy is None, ( + "_scaled_mm on DTensors doesn't support scale_result" + ) + # generate all possible strategies for mm + mm_strategy = gen_einsum_strategies(mm_equation, mesh) + # filter out invalid strategies and associate costs + strategies = mm_strategy.strategies + filtered_strategies = [] + for strtg in strategies: + assert strtg.input_specs is not None + self_spec = strtg.input_specs[0] + mat2_spec = strtg.input_specs[1] + # propagate the operands' specs to their scales, except for tensor-wise + # scaling which can have any numbers of dims (legacy...), hence sharding + # dims won't map. for tensor-wise, anyways, we can only do replication. + scale_self_spec = ( + DTensorSpec(self_spec.mesh, (Replicate(),)) + if prod(scale_self_strategy.shape) == 1 + else self_spec + ) + scale_mat2_spec = ( + DTensorSpec(mat2_spec.mesh, (Replicate(),)) + if prod(scale_mat2_strategy.shape) == 1 + else mat2_spec + ) + strtg.input_specs = list(strtg.input_specs) + [scale_self_spec, scale_mat2_spec] + if ( + is_tensor_shardable(self_strategy.shape, self_spec) + and is_tensor_shardable(mat2_strategy.shape, mat2_spec) + and is_tensor_shardable(scale_self_strategy.shape, scale_self_spec) + and is_tensor_shardable(scale_mat2_strategy.shape, scale_mat2_spec) + ): + redistribute_cost = [ + generate_redistribute_costs(self_strategy, self_spec), + generate_redistribute_costs(mat2_strategy, mat2_spec), + generate_redistribute_costs(scale_self_strategy, scale_self_spec), + generate_redistribute_costs(scale_mat2_strategy, scale_mat2_spec), + ] + strtg.redistribute_cost = redistribute_cost + filtered_strategies.append(strtg) + + mm_strategy.strategies = filtered_strategies + + return mm_strategy + + +@register_op_strategy(aten.dot.default) +def dot_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _mm_like_strategy("i,i->", mesh, op_schema) + + +@register_op_strategy(aten.mm.default) +def mm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _mm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.addmm.default) +def addmm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _addmm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy(aten.bmm.default) +def bmm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _mm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten.baddbmm.default) +def baddmm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _addmm_like_strategy("bmk,bkn->bmn", mesh, op_schema) + + +@register_op_strategy(aten._scaled_mm.default) +def scaled_mm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + return _scaled_mm_like_strategy("mk,kn->mn", mesh, op_schema) + + +@register_op_strategy( + aten._scaled_dot_product_flash_attention.default, schema_info=RuntimeSchemaInfo(5) +) +def scaled_dot_product_flash_attention_strategy(op_schema: OpSchema) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + # TODO: sdpa might be a good candidate for us to explore decomposed sharding propagation + # as it involves: matmul, pointwise, reduction ops together. + + mesh = op_schema.get_mesh_from_args() + + return_debug_mask = len(op_schema.args_schema) >= 6 and op_schema.args_schema[5] + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 3 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + Replicate(), + Replicate(), + Replicate(), + Replicate(), + ] + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + if return_debug_mask: + debug_attn_mask_sharding: Placement = Shard(1) # num head dim + else: + # empty debug mask, replicated + debug_attn_mask_sharding = Replicate() + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Shard on the batch dimension + single_mesh_dim_strategies.append( + [ + Shard(0), # output + Shard(0), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + Shard(0), # debugattn + Shard(0), # q + Shard(0), # k + Shard(0), # v + ] + ) + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + Replicate(), # rng_state + None, # unused + Shard(2), # debugattn + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +@register_op_strategy(aten._scaled_dot_product_flash_attention_backward.default) +def scaled_dot_product_flash_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + q_input_strategy = op_schema.args_schema[1] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + + tensor_input_indices = [ + i + for i, arg_spec in enumerate(op_schema.args_schema) + if isinstance(arg_spec, OpStrategy) + ] + num_tensor_inputs = len(tensor_input_indices) + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 3 tensor outputs and 6 to 10 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (3 + num_tensor_inputs) + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + grad_output_sharding = Shard(1) # num head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + grad_qkv_sharding = Shard(1) # num head dim + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + output_sharding, + logsumexp_sharding, + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + num_heads_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Batch sharding + batch_dim_sharding: PlacementList = [ + Shard(0), # grad_q + Shard(0), # grad_k + Shard(0), # grad_v + Shard(0), # grad_output + Shard(0), # q + Shard(0), # k + Shard(0), # v + Shard(0), # output + Shard(0), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + batch_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(batch_dim_sharding) + + # Context Parallelism: shards on the sequence dim + seq_dim_sharding: PlacementList = [ + Shard(2), # grad_q + Shard(2), # grad_k + Shard(2), # grad_v + Shard(2), # grad_output + Shard(2), # q + Shard(2), # k + Shard(2), # v + Shard(2), # output + Shard(2), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + seq_dim_sharding.extend([Replicate()] * (num_tensor_inputs - 6)) + single_mesh_dim_strategies.append(seq_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +@register_op_strategy(aten.constant_pad_nd.default) +def constant_pad_nd_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args(validate=False) + + # TODO(d4l3k); implement a more correct strategy for constant_pad_nd + return OpStrategy( + [ + OpSpec( + output_specs=DTensorSpec(mesh, (Replicate(),)), + input_specs=( + DTensorSpec(mesh, (Replicate(),)), + DTensorSpec(mesh, (Replicate(),)), + ), + redistribute_cost=[[1]], + ) + ] + ) + + +@register_op_strategy( + aten._scaled_dot_product_efficient_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_efficient_attention_strategy(op_schema: OpSchema) -> OpStrategy: + # NOTE: currently we only support some simple strategies to support tensor parallelism + mesh = op_schema.get_mesh_from_args() + q_input_strategy = op_schema.args_schema[0] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + + has_attn_bias = op_schema.args_schema[3] is not None + compute_log_sumexp = op_schema.args_schema[4] + + single_mesh_dim_strategies: list[PlacementList] = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 2 valid tensor outputs and 3 or 4 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), + Replicate(), + None, + None, + Replicate(), + Replicate(), + Replicate(), + ] + if has_attn_bias: + all_replicate.append(Replicate()) # attn bias + + # Context Parallelism: shards on the sequence dim + single_mesh_dim_strategies.append( + [ + Shard(2), # output + Shard(2), # logsumexp + None, # philox_seed + None, # philox_offset + Shard(2), # q + Shard(2), # k + Shard(2), # v + ] + ) + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + qkv_sharding = Shard(1) + output_sharding = Shard(1) + if compute_log_sumexp: + logsumexp_sharding: Placement = Shard(1) + else: + # empty logsumexp, replicated + logsumexp_sharding = Replicate() + + num_heads_dim_sharding = [ + output_sharding, + logsumexp_sharding, + None, + None, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + if has_attn_bias: + num_heads_dim_sharding.append(Shard(1)) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # batch sharding + if compute_log_sumexp: + logsumexp_sharding_dp: Placement = Shard(0) + else: + # empty logsumexp, replicated + logsumexp_sharding_dp = Replicate() + batch_sharding = [ + Shard(0), # output + logsumexp_sharding_dp, # logsumexp + None, # philox_seed + None, # philox_offset + Shard(0), # q + Shard(0), # k + Shard(0), # v + ] + if has_attn_bias: + batch_sharding.append(Shard(0)) + + single_mesh_dim_strategies.append(batch_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) + + +@register_op_strategy(aten._scaled_dot_product_efficient_attention_backward.default) +def scaled_dot_product_efficient_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + q_input_strategy = op_schema.args_schema[1] + assert isinstance(q_input_strategy, OpStrategy) + # assuming q/k/v have the same shape + has_attn_bias = op_schema.args_schema[4] is not None + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda backward case, we have 4 tensor outputs and 8 or 9 tensor inputs + # NOTE: Output sharding of grad_bias on heads dim if attn_bias is present; + # otherwise grad_bias will be empty and its DTensorSpec will be removed. + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [Replicate()] * (12 + has_attn_bias) + + if not has_attn_bias: + all_replicate[3] = None # grad bias is None if attn_bias is not present + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the heads dimension + grad_output_sharding = Shard(1) + qkv_sharding = Shard(1) + output_sharding = Shard(1) + logsumexp_sharding = Shard(1) + grad_qkv_sharding = Shard(1) + grad_bias_sharding = Shard(1) if has_attn_bias else None + + num_heads_dim_sharding: PlacementList = [ + grad_qkv_sharding, + grad_qkv_sharding, + grad_qkv_sharding, + grad_bias_sharding, + grad_output_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + # the place for optional input attn_bias, + output_sharding, + logsumexp_sharding, + ] + # input sharding of attn_bias on heads dim if present + if has_attn_bias: + num_heads_dim_sharding.insert(8, Shard(1)) + # accept replicate on the rest scalar tensor inputs + # namely philox_seed and philox_offset + num_heads_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # Shards on batch dim + batch_dim_sharding: PlacementList = [ + Shard(0), # grad_q + Shard(0), # grad_k + Shard(0), # grad_v + Shard(0) if has_attn_bias else None, # grad_bias + Shard(0), # grad_output + Shard(0), # q + Shard(0), # k + Shard(0), # v + Shard(0), # output + Shard(0), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + if has_attn_bias: + batch_dim_sharding.insert(8, Shard(0)) + batch_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(batch_dim_sharding) + + # Context Parallelism: shards on the sequence dim + seq_dim_sharding: PlacementList = [ + Shard(2), # grad_q + Shard(2), # grad_k + Shard(2), # grad_v + Shard(1) if has_attn_bias else None, # grad_bias + Shard(2), # grad_output + Shard(2), # q + Shard(2), # k + Shard(2), # v + Shard(2), # output + Shard(2), # logsumexp + ] + # accept replicate on the rest tensor inputs, potentially + # cum_seq_q, cum_seq_k, philox_seed, philox_offset + # at indices 6, 7, 12, 13, respectively + if has_attn_bias: + num_heads_dim_sharding.insert(8, Shard(1)) + seq_dim_sharding.extend([Replicate(), Replicate()]) + single_mesh_dim_strategies.append(seq_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=4, + ) + + +@register_op_strategy( + aten._scaled_dot_product_cudnn_attention.default, + schema_info=RuntimeSchemaInfo(4), +) +def scaled_dot_product_cudnn_attention_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + ( + query_strategy, # query + _, # key + _, # value + attn_bias_strategy, + compute_log_sumexp, # compute_log_sumexp + *rest_args, # optional args: dropout_p, is_causal, return_debug_mask, scale + ) = op_schema.args_schema + return_debug_mask = len(op_schema.args_schema) >= 8 and rest_args[2] + has_attn_bias = attn_bias_strategy is not None + debug_attn_mask_sharding: Optional[Placement] = ( + Replicate() if return_debug_mask else None + ) + + assert isinstance(query_strategy, OpStrategy) + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # in the spda case, we have 2 valid tensor outputs and 3 tensor inputs + # first we can always accept full replication for both inputs and outputs + all_replicate: PlacementList = [ + Replicate(), # output + Replicate(), # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + # NOTE: debug_attn_mask is not supproted by pytorch and is always an empty tensor + # https://github.com/pytorch/pytorch/blob/60205b0eb2602317856312a66d955c88334ade0b/aten/src/ATen/native/transformers/cuda/attention.cu#L839-L840 + debug_attn_mask_sharding, # debug_attn_mask + Replicate(), # q + Replicate(), # k + Replicate(), # v + ] + if has_attn_bias: + all_replicate.append(Replicate()) # attn bias + + single_mesh_dim_strategies.append(all_replicate) + + # second we can accept the sharding pattern of tensor parallelism, which + # shard on the num of head dim + tp_sharding = Shard(1) # num head dim + qkv_sharding = tp_sharding + output_sharding = tp_sharding + logsumexp_sharding = tp_sharding if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = tp_sharding if return_debug_mask else None + + num_heads_dim_sharding: PlacementList = [ + output_sharding, + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + qkv_sharding, + qkv_sharding, + qkv_sharding, + ] + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # batch parallelism + logsumexp_sharding = Shard(0) if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = Shard(0) if return_debug_mask else None + batch_dim_sharding: PlacementList = [ + Shard(0), # output + logsumexp_sharding, + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, + Shard(0), # q + Shard(0), # k + Shard(0), # v + ] + single_mesh_dim_strategies.append(batch_dim_sharding) + + # Context Parallelism: shards on the sequence dim + cp_sharding = Shard(2) # seq dim + logsumexp_sharding = cp_sharding if compute_log_sumexp else Replicate() + debug_attn_mask_sharding = cp_sharding if return_debug_mask else None + + single_mesh_dim_strategies.append( + [ + cp_sharding, # output + logsumexp_sharding, # logsumexp + None, # cum_seq_q + None, # cum_seq_k + None, # max_q + None, # max_k + None, # philox_seed + None, # philox_offset + debug_attn_mask_sharding, # debug_attn_mask + cp_sharding, # q + cp_sharding, # k + cp_sharding, # v + ] + ) + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=9 + ) + + +@register_op_strategy(aten._scaled_dot_product_cudnn_attention_backward.default) +def scaled_scaled_dot_product_cudnn_attention_backward_strategy( + op_schema: OpSchema, +) -> OpStrategy: + # backward op does not need to validate the mesh since forward op has already done it + mesh = op_schema.get_mesh_from_args(validate=False) + + assert len(op_schema.args_schema) >= 15 + has_attn_bias = op_schema.args_schema[8] is not None + has_scale = len(op_schema.args_schema) >= 16 and False + + query_strategy = op_schema.args_schema[1] + assert isinstance(query_strategy, OpStrategy) + # assuming q/k/v have the same shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [outputs, inputs] + # cudnn outputs: (Tensor dq, Tensor dk, Tensor dv) + # cudnn inputs: ( + # Tensor grad_out, + # Tensor query, + # Tensor key, + # Tensor value, + # Tensor out, + # Tensor logsumexp, + # Tensor philox_seed, + # Tensor philox_offset, + # Tensor attn_bias, + # Tensor cum_seq_q, + # Tensor cum_seq_k, + # SymInt max_q, + # SymInt max_k, + # float dropout_p, + # bool is_causal, + # int? scale, + # ) + + # case 1: we can always accept full replication for both inputs and outputs + all_replicate_out: PlacementList = [ + Replicate(), # dq + Replicate(), # dk + Replicate(), # dv + ] + all_replicate_inp: PlacementList = [Replicate()] * 6 + all_replicate_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + all_replicate_inp += [Replicate() if has_attn_bias else None] + all_replicate_inp += [None] * 6 + if has_scale: + all_replicate_inp.append(None) + + all_replicate: PlacementList = all_replicate_out + all_replicate_inp + single_mesh_dim_strategies.append(all_replicate) + + # case 2: we can accept the sharding pattern of tensor parallelism, which + # shards on the num of head dim + qkv_sharding = Shard(1) # num head dim + output_sharding = Shard(1) # num head dim + logsumexp_sharding = Shard(1) # num head dim + + num_heads_dim_sharding_out: PlacementList = [qkv_sharding] * 3 + num_heads_dim_sharding_inp: PlacementList = [qkv_sharding] * 4 + num_heads_dim_sharding_inp += [output_sharding] + num_heads_dim_sharding_inp += [logsumexp_sharding] + num_heads_dim_sharding_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + num_heads_dim_sharding_inp += [Shard(1) if has_attn_bias else None] + num_heads_dim_sharding_inp += [None] * 6 + if has_scale: + num_heads_dim_sharding_inp.append(None) + + num_heads_dim_sharding = num_heads_dim_sharding_out + num_heads_dim_sharding_inp + single_mesh_dim_strategies.append(num_heads_dim_sharding) + + # case 3: Context Parallelism which shards on the sequence dim + context_parallel_sharding_out: PlacementList = [Shard(2)] * 3 + context_parallel_sharding_inp: PlacementList = [Shard(2)] * 6 + context_parallel_sharding_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + context_parallel_sharding_inp += [Shard(2) if has_attn_bias else None] + context_parallel_sharding_inp += [None] * 6 + if has_scale: + context_parallel_sharding_inp.append(None) + + context_parallel_sharding = ( + context_parallel_sharding_out + context_parallel_sharding_inp + ) + single_mesh_dim_strategies.append(context_parallel_sharding) + + # case 4: we can accept the sharding pattern of batch parallelism, which + # shards on the batch dimension + qkv_sharding = Shard(0) + output_sharding = Shard(0) + logsumexp_sharding = Shard(0) + + batch_dim_sharding_out: PlacementList = [qkv_sharding] * 3 + batch_dim_sharding_inp: PlacementList = [qkv_sharding] * 4 + batch_dim_sharding_inp += [output_sharding] + batch_dim_sharding_inp += [logsumexp_sharding] + batch_dim_sharding_inp += [ + Replicate() + ] * 2 # philox_seed, philox_offset is casted to Replicate() in DTensor + batch_dim_sharding_inp += [Shard(0) if has_attn_bias else None] + batch_dim_sharding_inp += [None] * 6 + if has_scale: + batch_dim_sharding_inp.append(None) + + batch_dim_sharding = batch_dim_sharding_out + batch_dim_sharding_inp + single_mesh_dim_strategies.append(batch_dim_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=3 + ) + + +@register_op_strategy(aten._grouped_mm.default) +def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: + mesh = op_schema.get_mesh_from_args() + + mat1_strategy = op_schema.args_schema[0] + assert isinstance(mat1_strategy, OpStrategy) + mat2_strategy = op_schema.args_schema[1] + assert isinstance(mat2_strategy, OpStrategy) + if len(op_schema.args_schema) > 3: + bias_strategy = op_schema.args_schema[3] + assert bias_strategy is None, "grouped_mm doesn't support bias yet" + + single_mesh_dim_strategies = [] + + offs_placement = None + if len(op_schema.args_schema) > 2 and op_schema.args_schema[2] is not None: + offs_placement = Replicate() # offs should always be replicated + + all_replicate: PlacementList = [ + Replicate(), + Replicate(), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + partial_replicate: PlacementList = [ + Partial(), + Partial(), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + replicate_partial: PlacementList = [ + Partial(), + Replicate(), # mat1 + Partial(), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies = [all_replicate, partial_replicate, replicate_partial] + + if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 3: + # rowwise_replicate for 2dx3d not supported + replicate_colwise_2x3: PlacementList = [ + Shard(1), + Replicate(), # mat1 + Shard(2), # mat2 + offs_placement, # offs + None, # bias + ] + colwise_rowwise_2x3: PlacementList = [ + Partial(), + Shard(1), # mat1 + Shard(1), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend([replicate_colwise_2x3, colwise_rowwise_2x3]) + + if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 2: + # replicate_colwise for 3dx2d not supported + colwise_rowwise_3x2: PlacementList = [ + Partial(), + Shard(2), # mat1 + Shard(0), # mat2 + offs_placement, # offs + None, # bias + ] + rowwise_replicate_3x2: PlacementList = [ + Shard(0), + Shard(1), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend([colwise_rowwise_3x2, rowwise_replicate_3x2]) + + if mat1_strategy.ndim == 2 and mat2_strategy.ndim == 2: + # colwise_rowwise for 2dx2d not supported + replicate_colwise_2x2: PlacementList = [ + Shard(2), + Replicate(), # mat1 + Shard(1), # mat2 + offs_placement, # offs + None, # bias + ] + rowwise_replicate_2x2: PlacementList = [ + Shard(1), + Shard(0), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend( + [replicate_colwise_2x2, rowwise_replicate_2x2] + ) + + if mat1_strategy.ndim == 3 and mat2_strategy.ndim == 3: + replicate_colwise_3x3: PlacementList = [ + Shard(2), + Replicate(), # mat1 + Shard(2), # mat2 + offs_placement, # offs + None, # bias + ] + rowwise_replicate_3x3: PlacementList = [ + Shard(1), + Shard(1), # mat1 + Replicate(), # mat2 + offs_placement, # offs + None, # bias + ] + colwise_rowwise_3x3: PlacementList = [ + Partial(), + Shard(2), # mat1 + Shard(1), # mat2 + offs_placement, # offs + None, # bias + ] + batch_dim_sharding: PlacementList = [ + Shard(0), + Shard(0), # mat1 + Shard(0), # mat2 + offs_placement, # offs + None, # bias + ] + single_mesh_dim_strategies.extend( + [ + replicate_colwise_3x3, + rowwise_replicate_3x3, + colwise_rowwise_3x3, + batch_dim_sharding, + ] + ) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d5b3bce37f4ec6491d1c1ac440592545cb3fdf8c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_pointwise_ops.py @@ -0,0 +1,700 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Sequence +from typing import cast + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._ops.utils import ( + generate_redistribute_costs, + infer_broadcast_dims_map, + map_placements_after_broadcast, + normalize_dim, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten +# leave the remaining pointwise_ops list here for convenience, +# Below ops are some pointwise ops that are yet to be supported, +# they might not be a complete list. +# pointwise_ops = [ +# "fake_quantize_per_channel_affine", +# "fake_quantize_per_tensor_affine", +# "floor_divide", # floor_divide is deprecated +# "frexp", # multiple output pointwise op, need to add support +# "gradient", # need investigation on this op +# "imag", # complex data type only +# "quantized_batch_norm", +# "quantized_max_pool1d", +# "quantized_max_pool2d", +# "real", # complex data type only +# ] + + +linear_pointwise_ops = [ + aten.div.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.div_.Scalar, # this op is linear on the first argument, and the second argument is scalar, so it fits as a linear op. + aten.to.dtype, + aten.add.Tensor, + aten.add_.Tensor, +] + + +pointwise_ops = [ + # please keep the entries below alphabetically sorted + aten.__ilshift__.Scalar, + aten.__ilshift__.Tensor, + aten.__irshift__.Scalar, + aten.__irshift__.Tensor, + aten.__lshift__.Scalar, + aten.__lshift__.Tensor, + aten.__rshift__.Scalar, + aten.__rshift__.Tensor, + aten._conj.default, + aten.abs.default, + aten.abs.out, + aten.abs_.default, + aten.acos.default, + aten.acos.out, + aten.acos_.default, + aten.acosh.default, + aten.acosh.out, + aten.acosh_.default, + aten.add.Scalar, + aten.add.out, + aten.add_.Scalar, + aten.addcdiv.default, + aten.addcdiv.out, + aten.addcdiv_.default, + aten.addcmul.default, + aten.addcmul.out, + aten.addcmul_.default, + aten.angle.default, + aten.angle.out, + aten.asin.default, + aten.asin.out, + aten.asin_.default, + aten.asinh.default, + aten.asinh.out, + aten.asinh_.default, + aten.atan.default, + aten.atan.out, + aten.atan2.default, + aten.atan2.out, + aten.atan2_.default, + aten.atan_.default, + aten.atanh.default, + aten.atanh.out, + aten.atanh_.default, + aten.bitwise_and.Scalar, + aten.bitwise_and.Scalar_Tensor, + aten.bitwise_and.Scalar_out, + aten.bitwise_and.Tensor, + aten.bitwise_and.Tensor_out, + aten.bitwise_and_.Scalar, + aten.bitwise_and_.Tensor, + aten.bitwise_left_shift.Scalar_Tensor, + aten.bitwise_left_shift.Tensor, + aten.bitwise_left_shift.Tensor_Scalar, + aten.bitwise_left_shift.Tensor_Scalar_out, + aten.bitwise_left_shift.Tensor_out, + aten.bitwise_left_shift_.Tensor, + aten.bitwise_left_shift_.Tensor_Scalar, + aten.bitwise_not.default, + aten.bitwise_not.out, + aten.bitwise_not_.default, + aten.bitwise_or.Scalar, + aten.bitwise_or.Scalar_Tensor, + aten.bitwise_or.Scalar_out, + aten.bitwise_or.Tensor, + aten.bitwise_or.Tensor_out, + aten.bitwise_or_.Scalar, + aten.bitwise_or_.Tensor, + aten.bitwise_right_shift.Scalar_Tensor, + aten.bitwise_right_shift.Tensor, + aten.bitwise_right_shift.Tensor_Scalar, + aten.bitwise_right_shift.Tensor_Scalar_out, + aten.bitwise_right_shift.Tensor_out, + aten.bitwise_right_shift_.Tensor, + aten.bitwise_right_shift_.Tensor_Scalar, + aten.bitwise_xor.Scalar, + aten.bitwise_xor.Scalar_Tensor, + aten.bitwise_xor.Scalar_out, + aten.bitwise_xor.Tensor, + aten.bitwise_xor.Tensor_out, + aten.bitwise_xor_.Scalar, + aten.bitwise_xor_.Tensor, + aten.ceil.default, + aten.ceil.out, + aten.ceil_.default, + aten.clamp.default, + aten.clamp.out, + aten.clamp_.default, + aten.clip.default, + aten.clip.out, + aten.clip_.default, + aten.conj_physical.default, + aten.conj_physical.out, + aten.conj_physical_.default, + aten.copysign.Scalar, + aten.copysign.Scalar_out, + aten.copysign.Tensor, + aten.copysign.out, + aten.copysign_.Scalar, + aten.copysign_.Tensor, + aten.cos.default, + aten.cos.out, + aten.cos_.default, + aten.cosh.default, + aten.cosh.out, + aten.cosh_.default, + aten.deg2rad.default, + aten.deg2rad.out, + aten.deg2rad_.default, + aten.digamma.default, + aten.digamma.out, + aten.digamma_.default, + aten.div.Tensor, + aten.div.Tensor_mode, + aten.div.out, + aten.div.out_mode, + aten.div_.Tensor, + aten.div_.Tensor_mode, + aten.eq.Tensor, + aten.eq.Tensor_out, + aten.eq.Scalar, + aten.eq.Scalar_out, + aten.erf.default, + aten.erf.out, + aten.erf_.default, + aten.erfc.default, + aten.erfc.out, + aten.erfc_.default, + aten.erfinv.default, + aten.erfinv.out, + aten.erfinv_.default, + aten.exp.default, + aten.exp.out, + aten.exp2.default, + aten.exp2.out, + aten.exp2_.default, + aten.exp_.default, + aten.expm1.default, + aten.expm1.out, + aten.expm1_.default, + aten.float_power.Scalar, + aten.float_power.Scalar_out, + aten.float_power.Tensor_Scalar, + aten.float_power.Tensor_Scalar_out, + aten.float_power.Tensor_Tensor, + aten.float_power.Tensor_Tensor_out, + aten.float_power_.Scalar, + aten.float_power_.Tensor, + aten.floor.default, + aten.floor.out, + aten.floor_.default, + aten.fmod.Scalar, + aten.fmod.Scalar_out, + aten.fmod.Tensor, + aten.fmod.Tensor_out, + aten.fmod_.Scalar, + aten.fmod_.Tensor, + aten.frac.default, + aten.frac.out, + aten.frac_.default, + aten.ge.Scalar, + aten.ge.Tensor, + aten.gelu.default, + aten.gt.Tensor, + aten.gt.Tensor_out, + aten.gt.Scalar, + aten.gt.Scalar_out, + aten.gt.Scalar, + aten.gt.Tensor, + aten.hypot.default, + aten.hypot.out, + aten.hypot_.default, + aten.i0.default, + aten.i0.out, + aten.i0_.default, + aten.igamma.default, + aten.igamma.out, + aten.igamma_.default, + aten.igammac.default, + aten.igammac.out, + aten.igammac_.default, + aten.isinf.default, + aten.isnan.default, + aten.isneginf.default, + aten.isneginf.out, + aten.isposinf.default, + aten.isposinf.out, + aten.ldexp.default, + aten.ldexp.out, + aten.ldexp_.default, + aten.lt.Tensor, + aten.lt.Tensor_out, + aten.lt.Scalar, + aten.lt.Scalar_out, + aten.le.Scalar, + aten.le.Tensor, + aten.lerp.Scalar, + aten.lerp.Scalar_out, + aten.lerp.Tensor, + aten.lerp.Tensor_out, + aten.lerp_.Scalar, + aten.lerp_.Tensor, + aten.lgamma.default, + aten.lgamma.out, + aten.lgamma_.default, + aten.log.default, + aten.log.out, + aten.log10.default, + aten.log10.out, + aten.log10_.default, + aten.log1p.default, + aten.log1p.out, + aten.log1p_.default, + aten.log2.default, + aten.log2.out, + aten.log2_.default, + aten.log_.default, + aten.logaddexp.default, + aten.logaddexp.out, + aten.logaddexp2.default, + aten.logaddexp2.out, + aten.logical_and.default, + aten.logical_and.out, + aten.logical_and_.default, + aten.logical_not.default, + aten.logical_not.out, + aten.logical_not_.default, + aten.logical_or.default, + aten.logical_or.out, + aten.logical_or_.default, + aten.logical_xor.default, + aten.logical_xor.out, + aten.logical_xor_.default, + aten.logit.default, + aten.logit.out, + aten.logit_.default, + aten.masked_fill.Scalar, + aten.maximum.default, + aten.maximum.out, + aten.minimum.default, + aten.minimum.out, + aten.mul.Scalar, + aten.mul.Tensor, + aten.mul.out, + aten.mul_.Scalar, + aten.mul_.Tensor, + aten.mvlgamma.default, + aten.mvlgamma.out, + aten.mvlgamma_.default, + aten.native_dropout_backward.default, + aten.native_dropout_backward.out, + aten.nan_to_num.default, + aten.nan_to_num.out, + aten.nan_to_num_.default, + aten.ne.Scalar, + aten.neg.default, + aten.neg.out, + aten.neg_.default, + aten.nextafter.default, + aten.nextafter.out, + aten.nextafter_.default, + aten.polygamma.default, + aten.polygamma.out, + aten.polygamma_.default, + aten.positive.default, + aten.pow.Scalar, + aten.pow.Scalar_out, + aten.pow.Tensor_Scalar, + aten.pow.Tensor_Scalar_out, + aten.pow.Tensor_Tensor, + aten.pow.Tensor_Tensor_out, + aten.pow_.Scalar, + aten.pow_.Tensor, + aten.reciprocal.default, + aten.reciprocal.out, + aten.reciprocal_.default, + aten.rad2deg.default, + aten.rad2deg.out, + aten.rad2deg_.default, + aten.relu.default, + aten.relu_.default, + aten.remainder.Scalar, + aten.remainder.Scalar_Tensor, + aten.remainder.Scalar_out, + aten.remainder.Tensor, + aten.remainder.Tensor_out, + aten.remainder_.Scalar, + aten.remainder_.Tensor, + aten.round.decimals, + aten.round.decimals_out, + aten.round.default, + aten.round.out, + aten.round_.decimals, + aten.round_.default, + aten.rsqrt.default, + aten.rsqrt.out, + aten.rsqrt_.default, + aten.rsub.Scalar, + aten.sgn.default, + aten.sgn.out, + aten.sgn_.default, + aten.sigmoid.default, + aten.sigmoid.out, + aten.sigmoid_.default, + aten.sign.default, + aten.sign.out, + aten.sign_.default, + aten.signbit.default, + aten.signbit.out, + aten.silu.default, + aten.silu.out, + aten.sin.default, + aten.sin.out, + aten.sin_.default, + aten.sinc.default, + aten.sinc.out, + aten.sinc_.default, + aten.sinh.default, + aten.sinh.out, + aten.sinh_.default, + aten.sqrt.default, + aten.sqrt.out, + aten.sqrt_.default, + aten.square.default, + aten.square.out, + aten.square_.default, + aten.sub.Scalar, + aten.sub.Tensor, + aten.sub.out, + aten.sub_.Scalar, + aten.sub_.Tensor, + aten.tan.default, + aten.tan.out, + aten.tan_.default, + aten.tanh.default, + aten.tanh.out, + aten.tanh_.default, + aten.true_divide.Tensor, + aten.trunc.default, + aten.trunc.out, + aten.trunc_.default, + aten.where.self, + aten.where.self_out, + aten.xlogy.OutScalar_Self, + aten.xlogy.OutScalar_Other, + aten.xlogy.OutTensor, + aten.xlogy.Scalar_Other, + aten.xlogy.Scalar_Self, + aten.xlogy.Tensor, + aten.xlogy_.Scalar_Other, + aten.xlogy_.Tensor, + # backward point-wise ops + # please keep the entries below alphabetically sorted + aten.gelu_backward.default, + aten.sigmoid_backward.default, + aten.silu_backward.default, + aten.tanh_backward.default, + aten.threshold_backward.default, +] + + +def pointwise_strategy(op_schema: OpSchema, linearity: bool = False) -> OpStrategy: + max_shards_strategy_index = -1 + max_shards = -1 + max_ndim = -1 + + if op_schema.is_inplace_op(): + # inplace op should follow the first arg strategy + followed_strategy = op_schema.args_schema[0] + elif op_schema.is_out_variant_op(): + # out variant op should follow the out kwarg strategy + followed_strategy = op_schema.kwargs_schema["out"] + else: + # normal pointwise op, we choose to follow the arg with + # the max shards in case operands needs reshard + # in case of multiple operands with max shard, we take + # the one with the max number of dimensions + for idx, arg_strategy in enumerate(op_schema.args_schema): + if not isinstance(arg_strategy, OpStrategy): + continue + + arg_max_shards = arg_strategy.max_num_shards() + arg_max_ndim = arg_strategy.ndim + if (arg_max_shards > max_shards) or ( + arg_max_shards == max_shards and arg_max_ndim > max_ndim + ): + max_shards_strategy_index = idx + max_shards = arg_max_shards + max_ndim = arg_max_ndim + + followed_strategy = op_schema.args_schema[max_shards_strategy_index] + + assert isinstance(followed_strategy, OpStrategy), ( + f"no strategy to follow for {op_schema}!" + ) + return common_pointwise_strategy( + op_schema.args_schema, followed_strategy, linearity + ) + + +def common_pointwise_strategy( + args_schema: Sequence[object], + followed_strategy: OpStrategy, + linearity: bool, +) -> OpStrategy: + # handle broadcasting + common_shape = torch.broadcast_shapes( + *[arg.shape for arg in args_schema if isinstance(arg, OpStrategy)] + ) + pointwise_strategy = OpStrategy([]) + + for placement_strategy in followed_strategy.strategies: + spec_to_follow = placement_strategy.output_spec + out_placements: list[Placement] = [] + for placement in spec_to_follow.placements: + if isinstance(placement, Shard): + shard_dim = normalize_dim(placement.dim, len(spec_to_follow.shape)) + common_ndim = len(common_shape) + new_shard_dim = common_ndim - len(spec_to_follow.shape) + shard_dim + out_placements.append(Shard(new_shard_dim)) + elif isinstance(placement, Partial) and not linearity: + # clear the partial placemnet if op does not support linearity + # by default we just replicate the partial, need to see if this + # is optimal for all cases + out_placements.append(Replicate()) + else: + out_placements.append(placement) + + input_specs: list[DTensorSpec] = [] + redistribute_costs: list[list[float]] = [] + for arg_idx, input_arg in enumerate(args_schema): + if isinstance(input_arg, OpStrategy): + # sanity check that all args that follow the same strategy + # are on the same DeviceMesh + if input_arg.mesh != followed_strategy.mesh: + raise ValueError( + f"Could not run pointwise computation across different mesh: " + f"Found {input_arg.mesh} and {followed_strategy.mesh}!" + ) + + # every arg follow the out_placements, but need to handle broadcasting + input_arg_spec = input_arg.strategies[0].output_spec + input_arg_dims_map = infer_broadcast_dims_map( + common_shape, input_arg_spec.shape + ) + input_target_placements = map_placements_after_broadcast( + tuple(out_placements), + common_shape, + input_arg_dims_map, + ) + input_arg_target_spec = DTensorSpec( + mesh=followed_strategy.mesh, + placements=input_target_placements, + tensor_meta=input_arg_spec.tensor_meta, + ) + input_specs.append(input_arg_target_spec) + redistribute_costs.append( + generate_redistribute_costs(input_arg, input_arg_target_spec) + ) + + pointwise_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec( + mesh=followed_strategy.mesh, + placements=tuple(out_placements), + ), + input_specs=input_specs, + redistribute_cost=redistribute_costs, + ) + ) + return pointwise_strategy + + +def linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: + """ + Linear pointwise operators can propagate pending reductions. + For example, c = add(a, b); if a is pending sum, then c will be + pending sum as well without any communication overhead. + """ + return pointwise_strategy(op_schema, linearity=True) + + +for op in linear_pointwise_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + linear_pointwise_strategy + ) + +for op in pointwise_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(static_kwargkey=["out"]))( + pointwise_strategy + ) + + +# TODO: add all for_each ops +for_each_ops = [ + aten._foreach_abs.default, + aten._foreach_abs_.default, + aten._foreach_addcdiv_.Scalar, + aten._foreach_addcdiv_.ScalarList, + aten._foreach_addcdiv_.Tensor, + aten._foreach_addcmul.Scalar, + aten._foreach_addcmul_.Scalar, + aten._foreach_addcmul_.ScalarList, + aten._foreach_addcmul_.Tensor, + aten._foreach_clamp_max_.Scalar, + aten._foreach_clamp_min_.Scalar, + aten._foreach_div_.List, + aten._foreach_div_.Scalar, + aten._foreach_div_.ScalarList, + aten._foreach_div_.Tensor, + aten._foreach_div.List, + aten._foreach_div.Scalar, + aten._foreach_div.ScalarList, + aten._foreach_div.Tensor, + aten._foreach_lerp_.Scalar, + aten._foreach_maximum_.List, + aten._foreach_mul.Scalar, + aten._foreach_mul.ScalarList, + aten._foreach_mul.Tensor, + aten._foreach_mul.List, + aten._foreach_mul_.Scalar, + aten._foreach_mul_.ScalarList, + aten._foreach_mul_.Tensor, + aten._foreach_mul_.List, + aten._foreach_neg.default, + aten._foreach_neg_.default, + aten._foreach_reciprocal_.default, + aten._foreach_sub.Scalar, + aten._foreach_sub_.Scalar, + aten._foreach_sub.List, + aten._foreach_sub_.List, + aten._foreach_sub.ScalarList, + aten._foreach_sub_.ScalarList, + aten._foreach_sqrt.default, + aten._foreach_sqrt_.default, + aten._foreach_zero_.default, + aten._foreach_exp.default, + aten._foreach_exp_.default, + aten._foreach_cos.default, + aten._foreach_cos_.default, + aten._foreach_log.default, + aten._foreach_log_.default, + aten._amp_foreach_non_finite_check_and_unscale_.default, +] + +for_each_linearity_ops = [ + aten._foreach_add.Scalar, + aten._foreach_add_.Scalar, + aten._foreach_add_.ScalarList, + aten._foreach_add.List, + aten._foreach_add_.List, +] + + +def list_pointwise_strategy( + op_schema: OpSchema, linearity: bool = False +) -> StrategyType: + """ + Apply the pointwise strategy to the zipped arguments. For example, if we + run a foreach add of two lists l1 and l2, then we apply the pointwise + strategy on each pair (l1[i], l2[i]). If the first argument is a list but + the second (or later) one is a tensor, then we broadcast the tensor by + replicating it into a list with the length of the first argument. + + Args: + mesh (DeviceMesh): device mesh for pointwise ops + op_schema (OpSchema): schema of the operator to generate strategy for + linearity (bool): specify whether op(a) + op(b) = op(a + b) + + Returns: + OpStrategy: generated strategy + """ + + def args_tuple_strategies(args_schema: tuple[object, ...]) -> list[TupleStrategy]: + first_arg = args_schema[0] + assert isinstance(first_arg, TupleStrategy) + strategy_len = len(first_arg.childs) + tuple_strategies: list[TupleStrategy] = [] + for arg_idx, arg in enumerate(args_schema): + if isinstance(arg, TupleStrategy): + # every tuple strategy should have the same length + assert len(arg.childs) == strategy_len + tuple_strategies.append(arg) + elif isinstance(arg, OpStrategy): + if arg_idx > 0: # implicitly broadcast + tuple_strategies.append( + TupleStrategy([arg for _ in range(strategy_len)]) + ) + else: + raise RuntimeError( + f"list op only supports tuple strategy! {op_schema}" + ) + return tuple_strategies + + args_strategies = args_tuple_strategies(op_schema.args_schema) + follow_strategy: TupleStrategy = args_strategies[0] + list_strategy: list[OpStrategy] = [] + for child_idx, child_strtgy in enumerate(follow_strategy.childs): + assert isinstance(child_strtgy, OpStrategy) + args_schema: list[OpStrategy] = [ + cast(OpStrategy, arg_strategy.childs[child_idx]) + for arg_strategy in args_strategies + ] + pointwise_strategy: OpStrategy = common_pointwise_strategy( + args_schema, child_strtgy, linearity + ) + list_strategy.append(pointwise_strategy) + return TupleStrategy(list_strategy) + + +def list_linear_pointwise_strategy(op_schema: OpSchema) -> StrategyType: + """ + for each list op stratgy that supports linearity + """ + return list_pointwise_strategy(op_schema, linearity=True) + + +for op in for_each_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_pointwise_strategy + ) + +for op in for_each_linearity_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_linear_pointwise_strategy + ) + +fused_ops = [ + aten._fused_adam_.default, + aten._fused_adam.default, + aten._fused_adam.tensor_lr, + aten._fused_adam_.tensor_lr, + aten._fused_adamw_.default, + aten._fused_adamw.default, + aten._fused_adamw.tensor_lr, + aten._fused_adamw_.tensor_lr, +] + +for op in fused_ops: + register_op_strategy(op, schema_info=RuntimeSchemaInfo(needs_pytree=True))( + list_pointwise_strategy + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_random_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_random_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce26672aac185d4cb8154ff45617c27fe35818a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_random_ops.py @@ -0,0 +1,36 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import torch +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + StrategyType, +) +from torch.distributed.tensor._ops.utils import is_tensor_partial, register_op_strategy + + +aten = torch.ops.aten + + +@register_op_strategy( + [ + aten.normal_.default, + aten.uniform_.default, + aten.native_dropout.default, + aten.bernoulli_.float, + aten.bernoulli.default, + ] +) +def random_op_strategy(op_schema: OpSchema) -> StrategyType: + self_strategy = op_schema.args_schema[0] + assert isinstance(self_strategy, OpStrategy) + + random_strategy = OpStrategy([]) + for arg_strategy in self_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # TODO: figure out how inplace random op should behave when it's partial + raise RuntimeError(f"{op_schema.op} with Partial is not supported yet!") + random_strategy.strategies.append(OpSpec(output_specs=arg_spec)) + + return random_strategy diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d0af26ca33f375708c321055d836962b41b40d44 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_tensor_ops.py @@ -0,0 +1,915 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Sequence, Sized +from typing import cast, Optional + +import torch +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + OutputSharding, + PlacementList, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._ops._common_rules import pointwise_rule +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops.utils import ( + expand_to_full_mesh_op_strategy, + generate_redistribute_costs, + is_tensor_dim_sharded, + is_tensor_evenly_shardable, + is_tensor_partial, + normalize_dim, + register_op_strategy, + register_prop_rule, +) +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +aten = torch.ops.aten + + +def default_strategy(op_schema: OpSchema) -> StrategyType: + # Default strategy by default just propagate the first input strategy + select_strategy = op_schema.args_schema[0] + assert isinstance(select_strategy, OpStrategy) + # we create new DTensorSpecs even for default strategy to assure that + # the tensor metas are distinct between the arguments and outputs + default_strategy = [ + OpSpec( + output_specs=DTensorSpec( + mesh=select_strategy.mesh, + placements=strategy.output_spec.placements, + ) + ) + for strategy in select_strategy.strategies + ] + return OpStrategy(default_strategy) + + +register_op_strategy( + [ + aten.clone.default, + aten.contiguous.default, + aten.copy_.default, + aten.detach.default, + aten.fill_.Scalar, + aten.view.dtype, + aten.zero_.default, + ] +)(default_strategy) + +register_op_strategy( + aten._to_copy.default, schema_info=RuntimeSchemaInfo(static_kwargkey=["dtype"]) +)(default_strategy) + + +@register_op_strategy( + [ + aten.equal.default, + aten.is_same_size.default, + ] +) +def equal_strategy(op_schema: OpSchema) -> StrategyType: + # equal_strategy deals with ops that comparing two tensor, we need to make sure + # sharding layout the same with two operands, we choose to follow the arg with max + # num of shards, still keep is_same_size here for completeness as they share the + # same strategy in theory. + mesh = op_schema.get_mesh_from_args() + self_strategy, other_strategy = op_schema.args_schema + assert isinstance(self_strategy, OpStrategy) + assert isinstance(other_strategy, OpStrategy) + + select_strategy = ( + self_strategy + if self_strategy.max_num_shards() >= other_strategy.max_num_shards() + else other_strategy + ) + equal_strategy = OpStrategy([]) + + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + # if the arg_spec have partial, reshard to replicate + # otherwise local shard tensor comparison would be invalid + output_spec = DTensorSpec( + mesh=mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + equal_strategy.strategies.append(OpSpec(output_specs=output_spec)) + else: + equal_strategy.strategies.append(OpSpec(arg_spec)) + return equal_strategy + + +@register_op_strategy( + [ + aten.empty_like.default, + aten.ones_like.default, + aten.rand_like.default, + aten.randn_like.default, + aten.zeros_like.default, + ], + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +@register_op_strategy( + [aten.full_like.default], + schema_info=RuntimeSchemaInfo(2, ["dtype"]), +) +@register_op_strategy( + [ + aten.randint_like.default, + aten.randint_like.low_dtype, + aten.randint_like.low_dtype_out, + ], + schema_info=RuntimeSchemaInfo(3, ["dtype"]), +) +def create_like_strategy(op_schema: OpSchema) -> StrategyType: + # create_like_strategy deals with ops that creating tensors with same + # shape as input, but with specific content that does not depend on + # the input, we can propagate sharding, but we have to make sure we + # move from partial to replicated. + select_strategy = op_schema.args_schema[0] + create_like_strategy = OpStrategy([]) + assert isinstance(select_strategy, OpStrategy) + for arg_strategy in select_strategy.strategies: + arg_spec = arg_strategy.output_spec + output_spec = DTensorSpec( + mesh=select_strategy.mesh, + placements=tuple( + Replicate() if isinstance(p, Partial) else p + for p in arg_spec.placements + ), + ) + create_like_strategy.strategies.append( + OpSpec(output_specs=output_spec, input_specs=(arg_spec,)) + ) + + return create_like_strategy + + +@register_op_strategy( + [ + aten.new_empty.default, + aten.new_full.default, + aten.new_ones.default, + aten.new_zeros.default, + aten.new_empty_strided.default, + ], + schema_info=RuntimeSchemaInfo(1, ["dtype"]), +) +def new_factory_strategy(op_schema: OpSchema) -> StrategyType: + # Currently there are two strategies: + # 1. let the output be replicated + # 2. let the output follow the input if input and output have the same shape + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + + mesh = input_strategy.mesh + input_shape = input_strategy.shape + output_shape = op_schema.args_schema[1] + assert isinstance(output_shape, list) + + new_factory_strategy = OpStrategy([]) + for arg_strategy in input_strategy.strategies: + input_spec = arg_strategy.output_spec + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + new_factory_strategy.strategies.append( + OpSpec( + output_specs=replica_spec, + input_specs=(input_spec,), + redistribute_cost=[[0.0] * mesh.ndim], + ) + ) + + if tuple(input_shape) == tuple(output_shape) and input_spec.is_sharded(): + # NOTE: for new_empty_strided, currently the non-replicate sharding + # is supported only when the shape is evenly shardable + if ( + op_schema.op == aten.new_empty_strided.default + and not is_tensor_evenly_shardable(input_shape, input_spec) + ): + continue + + new_factory_strategy.strategies.append( + OpSpec( + output_specs=input_spec, + input_specs=(input_spec,), + # encouraging new tensor placement to be the same as input + redistribute_cost=[[-0.1] * mesh.ndim], + ) + ) + + return new_factory_strategy + + +@register_op_strategy(aten.bucketize.Tensor) +def gen_bucketize_strategy(op_schema: OpSchema) -> StrategyType: + """Just propagate input sharding, but expect replicated for boundaries input.""" + mesh = op_schema.get_mesh_from_args() + input_strategy = op_schema.args_schema[0] + bucketize_strategy = OpStrategy([]) + assert isinstance(input_strategy, OpStrategy) + for arg_strategy in input_strategy.strategies: + arg_spec = DTensorSpec(mesh, arg_strategy.output_spec.placements) + replica_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + bucketize_strategy.strategies.append( + OpSpec(output_specs=arg_spec, input_specs=(arg_spec, replica_spec)) + ) + + return bucketize_strategy + + +@register_op_strategy(aten.select.int, schema_info=RuntimeSchemaInfo(1)) +def select_int_strategy(op_schema: OpSchema) -> StrategyType: + """ + In this select op, first determine the input specs, then determine the output specs. + - Input specs: + - If the input is sharded on the selected dim, unshard it and change to replicate. + - Otherwise, keep the original input specs. + - Output specs: + - It checks the input specs with the following cases: + - Case 1 shard_dim == selected_dim: not possible as the input is already unsharded. + - Case 2 shard_dim < selected_dim: keep the input specs. + - Case 3 shard_dim > selected_dim: shard_dim -= 1. + """ + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + assert len(op_schema.args_schema) == 3 + selected_dim, index = ( + cast(int, op_schema.args_schema[1]), + cast(int, op_schema.args_schema[2]), + ) + input_shape = input_strategy.shape + input_ndim = input_strategy.ndim + selected_dim = normalize_dim(selected_dim, input_ndim) + index = normalize_dim(index, input_shape[selected_dim]) + + select_strategy = OpStrategy([]) + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + + # determine input spec + input_specs = arg_spec + if is_tensor_dim_sharded(arg_spec, dim=selected_dim): + # if input is sharded on the selected dim, need to unshard it, change to replicate + arg_target_placements = unshard_tensor_dim( + arg_spec.placements, dim=selected_dim + ) + input_specs = DTensorSpec(arg_spec.mesh, arg_target_placements) # R + + # determine output spec + output_specs = input_specs + if input_specs.is_sharded(): + # handle cases with sharded_dim != selected_dim + output_spec_placements = [] + for placement in input_specs.placements: + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + if shard_dim > selected_dim: + shard_dim -= 1 + placement = Shard(dim=shard_dim) + output_spec_placements.append(placement) + output_specs = DTensorSpec( + arg_spec.mesh, placements=tuple(output_spec_placements) + ) + + select_strategy.strategies.append( + OpSpec( + output_specs=output_specs, + input_specs=(input_specs,), + ) + ) + return select_strategy + + +@register_op_strategy( + aten.select_backward.default, + schema_info=RuntimeSchemaInfo(1), +) +def select_backward_strategy(op_schema: OpSchema) -> OpStrategy: + # func: select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + args_schema = op_schema.args_schema + input_strategy, dim = args_schema[0], args_schema[2] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + assert isinstance(dim, int) + output_strategies: list[OpSpec] = [] + for placement_strategy in input_strategy.strategies: + input_spec = placement_strategy.output_spec + output_spec_placements: list[Placement] = [] + for placement in input_spec.placements: + if isinstance(placement, Shard): + shard_dim = placement.dim + if shard_dim >= dim: + # NOTE: shard_dim is guaranteed to exist because + # grad_input has one more dim than grad_output + output_spec_placements.append(Shard(shard_dim + 1)) + else: + output_spec_placements.append(Shard(shard_dim)) + else: + output_spec_placements.append(placement) + output_specs = DTensorSpec(input_spec.mesh, tuple(output_spec_placements)) + output_strategies.append( + OpSpec(output_specs=output_specs, input_specs=(input_spec,)) + ) + return OpStrategy(output_strategies) + + +@register_op_strategy(aten.slice.Tensor, schema_info=RuntimeSchemaInfo(1)) +def gen_slice_strategy(op_schema: OpSchema) -> StrategyType: + """Forward all shardings except the slice dimension.""" + defaults = (None, 0, None, None, 1) + input_strategy, dim, start, end, step = ( + op_schema.args_schema + defaults[len(op_schema.args_schema) :] + ) + assert isinstance(input_strategy, OpStrategy) + + mesh = input_strategy.mesh + input_shape = input_strategy.shape + input_ndim = input_strategy.ndim + assert isinstance(dim, int) + if start is None: + start = 0 + if end is None or end > input_shape[dim]: + end = input_shape[dim] + assert isinstance(start, int) + assert isinstance(end, int) + assert isinstance(step, int) + + # normalize args + slice_dim = normalize_dim(dim, input_ndim) + start = normalize_dim(start, input_shape[dim]) + end = normalize_dim(end, input_shape[dim]) + + redundant_slice = start == 0 and end == input_shape[dim] and step == 1 + + slice_strategy = OpStrategy([]) + + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if not is_tensor_dim_sharded(arg_spec, dim=slice_dim) or redundant_slice: + # only add the strategy if the slice dim is not sharded + out_spec = DTensorSpec(mesh, arg_spec.placements) + slice_strategy.strategies.append(OpSpec(output_specs=out_spec)) + if not slice_strategy.strategies: + # if all strategies are filtered out, unsharding all specs on slice dim + # of the input strategy, and use that as the op strategy + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + unshard_spec = DTensorSpec( + mesh, unshard_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_strategy.strategies.append(OpSpec(output_specs=unshard_spec)) + return slice_strategy + + +@register_op_strategy( + aten.slice_backward.default, + schema_info=RuntimeSchemaInfo(1), +) +def slice_backward_rules(op_schema: OpSchema) -> OpStrategy: + # func: slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + args_schema = op_schema.args_schema + input_strategy, dim = args_schema[0], args_schema[2] + assert isinstance(input_strategy, OpStrategy), f"{input_strategy}" + output_strategies: list[OpSpec] = [] + for placement_strategy in input_strategy.strategies: + output_spec = placement_strategy.output_spec + new_placements: list[Placement] = [] + for placement in output_spec.placements: + # Redistribute to replicate only if the dim is sharded and matches the slice dim + if isinstance(placement, Shard) and placement.dim == dim: + new_placements.append(Replicate()) + else: + new_placements.append(placement) + new_spec = DTensorSpec(output_spec.mesh, tuple(new_placements)) + redistribute_cost = [generate_redistribute_costs(input_strategy, new_spec)] + placement_strategy.redistribute_cost = redistribute_cost + new_strategy = OpSpec(output_specs=new_spec) + output_strategies.append(new_strategy) + return OpStrategy(output_strategies) + + +def unshard_tensor_dim( + placements: Sequence[Placement], dim: int +) -> tuple[Placement, ...]: + """Disallow the given tensor dimension to be sharded.""" + return tuple( + p if (not isinstance(p, Shard) or p.dim != dim) else Replicate() + for p in placements + ) + + +def replicate_tensor_dim( + placements: Sequence[Placement], dim: int +) -> tuple[Placement, ...]: + """Force the given tensor dimension to be replicated.""" + # Not using p.is_shard() to avoid mypy complain about Placement not having + # attribute dim. + return tuple( + Replicate() if p.is_partial() or isinstance(p, Shard) and p.dim == dim else p + for p in placements + ) + + +@register_op_strategy(aten.slice_scatter.default, schema_info=RuntimeSchemaInfo(2)) +def gen_slice_scatter_strategy(op_schema: OpSchema) -> StrategyType: + # 1. number of dimensions in input and src need to match. + # 2. number of elements on all non-dim need to match between input and src. + # 3. numer of elements in src in dim need to match the slice size. + # Given the above: + # - We suggest for src to follow the sharding of input, except on the scatter dimension, + # where our best bet for now is to make them replicated as a fall-back. + # TODO: Ideally we'd like to make sure the output is re-sharded afterwards to keep input sharding. + mesh = op_schema.get_mesh_from_args() + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + input_ndim = input_strategy.ndim + slice_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 + ) + slice_dim = normalize_dim(slice_dim, input_ndim) + + slice_scatter_strategy = OpStrategy([]) + # by default follow the input strategy for both input and src + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if not ( + is_tensor_dim_sharded(arg_spec, dim=slice_dim) + or is_tensor_partial(arg_spec) + ): + # only add the strategy if the slice_scatter dim is not sharded or partial + slice_scatter_strategy.strategies.append(OpSpec(output_specs=arg_spec)) + + if not slice_scatter_strategy.strategies: + # if all strategies are filtered out, replicating all specs on slice_scatter dim + # of the input strategy, and use that as the op strategy + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + replicate_spec = DTensorSpec( + mesh, replicate_tensor_dim(arg_spec.placements, dim=slice_dim) + ) + slice_scatter_strategy.strategies.append( + OpSpec(output_specs=replicate_spec) + ) + return slice_scatter_strategy + + +@register_op_strategy(aten._local_scalar_dense.default) +def replica_only_strategy(op_schema: OpSchema) -> StrategyType: + """Only allow replication on the input/output.""" + input_strategy = op_schema.args_schema[0] + assert isinstance(input_strategy, OpStrategy) + mesh = input_strategy.mesh + replicate_spec = DTensorSpec(mesh, tuple([Replicate()] * mesh.ndim)) + return OpStrategy([OpSpec(replicate_spec)]) + + +@register_op_strategy( + [aten.scatter_.value, aten.scatter.value, aten.scatter_.src, aten.scatter.src], + schema_info=RuntimeSchemaInfo(1), +) +def scatter_strategy(op_schema: OpSchema) -> StrategyType: + mesh = op_schema.get_mesh_from_args() + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index, src] + # first we always have replicate all for inputs and output + if len(op_schema.args_strategy) < 3: + # scatter_.src/scatter.src with src be float number instead of tensor + all_replicate: PlacementList = [Replicate()] * 3 + else: + all_replicate = [Replicate()] * 4 + single_mesh_dim_strategies.append(all_replicate) + + # TODO: see if we can support input sharding pattern + op_strategy = expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + inplace_op=op_schema.is_inplace_op(), + ) + return op_strategy + + +@register_op_strategy(aten.gather.default) +def gather_strategy(op_schema: OpSchema) -> StrategyType: + mesh = op_schema.get_mesh_from_args() + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + dim = cast(int, op_schema.args_schema[1]) + index_strategy = cast(OpStrategy, op_schema.args_schema[2]) + + input_shape = input_strategy.shape + index_shape = index_strategy.shape + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, input, index] + # first we always have replicate all for inputs and output + all_replicate: PlacementList = [Replicate()] * 3 + single_mesh_dim_strategies.append(all_replicate) + + # input sharding, input sharded, index accepts mask partial, output follows index + # this only works when the input is sharded on the gather dimension, and + # index has size 1 on the gather dimension + if index_shape[dim] == 1: + index_partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=dim) + input_sharding: PlacementList = [ + index_partial_placement, + Shard(dim), + index_partial_placement, + ] + single_mesh_dim_strategies.append(input_sharding) + + # index sharding, input replicated, index sharded, output follows index + # this only works when the sharding dimension is the gather dimension + index_sharding: PlacementList = [Shard(dim), Replicate(), Shard(dim)] + single_mesh_dim_strategies.append(index_sharding) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +def _derive_follow_placements_from_tuple_strategy( + op: torch._ops.OpOverload, + tuple_strategy: TupleStrategy, +) -> Sequence[Placement]: + """ + derive the placements to follow from the tuple strategy, mainly used by + aten.stack, aten.cat, where each operand have the same shape, and correspondingly + expecting the same sharding + """ + + def merge_placement( + cur_placement: Placement, new_placement: Placement + ) -> Placement: + # semantic if we already have a follow placement, we + # check each placement for the current arg placement + # to see if we want to merge/adjust the placement to follow + # the priority: Partial -> Shard -> Replicate + if cur_placement == new_placement: + return cur_placement + + if cur_placement.is_partial(): + if new_placement.is_shard(): + # follow new placement + return new_placement + elif new_placement.is_partial(): + # different partial types, we can't merge and have to replicate all here + return Replicate() + else: + # follow partial + return cur_placement + elif cur_placement.is_shard(): + if new_placement.is_shard(): + # cur/new placement are different sharding (i.e. different shard dim) + # currently fallback to replicate all args + return Replicate() + else: + # for partial/replicate, follow the current shard placement + return cur_placement + else: + # current replicate, just follow new placement + return new_placement + + follow_placements: Optional[list[Placement]] = None + mesh = tuple_strategy.child_mesh(0) + for arg_strategy in tuple_strategy.childs: + assert isinstance(arg_strategy, OpStrategy) + if arg_strategy.mesh != mesh: + raise ValueError( + f"All operands in {op} must have the same mesh, " + f"but got {arg_strategy.mesh} and {mesh}." + ) + + for placement_strategy in arg_strategy.strategies: + arg_placements = placement_strategy.output_spec.placements + if follow_placements is None: + follow_placements = list(arg_placements) + continue + assert follow_placements is not None + for mesh_idx in range(mesh.ndim): + # merge placements with the priority + follow_placements[mesh_idx] = merge_placement( + follow_placements[mesh_idx], arg_placements[mesh_idx] + ) + assert follow_placements is not None, "follow placements should not be None!" + return follow_placements + + +def normalize_shard_for_stack( + placements: Sequence[Placement], insert_dim: int = 0 +) -> Sequence[Placement]: + # stack op would "insert" new dim, so all sharded dim >= the inserted dim need to + # be normalized with the new Shard placement + normalized_placements: list[Placement] = [] + for placement in placements: + if isinstance(placement, Shard) and placement.dim >= insert_dim: + normalized_placements.append(Shard(placement.dim + 1)) + else: + normalized_placements.append(placement) + return normalized_placements + + +@register_op_strategy(aten.stack.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def stack_strategy(op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + mesh = first_input_strategy.mesh + + follow_placements = _derive_follow_placements_from_tuple_strategy( + op_schema.op, input_tuple_strategy + ) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + + follow_placements = normalize_shard_for_stack(follow_placements, dim) + + op_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + +@register_op_strategy(aten.cat.default, RuntimeSchemaInfo(1, needs_pytree=True)) +def cat_strategy(op_schema: OpSchema) -> StrategyType: + args_schema = op_schema.args_schema + input_tuple_strategy = args_schema[0] + assert isinstance(input_tuple_strategy, TupleStrategy), f"{input_tuple_strategy}" + first_input_strategy = input_tuple_strategy.childs[0] + assert isinstance(first_input_strategy, OpStrategy), f"{first_input_strategy}" + common_input_ndim = first_input_strategy.ndim + dim = cast(int, args_schema[1]) if len(args_schema) > 1 else 0 + # normalize the dim to be within the common input ndim + dim = normalize_dim(dim, common_input_ndim) + + mesh = first_input_strategy.mesh + + follow_placements = _derive_follow_placements_from_tuple_strategy( + op_schema.op, input_tuple_strategy + ) + # for cat we unshard the cat dim if it is sharded + follow_placements = unshard_tensor_dim(follow_placements, dim) + + # create op strategy base on the follow placements + op_strategy = OpStrategy([]) + + input_specs = tuple( + DTensorSpec(mesh, tuple(follow_placements)) + for _ in range(len(input_tuple_strategy.childs)) + ) + op_strategy.strategies.append( + OpSpec( + output_specs=DTensorSpec(mesh, tuple(follow_placements)), + input_specs=input_specs, + ) + ) + return op_strategy + + +@register_prop_rule(aten.index_select.default, schema_info=RuntimeSchemaInfo(1)) +def prop_index_select(op_schema: OpSchema) -> OutputSharding: + values_spec, dim, indices_spec = op_schema.args_schema + + assert isinstance(values_spec, DTensorSpec) + assert isinstance(dim, int) + assert isinstance(indices_spec, DTensorSpec) + + all_indices_spec: list[Optional[DTensorSpec]] = [ + indices_spec if dim == i else None for i in range(values_spec.ndim) + ] + + result = prop_index( + OpSchema( + op=op_schema.op, + args_schema=(values_spec, all_indices_spec), + kwargs_schema=op_schema.kwargs_schema, + ) + ) + if result.redistribute_schema: + schema_suggestion = result.redistribute_schema + result.redistribute_schema = OpSchema( + op=op_schema.op, + args_schema=( + schema_suggestion.args_schema[0], + dim, + schema_suggestion.args_schema[1][dim], # type: ignore[index] + ), + kwargs_schema=op_schema.kwargs_schema, + ) + return result + + +@register_prop_rule(aten.index.Tensor, schema_info=RuntimeSchemaInfo(needs_pytree=True)) +def prop_index(op_schema: OpSchema) -> OutputSharding: + """ + Expect replicated on the first input; _mostly_ pointwise on the second input. + + TODO: exception: when the dtype of second input is "bool", then a torch.nonzero needs to be triggered first. + """ + # Current sharding constraints: + # For values: + # 1. We currently require that the dimension of values_spec be replicated or partial + # if they are being indexed on. + # 2. Other dimensions of values_spec can remain sharded if they are so. + # For indices: + # Indices can be either sharded or replicated. All index tensors need to be sharded + # in a compatible way, following the pointwise rule (including resolving Partial + # into either sharded or replicated) + + values_spec, multi_indices_spec = op_schema.args_schema + assert isinstance(values_spec, DTensorSpec) + assert isinstance(multi_indices_spec, list) + multi_indices_spec = cast(list[Optional[DTensorSpec]], multi_indices_spec) + valid_indices_spec: list[tuple[int, DTensorSpec]] = [ + (i, a) for i, a in enumerate(multi_indices_spec) if a is not None + ] + + # 1. All indices have to be sharded equally. Moreover, indices can be broadcast. + # Here, we piggyback on the pointwise sharding rule for indices. + indices_out = pointwise_rule( + OpSchema( + op=op_schema.op, + args_schema=tuple(v[1] for v in valid_indices_spec), + kwargs_schema={}, + ) + ) + need_reshard_on_indices = indices_out.output_spec is None + + if not need_reshard_on_indices: + # this means that our inputs are already sharded properly and we will use that as our indices_spec + assert isinstance(indices_out.output_spec, DTensorSpec) + indices_spec: DTensorSpec = indices_out.output_spec + else: + assert indices_out.redistribute_schema is not None + valid_indices_suggestion = indices_out.redistribute_schema + for i, v in enumerate(valid_indices_suggestion.args_spec): + multi_indices_spec[valid_indices_spec[i][0]] = v + # we'll need to call pointwise_rule again to see what's our ideal indices_spec and then + # use that to compute our ideal values_spec + indices_output_spec = pointwise_rule(valid_indices_suggestion).output_spec + assert isinstance(indices_output_spec, DTensorSpec) + indices_spec = indices_output_spec + + lookup_dims = {v[0] for v in valid_indices_spec} + + need_reshard_on_values = tuple( + (isinstance(vp, Shard) and (vp.dim in lookup_dims or isinstance(ip, Shard))) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + + if not need_reshard_on_indices and not any(need_reshard_on_values): + value_placements = values_spec.placements + + all_dims_consecutive = all( + b[0] - a[0] == 1 + for b, a in zip(valid_indices_spec[1:], valid_indices_spec[:-1]) + ) + if all_dims_consecutive: + # if all index vectors are consecutives, insert at the dimension of the first index + insert_dim: int = valid_indices_spec[0][0] + else: + # else, insert on the first dimension + insert_dim = 0 + + def place(vp: Placement, ip: Placement) -> Placement: + if isinstance(vp, Shard): + return Shard( + vp.dim + if vp.dim < insert_dim + # accounts for the offset in output dimensions + else vp.dim + + indices_spec.ndim + - sum(1 if vp.dim > v[0] else 0 for v in valid_indices_spec) + ) + if isinstance(ip, Shard): + return Shard(ip.dim + insert_dim) + # Partial or Replicated + return vp + + value_placements = tuple( + place(vp, ip) + for vp, ip in zip(values_spec.placements, indices_spec.placements) + ) + result = OutputSharding( + output_spec=DTensorSpec( + mesh=values_spec.mesh, + placements=value_placements, + ) + ) + return result + else: + result = OutputSharding( + output_spec=None, + redistribute_schema=OpSchema( + op=op_schema.op, + args_schema=( + DTensorSpec( + mesh=values_spec.mesh, + placements=tuple( + [ + Replicate() if need_reshard_on_values[i] else v + for i, v in enumerate(values_spec.placements) + ] + ), + tensor_meta=values_spec.tensor_meta, + ), + multi_indices_spec, + ), + kwargs_schema=op_schema.kwargs_schema, + ), + ) + return result + + +@register_op_strategy( + [ + aten.split.Tensor, + aten.split_with_sizes.default, + aten.split_with_sizes_copy.default, + ], + RuntimeSchemaInfo(1), +) +def split_strategy(op_schema: OpSchema) -> TupleStrategy: + input_strategy = op_schema.args_schema[0] + split_size_or_sections = op_schema.args_schema[1] + assert isinstance(input_strategy, OpStrategy) + input_ndim = input_strategy.ndim + split_dim = ( + cast(int, op_schema.args_schema[2]) if len(op_schema.args_schema) > 2 else 0 + ) + dim = normalize_dim(split_dim, input_ndim) + + # tensor to split cannot have Partial for now + for arg_strategy in input_strategy.strategies: + arg_spec = arg_strategy.output_spec + if is_tensor_partial(arg_spec): + raise NotImplementedError( + f"splitting distributed tensor with " + f"Partial placement is not implemented!\n" + f"DTensorSpec={arg_strategy}" + ) + + def size_split(N, i) -> list: + # Last chunk will be smaller if the tensor size N + # along the given dimension dim is not divisible by i. + assert i > 0 + return [i] * (N // i) + ([N % i] if N % i != 0 else []) + + output_size_list = ( + size_split(input_strategy.shape[dim], split_size_or_sections) + if isinstance(split_size_or_sections, int) + else split_size_or_sections + ) + assert isinstance(output_size_list, Sized) + + split_strategies = [] + + for _ in range(len(output_size_list)): + op_strategy = OpStrategy([]) + + for strategy in input_strategy.strategies: + spec = strategy.output_spec + placements = spec.placements + if is_tensor_dim_sharded(spec, dim=dim): + # if the input is sharded on the split dim, we need to unshard it + placements = unshard_tensor_dim(spec.placements, dim=dim) + + spec = DTensorSpec(spec.mesh, placements) + + op_strategy.strategies.append( + OpSpec(output_specs=spec, input_specs=([spec])) + ) + split_strategies.append(op_strategy) + + return TupleStrategy(split_strategies) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_view_ops.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_view_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..d8cadd28860c2b218a4543c3913e077b728fc8f8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/_view_ops.py @@ -0,0 +1,714 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import Callable, cast, Optional, Union + +import torch +from torch import Tensor +from torch._prims_common import DimsType +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + RuntimeSchemaInfo, + StrategyType, +) +from torch.distributed.tensor._ops.utils import ( + generate_redistribute_costs, + normalize_dim, + normalize_dims, + prod, + register_op_strategy, +) +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard + + +aten = torch.ops.aten + +Shape = tuple[int, ...] + + +@dataclass +class DimSpec: + """Specifies how an output dimension maps to an input dimension.""" + + def inputs(self) -> Iterable["DimSpec"]: + return () + + +# Rules that map each dimension of the output to dimensions of the input tensor +DimMap = tuple[DimSpec, ...] + + +@dataclass +class Singleton(DimSpec): + """Output dimension is a singleton.""" + + +@dataclass +class InputDim(DimSpec): + """Output dimension maps directly to an input dimension.""" + + input_dim: int + + +@dataclass +class Broadcast(DimSpec): + """Output is the broadcast of a singleton input dimension.""" + + dim: DimSpec + dim_size: int + + @classmethod + def new(cls, dim: DimSpec, dim_size: int) -> DimSpec: + return Broadcast(dim, dim_size) + + def inputs(self) -> Iterable[DimSpec]: + return (self.dim,) + + +@dataclass +class NewDim(DimSpec): + """This is a new dimension created by the op.""" + + size: int + + @classmethod + def new(cls, size: int) -> DimSpec: + return Singleton() if size == 1 else NewDim(size) + + +@dataclass +class Repeat(DimSpec): + """Output dimension is the input dimension repeated n-times.""" + + input_dim: DimSpec + times: int + + @classmethod + def new(cls, dim: DimSpec, times: int) -> DimSpec: + if times == 1: + return dim + elif isinstance(dim, Singleton): + # repeating a singleton is the same as broadcasting it + return Broadcast(dim, times) + else: + return Repeat(dim, times) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +@dataclass +class Flatten(DimSpec): + """Flatten a set of input dimensions, ensuring right-most adjacent elements remain adjacent in the output.""" + + input_dims: Sequence[DimSpec] + + @classmethod + def new(cls, dims: Sequence[DimSpec]) -> DimSpec: + if len(dims) == 0: + # flattening a scalar leads to a singleton + return Singleton() + elif len(dims) == 1: + # flattening a single dimension is no-op + return dims[0] + else: + return Flatten(dims) + + def inputs(self) -> Iterable[DimSpec]: + return self.input_dims + + +@dataclass +class Split(DimSpec): + """ + This dimension is a member of a decomposition of the input dim. + + Note that input_dim itself could be a Flattened set of input dims. + """ + + input_dim: DimSpec + group_shape: Shape + split_id: int + + @classmethod + def new(cls, dim: DimSpec, group_shape: tuple[int, ...], idx: int) -> DimSpec: + assert len(group_shape) > 0 + if len(group_shape) == 1: + # not really a group, just return the input dim back + assert idx == 0 + return dim + elif group_shape[idx] == 1: + return Singleton() + else: + # remove singletons from group + # group_mapping = [(new_index, (shape, old_index)) ...] + group_mapping = list( + enumerate((s, i) for i, s in enumerate(group_shape) if s != 1) + ) + new_group_shape = tuple(m[1][0] for m in group_mapping) + new_idx = next(filter(lambda x: x[1][1] == idx, group_mapping))[0] + return Split(dim, new_group_shape, new_idx) + + def inputs(self) -> Iterable[DimSpec]: + return (self.input_dim,) + + +def dim_pad_left(ndim: int, min_dims: int) -> DimMap: + return (Singleton(),) * max(0, min_dims - ndim) + tuple( + InputDim(i) for i in range(ndim) + ) + + +def dim_atleast_3d(ndim: int) -> DimMap: + if ndim == 0: + return (Singleton(), Singleton(), Singleton()) + elif ndim == 1: + return (Singleton(), InputDim(0), Singleton()) + elif ndim == 2: + return (InputDim(0), InputDim(1), Singleton()) + else: + return tuple(InputDim(i) for i in range(ndim)) + + +def expand(input_shape: Shape, shape: Shape) -> DimMap: + """Implement broadcast on multiple dimensions.""" + assert len(shape) >= len(input_shape) + + # 1. create padded input dimensions + padded_input = dim_pad_left(len(input_shape), len(shape)) + # 2. check that input shapes are compatible + mapping = [] + for p, desired_s in zip(padded_input, shape): + if isinstance(p, Singleton): + actual_s = 1 + assert desired_s >= 0 + else: + assert isinstance(p, InputDim), f"DimSpec not supported in expand: {p}" + actual_s = input_shape[p.input_dim] + assert actual_s == 1 or desired_s == -1 or desired_s == actual_s + mapping.append( + p + if desired_s in (1, -1) or desired_s == actual_s + else Broadcast.new(p, desired_s) + ) + return tuple(mapping) + + +def normalize_sizes(sizes: Union[Shape, tuple[Shape]]) -> Shape: + if isinstance(sizes[0], int): + return cast(Shape, sizes) + elif len(sizes) == 1: + return sizes[0] + else: + raise RuntimeError("Size must be int... or tuple") + + +def dim_flatten(ndim: int, start_dim=0, end_dim=-1) -> DimMap: + if ndim == 0: + return (Singleton(),) + elif ndim == 1: + return (InputDim(0),) + else: + # only flattening dims from start_dim to end_dim (inclusive) + # other dims are passed through + if end_dim < 0: + end_dim += ndim + results: list[DimSpec] = [InputDim(i) for i in range(start_dim)] + results.append( + Flatten.new(tuple(InputDim(i) for i in range(start_dim, end_dim + 1))) + ) + results.extend([InputDim(i) for i in range(end_dim + 1, ndim)]) + return tuple(results) + + +def dim_movedim( + ndim: int, + input: DimsType, + destination: DimsType, +) -> DimMap: + input = normalize_dims(input, ndim) + destination = normalize_dims(destination, ndim) + + assert len(input) == len(destination) + input_set = set(input) + assert len(input_set) == len(input), "Found repeated input dims" + assert len(set(destination)) == len(destination), "Found repeated output dims" + assert max(input) < ndim + assert max(destination) < ndim + + dest = [-1] * ndim + for i, d in zip(input, destination): + dest[d] = i + + unused_inputs_iter = iter(i for i in range(ndim) if i not in input_set) + for i in range(ndim): + if dest[i] == -1: + dest[i] = next(unused_inputs_iter) + + return tuple(InputDim(i) for i in dest) + + +def dim_repeat(ndim: int, sizes: Shape) -> DimMap: + sizes = normalize_sizes(sizes) + assert len(sizes) >= ndim, ( + f"Number of dimensions of repeat dims {sizes} can not be smaller than number of dimensions of tensor {ndim}." + ) + pad = len(sizes) - ndim + return tuple(Repeat.new(Singleton(), s) for s in sizes[:pad]) + tuple( + Repeat.new(InputDim(i), s) for i, s in enumerate(sizes[pad:]) + ) + + +def infer_size(total_size: int, sizes: Shape) -> Shape: + """ + One dimension input to view may be "-1". + + Infer the size of this dimension given the total_size. + """ + infers = [i for i, s in enumerate(sizes) if s == -1] + size = prod(sizes) + assert len(infers) <= 1, "can only infer one size" + if infers: + size = -size + missing_size = total_size // size + assert total_size % size == 0, ( + f"size inferred for -1 is not integral {sizes} should have {total_size} elements." + ) + return tuple(s if s != -1 else missing_size for s in sizes) + assert size == total_size, f"sizes do not match {total_size} vs {size}" + return sizes + + +def view_groups(from_size: Shape, to_size: Shape) -> DimMap: + """ + Decompose a reshape operation into forwarding, flattening, or splitting dimensions for each output dimension. + + A view or reshape operation can be decomposed into a set of 3 types of smaller operations: + 1) Forward a dimension from input to output + 2) Flatten a set of dimensions into a single dimension + 3) Split one dimension into multiple dimensions + + view_groups identifies these operations and returns, for each output dimension, what + is operation was performed in the input dimension. For example: + + view_groups([2, 3, 4], [2, 12]) -> ( + InputDim(0), + Flatten((InputDim(1), InputDim(2))) + ) + + - ouptut dimension 0 maps to input dimension 0 + - output dimension 1 maps to a flattened input dimensions 1 and 2 + + + view_groups([2, 3], [3, 2]) -> ( + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 0), + Split(Flatten((InputDim(0), InputDim(1))), (3, 2), 1), + ) + + - in the above, input is flattened into a single dimension and then split + into two separate dimensions with different sizes from the input. + """ + from_nelem = prod(from_size) + to_size = infer_size(from_nelem, normalize_sizes(to_size)) + + assert from_nelem == prod(to_size), "Total view shape does not add up" + + from_idx = 0 + to_idx = 0 + from_len = len(from_size) + to_len = len(to_size) + + result_pp = [] + + while from_idx < from_len or to_idx < to_len: + from_group_dim, to_group_shape = [], [] + + if from_idx >= from_len: + f = 1 + else: + f = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + + if to_idx >= to_len: + t = 1 + else: + t = to_size[to_idx] + to_group_shape.append(t) + to_idx += 1 + + # if any of the groups is singleton, great, we need to backtrack though + if f == 1 and t != 1: + # produces ([1], []) + to_idx -= 1 + to_group_shape = [] + elif f != 1 and t == 1: + # produces ([], [1]) + from_idx -= 1 + from_group_dim = [] + else: + # produces ([1], [1]), ([2], [2]), ([2,3], [6]) + while f != t: + if f < t: + nf = from_size[from_idx] + from_group_dim.append(from_idx) + from_idx += 1 + f *= nf + else: + nt = to_size[to_idx] + to_group_shape.append(nt) + to_idx += 1 + t *= nt + + if len(to_group_shape) > 0: + flattened = Flatten.new( + tuple(InputDim(fi) for fi in from_group_dim if from_size[fi] >= 1) + ) + result_pp += [ + Split.new(flattened, tuple(to_group_shape), i) + for i in range(len(to_group_shape)) + ] + + return tuple(result_pp) + + +def dim_tile(ndim: int, dims: tuple[int, ...]) -> DimMap: + if len(dims) < ndim: + dims = (1,) * (ndim - len(dims)) + dims + return dim_repeat(ndim, dims) + + +def dim_transpose(ndim: int, dim1: int, dim2: int) -> DimMap: + dim1 = normalize_dim(dim1, ndim) + dim2 = normalize_dim(dim2, ndim) + assert dim1 < ndim + assert dim2 < ndim + dimmap = [InputDim(i) for i in range(ndim)] + swapdim = dimmap[dim1] + dimmap[dim1] = dimmap[dim2] + dimmap[dim2] = swapdim + return tuple(dimmap) + + +def dim_squeeze(shape: Shape, dim: Optional[int] = None) -> DimMap: + # FIXME: this is wrong when dim=None and one of the dimensions + # equals size of the mesh. For example squeeze(DTensor(tensor(4), Shard[0])) could + # end up as squeeze(tensor(1)) if we have 4 devices; this would lead to + # removal of a dimension that is not actually a singleton. + return tuple( + InputDim(i) + for i, s in enumerate(shape) + if s > 1 or (dim is not None and i != normalize_dim(dim, len(shape))) + ) + + +def dim_unsqueeze(ndim: int, dim: int) -> DimMap: + dims = tuple(InputDim(i) for i in range(ndim)) + if dim < 0: + dim += ndim + 1 + return dims[:dim] + (Singleton(),) + dims[dim:] + + +def dim_view_as_real(shape: Shape) -> DimMap: + ndim = len(shape) + results: list[DimSpec] = [InputDim(i) for i in range(ndim - 1)] + # each complex number is split into two real numbers, + # resulting in one more dimension of size 2 + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 0)) + results.append(Split(InputDim(ndim - 1), (shape[-1], 2), 1)) + return tuple(results) + + +def dim_reduction(ndim: int, dim_or_dims: Optional[DimsType], keepdim: bool) -> DimMap: + """ + General fallback for reduction ops where Partial() does not apply. + + This will cause incoming tensor to be replicated on the reducing dimensions. + """ + if dim_or_dims is None: + dim_or_dims = tuple(range(ndim)) + if isinstance(dim_or_dims, int): + dim_or_dims = (dim_or_dims,) + dim_or_dims = tuple(d if d >= 0 else d + ndim for d in dim_or_dims) + return tuple( + InputDim(i) if i not in dim_or_dims else Singleton() + for i in range(ndim) + if i not in dim_or_dims or keepdim + ) + + +dim_maps: dict[Callable[..., torch.Tensor], Callable[..., DimMap]] = { + torch.atleast_1d: lambda x: dim_pad_left(x.ndim, 1), + torch.atleast_2d: lambda x: dim_pad_left(x.ndim, 2), + torch.atleast_3d: lambda x: dim_atleast_3d(x.ndim), + torch.broadcast_to: lambda input, shape: expand(input.shape, shape), + Tensor.expand: lambda self, *sizes: expand(self.shape, normalize_sizes(sizes)), + torch.flatten: lambda tensor: dim_flatten(tensor.ndim), + torch.movedim: lambda input, source, destination: dim_movedim( + input.ndim, source, destination + ), + torch.permute: lambda input, dims: tuple( + InputDim(i) for i in normalize_dims(dims, input.ndim) + ), + torch.ravel: lambda tensor: dim_flatten(tensor.ndim), + Tensor.repeat: lambda self, *sizes: dim_repeat(self.ndim, sizes), + torch.reshape: lambda input, shape: view_groups(input.shape, shape), + torch.squeeze: lambda input, dim=None: dim_squeeze(input.shape, dim), + torch.tile: lambda input, dims: dim_tile(input.ndim, dims), + torch.transpose: lambda input, dim0, dim1: dim_transpose(input.ndim, dim0, dim1), + torch.unsqueeze: lambda input, dim: dim_unsqueeze(input.ndim, dim), + Tensor.view: lambda input, *shape: view_groups(input.shape, shape), + torch.view_as_complex: lambda input: dim_flatten(input.ndim, input.ndim - 2), + torch.view_as_real: lambda input: dim_view_as_real(input.shape), +} + + +def propagate_shape_and_sharding( + input_src_placements: Sequence[Placement], + global_input_shape: Shape, + rule: DimMap, + mesh_sizes: Shape, + strict_view: bool = False, +) -> tuple[Sequence[Placement], Sequence[Placement]]: + """ + Determine input target sharding and output sharding based on + given global tensor shape and input source sharding. + + Sharding propagation follows mapped dimensions: + - An output dimension that maps directly to an input dimension is sharded equally + - An output dimension that is a flattened set of input dimensions can only be + sharded if only the leftmost flattened dimension is sharded. + - An output dimension that is a split of the input dimension can only be sharded + if the leftmost split size is divisible by the mesh dimension + """ + assert len(input_src_placements) == len(mesh_sizes) + # for each input dim, for each mesh dim, provides a list of possible shardable dimensions + mesh_ndim = len(mesh_sizes) + shardable_dims: dict[int, list[bool]] = {} + + # in case an input dimension disappears (e.g. collapsing, reduction) + # we cannot shard in that dimension (we need a replication fall-back rule) + seen_input_dims: set[int] = set() + + def collect_used_inputs(cmd: DimSpec) -> None: + if isinstance(cmd, InputDim): + seen_input_dims.add(cmd.input_dim) + for inp in cmd.inputs(): + collect_used_inputs(inp) + + for cmd in rule: + collect_used_inputs(cmd) + for dim in range(len(global_input_shape)): + shardable_dims[dim] = [dim in seen_input_dims] * mesh_ndim + + def maybe_get_shard_mesh_dim_and_placement( + input_dim: InputDim, + ) -> tuple[Optional[int], Optional[Shard]]: + # if input_dim is sharded, return the mesh_dim and shard placement + for i, placement in enumerate(input_src_placements): + if isinstance(placement, Shard) and placement.dim == input_dim.input_dim: + return i, placement + return None, None + + def get_in_dim_to_shard(cmd: DimSpec) -> Optional[InputDim]: + # TODO(whc) this helper is pretty hard to understand, at least it should be better documented if not refactored + if isinstance(cmd, InputDim): + return cmd + elif isinstance(cmd, Flatten): + for i, dim in enumerate(cmd.input_dims): + if isinstance(dim, InputDim): + can_shard_dim = True + shard_mesh_dim, shard_placement = ( + maybe_get_shard_mesh_dim_and_placement(dim) + ) + input_sharded = shard_mesh_dim is not None + if i > 0: + can_shard_dim = False + if strict_view and input_sharded: + raise RuntimeError( + f"Attempted to flatten sharded dimension {i}, ", + "but only the leftmost dim of a Flatten can be sharded.", + ) + elif input_sharded: + assert ( + shard_placement is not None and shard_mesh_dim is not None + ) + tensor_dim_size = global_input_shape[shard_placement.dim] + mesh_dim_size = mesh_sizes[shard_mesh_dim] + if tensor_dim_size % mesh_dim_size != 0: + can_shard_dim = False + if strict_view: + raise RuntimeError( + f"Attempted to flatten unevenly sharded dimension {i}, " + "which would require resharding the input. " + "Please explicitly redistribute the tensor instead." + ) + + shardable_dims[dim.input_dim] = [can_shard_dim] * mesh_ndim + dim0 = cmd.input_dims[0] + # TODO(whc) dim0 can be sharded or not sharded, can't it? + # should we only return it if its sharded in the placement? + return dim0 if isinstance(dim0, InputDim) else None + elif isinstance(cmd, Split): + in_dim = get_in_dim_to_shard(cmd.input_dim) + out_size = cmd.group_shape[cmd.split_id] + if cmd.split_id == 0 and in_dim is not None: + # we need to check that the input dimension is divisible + # by the size of the submesh we're sharding it on + # NOTE: it would be possible to shard the same input dimension + # on more than one mesh dimension. In that case, the dimension + # needs to be divisible by the product of mesh sizes. + # In order to keep the problem more tractable, we will not consider + # double resharding as a suggestion (e.g. [Shard(0), Shard(0) ]) + # but we will allow it if that's the input and it's compatible + + # 1. is this dimension shardable on each individual mesh dim? + shardable_dims[in_dim.input_dim] = [ + out_size % mesh_dim_size == 0 for mesh_dim_size in mesh_sizes + ] + + # 2. here we special case things like [Shard(0), Shard(0)] + submesh_size = 1 + for size, shard in zip(mesh_sizes, input_src_placements): + if isinstance(shard, Shard) and shard.dim == in_dim: + submesh_size *= size + assert out_size % submesh_size == 0, ( + f"Resulting dimension size {out_size} is not divisible by its mesh dimension {submesh_size}." + ) + + # we will only shard our first component of the split + return in_dim if cmd.split_id == 0 else None + elif isinstance(cmd, Repeat): + in_dim = get_in_dim_to_shard(cmd.input_dim) + if in_dim is not None: + shardable_dims[in_dim.input_dim] = [False] * mesh_ndim + return None + else: + return None + + # for each output dim, find the corresponding input dim in terms of sharding prop + shard_dim_map = {} + for dim, cmd in enumerate(rule): + in_dim = get_in_dim_to_shard(cmd) + if in_dim is not None: + shard_dim_map[in_dim.input_dim] = dim + + input_tgt_placements = [ + ( + Replicate() + if isinstance(p, Shard) and not shardable_dims[p.dim][mesh_dim] + else p + ) + for mesh_dim, p in enumerate(input_src_placements) + ] + output_placements = [ + Shard(shard_dim_map[p.dim]) if isinstance(p, Shard) else p + for p in input_tgt_placements + ] + + return input_tgt_placements, output_placements + + +def register_op_strategy_map( + aten_op_overload: torch._ops.OpOverload, + local_op_name: Callable[..., torch.Tensor], + schema_info: Optional[RuntimeSchemaInfo] = None, + strict_view: bool = False, +) -> None: + """ + Helper that registers strategies for view-like operators that follow a pattern: + (1) define the way input dims are split/combined to form output dims (dim_maps) + (2) register a strategy for the op schema that uses the dim_map as a sharding prop rule + + strict_view: if True, we will error out if the view-operation would require resharding the input. + Currently, this should be set to 'true' for any "view" ops. + We could diverge behavior for "reshape" ops which could perform a redistribute implicitly. + """ + dim_map: Callable[..., DimMap] = dim_maps[local_op_name] + + @register_op_strategy(aten_op_overload, schema_info=schema_info) + def reshape_strategy(op_schema: OpSchema) -> StrategyType: + rules = dim_map(*op_schema.args_schema, **op_schema.kwargs_schema) + input_strategy = cast(OpStrategy, op_schema.args_schema[0]) + mesh = op_schema.get_mesh_from_args(validate=False) + + global_in_shape = input_strategy.shape + assert global_in_shape is not None, "Shape required." + + output_strategy = OpStrategy([]) + for input_placement_strategy in input_strategy.strategies: + input_src_spec = input_placement_strategy.output_spec + + input_tgt_placements, output_placements = propagate_shape_and_sharding( + input_src_spec.placements, + tuple(global_in_shape), + rules, + mesh.shape, + strict_view, + ) + + # TODO: optimize this. we shouldn't simply blindly replicate + # unshardable dims ... + # FIXME: this can be wrong for situations where we have + # [Shard(0), Shard(0)] + input_tgt_spec = DTensorSpec( + placements=tuple(input_tgt_placements), + mesh=mesh, + tensor_meta=input_src_spec.tensor_meta, + ) + redistribute_costs: list[list[float]] = [ + generate_redistribute_costs(input_strategy, input_tgt_spec) + ] + + output_spec = DTensorSpec(mesh=mesh, placements=tuple(output_placements)) + output_strategy.strategies.append( + OpSpec( + output_specs=output_spec, + input_specs=(input_tgt_spec,), + redistribute_cost=redistribute_costs, + ) + ) + + return output_strategy + + +register_op_strategy_map(aten.squeeze.default, torch.squeeze) +register_op_strategy_map( + aten.squeeze.dim, torch.squeeze, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.view.default, + Tensor.view, + schema_info=RuntimeSchemaInfo(1), + strict_view=True, +) +register_op_strategy_map( + aten.reshape.default, torch.reshape, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten._unsafe_view.default, + Tensor.view, + schema_info=RuntimeSchemaInfo(1), + strict_view=True, +) +register_op_strategy_map( + aten.unsqueeze.default, torch.unsqueeze, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.expand.default, Tensor.expand, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.permute.default, torch.permute, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.repeat.default, Tensor.repeat, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map( + aten.transpose.int, torch.transpose, schema_info=RuntimeSchemaInfo(1) +) +register_op_strategy_map(aten.view_as_complex.default, torch.view_as_complex) +register_op_strategy_map(aten.view_as_real.default, torch.view_as_real) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/utils.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..589f1e5c9fc2bef12cb95a52271dd5ed23a74933 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_ops/utils.py @@ -0,0 +1,298 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import functools +import itertools +import operator +from collections.abc import Iterable, Sequence +from typing import Callable, cast, Optional, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch._prims_common import DimsSequenceType, DimsType +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor._collective_utils import redistribute_cost +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OpStrategy, + OutputSharding, + PlacementList, + RuntimeSchemaInfo, +) +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + + +# convenient wrapper to register sharding propagation rules +# pyre-fixme[3]: Return type must be annotated. +# pyre-fixme[2]: Parameter must be annotated. +def register_prop_rule( + op: Union[torch._ops.OpOverload, list[torch._ops.OpOverload]], + schema_info: Optional[RuntimeSchemaInfo] = None, +) -> Callable[ + [Callable[[OpSchema], OutputSharding]], Callable[[OpSchema], OutputSharding] +]: + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def wrapper( + impl: Callable[[OpSchema], OutputSharding], + ) -> Callable[[OpSchema], OutputSharding]: + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_sharding_prop_rule( + overload, impl, schema_info + ) + return impl + + return wrapper + + +def register_op_strategy( + op, schema_info=None +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + # pyre-fixme[53]: Captured variable `func` is not annotated. + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + + # For every ATen op that accepts any args in this list, + # the arg itself can impact the strides (and potentially the sharding strategy) + # of the output tensor. + # thus, we will detect ATen schemas with any of these args and ensure + # that they get specialized here. + arg_names_that_require_specializing_cache_strategy = [ + "memory_format", + ] + + def wrapper(impl): + if isinstance(op, list): + overloads = op + else: + overloads = [op] + + for overload in overloads: + curr_schema_info = None + if schema_info is None: + specialized_args = [ + a.name + for a in overload._schema.arguments + if a.name in arg_names_that_require_specializing_cache_strategy + ] + if any(specialized_args): + curr_schema_info = RuntimeSchemaInfo( + static_kwargkey=specialized_args + ) + else: + curr_schema_info = schema_info + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, impl, curr_schema_info + ) + return impl + + return wrapper + + +def as_list( + x: Union[list[object], object], + # pyre-fixme[11]: Annotation `immutable_list` is not defined as a type. +) -> Union[list[object], torch.fx.immutable_collections.immutable_list]: # type: ignore[valid-type] + # During tracing, `aten.sum.dim_IntList` uses `immutable_list` for its args, + # which is an object but treated as a list by the tracer. Therefore, keep + # `immutable_list` intact here as well. + if type(x) is list or isinstance(x, torch.fx.immutable_collections.immutable_list): + return x + else: + return [x] + + +def normalize_dim(dim: int, ndim: int) -> int: + return dim if dim >= 0 else dim + ndim + + +def normalize_dims(dims: DimsType, ndim: int) -> DimsSequenceType: + """Normalize a dim or a sequence of dims, so that they are all positive.""" + if isinstance(dims, int): + dims = (normalize_dim(dims, ndim),) + elif isinstance(dims, list): + dims = [normalize_dim(dim, ndim) for dim in dims] + elif isinstance(dims, tuple): + dims = tuple([normalize_dim(dim, ndim) for dim in dims]) + return dims + + +def prod(xs: Iterable[int]) -> int: + return functools.reduce(operator.mul, xs, 1) + + +def is_tensor_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh.size(i) + + for i, dim_size in enumerate(shape): + # TODO: maybe we should determine is_shardable based on + # whether it's evenly sharded or not + if shards_map[i] > 1 and dim_size < shards_map[i]: + return False + + return True + + +def is_tensor_evenly_shardable(shape: Sequence[int], spec: DTensorSpec) -> bool: + """Check if the shape is evenly shardable according to the spec.""" + # number of shards in each tensor dimension + shards_map = [1] * len(shape) + for i, placement in enumerate(spec.placements): + if placement.is_shard(): + shard_dim = cast(Shard, placement).dim + shards_map[shard_dim] *= spec.mesh.size(i) + + for i, dim_size in enumerate(shape): + if shards_map[i] > 1 and (dim_size % shards_map[i] != 0): + return False + + return True + + +def is_tensor_dim_sharded(spec: DTensorSpec, dim: int) -> bool: + """Return True if tensor dim is sharded.""" + return any(p.is_shard(dim) for p in spec.placements) + + +def is_tensor_partial(spec: DTensorSpec) -> bool: + """Return True if tensor is partial on the mesh.""" + return any(p.is_partial() for p in spec.placements) + + +def infer_broadcast_dims_map( + common_shape: torch.Size, input_shape: torch.Size +) -> list[int]: + # infer the broadcast dims map, where it maps from the common shape dim to the input shape dim + # this is aligned with the broadcast semantics + common_ndim = len(common_shape) + input_ndim = len(input_shape) + broadcast_dims_map = [-1] * common_ndim + for idx in range(-1, -1 - input_ndim, -1): + if input_shape[idx] == common_shape[idx]: + broadcast_dims_map[common_ndim + idx] = input_ndim + idx + return broadcast_dims_map + + +def map_placements_after_broadcast( + placements: tuple[Placement, ...], + shape: torch.Size, + broadcast_dims_map: list[int], +) -> tuple[Placement, ...]: + """Map each placement based on the output shape after broadcast.""" + new_placements: list[Placement] = [] + for placement in placements: + if isinstance(placement, (Replicate, Partial)): + new_placements.append(placement) + else: + assert isinstance(placement, Shard) + shard_dim = normalize_dim(placement.dim, len(shape)) + new_shard_dim = broadcast_dims_map[shard_dim] + if new_shard_dim != -1: + # there's a map from the common shape shard dim to + # the input shape shard dim before broadcasting, + # use that instead + new_placements.append(Shard(new_shard_dim)) + else: + # there's no map between common shape shard dim and + # the input shape shard dim before broadcasting, + # in this case it means implicit broadcasting happen + # in this dim, so we can just mark it as replicate + # and implict broadcast will broadcast automatically + # to the sharded shape + new_placements.append(Replicate()) + + return tuple(new_placements) + + +def generate_redistribute_costs( + src_strategy: OpStrategy, dst_spec: DTensorSpec +) -> list[float]: + redistribute_costs: list[float] = [ + redistribute_cost(strat.output_spec, dst_spec) + for strat in src_strategy.strategies + ] + + return redistribute_costs + + +def expand_to_full_mesh_op_strategy( + mesh: DeviceMesh, + op_schema: OpSchema, + single_mesh_dim_strategies: list[PlacementList], + *, + input_index: int = 1, + inplace_op: bool = False, +) -> OpStrategy: + # Expand the single_mesh_dim_strategies to full mesh dim strategies. + all_mesh_dim_strategies = [single_mesh_dim_strategies] * mesh.ndim + + strategy_combs = itertools.product(*all_mesh_dim_strategies) + + all_strategies = [] + for strategy_comb in strategy_combs: + spec_list: list[Optional[DTensorSpec]] = [] + for specs in zip(*strategy_comb): + if specs[0] is not None: + spec_list.append(DTensorSpec(mesh, specs)) + else: + spec_list.append(None) + + input_specs: list[DTensorSpec] = [ + s for s in spec_list[input_index:] if isinstance(s, DTensorSpec) + ] + + input_args_strategy = op_schema.args_strategy + assert len(input_specs) == len(input_args_strategy) + self_spec = input_args_strategy[0].strategies[0].output_spec + + if inplace_op and self_spec.placements != input_specs[0].placements: + # if it's inplace op, we would only allow the OpSpec to be added when the + # input_spec matches the first argument's runtime sharding, otherwise we skip + continue + + # check inputs shardable + inputs_shardable = all( + is_tensor_shardable(inp.shape, s) + for inp, s in zip(input_args_strategy, input_specs) + ) + + # only add to the all_strategies list when all inputs are shardable + if inputs_shardable: + redistribute_cost = [ + generate_redistribute_costs(input_strategy, input_spec) + for input_strategy, input_spec in zip(input_args_strategy, input_specs) + ] + if input_index > 1: + output_specs = tuple(spec_list[:input_index]) + else: + if spec_list[0] is not None: + output_specs = spec_list[0] # type: ignore[assignment] + else: + raise RuntimeError("output spec is None") + strategy = OpSpec( + output_specs=output_specs, + input_specs=input_specs, + redistribute_cost=redistribute_cost, + ) + all_strategies.append(strategy) + + return OpStrategy(all_strategies) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_random.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_random.py new file mode 100644 index 0000000000000000000000000000000000000000..459757d430333a203c7851d15ae13ea986a903fe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_random.py @@ -0,0 +1,393 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +import warnings +from typing import Optional, Union + +import torch +import torch.distributed as dist +from torch import Tensor +from torch.distributed.device_mesh import _get_device_handle, DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import Shard + + +__all__ = [ + "is_rng_supported_mesh", + "manual_seed", + "OffsetBasedRNGTracker", +] + +_rng_tracker: Optional["_RNGStateTracker"] = None + + +def is_rng_supported_mesh(device_mesh: DeviceMesh) -> bool: + """Checks if the current device of ``device_mesh`` supports DTensor's random APIs. + Currently DTensor Random APIs only supports cuda/cuda-like devices. We suggest + users call this API to test the availability before using our random APIs. + + Args: + device_mesh (:class:`DeviceMesh`): The device mesh on which we check if the + random ops APIs are supported. + + Returns: + A bool value. True if ``device_mesh`` supports DTensor Random APIs; False otherwise. + + .. warning:: + Currently we only support correct RNG on cuda/cuda-like devices. + """ + device_handle = _get_device_handle(device_mesh.device_type) + if device_handle and hasattr(device_handle, "set_rng_state"): + return True + else: + # TODO: Logs way too much + warnings.warn( + f"DTensor random operators may not have complete support on {device_mesh.device_type} device mesh" + ) + return False + + +def manual_seed(seed: int, device_mesh: DeviceMesh) -> None: + """Sets the seed for generating random numbers for the calling rank. + + Args: + seed (int): The desired seed. + device_mesh (:class:`DeviceMesh`): The device mesh to set the seed. It is + required that the ``device_mesh`` include the calling rank. This is + to ensure that the SPMD region maintains a synchronous RNG state, which + means no ranks should be initialized with values other than ``seed``. + + Returns: + None + + .. warning:: + :func:`manual_seed` does not check the ``seed`` value correctness. Users must + ensure on their own that the value passed in is the desired ``seed`` for ranks + within ``device_mesh``. + If ``device_mesh`` is a sub-mesh and the calling rank is not a part of it, + ``manual_seed`` will throw an error. + Current implementation only supports a GPU device mesh. + """ + if not is_rng_supported_mesh(device_mesh): + warnings.warn( + "DTensor manual_seed() may not have complete support " + f"on {device_mesh.device_type} device mesh" + ) + return + + # instantiate a RNG tracker if haven't. By default DTensor uses an + # OffsetBasedRNGTracker to perform random operators. + global _rng_tracker + if not _rng_tracker: + _rng_tracker = OffsetBasedRNGTracker(device_mesh, run_state_sync=False) + + # the current rank is in mesh + if device_mesh.get_coordinate() is not None: + _rng_tracker._manual_seed(seed) + else: + raise RuntimeError( + "manual_seed requires the current rank to be a part of the device mesh " + "otherwise DTensor RNG state on the rank will not be initialized and " + "the behavior of DTensor random ops is undefined." + ) + + +class _RNGStateTracker: + """ + _RNGStateTracker stores Random Number Generator (RNG) state (a ByteTensor object) + in a dict, mapping from a corresponding tag to each state tensor. It also provides + a set of convenient utility methods to help access/modify the state tensors. The most + important interface is _distribute_region which will be used when DTensor executes + a random op (an operator that calls RNG). + """ + + def __init__(self, device: torch.device): + self._device = device + self._device_handle = _get_device_handle(self._device.type) + if not (self._device_handle and self._device_handle.is_available()): + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of " + f"{device.type} device but couldn't find." + ) + + self._states: dict[str, Tensor] = {} + self._use_distribute_region = True + + @property + def rng_states(self) -> dict[str, Tensor]: + return self._states + + @property + def distribute_region_enabled(self) -> bool: + return self._use_distribute_region + + @distribute_region_enabled.setter + def distribute_region_enabled(self, value) -> None: + self._use_distribute_region = value + + def rng_state_is_sync(self, name) -> bool: + return name in self.rng_states + + def get_seed(self, name: str) -> int: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + seed_tensor = (self.rng_states[name])[0:8].view(dtype=torch.int64) + return int(seed_tensor.item()) + + def set_seed(self, name: str, seed: int) -> None: + seed_tensor = torch.tensor([seed], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + offset_tensor = torch.tensor([0], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) + + def _distribute_region(self, spec: DTensorSpec): + pass + + def _manual_seed(self, parallel_seed: int) -> None: + pass + + +class OffsetBasedRNGTracker(_RNGStateTracker): + """ + This subclass of ``_RNGStateTracker`` defines the default policy of how RNG states + should be shared and synchronized among all ranks to respect the semantics of DTensor + random operators. + + note: _RNGStateTracker only supports cuda/cuda-like device. + """ + + def __init__( + self, + device_mesh: DeviceMesh, + run_state_sync: bool = True, + ): + super().__init__(_resolve_device(device_mesh=device_mesh)) + assert self._device_handle is not None + # DTensor RNG tracker so far only supports CUDA/CUDA-like devices + if self._device.type == "cpu": + raise RuntimeError( + f"{self.__class__.__name__} instantiation requires the presence of " + f"CUDA/CUDA-like/XPU device. Got {self._device.type} instead." + ) + + rng_state = self._device_handle.get_rng_state().to(self._device) + if run_state_sync: + # synchronize RNG state using rank 0's current one + dist.broadcast(rng_state, 0) + + self.rng_states["parallel-rng"] = rng_state.to("cpu") + + def _manual_seed(self, parallel_seed: int) -> None: + self.set_seed("parallel-rng", parallel_seed) + + @contextlib.contextmanager + def _distribute_region(self, spec: DTensorSpec): + # check if the parallel rng state has been synchronized or not + if not self.rng_state_is_sync("parallel-rng"): + raise RuntimeError( + "OffsetBasedRNGTracker requires the random state to be synchronized " + "before entering into a distribute region!" + ) + + if self.distribute_region_enabled: + old_offset = self.get_offset("parallel-rng") + self._set_pre_op_offset(spec) + with torch.random.fork_rng( + devices=[self._device], device_type=self._device.type + ): + assert self._device_handle is not None + self._device_handle.set_rng_state(self.rng_states["parallel-rng"]) + try: + yield # execute the region code + finally: + # update offset to synchronize among ranks + self._set_post_op_offset(spec, old_offset) + else: + yield + + def get_offset(self, name: str) -> int: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + offset_tensor = (self.rng_states[name])[8:].view(dtype=torch.int64) + return int(offset_tensor.item()) + + def set_offset(self, name: str, offset: int) -> None: + if name not in self.rng_states: + raise RuntimeError( + f"{self.__class__.__name__} does not have random state for {name}" + ) + + seed_tensor = (self.rng_states[name])[0:8] + offset_tensor = torch.tensor([offset], dtype=torch.uint64, device="cpu").view( + torch.uint8 + ) + self.rng_states[name] = torch.cat([seed_tensor, offset_tensor]) + + def _set_pre_op_offset(self, spec: DTensorSpec) -> None: + """Set the starting RNG offset for current device's local shard before actual + op execution. The pre_op_offset value should start from the current RNG offset + and increment by the size of local shard until it reaches the size of the whole + DTensor. For different ranks that hold the same DTensor shard, their pre_op_offset + will be the same. + + Args: + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we prepare the offset for running random ops. + + Returns: + None + + .. warning:: + Note that, current implementation does not consider DTensor's continguity. + + Example: + take a DTensor of shape [8, 16] as an example. Assume that the DTensor + is placed on a device mesh with placements ([Shard(1), Replicate(), Shard(0)]), + and the mesh is: + [[[0, 1], [2, 3]], [[4, 5], [6, 7]]] + ``spec.mesh.get_coordinate()`` provides the coordinate of the current rank + in the mesh. For example, the coordinate of rank 5 is (1, 0, 1). + + Another concept to introduce besides rank coordinate is shard coordinate. + Each rank holds a local shard of the DTensor. In the example, the DTensor + is partitioned into 4 [4, 8] shards. The first shard has 2 replicas and + rank 0 (coord (0, 0, 0)) and rank 2 (coord (0, 1, 0)) have 1 replica each. + That being said, the local shard on rank 0 and rank 2 correspond to the same + shard of the DTensor. To denote each DTensor shard, we use a shard coordinate + (in the example, it will be a tuple (i, j) where shard (i, j) has the slice + DTensor[4 * i : 4 * (i + 1), 8 * j : 8 * (j + 1)], 0 <= i < 2, 0 <= j < 2). + + Once we have rank coordinate and shard coordinate, we can calculate on each rank + what shard of the DTensor the rank holds, with the help of dim_map. The dim_map + of the above DTensor is [2, 0] so the shard coordinate of a rank with rank coord + (x, y, z) is simply (z, x) by taking(rank_coord[dim_map[0]],rank_coord[dim_map[1]]). + Following this calculation, + rank 0 and rank 2 holds the shard of coord (0, 0); + rank 1 and rank 3 holds the shard of coord (0, 1); + rank 4 and rank 6 holds the shard of coord (1, 0); + rank 5 and rank 7 holds the shard of coord (1, 1); + + The last value to calculate before obtaining the starting offset is the shard linear index. + The starting offset for each rank will be its shard_linear_index * local_tensor_numel. + """ + dtensor_shape = spec.shape + mesh = spec.mesh + # note: dim_map does not allow double sharding which is the FSDP(fully_shard)+TP + # case. Replace the custom logic with dim_map once we support it. + dim_map: list[Union[int, list[int]]] = [-1] * spec.ndim + for i, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + shard_dim = placement.dim + if dim_map[shard_dim] == -1: + dim_map[shard_dim] = [i] + else: + mesh_dim_list = dim_map[shard_dim] + assert isinstance(mesh_dim_list, list) + mesh_dim_list.append(i) + + # Compute shard coordinate: + # The coordinate on each tensor dim is a tuple (idx, range) + # If a DTensor is partitioned on its dim i into n shards, and the current rank + # holds the j-th, then its shard coordinate will be (idx=j, range=n) on dim i + mesh_coordinate = mesh.get_coordinate() + assert mesh_coordinate is not None + mesh_size = mesh.shape + shard_idx_by_dim = [] + total_num_shards_by_dim = [] # total number of shards on each tensor dim + for mesh_dim in dim_map: + shard_idx = 0 + total_num_shards = 1 + # the tensor dim is sharded on more than 1 mesh dim + if isinstance(mesh_dim, list): + rank_coord = [mesh_coordinate[d] for d in mesh_dim] + num_shards = [mesh_size[d] for d in mesh_dim] + # compute the shard idx and total number of shards + for idx, size in zip(rank_coord, num_shards): + shard_idx = shard_idx * size + idx + total_num_shards *= size + + shard_idx_by_dim.append(shard_idx) + total_num_shards_by_dim.append(total_num_shards) + + # compute shard linear index + shard_linear_idx = self._calc_shard_linear_idx( + shard_idx_by_dim, total_num_shards_by_dim + ) + + # compute starting offset using the first shard's size + local_size_on_rank_0 = list(dtensor_shape) + for idx, placement in enumerate(spec.placements): + if isinstance(placement, Shard): + mesh_dim_size = mesh.size(idx) + shard_dim = placement.dim + local_size_on_rank_0[shard_dim], _ = ( + placement._local_shard_size_and_offset( + dtensor_shape[shard_dim], + mesh_dim_size, + 0, + ) + ) + + from torch.distributed.tensor._ops.utils import prod + + local_size = prod(local_size_on_rank_0) + + # get current RNG offset + current_offset = self.get_offset("parallel-rng") + + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + offset_incr = (shard_linear_idx * local_size + 3) // 4 * 4 + self.set_offset("parallel-rng", current_offset + offset_incr) + + def _set_post_op_offset(self, spec: DTensorSpec, old_offset: int) -> None: + """Sets the RNG to a synchronized state after running the local random op. Every + rank should set its RNG offset to `old_offset + DTensor.numel()` where old_offset is + the offset before calling `set_pre_op_offset` i.e. the offset before running DTensor + random ops. + + Args: + spec (:class:`DTensorSpec`): the spec of the DTensor object on which + we post-process the offset for running random ops. + + Returns: + None + """ + dtensor_shape = spec.shape + + from torch.distributed.tensor._ops.utils import prod + + numel = prod(dtensor_shape) + # pytorch: offset must be multiple of 4 + # source: aten/src/ATen/cuda/CUDAGeneratorImpl.cpp + numel = (numel + 3) // 4 * 4 + self.set_offset("parallel-rng", old_offset + numel) + + def _calc_shard_linear_idx( + self, shard_coord: list[int], shard_size: list[int] + ) -> int: + # compute shard linear index + shard_linear_idx = 0 + shard_coord_stride = 1 + for idx, size in zip(reversed(shard_coord), reversed(shard_size)): + shard_linear_idx += idx * shard_coord_stride + shard_coord_stride *= size + + return shard_linear_idx + + +def _resolve_device(device_mesh: DeviceMesh) -> torch.device: + device_type = device_mesh.device_type + device_handle = _get_device_handle(device_type) + assert device_handle is not None + device_idx = device_mesh.get_rank() % device_handle.device_count() + return torch.device(f"{device_type}:{device_idx:d}") diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_redistribute.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_redistribute.py new file mode 100644 index 0000000000000000000000000000000000000000..7c089ed09124c6ad1594390f9d29972c4ad450e7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_redistribute.py @@ -0,0 +1,403 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import logging +from functools import cache +from typing import cast, NamedTuple, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor.device_mesh import DeviceMesh +from torch.distributed.tensor.placement_types import ( + Partial, + Placement, + Replicate, + Shard, +) + + +logger = logging.getLogger(__name__) + + +class _TransformInfo(NamedTuple): + mesh_dim: int + src_dst_placements: tuple[Placement, Placement] + # logical_shape on this mesh dimension + logical_shape: list[int] + + +def _gen_transform_infos_non_cached( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> list[_TransformInfo]: + """ + Generate the transform infos from the source placements to the target placements. + + To transform from source to target placement it might have multiple steps, i.e. it + might decompose Si -> Sj into Si -> R -> Sj. + This would detect if there're mis-aligned/nested shardings between src/dst placements. + E.g. Suppose the redistribution to perform is (Shard(0), Shard(0)) -> (Replicate(), Shard(0)), + in this case Shard(0) -> Shard(0) for mesh dimension 1 actually needs resharding, because in + the former is a nested-sharding of a tensor already already sharded dimension 0, whereras + the latter is the first sharding on tensor dimension 0. + """ + transform_infos: list[_TransformInfo] = [] + + device_mesh = src_spec.device_mesh + my_coordinate = device_mesh.get_coordinate() + assert my_coordinate is not None + + # logical shape records the logic tensor shape on the mesh dimension + # this is useful to ensure uneven sharding gets correct output shape + initial_logical_shape = list(src_spec.shape) + mesh_dims_to_logical_shape = [initial_logical_shape] + + if device_mesh.ndim == 1: + # if device_mesh is 1D, redistribute is a simple direct transformation + transform_infos.append( + _TransformInfo( + mesh_dim=0, + src_dst_placements=(src_spec.placements[0], dst_spec.placements[0]), + logical_shape=initial_logical_shape, + ) + ) + return transform_infos + + # Handle multi-dim device mesh placement redistribution + # First, we need to build the logical shape for each mesh dim + # for correct allgathering uneven shards on each mesh dim (with dynamic padding) + for i, src in enumerate(src_spec.placements): + current_logical_shape = mesh_dims_to_logical_shape[i] + if isinstance(src, Shard): + if i < device_mesh.ndim - 1: + # calculate and save the logical shape for this sharding + mesh_dim_size = device_mesh.size(mesh_dim=i) + local_shard_size, _ = src._local_shard_size_and_offset( + current_logical_shape[src.dim], + mesh_dim_size, + my_coordinate[i], + ) + new_logical_shape = list(current_logical_shape) + new_logical_shape[src.dim] = local_shard_size + mesh_dims_to_logical_shape.append(new_logical_shape) + else: + mesh_dims_to_logical_shape.append(current_logical_shape) + + # Next, we need to derive the transform infos from src to dst placements, + # here we use a greedy search with step by step state transformations + current_placements = list(src_spec.placements) + target_placements = list(dst_spec.placements) + + if src_spec.num_shards > 1: + # If src_spec have sharding, it could potentially have sharding that is misaligned with dst_spec + # a common case of this is nested sharding (i.e. (S(0), S(0)) -> (R, S(0))). + # In those cases, we first traverse from inner placement to outer placement + # to detect misaligned shardings and properly replicate nested sharding first. + for mesh_dim in reversed(range(len(current_placements))): + current = current_placements[mesh_dim] + target = target_placements[mesh_dim] + # If target is not Shard, we can directly redistribute since we are traversing from innner + # to outer placements here + if isinstance(target, Shard): + # If target is Shard, check for nested sharding on the tensor dim BEFORE the current mesh_dim + shard_dim = target.dim + current_mesh_sharding, target_mesh_sharding = [], [] + for i, (s, p) in enumerate(zip(current_placements, target_placements)): + if i >= mesh_dim: + break + if s.is_shard(shard_dim): + current_mesh_sharding.append(i) + if p.is_shard(shard_dim): + target_mesh_sharding.append(i) + + if current_mesh_sharding != target_mesh_sharding: + # if current/target_placements have misaligned sharding on the tensor dim BEFORE the current + # mesh_dim, we need to replicate the tensor on the mesh dim first to clear the nested sharding + target = Replicate() + + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + + # We always traverse from outer placement to inner placement to collect the remaining + # needed transform infos (i.e. the replication from nested sharding might need to further + # perform resharding to Shard again) + for mesh_dim, (current, target) in enumerate( + zip(current_placements, target_placements) + ): + if current != target: + transform_infos.append( + _TransformInfo( + mesh_dim=mesh_dim, + src_dst_placements=(current, target), + logical_shape=mesh_dims_to_logical_shape[mesh_dim], + ) + ) + current_placements[mesh_dim] = target + + return transform_infos + + +@cache +def _gen_transform_infos( + src_spec: DTensorSpec, + dst_spec: DTensorSpec, +) -> list[_TransformInfo]: + return _gen_transform_infos_non_cached(src_spec, dst_spec) + + +def redistribute_local_tensor( + local_tensor: torch.Tensor, + current_spec: DTensorSpec, + target_spec: DTensorSpec, + *, + async_op: bool = False, + is_backward: bool = False, +) -> torch.Tensor: + """ + This redistribute the local tensor (torch.Tensor) from the current DTensorSpec to + the target DTensorSpec, which involves the necessary collective calls to transform + the local shard of the DTensor from its current spec to the target spec. + """ + + if current_spec.mesh != target_spec.mesh: + # TODO: alltoall/permute reshuffling to change device_mesh if they are not the same + raise NotImplementedError("Cross device mesh comm not supported yet!") + + new_local_tensor = local_tensor + device_mesh = current_spec.mesh + + my_coordinate = device_mesh.get_coordinate() + + if my_coordinate is None: + # if rank is not part of mesh, we skip redistribute and simply return local_tensor, + # which should be an empty tensor + return local_tensor + + has_symints = any(isinstance(s, torch.SymInt) for s in current_spec.shape) or any( + isinstance(s, torch.SymInt) for s in target_spec.shape + ) + if has_symints: + transform_infos = _gen_transform_infos_non_cached(current_spec, target_spec) + else: + transform_infos = _gen_transform_infos(current_spec, target_spec) + + for transform_info in transform_infos: + i = transform_info.mesh_dim + current, target = transform_info.src_dst_placements + device_mesh.size(mesh_dim=i) + + if current == target: + # short cut, just use the original local tensor + new_local_tensor = local_tensor + continue + + logger.debug("redistribute from %s to %s on mesh dim %s", current, target, i) + + if target.is_replicate(): + # Case 1: target is Replicate + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_value( + local_tensor, device_mesh, i + ) + elif current.is_shard(): + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + elif target.is_shard(): + # Case 2: target is Shard + target_placement = cast(Shard, target) + if current.is_partial(): + partial_spec = cast(Partial, current) + new_local_tensor = partial_spec._reduce_shard_value( + local_tensor, device_mesh, i, target_placement + ) + elif current.is_replicate(): + # split the tensor and return the corresponding cloned local shard + new_local_tensor = target_placement._replicate_to_shard( + local_tensor, device_mesh, i, my_coordinate[i] + ) + else: + assert current.is_shard(), ( + f"Current placement should be shard but found {current}" + ) + shard_spec = cast(Shard, current) + if shard_spec.dim != target_placement.dim: + new_local_tensor = shard_spec._to_new_shard_dim( + local_tensor, + device_mesh, + i, + transform_info.logical_shape, + target_placement.dim, + ) + elif target.is_partial(): + if current.is_replicate(): + partial_spec = cast(Partial, target) + # skip the replicate to partial transformation when we are in backward pass + # In this case we keep the grad as replicate, this is because we don't + # want to convert the replicated gradients back to partial, although + # that's logically conform with the same layout, converting the gradients + # back to partial is actually useless as you would have to do reduce later + # which would be more expensive than keeping it replicate! For this reason, + # we keep the replicate grad here. + new_local_tensor = ( + partial_spec._partition_value(local_tensor, device_mesh, i) + if not is_backward + else local_tensor + ) + elif current.is_shard(): + if not is_backward: + raise RuntimeError( + f"redistribute from {current} to {target} not supported yet" + ) + # for backward shard -> partial, we just need to convert the shard to replicate + current_placement = cast(Shard, current) + new_local_tensor = current_placement._to_replicate_tensor( + local_tensor, device_mesh, i, transform_info.logical_shape + ) + else: + # partial -> partial no op, should never hit + new_local_tensor = local_tensor + + local_tensor = new_local_tensor + + if not async_op and isinstance(new_local_tensor, funcol.AsyncCollectiveTensor): + new_local_tensor = new_local_tensor.wait() + + return new_local_tensor + + +class Redistribute(torch.autograd.Function): + @staticmethod + def forward( # type: ignore[override] + # pyre-fixme[2]: Parameter must be annotated. + ctx, + input: "dtensor.DTensor", + device_mesh: DeviceMesh, + placements: tuple[Placement, ...], + async_op: bool = False, + forward_dtype: Optional[torch.dtype] = None, + backward_dtype: Optional[torch.dtype] = None, + ): + ctx.async_op = async_op + ctx.backward_dtype = backward_dtype + ctx.original_dtype = input._local_tensor.dtype + + if forward_dtype is not None and forward_dtype != input._local_tensor.dtype: + local_tensor = input._local_tensor.to(dtype=forward_dtype) + current_spec = DTensorSpec( + mesh=device_mesh, + placements=input._spec.placements, + tensor_meta=TensorMeta( + shape=input.shape, + stride=input.stride(), + dtype=forward_dtype, + ), + ) + else: + local_tensor = input._local_tensor + current_spec = input._spec + + ctx.current_spec = current_spec + + if current_spec.placements != placements: + target_spec = DTensorSpec( + device_mesh, placements, tensor_meta=current_spec.tensor_meta + ) + + output = redistribute_local_tensor( + local_tensor, current_spec, target_spec, async_op=async_op + ) + else: + # use the same local tensor if placements are the same. + output = local_tensor + target_spec = current_spec + + return dtensor.DTensor( + output, + target_spec, + requires_grad=input.requires_grad, + ) + + @staticmethod + def backward(ctx, grad_output: "dtensor.DTensor"): # type: ignore[override] + previous_spec = ctx.current_spec + async_op = ctx.async_op + backward_dtype = ctx.backward_dtype or ctx.original_dtype + + if backward_dtype != grad_output._local_tensor.dtype: + local_tensor = grad_output._local_tensor.to(dtype=backward_dtype) + current_spec = DTensorSpec( + mesh=grad_output._spec.device_mesh, + placements=grad_output._spec.placements, + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=backward_dtype, + ), + ) + previous_spec = DTensorSpec( + mesh=previous_spec.device_mesh, + placements=previous_spec.placements, + tensor_meta=current_spec.tensor_meta, + ) + else: + local_tensor = grad_output._local_tensor + current_spec = grad_output._spec + + output = redistribute_local_tensor( + local_tensor, + current_spec, + previous_spec, + async_op=async_op, + is_backward=True, + ) + + if output.dtype != ctx.original_dtype: + output = output.to(ctx.original_dtype) + + # normalize the target placement to replicate if it is partial + normalized_placements: list[Placement] = [] + for previous_placement in previous_spec.placements: + if previous_placement.is_partial(): + # keep target placement to replicate instead of partial in this case + normalized_placements.append(Replicate()) + else: + normalized_placements.append(previous_placement) + + spec = DTensorSpec( + previous_spec.device_mesh, + tuple(normalized_placements), + tensor_meta=TensorMeta( + shape=grad_output.shape, + stride=grad_output.stride(), + dtype=output.dtype, + ), + ) + output_dtensor = dtensor.DTensor( + output, + spec, + requires_grad=grad_output.requires_grad, + ) + + return ( + output_dtensor, + None, + None, + None, + None, + None, + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_sharding_prop.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_sharding_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..ab0fc73133227ae6deefd73016607005a0467884 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_sharding_prop.py @@ -0,0 +1,532 @@ +# mypy: allow-untyped-defs +import threading +from collections.abc import Sequence +from functools import lru_cache +from itertools import chain +from typing import Callable, cast, Optional, Union + +import torch +from torch._ops import OpOverload +from torch._subclasses import FakeTensorMode +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpInfo, + OpSchema, + OpSpec, + OpStrategy, + OutputSharding, + OutputSpecType, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._utils import ( + compute_local_shape_and_global_offset, + compute_local_stride, +) + + +aten = torch.ops.aten + + +def _length(obj) -> int: + if obj is None: + return 0 + if not isinstance(obj, Sequence): + return 1 + return len(obj) + + +class LocalLRUCache(threading.local): + def __init__(self, user_function: Callable) -> None: + self.cache = lru_cache(None)(user_function) + + def __call__(self, *args, **kwargs) -> object: + return self.cache(*args, **kwargs) + + def cache_info(self): + return self.cache.cache_info() + + +class ShardingPropagator: + def __init__(self) -> None: + self.op_to_rules: dict[OpOverload, Callable[[OpSchema], OutputSharding]] = {} + self.op_strategy_funcs: dict[ + OpOverload, + Callable[[OpSchema], StrategyType], + ] = {} + # op map to save static argnum to decide to reuse sharding prop cache or + # re-run sharding prop + self.op_to_schema_info: dict[OpOverload, RuntimeSchemaInfo] = {} + self.propagate_op_sharding = LocalLRUCache( + self.propagate_op_sharding_non_cached + ) + # op map to save indices of shape (and stride) args which may need to be + # modified in sharding prop + self.op_to_shape_and_stride_idx: dict[ + OpOverload, Union[int, tuple[int, int]] + ] = { + # new factory ops + aten.new_empty.default: 1, + aten.new_full.default: 1, + aten.new_ones.default: 1, + aten.new_zeros.default: 1, + aten.new_empty_strided.default: (1, 2), + # view ops + aten.expand.default: 1, + aten.reshape.default: 1, + aten.view.default: 1, + aten._unsafe_view.default: 1, + aten.select_backward.default: 1, + aten.slice_backward.default: 1, + } + + def register_sharding_prop_rule( + self, + op_overload: OpOverload, + rule_func: Callable[[OpSchema], OutputSharding], + schema_info: Optional[RuntimeSchemaInfo] = None, + ): + """ + Register a sharding propagation rule for an operator. + """ + self.op_to_rules[op_overload] = rule_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def register_op_strategy( + self, + op_overload: OpOverload, + strategy_func: Callable[[OpSchema], StrategyType], + schema_info: Optional[RuntimeSchemaInfo] = None, + ): + """ + Register a sharding strategy generator for an operator. + """ + self.op_strategy_funcs[op_overload] = strategy_func + if schema_info is not None: + self.op_to_schema_info[op_overload] = schema_info + + def _propagate_tensor_meta_non_cached( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + """ + Propagate the tensor metadata, it could either return a TensorMeta + or a list/tuple of TensorMetas + """ + if op_schema.op == aten.equal.default: + # data dependent ops can't be used for fake propagation + return None + + # NOTE: We must call the tracing in fake tensor mode so that it + # avoids materializing memory + with FakeTensorMode(): + fake_args = op_schema.gen_fake_args() + fake_kwargs = op_schema.gen_fake_kwargs() + fake_out = op_schema.op(*fake_args, **fake_kwargs) + + if isinstance(fake_out, torch.Tensor): + return TensorMeta( + shape=fake_out.shape, stride=fake_out.stride(), dtype=fake_out.dtype + ) + + elif isinstance(fake_out, (tuple, list)): + tensor_meta_list: list[Optional[TensorMeta]] = [] + for fake_out_item in fake_out: + if isinstance(fake_out_item, torch.Tensor): + tensor_meta_list.append( + TensorMeta( + shape=fake_out_item.shape, + stride=fake_out_item.stride(), + dtype=fake_out_item.dtype, + ) + ) + else: + tensor_meta_list.append(None) + return ( + tuple(tensor_meta_list) + if isinstance(fake_out, tuple) + else tensor_meta_list + ) + else: + # if fake is not a tensor or tuple of tensor, return as none + return None + + @lru_cache # noqa: B019 + def _propagate_tensor_meta( + self, op_schema: OpSchema + ) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]: + return self._propagate_tensor_meta_non_cached(op_schema) + + def _wrap_output_spec_tensor_meta( + self, + op: OpOverload, + output_specs: OutputSpecType, + output_tensor_meta: Union[None, TensorMeta, Sequence[Optional[TensorMeta]]], + ) -> None: + """ + Wrap the output_specs with the tensor metadata from the output. + """ + + if isinstance(output_specs, DTensorSpec): + if not isinstance(output_tensor_meta, TensorMeta): + # Either error due to ShardingPropagator or due to incorrect OutputSpec + if not isinstance(output_tensor_meta, (tuple, list)): + raise ValueError( + "ShardingPropagator error: output does not have an associated " + "TensorMeta" + ) + raise ValueError( + f"For the op {op.name()}, `output_specs` has 1 output which does " + "not equal the " + f"number of op outputs: {len(output_tensor_meta)}." + ) + output_specs.tensor_meta = output_tensor_meta + elif isinstance(output_specs, (tuple, list)): + if not isinstance(output_tensor_meta, (tuple, list)) or len( + output_specs + ) != len(output_tensor_meta): + raise ValueError( + f"For the op {op.name()}, `output_specs` has {len(output_specs)} " + "outputs which does not equal the " + f"number of op outputs {_length(output_tensor_meta)}." + ) + + for i, spec in enumerate(output_specs): + if isinstance(spec, DTensorSpec): + output_tensor_meta_i = output_tensor_meta[i] + if not isinstance(output_tensor_meta_i, TensorMeta): + # NOTE: aten.convolution_backward.default is an exception and it + # needs extra handling because the first Tensor in the output + # tuple can be `None` if the input Tensor to convolution op has + # `requires_grad=False` (e.g. convolution layer is the first + # layer in the model). We explicitly allow its corresponding + # TensorMeta to be `None`. + if ( + op == aten.convolution_backward.default + and i == 0 + and output_tensor_meta_i is None + ): + assert isinstance(output_specs, list) + output_specs[i] = None + continue + else: + raise ValueError( + f"ShardingPropagator error: output {i} of {op.name()} " + "does not have an associated TensorMeta" + ) + + spec.tensor_meta = output_tensor_meta_i + + def _wrap_with_op_strategy(self, op_schema: OpSchema) -> OpSchema: + """ + wrap a op_schema that contains DTensorSpec to another op_schema that contains + OpStrategy/TupleStrategy, the returned op_schema is then used for sharding + strategy propagation on pytorch operators. + """ + + def spec_to_strategy(spec: object) -> object: + if isinstance(spec, DTensorSpec): + return OpStrategy([OpSpec(spec)]) + elif ( + isinstance(spec, (list, tuple)) + and len(spec) > 0 + and isinstance(spec[0], DTensorSpec) + ): + # tensor list create tuple strategy + tuple_strategy = [spec_to_strategy(s) for s in spec] + tuple_strategy = cast(Sequence[StrategyType], tuple_strategy) + return TupleStrategy( + tuple(tuple_strategy) if isinstance(spec, tuple) else tuple_strategy + ) + else: + return spec + + args_op_strategy = [spec_to_strategy(i) for i in op_schema.args_schema] + + kwargs_op_strategy = { + k: spec_to_strategy(v) for k, v in op_schema.kwargs_schema.items() + } + + return OpSchema( + op=op_schema.op, + args_schema=tuple(args_op_strategy), + kwargs_schema=kwargs_op_strategy, + ) + + def propagate(self, op_info: OpInfo) -> None: + # We cannot use an lru cache if we know that inputs will have dynamic shapes, + # because SymInts are not hashable. + # This is generally ok because this only happens during tracing in torch.compile, + # and tracing does not need to be as fast as eagermode DTensor usages. + if op_info.schema.has_symints: + output_sharding = self.propagate_op_sharding_non_cached(op_info.schema) + else: + output_sharding = cast( + OutputSharding, self.propagate_op_sharding(op_info.schema) + ) + op_info.output_sharding = output_sharding + + def propagate_op_sharding_non_cached(self, op_schema: OpSchema) -> OutputSharding: + """ + Propagate the sharding for an operator given the op_schema. + """ + # special case op, we don't need to propagate for local + # scalar. TODO: figure out a better way to handle this + if op_schema.op is aten._local_scalar_dense.default: + return OutputSharding(None, op_schema) + + out_tensor_meta = self._propagate_tensor_meta_non_cached(op_schema) + + if op_schema.op in self.op_strategy_funcs: + # wrap the op_schema with op strategy for sharding strategy propagation + strategy_schema = self._wrap_with_op_strategy(op_schema) + + # run sharding strategy propagation/generation + op_strategy = self.op_strategy_funcs[op_schema.op](strategy_schema) + + if isinstance(op_strategy, OpStrategy): + # single Op strategy + output_strategy = self._select_strategy(op_strategy) + + # check if we need to redistribute the input + needs_redistribute = False + expected_input_specs: list[DTensorSpec] = [] + + # in case where the op does not specify input_specs and output_specs + # is a DTensorSpec, we use output_specs as the spec for each DTensor + # input arg. + if output_strategy.input_specs is None: + assert isinstance(output_strategy.output_specs, DTensorSpec) + + for idx, input_spec in enumerate(op_schema.args_spec): + desired_spec = ( + output_strategy.output_spec + if output_strategy.input_specs is None + else output_strategy.input_specs[idx] + ) + expected_input_specs.append( + desired_spec.shallow_copy_with_tensor_meta( + input_spec.tensor_meta + ) + ) + if input_spec.placements != desired_spec.placements: + needs_redistribute = True + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(expected_input_specs), {} + ) + suggestion_schema._inplace_rewrap_schema_suggestion(op_schema) + + # shape and stride args need to be modified for + # view ops and new factory ops, potentially + if op_schema.op in self.op_to_shape_and_stride_idx: + assert isinstance(output_strategy.output_spec, DTensorSpec) + # It happens when the output has the same shape as the input + # and the input placements are not all Replicate(). + if output_strategy.output_spec.is_sharded(): + schema = suggestion_schema or op_schema + assert isinstance(out_tensor_meta, TensorMeta) + suggestion_schema = self._adjust_shape_and_stride_args( + out_tensor_meta, schema, output_strategy.output_spec + ) + needs_redistribute = True + + # construct output spec for the op + if op_schema.return_type_tuple_tensor_like(): + # for ops that return multiple tensors and the output_specs is not + # a tuple, we use a tuple of that single output spec as the new + # output_specs + output_specs: OutputSpecType = output_strategy.output_specs + if isinstance(output_specs, DTensorSpec): + output_specs = tuple( + [ + # create a new DTensorSpec with the same placement as the + # output_specs in output_strategy + DTensorSpec( + mesh=output_specs.mesh, + placements=output_specs.placements, + tensor_meta=output_specs.tensor_meta, + ) + for _ in range(len(op_schema.op._schema.returns)) + ] + ) + elif op_schema.return_type_tensor(): + output_specs = output_strategy.output_specs + else: + output_specs = None + + output_sharding = OutputSharding( + output_specs, + suggestion_schema, + needs_redistribute=needs_redistribute, + ) + elif isinstance(op_strategy, TupleStrategy): + # tuple strategy output sharding processing + # runtime select OpSpec for each TupleStrategy input arg + selected_strategies: list[OpSpec] = [] + out_spec_list: list[DTensorSpec] = [] + for strategy in op_strategy.childs: + assert isinstance(strategy, OpStrategy) + selected_strategy = self._select_strategy(strategy) + selected_strategies.append(selected_strategy) + out_spec_list.append(selected_strategy.output_spec) + + needs_redistribute = False + suggestion_args: list[object] = [] + tensor_or_list_tensor_arg_idx = 0 + + for arg in op_schema.args_schema: + if ( + arg + and isinstance(arg, (list, tuple)) + and isinstance(arg[0], DTensorSpec) + ): + expected_input_spec_list: list[DTensorSpec] = [] + for idx, arg_spec in enumerate(arg): + expected_input_spec = selected_strategies[idx].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg_spec.tensor_meta + ) + ) + if arg_spec.placements != expected_input_spec.placements: + needs_redistribute = True + expected_input_spec_list.append(expected_input_spec) + suggestion_args.append( + tuple(expected_input_spec_list) + if isinstance(arg, tuple) + else expected_input_spec_list + ) + tensor_or_list_tensor_arg_idx += 1 + + elif isinstance(arg, DTensorSpec): + expected_input_spec = selected_strategies[0].input_spec( + tensor_or_list_tensor_arg_idx + ) + expected_input_spec = ( + expected_input_spec.shallow_copy_with_tensor_meta( + arg.tensor_meta + ) + ) + if arg.placements != expected_input_spec.placements: + needs_redistribute = True + suggestion_args.append(expected_input_spec) + tensor_or_list_tensor_arg_idx += 1 + else: + suggestion_args.append(arg) + + suggestion_schema = None + if needs_redistribute: + suggestion_schema = OpSchema( + op_schema.op, tuple(suggestion_args), op_schema.kwargs_schema + ) + + output_sharding = OutputSharding( + tuple(out_spec_list) if out_tensor_meta is not None else None, + suggestion_schema, + needs_redistribute=needs_redistribute, + ) + else: + raise ValueError("Unsupported op strategy type") + + # associate the output sharding with the output tensor metadata + self._wrap_output_spec_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + return output_sharding + elif op_schema.op in self.op_to_rules: + # propagate the sharding with rule + sharding_prop_func = self.op_to_rules[op_schema.op] + + # step 1. there's sharding propagation rule, run + # sharding propagation to get the output sharding + try: + output_sharding = sharding_prop_func(op_schema) + except NotImplementedError as e: + raise e + except Exception as e: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}.\nError: {e}" + ) from e + + # step 2. if can't get output_spec from sharding + # propagation (i.e. no rules apply for input + # placements), we return the output sharding + # with schema suggestions, which can be used to + # decide how to do redistribute on inputs + if output_sharding.output_spec is None: + if output_sharding.redistribute_schema is None: + raise RuntimeError( + f"Sharding propagation failed on op {op_schema}!" + ) + else: + # we do auto redistribute on inputs if necessary + # run sharding propagation again with suggested schema + propagation_res = sharding_prop_func( + output_sharding.redistribute_schema + ) + # we set the output sharding with the new propagation result + # so that dispatching know both output_spec and redistribute_schema + # exist, which indicates a reshard is needed + output_sharding.output_spec = propagation_res.output_spec + output_sharding.needs_redistribute = True + + # associate the output sharding with the output tensor metadata + self._wrap_output_spec_tensor_meta( + op_schema.op, output_sharding.output_spec, out_tensor_meta + ) + + return output_sharding + else: + raise NotImplementedError( + f"Operator {op_schema.op} does not have a sharding strategy registered." + ) + + def _select_strategy(self, strategy: OpStrategy) -> OpSpec: + if len(strategy.strategies) == 1: + # short cut with only one possible OpSpec + return strategy.strategies[0] + + op_spec_costs: list[float] = [] + for op_spec in strategy.strategies: + assert op_spec.redistribute_cost is not None, ( + "must set redistribute cost each OpSpec!" + ) + redistribute_cost = sum(chain.from_iterable(op_spec.redistribute_cost)) + op_spec_costs.append(redistribute_cost) + + # for eager execution, we just select the one with the minimal redistribute cost + return strategy.strategies[op_spec_costs.index(min(op_spec_costs))] + + def _adjust_shape_and_stride_args( + self, + out_tensor_meta: TensorMeta, + schema: OpSchema, + spec: DTensorSpec, + ) -> OpSchema: + shape_stride_idx = self.op_to_shape_and_stride_idx[schema.op] + if isinstance(shape_stride_idx, tuple): + shape_idx, stride_idx = shape_stride_idx + else: + shape_idx = shape_stride_idx + stride_idx = None + + expected_input_schema = list(schema.args_schema) + # adjust shape to be the same as that of the _local_tensor + # of the DTensor input arg at index 0, which is inferred + expected_input_schema[shape_idx], _ = compute_local_shape_and_global_offset( + out_tensor_meta.shape, spec.mesh, spec.placements + ) + + # adjust the stride arg for aten.new_empty_strided.default + if stride_idx: + expected_input_schema[stride_idx] = compute_local_stride( + out_tensor_meta.stride, spec.mesh, spec.placements + ) + + return OpSchema(schema.op, tuple(expected_input_schema), schema.kwargs_schema) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_shards_wrapper.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_shards_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..224c490cc5520529aa409a5df02922d09c6a8a94 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_shards_wrapper.py @@ -0,0 +1,359 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any + +import torch +from torch.distributed.checkpoint.metadata import ( + ChunkStorageMetadata, + MetadataIndex, + TensorProperties, + TensorStorageMetadata, +) +from torch.distributed.checkpoint.planner import ( + TensorWriteData, + WriteItem, + WriteItemType, +) + + +aten = torch.ops.aten + + +class LocalShardsWrapper(torch.Tensor): + """ + A wrapper class to hold local shards of a DTensor. + This class is used largely for checkpointing purposes and implicity subtypes + the _Checkpointable protocol. + """ + + __slots__ = ["_local_shards", "_storage_meta"] + _local_shards: list[torch.Tensor] + _storage_meta: TensorStorageMetadata + + @staticmethod + def __new__( + cls, local_shards: list[torch.Tensor], local_offsets: list[tuple[int, ...]] + ) -> "LocalShardsWrapper": + assert all( + tensor.device == local_shards[0].device for tensor in local_shards[1:] + ) + + # if empty shard, we create a empty tensor + if len(local_shards) == 0: + r = torch.Tensor._make_wrapper_subclass( + cls, + torch.Size([0, 0]), + ) + r._local_shards = [] + r._storage_meta = TensorStorageMetadata( + properties=TensorProperties(), + size=torch.Size([0, 0]), + chunks=[ + ChunkStorageMetadata( + offsets=torch.Size([0, 0]), sizes=torch.Size([0, 0]) + ) + ], + ) + return r + + # we calculate the total tensor size by "concat" on second tensor dimension + cat_tensor_shape = list(local_shards[0].size()) + if len(local_shards) > 1 and local_shards[0].ndim == 2: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[1] += shard.size()[1] + + # in cases of sharding optimizer rowwise, we calculate total tensor size by "concat" on first tensor dimension + if len(local_shards) > 1 and local_shards[0].ndim == 1: # column-wise sharding + for shard in local_shards[1:]: + cat_tensor_shape[0] += shard.size()[0] + + wrapper_properties = TensorProperties.create_from_tensor(local_shards[0]) + wrapper_shape = torch.Size(cat_tensor_shape) + chunks_meta = [ + ChunkStorageMetadata( + offsets=torch.Size(offset), + sizes=shard.size(), + ) + for shard, offset in zip(local_shards, local_offsets) + ] + + r = torch.Tensor._make_wrapper_subclass( + cls, + torch.Size(cat_tensor_shape), + ) + r._local_shards = local_shards + r._storage_meta = TensorStorageMetadata( + properties=wrapper_properties, + size=wrapper_shape, + chunks=chunks_meta, + ) + + return r + + # necessary for ops dispatching from this subclass to its local shards + @classmethod + def __torch_dispatch__(cls, func, types, args=(), kwargs=None): # type: ignore[override] + kwargs = kwargs or {} + + dispatcher = { + torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor, + torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor, + aten._to_copy.default: cls.handle_to_copy, + aten.view.default: cls.handle_view, + aten.equal.default: cls.handle_equal, + aten.detach.default: cls.handle_detach, + aten.clone.default: cls.handle_clone, + aten.new_empty.default: cls.handle_new_empty, + } + + if func in dispatcher: + return dispatcher[func](args, kwargs) + else: + raise NotImplementedError( + f"{func} is not supported for LocalShardsWrapper!" + ) + + @staticmethod + def handle_all_gather_into_tensor(args, kwargs) -> torch.Tensor: + dim = args[0].local_sizes()[0][1] + cat_tensor = torch.cat( + [t.view(-1) for t in args[0].local_shards()], dim=0 + ).view(-1, dim) + return torch.ops._c10d_functional.all_gather_into_tensor.default( + cat_tensor, *args[1:], **kwargs + ) + + @staticmethod + def handle_wait_tensor(args, kwargs) -> torch.Tensor: + return torch.ops._c10d_functional.wait_tensor(args[0]) + + @staticmethod + def handle_to_copy(args, kwargs) -> torch.Tensor: + res_shards_list = [ + aten._to_copy.default(shard, *args[1:], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + def handle_view(args, kwargs) -> "LocalShardsWrapper": + view_shape = args[1] + res_shards_list = [] + if len(args[0].local_shards()) > 1: + if args[0].local_shards()[0].ndim == 2: + assert ( + args[0].storage_metadata().size[0] == view_shape[0] + and args[0].storage_metadata().size[1] == view_shape[1] + ) + # This accounts for a DTensor quirk, when multiple shards are present on a rank, DTensor on + # init calls view_as() on the global tensor shape + # will fail because the view shape is not applicable to individual shards. + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + elif args[0].local_shards()[0].ndim == 1: + assert args[0].storage_metadata().size[0] == view_shape[0] + # This case is for optimizer sharding as regardles of sharding type, optimizer state is row wise sharded + res_shards_list = [ + aten.view.default(shard, shard.shape, **kwargs) + for shard in args[0].local_shards() + ] + else: + raise NotImplementedError("No support for view on tensors ndim > 2") + else: + # view is called per shard + res_shards_list = [ + aten.view.default(shard, args[1], **kwargs) + for shard in args[0].local_shards() + ] + return LocalShardsWrapper(res_shards_list, args[0].local_offsets()) + + @staticmethod + def handle_equal(args, kwargs) -> bool: + """ + LocalShardsWrapper equal impl also checks for equality of storage metadata + and the order of shards + """ + a, b = args[0], args[1] + if len(a.local_shards()) != len(b.local_shards()): + return False + if not all( + aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards()) + ): + return False + if not a.storage_metadata() == b.storage_metadata(): + return False + return True + + @staticmethod + def handle_detach(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + deatched_local_shards = [ + aten.detach.default(shard) for shard in self_ls.local_shards() + ] + self_ls._local_shards = deatched_local_shards + self_ls._storage_meta.properties.requires_grad = False + return self_ls + + @staticmethod + def handle_clone(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + desired_memory_format = kwargs.get("memory_format", None) + if desired_memory_format and desired_memory_format != torch.preserve_format: + raise NotImplementedError( + f"{desired_memory_format} is not supported for LocalShardsWrapper!" + ) + cloned_local_shards = [ + shard.clone(memory_format=desired_memory_format) + for shard in self_ls._local_shards + ] + return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets()) + + @staticmethod + def handle_new_empty(args, kwargs) -> "LocalShardsWrapper": + self_ls = args[0] + return LocalShardsWrapper( + [torch.empty_like(shard) for shard in self_ls._local_shards], + self_ls.local_offsets(), + ) + + @property + def device(self) -> torch._C.device: # type: ignore[override] + return ( + self._local_shards[0].device if self._local_shards else torch.device("meta") + ) + + @property + def is_meta(self) -> bool: # type: ignore[override] + return self._local_shards[0].is_meta if self._local_shards else True + + def is_pinned(self) -> bool: # type: ignore[override] + return self._storage_meta.properties.pin_memory + + def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper": + self._storage_meta.properties.requires_grad = requires_grad + [shard.requires_grad_(requires_grad) for shard in self._local_shards] + return self + + def local_shards(self) -> list[torch.Tensor]: + """ + Returns a list of :class:`torch.Tensor' corresponding to the + local shards for this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return self._local_shards + + def local_sizes(self) -> list[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local sizes for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.sizes for chunk in self._storage_meta.chunks] + + def local_offsets(self) -> list[torch.Size]: + """ + Returns a list of :class:`torch.Size' corresponding to the + local offsets for the shards on this rank. Returns an empty list if the current rank + does not host any shards for this Tensor. + """ + return [chunk.offsets for chunk in self._storage_meta.chunks] + + @property + def local_chunks(self) -> list[ChunkStorageMetadata]: + """ + Returns a :class:`list[ChunkStorageMetadata]` object corresponding to the + metadata for each tensor shard + """ + return self._storage_meta.chunks + + def storage_metadata(self) -> TensorStorageMetadata: + """ + Returns a :class:`TensorStorageMetadata` object corresponding to the + metadata for the local tensor on current rank + """ + return self._storage_meta + + def is_empty_shard(self) -> bool: + """ + Returns a :class:`bool` object indicating if the local tensor on current rank + is an empty tensor + """ + return self._storage_meta.size[0] == 0 and self._storage_meta.size[1] == 0 + + def __create_write_items__(self, fqn: str, object: Any) -> list[WriteItem]: + """ + For compatibility with DCP, we support creation of WriteItems + such that they can be saved properly. + """ + return [ + WriteItem( + index=MetadataIndex(fqn, chunks.offsets), + type=WriteItemType.SHARD, + tensor_data=TensorWriteData( + chunk=ChunkStorageMetadata( + offsets=chunks.offsets, + sizes=chunks.sizes, + ), + properties=self._storage_meta.properties, + size=object.size(), + ), + ) + for tensor, chunks in zip(self.local_shards(), self.local_chunks) + ] + + def __create_chunk_list__(self) -> list[ChunkStorageMetadata]: + """ + For compatibility with DCP, we support creation of chunk lists + such that they can be saved properly. + """ + return self._storage_meta.chunks + + def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor: + """ + For compatibility with DCP, we support finding shard based on index + Return a 'torch.Tensor' shard based on 'MetadataIndex'. + """ + # Fast lookup path + if index.index is not None: + if ( + len(self._local_shards) > index.index + and self._storage_meta.chunks[index.index].offsets == index.offset + ): + return self._local_shards[index.index] + + if index.offset is not None: + for shard, chunk in zip(self._local_shards, self._storage_meta.chunks): + if chunk.offsets == index.offset: + return shard + + # Empty shard case + if len(self._local_shards) == 0 and self._storage_meta.chunks[ + 0 + ].sizes == torch.Size([0, 0]): + return torch.empty(0) + + raise ValueError( + f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'" + ) + + def _get_tensor_size_bytes(self) -> int: + object_size = 0 + for shard in self.local_shards(): + object_size += shard.nelement() * shard.element_size() + return object_size + + def __hash__(self) -> int: + return id(self) + + def __repr__(self) -> str: # type: ignore[override] + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" + + def __str__(self) -> str: + return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}" diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_tp_conv.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_tp_conv.py new file mode 100644 index 0000000000000000000000000000000000000000..31bcc3ad3d6e2a72107be7cb3b7862e5aaf987dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_tp_conv.py @@ -0,0 +1,279 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +# implement matrix related ops for distributed tensor +from typing import cast + +import torch +import torch.distributed as dist +import torch.distributed.tensor._api as dtensor + + +aten = torch.ops.aten + + +def _requires_data_exchange(padding): + # TODO: whether there requires data exchange is currently determined by padding + return padding[1] != 0 + + +def _is_supported(input_size, kernel_size, stride, padding, dilation): + if dilation[1] != 1: + raise RuntimeError("Dilation must be 1 for tensor parallel convolution.") + if padding[1] != 0: + if stride[1] != 1: + raise RuntimeError( + "Stride must be 1 when there is padding for tensor parallel convolution." + ) + if kernel_size[3] // 2 > input_size[3]: + raise RuntimeError( + "kernel_size[3] // 2 should be less than or equal to input_size[3] for tensor parallel convolution." + ) + else: + if not (input_size[3] % stride[1] == 0 and stride[1] == kernel_size[3]): + raise RuntimeError( + "It requires that input_size[3] is divisible by stride[1] and stride[1] equals kernel_size[3] " + "when there is padding for tensor parallel convolution." + ) + return True + + +def _ring_send_recv_construct(in_tensor, d1, d2, left, right, rank, size): + # dist comms and reconstruct local input tensor + send_to_right = in_tensor[:, :, :, -d1:].contiguous() + send_to_left = in_tensor[:, :, :, :d2].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + in_tensor = torch.cat([in_tensor, recv_from_right], dim=-1) + elif rank == size - 1: + in_tensor = torch.cat([recv_from_left, in_tensor], dim=-1) + else: + in_tensor = torch.cat([recv_from_left, in_tensor, recv_from_right], dim=-1) + + return in_tensor + + +def _ring_send_recv_aggregate(grad_in_tensor, d1, d2, left, right, rank, size): + # dist comms and aggregate gradients for edge pixels + send_to_right = grad_in_tensor[:, :, :, -d2:].contiguous() + send_to_left = grad_in_tensor[:, :, :, :d1].contiguous() + recv_from_right = torch.zeros_like(send_to_left) + recv_from_left = torch.zeros_like(send_to_right) + + send_op_right = dist.P2POp(dist.isend, send_to_right, right) + send_op_left = dist.P2POp(dist.isend, send_to_left, left) + recv_op_right = dist.P2POp(dist.irecv, recv_from_right, right) + recv_op_left = dist.P2POp(dist.irecv, recv_from_left, left) + + reqs = dist.batch_isend_irecv( + [send_op_right, send_op_left, recv_op_left, recv_op_right] + ) + for req in reqs: + req.wait() + + if rank == 0: + grad_in_tensor = grad_in_tensor[:, :, :, :-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + elif rank == size - 1: + grad_in_tensor = grad_in_tensor[:, :, :, d1:] + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + else: + grad_in_tensor = grad_in_tensor[:, :, :, d1:-d2] + grad_in_tensor[:, :, :, -d1:] = torch.add( + grad_in_tensor[:, :, :, -d1:], recv_from_right + ) + grad_in_tensor[:, :, :, :d2] = torch.add( + grad_in_tensor[:, :, :, :d2], recv_from_left + ) + + +def tp_convolution( + op_call: torch._ops.OpOverload, + local_tensor_args: tuple[object, ...], + local_tensor_kwargs: dict[str, object], +) -> object: + assert op_call == aten.convolution.default + assert len(local_tensor_args) == 9 + + rank = dist.get_rank() + size = dist.get_world_size() + in_tensor = cast(torch.Tensor, local_tensor_args[0]) + weight = cast(torch.Tensor, local_tensor_args[1]) + stride, padding, dilation = local_tensor_args[3:6] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, list) + + if not _requires_data_exchange(padding): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[3] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = in_tensor + local_tensor_args = cast(tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step3 remove extra outputs from the results + padding_w = padding[1] + w = local_results.size(3) + if rank == 0: + local_results = local_results[:, :, :, : w - padding_w] + elif rank == size - 1: + local_results = local_results[:, :, :, padding_w:] + else: + local_results = local_results[:, :, :, padding_w : w - padding_w] + + return local_results + + +def tp_convolution_backward( + op_call: torch._ops.OpOverload, + local_tensor_args: tuple[object, ...], + local_tensor_kwargs: dict[str, object], +) -> object: + assert op_call == aten.convolution_backward.default + assert len(local_tensor_args) == 11 + + rank = dist.get_rank() + size = dist.get_world_size() + grad_out_tensor = cast(torch.Tensor, local_tensor_args[0]) + in_tensor = cast(torch.Tensor, local_tensor_args[1]) + weight = cast(torch.Tensor, local_tensor_args[2]) + stride, padding, dilation = local_tensor_args[4:7] + + assert _is_supported(in_tensor.shape, weight.shape, stride, padding, dilation) + assert isinstance(padding, list) + + if not _requires_data_exchange(padding): + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + return local_results + else: + # step 0 compute the overlap pixels of the input tensor + d = weight.shape[3] - 1 + d1 = d // 2 + d2 = d - d1 + assert d1 + d2 == d + right = (rank + 1) % size + left = (rank - 1 + size) % size + + # step1 reconstruct local input tensor + in_tensor = _ring_send_recv_construct( + in_tensor, d1, d2, left, right, rank, size + ) + + # step2 reconstruct local gradient output tensor + padding_w = padding[1] + if rank == 0: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (0, padding_w), "constant", 0 + ) + elif rank == size - 1: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, 0), "constant", 0 + ) + else: + grad_out_tensor = torch.nn.functional.pad( + grad_out_tensor, (padding_w, padding_w), "constant", 0 + ) + + # step3 feed local input tensor to op_call + local_tensor_args_list = list(local_tensor_args) + local_tensor_args_list[0] = grad_out_tensor + local_tensor_args_list[1] = in_tensor + local_tensor_args = cast(tuple[object, ...], local_tensor_args_list) + local_results = op_call(*local_tensor_args, **local_tensor_kwargs) + + # step4 aggregate gradients for edge pixels + grad_in_tensor = local_results[0] + if grad_in_tensor is not None: + grad_in_tensor = _ring_send_recv_aggregate( + grad_in_tensor, d1, d2, left, right, rank, size + ) + local_results = list(local_results) + local_results[0] = grad_in_tensor + + local_results = cast(tuple[object, ...], local_results) + + return local_results + + +def convolution_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + + # local propagation + local_results = tp_convolution( + op_call, tuple(op_info.local_args), op_info.local_kwargs + ) + + return dtensor.DTensor._op_dispatcher.wrap( + local_results, output_sharding.output_spec + ) + + +def convolution_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # Redistribute grad_output tensor to the same placement as input tensor + args = list(args) + assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor) + args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements) + args = tuple(args) + + # extract local tensor and sharding infos to a OpInfo + op_info = dtensor.DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + + # sharding propagation + dtensor.DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + + # local propagation + local_results = tp_convolution_backward( + op_call, tuple(op_info.local_args), op_info.local_kwargs + ) + + return dtensor.DTensor._op_dispatcher.wrap( + local_results, output_sharding.output_spec + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/_utils.py b/phivenv/Lib/site-packages/torch/distributed/tensor/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..07c005386a79b53965a5c80b5fee7ebbcc91a94d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/_utils.py @@ -0,0 +1,371 @@ +from collections import defaultdict +from collections.abc import Sequence +from typing import cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +import torch.distributed.tensor._api as dtensor +from torch._prims_common import ShapeType +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._dtensor_spec import DTensorSpec +from torch.distributed.tensor.placement_types import ( + _StridedShard, + Partial, + Placement, + Replicate, + Shard, +) + + +def _explicit_order_placements( + mesh_shape: ShapeType, placements: Sequence[Placement] +) -> Sequence[tuple[int, Placement]]: + """ + Replace Strided Shards with regular shards in an adjusted order. + + Returns a list of (mesh_dim, placement) tuples where the list order is the sharding order. + + ex. + [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] -> + [(0, Shard(0)), (2, Shard(0)), (1, Shard(0))] + + """ + if not len(placements) == len(mesh_shape): + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {len(mesh_shape)} mesh dims." + ) + ordered = [] + deferred_strided_placements = defaultdict(list) + strided_part_ended_for_dim = set() + for mesh_dim, p in enumerate(placements): + if isinstance(p, _StridedShard): + # validate the stride is the correct multiple of the meshdim and the earlier shard + deferred_strided_placements[p.dim].append((mesh_dim, p)) + + else: + ordered.append((mesh_dim, p)) + if isinstance(p, Shard): + if p.dim in strided_part_ended_for_dim: + raise NotImplementedError( + f"Strided sharding does not allow Shard() to appear after " + f"the strided part has ended. {p} at mesh dim {mesh_dim} in " + f"{placements} violates this assumption." + ) + + if p.dim in deferred_strided_placements: + strided_part_ended_for_dim.add(p.dim) + strided_placements = deferred_strided_placements.pop(p.dim) + aggregate_size = mesh_shape[mesh_dim] + while len(strided_placements) > 0: + strided_mesh_dim, strided = strided_placements.pop() + if not strided.split_factor == aggregate_size: + raise RuntimeError( + f"Can only convert _StridedShard to ordered Shard if split_factor({strided.split_factor})" + f" == aggregate mesh size ({aggregate_size})" + ) + aggregate_size *= mesh_shape[strided_mesh_dim] + ordered.append((strided_mesh_dim, Shard(p.dim))) + + return ordered + + +def compute_local_shape_and_global_offset( + global_shape: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> tuple[tuple[int, ...], tuple[int, ...]]: + """ + Compute the local tensor shape and the global offsets into the original tensor + of a DTensor on its current global rank. This is useful for checkpointing purpose. + + Example: + global_tensor = [[0, 1, 2, 3, 4], sharded on mesh (DP=2, TP=2) with (Shard(1), Shard(1)) + [10, 11, 12, 13, 14]] + + This table shows the return value of local_shape and global_offset for each rank. + (`local_tensor` is for illustration only). + + Note how the first coordinate of global_offset is always 0, corresponding to tensor dim 0 being replicated. + + Rank local_tensor local_shape global_offset + ------------------------------------------------------------- + 0 [[0, 1], (2, 2) (0, 0) + [10, 11]] + + 1 [[2], (2, 1) (0, 2) + [12]] + + 2 [[3], (2, 1) (0, 3) + [13]] + + 3 [[4], (2, 1) (0, 4) + [14]] + + Args: + global_shape (ShapeType): The global shape of the DTensor. + mesh (:class:`DeviceMesh`): The device mesh this DTensor is distributed on. + placements (Sequence[:class:`Placement`]]): The placements of the DTensor. + + Return: + local_shape: the shape of the DTensor's _local_tensor on the current rank. + global_offset: a tuple of offsets for each dimension of the global tensor shape, + identifying how this shard fits into the global tensor in each dimension. + + """ + return _compute_local_shape_and_global_offset( + global_shape, mesh.shape, mesh.get_coordinate(), placements + ) + + +# accept 'plain data types' to enable simpler unit testing without creating device mesh +def _compute_local_shape_and_global_offset( + global_shape: ShapeType, + mesh_shape: ShapeType, + my_coordinate: Optional[list[int]], + placements: Sequence[Placement], +) -> tuple[tuple[int, ...], tuple[int, ...]]: + ordered_placements = _explicit_order_placements(mesh_shape, placements) + + if my_coordinate is None: + # if rank not in the mesh, return empty offset + return ((0,), ()) + else: + local_shape = list(global_shape) + global_offset = [0] * len(global_shape) + for mesh_dim, placement in ordered_placements: + mesh_dim_size = mesh_shape[mesh_dim] + if isinstance(placement, Shard): + shard_dim = placement.dim + local_offset = [0] * len(global_shape) + assert shard_dim < len(local_shape), ( + f"Sharding dim {shard_dim} greater than tensor ndim {len(local_shape)}" + ) + shard_size, shard_offset = placement._local_shard_size_and_offset( + local_shape[shard_dim], + mesh_dim_size, + my_coordinate[mesh_dim], + ) + + local_shape[shard_dim] = shard_size + local_offset[shard_dim] = shard_offset + if shard_size == 0: + # Special case to fill in a standardized non-garbage value for the global_offset + # of zero-sized shards. This value is out of bounds of the tensor, so it won't conflict + # with any real offsets. DCP may rely on this value to de-duplicate shards. + global_offset[shard_dim] = global_shape[shard_dim] + else: + # On a given dimension, if the local_offset[shard_dim] is smaller than global_offset[shard_dim], + # it means that this dimension has been already sharded in previous placement. + # Therefore, we cannot simply replace the global_offset[shard_dim] with local_offset[shard_dim]. + # Instead, for the given shard_dim, we need to add local_offset[shard_dim] to existing global_offset[shard_dim]. + if global_offset[shard_dim] <= local_offset[shard_dim]: + global_offset[shard_dim] = local_offset[shard_dim] + else: + global_offset[shard_dim] += local_offset[shard_dim] + + # NOTE: the offset compute relies on the local shard index and it has no + # problem when strided sharding is not present. To correctly compute, we assume + # that the ``_StridedShard.split_factor`` field encodes how many partitions + # each local tensor will be further split into when sharding on higher mesh + # dimensions. However, this number is only correct if the DTensor is not + # sharded after the strided sharding completes. For example, + # [Shard(0), _StridedShard(0, split_factor=2), Shard(0)] is the placements + # where the DTensor's dim-0 is first sharded on device mesh dim-0, then on + # device mesh dim-2, and last on mesh dim-1. We define the + # "_StridedShard(0, split_factor=2), Shard(0)" part as the strided sharding + # part because strided sharding happens on mesh dim-1 and it was caused by + # the fact that sharding on dim-2 occurred ahead. In this case, there's no + # further sharding after this strided sharding part and ``split_factor`` + # correctly encodes the number. Another example is + # [_StridedShard(0, split_factor=2), Shard(0), Shard(0)] where the DTensor's + # dim-0 is first sharded on mesh dim-1, then on mesh dim-0, and last on mesh + # dim-2. This violates our assumption that no further sharding shall occur + # after the strided sharding part and ``split_factor`` won't correctly + # encode the number of further split. So far, the only case where _StridedShard + # placement would appear is FSDP2 + TP on 2D mesh and the above case could only + # happen on mesh of 3 or more dimensions. + # TODO: change this function to correctly address this. + # TODO: this logic can be applied to contiguous sharding as well + return tuple(local_shape), tuple(global_offset) + + +def compute_global_tensor_info( + tensor: torch.Tensor, mesh: DeviceMesh, placements: Sequence[Placement] +) -> tuple[list[int], list[int]]: + """ + Compute the global size and stride of a DTensor from the given local tensor. + The local size is multiplited by `world_size` per Sharding dim. + The local stride is multiplited by `world_size` per Sharding dim, as long as the + dimension is outside sharding dim. + + For example, if we have a local tensor with size (4, 8, 2) and stride (16, 1, 8). + If the DTensor placements are [Shard(2)] and world_size is 2; + then the global size is (4, 8, 4) and stride is (16 * 2, 1, 8). + + Args: + tensor (:class:`torch.Tensor`): + Local tensor which DTensor will be constructed from. + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Return: + tensor_shape: A List of int which specifies the size of DTensor which build + on top of the local tensor. + tensor_stride: A List of int which specifies the stride of DTensor. + """ + tensor_shape = list(tensor.size()) + tensor_stride = list(tensor.stride()) + for idx, placement in enumerate(placements): + mesh_dim_size = mesh.size(idx) + if placement.is_shard(): + shard_placement = cast(Shard, placement) + if shard_placement.dim < 0: + raise AssertionError( + "Shard placements should have negative dims normalized in " + f"the user-facing APIs: {shard_placement}" + ) + shard_dim = shard_placement.dim + + assert shard_dim < tensor.ndim, ( + f"Sharding dim {shard_dim} greater than tensor ndim {tensor.ndim} for placement number {idx}." + ) + + local_dim_size = tensor_shape[shard_dim] + tensor_shape[shard_dim] = local_dim_size * mesh_dim_size + + # recover tensor stride by modifying the stride that larger than + # the current stride on the shard_dim + for i in range(len(tensor_stride)): + if i != shard_dim and tensor_stride[i] >= tensor_stride[shard_dim]: + # rescale the stride by the shard size + tensor_stride[i] = tensor_stride[i] * mesh_dim_size + elif not isinstance(placement, (Replicate, Partial)): + raise RuntimeError(f"placement type {type(placement)} not supported!") + return tensor_shape, tensor_stride + + +def compute_global_tensor_shape( + shape: torch.Size, mesh: DeviceMesh, placements: Sequence[Placement] +) -> torch.Size: + """ + Compute the global size of a DTensor from the given local tensor shape, + the mesh and placements. Different from `compute_global_tensor_info`, + which assumes sharding is even, this util allgathers local shards' shapes + from all ranks and thus can support uneven sharding. + NOTE: Currently this function only supports 1D mesh. + + Args: + shape (:class:`torch.Size`): + Shape of the local tensor + mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for the DTensor. + placements (Sequence[:class:`Placement`]]): + The attribute of the DTensor that describes its layout + on the mesh topology. + + Return: + tensor_shape: Shape of the global DTensor. + """ + if len(placements) != 1: + raise NotImplementedError( + "compute_global_tensor_shape only supports 1 placement for now." + ) + + if len(placements) != mesh.ndim: + raise RuntimeError( + "Expected one placement per mesh dim, " + f"but found {len(placements)} placements and {mesh.ndim} mesh dims." + ) + + if isinstance(placements[0], Replicate): + return shape + elif isinstance(placements[0], Shard): + local_shape = torch.tensor(list(shape)) + gathered_shaped_tensors = [ + torch.empty_like(local_shape, device=local_shape.device) + for _ in range(mesh.size()) + ] + funcol.all_gather_inplace(gathered_shaped_tensors, local_shape) + sharded_dim_sum = 0 + shard_dim = placements[0].dim + other_dims = [d for d in range(mesh.ndim) if d != shard_dim] + for shape_tensor in gathered_shaped_tensors: + if not torch.equal(local_shape[other_dims], shape_tensor[other_dims]): + raise RuntimeError( + "Non-sharded dimentions should have identical size across ranks." + ) + shape_tensor_list = shape_tensor.tolist() + sharded_dim_sum += shape_tensor_list[shard_dim] + global_shape = list(shape) + global_shape[placements[0].dim] = sharded_dim_sum + return torch.Size(global_shape) + else: + raise NotImplementedError( + f"Placement type {type(placements[0])} not supported." + ) + + +def try_find_mesh_from_args( + op_call: torch._ops.OpOverload, args: Sequence[object] +) -> DeviceMesh: + """ + Find the device mesh object from args. + It returns None if no mesh is found. + NOTE: we can optimize this search if needed + """ + for arg in args: + if isinstance(arg, (dtensor.DTensor, DTensorSpec)): + return arg.device_mesh + elif ( + isinstance(arg, (list, tuple)) + and len(arg) > 0 + and isinstance(arg[0], (dtensor.DTensor, DTensorSpec)) + ): + return arg[0].device_mesh + + raise ValueError(f"Cannot find device mesh from args for op : {op_call}.") + + +def compute_local_stride( + global_stride: ShapeType, mesh: DeviceMesh, placements: Sequence[Placement] +) -> tuple[int, ...]: + """ + Compute the stride of a local tensor shard, given the global stride of the DTensor. + NOTE: Currently this function is assuming the DTensor is evenly shardable. + """ + stride_divisors = [1] * len(global_stride) + for mesh_idx, p in enumerate(placements): + if p.is_shard(): + i = cast(Shard, p).dim + # tensor dimension i is sharded on mesh dimension mesh_idx, + # so we need to divide all the strides larger than stride[i] + # (by the submesh size) + for j in range(len(global_stride)): + if global_stride[j] > global_stride[i]: + stride_divisors[j] *= mesh.size(mesh_idx) + return tuple( + global_stride[i] // stride_divisors[i] for i in range(len(global_stride)) + ) + + +def normalize_to_torch_size(size) -> torch.Size: # type: ignore[no-untyped-def] + """ + Unify variable types of size argument to torch.Size + Acceptable types include: + int, Sequence[int], Tuple[int], Tuple[Sequence[int]], + or torch.Size + """ + if isinstance(size, torch.Size): + return size + + if isinstance(size, int): + torch_size = [size] + elif len(size) == 1 and isinstance(size[0], Sequence): + torch_size = list(size[0]) + else: + torch_size = list(size) + return torch.Size(torch_size) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__init__.py b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f5e633d85c0ac968354a86c992f319e7a3dfad3e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__init__.py @@ -0,0 +1,24 @@ +# mypy: allow-untyped-defs +from torch.distributed.tensor.debug._comm_mode import CommDebugMode +from torch.distributed.tensor.debug._visualize_sharding import visualize_sharding + + +__all__ = ["CommDebugMode", "visualize_sharding"] + + +def _get_sharding_prop_cache_info(): + """ + Get the cache info for the sharding propagation cache, used for debugging purpose only. + This would return a named tuple showing hits, misses, maxsize and cursize of the sharding + propagator cache. + """ + from torch.distributed.tensor._api import DTensor + + return ( + DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding.cache_info() # type:ignore[attr-defined] + ) + + +# Set namespace for exposed private names +CommDebugMode.__module__ = "torch.distributed.tensor.debug" +visualize_sharding.__module__ = "torch.distributed.tensor.debug" diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..090376fb0c410a79573de155cbc172887fa432b2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af063b427535d86e3c1b393b728af4cbdaaeac8e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_comm_mode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c89863268432e58bee38c8c0c75fcb0e10c234ac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_op_coverage.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e53c11094a3afb55c75930ce970933bd55d1f602 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/__pycache__/_visualize_sharding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_comm_mode.py b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_comm_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..6b07e7713973aeebd735a2b249dccd3b7848c37b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_comm_mode.py @@ -0,0 +1,735 @@ +# mypy: allow-untyped-defs +import copy +import json +import re +import weakref +from collections import defaultdict +from typing import Any + +import torch +import torch.nn +from torch._guards import detect_fake_mode +from torch.autograd.graph import register_multi_grad_hook +from torch.distributed._tools.mod_tracker import ModTracker +from torch.distributed.tensor._api import DTensor +from torch.nn.modules.module import ( + register_module_forward_hook, + register_module_forward_pre_hook, + register_module_full_backward_pre_hook, +) +from torch.utils._python_dispatch import TorchDispatchMode +from torch.utils._pytree import tree_flatten + + +__all__ = ["CommDebugMode"] + +funcol_native = torch.ops._c10d_functional +funcol_py = torch.ops.c10d_functional +funcol_autograd = torch.ops._c10d_functional_autograd +c10d_ops = torch.ops.c10d + +NATIVE_TO_PY_MAPPING = { + funcol_native.all_gather_into_tensor: funcol_py.all_gather_into_tensor, + funcol_native.all_gather_into_tensor_coalesced: funcol_py.all_gather_into_tensor_coalesced, + funcol_native.all_reduce: funcol_py.all_reduce, + funcol_native.all_reduce_coalesced: funcol_py.all_reduce_coalesced, + funcol_native.all_to_all_single: funcol_py.all_to_all_single, + funcol_native.broadcast: funcol_py.broadcast, + funcol_native.reduce_scatter_tensor: funcol_py.reduce_scatter_tensor, + funcol_native.reduce_scatter_tensor_coalesced: funcol_py.reduce_scatter_tensor_coalesced, + # functional ops + funcol_autograd.all_to_all_single: funcol_py.all_to_all_single, +} + +c10d_collective_ops = { + c10d_ops._allgather_base_, + c10d_ops._reduce_scatter_base_, + c10d_ops.allgather_, + c10d_ops.allgather_coalesced_, + c10d_ops.allgather_into_tensor_coalesced_, + c10d_ops.allreduce_, + c10d_ops.allreduce_coalesced_, + c10d_ops.alltoall_, + c10d_ops.alltoall_base_, + c10d_ops.broadcast_, + c10d_ops.gather_, + c10d_ops.scatter_, + c10d_ops.reduce_, + c10d_ops.reduce_scatter_, + c10d_ops.reduce_scatter_tensor_coalesced_, +} + +trivial_ops = { + "aten.detach.default", + "aten.t.default", + "aten.view.default", + "aten._to_copy.default", + "aten.as_strided.default", + "aten.transpose.int", +} + + +class _CommModeModuleTracker(ModTracker): + """ + Inherits ModuleTracker and expands on its functionality to track the + parameters and sharding information of a model at a module-level + """ + + def __init__(self): + super().__init__() + self.module_helper_dict = {} + self.module_parameters_dict = {} + self.module_parents_dict = {} + self.register_forward_hook_handles = {} + self.parent_dict = {} + self.parent_list = [] + self.sharding_dict = {} + self.activation_checkpointing = False + self.name = "" + + def _fw_set_module_hook(self, mod, input, output): + """ + Updates the current module after module finishes running and + all other hooks are resolved + """ + + if self.is_bw: + self.activation_checkpointing = True + else: + self.activation_checkpointing = False + + if not self.activation_checkpointing: + # module is no longer parent of next modules + self.parent_list.pop() + + # set current module to previous parent module + self.name = self.parent_list[-1] + + def _fw_pre_hook(self, mod, input): + """ + This function is called before the forward pass of a module. It + collects the parameters and sharding information of a module and + stores it in a dictionary. + """ + if self.is_bw: + self.activation_checkpointing = True + else: + self.activation_checkpointing = False + + self.name = super()._get_mod_name(mod) + w_mod = weakref.ref(mod) + + # adds current sub-module to module tracker parent class + super()._get_append_fn(w_mod, self.name, False)() + + args, _ = tree_flatten(input) + tensors = [a for a in args if isinstance(a, torch.Tensor) and a.requires_grad] + if not self.is_bw and tensors: + register_multi_grad_hook( + tensors, super()._get_pop_fn(w_mod, self.name, True) + ) + + if not self.activation_checkpointing: + # contains information about module ordering and depth in the module tree + if self.name not in self.module_helper_dict: + self.module_helper_dict[self.name] = {} + + self.module_helper_dict[self.name]["module_type"] = ( + str(type(mod)).replace("<", "").replace(">", "") + ) + self.module_helper_dict[self.name]["depth"] = len(self.parents) - 1 + + for param_name, param in mod.named_parameters(recurse=False): + if self.name not in self.module_parameters_dict: + self.module_parameters_dict[self.name] = {} + + self.module_parameters_dict[self.name][param_name] = param.data + + if isinstance(param.data, DTensor): + key_name = self.name + "." + param_name + self.sharding_dict[key_name] = param.data.placements + + if "parameters" not in self.module_helper_dict[self.name]: + self.module_helper_dict[self.name]["parameters"] = {} + + self.module_helper_dict[self.name]["parameters"][param_name] = str( + param.data.placements + ) + + # used to store module's parents to ensure correctness in backward pass/checkpointing + if self.name not in self.module_parents_dict: + self.module_parents_dict[self.name] = copy.deepcopy(self.parents) + + # used to create parent-child module associations for json dumps + parent = self.parent_list[-1] + if parent not in self.parent_dict: + self.parent_dict[parent] = [] + + self.parent_dict[parent].append(self.name) + self.parent_list.append(self.name) + + self.register_forward_hook_handles[self.name] = mod.register_forward_hook( + self._fw_set_module_hook + ) + + def _fw_post_hook(self, mod, input, output): + """ + This function is called when the forward pass of a module is called. + It updates the module tracker and removes the module from parent data + """ + + super()._fw_post_hook(mod, input, output) + + def _bw_hook(self, mod, output): + """ + This function is called when the backward pass of a module is called. It + updates the current module for backward passes + """ + self.activation_checkpointing = False + self.name = super()._get_mod_name(mod) + + def __enter__(self): + self.activation_checkpointing = False + self.module_parameters_dict.clear() + self.sharding_dict.clear() + self.parent_dict.clear() + self.parent_list = ["Global"] + self.module_helper_dict.clear() + self.module_helper_dict["Global"] = {"depth": 0} + self.module_parents_dict.clear() + self.module_parents_dict["Global"] = set() + self._fw_pre_handle = register_module_forward_pre_hook(self._fw_pre_hook) + self._fw_post_handle = register_module_forward_hook(self._fw_post_hook) + self.register_forward_hook_handles.clear() + self._bw_handle = register_module_full_backward_pre_hook(self._bw_hook) + self.name = "Global" + + def __exit__(self, *args): + super().__exit__(*args) + self._bw_handle.remove() + + # removes all forward_hook handles added in the pre-hook + for handle in self.register_forward_hook_handles.values(): + handle.remove() + + def print_paramater_info(self): + print(self.module_parameters_dict) + + def print_sharding_info(self): + for key, value in self.sharding_dict.items(): + print(key + ": " + str(value)) + + +class CommDebugMode(TorchDispatchMode): + """ + :class:`CommDebugMode` is a context manager that counts the number of + functional collectives within its context. It does this using a + ``TorchDispatchMode``. + + .. note:: Not all collectives are supported yet. + + Example usage + + .. code-block:: python + + mod = ... + comm_mode = CommDebugMode() + with comm_mode: + mod.sum().backward() + print(comm_mode.get_comm_counts()) + """ + + def __init__(self): + self.comm_counts: dict[Any, int] = defaultdict(int) + self.comm_module_counts = {} + self.comm_module_operation_counts = {} + self.comm_registry = set() + for native_op, py_op in NATIVE_TO_PY_MAPPING.items(): + self.comm_registry.add(native_op) + self.comm_registry.add(py_op) + + self.comm_registry.add(torch.ops._dtensor.shard_dim_alltoall) + self.advanced_module_tracker = _CommModeModuleTracker() + + def generate_json_dump(self, file_name="comm_mode_log.json", noise_level=3): + """ + Creates json file used to build browser visual + 0. prints module-level collective counts + 1. prints dTensor operations not included in trivial operations + 2. prints operations not included in trivial operations + 3. prints all operations + """ + + ( + include_DTensor_ops, + include_module_data, + include_ops, + include_trivial_ops, + ) = self._set_noise_parameters(noise_level) + + # recursively builds json data + def add_json_information(json_dict, fqn): + json_dict["fqn"] = fqn + json_dict["module_type"] = "" + json_dict["parameters"] = [] + json_dict["children"] = [] + json_dict["collectives_forward"] = [] + json_dict["collectives_backward"] = [] + json_dict["operations_forward"] = [] + json_dict["operations_backward"] = [] + + # adds module layer type and parameters, and their sharding + if ( + "module_type" in self.advanced_module_tracker.module_helper_dict[fqn] + and include_module_data + ): + json_dict["module_type"] = ( + self.advanced_module_tracker.module_helper_dict[fqn]["module_type"] + ) + + if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: + for ( + param_name, + placement, + ) in self.advanced_module_tracker.module_helper_dict[fqn][ + "parameters" + ].items(): + json_dict["parameters"].append((param_name, placement)) + + # adds module collective information + if fqn in self.comm_module_counts: + for collective, count in self.comm_module_counts[fqn][ + "forward" + ].items(): + json_dict["collectives_forward"].append((str(collective), count)) + + for collective, count in self.comm_module_counts[fqn][ + "backward" + ].items(): + json_dict["collectives_backward"].append((str(collective), count)) + + # adds module operation information + forward_operations = [] + backward_operations = [] + checkpointing_operations = [] + + # only get operations if the minimum operation noise level is set to true + if include_DTensor_ops: + if fqn in self.comm_module_operation_counts: + ( + forward_operations, + backward_operations, + checkpointing_operations, + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) + + # remove all operations who don't have DTensor inputs + if not include_ops: + forward_operations = [ + op for op in forward_operations if len(op["input_sharding"]) + ] + backward_operations = [ + op for op in backward_operations if len(op["input_sharding"]) + ] + checkpointing_operations = [ + op for op in checkpointing_operations if len(op["input_sharding"]) + ] + + # remove all operations in trivial operations set + if not include_trivial_ops: + forward_operations = [ + op + for op in forward_operations + if str(op["name"]) not in trivial_ops + ] + backward_operations = [ + op + for op in backward_operations + if str(op["name"]) not in trivial_ops + ] + checkpointing_operations = [ + op + for op in checkpointing_operations + if str(op["name"]) not in trivial_ops + ] + + # converts operation information into string format for json.dumps() + forward_operations = copy.deepcopy(forward_operations) + for op in forward_operations: + op["name"] = str(op["name"]) + + for i in range(len(op["input_sharding"])): + op["input_sharding"][i] = str(op["input_sharding"][i]) + op["input_shape"][i] = str(op["input_shape"][i]) + + backward_operations = copy.deepcopy(backward_operations) + for op in backward_operations: + op["name"] = str(op["name"]) + + for i in range(len(op["input_sharding"])): + op["input_sharding"][i] = str(op["input_sharding"][i]) + op["input_shape"][i] = str(op["input_shape"][i]) + + checkpointing_operations = copy.deepcopy(checkpointing_operations) + for op in checkpointing_operations: + op["name"] = str(op["name"]) + + for i in range(len(op["input_sharding"])): + op["input_sharding"][i] = str(op["input_sharding"][i]) + op["input_shape"][i] = str(op["input_shape"][i]) + + json_dict["operations_forward"] = forward_operations + json_dict["operations_backward"] = backward_operations + json_dict["operations_checkpointing"] = checkpointing_operations + + if fqn not in self.advanced_module_tracker.parent_dict: + return json_dict + + # recursively adds module's children + for ele in self.advanced_module_tracker.parent_dict[fqn]: + json_dict["children"].append(add_json_information({}, ele)) + + return json_dict + + json_dict: dict[str, Any] = {} + add_json_information(json_dict, "Global") + + # converts dictonary into json file + with open(file_name, "w") as json_file: + json.dump(json_dict, json_file, indent=4) + + def generate_comm_debug_tracing_table(self, noise_level=3): + """ + Generates detailed table displaying operations and collective tracing information + on a module level. Amount of information is dependent on noise_level + + 0. prints module-level collective counts + 1. prints dTensor operations not included in trivial operations, module information + 2. prints operations not included in trivial operations + 3. prints all operations + """ + + ( + include_DTensor_ops, + include_module_data, + include_ops, + include_trivial_ops, + ) = self._set_noise_parameters(noise_level) + + table = "" + for fqn in self.advanced_module_tracker.module_helper_dict: + # setting up indentations for table formatting + indent = " " * ( + 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + ) + table += f"{indent}{fqn}\n" + + if include_module_data: + if ( + "module_type" + in self.advanced_module_tracker.module_helper_dict[fqn] + ): + module_type = self.advanced_module_tracker.module_helper_dict[fqn][ + "module_type" + ] + table += f"{indent}*module type: {module_type}\n" + + if "parameters" in self.advanced_module_tracker.module_helper_dict[fqn]: + table += f"{indent}*Parameter List\n" + for ( + param_name, + placement, + ) in self.advanced_module_tracker.module_helper_dict[fqn][ + "parameters" + ].items(): + table += f"{indent} *{param_name}: {placement}\n" + + indent += " " + collective_indent = " " * ( + 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 2 + ) + operation_indent = " " * ( + 2 * self.advanced_module_tracker.module_helper_dict[fqn]["depth"] + 3 + ) + + # separate the module's collective and operations by forward and backward + forward_collectives = {} + backward_collectives = {} + if fqn in self.comm_module_counts: + forward_collectives = self.comm_module_counts[fqn]["forward"] + backward_collectives = self.comm_module_counts[fqn]["backward"] + + forward_operations = [] + backward_operations = [] + checkpointing_operations = [] + + if include_DTensor_ops: + if fqn in self.comm_module_operation_counts: + ( + forward_operations, + backward_operations, + checkpointing_operations, + ) = self._get_operations_list( + self.comm_module_operation_counts[fqn] + ) + + def add_tracing_information(table, collectives_dict, operation_list): + """ + adds tracing information for module's forward or backward + """ + for collective, count in collectives_dict.items(): + table += ( + f"\033[1;33m{collective_indent}*{collective}: {count}\033[0m\n" + ) + + def add_operations( + table, operation, collective_indent, operation_indent + ): + """ + adds operation information to the table + """ + table += f"\033[1;33m{collective_indent}**{operation_name}\033[0m\n" + + if len(operation["input_shape"]): + operation_shape = operation["input_shape"] + operation_sharding = operation["input_sharding"] + operation_device_mesh = operation["device_mesh"] + + table += f"\033[1;31m{operation_indent}shape: {operation_shape}\033[0m\n" + table += f"\033[1;31m{operation_indent}sharding: {operation_sharding}\033[0m\n" + table += f"\033[1;31m{operation_indent}device mesh: {operation_device_mesh}\033[0m\n" + + return table + + for operation in operation_list: + operation_name = str(operation["name"]) + + # include all operations + if include_trivial_ops: + table = add_operations( + table, operation, collective_indent, operation_indent + ) + + # include all operations not in trivial operations + elif include_ops and operation_name not in trivial_ops: + table = add_operations( + table, operation, collective_indent, operation_indent + ) + + # only include dTensor operations not in trivial set + elif ( + include_DTensor_ops + and (operation_name not in trivial_ops) + and len(operation["input_shape"]) + ): + table = add_operations( + table, operation, collective_indent, operation_indent + ) + + return table + + if len(forward_collectives) or len(forward_operations): + table += f"{indent}FORWARD PASS\n" + table = add_tracing_information( + table, forward_collectives, forward_operations + ) + + if len(backward_collectives) or len(backward_operations): + table += f"{indent}BACKWARD PASS\n" + table = add_tracing_information( + table, backward_collectives, backward_operations + ) + + if len(checkpointing_operations): + table += f"{indent}ACTIVATION CHECKPOINTING\n" + table = add_tracing_information(table, {}, checkpointing_operations) + + return table + + def _get_operations_list(self, module_operation_counts): + forward_operations = [ + op for op in module_operation_counts["operations_list"] if not op["is_bw"] + ] + backward_operations = [ + op + for op in module_operation_counts["operations_list"] + if op["is_bw"] and not op["is_activation_checkpointing"] + ] + checkpointing_operations = [ + op + for op in module_operation_counts["operations_list"] + if op["is_activation_checkpointing"] + ] + + return forward_operations, backward_operations, checkpointing_operations + + def get_total_counts(self) -> int: + return sum(self.comm_counts.values()) + + def get_comm_counts(self) -> dict[Any, int]: + """Returns the communication counts as a dictionary. + + Returns: + Dict[Any, int]: The communication counts as a dictionary. + """ + return self.comm_counts + + def get_parameter_info(self) -> dict[str, dict[str, Any]]: + return self.advanced_module_tracker.module_parameters_dict + + def get_sharding_info(self) -> dict[str, dict[str, Any]]: + return self.advanced_module_tracker.sharding_dict + + def __enter__(self): + self.comm_counts.clear() + self.comm_module_counts.clear() + self.comm_module_counts["Global"] = {} + self.comm_module_counts["Global"]["forward"] = defaultdict(int) + self.comm_module_counts["Global"]["backward"] = defaultdict(int) + + self.comm_module_operation_counts.clear() + + super().__enter__() + self.advanced_module_tracker.__enter__() + return self + + def __exit__(self, *args): + self.advanced_module_tracker.__exit__() + super().__exit__(*args) + + def log_comm_debug_tracing_table_to_file( + self, file_name="comm_mode_log.txt", noise_level=3 + ): + """ + Alternative to console CommDebugMode output, writes to file specified by the user + """ + ansi_escape = re.compile(r"\x1B\[[0-?]*[ -/]*[@-~]") + table = ansi_escape.sub("", self.generate_comm_debug_tracing_table(noise_level)) + + with open(file_name, "w") as log_file: + log_file.write(table) + + def _set_noise_parameters(self, noise_level): + """ + sets variables controlling what information displays based on noise level + """ + include_DTensor_ops = False + include_module_data = False + include_ops = False + include_trivial_ops = False + + if noise_level > 0: + include_DTensor_ops = True + include_module_data = True + + if noise_level > 1: + include_ops = True + + if noise_level > 2: + include_trivial_ops = True + + return ( + include_DTensor_ops, + include_module_data, + include_ops, + include_trivial_ops, + ) + + def __torch_dispatch__(self, func, types, args=(), kwargs=None): + # When running this mode with DTensor, ordinarily all modes will + # run **before** subclasses get a chance to run. + # Returning NotImplemented here gives us a chance to let DTensor + # run and desugar into comms ops, before CommDebugMode sees them. + + # sets up operation-level collective count + if self.advanced_module_tracker.name not in self.comm_module_operation_counts: + # dictionary should hold module input and output shape, operations list and collective counter + self.comm_module_operation_counts[self.advanced_module_tracker.name] = { + "operations_list": [] + } + operation_dict = {} + operation_dict["name"] = func + + operation_dict["input_shape"] = [] + operation_dict["input_sharding"] = [] + operation_dict["device_mesh"] = "" + + # tracks if the operation is part of the backward pass + operation_dict["is_bw"] = self.advanced_module_tracker.is_bw + + # tracks if the operation is part of activation checkpointing + operation_dict["is_activation_checkpointing"] = ( + self.advanced_module_tracker.activation_checkpointing + ) + + if any(t == DTensor for t in types): + for ele in args: + if isinstance(ele, DTensor): + # saves shapes and placements of all DTensor args + operation_dict["input_shape"].append(ele.shape) + operation_dict["input_sharding"].append(ele.placements) + operation_dict["device_mesh"] = str(ele.device_mesh) + + self.comm_module_operation_counts[self.advanced_module_tracker.name][ + "operations_list" + ].append(operation_dict) + + return NotImplemented + + kwargs = kwargs if kwargs else {} + out = func(*args, **kwargs) + func_packet = func._overloadpacket + + # We have many tests that use CommDebugMode to verify the occurrence of + # collectives. These tests do so by querying comm_counts with legacy + # funcol ops as key. For the purpose of native funcol migration, we + # need these tests to work for both legacy and native funcol. To avoid + # the need to modify all tests to accommodate the two implementations, + # we make CommDebugMode translate native funcol ops into legacy funcol + # ops until the migration finishes. + + if func_packet in self.comm_registry or func_packet in c10d_collective_ops: + if func_packet in NATIVE_TO_PY_MAPPING: + func_packet = NATIVE_TO_PY_MAPPING[func_packet] + self.comm_counts[func_packet] += 1 + + key = "forward" + if self.advanced_module_tracker.is_bw: + key = "backward" + + # adds collective count to current module + if self.advanced_module_tracker.name not in self.comm_module_counts: + self.comm_module_counts[self.advanced_module_tracker.name] = {} + self.comm_module_counts[self.advanced_module_tracker.name][ + "forward" + ] = defaultdict(int) + self.comm_module_counts[self.advanced_module_tracker.name][ + "backward" + ] = defaultdict(int) + self.comm_module_counts[self.advanced_module_tracker.name][key][ + func_packet + ] += 1 + + # adds collective count to parent modules + for par in self.advanced_module_tracker.module_parents_dict[ + self.advanced_module_tracker.name + ]: + # makes sure we aren't double counting when current sub-module hasn't been removed from parents + if par != self.advanced_module_tracker.name: + if par not in self.comm_module_counts: + self.comm_module_counts[par] = {} + self.comm_module_counts[par]["forward"] = defaultdict(int) + self.comm_module_counts[par]["backward"] = defaultdict(int) + self.comm_module_counts[par][key][func_packet] += 1 + + # if tensor op uses fake tensors, return + if detect_fake_mode(args): + return out + + # add tensor operation to module operation list + self.comm_module_operation_counts[self.advanced_module_tracker.name][ + "operations_list" + ].append(operation_dict) + + return out diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_op_coverage.py b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_op_coverage.py new file mode 100644 index 0000000000000000000000000000000000000000..0013d18ca8b8e376dcc8e5c56064a231907ca5d8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_op_coverage.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from operator import itemgetter + +import torch +import torch.fx +import torch.nn as nn +from functorch.compile import make_boxed_func +from torch._functorch.compilers import aot_module +from torch._inductor.decomposition import select_decomp_table +from torch.distributed.tensor import DTensor + + +inductor_decomps = select_decomp_table() + +graphs: list[torch.fx.GraphModule] = [] + + +def fwd_bwd_compiler(fx_g, _): + graphs.append(fx_g) + return make_boxed_func(fx_g) + + +def get_inductor_decomp_graphs(model: nn.Module, args, kwargs): + """ + Obtain forward and backward graphs of a model with inductor decompositions using tracing and aot_module. + + Convenient util to get the fwd and bwd graphs of an arbitrary model + with inductor decompositions. Note that this would simply do tracing + with aot_module and don't ensure correctness. This is useful to track + the ops needed in DTensor. + """ + compiled_mod = aot_module( + model, fw_compiler=fwd_bwd_compiler, decompositions=inductor_decomps + ) + output = compiled_mod(*args, **kwargs) + + if output.ndim != 0: + # if output is not a scalar tensor, by default sum it in order to + # run backward + output = output.sum() + + output.backward() + + # one fwd, one bwd graph + assert len(graphs) == 2 + return graphs + + +def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=False): + """ + Util to print the operator coverage summary of a certain model with tabulute. + + Must have tabulate module installed. + """ + # python module required for summary + import csv + + from tabulate import tabulate + + fwd_graph, bwd_graph = get_inductor_decomp_graphs(model, args, kwargs) + + op_counts = {} + + for node in fwd_graph.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + if node.target not in op_counts: + op_counts[node.target] = 0 + + op_counts[node.target] += 1 + + for node in bwd_graph.graph.nodes: + if node.op == "call_function" and isinstance( + node.target, torch._ops.OpOverload + ): + if node.target not in op_counts: + op_counts[node.target] = 0 + + op_counts[node.target] += 1 + + op_infos = [] + + for op, count in op_counts.items(): + supported = op in DTensor._op_dispatcher.sharding_propagator.op_to_rules + op_infos.append([op, str(op._schema), count, supported]) + + # sort the op info base on the total count index + count_idx = 2 + op_infos.sort(key=itemgetter(count_idx), reverse=True) + + headers = ["Operator", "Schema", "Total Count", "Supported"] + print(tabulate(op_infos, headers=headers)) + + if output_csv: + # Open a CSV file for writing + with open("op_summary.csv", "w", newline="") as csv_file: + # Create a CSV writer object + csv_writer = csv.writer(csv_file) + + csv_writer.writerow(headers) + # Write each table row to the CSV file + for row in op_infos: + csv_writer.writerow(row) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_visualize_sharding.py b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_visualize_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..b678840c069a0e71f6dda40684d653ef950f448a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/debug/_visualize_sharding.py @@ -0,0 +1,227 @@ +# mypy: allow-untyped-defs +import importlib.util + +import numpy as np + +from torch._prims_common import ShapeType +from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset + + +__all__ = ["visualize_sharding"] + +Color = tuple[float, float, float] + + +def _create_table( + shards: list[tuple[tuple[int, int], tuple[int, int], int]], device_kind: str = "" +): + """ + Creates a tabulate table given row and column ranges with device name + """ + from tabulate import tabulate + + # Extract unique row and column ranges + row_ranges = sorted({block[0] for block in shards}) + col_ranges = sorted({block[1] for block in shards}) + + # Create a matrix initialized with empty strings + matrix = [["" for _ in col_ranges] for _ in row_ranges] + + # Fill the matrix with values + for block in shards: + row_index = row_ranges.index(block[0]) + col_index = col_ranges.index(block[1]) + if matrix[row_index][col_index] == "": + matrix[row_index][col_index] = device_kind + ":" + str(block[2]) + else: + matrix[row_index][col_index] += "," + str(block[2]) + + # Prepare headers + row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges] + col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges] + + return tabulate(matrix, headers=col_headers, showindex=row_headers) + + +def make_color_iter(color_map, num_rows, num_cols): + num_colors = num_rows * num_cols + for idx in range(num_colors): + yield color_map(idx) + + +def _canonicalize_color(color: Color) -> str: + if isinstance(color, str): + return color + r, g, b = (int(a * 255) for a in color) + return f"#{r:02X}{g:02X}{b:02X}" + + +def _get_text_color(color: str) -> str: + r, g, b = map(lambda x: int(x, 16), (color[1:3], color[3:5], color[5:7])) # noqa: C417 + if (r * 0.299 + g * 0.587 + b * 0.114) > 186: + return "#000000" + return "#ffffff" + + +def _create_rich_table( + shape: ShapeType, + shards: list[tuple[tuple[int, int], tuple[int, int], int]], + device_kind: str = "", + scale: float = 1.0, + min_width: int = 9, + max_width: int = 80, +): + import matplotlib + import rich.align + import rich.box + import rich.console + import rich.padding + import rich.style + import rich.table + + dtensor_height = shape[0] + dtensor_width = shape[1] if len(shape) == 2 else 1 + + row_ranges = sorted({s[0] for s in shards}) + col_ranges = sorted({s[1] for s in shards}) + num_rows, num_cols = len(row_ranges), len(col_ranges) + + console = rich.console.Console(width=max_width) + use_color = console.color_system + color_iter = make_color_iter(matplotlib.colormaps["tab20b"], num_rows, num_cols) + + base_height = int(10 * scale) + aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0] + base_width = int(base_height * aspect_ratio) + height_to_width_ratio = 2.5 + + table = rich.table.Table( + show_header=False, + show_lines=not use_color, + padding=0, + highlight=not use_color, + pad_edge=False, + box=rich.box.SQUARE if not use_color else None, + ) + for row in range(num_rows): + table_row = [] + for col in range(num_cols): + entry = ( + device_kind + + ":" + + ",".join( + [ + str(device_id) + for row_range, col_range, device_id in shards + if row_range == row_ranges[row] and col_range == col_ranges[col] + ] + ) + ) + width = (col_ranges[col][1] - col_ranges[col][0]) / dtensor_width + width = int(width * base_width * height_to_width_ratio) + height = (row_ranges[row][1] - row_ranges[row][0]) / dtensor_height + height = int(height * base_height) + left_padding, remainder = divmod(width - len(entry) - 2, 2) + right_padding = left_padding + remainder + top_padding, remainder = divmod(height - 2, 2) + bottom_padding = top_padding + remainder + if use_color: + color = _canonicalize_color(next(color_iter)[:3]) + text_color = _get_text_color(color) + top_padding += 1 + bottom_padding += 1 + left_padding += 1 + right_padding += 1 + else: + color = None + text_color = None + padding = ( + max(top_padding, 0), + max(right_padding, 0), + max(bottom_padding, 0), + max(left_padding, 0), + ) + table_row.append( + rich.padding.Padding( + rich.align.Align(entry, "center", vertical="middle"), + padding, + style=rich.style.Style(bgcolor=color, color=text_color), + ) + ) + table.add_row(*table_row) + console.print(table, end="\n\n") + + +def visualize_sharding(dtensor, header="", use_rich: bool = False): + """ + Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D. + + .. note:: This requires the ``tabulate`` package, or ``rich`` and ``matplotlib``. + No sharding info will be printed for empty tensors + """ + if dtensor.numel() == 0: # Do not print empty dtensors. + return + + if len(dtensor.shape) >= 3: + raise RuntimeError("visualize sharding supports only 1D or 2D DTensor") + + if dtensor.device_mesh.get_coordinate() is None: # current rank is not in the mesh + return + + # Only display the visualization once for each DTensor, on the rank whose + # coordinate is 0 on all dimensions. For example, if the mesh is a full mesh, + # we will only print on rank 0. + local_rank_zero_on_all_dim = all( + dtensor.device_mesh.get_local_rank(mesh_dim=dim) == 0 + for dim in range(dtensor.device_mesh.ndim) + ) + if not local_rank_zero_on_all_dim: + return + + device_coords = { + int(device_index.item()): list(coord) + for coord, device_index in np.ndenumerate( + np.array(dtensor.device_mesh.mesh.tolist()) + ) + } + + device_shard_shape_and_offsets = { + device_index: _compute_local_shape_and_global_offset( + dtensor.shape, + dtensor.device_mesh.shape, + device_coords[device_index], + dtensor.placements, + ) + for device_index in device_coords + } + + # Extend shards in a 1D tensor to 2D + device_shard_shape_and_offsets = { + device_index: ( + shape if len(shape) == 2 else (shape[0], 1), + offset if len(offset) == 2 else (offset[0], 0), + ) + for device_index, (shape, offset) in device_shard_shape_and_offsets.items() + } + + shards = [ + ( + (offset[0], offset[0] + shape[0] - 1), + (offset[1], offset[1] + shape[1] - 1), + device_index, + ) + for device_index, (shape, offset) in device_shard_shape_and_offsets.items() + ] + + if ( + importlib.util.find_spec("rich") + and importlib.util.find_spec("matplotlib") + and use_rich + ): + _create_rich_table( + dtensor.shape, shards, device_kind=dtensor.device_mesh.device_type + ) + elif importlib.util.find_spec("tabulate"): + print(_create_table(shards, device_kind=dtensor.device_mesh.device_type)) + else: + raise ValueError("`visualize_sharding` requires either `rich` or `tabulate`.") diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/device_mesh.py b/phivenv/Lib/site-packages/torch/distributed/tensor/device_mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..e629de83865ddcf5223d9fda7ba7111acc35a52b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/device_mesh.py @@ -0,0 +1,9 @@ +from torch.distributed.device_mesh import ( # noqa: F401 + _get_device_handle, + _mesh_resources, + DeviceMesh, + init_device_mesh, +) + + +__all__ = ["init_device_mesh", "DeviceMesh"] diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__init__.py b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5b11181deac25952dee45059fb5933397244fba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Iterator +from contextlib import contextmanager + +from torch.distributed.tensor._api import DTensor +from torch.distributed.tensor.experimental._attention import context_parallel +from torch.distributed.tensor.experimental._func_map import local_map +from torch.distributed.tensor.experimental._register_sharding import register_sharding + + +__all__ = ["context_parallel", "implicit_replication", "local_map", "register_sharding"] + + +@contextmanager +def implicit_replication() -> Iterator[None]: + """ + This context manager allows :class:`DTensor` to implicitly treat all non-DTensors (``torch.Tensor``) + in the program be replicate :class:`DTensor` s during the operator computation. + + .. warning:: This might possible lead to incorrect results if ``torch.Tensor`` s are not replicated + in practice, please use it at your discretion. + """ + try: + DTensor._op_dispatcher._allow_implicit_replication = True + yield + finally: + DTensor._op_dispatcher._allow_implicit_replication = False + + +# Set namespace for exposed private names +context_parallel.__module__ = "torch.distributed.tensor.experimental" +implicit_replication.__module__ = "torch.distributed.tensor.experimental" +local_map.__module__ = "torch.distributed.tensor.experimental" +register_sharding.__module__ = "torch.distributed.tensor.experimental" diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02719c983d9344325764640238ba454ae4fa8469 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_attention.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_attention.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eaa063108815985d52c678bc0d0d364201cb038d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_attention.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0db4a98c54012e6973bf7730458bba79010ea79c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_func_map.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..110cd06d3801fb88c0ce07e042940fa6a8cafcbc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_register_sharding.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d67053eff92094a26f0751ee945e8d1d46ad7c64 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/__pycache__/_tp_transform.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_attention.py b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3961f7e44c2f5c2c2f549c93ca3bbe61af6649e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_attention.py @@ -0,0 +1,1458 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates + +import contextlib +import itertools +import logging +import types +import weakref +from abc import ABC, abstractmethod +from collections.abc import Generator +from dataclasses import dataclass +from enum import auto, Enum +from typing import Any, Callable, Optional, Protocol, Union + +import torch +import torch.distributed as dist +import torch.distributed._functional_collectives as ft_c +import torch.nn.functional as F +from torch import nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import distribute_module, DTensor, Replicate, Shard +from torch.distributed.tensor.parallel.style import ParallelStyle +from torch.overrides import TorchFunctionMode + + +__all__ = ["context_parallel", "set_rotate_method"] + + +class _CausalBehavior(Enum): + SKIP = None + NOT_IS_CAUSAL = False + IS_CAUSAL = True + + +class _RotateMethod(Enum): + ALL_TO_ALL = auto() + ALL_GATHER = auto() + + +aten = torch.ops.aten +logger = logging.getLogger(__name__) + + +class _DispatchMode(Enum): + MONKEY_PATCH = auto() + TORCH_FUNCTION = auto() + TORCH_DISPATCH = auto() + + +_dispatch_mode: _DispatchMode = _DispatchMode.MONKEY_PATCH + + +@dataclass +class _ContextParallelOptions: + # Whether to upcast parameters and gradients to float32 to avoid accumulation + # errors. It is likely this is always True but we currently keep this variable + # for the experimental purpose. + convert_to_f32: bool = True + enable_load_balance = True + rotate_method: _RotateMethod = _RotateMethod.ALL_GATHER + + +_cp_options = _ContextParallelOptions() + + +def _is_causal_behavior( + rank: int, world_size: int, i: int, is_causal: bool +) -> _CausalBehavior: + """ + Calculate is_causal behavior for each KV block. The attention can either be + calculated in full, not at all or with the causal mask applied. + """ + if not is_causal: + return _CausalBehavior.NOT_IS_CAUSAL + + if i == 0: + return _CausalBehavior.IS_CAUSAL + + source_rank = (rank - i) % world_size + if source_rank < rank or _cp_options.enable_load_balance: + return _CausalBehavior.NOT_IS_CAUSAL + else: + return _CausalBehavior.SKIP + + +def _maybe_wait(tensor: torch.Tensor) -> torch.Tensor: + """ + When tracing the code, the result tensor is not an AsyncCollectiveTensor, + so we cannot call ``wait()``. + """ + if isinstance(tensor, ft_c.AsyncCollectiveTensor): + return tensor.wait() + return tensor + + +def _partial_update( + original: torch.Tensor, + new: torch.Tensor, + dim: int, + n_chunks: int, + idx: int, + add: bool, +) -> torch.Tensor: + """ + This API partially update a chunk of ``original`` tensor. The ``original`` + tensor will be first chunked along ``dim`` dimension then the ``idx`` chunk + will be updated with ``new``. If ``add`` is True, the chunk will be added + with ``new``, otherwise the chunk with be replaced by ``add``. + + The result is a tensor that is the same size as ``original``. + """ + chunks = list(original.chunk(n_chunks, dim=dim)) + assert chunks[idx].shape == new.shape, (original.shape, new.shape, idx) + if add: + chunks[idx] += new + else: + chunks[idx] = new + return torch.cat(chunks, dim=dim) + + +class _SDPAMerger: + """A class to help to merge the local SDPA result.""" + + def __init__(self, convert_to_f32: bool, seq_dim: int): + self._seq_dim = seq_dim + self._out: Optional[torch.Tensor] = None + self._lse: Optional[torch.Tensor] = None + self._convert_to_f32 = convert_to_f32 + self._out_dtype = torch.float32 + self._lse_dtype = torch.float32 + + def _merge_one( + self, block_out: torch.Tensor, block_lse: torch.Tensor, partial: bool + ) -> None: + block_lse = block_lse.unsqueeze(dim=-1) + if self._lse is None: + self._lse = block_lse + self._out = block_out + else: + ROUND_ROBIN_CYCLE = 2 + assert self._lse is not None + assert self._out is not None + lse = ( + self._lse.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._lse + ) + out = ( + self._out.chunk(ROUND_ROBIN_CYCLE, dim=self._seq_dim)[1] + if partial + else self._out + ) + + # The algorithm from + # github.com/zhuzilin/ring-flash-attention/pull/34#issuecomment-2076126795 + # gives a relatively stable result. + out = out - F.sigmoid(block_lse - lse) * (out - block_out) + lse = lse - F.logsigmoid(lse - block_lse) + if partial: + self._lse = _partial_update( + self._lse, + lse, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + self._out = _partial_update( + self._out, + out, + dim=self._seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=False, + ) + else: + self._lse = lse + self._out = out + + def step(self, out: torch.Tensor, lse: torch.Tensor, partial: bool) -> None: + self._out_dtype = out.dtype + self._lse_dtype = lse.dtype + + if self._convert_to_f32: + out = out.to(torch.float32) + lse = lse.to(torch.float32) + + self._merge_one(out, lse, partial) + + def results(self) -> tuple[torch.Tensor, torch.Tensor]: + assert self._out is not None + assert self._lse is not None + out, lse = self._out, self._lse.squeeze(-1) + return out.to(self._out_dtype), lse.to(self._lse_dtype) + + +def _scaled_dot_product_ring_flash_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if return_debug_mask: + raise NotImplementedError("return_debug_mask is not supported yet") + + seq_dim = 2 + return _templated_ring_attention( + mesh, + seq_dim, + aten._scaled_dot_product_flash_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + dropout_p=dropout_p, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + seq_dim = 2 + return _templated_ring_attention( + mesh, + seq_dim, + aten._scaled_dot_product_efficient_attention, + query=query, + key=key, + value=value, + is_causal=is_causal, + attn_bias=attn_bias, + dropout_p=dropout_p, + scale=scale, + compute_log_sumexp=compute_log_sumexp, + ) + + +def _scaled_dot_product_ring_cudnn_attention( + mesh: DeviceMesh, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attn_bias: Optional[torch.Tensor] = None, + compute_log_sumexp: bool = True, + dropout_p: float = 0.0, + is_causal: bool = False, + return_debug_mask: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + if attn_bias is not None: + raise NotImplementedError("attn_bias is not supported yet") + + if not compute_log_sumexp: + # CP requires compute_log_sumexp to be True because it always merges LSE + compute_log_sumexp = True + + seq_dim = 2 + return _templated_ring_attention( + mesh, + seq_dim, + aten._scaled_dot_product_cudnn_attention, + query=query, + key=key, + value=value, + attn_bias=attn_bias, + compute_log_sumexp=compute_log_sumexp, + dropout_p=dropout_p, + is_causal=is_causal, + return_debug_mask=return_debug_mask, + scale=scale, + ) + + +class _AttentionOp(Protocol): + def __call__( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + **kwargs: object, + ) -> tuple[torch.Tensor, ...]: ... + + +class _RingRotater(ABC): + @abstractmethod + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: ... + + @abstractmethod + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: ... + + @abstractmethod + def next_buffer(self) -> torch.Tensor: ... + + +class _AllToAllRotater(_RingRotater): + """Use all_to_all to send the kv to the next rank""" + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._buffer: Optional[torch.Tensor] = None + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + curr_buffer = curr_buffer.contiguous() + size = dist.get_world_size(self._pg) + dsts = list(range(1, size)) + [0] + self._buffer = ft_c.permute_tensor(curr_buffer, dsts, self._pg) + + def next_buffer(self) -> torch.Tensor: + assert self._buffer is not None + return _maybe_wait(self._buffer) + + +class _AllGatherRotater(_RingRotater): + """ + Allgather the kv and return the only the requried kv. + Only one communication will be done. + """ + + def __init__(self, pg: dist.ProcessGroup, seq_dim: int) -> None: + self._pg = pg + self._seq_dim = seq_dim + self._aggregated_buffer: Optional[torch.Tensor] = None + self._idx = 0 + + def exchange_buffers(self, curr_buffer: torch.Tensor) -> None: + # We only need to perform the allgather once. + self._idx += 1 + if self._aggregated_buffer is None: + self._aggregated_buffer = ft_c.all_gather_tensor( + curr_buffer.contiguous(), gather_dim=0, group=self._pg + ) + + def next_buffer(self) -> torch.Tensor: + rank = dist.get_rank(self._pg) + idx = rank - self._idx + + assert self._aggregated_buffer is not None + self._aggregated_buffer = _maybe_wait(self._aggregated_buffer) + return self._aggregated_buffer.chunk(dist.get_world_size(self._pg))[idx] + + +def _create_rotater( + pg: dist.ProcessGroup, seq_dim: int, method: Optional[_RotateMethod] = None +) -> _RingRotater: + if method is None: + method = _cp_options.rotate_method + + if method == _RotateMethod.ALL_TO_ALL: + return _AllToAllRotater(pg, seq_dim) + elif method == _RotateMethod.ALL_GATHER: + return _AllGatherRotater(pg, seq_dim) + else: + raise NotImplementedError(f"Unkonwn method {method}") + + +def _ring_rotate( + block: torch.Tensor, pg: dist.ProcessGroup, send_to_next: bool +) -> torch.Tensor: + block = block.contiguous() + size = dist.get_world_size(pg) + dsts = ( + list(range(1, size)) + [0] + if send_to_next + else [size - 1] + list(range(0, size - 1)) + ) + return ft_c.permute_tensor(block, dsts, pg) + + +def _templated_ring_attention( + mesh: DeviceMesh, + seq_dim: int, + op: _AttentionOp, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + is_causal: bool = False, + **kwargs: object, +) -> tuple[torch.Tensor, ...]: + """ + This is a generalized ring attention implementation that can support multiple attention ops. + + Note [Context parallelism load balance algorithm for causal masking] + ===================== + This explanation uses an example to illustrate the CP algorithm with causal + masking. + + Consider a scenario where the sequence length of q, k, and v is 4 (e.g., + q = (q0, q1, q2, q3)), and there are two ranks. For simplicity, we will discuss + only q and k, as v follows the same pattern as k. + + The diagram below represents a complete QK^T operation without parallelism. + The `****` entries indicate that the result is not required due to causal + masking (e.g., q0k1 is marked as `****`). + + +----+------------------------+ + | | k0 k1 k2 k3 | + +----+------------------------+ + | q0 | q0k0, ****, ****, **** | + | q1 | q1k0, q1k1, ****, **** | + | q2 | q2k0, q2k1, q2k2, **** | + | q3 | q3k0, q3k1, q3k2, q3k3 | + +----+------------------------+ + + ### No Load Balance: + + In this scenario, each rank owns a local chunk of q, k, and v, with each chunk + containing two elements. Rank0 is responsible for managing (q0, q1) and (k0, k1), + while rank1 manages (q2, q3) and (k2, k3). + + First Iteration: Both rank0 and rank1 perform SDPA with their local qkv pairs. + Causal masking is enabled as some results are not required (e.g., q0k1). + + Second Iteration: Local queries remain the same, but local kv pairs are exchanged. + Rank0 now has (q0, q1) and (k2, k3); rank1 has (q2, q3) and (k0, k1). Rank0 performs + no computation, while rank1 computes locally without causal masking since all results + (q2k0, q2k1, q3k0, q3k1) are needed. + + ### Round-robin Load Balance: + + In this setup, each rank owns two local chunks of q, k, and v, with each chunk + containing one element. Rank0 manages (q0, q3) and (k0, k3); Rank1 manages (q1, q2) + and (k1, k2). Although the local chunks are not consecutive, they are concatenated to + enable SDPA to be performed in a single call for each step. Consequently, the chunk() + function may be required to prepare the correct q, k, and v configurations. + + First Iteration: Both ranks perform SDPA with their local qkv pairs, similar to the + no-load-balance case. This iteration corresponds to the `if` of the + (`if, `elif`, `else`) in the implemementation. + + Second Iteration: Rank0 now has (q0, q3) and (k1, k2); rank1 has (q1, q2) and + (k0, k3). For rank0, no computation is needed for q0. However, computations for + q3k1 and q3k2 are required, so only q3 is used for SDPA. This corresponds to the + `else` of the (`if`, `elif`, `else`) in the implemementation. + For rank1, k0 is not needed for q1 and q2, so only k3 is used for SDPA. This + corresponds to the `elif` of (`if`, `elif`, `else`) in the implementation. + + Parameters + ---------- + op: + The attention op to use + *args: + additional args are passed to the op + **kwargs: + additional kwargs are passed to the op + + Returns + ------- + out: + The merged attention output + softmax_lse: + The logsumexp of the merged attention output + """ + if is_causal and (query.size(2) != key.size(2)): + raise NotImplementedError( + "is_causal requires the same query and context sequence lengths" + ) + if not is_causal and _cp_options.enable_load_balance: + raise RuntimeError("Load balancing requires `is_causal=True`.") + + if isinstance(mesh, dist.ProcessGroup): + pg: Union[dist.ProcessGroup, list[dist.ProcessGroup]] = mesh + else: + pg = mesh.get_group() + assert isinstance(pg, dist.ProcessGroup), "process group must be single dimension" + rank = dist.get_rank(pg) + size = dist.get_world_size(pg) + + next_kv = None + + # Without making key and value contiguous(), the lose curve is bad. + # TODO(fegin): figure out why this is a requirement since SDPA does not have + # this requirement. + key = key.contiguous() + value = value.contiguous() + + sdpa_merger = _SDPAMerger(_cp_options.convert_to_f32, seq_dim=seq_dim) + + rest: list[Any] + out: torch.Tensor + logsumexp: torch.Tensor + + rotater = _create_rotater(pg, 2) + + for i in range(size): + if i > 0: + # Wait for the kv from the (cp_rank - 1) rank. + next_kv = rotater.next_buffer() + key = next_kv[: key.numel()].reshape(key.shape) + value = next_kv[key.numel() :].reshape(value.shape) + + if i < (size - 1): + # Send the k, v to the next rank + next_kv = torch.cat([key.flatten(), value.flatten()]) + next_kv = rotater.exchange_buffers(next_kv) + + is_causal_behavior = _is_causal_behavior( + rank=rank, world_size=size, i=i, is_causal=is_causal + ) + + # For a detailed understanding of the load balancing algorithm, see + # Note [Context parallelism load balance algorithm for causal masking] + if is_causal_behavior == _CausalBehavior.SKIP: + # If i > rank and load balancing is not turned on. + continue + + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # When local balance is enabled, we still need to do SDPA with + # the both local chunks of q, k, v for the first iteration. + q, k, v, partial = (query, key, value, False) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SPDA, with only the first local chunk of the k, v. + # Note that q, k, v, each contains two local chunks. + ROUND_ROBIN_CYCLE = 2 + q, k, v, partial = ( + query, + key.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + value.chunk(ROUND_ROBIN_CYCLE, dim=2)[0], + False, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SPDA with only the second half of the q, and update + # only the the second part of logsumexp. So partial is True. + # Note that q, k, v, each contains two chunks. + q, k, v, partial = query.chunk(2, dim=2)[1], key, value, True + + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + out, logsumexp, *rest = op( + q, + k, + v, + is_causal=is_causal_behavior.value, + **kwargs, + ) + sdpa_merger.step(out, logsumexp, partial) + + return *sdpa_merger.results(), *rest + + +def _sdpa_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + # TODO: remove the context parallel strategy from the default propagation + # rule. Either figure out how to dynamically enable it or just don't call + # propagate. + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + if op_call == aten._scaled_dot_product_flash_attention.default: + local_results = _scaled_dot_product_ring_flash_attention( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_efficient_attention.default: + local_results = _scaled_dot_product_ring_efficient_attention( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_cudnn_attention.default: + local_results = _scaled_dot_product_ring_cudnn_attention( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError( + "CP only supports flash attention and memory efficient attention now." + ) + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + +def _sdpa_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + # Redistribute grad_output tensor to the same placement as output tensor + args = list(args) + args = tuple(args) + + # extract local tensor and sharding infos to a OpInfo + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + logger.debug("Dispatching op_call: %s", op_info.schema) + + # sharding propagation + DTensor._op_dispatcher.sharding_propagator.propagate(op_info) + output_sharding = op_info.output_sharding + assert output_sharding is not None, "output sharding should not be None" + assert not output_sharding.needs_redistribute, "inputs need to be redistributed" + + if op_call == aten._scaled_dot_product_flash_attention_backward.default: + local_results = _scaled_dot_product_ring_flash_attention_backward( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_efficient_attention_backward.default: + local_results = _scaled_dot_product_ring_efficient_attention_backward( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + elif op_call == aten._scaled_dot_product_cudnn_attention_backward.default: + local_results = _scaled_dot_product_ring_cudnn_attention_backward( + op_info.compute_mesh, + *op_info.local_args, # type: ignore[arg-type] + **op_info.local_kwargs, # type: ignore[arg-type] + ) + else: + raise NotImplementedError(f"{op_call=}") + + return DTensor._op_dispatcher.wrap(local_results, output_sharding.output_spec) + + +def _templated_ring_attention_backward( + mesh: DeviceMesh, + seq_dim: int, + op: _AttentionOp, + grad_out: torch.Tensor, + grad_out_name: str, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + is_causal: bool, + **kwargs: Any, +) -> tuple[torch.Tensor, ...]: + """This API implements the backward of the ring attention.""" + if not is_causal and _cp_options.enable_load_balance: + raise RuntimeError("Load balancing requires `is_causal=True`.") + pg = mesh.get_group() + assert isinstance(pg, dist.ProcessGroup), "must be single dimension" + rank = dist.get_rank(pg) + size = dist.get_world_size(pg) + next_kv = None + next_grad_kv = None + rest: list[Any] + grad_query_, grad_key_, grad_value_ = None, None, None + + accum_dtype = torch.float32 if _cp_options.convert_to_f32 else query.dtype + grad_query = torch.zeros_like(query, dtype=accum_dtype) + grad_key = torch.zeros_like(key, dtype=accum_dtype) + grad_value = torch.zeros_like(value, dtype=accum_dtype) + + key = key.contiguous() + value = value.contiguous() + kv_rotater = _create_rotater(pg, 2) + dkv_rotater = _create_rotater(pg, 2, method=_RotateMethod.ALL_TO_ALL) + for i in range(size): + if i > 0: + # Wait for the kv from the (cp_rank - 1) rank. + buffer = kv_rotater.next_buffer() + pointer = 0 + key = buffer[pointer : pointer + key.numel()].reshape(key.shape) + pointer += key.numel() + value = buffer[pointer : pointer + value.numel()].reshape(value.shape) + pointer += value.numel() + + if i != size - 1: + # Send the kv to the next rank. + next_kv = torch.cat([key.flatten(), value.flatten()]) + kv_rotater.exchange_buffers(next_kv) + + is_causal_behavior = _is_causal_behavior( + rank=rank, world_size=size, i=i, is_causal=is_causal + ) + + if is_causal_behavior != _CausalBehavior.SKIP: + if i == 0 or (not _cp_options.enable_load_balance or not is_causal): + # We need to do SDPA with the full local q, k, v. + q, k, v, out_, dout, lse = (query, key, value, out, grad_out, logsumexp) + elif i <= rank: + # Round-robin load balancing case, and i <= rank. + # We need to do SPDA with only the first half of the k, v. + # Note that q, k, v, each contains two chunks. + q, k, v, out_, dout, lse = ( + query, + key.chunk(2, dim=seq_dim)[0], + value.chunk(2, dim=seq_dim)[0], + out, + grad_out, + logsumexp, + ) + else: + # Round-robin load balancing case, and i > rank. + # We need to do SPDA with only the second half of the q + # Note that q, k, v, each contains two chunks. + q, k, v, out_, dout, lse = ( + query.chunk(2, dim=seq_dim)[1], + key, + value, + out.chunk(2, dim=seq_dim)[1], + grad_out.chunk(2, dim=seq_dim)[1], + # Need to make logsumexp contiguous, otherwise there will + # be numerical error. + logsumexp.chunk(2, dim=seq_dim)[1].contiguous(), + ) + + kwargs[grad_out_name] = dout + # See https://github.com/pytorch/pytorch/blob/release/2.4/aten/src/ATen/native/native_functions.yaml#L14695 + # for the SDPA kernel definitions. + grad_query_, grad_key_, grad_value_, *rest = op( + query=q, + key=k, + value=v, + out=out_, + logsumexp=lse, + is_causal=is_causal_behavior.value, + **kwargs, + ) + else: + grad_query_ = torch.zeros_like(query, dtype=accum_dtype) + grad_key_ = torch.zeros_like(key, dtype=accum_dtype) + grad_value_ = torch.zeros_like(value, dtype=accum_dtype) + + ROUND_ROBIN_CYCLE = 2 + if i == 0: + grad_key += grad_key_ + grad_value += grad_value_ + else: + pointer = 0 + # Wait for the kv gradient from (cp_rank - 1) rank. + next_grad_kv = dkv_rotater.next_buffer() + grad_key = next_grad_kv[pointer : pointer + grad_key.numel()].reshape( + grad_key.shape + ) + pointer += grad_key.numel() + grad_value = next_grad_kv[pointer : pointer + grad_value.numel()].reshape( + grad_value.shape + ) + + if i <= rank and _cp_options.enable_load_balance: + grad_key = _partial_update( + grad_key, + grad_key_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + grad_value = _partial_update( + grad_value, + grad_value_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=0, + add=True, + ) + else: + grad_key += grad_key_ + grad_value += grad_value_ + + next_grad_kv = torch.cat([grad_key.flatten(), grad_value.flatten()]) + # Send the grad key, and grad value to the next rank. + dkv_rotater.exchange_buffers(next_grad_kv) + + if i <= rank or not _cp_options.enable_load_balance: + grad_query += grad_query_ + else: + grad_query = _partial_update( + grad_query, + grad_query_, + dim=seq_dim, + n_chunks=ROUND_ROBIN_CYCLE, + idx=1, + add=True, + ) + + assert grad_key_ is not None + assert grad_value_ is not None + grad_query = grad_query.to(query.dtype) + next_grad_kv = dkv_rotater.next_buffer().to(key.dtype) + grad_key = next_grad_kv[: grad_key.numel()].reshape(grad_key.shape) + grad_value = next_grad_kv[grad_key.numel() :].reshape(grad_value.shape) + return ( + grad_query, + grad_key, + grad_value, + *rest, + ) + + +def _scaled_dot_product_ring_flash_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + cum_seq_q: torch.Tensor, + cum_seq_k: torch.Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + seq_dim = 2 + return _templated_ring_attention_backward( + mesh, + seq_dim, + aten._scaled_dot_product_flash_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=logsumexp, + is_causal=is_causal, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout_p, + philox_seed=philox_seed, + philox_offset=philox_offset, + scale=scale, + ) + + +def _scaled_dot_product_ring_efficient_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + bias: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + dropout_p: float, + grad_input_mask: tuple[bool, ...], + is_causal: bool = False, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + seq_dim = 2 + return _templated_ring_attention_backward( + mesh, + seq_dim, + aten._scaled_dot_product_efficient_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out_", + query=query, + key=key, + value=value, + attn_bias=bias, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + dropout_p=dropout_p, + grad_input_mask=grad_input_mask, + is_causal=is_causal, + scale=scale, + ) + + +def _scaled_dot_product_ring_cudnn_attention_backward( + mesh: DeviceMesh, + grad_out: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + out: torch.Tensor, + logsumexp: torch.Tensor, + philox_seed: torch.Tensor, + philox_offset: torch.Tensor, + attn_bias: torch.Tensor, + cum_seq_q: torch.Tensor, + cum_seq_k: torch.Tensor, + max_q: int, + max_k: int, + dropout_p: float, + is_causal: bool, + *, + scale: Optional[float] = None, +) -> tuple[torch.Tensor, ...]: + seq_dim = 2 + return _templated_ring_attention_backward( + mesh, + seq_dim, + aten._scaled_dot_product_cudnn_attention_backward.default, + grad_out=grad_out, + grad_out_name="grad_out", + query=query, + key=key, + value=value, + out=out, + logsumexp=logsumexp, + philox_seed=philox_seed, + philox_offset=philox_offset, + attn_bias=attn_bias, + cum_seq_q=cum_seq_q, + cum_seq_k=cum_seq_k, + max_q=max_q, + max_k=max_k, + dropout_p=dropout_p, + is_causal=is_causal, + scale=scale, + ) + + +customized_ops = { + aten._scaled_dot_product_flash_attention.default: _sdpa_handler, + aten._scaled_dot_product_flash_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_efficient_attention.default: _sdpa_handler, + aten._scaled_dot_product_efficient_attention_backward.default: _sdpa_backward_handler, + aten._scaled_dot_product_cudnn_attention.default: _sdpa_handler, + aten._scaled_dot_product_cudnn_attention_backward.default: _sdpa_backward_handler, +} + + +_replaced_functions: dict[Callable, tuple[str, Callable]] = {} + + +def _distribute_function( + fn: Callable, + fn_module: types.ModuleType, + device_mesh: DeviceMesh, + input_fn: Optional[Callable] = None, + output_fn: Optional[Callable] = None, +) -> None: + """ + ``distribute_function`` is an experimental API that allows users to "distribute" + the inputs and outputs of a function. Similar to ``distribute_module``, this API + installs hooks to the ``fn`` to convert the inputs and outputs. There are two + major differences between ``distribute_function`` and ``distribute_module``. + First, a function does not have parammeters and buffers, as a result, + ``distribute_function`` itself won't convert any parameters/buffers but simply + install the input and output hooks. The tensor conversion will happen in the hooks. + Another difference is an nn.Module subclass can have several instances and each + instance be fed into ``distribute_module`` independently with affecting other + instance. On the other hand, function is a singleton object. So if a function + is distributed by ``distribute_function`` all subsequent calls to the function + will invoke the installed hooks. + + Args: + fn (Callable): the function to be distributed. + fn_module (types.ModuleType): the Python module that the function is declared. + e.g., if ``fn`` is ``torch.nn.functional.scaled_dot_product_attention``, + ``fn_module`` is ``torch.nn.functional``. + device_mesh (:class:`DeviceMesh`): the device mesh that will be used by the + input and output hooks to distribute the tensors. + input_fn (Optioinal[Callable]): the hook to distribute or convert the input + arguments of ``fn``. + output_fn (Optioinal[Callable]): the hook to distribute or convert the output + arguments of ``fn``. + """ + + def wrapper( + target_fn: Callable, input_fn: Optional[Callable], output_fn: Optional[Callable] + ) -> Callable: + def inner_fn(*args: tuple[Any, ...], **kwargs: dict[str, Any]) -> Any: + if input_fn is not None: + args, kwargs = input_fn(device_mesh, *args, **kwargs) + output = target_fn(*args, **kwargs) + if output_fn is not None: + output = output_fn(device_mesh, output) + return output + + return inner_fn + + global _replaced_functions + + if fn in _replaced_functions: + return + + wrapper_fn = wrapper(fn, input_fn, output_fn) + setattr(fn_module, fn.__name__, wrapper_fn) + _replaced_functions[wrapper_fn] = (fn.__name__, fn) + + +def _restore_function(fn: Callable, fn_module: types.ModuleType) -> None: + """Restore the function that is replaced by _distribute_function.""" + global _original_functions + global _wrapper_functions + + if fn not in _replaced_functions: + return + + original_name, original_fn = _replaced_functions[fn] + setattr(fn_module, original_name, original_fn) + + +@contextlib.contextmanager +def _enable_cp_dispatcher() -> Generator[None, None, None]: + """Enables DTensor dispatcher to dispatch SDPA to CP.""" + old_handlers = DTensor._op_dispatcher._custom_op_handlers + DTensor._op_dispatcher._custom_op_handlers = {**old_handlers, **customized_ops} + + yield + + DTensor._op_dispatcher._custom_op_handlers = old_handlers + + +class _AttentionContextParallel(ParallelStyle): + """ + Applies context parallel optimizations to the attention layer. + + This will work for nn.MultiHeadedAttention and custom attention layers that + call F.scaled_dotproduct_attention with a simliar signature. + + This expects the `forward` method consumes either: + + * a single tensor for self attention + * one argument for each of: query, key, value + + This currently only supports ring attention and the + SDPBackend.FLASH_ATTENTION backend. See sdpa_kernel. + + Non-flash attention backends will result in incorrect results. + """ + + # use a weakref dictionary to store context managers for each nn.Module + _CONTEXT_MANAGERS: "weakref.WeakKeyDictionary[nn.Module, Any]" = ( + weakref.WeakKeyDictionary() + ) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if not isinstance(device_mesh, DeviceMesh): + raise ValueError( + f"{type(device_mesh)} is not supported by {type(self)} yet." + ) + + if not device_mesh.ndim == 1: + raise ValueError + + return distribute_module( + module, + device_mesh, + input_fn=self._input_fn, # type: ignore[arg-type] + output_fn=self._output_fn, # type: ignore[arg-type] + ) + + @classmethod + def _input_fn( + cls, + module: nn.Module, + inputs: tuple[Union[torch.Tensor, int, float], ...], + device_mesh: DeviceMesh, + ) -> tuple[Union[torch.Tensor, int, float], ...]: + # TODO(d4l3k); this should be Shard(2), need to fix Linear layer rules + placement = [Replicate()] + + def backward_hook(grad: torch.Tensor) -> None: + if module in cls._CONTEXT_MANAGERS: + cls._CONTEXT_MANAGERS[module].__exit__(None, None, None) + del cls._CONTEXT_MANAGERS[module] + + # convert inputs to DTensor + inp = [] + for input in inputs: + if isinstance(input, torch.Tensor) and not isinstance(input, DTensor): + input = DTensor.from_local( + input.contiguous(), device_mesh, placement, run_check=False + ) + + if isinstance(input, torch.Tensor) and input.requires_grad: + input.register_hook(backward_hook) + + inp.append(input) + + manager = _enable_cp_dispatcher() + manager.__enter__() + cls._CONTEXT_MANAGERS[module] = manager + + return tuple(inp) + + @classmethod + def _output_fn( + cls, + module: nn.Module, + outputs: Union[torch.Tensor, tuple[Union[torch.Tensor, int, float], ...]], + device_mesh: DeviceMesh, + ) -> Union[ + Union[torch.Tensor, int, float], tuple[Union[torch.Tensor, int, float], ...] + ]: + cls._CONTEXT_MANAGERS[module].__exit__(None, None, None) + del cls._CONTEXT_MANAGERS[module] + + def backward_hook(grad: torch.Tensor) -> None: + if module not in cls._CONTEXT_MANAGERS: + manager = _enable_cp_dispatcher() + manager.__enter__() + cls._CONTEXT_MANAGERS[module] = manager + + # back to local tensor + out = [] + for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: + output = output.to_local() if isinstance(output, DTensor) else output + + if isinstance(output, torch.Tensor) and output.requires_grad: + output.register_hook(backward_hook) + + out.append(output) + + if isinstance(outputs, torch.Tensor): + return out[0] + + return tuple(out) + + +@contextlib.contextmanager +def _context_parallel(seq_dim: int, mesh: DeviceMesh) -> Generator[None, None, None]: + """Replace SDPA with the CP-wrapped version and enable DTensor CP dispatcher.""" + + def attention_input_fn( + mesh: DeviceMesh, *args: tuple[Any, ...], **kwargs: dict[str, Any] + ) -> tuple[tuple[Any, ...], dict[str, Any]]: + placement = [Shard(seq_dim)] + all_args = [] + + for arg in itertools.chain(args, kwargs.values()): + if isinstance(arg, torch.Tensor) and not isinstance(arg, DTensor): + arg = DTensor.from_local(arg, mesh, placement, run_check=False) + + all_args.append(arg) + + new_args = tuple(all_args[0 : len(args)]) + new_kwargs = dict(zip(kwargs.keys(), all_args[len(args) :])) + return new_args, new_kwargs + + def attention_output_fn(mesh: DeviceMesh, outputs: Any) -> Any: + new_outputs = [] + for output in [outputs] if isinstance(outputs, torch.Tensor) else outputs: + output = output.to_local() if isinstance(output, DTensor) else output + new_outputs.append(output) + + if isinstance(outputs, torch.Tensor): + return new_outputs[0] + + return tuple(new_outputs) + + class DistributeFunction(TorchFunctionMode): + def __init__( + self, + fn: Callable, + device_mesh: DeviceMesh, + input_fn: Optional[Callable] = None, + output_fn: Optional[Callable] = None, + ): + self._device_mesh = device_mesh + self._input_fn = input_fn + self._output_fn = output_fn + self._fn = fn + + def __torch_function__( + self, + func: Callable, + types: Any, + args: tuple[Any, ...] = (), + kwargs: Optional[dict[str, Any]] = None, + ) -> Any: + kwargs = kwargs or {} + + if func != self._fn: + return func(*args, **kwargs) + + if self._input_fn is not None: + args, kwargs = self._input_fn(self._device_mesh, *args, **kwargs) + output = func(*args, **kwargs) + if self._output_fn is not None: + output = self._output_fn(self._device_mesh, output) + return output + + if _dispatch_mode == _DispatchMode.MONKEY_PATCH: + _distribute_function( + F.scaled_dot_product_attention, + F, + mesh, + attention_input_fn, + attention_output_fn, + ) + with _enable_cp_dispatcher(): + yield + _restore_function(F.scaled_dot_product_attention, F) + elif _dispatch_mode == _DispatchMode.TORCH_FUNCTION: + with DistributeFunction( + F.scaled_dot_product_attention, + mesh, + attention_input_fn, + attention_output_fn, + ): + with _enable_cp_dispatcher(): + yield + else: + raise NotImplementedError("torch dispatch mode is not supported yet.") + + +class _LoadBalancer(ABC): + @classmethod + @abstractmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: ... + + @classmethod + @abstractmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: ... + + +class _SequentialSharder(_LoadBalancer): + """ + This load balancer chunks the buffer into cp_world_size and rank0 gets + 0th shard, rank1 gets 1st shard, ... + So this doesn't have any load balancing effect when using the causal masking. + """ + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert buffer.size()[seq_dim] % mesh.size() == 0 + return buffer.chunk(mesh.size(), dim=seq_dim)[mesh.get_local_rank()] + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + buffer = buffer.contiguous() + all_buffers = [torch.empty_like(buffer) for _ in range(mesh.size())] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + return torch.cat(all_buffers, dim=seq_dim) + + +class _RoundRobinLoadBalancer(_LoadBalancer): + """ + This load balancer chunk the buffer into cp_world_size * ROUND_ROBIN_CYCLE + shards, and uses a round robin approach to achieve load balancing. + Since ROUND_ROBIN_CYCLE being 2 will achieve perfect load balancing for + causal masking, we assume ROUND_ROBIN_CYCLE is always 2 to simplify the + implementation. + """ + + ROUND_ROBIN_CYCLE = 2 + + @classmethod + def shard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert cls.ROUND_ROBIN_CYCLE == 2, ( + "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + ) + cp_world_size = mesh.size() + cp_rank = mesh.get_local_rank() + assert buffer.size()[seq_dim] % (cp_world_size * 2) == 0 + chunks = buffer.chunk(cp_world_size * 2, dim=seq_dim) + return torch.cat( + (chunks[cp_rank], chunks[cp_world_size * 2 - cp_rank - 1]), + dim=seq_dim, + ) + + @classmethod + def unshard( + cls, buffer: torch.Tensor, mesh: DeviceMesh, seq_dim: int + ) -> torch.Tensor: + assert cls.ROUND_ROBIN_CYCLE == 2, ( + "The current implementation only works if ROUND_ROBIN_CYCLE is 2." + ) + buffer = buffer.contiguous() + cp_world_size = mesh.size() + + all_buffers = [torch.empty_like(buffer) for _ in range(cp_world_size)] + ft_c.all_gather_inplace(all_buffers, buffer, mesh) + sliced_buffers = [sb for b in all_buffers for sb in b.chunk(2, dim=seq_dim)] + ordered_buffers = list(sliced_buffers) + for i, b in enumerate(sliced_buffers): + if i % 2 == 0: + ordered_buffers[i // 2] = b + else: + ordered_buffers[cp_world_size * 2 - (i // 2) - 1] = b + return torch.cat(ordered_buffers, dim=seq_dim) + + +def _context_parallel_buffers( + mesh: DeviceMesh, + buffers: list[torch.Tensor], + buffer_seq_dims: list[int], +) -> list[torch.Tensor]: + """Shard the buffers along the sequence dimensions according to CP rules.""" + new_buffers = [] + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) + for buffer, seq_dim in zip(buffers, buffer_seq_dims): + new_buffers.append(sharder.shard(buffer, mesh, seq_dim)) + + return new_buffers + + +@contextlib.contextmanager +@torch.no_grad() +def context_parallel( + mesh: DeviceMesh, + *, + buffers: Optional[list[torch.Tensor]] = None, + buffer_seq_dims: Optional[list[int]] = None, + no_restore_buffers: Optional[set[torch.Tensor]] = None, +) -> Generator[None, None, None]: + """ + + ``context_parallel`` is an experimental API to enable context + parallelism (CP). This API performs two actions: 1) patch the SDPA + (``torch.nn.functional.scaled_dot_product_attention``) with the CP-enabled + one, 2) shard ``buffers`` along the sequence dimension and each rank will + preserve the corresponding shard according ``mesh``. + + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (Optional[List[torch.Tensor]]): buffers that the usage depend + on the sequence dimension. Examples are input batch, labels and + positional embedding buffers. These buffers must be sharded along + the sequence dimension to ensure the accuracy. The sharding will + happen in-place, the buffer's shape will change within the context. + The buffers will be restored after the context finishes. + ``no_restore_buffers`` can be used to specify which buffers don't + need to be restored. Note that ``buffers`` should not contain any + nn.Parameter. + buffer_seq_dims (Optional[List[int]]): the sequence dimensions of ``buffers``. + no_restore_buffers (Optional[Set[torch.Tensor]]): buffers in these set + won't be restored after the context exits. This set must be a subset + of ``buffers``. If the buffers won't be used after the context exits, + these buffers can be put in this list to avoid extra restore time. + + .. warning:: + `torch.distributed.tensor.experimental.context_parallel` is a + prototype feature in PyTorch. The API is subject to change. + """ + buffers = [] if buffers is None else buffers + buffer_seq_dims = [] if buffer_seq_dims is None else buffer_seq_dims + no_restore_buffers = set() if no_restore_buffers is None else no_restore_buffers + + if len(buffers) != len(buffer_seq_dims): + raise ValueError( + "`seq_dims` must have the same number of elements as `buffers`." + ) + + for buffer in no_restore_buffers: + # Cannot use `if not buffer in buffers` which will incur tensor comparison. + if not any(b is buffer for b in buffers): + raise ValueError("`no_restore_buffers` must be a subset of `buffers`.") + + original_buffers = [None if b in no_restore_buffers else b.clone() for b in buffers] + chunks = _context_parallel_buffers(mesh, buffers, buffer_seq_dims) + for buffer, chunk in zip(buffers, chunks): + chunk = chunk.clone() + buffer.resize_(chunk.shape) + buffer.copy_(chunk) + + with _context_parallel(seq_dim=2, mesh=mesh): + yield + + for buffer, original_buffer in zip(buffers, original_buffers): + if original_buffer is not None: + buffer.resize_(original_buffer.shape) + buffer.copy_(original_buffer) + + +@torch.no_grad() +def context_parallel_unshard( + mesh: DeviceMesh, + buffers: list[torch.Tensor], + seq_dims: list[int], +) -> list[torch.Tensor]: + """ + Unshard the tensors (e.g., output) that are sharded due to context parallelism. + + Args: + mesh (:class:`DeviceMesh`): the device mesh for the context parallelism. + buffers (List[torch.Tensor]): the buffers to be unsharded. + seq_dims (List[int]): the sequence dimensions of ``buffers``. This list + must have the same length as ``buffers``. + + Returns: + List[torch.Tensor]: the unsharded buffers. + """ + sharder = ( + _RoundRobinLoadBalancer + if _cp_options.enable_load_balance + else _SequentialSharder + ) + return [sharder.unshard(b, mesh, dim) for b, dim in zip(buffers, seq_dims)] + + +def set_rotate_method(rotate_method: str) -> None: + """ + Context Parallel SDPA requires the rotation of kv shards. Users can call this + API to specify which rotation method to use. "alltoall" shuffles the kv shards + using all-to-all collective. While "allgather" gathers the kv shards using + all-gather collective after the first sub-SDPA computation. If this API has not + been called, the default rotate method is "allgather". + + Args: + rotate_method (str): the rotate method to use. Currently only supports + "allgather" and "alltoall". If a different string other than these two + is passed in, the function will raise an error. + + Returns: + None + """ + if rotate_method == "allgather": + _cp_options.rotate_method = _RotateMethod.ALL_GATHER + elif rotate_method == "alltoall": + _cp_options.rotate_method = _RotateMethod.ALL_TO_ALL + else: + raise NotImplementedError( + "Context Parallel does not support " + f"using {rotate_method} for kv shards rotation" + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_func_map.py b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_func_map.py new file mode 100644 index 0000000000000000000000000000000000000000..5094202f839f5d1a407b4d5ac8c3fefca21011a2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_func_map.py @@ -0,0 +1,254 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import functools +from collections.abc import Sequence +from typing import Callable, Optional, Union + +import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.tensor import DeviceMesh, DTensor +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch.utils import _cxx_pytree as pytree +except ImportError: + from torch.utils import _pytree as pytree # type: ignore[no-redef] + + +__all__ = ["local_map"] + +PlacementType = Optional[Sequence[Placement]] +InputPlacements = Optional[tuple[PlacementType, ...]] +OutputPlacements = Union[PlacementType, tuple[PlacementType, ...]] + + +def local_map( + func: Callable, + out_placements: OutputPlacements, + in_placements: Optional[InputPlacements] = None, + in_grad_placements: Optional[InputPlacements] = None, + device_mesh: Optional[DeviceMesh] = None, + *, + redistribute_inputs: bool = False, +): + """ + :meth:`local_map` is an experimental API that allows users to pass :class:`DTensor` s + to a function that is written to be applied on ``torch.Tensor`` s. It is done by extracting + the local components of :class:`DTensor`, call the function, and wrap the outputs to + :class:`DTensor` according to the ``out_placements``. + + Args: + func (Callable): the function to be applied on each local shard of + :class:`DTensor` s. + out_placements (Union[`PlacementType`, Tuple[`PlacementType`, ...]]): + the desired placements of the :class:`DTensor` s in ``func``'s flattened output. + If the flattened ``output`` is a single value, the ``out_placements`` should be + of type `PlacementType`. Otherwise if the flattened ``output`` has multiple + values, the ``out_placements`` should be a tuple of `PlacementType` values 1:1 + mapping to the flattened ``output``. + Besides, for :class:`Tensor` output, we use `PlacementType` as its + placements (a `Tuple[Placement]` value). For non-Tensor output, the `PlacementType` + should be `None`. + Note that the only exception is when no :class:`DTensor` argument is passed + in. In this case, even if `out_placements` is not `None`, the result function + should ignore the desired placements because the function is not running with + :class:`DTensor` s. + in_placements (Tuple[`PlacementType`, ...], optional): + the required placements of the :class:`DTensor` s in the flattened inputs of ``func``. + If ``in_placements`` is specified, :meth:`local_map` would examine whether the + placements of each :class:`DTensor` argument is the same as the required + placements or not. If the placements are not the same and + ``redistribute_inputs`` is ``False``, an exception will be raised. Otherwise if + ``redistribute_inputs`` is ``True``, the argument will be first redistributed to + the required sharding placements before passing its local tensor to ``func``. + The only exception is when required placements are not ``None`` and the + argument is a :class:`torch.Tensor`. In this case, the placements examination + will be skipped and the argument will be directly passed to ``func``. + If ``in_placements`` is ``None``, no placements examination will be performed. + Default: None + in_grad_placements (Tuple[`PlacementType`, ...], optional): + the placements hint of the :class:`DTensor` s gradient corresponds + to the flattened input DTensor. This argument is the hint that user + can give to :meth:`to_local` in case the gradient layout of the + local tensor input does not match its :class:`DTensor` input layout. + If not specified, we will assume the gradient layout of the local + tensor input remains the same as the original :class:`DTensor` input + and use that for gradient computation. Default: None. + device_mesh (:class:`DeviceMesh`, optional): + the device mesh that all the :class:`DTensor` s are placed on. If not + specified, this will be inferred from the input :class:`DTensor` s' device + mesh. `local_map` requires every :class:`DTensor` s to be placed on the same + device mesh. Default: None. + redistribute_inputs (bool, optional): + the bool value indicating whether to reshard the input :class:`DTensor` s when + their placements are different from the required input placements. If this + value is ``False`` and some :class:`DTensor` input has a different placement, + an exception will be raised. Default: False. + + Returns: + A ``Callable`` that applies ``func`` to each local shard of the input :class:`DTensor` + and returns a :class:`DTensor` constructed from the return value of ``func``. + + Raises: + AssertionError: If the input :class:`DTensor` is not placed on the same device + mesh, or if they are placed on a different device mesh than the ``device_mesh`` + argument passed in. + + AssertionError: For any non-DTensor output, we require its corresponding + output placement in ``out_placements`` be None. An AssertionError will be raised + if this is not the case. + + ValueError: If ``redistribute_inputs=False`` but the input :class:`DTensor` needs + a redistribution according to ``in_placements``. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> def mm_allreduce_forward(device_mesh, W, X): + >>> partial_sum_tensor = torch.mm(W, X) + >>> reduced_tensor = funcol.all_reduce(partial_sum_tensor, "sum", device_mesh) + >>> return reduced_tensor + >>> + >>> W = torch.randn(12, 8, requires_grad=False) + >>> X = torch.randn(8, 16, requires_grad=False) + >>> Y = torch.mm(W, X) + >>> row_wise = [Shard(0)] # row-wise sharding placements on 1-d mesh + >>> col_wise = [Shard(1)] # col-wise sharding placements on 1-d mesh + >>> + >>> # local_mm_allreduce_forward is the function wrapped with DTensor/Tensor convertion + >>> local_mm_allreduce_forward = local_map( + >>> mm_allreduce_forward, + >>> out_placements=[Replicate()], + >>> in_placements=[col_wise, row_wise], + >>> device_mesh=device_mesh, + >>> ) + >>> + >>> W_dt = distribute_tensor( + ... W, device_mesh, (col_wise) + ... ) # col-wisely sharded W tensor + >>> X_dt = distribute_tensor( + ... X, device_mesh, (row_wise) + ... ) # row-wisely sharded X tensor + >>> Y_dt = local_mm_allreduce_forward( + ... device_mesh, W_dt, X_dt + ... ) # apply local_mm_allreduce_forward to DTensors + + .. note:: This API is currently experimental and subject to change + """ + + def wrapped(device_mesh: Optional[DeviceMesh], *args, **kwargs): + # process input args + flat_args, args_spec = pytree.tree_flatten(args) + if in_placements is not None: + assert len(in_placements) == len(flat_args), ( + f"in_placements length {len(in_placements)} does not match the number " + f"of input args {len(flat_args)}!" + ) + + # we assume every DTensor object is placed on the same device mesh + flat_local_args = [] + seen_dtensor_arg = False + for idx, arg in enumerate(flat_args): + if isinstance(arg, DTensor): + # TODO: the current code doesn't consider the uneven sharding case + # Need to think about what the consequence is when the input DTensor + # is uneven sharded. + if device_mesh is None: # infer device mesh from the DTensor arg + device_mesh = arg.device_mesh + + # this function is applied to at least one DTensor argument + seen_dtensor_arg = True + + assert arg.device_mesh == device_mesh, ( + f"arg {arg} in local_map has a mismatched device mesh: " + f"{arg} has device mesh {arg.device_mesh} while " + f"the expected device mesh is {device_mesh}!" + ) + if in_placements is not None: + spec = in_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects placements but received {spec}!" + ) + + if not isinstance(spec, tuple): + spec = tuple(spec) + + if arg.placements != spec: + if redistribute_inputs: + # redistribute to input placements + arg = arg.redistribute(device_mesh, spec) + else: + raise ValueError( + f"arg {arg} in local_map has a mismatched placements: " + f"arg placements is {arg.placements} but the input " + f"placements is {spec}! " + "If redistribute_inputs is wanted, set " + "redistribute_inputs=True to local_map." + ) + + if in_grad_placements is not None: + spec = in_grad_placements[idx] + assert spec is not None, ( + f"DTensor input {arg} expects in grad placements but received {spec}!" + ) + if not isinstance(spec, tuple): + spec = tuple(spec) + local_arg = arg.to_local(grad_placements=spec) + else: + local_arg = arg.to_local() + + if isinstance(local_arg, AsyncCollectiveTensor): + local_arg = local_arg.wait() + + flat_local_args.append(local_arg) + else: + # Non-Tensor input must have None in `in_placements` + if in_placements is not None and not isinstance(arg, torch.Tensor): + spec = in_placements[idx] + assert spec is None, ( + f"Non-Tensor input {arg} expects None placements " + f"but received {spec}!" + ) + + flat_local_args.append(arg) + + local_args = pytree.tree_unflatten(flat_local_args, args_spec) + + out = func(*local_args, **kwargs) + + if seen_dtensor_arg: + # process output + flat_out, out_spec = pytree.tree_flatten(out) + + flat_dist_out = [] + out_placements_tuple = ( + out_placements + if isinstance(out_placements, tuple) + else (out_placements,) + ) + assert len(flat_out) == len(out_placements_tuple), ( + "local_map requires one PlacementType be provided for each output value," + f" received {len(out_placements_tuple)} out_placements but" + f" {len(flat_out)} is expected!" + ) + for out, spec in zip(flat_out, out_placements_tuple): + if isinstance(out, torch.Tensor): + assert not isinstance(out, DTensor), ( + f"torch.Tensor output expected but received {type(out)}: {out}" + ) + + flat_dist_out.append( + DTensor.from_local(out, device_mesh, spec, run_check=False) + ) + else: + assert spec is None, ( + f"Non-tensor output {out} expects None placements but received {spec}!" + ) + + flat_dist_out.append(out) + + return pytree.tree_unflatten(flat_dist_out, out_spec) + else: + return out + + return functools.partial(wrapped, device_mesh) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_register_sharding.py b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_register_sharding.py new file mode 100644 index 0000000000000000000000000000000000000000..ab17a00643f3731daa5a49d7b6d799053b254744 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_register_sharding.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from collections.abc import Sequence +from functools import partial +from typing import Callable, Union + +import torch +from torch._ops import OpOverload +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpStrategy, + PlacementList, + RuntimeSchemaInfo, + StrategyType, + TupleStrategy, +) +from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + +__all__ = ["register_sharding"] + + +def register_sharding(op: Union[OpOverload, list[OpOverload]]): + """ + :meth:`register_sharding` is an experimental API that allows users to register sharding + strategies for an operator when the tensor inputs and outputs are DTensor. + It can be useful when: (1) there doesn't exist a default sharding strategy for ``op``, + e.g. when ``op`` is a custom operator that is not supported by :class:`DTensor`; (2) + when users would like to overwrite default sharding strategies of existing operators. + + Args: + op (Union[OpOverload, List[OpOverload]]): + An op or a list of ops to register the customized sharding function. + + Returns: + A function decorator which can be used to wrap a function that defines the sharding + strategy for the operator specified in ``op``. The defined sharding strategy will be + registered to DTensor and will override the default sharding strategy if DTensor has + already implemented the operator. The customized sharding function takes the same inputs + as the original op (except that if an arg is a :class:`torch.Tensor`, it will be + replaced by a tensor-like object that DTensor uses internally). The function should + return a sequence of 2-tuples, each specifying acceptable output placements and its + corresponding intput placements. + + Example: + >>> # xdoctest: +SKIP("distributed") + >>> @register_sharding(aten._softmax.default) + >>> def custom_softmax_sharding(x, dim, half_to_float): + >>> softmax_dim = dim if dim >= 0 else dim + x.ndim + >>> acceptable_shardings = [] + >>> + >>> all_replicate = ([Replicate()], [Replicate(), None, None]) + >>> acceptable_shardings.append(all_replicate) + >>> + >>> for sharding_dim in range(x.ndim): + >>> if sharding_dim != softmax_dim: + >>> all_sharded = ( + >>> [Shard(sharding_dim)], + >>> [Shard(sharding_dim), None, None], + >>> ) + >>> acceptable_shardings.append(all_sharded) + >>> + >>> return acceptable_shardings + + .. note:: This API is currently experimental and subject to change + """ + + def custom_strategy( + custom_sharding_fn: Callable[ + ..., Sequence[tuple[PlacementList, PlacementList]] + ], + op_schema: OpSchema, + ) -> StrategyType: + def strategy_to_spec(strategy: object) -> object: + if isinstance(strategy, OpStrategy): + # take the output spec from the first strategy + return strategy.strategies[0].output_spec + elif isinstance(strategy, TupleStrategy): + return tuple(strategy_to_spec(s) for s in strategy.childs) + else: + return strategy + + mesh = op_schema.get_mesh_from_args() + + args_schema = tuple(strategy_to_spec(i) for i in op_schema.args_schema) + kwargs_schema = { + k: strategy_to_spec(v) for k, v in op_schema.kwargs_schema.items() + } + + acceptable_shardings = custom_sharding_fn(*args_schema, **kwargs_schema) + + single_mesh_dim_strategies: list[PlacementList] = [] + for output_specs, input_specs in acceptable_shardings: + single_mesh_dim_strategies.append(output_specs + input_specs) + + # TODO: handle out variant ops + return expand_to_full_mesh_op_strategy( + mesh, + op_schema, + single_mesh_dim_strategies, + input_index=len(op_schema.op._schema.returns), + inplace_op=op_schema.is_inplace_op(), + ) + + def wrapper(custom_sharding_fn): + def derive_schema_info(op): + # NOTE: without user directly providing RuntimeSchemaInfo, for now + # we create it in a conservative fashion as follows: + # 1. let static_argnum be the first int argument + # 2. let static_kwargkey include all the int type kwargs + # 3. always set needs_pytree=True + static_argnum = 100 + static_kwargkey: list[str] = [] + for i, arg in enumerate(op._schema.arguments): + if isinstance(arg.type, torch.IntType) or ( + isinstance(arg.type, torch.OptionalType) + and isinstance(arg.type.getElementType(), torch.IntType) + ): + static_argnum = min(i, static_argnum) + if arg.kwarg_only: + static_kwargkey.append(arg.name) + return RuntimeSchemaInfo( + static_argnum, static_kwargkey or None, needs_pytree=True + ) + + overloads = op if isinstance(op, list) else [op] + for overload in overloads: + DTensor._op_dispatcher.sharding_propagator.register_op_strategy( + overload, + partial(custom_strategy, custom_sharding_fn), + derive_schema_info(overload), + ) + + return custom_sharding_fn + + return wrapper diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_tp_transform.py b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_tp_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..1d59f07fa3bf18a12c04059cd54229595d8fc9f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/experimental/_tp_transform.py @@ -0,0 +1,554 @@ +# mypy: allow-untyped-defs +import copy +import operator +from collections.abc import Sequence +from typing import Any, cast, Optional + +import torch +from torch._subclasses.fake_tensor import FakeTensor +from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._op_schema import ( + OpSchema, + OpSpec, + OutputSharding, + OutputSpecType, +) +from torch.distributed.tensor._redistribute import redistribute_local_tensor +from torch.distributed.tensor.parallel.style import ColwiseParallel, ParallelStyle +from torch.distributed.tensor.placement_types import Placement, Replicate, Shard +from torch.export import ExportedProgram +from torch.export.exported_program import ExportGraphSignature +from torch.fx import GraphModule +from torch.fx.experimental.proxy_tensor import make_fx +from torch.fx.node import Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.utils import _pytree as pytree + + +__all__ = ["tensor_parallel_transformation"] + +aten = torch.ops.aten + + +def tensor_parallel_transformation( + exported_program: ExportedProgram, + rank: int, + world_size: int, + device_type: str, + parallel_strategies: dict[str, ParallelStyle], +) -> ExportedProgram: + """ + The entry point function to perform graph transformations on an exported program + to transform a single-device graph into a tensor parallel graph. + + .. warning:: + This API is experimental and subject to change. + """ + + gm = exported_program.graph_module + sig = copy.deepcopy(exported_program.graph_signature) + state_dict = copy.copy(exported_program.state_dict) + + with gm._set_replace_hook(sig.get_replace_hook()): + res = _TensorParallelTransformPass( + rank, + world_size, + device_type, + state_dict, + exported_program.graph_signature, + parallel_strategies, + )(gm) + assert res is not None + gm = res.graph_module + + return exported_program._update(gm, sig, state_dict=state_dict) + + +class _TensorParallelTransformPass(PassBase): + """ + This pass is responsible for transforming a single-device graph into a tensor parallel + graph. It will mark the OpSpec of each node in the graph, partition the graph into + distributed graph, then shard the parameters/buffers accordingly. + """ + + def __init__( + self, + rank: int, + world_size: int, + device_type: str, + state_dict: dict[str, torch.Tensor], + graph_signature: ExportGraphSignature, + parallel_strategies: dict[str, ParallelStyle], + ) -> None: + super().__init__() + self.rank = rank + self.mesh = DeviceMesh(device_type, torch.arange(world_size)) + self.state_dict: dict[str, torch.Tensor] = state_dict + self.graph_signature = graph_signature + self.parallel_strategies = parallel_strategies + + def call(self, graph_module) -> PassResult: + gm = copy.deepcopy(graph_module) + + parameter_placements = _generate_parameter_and_buffer_placements( + list(self.state_dict.keys()), self.parallel_strategies + ) + placement_strategies = _mark_sharding( + gm, self.graph_signature, self.mesh, parameter_placements + ) + _partitioner(gm) + _shard_state_dict( + self.state_dict, placement_strategies, self.graph_signature, self.mesh + ) + return PassResult(gm, True) + + +def _generate_parameter_and_buffer_placements( + params_and_buffers: list[str], + parallel_strategies: dict[str, ParallelStyle], +) -> dict[str, Placement]: + """ + Build parameter placements based on the give parallel style of linear layers. + """ + parameter_placements: dict[str, Placement] = {} + for linear_fqn, parallel_style in parallel_strategies.items(): + weight_fqn = f"{linear_fqn}.weight" + bias_fqn = f"{linear_fqn}.bias" + assert weight_fqn in params_and_buffers + parameter_placements[weight_fqn] = ( + Shard(0) if parallel_style == ColwiseParallel else Shard(1) + ) + if bias_fqn in params_and_buffers: + parameter_placements[bias_fqn] = ( + Shard(0) if parallel_style == ColwiseParallel else Replicate() + ) + return parameter_placements + + +def _mark_tensor_parallel_shardings( + gm: GraphModule, + graph_signature: ExportGraphSignature, + mesh: DeviceMesh, + parameter_placements: dict[str, Placement], +) -> dict[Node, OpSpec]: + """ + Mark the placement strategies of the parameter and buffer placeholder nodes. + """ + placement_strategies: dict[Node, OpSpec] = {} + num_params_and_buffers = len(graph_signature.inputs_to_parameters) + len( + graph_signature.inputs_to_buffers + ) + placeholder_idx: int = 0 + for node in gm.graph.nodes: + if node.op == "placeholder": + if placeholder_idx < num_params_and_buffers: + fqn: str = _get_input_node_fqn(node.name, graph_signature) + placement: Placement = ( + parameter_placements[fqn] + if fqn in parameter_placements + else Replicate() + ) + placement_strategies[node] = _create_placement_strategy( + node, + mesh, + placements=(placement,), + ) + placeholder_idx += 1 + else: + placement_strategies[node] = _create_placement_strategy( + node, + mesh, + placements=(Replicate(),), + ) + return placement_strategies + + +def _get_input_node_fqn(input_name: str, graph_signature: ExportGraphSignature) -> str: + """ + Return the FQN of an input node. + """ + if input_name in graph_signature.inputs_to_parameters: + return graph_signature.inputs_to_parameters[input_name] + elif input_name in graph_signature.inputs_to_buffers: + return graph_signature.inputs_to_buffers[input_name] + else: + raise ValueError( + f"{input_name} not found in inputs_to_parameters or inputs_to_buffers" + ) + + +def _mark_sharding( + gm: GraphModule, + graph_signature: ExportGraphSignature, + mesh: DeviceMesh, + parameter_placements: dict[str, Placement], +) -> dict[Node, OpSpec]: + """ + Mark the sharding strategy for each node in the graph module. + """ + placement_strategies: dict[Node, OpSpec] = _mark_tensor_parallel_shardings( + gm, + graph_signature, + mesh, + parameter_placements, + ) + + for node in gm.graph.nodes: + if node.op == "placeholder": + if node not in placement_strategies: + placement_strategies[node] = _create_placement_strategy( + node, mesh, placements=(Replicate(),) + ) + node.meta["sharding"] = placement_strategies[node] + elif node.op == "call_function": + if node.target == operator.getitem: + input_nodes = node.all_input_nodes + assert len(input_nodes) == 1, ( + f"non-compute op only support one input now, found node: {node} with length of inputs: {len(node.args)}" + ) + arg_strategy = placement_strategies[input_nodes[0]] + placement_strategies[node] = _create_placement_strategy( + node, + mesh, + placements=arg_strategy.output_spec.placements, + input_specs=_get_input_node_specs(node, placement_strategies), + ) + node.meta["sharding"] = placement_strategies[node] + else: + op_schema = _get_op_schema(node, placement_strategies) + + # get DTensor specs for inputs and outputs + if ( + op_schema.op + not in DTensor._op_dispatcher.sharding_propagator.op_strategy_funcs + and op_schema.op + not in DTensor._op_dispatcher.sharding_propagator.op_to_rules + ): + # Mark all as replicated + output_sharding = _generate_default_output_sharding( + node, + mesh, + op_schema, + ) + else: + output_sharding = DTensor._op_dispatcher.sharding_propagator.propagate_op_sharding( # type: ignore[assignment] + op_schema, + ) + placement_strategies[node] = OpSpec( + output_specs=_get_output_spec_from_output_sharding(output_sharding), + input_specs=output_sharding.redistribute_schema.args_spec + if output_sharding.redistribute_schema is not None + else _get_input_node_specs(node, placement_strategies), + ) + node.meta["sharding"] = placement_strategies[node] + elif node.op == "output": + node.meta["sharding"] = None + else: + raise RuntimeError(f"op code {node.op} not supported") + return placement_strategies + + +def _get_output_spec_from_output_sharding( + output_sharding: OutputSharding, +) -> DTensorSpec: + """ + Util function to extract output spec from output sharding. + """ + if isinstance(output_sharding.output_spec, DTensorSpec): + return output_sharding.output_spec + else: + # For ops that return multiple outputs, the outputs should have the same output spec + assert isinstance(output_sharding.output_spec, Sequence) + assert output_sharding.output_spec[0] is not None + output_sharding.output_spec[0].tensor_meta = None + return output_sharding.output_spec[0] + + +def _create_placement_strategy( + node: Node, + mesh: DeviceMesh, + placements: tuple[Placement, ...], + input_specs: Optional[Sequence[DTensorSpec]] = None, +) -> OpSpec: + """ + Util function to construct an OpSpec for a given node. + """ + placement = OpSpec( + input_specs=input_specs, + output_specs=DTensorSpec( + mesh=mesh, + placements=placements, + ), + ) + _populate_tensor_meta(node, placement.output_specs) + return placement + + +def _populate_tensor_meta(node: Node, output_spec: OutputSpecType) -> None: + """ + Util function to populate tensor meta of output_spec based on node metadata. + """ + if isinstance(node.meta["val"], Sequence): + assert isinstance(output_spec, Sequence) + for spec, fake_tensor in zip(output_spec, node.meta["val"]): + assert spec is not None + spec.tensor_meta = TensorMeta( + shape=fake_tensor.shape, + stride=fake_tensor.stride(), + dtype=fake_tensor.dtype, + ) + else: + assert isinstance(output_spec, DTensorSpec) + output_spec.tensor_meta = TensorMeta( + shape=node.meta["val"].shape, + stride=node.meta["val"].stride(), + dtype=node.meta["val"].dtype, + ) + + +def _generate_default_output_sharding( + node: Node, + mesh: DeviceMesh, + op_schema: OpSchema, +) -> OutputSharding: + """ + Util function to create a default output sharding that suggests Replicate placement for both args and outputs. + """ + + def update_arg_spec(arg_spec: DTensorSpec) -> DTensorSpec: + return DTensorSpec( + mesh=arg_spec.mesh, + placements=(Replicate(),), + tensor_meta=arg_spec.tensor_meta, + ) + + new_op_schema = OpSchema( + op=op_schema.op, + args_schema=pytree.tree_map_only( + DTensorSpec, update_arg_spec, op_schema.args_schema + ), + kwargs_schema=op_schema.kwargs_schema, + ) + + def create_output_spec(tensor: FakeTensor) -> DTensorSpec: + return DTensorSpec( + mesh=mesh, + placements=(Replicate(),), + tensor_meta=TensorMeta( + shape=tensor.shape, + stride=tensor.stride(), + dtype=tensor.dtype, + ), + ) + + return OutputSharding( + output_spec=pytree.tree_map_only( + FakeTensor, create_output_spec, node.meta["val"] + ), + redistribute_schema=new_op_schema, + needs_redistribute=True, + ) + + +def _partitioner(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Graph partitioner that partitions the single device graph + to distributed graph + """ + for node in gm.graph.nodes: + node_sharding = node.meta["sharding"] + if node.op == "placeholder": + out_spec = node_sharding.output_spec + local_val = _partition_val(node.meta["val"], out_spec) + # update node value + node.meta["val"] = local_val + elif node.op == "call_function": + out_spec = node_sharding.output_spec + # check if there's misaligned sharding, insert reshard if there is + expected_input_specs = node_sharding.input_specs + for idx, input_arg in enumerate(node.all_input_nodes): + input_arg_sharding = input_arg.meta["sharding"] + input_arg_spec = input_arg_sharding.output_spec + desired_spec = ( + out_spec + if expected_input_specs is None + else expected_input_specs[idx] + ) + if input_arg_spec != desired_spec: + _insert_reshard_gm( + gm, node, input_arg, input_arg_spec, desired_spec + ) + # convert output val to its local component + output_val = node.meta["val"] + node.meta["val"] = _partition_val(output_val, out_spec) + elif node.op == "output": + for input_arg in node.all_input_nodes: + # input args of output should be Replicate, otherwise redistribution is needed. + input_args_to_check: Sequence[Node] = ( + input_arg if isinstance(input_arg, Sequence) else [input_arg] + ) + for arg in input_args_to_check: + arg_sharding = arg.meta["sharding"] + arg_spec = arg_sharding.output_spec + desired_spec = copy.copy(arg_spec) + desired_spec.placements = (Replicate(),) + if arg_spec != desired_spec: + _insert_reshard_gm(gm, node, arg, arg_spec, desired_spec) + else: + raise RuntimeError(f"op code {node} not supported") + + _clean_up_graph_metadata(gm) + gm.graph.lint() + gm.recompile() + return gm + + +def _partition_val(val: Any, spec: DTensorSpec) -> Any: + """ + util function to convert a full tensor val to its local component + """ + if isinstance(val, torch.Tensor): + local_shard = val + if val.ndim == 0: + # If it's already a scalar tensor, it is already local, we don't + # need to do anything + return local_shard + + for idx, placement in enumerate(spec.placements): + if placement.is_shard(): + placement = cast(Shard, placement) + num_chunks = spec.mesh.size(mesh_dim=idx) + my_coord = spec.mesh.get_coordinate() + assert my_coord is not None, "current rank not in mesh!" + my_coord_on_mesh_dim = my_coord[idx] + local_shard = placement._split_tensor( + local_shard, num_chunks, with_padding=False, contiguous=True + )[0][my_coord_on_mesh_dim] + return local_shard + elif isinstance(val, (list, tuple)): + return val.__class__(_partition_val(v, spec) for v in val) + else: + raise RuntimeError(f"val type {type(val)} not supported") + + +def _insert_reshard_gm( + gm: torch.fx.GraphModule, + node: Node, + input_arg: Node, + input_arg_spec: DTensorSpec, + desired_spec: DTensorSpec, +) -> None: + """ + Transform the graph for tensor redistribution. + """ + input_arg_spec.tensor_meta = input_arg.meta["tensor_meta"] + desired_spec.tensor_meta = input_arg.meta["tensor_meta"] + input_arg_tensor = input_arg.meta["val"] + + # insert reshard operation + def reshard_fn(local_tensor: torch.Tensor) -> torch.Tensor: + return redistribute_local_tensor( + local_tensor, + input_arg_spec, + desired_spec, + ) + + reshard_gm = make_fx(reshard_fn)(input_arg_tensor) + reshard_gm_nodes = list(reshard_gm.graph.nodes) + input_node = reshard_gm_nodes[0] + with gm.graph.inserting_before(node): + # copy nn_module_stack metadata for output, all-reduce nodes + for reshard_node in reshard_gm.graph.nodes: + if reshard_node.op not in ["placeholder", "output"]: + reshard_node.meta["nn_module_stack"] = ( + copy.copy(input_arg.meta["nn_module_stack"]) + if not input_arg.op == "placeholder" + else copy.copy(node.meta["nn_module_stack"]) + ) + output_node = gm.graph.graph_copy( + reshard_gm.graph, + val_map={ + input_node: input_arg, + }, + ) + node.replace_input_with(input_arg, output_node) # type: ignore[arg-type] + + +def _clean_up_graph_metadata(gm: torch.fx.GraphModule) -> None: + """ + Clean up the graph by removing sharding and partitioning related metadata + """ + for node in gm.graph.nodes: + if "sharding" in node.meta: + del node.meta["sharding"] + if "val" in node.meta and isinstance(node.meta["val"], torch.Tensor): + local_tensor_meta = _extract_tensor_metadata(node.meta["val"]) + node.meta["tensor_meta"] = local_tensor_meta + + +def _get_input_node_specs( + node: Node, placement_strategies: dict[Node, OpSpec] +) -> tuple[DTensorSpec, ...]: + """ + Get the input specs of a node. + """ + input_specs_list: list[DTensorSpec] = [] + for input_arg in node.all_input_nodes: + if input_arg in placement_strategies: + output_spec = placement_strategies[input_arg].output_specs + assert isinstance(output_spec, DTensorSpec) + input_specs_list.append(output_spec) + else: + raise ValueError(f"{input_arg} does not have output_spec populated.") + return tuple(input_specs_list) + + +def _get_op_schema(node: Node, placement_strategies: dict[Node, OpSpec]) -> OpSchema: + """ + Util function to construct the operator schema of a node. + """ + args_schema_list = pytree.tree_map_only( + Node, lambda arg: placement_strategies[arg].output_specs, node.args + ) + op_schema = OpSchema( + op=cast(torch._ops.OpOverload, node.target), + args_schema=tuple(args_schema_list), + kwargs_schema=cast(dict[str, object], node.kwargs), + ) + return op_schema + + +def _shard_state_dict( + state_dict: dict[str, torch.Tensor], + placement_strategies: dict[Node, OpSpec], + graph_signature: ExportGraphSignature, + mesh: DeviceMesh, +) -> None: + """ + Inplace partition the weights based on the OpSpec + """ + for node, op_spec in placement_strategies.items(): + if node.op != "placeholder": + continue + if node.name in graph_signature.inputs_to_parameters: + fqn = graph_signature.inputs_to_parameters[node.name] + elif node.name in graph_signature.inputs_to_buffers: + fqn = graph_signature.inputs_to_buffers[node.name] + else: + continue + assert fqn in state_dict, f"{fqn} not found in state dict: {state_dict.keys()}" + + original_param = state_dict[fqn] + dtensor_param = distribute_tensor( + original_param, + mesh, + op_spec.output_spec.placements, + ) + local_param = dtensor_param.to_local() + state_dict[fqn] = ( + torch.nn.Parameter(local_param) + if isinstance(original_param, torch.nn.Parameter) + else local_param + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__init__.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e5468263d30aa49506380759b032a627cfea574d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__init__.py @@ -0,0 +1,25 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from torch.distributed.tensor.parallel.api import parallelize_module +from torch.distributed.tensor.parallel.loss import loss_parallel +from torch.distributed.tensor.parallel.style import ( + ColwiseParallel, + ParallelStyle, + PrepareModuleInput, + PrepareModuleInputOutput, + PrepareModuleOutput, + RowwiseParallel, + SequenceParallel, +) + + +__all__ = [ + "ColwiseParallel", + "ParallelStyle", + "PrepareModuleInput", + "PrepareModuleInputOutput", + "PrepareModuleOutput", + "RowwiseParallel", + "SequenceParallel", + "parallelize_module", + "loss_parallel", +] diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79f5ff10351455a4c1aadc866f02cee36d116427 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84c84d65eb3b2a705b0e57b29e7902258c738954 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_data_parallel_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a224bae92ab9a8106a04eef06fbeead000f394a1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb223b8cae9f8314d41711853deaa8fd73a8a63 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/api.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76949982015b3f41944d1ff98b5fd9dc07dcac7c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/ddp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5306bfeaad3d71fd043ebd169f12c309b88483c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/fsdp.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2640bdb73282590cfff152921c22957a651baf56 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/input_reshard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3bf47371ba61bbedc58d501edf8b1383a0df8fe7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/loss.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..599e67358475e5b13ebecb023d3fbec0ed089eba Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/__pycache__/style.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..50e9ece999f4ac0872994197fc1c83861b1ba691 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/_data_parallel_utils.py @@ -0,0 +1,51 @@ +from functools import partial +from typing import no_type_check, Optional + +import torch +from torch.distributed._functional_collectives import AsyncCollectiveTensor +from torch.distributed.tensor import DTensor +from torch.distributed.tensor._dtensor_spec import DTensorSpec + + +@no_type_check +def sync_grad_hook(grad, *, device_handle=None, compute_stream=None): + if isinstance(grad, AsyncCollectiveTensor): + if compute_stream is not None: + with device_handle.stream(compute_stream): + grad = grad.wait() + else: + grad = grad.wait() + + return grad + + +def _flatten_tensor( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, Optional[DTensorSpec]]: + if isinstance(tensor, DTensor): + tensor._local_tensor.requires_grad_() + return tensor._local_tensor, tensor._spec + return tensor, None + + +@no_type_check +def _unflatten_tensor(tensor, spec, *, device_handle=None, compute_stream=None): + # unflatten would mainly be called everytime FSDP allgather parameters. + result = DTensor.from_local( + tensor, + spec.mesh, + spec.placements, + run_check=False, + shape=spec.shape, + stride=spec.stride, + ) + if tensor.requires_grad: + # only register the hook if the tensor requires grad + tensor.register_hook( + partial( + sync_grad_hook, + device_handle=device_handle, + compute_stream=compute_stream, + ) + ) + return result diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/_utils.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eaf0208ea6e12a3662e757e433cd59176a4d75fa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/_utils.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Union + +from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.tensor import DeviceMesh +from torch.distributed.tensor.placement_types import Placement + + +try: + from torch._dynamo.external_utils import is_compiling as is_torchdynamo_compiling +except Exception: + + def is_torchdynamo_compiling(): # type: ignore[misc] + return False + + +LayoutsType = Union[Placement, tuple[Placement, ...]] + + +def _deprecate_warnings(func_name: str, extra_msg: str) -> None: + """ + Inject common validation logics for `_prepare_input` funcs via this decorator. + + Include verifying that input needs to be either a :class:`Tensor` or :class:`DTensor` + and only 1D :class:`DeviceMesh` is passed in. + """ + # TODO: Will follow up with dynamo POC to make warnings.warn working with dynamo. + if not is_torchdynamo_compiling(): + warnings.warn( + f"{func_name} is deprecated and will be removed soon. {extra_msg}", + FutureWarning, + stacklevel=3, + ) + + +def _validate_tp_mesh_dim( + device_mesh: DeviceMesh, +) -> None: + """ + Check whether TP mesh dimension is valid or not. + + Args: + device_mesh (:class:`DeviceMesh`): + The `device_mesh` where we perform + Tensor Parallelism on. + + Return: + `True` if the mesh dimension + is valid, `False` otherwise. + """ + if device_mesh.ndim > 1: + raise ValueError( + f"Tensor Parallel only accepts a 1D DeviceMesh, but found {device_mesh.ndim}D!" + 'If you have a 2-D or N-D device_mesh, consider passing in device_mesh["tp"]' + ) + + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + # if a root mesh is not the same as device_mesh, + # meaning the device_mesh is sliced out from the root mesh. + if root_mesh and root_mesh != device_mesh: + tp_mesh_dim_in_root = _mesh_resources.get_root_mesh_dim(device_mesh) + if tp_mesh_dim_in_root != root_mesh.ndim - 1: + raise RuntimeError( + f"Found TP device_mesh on the {tp_mesh_dim_in_root} dimension of its parent mesh.", + "Currently we only support intranode TP and TP needs to be the innermost dimension on its parent mesh.", + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/api.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/api.py new file mode 100644 index 0000000000000000000000000000000000000000..2fc1c3a78484f4132eda55dba78a0879f431d069 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/api.py @@ -0,0 +1,129 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +import warnings +from fnmatch import fnmatch +from typing import Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import _mesh_resources, DeviceMesh +from torch.distributed.tensor.parallel._utils import _validate_tp_mesh_dim +from torch.distributed.tensor.parallel.style import ParallelStyle + + +__all__ = ["parallelize_module"] + + +def parallelize_module( # type: ignore[return] + module: nn.Module, + device_mesh: Optional[DeviceMesh] = None, + parallelize_plan: Optional[Union[ParallelStyle, dict[str, ParallelStyle]]] = None, + *, + src_data_rank: Optional[int] = 0, +) -> nn.Module: + """ + Apply Tensor Parallelism in PyTorch by parallelizing modules or sub-modules based on a user-specified plan. + + We parallelize module or sub_modules based on a parallelize_plan. The parallelize_plan contains + :class:`ParallelStyle`, which indicates how user wants the module or sub_module + to be parallelized. + + User can also specify different parallel style per module fully qualified name (FQN). + + Note that ``parallelize_module`` only accepts a 1-D :class:`DeviceMesh`, if you have a 2-D or N-D :class:`DeviceMesh`, + slice the DeviceMesh to a 1-D sub DeviceMesh first then pass to this API(i.e. ``device_mesh[\"tp\"]``) + + Args: + module (:class:`nn.Module`): + Module to be parallelized. + device_mesh (:class:`DeviceMesh`, optional): + Object which describes the mesh topology of devices for the DTensor. + If not specified, the call must be under a DeviceMesh context. + parallelize_plan (Union[:class:`ParallelStyle`, Dict[str, :class:`ParallelStyle`]], optional): + The plan used to parallelize the module. It can be either a + :class:`ParallelStyle` object which contains how we prepare + input/output for Tensor Parallelism or it can be a dict of module + FQN and its corresponding :class:`ParallelStyle` object. If not + specified, the call will do nothing at the moment. + Keyword args: + src_data_rank (int, optional): the rank of the source data for the logical/global tensor, it is used by + :meth:`distribute_tensor` to scatter/broadcast the shards/replicas to other ranks. By default, + we use ``group_rank=0`` on each DeviceMesh dimension as the source data to preserve the single-device + semantic. If passing ``None`` explicitly, :meth:`parallelize_module` simply uses its local data instead + of trying to preserve the single-device semantic via scatter/broadcast. Default: 0 + Return: + A :class:`nn.Module` object parallelized. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> + >>> # Define the module. + >>> m = Model(...) + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> m = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel(), "w2": RowwiseParallel()}) + >>> + + .. note:: For complex module architecture like Attention, MLP layers, we recommend composing + different ParallelStyles together (i.e. ``ColwiseParallel`` and ``RowwiseParallel``) and pass + as a parallelize_plan, to achieves the desired sharding computation. + """ + torch._C._log_api_usage_once("torch.distributed.tensor.parallel.parallelize_module") + + device_mesh = device_mesh or _mesh_resources.get_current_mesh() + _validate_tp_mesh_dim(device_mesh) + + if parallelize_plan is None: + warnings.warn( + "No parallelize_plan is provided and auto-parallel is not supported " + "at the moment, so this parallelize_module call will do nothing." + ) + return module + + # note: The RNG tracker will be initialized in distribute_tensor() call if it hasn't + # been initialized. + + if isinstance(parallelize_plan, ParallelStyle): + parallelize_plan.src_data_rank = src_data_rank + return parallelize_plan._apply(module, device_mesh) + elif isinstance(parallelize_plan, dict): + for module_path, parallelize_style in parallelize_plan.items(): + path_splits = module_path.split(".") + if len(path_splits) == 0: + raise ValueError( + "Expect module path to be non-empty, but got empty string!" + ) + while path_splits: + atom = path_splits.pop(0) + matched_children = filter( + # `t[0]` is child name + lambda t: fnmatch(t[0], atom), + module.named_children(), + ) + # apply the plan to all matched submodules + for _, submodule in matched_children: + if path_splits: + # we haven't reached the leaf, apply in dict style + leaf_path = ".".join( + path_splits + ) # rest of the path after `atom` + parallelize_module( + submodule, + device_mesh, + {leaf_path: parallelize_style}, + src_data_rank=src_data_rank, + ) + else: + # otherwise, directly apply style to this submodule + parallelize_module( + submodule, + device_mesh, + parallelize_style, + src_data_rank=src_data_rank, + ) + return module + else: + raise TypeError( # pyre-ignore[7] + "Expect Union[ParallelStyle, Dict[str, ParallelStyle]] for" + f" parallelize_plan, {type(parallelize_plan)} found!" + ) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/ddp.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/ddp.py new file mode 100644 index 0000000000000000000000000000000000000000..53e2c61f2e4309028b42ceb3d321d3c7555bea0e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/ddp.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from typing import Any, Optional + +import torch.nn as nn +from torch.distributed.tensor.parallel._data_parallel_utils import ( + _flatten_tensor, + _unflatten_tensor, +) + + +__all__ = [] # type: ignore[var-annotated] + + +def _get_submodule_n_params(module: nn.Module, path: str): + """ + Get submodule and the direct path of parameter from the module + """ + if "." in path: + path_list = path.split(".") + parent_module_path = ".".join(path_list[:-1]) + module = module.get_submodule(parent_module_path) + path = path_list[-1] + return module, path + + +def _update_module_param(param_list: list[tuple[nn.Module, str, nn.Parameter]]): + """ + Update parameters within the module + """ + for item in param_list: + parent_module, module_path, t = item + assert hasattr(parent_module, module_path) + delattr(parent_module, module_path) + setattr(parent_module, module_path, t) + + +def _reconstruct_dtensor(module: nn.Module, _input: Any): + """ + Recontruct DTensor parameters from local tensors + """ + param_list = [] + # TODO: To add perf optimizations to this iterations + for name, t in module.named_parameters(): + if hasattr(t, "_st_info"): + dtensor = _unflatten_tensor(t, t._st_info) + param_list.append((*_get_submodule_n_params(module, name), dtensor)) + _update_module_param(param_list) # type: ignore[arg-type] + + +def _localize_dtensor( + module: nn.Module, *_: Any, ignored_params: Optional[set[nn.Parameter]] = None +): + """ + Convert DTensor parameters to local tensors + """ + if ignored_params is None: + ignored_params = set() + param_list = [] + for name, param in module.named_parameters(): + if param in ignored_params: + continue + t, sharding_info = _flatten_tensor(param) + if sharding_info is not None: + t = nn.Parameter(t) + t._st_info = sharding_info # type: ignore[attr-defined] + param_list.append((*_get_submodule_n_params(module, name), t)) + _update_module_param(param_list) # type: ignore[arg-type] + + +def _pre_dp_module_transform(module: nn.Module): + """ + Enable the composability between Tensor Parallelism (TP) and Data + Parallelism(DP) in PyTorch when using DDP. We need to convert Parameters which + are DTensors to local tensors before wrapping with data parallelism API. + We then register two hooks, one for converting local tensors back to DTensor + preforward and one to convert DTensors back to tensors after Forward. By + integrating this way, we avoid any special handling of DTensor parameters by DDP + and get DTensor's gradients propagated back to DP, e.g. gradient buckets of DDP. + + For now, this API only works with ``DistributedDataParallel``. It will later support + other DP methods such as FSDP. + + Args: + module (:class:`nn.Module`): + Module which has been applied TP on. + + Example:: + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import parallelize_module, PairwiseParallel + >>> from torch.nn.parallel import DistributedDataParallel as DDP + >>> from torch.distributed.tensor.parallel.ddp import pre_dp_module_transform + >>> + >>> # Define the module. + >>> m = module(...) + >>> parallelize_module(m, PairwiseParallel()) + >>> m = pre_dp_module_transform(m) + >>> m = DDP(m) + >>> + """ + + _localize_dtensor(module, None, None) + # TODO: To add test cases and ensure that it works for nested modules + module.register_forward_pre_hook(_reconstruct_dtensor) + module.register_forward_hook(_localize_dtensor) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/fsdp.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/fsdp.py new file mode 100644 index 0000000000000000000000000000000000000000..f06fc3d1caef19293ded1f979ba6864cb082e0fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/fsdp.py @@ -0,0 +1,390 @@ +# mypy: allow-untyped-defs +import copy +from typing import Any, cast, Optional + +import torch +import torch.distributed as dist +import torch.distributed._shard.sharding_spec as shard_spec +import torch.distributed.distributed_c10d as c10d +from torch.distributed._shard.sharded_tensor import ( + Shard, + ShardedTensor, + ShardedTensorMetadata, + TensorProperties, +) +from torch.distributed._shard.sharding_spec import ShardMetadata +from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec +from torch.distributed.device_mesh import _mesh_resources +from torch.distributed.fsdp._common_utils import _set_fsdp_flattened +from torch.distributed.fsdp._fsdp_extensions import FSDPExtensions +from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor +from torch.distributed.remote_device import _remote_device +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard as DShard +from torch.distributed.tensor.parallel._data_parallel_utils import ( + _flatten_tensor, + _unflatten_tensor, +) + + +__all__ = ["DTensorExtensions"] + + +def _get_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: + device_mesh = tensor.device_mesh + assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + placement = tensor.placements[0] + offsets = [0] * len(tensor.size()) + num_chunks = device_mesh.size(mesh_dim=0) + + if tensor.placements[0].is_shard(): + shard_dim = cast(DShard, placement).dim + chunk_size = tensor.size(shard_dim) // num_chunks + offsets[shard_dim] = chunk_size + + return (torch.Size(offsets), tensor._local_tensor.size()) + + +def _get_box_for(tensor: DTensor, idx: int) -> tuple[torch.Size, torch.Size]: + offsets, size = _get_box(tensor) + return (torch.Size([val * idx for val in offsets]), size) + + +def _get_local_box(tensor: DTensor) -> tuple[torch.Size, torch.Size]: + device_mesh = tensor.device_mesh + coord = device_mesh.get_coordinate() + assert coord is not None + return _get_box_for(tensor, coord[0]) + + +def _create_shard_md_from_dt(dt: DTensor, current_rank: int) -> ShardMetadata: + mesh = dt.device_mesh + assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + offsets, sizes = _get_local_box(dt) + return ShardMetadata( + shard_offsets=list(offsets), + shard_sizes=list(sizes), + placement=f"rank:{current_rank}/{dt._local_tensor.device}", + ) + + +def _create_sharded_tensor_md_from_dt( + dt: DTensor, dt_pg: c10d.ProcessGroup +) -> ShardedTensorMetadata: + # This is where it gets tricky, we have to produce a ShardedTensor that has full coverage + # and yet has only one valid shard for the current rank. + + shards_md = [] + my_rank = dist.get_rank(dt_pg) + scapegoat_rank = 0 if my_rank > 0 else 1 + + if dt.placements[0].is_shard(): + shard_count = dt_pg.size() + else: + shard_count = 1 + + for i in range(shard_count): + offsets, sizes = _get_box_for(dt, i) + shards_md.append( + ShardMetadata( + shard_offsets=list(offsets), + shard_sizes=list(sizes), + placement=( + f"rank:{scapegoat_rank if i > 0 else my_rank}/{dt._local_tensor.device}" + ), + ) + ) + + return ShardedTensorMetadata( + shards_metadata=shards_md, + size=dt.size(), + tensor_properties=TensorProperties( + dtype=dt.dtype, + layout=dt.layout, + requires_grad=dt.requires_grad, + # ignore memory_format and pin_memory as those are not supported by DT + ), + ) + + +def _get_dt_pg(dt: DTensor) -> c10d.ProcessGroup: + mesh = dt.device_mesh + assert mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + return mesh.get_group() + + +def _rewrite_spec_if_needed( + spec: shard_spec.ShardingSpec, tensor: torch.Tensor, rank: int +) -> shard_spec.ShardingSpec: + """ + Rewrite ``spec`` to match the device of ``tensor``. + + FSDP.sharded_optim_state_dict sneakly ships optimizer state to CPU so if the original ShardingSpec + produces CUDA metadata, ST construction bombs. + """ + if not isinstance(spec, ChunkShardingSpec): + return spec + + # let's see if we need + rewrite = False + for p in spec.placements: + p = cast(_remote_device, p) + if p.rank() == rank and p.device() != tensor.device: + rewrite = True + break + if rewrite: + spec = copy.deepcopy(spec) + for i, placement in enumerate(spec.placements): + placement = cast(_remote_device, placement) + if placement.rank() == rank and placement.device() != tensor.device: + spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}") + + return spec + + +def _chunk_tensor( + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, +) -> torch.Tensor: + if type(tensor) is ShardedTensor: + assert len(tensor.local_shards()) == 1 + + inner_param = tensor.local_tensor() + inner_st = _create_chunk_sharded_tensor( + inner_param, + rank, + world_size, + num_devices_per_node, + pg, + ) + + outer_local_shard = tensor.local_shards()[0] + shards: list[Shard] = [ + Shard(inner_st, copy.deepcopy(outer_local_shard.metadata)) + ] + st_meta = copy.deepcopy(tensor.metadata()) + st_meta.tensor_properties.requires_grad = False + + st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( + shards, + sharded_tensor_metadata=st_meta, + process_group=tensor._process_group, + init_rrefs=False, + ) + return st_outer + elif type(tensor) is DTensor: + device_mesh = tensor.device_mesh + assert device_mesh.ndim == 1, "Only 1D DeviceMeshes currently handled" + + inner_param = tensor._local_tensor + + inner_st = _create_chunk_sharded_tensor( + inner_param, + rank, + world_size, + torch.accelerator.device_count(), + pg, + ) + + dt_pg = _get_dt_pg(tensor) + # We do this differently here, we create a ST with no local shards then patch it + shards = [ + Shard(inner_st, _create_shard_md_from_dt(tensor, dist.get_rank(dt_pg))) + ] + + st_meta = _create_sharded_tensor_md_from_dt(tensor, dt_pg) + st_meta.tensor_properties.requires_grad = False + + st_outer = ShardedTensor._init_from_local_shards_and_global_metadata( + shards, + sharded_tensor_metadata=st_meta, + process_group=dt_pg, + init_rrefs=False, + ) + + return st_outer + else: + return _create_chunk_sharded_tensor( + tensor, + rank, + world_size, + num_devices_per_node, + pg, + ) + + +def _chunk_dtensor( + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, +) -> DTensor: + """ + Shard a tensor to chunks along the first dimension. + + The local rank will gets its corresponding chunk as the local tensor to create a DTensor. + """ + root_mesh = _mesh_resources.get_root_mesh(device_mesh) + if root_mesh is None: + raise RuntimeError("No parent device_mesh is found for FSDP device_mesh.") + if root_mesh.ndim < 2: + raise RuntimeError( + f"Found parent device_mesh of ndim={root_mesh.ndim},", + "but meshes must be at least 2D.", + ) + + # We need to explicitly call .detach() to return a new tensor detached from the current graph. + tensor = tensor.detach().clone() + + # When a layer is not involved in TP, then the tensor will not be a DTensor. + # e.g. When a layer is not sppecified in the parallelize_plan, TP will have no effect on the layer. + # e.g. When you do PairwiseParallel on a 3 layer model, TP will have no effect on the third layer. + if isinstance(tensor, torch.Tensor) and not isinstance(tensor, DTensor): + # For tensors, it is replicated across tp dimension and sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), Replicate()). + replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] + shard_placements = [Replicate() for _ in range(root_mesh.ndim)] + shard_placements[0] = DShard(0) # type: ignore[call-overload] + + return DTensor.from_local( + tensor, root_mesh, replicate_placements, run_check=False + ).redistribute( + device_mesh=root_mesh, + placements=shard_placements, + ) + + else: + tp_placements = tensor.placements + tp_placement = tp_placements[0] + + tensor = tensor.to_local() + + # For DTensors, it is sharded across tp dimension first and then sharded across FSDP dimension. + # TP is the inner dimension and FSDP is the outer dimension. + # Therefore, shard placements for tensor is (Shard(0), tp_placement). + # For higher dimensional meshes, it is replicated across other dimensions. For example, with + # HSDP the shard placements for tensor is (Replicate, Shard(0), tp_placement). + replicate_placements = [Replicate() for _ in range(root_mesh.ndim)] + replicate_placements[-1] = tp_placement # type: ignore[call-overload] + shard_placements = [Replicate() for i in range(root_mesh.ndim)] # type: ignore[misc] + shard_placements[-2] = DShard(0) # type: ignore[call-overload] + shard_placements[-1] = tp_placement # type: ignore[call-overload] + + return DTensor.from_local( + tensor, root_mesh, replicate_placements, run_check=False + ).redistribute( + device_mesh=root_mesh, + placements=shard_placements, + ) + + +def _pre_load_state_dict( + tensor: torch.Tensor, +) -> tuple[torch.Tensor, list[Shard]]: + shards = cast(ShardedTensor, tensor).local_shards() + if len(shards) == 1 and type(shards[0].tensor) is ShardedTensor: + inner_tensor = shards[0].tensor + shards = inner_tensor.local_shards() # pyre-ignore[16] + tensor = inner_tensor + + return (tensor, shards if len(shards) > 0 else []) + + +def _all_gather_dtensor( + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], +) -> torch.Tensor: + """All gather a DTensor in its FSDP dimension and return the local tensor.""" + assert parent_mesh == tensor.device_mesh + + placements = list(copy.deepcopy(tensor.placements)) + # FSDP + TP: [Shard(0), tp_placement] -> [Replicate(), tp_placement] + # HSDP + TP: [Replicate(), Shard(0), tp_placement] -> [Replicate(), Replicate(), tp_placement] + for i in range(0, len(placements) - 1): + placements[i] = Replicate() + tensor = tensor.redistribute( + device_mesh=tensor.device_mesh, + placements=placements, + ) + + return tensor.to_local() + + +class DTensorExtensions(FSDPExtensions): + """ + DTensorExtension is the TensorFlattener extension needed for 2D FSDP + TP. + + This is the implementation for FSDPExtensions defined in + https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fsdp_extensions.py + """ + + def __init__(self, device_handle) -> None: + super().__init__() + self.compute_stream = None + self.device_handle = device_handle + # we have to use the dynamo disable this way to disable dynamo as the decorater way would + # trigger build failure with torch deploy... + self.post_unflatten_transform = torch._dynamo.disable( # type: ignore[method-assign] + self.post_unflatten_transform + ) + + def pre_flatten_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, Optional[Any]]: + return _flatten_tensor(tensor) + + def post_unflatten_transform( + self, tensor: torch.Tensor, param_extension: Any + ) -> torch.Tensor: + stream = self.compute_stream or self.device_handle.current_stream() + with self.device_handle.stream(stream): + # runtime we put the unflattened tensor call on the compute stream since + # the unflattened tensor might contain computations in fwd/bwd where we + # need to sync properly. + # TODO: this is a short term fix and we should make the get_unflat_views + # directly happen in the compute stream. + result = _unflatten_tensor( + tensor, + param_extension, + device_handle=self.device_handle, + compute_stream=self.compute_stream, + ) + _set_fsdp_flattened(result) + return result + + def chunk_tensor( + self, + tensor: torch.Tensor, + rank: int, + world_size: int, + num_devices_per_node: int, + pg: dist.ProcessGroup, + device: Optional[torch.device] = None, + ) -> torch.Tensor: + return _chunk_tensor(tensor, rank, world_size, num_devices_per_node, pg) + + def chunk_dtensor( + self, + tensor: torch.Tensor, + rank: int, + device_mesh: DeviceMesh, + ) -> torch.Tensor: + return _chunk_dtensor(tensor, rank, device_mesh) + + def pre_load_state_dict_transform( + self, + tensor: torch.Tensor, + ) -> tuple[torch.Tensor, list[Shard]]: + return _pre_load_state_dict(tensor) + + def all_gather_dtensor( + self, + tensor: DTensor, + parent_mesh: Optional[DeviceMesh], + ) -> torch.Tensor: + return _all_gather_dtensor(tensor, parent_mesh) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/input_reshard.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/input_reshard.py new file mode 100644 index 0000000000000000000000000000000000000000..26ccdeaba311911bdc4eba76f68b0d562d0bf60a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/input_reshard.py @@ -0,0 +1,106 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates +from functools import partial +from typing import Any, Optional + +import torch +from torch.distributed.tensor import DeviceMesh, DTensor, Replicate, Shard + + +__all__ = [ + "input_reshard", +] + + +def input_reshard( + module: torch.nn.Module, + tp_device_mesh: DeviceMesh, + input_reshard_dim: Optional[int] = None, +) -> torch.nn.Module: + """ + Register hooks to an nn.Module for input resharding, enabling sharding and restoration during backward computation. + + Register hooks to an nn.Module with input resharding so that we can shard + per the given `tp_device_mesh` and `input_reshard_dim` and restore the + input back when recomputing the activations in the backward. The reason + why we can do this is that for Tensor Parallel(TP), the input are same + across all TP ranks. + + Args: + module (:class:`nn.Module`): + Module to be registered with input resharding. + tp_device_mesh (:class:`DeviceMesh`): + Object which describes the mesh topology + of devices for Tensor Parallel. + input_reshard_dim (Optional[int]): + The dimension of where we perform the sharding + of input. If set None, there is no sharding of input. + Default: None + + Return: + A :class:`nn.Module` object registered with TP input resharding. + """ + if input_reshard_dim is None: + return module + + cx: Optional[torch.autograd.graph.saved_tensors_hooks] = None + + def input_reshard_forward_pre_hook(_: torch.nn.Module, _i: tuple[Any, ...]) -> None: + saved_tensor_hooks = torch.autograd.graph.saved_tensors_hooks( + partial(_pack_hook_tp, tp_device_mesh, input_reshard_dim), + partial(_unpack_hook_tp, tp_device_mesh, input_reshard_dim), + ) + saved_tensor_hooks.__enter__() + nonlocal cx + cx = saved_tensor_hooks # type: ignore[name-defined] + + def input_reshard_backward_hook( + _: torch.nn.Module, _i: tuple[Any, ...], _o: Any + ) -> Any: + nonlocal cx + cx.__exit__() # type: ignore[name-defined, union-attr] + + module.register_forward_pre_hook(input_reshard_forward_pre_hook) + module.register_forward_hook(input_reshard_backward_hook) + return module + + +def _pack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: torch.Tensor) -> Any: # noqa: D401 + """Hook function called after FWD to shard input.""" + if isinstance(x, DTensor) and all(p.is_replicate() for p in x._spec.placements): + return x.redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) + elif ( + not isinstance(x, DTensor) + and isinstance(x, torch.Tensor) + and x.numel() >= mesh.size() + ): + return ( + DTensor.from_local(x, device_mesh=mesh) + .redistribute(device_mesh=mesh, placements=[Shard(input_reshard_dim)]) + .to_local() + ) + else: + return x + + +def _unpack_hook_tp(mesh: DeviceMesh, input_reshard_dim: int, x: Any) -> torch.Tensor: # noqa: D401 + """Hook function called before activation recomputing in BWD to restore input.""" + if ( + isinstance(x, DTensor) + and len(x._spec.placements) == 1 + and x._spec.placements[0].is_shard() + ): + return x.redistribute(device_mesh=mesh, placements=[Replicate()]) + elif ( + not isinstance(x, DTensor) + and isinstance(x, torch.Tensor) + and x.numel() >= mesh.size() + ): + return ( + DTensor.from_local( + x, device_mesh=mesh, placements=[Shard(input_reshard_dim)] + ) + .redistribute(device_mesh=mesh, placements=[Replicate()]) + .to_local() + ) + else: + return x diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/loss.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..a27a5881bb6cb6b389232505ecb7daa6208a0c6b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/loss.py @@ -0,0 +1,490 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +import contextlib +from typing import cast, Optional + +import torch +import torch._prims_common as utils +import torch.distributed._functional_collectives as funcol +import torch.distributed.distributed_c10d as c10d +from torch import Tensor +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor import DTensor, Replicate, Shard +from torch.distributed.tensor._dtensor_spec import DTensorSpec, TensorMeta +from torch.distributed.tensor._ops._embedding_ops import _MaskPartial +from torch.distributed.tensor._ops._math_ops import ( + _skip_dim, + Reduction, + replicate_reduction_dims, +) +from torch.distributed.tensor._ops.utils import normalize_dim +from torch.distributed.tensor.placement_types import Placement + + +aten = torch.ops.aten + + +__all__ = ["loss_parallel"] + + +@contextlib.contextmanager +def loss_parallel(): + """ + A context manager that enables loss parallelism, where efficient parallelized loss computation + can be performed when the input is sharded on the class dimension. Currently only the cross-entropy + loss is supported. + + Within this context manager, one can use :func:`~torch.nn.functional.cross_entropy` or + :class:`~torch.nn.CrossEntropyLoss` as usual, with the following assumptions on the input parameters. + The corresponding ``backward()`` call, if any, also needs to happen under this context manager. + + Args: + input (:class:`DTensor`): + Input logits. Assumed to be sharded on the class dimension. + target (Union[:class:`torch.Tensor`, :class:`DTensor`]): + Must be ground truth class indices (class probabilities currently not supported). + Assumed to be replicated across the ``DeviceMesh``. + weight (Union[:class:`torch.Tensor`, :class:`DTensor`], optional): + If given, assumed to be replicated across the ``DeviceMesh``. + label_smoothing: + Currently not supported. + + Returns: + A replicated :class:`DTensor`. + + Example: + A sharded DTensor is manually created here to showcase the usage. + In practice, it is usually the output of a TP module. + + >>> # xdoctest: +SKIP("distributed") + >>> from torch.distributed.tensor.parallel import loss_parallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> device_mesh = init_device_mesh("cuda", (8,)) + >>> input = torch.randn(4, 16, device="cuda", requires_grad=True) + >>> dist_input = distribute_tensor(input, device_mesh, placements=[Shard(1)]) + >>> target = torch.randint(16, (4,), device="cuda") + >>> with loss_parallel(): + >>> loss = F.cross_entropy(dist_input, target, reduction="mean") + >>> loss.backward() + >>> ... + """ + _enable_custom_loss_ops() + + yield + + _disable_custom_loss_ops() + + +# Currently only needs to support one dimensional DeviceMesh; in general return +# the mesh_dim with placements[mesh_dim].is_shard(dim) +def _find_all_reduce_mesh_dim(placements: tuple[Placement, ...], dim: int) -> int: + if not len(placements) == 1: + raise ValueError( + "Currently loss_parallel() only supports input on one-dimensional DeviceMesh." + ) + if not placements[0].is_shard(dim): + raise ValueError( + f"loss_parallel() should be enabled only when the input tensor is sharded on dimension {dim}." + ) + return 0 + + +def _cast_to_dtensor( + tensor, placements: tuple[Placement, ...], mesh: DeviceMesh +) -> DTensor: + if isinstance(tensor, DTensor): + if tensor.placements == placements: + return tensor + else: + raise RuntimeError(f"Expected {placements} but got {tensor.placements}.") + elif isinstance(tensor, torch.Tensor): + return DTensor.from_local( + tensor, device_mesh=mesh, placements=placements, run_check=False + ) + else: + raise TypeError(f"Unsupported type {type(tensor)}") + + +def _propagate_tensor_meta( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> TensorMeta: + op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs) + tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta( + op_info.schema + ) + if isinstance(tensor_meta, TensorMeta): + return tensor_meta + elif isinstance(tensor_meta, tuple): + return tensor_meta[0] + else: + raise RuntimeError(f"Unexpected tensor meta type: {type(tensor_meta)}.") + + +# NOTE: The implementation follows torch._decomp.decomposition._log_softmax, +# with all_reduce manually inserted to perform distributed computation. +def _log_softmax(x, dim, half_to_float, mesh, mesh_dim): + if half_to_float: + assert x.dtype == torch.half + computation_dtype, result_dtype = utils.elementwise_dtypes( + x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT + ) + x = x.to(dtype=computation_dtype, memory_format=torch.contiguous_format) + if x.numel() == 0: + shifted = x + else: + x_max = torch.amax(x, dim, keepdim=True) + x_max = funcol.all_reduce( + x_max, reduceOp=c10d.ReduceOp.MAX.name, group=(mesh, mesh_dim) + ) + shifted = x - x_max + shifted_sumexp = torch.sum(torch.exp(shifted), dim, keepdim=True) + shifted_sumexp = funcol.all_reduce( + shifted_sumexp, reduceOp=c10d.ReduceOp.SUM.name, group=(mesh, mesh_dim) + ) + shifted_logsumexp = torch.log(shifted_sumexp) + result = shifted - shifted_logsumexp + if not half_to_float: + result = result.to(result_dtype) + return result + + +def _log_softmax_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + dim = cast(int, args[1]) + half_to_float = cast(bool, args[2]) + + spec = x._spec + dim = normalize_dim(dim, x.dim()) + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, dim) + + output_tensor_meta = _propagate_tensor_meta(op_call, args, kwargs) + + res = _log_softmax(x._local_tensor, dim, half_to_float, spec.mesh, mesh_dim) + + res_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) + + return DTensor( + res, + res_spec, + requires_grad=res.requires_grad, + ) + + +# NOTE: As explained below at _nll_loss_and_log_softmax_backward, the +# _log_softmax_backward_handler does not actually do any computation. +def _log_softmax_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + input_dtype = cast(torch.dtype, args[3]) + return grad_output.to(input_dtype) + + +# NOTE: The implementation follows torch._decomp.decomposition._nll_loss_forward, +# with customized communication inserted to perform distributed computation. +def _nll_loss_forward( + x: Tensor, + target: Tensor, + weight: Optional[Tensor], + local_weight: Optional[Tensor], + reduction: int, + ignore_index: int, + input_shape: torch.Size, + channel_dim: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> tuple[Tensor, Tensor]: + n_dims = x.dim() + channel_dim = 1 + if n_dims < 2: + channel_dim = 0 + + def _weight_view(weight: Tensor) -> Tensor: + if n_dims > 1: + shape = [ + 1, + ] * n_dims + shape[channel_dim] = weight.shape[0] + w = weight.view(shape) + else: + w = weight + return w + + if weight is not None: + w = _weight_view(weight) + assert local_weight is not None + local_w = _weight_view(local_weight) + x = x * local_w + safe_target = torch.where(target != ignore_index, target, 0) + safe_target_ = safe_target.unsqueeze(channel_dim) + + # The following code block is a distributed version of + # result = -torch.gather(self, channel_dim, safe_target_).squeeze(channel_dim) + partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) + safe_target_partial_ = partial_placement._partition_value( + safe_target_, mesh, mesh_dim + ) + result_partial = torch.gather(x, channel_dim, safe_target_partial_) + # an all_reduce happens here + result_reduced = partial_placement._reduce_value(result_partial, mesh, mesh_dim) + result = -result_reduced.squeeze(channel_dim) + + result = torch.where(target != ignore_index, result, 0) + + if reduction == Reduction.NONE.value and n_dims > 1: + total_weight = x.new_full((), 0.0) + return result, total_weight + + if weight is not None: + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = w.expand(new_shape) + wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) + wsum = torch.where(target != ignore_index, wsum, 0) + total_weight = wsum.sum() + else: + total_weight = (target != ignore_index).sum().to(x) + + # NOTE: this is correct only on 1D DeviceMesh; o/w additional + # all-reduce on result and total_weight is needed + if reduction == Reduction.SUM.value: + result = result.sum() + elif reduction == Reduction.MEAN.value: + result = result.sum() / total_weight + + return result, total_weight + + +def _nll_loss_forward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + x = cast(DTensor, args[0]) + target = args[1] + weight = args[2] + reduction = cast(int, args[3]) + ignore_index = cast(int, args[4]) + + channel_dim = 1 if x.dim() >= 2 else 0 + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # Check user input: if target and weight are not DTensors, convert them to DTensors; + # if they are DTensors, check that they have the desired placements. + target_placements = _skip_dim( + replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim + ) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + local_weight = None + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + # For local computation, both (replicated) weight and (sharded) local_weight + # are needed in _nll_loss_forward(). local_weight is generated here using + # DTensor API, without incurring any communication. + sharded_placements = [ + Shard(0) if i == mesh_dim else Replicate() for i in range(spec.mesh.ndim) + ] + local_weight = weight.redistribute(spec.mesh, sharded_placements)._local_tensor + assert local_weight.shape[0] == x._local_tensor.shape[channel_dim] + + if reduction == Reduction.NONE.value: + output_placements = target_placements + else: + output_placements = all_replicate_placements + + # tensor inputs to _propagate_tensor_meta need to be DTensors + args = list(args) + args[1], args[2] = target, weight + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result, total_weight = _nll_loss_forward( + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + local_weight, + reduction, + ignore_index, + x.shape, + channel_dim, + spec.mesh, + mesh_dim, + ) + out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) + + return ( + DTensor( + result, + out_spec, + requires_grad=result.requires_grad, + ), + total_weight, + ) + + +# NOTE: The backward computation of cross_entropy goes through two steps: +# backward for nll_loss and then backward for log_softmax. In loss parallel, +# the two steps are fused into the following function (called by _nll_loss_backward_handler) +# to avoid communication when target contains class indices not class probabilities. +# Also note that the _log_softmax_backward_handler does not perform computation. +# The implementation resembles _nll_loss_backward and _log_softmax_backward_data +# from torch._decomp.decomposition. +def _nll_loss_and_log_softmax_backward( + grad_output: Tensor, + x: Tensor, + target: Tensor, + weight: Optional[Tensor], + reduction: int, + ignore_index: int, + total_weight: Tensor, + input_shape: torch.Size, + channel_dim: int, + mesh: DeviceMesh, + mesh_dim: int, +) -> Tensor: + channel_dim = 0 if x.dim() < 2 else 1 + if reduction == Reduction.MEAN.value: + grad_output = grad_output / total_weight + + target = target.unsqueeze(channel_dim) + safe_target = torch.where(target != ignore_index, target, 0) + grad_input = torch.zeros_like(x) + + # The following code block is a distributed version of + # grad_input = torch.scatter(grad_input, channel_dim, safe_target, -1.0) + partial_placement = _MaskPartial(offset_shape=input_shape, offset_dim=channel_dim) + safe_target = safe_target.squeeze(channel_dim).flatten() + masked_safe_target = partial_placement._partition_value(safe_target, mesh, mesh_dim) + # only update grad_input to -1 if not masked + assert partial_placement.mask_buffer.data is not None + grad_update = partial_placement.mask_buffer.data.to(grad_input.dtype) - 1.0 + arange_1d = torch.arange( + masked_safe_target.shape[0], device=masked_safe_target.device + ) + # The first two cases with x.dim() <= 2 are for aten.nll_loss_backward.default; + # the last case is for aten.nll_loss2d_backward.default. + if x.dim() == 1: + grad_input[masked_safe_target] = grad_update + elif x.dim() == 2: + grad_input[arange_1d, masked_safe_target] = grad_update + else: + grad_input_t = grad_input.transpose(channel_dim, -1) + intermidate_shape = grad_input_t.shape + grad_input_2d = grad_input_t.reshape(-1, x.shape[channel_dim]) + grad_input_2d[arange_1d, masked_safe_target] = grad_update + grad_input = grad_input_2d.view(intermidate_shape).transpose(channel_dim, -1) + + if grad_input.dim() > grad_output.dim() > 0: + grad_output = grad_output.unsqueeze(channel_dim) + + if weight is not None: + new_shape = [1 for _ in range(x.dim())] + new_shape[channel_dim] = weight.shape[0] + weight = weight.reshape(new_shape) + # In order for fused computation to work, the following line is rewritten. + # grad_output = grad_output * weight + new_shape = list(x.shape) + new_shape[channel_dim] = -1 + w = weight.expand(new_shape) + w_target = torch.gather(w, channel_dim, target) + grad_output = grad_output * w_target + + grad_output = torch.where(target != ignore_index, grad_output, 0) + + # NOTE: Instead of directly returning the grad_input as grad_output for log_softmax, + # here we perform backward computation for log_softmax altogether to avoid the + # otherwise extra all_gather communication. + # return grad_input * grad_output + return (grad_input + torch.exp(x)) * grad_output + + +def _nll_loss_backward_handler( + op_call: torch._ops.OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + grad_output = cast(DTensor, args[0]) + x = cast(DTensor, args[1]) + target = args[2] + weight = args[3] + reduction = cast(int, args[4]) + ignore_index = cast(int, args[5]) + total_weight = cast(Tensor, args[6]) + + channel_dim = 1 if x.dim() >= 2 else 0 + spec = x._spec + mesh_dim = _find_all_reduce_mesh_dim(spec.placements, channel_dim) + + # if target and weight are not DTensors, convert them to DTensors + target_placements = _skip_dim( + replicate_reduction_dims(spec.placements, [channel_dim]), channel_dim + ) + all_replicate_placements = (Replicate(),) * spec.mesh.ndim + target = _cast_to_dtensor(target, target_placements, spec.mesh) + if weight is not None: + weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) + + # tensor inputs to _propagate_tensor_meta need to be DTensors + args = list(args) + args[2], args[3] = target, weight + args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) + output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) + + result = _nll_loss_and_log_softmax_backward( + grad_output._local_tensor, + x._local_tensor, + target._local_tensor, + weight._local_tensor if weight is not None else None, + reduction, + ignore_index, + total_weight, + x.shape, + channel_dim, + spec.mesh, + mesh_dim, + ) + # the output sharding is the same as input sharding: Shard(channel_dim) on mesh_dim + out_spec = DTensorSpec( + spec.mesh, + spec.placements, + tensor_meta=output_tensor_meta, + ) + + return DTensor( + result, + out_spec, + requires_grad=result.requires_grad, + ) + + +customized_loss_ops = { + aten._log_softmax.default: _log_softmax_handler, + aten._log_softmax_backward_data.default: _log_softmax_backward_handler, + aten.nll_loss_forward.default: _nll_loss_forward_handler, + aten.nll_loss2d_forward.default: _nll_loss_forward_handler, + aten.nll_loss_backward.default: _nll_loss_backward_handler, + aten.nll_loss2d_backward.default: _nll_loss_backward_handler, +} + + +def _enable_custom_loss_ops(): + DTensor._op_dispatcher._custom_op_handlers.update(customized_loss_ops) + + +def _disable_custom_loss_ops(): + for custom_op in customized_loss_ops: + DTensor._op_dispatcher._custom_op_handlers.pop(custom_op) diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/style.py b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/style.py new file mode 100644 index 0000000000000000000000000000000000000000..17b542f60819e7f336daa223012d5afdc5bb8fa2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/parallel/style.py @@ -0,0 +1,812 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates +from abc import ABC, abstractmethod +from functools import partial +from typing import Any, Optional, Union + +import torch +import torch.nn as nn +from torch.distributed.tensor import ( + DeviceMesh, + distribute_module, + distribute_tensor, + DTensor, + Replicate, + Shard, +) +from torch.distributed.tensor.placement_types import Placement + + +__all__ = [ + "ParallelStyle", + "RowwiseParallel", + "SequenceParallel", + "ColwiseParallel", + "PrepareModuleInput", + "PrepareModuleInputOutput", + "PrepareModuleOutput", +] + + +class ParallelStyle(ABC): + """ + The parallel style contract defines how the module or submodule should be parallelized. + + It only defines the ``apply`` method for ``parallelize_module`` to use, this allows maximum + flexibility for different kind of style implementations. + """ + + src_data_rank: Optional[int] = 0 + + @abstractmethod + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: ... + + +class ColwiseParallel(ParallelStyle): + """ + Partition a compatible nn.Module in a column-wise fashion. Currently supports nn.Linear and nn.Embedding. + Users can compose it together with RowwiseParallel to achieve the sharding of more complicated modules. + (i.e. MLP, Attention) + + Keyword Args: + input_layouts (Placement, optional): + The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to + become a DTensor. If not specified, we assume the input tensor to be replicated. + output_layouts (Placement, optional): + The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module + with the user desired layout. If not specified, the output tensor is sharded on the last dimension. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. + Returns: + A :class:`ParallelStyle` object that represents Colwise sharding of the nn.Module. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "w1" nn.Linear submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "w1" Linear will be converted to Replicated DTensor + >>> # and the output of "w1" will return :class:`torch.Tensor` that shards on the last dim. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"w1": ColwiseParallel()}) + >>> ... + + .. note:: By default ``ColwiseParallel`` output is sharded on the last dimension if the ``output_layouts`` not + specified, if there're operators that require specific tensor shape (i.e. before the paired ``RowwiseParallel``), + keep in mind that if the output is sharded the operator might need to be adjusted to the sharded size. + """ + + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Replicate(),) + self.output_layouts = (output_layouts or Shard(-1),) + # colwise linear runtime sharding (desired sharding): + # 1. requires replicate input + # 2. shard output on last dim + self.desired_input_layouts = (Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + # TODO: figure out dynamo support for instance method and switch this to instance method + + # annotate module input placements/sharding with input_layouts + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + # transform the input layouts to the desired layouts of ColwiseParallel + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + def _partition_linear_fn(self, name, module, device_mesh): + # colwise shard weight/bias to Shard(0), weight be Shard(0) + # means Colwise as Linear is input * weight^T + bias, where + # weight would become Shard(1) + for name, param in module.named_parameters(): + dist_param = nn.Parameter( + distribute_tensor( + param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank + ) + ) + module.register_parameter(name, dist_param) + + def _partition_embedding_fn(self, name, module, device_mesh): + # colwise shard embedding.weight is straight forward as Shard(1) + for name, param in module.named_parameters(): + dist_param = nn.Parameter( + distribute_tensor( + param, device_mesh, [Shard(1)], src_data_rank=self.src_data_rank + ) + ) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # outputs is a shard on last dimension DTensor, i.e. Shard(-1) + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if isinstance(module, nn.Linear): + partition_fn = self._partition_linear_fn + elif isinstance(module, nn.Embedding): + partition_fn = self._partition_embedding_fn + else: + raise NotImplementedError( + "ColwiseParallel currently only support nn.Linear and nn.Embedding!" + ) + + return distribute_module( + module, + device_mesh, + partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), + ) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.input_layouts}, " + tmpstr += f"output_layouts={self.output_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class RowwiseParallel(ParallelStyle): + """ + Partition a compatible nn.Module in a row-wise fashion. Currently supports nn.Linear and nn.Embedding. + Users can compose it with ColwiseParallel to achieve the sharding of more complicated modules. + (i.e. MLP, Attention) + + Keyword Args: + input_layouts (Placement, optional): + The DTensor layout of input tensor for the nn.Module, this is used to annotate the input tensor to + become a DTensor. If not specified, we assume the input tensor to be sharded on the last dimension. + output_layouts (Placement, optional): + The DTensor layout of the output for the nn.Module, this is used to ensure the output of the nn.Module + with the user desired layout. If not specified, the output tensor is replicated. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: True. + Returns: + A :class:`ParallelStyle` object that represents Rowwise sharding of the nn.Module. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, RowwiseParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "w2" nn.Linear submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "w2" Linear will be converted to DTensor that shards on the last dim + >>> # and the output of "w2" will return a replicated :class:`torch.Tensor`. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"w2": RowwiseParallel()}), + >>> ... + """ + + def __init__( + self, + *, + input_layouts: Optional[Placement] = None, + output_layouts: Optional[Placement] = None, + use_local_output: bool = True, + ): + super().__init__() + self.input_layouts = (input_layouts or Shard(-1),) + self.output_layouts = (output_layouts or Replicate(),) + self.use_local_output = use_local_output + + @staticmethod + def _prepare_input_fn( + input_layouts, desired_input_layouts, mod, inputs, device_mesh + ): + input_tensor = inputs[0] + if not isinstance(input_tensor, DTensor): + input_tensor = DTensor.from_local( + input_tensor, device_mesh, input_layouts, run_check=False + ) + + if input_layouts != desired_input_layouts: + input_tensor = input_tensor.redistribute( + placements=desired_input_layouts, async_op=True + ) + return input_tensor + + def _partition_linear_fn(self, name, module, device_mesh): + # Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1) + # means Rowwise as nn.Linear is input * weight^T + bias, where + # weight would become Shard(0) + module.register_parameter( + "weight", + nn.Parameter( + distribute_tensor( + module.weight, + device_mesh, + [Shard(1)], + src_data_rank=self.src_data_rank, + ) + ), + ) + if getattr(module, "bias", None) is not None: + # The Linear module has bias + module.register_parameter( + "bias", + nn.Parameter( + distribute_tensor( + module.bias, + device_mesh, + [Replicate()], + src_data_rank=self.src_data_rank, + ) + ), + ) + + def _partition_embedding_fn(self, name, module, device_mesh): + # rowwise shard embedding.weight is Shard(0) + for name, param in module.named_parameters(): + dist_param = nn.Parameter( + distribute_tensor( + param, device_mesh, [Shard(0)], src_data_rank=self.src_data_rank + ) + ) + module.register_parameter(name, dist_param) + + @staticmethod + def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): + # Rowwise sharding produces partial output, depending on output layouts: + # 1. to replicate -> allreduce + # 2. to shard -> reduce_scatter + if outputs.placements != output_layouts: + outputs = outputs.redistribute(placements=output_layouts, async_op=True) + # back to local tensor if use_local_output is True + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if isinstance(module, nn.Linear): + partition_fn = self._partition_linear_fn + # rowwise linear runtime sharding requires input tensor shard on last dim + self.desired_input_layouts: tuple[Placement, ...] = (Shard(-1),) + elif isinstance(module, nn.Embedding): + partition_fn = self._partition_embedding_fn + # rowwise embedding runtime sharding requires input tensor replicated + self.desired_input_layouts = (Replicate(),) + else: + raise NotImplementedError( + "RowwiseParallel currently only support nn.Linear and nn.Embedding!" + ) + + return distribute_module( + module, + device_mesh, + partition_fn, + partial( + self._prepare_input_fn, self.input_layouts, self.desired_input_layouts + ), + partial( + self._prepare_output_fn, self.output_layouts, self.use_local_output + ), + ) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.input_layouts}, " + tmpstr += f"output_layouts={self.output_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class SequenceParallel(ParallelStyle): + """ + SequenceParallel replicates a compatible ``nn.Module`` parameters and runs the sharded computation with + input sharded on the sequence dimension. This currently supports ``nn.LayerNorm``, ``nn.Dropout``, and the + `RMSNorm python implementation `__ + + This style implements the operation that is described in the paper + `Reducing Activation Recomputation in Large Transformer Models `__ + + If the input passed in to this ``nn.Module`` is a :class:`torch.Tensor`, it assumes that the input is already sharded + on the sequence dimension and converts the input to a :class:`DTensor` sharded on the sequence dimension. If the input + passed in to this ``nn.Module`` is already a :class:`DTensor` but is not sharded on the sequence dimension, it would + redistribute the input to be sharded on the sequence dimension. + + The output of the ``nn.Module`` will be sharded on the sequence dimension. + + Keyword Args: + sequence_dim (int, optional): + The sequence dimension of the input tensor for the ``nn.Module``, this is used to annotate the input tensor to + become a DTensor that is sharded on the sequence dimension, default: 1. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module output, default: False. + Returns: + A :class:`ParallelStyle` object that represents Sequence Parallel of the ``nn.Module``. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, SequenceParallel + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> m = Model(...) # m is a nn.Module that contains a "norm" nn.LayerNorm submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # By default, the input of the "norm" will be converted to DTensor that shards on the sequence dim + >>> # and the output of "norm" will return a sharded on sequence dimension :class:`DTensor`. + >>> + >>> sharded_mod = parallelize_module(m, tp_mesh, {"norm": SequenceParallel()}), + >>> ... + + .. note:: SequenceParallel style assumes ones initialization if there are weights in the nn.Module (i.e. + ``nn.LayerNorm`` or ``RMSNorm``, and they by default have ones initialization). If you have custom + inits for the weights on those modules, you need to broadcast the weights before/after parallelizing + to ensure that they are replicated. + """ + + def __init__(self, *, sequence_dim: int = 1, use_local_output: bool = False): + super().__init__() + self.sequence_sharding = (Shard(sequence_dim),) + self.use_local_output = use_local_output + + def _replicate_module_fn( + self, name: str, module: nn.Module, device_mesh: DeviceMesh + ): + for p_name, param in module.named_parameters(): + # simple replication with fixed ones_ init from LayerNorm/RMSNorm, which allow + # us to simply just use from_local + replicated_param = torch.nn.Parameter( + DTensor.from_local(param, device_mesh, [Replicate()], run_check=False) + ) + module.register_parameter(p_name, replicated_param) + + @staticmethod + def _prepare_input_fn(sequence_sharding, mod, inputs, device_mesh): + input_tensor = inputs[0] + if isinstance(input_tensor, DTensor): + # if the passed in input DTensor is not sharded on the sequence dim, we need to redistribute it + if input_tensor.placements != sequence_sharding: + input_tensor = input_tensor.redistribute( + placements=sequence_sharding, async_op=True + ) + return input_tensor + elif isinstance(input_tensor, torch.Tensor): + # assume the input passed in already sharded on the sequence dim and create the DTensor + return DTensor.from_local( + input_tensor, device_mesh, sequence_sharding, run_check=False + ) + else: + raise ValueError( + f"expecting input of {mod} to be a torch.Tensor or DTensor, but got {input_tensor}" + ) + + @staticmethod + def _prepare_output_fn(use_local_output, mod, outputs, device_mesh): + return outputs.to_local() if use_local_output else outputs + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + return distribute_module( + module, + device_mesh, + self._replicate_module_fn, + partial(self._prepare_input_fn, self.sequence_sharding), + partial(self._prepare_output_fn, self.use_local_output), + ) + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + if len(self.sequence_sharding) == 1: + tmpstr += f"sequence_dim={self.sequence_sharding[0].dim}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class PrepareModuleInput(ParallelStyle): + """ + Configure the nn.Module's inputs to convert the input tensors of the nn.Module to DTensors at runtime according to + ``input_layouts``, and perform layout redistribution according to the ``desired_input_layouts``. + + Keyword Args: + input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to + DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified + as a placeholder. default: None. + desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. + input_kwarg_layouts (Dict[str, Placement]): + The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. + default: None + desired_input_kwarg_layouts: (Dict[str, Placement]): + The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. default: None. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. + Returns: + A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the first input of attn will be annotated to Sharded DTensor + >>> # and then redistributed to Replicated DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan={ + >>> "attn": PrepareModuleInput( + >>> input_layouts=(Shard(0), None, None, ...), + >>> desired_input_layouts=(Replicate(), None, None, ...) + >>> ), + >>> } + >>> ) + """ + + def __init__( + self, + *, + input_layouts: Optional[Union[Placement, tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, tuple[Optional[Placement]]] + ] = None, + input_kwarg_layouts: Optional[dict[str, Placement]] = None, + desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None, + use_local_output: bool = False, + ): + self.input_layouts = ( + (input_layouts,) if isinstance(input_layouts, Placement) else input_layouts + ) + self.desired_input_layouts = ( + (desired_input_layouts,) + if isinstance(desired_input_layouts, Placement) + else desired_input_layouts + ) + self.use_local_output = use_local_output + if self.input_layouts is not None: + assert self.desired_input_layouts is not None, ( + "desired module inputs should not be None!" + ) + assert len(self.input_layouts) == len(self.desired_input_layouts), ( + "input_layouts and desired_input_layouts should have same length!" + ) + self.with_kwargs = input_kwarg_layouts is not None + self.input_kwarg_layouts = input_kwarg_layouts or {} + self.desired_input_kwarg_layouts = desired_input_kwarg_layouts or {} + if self.with_kwargs: + assert len(self.input_kwarg_layouts) == len( + self.desired_input_kwarg_layouts + ), ( + "input_kwarg_layouts and desired_input_kwarg_layouts should have same length!" + ) + + def _prepare_input_arg( + self, + input: Any, + mesh: DeviceMesh, + input_layout: Optional[Placement], + desired_layout: Optional[Placement], + ): + if input_layout is not None: + if isinstance(input, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert inp.placements[0] == input_layout + dt_inp = input + else: + assert isinstance(input, torch.Tensor), ( + "expecting input to be a torch.Tensor!" + ) + dt_inp = DTensor.from_local( + input, mesh, (input_layout,), run_check=False + ) + + if desired_layout is not None and input_layout != desired_layout: + dt_inp = dt_inp.redistribute(placements=(desired_layout,)) + + return dt_inp.to_local() if self.use_local_output else dt_inp + else: + return input + + def _prepare_input_fn(self, inputs, device_mesh): + if self.input_layouts is None: + return inputs + prepared_inputs = [] + if not isinstance(inputs, tuple): + inputs = (inputs,) + if len(inputs) != len(self.input_layouts): + raise ValueError("module inputs and input_layouts should have same length!") + + assert self.desired_input_layouts is not None, ( + "desired module inputs should not be None!" + ) + for inp, input_layout, desired_layout in zip( + inputs, self.input_layouts, self.desired_input_layouts + ): + prepared_inputs.append( + self._prepare_input_arg(inp, device_mesh, input_layout, desired_layout) + ) + return tuple(prepared_inputs) + + def _prepare_input_kwarg_fn(self, inputs, kwarg_inputs, device_mesh): + prepared_arg_inputs = self._prepare_input_fn(inputs, device_mesh) + prepared_kwarg_inputs = {} + for kwarg_key in kwarg_inputs.keys(): + kwarg_val = kwarg_inputs[kwarg_key] + input_layout = self.input_kwarg_layouts.get(kwarg_key) + desired_input_layout = self.desired_input_kwarg_layouts.get(kwarg_key) + + prepared_kwarg_inputs[kwarg_key] = self._prepare_input_arg( + kwarg_val, device_mesh, input_layout, desired_input_layout + ) + + return (prepared_arg_inputs, prepared_kwarg_inputs) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + if self.with_kwargs: + module.register_forward_pre_hook( + lambda _, inputs, kwargs: self._prepare_input_kwarg_fn( + inputs, kwargs, device_mesh + ), + with_kwargs=True, + ) # type: ignore[misc] + else: + module.register_forward_pre_hook( + lambda _, inputs: self._prepare_input_fn(inputs, device_mesh) + ) # type: ignore[misc, call-arg] + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.input_layouts}, " + tmpstr += f"desired_input_layouts={self.desired_input_layouts}, " + tmpstr += f"input_kwarg_layouts={self.input_kwarg_layouts}, " + tmpstr += f"desired_input_kwarg_layouts={self.desired_input_kwarg_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class PrepareModuleOutput(ParallelStyle): + """ + Configure the nn.Module's outputs to convert the output tensors of the nn.Module to DTensors at runtime according to + ``output_layouts``, and perform layout redistribution according to the ``desired_output_layouts``. + + Keyword Args: + output_layouts (Union[Placement, Tuple[Placement]]): + The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to + DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, + ``None`` need to be specified as a placeholder. + desired_output_layouts (Union[Placement, Tuple[Placement]]): + The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module + have the desired DTensor layouts. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. + Returns: + A ParallelStyle object that prepares the sharding layouts of the nn.Module's outputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleOutput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the output of the TransformerBlock will be converted to Replicated DTensor + >>> # and then redistributed to Sharded DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan = PrepareModuleOutput( + >>> output_layouts=Replicate(), + >>> desired_output_layouts=Shard(0) + >>> ) + >>> ) + """ + + def __init__( + self, + *, + output_layouts: Union[Placement, tuple[Placement]], + desired_output_layouts: Union[Placement, tuple[Placement]], + use_local_output: bool = True, + ): + self.output_layouts = ( + (output_layouts,) + if isinstance(output_layouts, Placement) + else output_layouts + ) + self.desired_output_layouts = ( + (desired_output_layouts,) + if isinstance(desired_output_layouts, Placement) + else desired_output_layouts + ) + self.use_local_output = use_local_output + assert len(self.output_layouts) == len(self.desired_output_layouts), ( + "output_layouts and desired_output_layouts should have same length!" + ) + + def _prepare_out_fn(self, outputs, device_mesh): + prepared_outputs = [] + if not isinstance(outputs, tuple): + outputs = (outputs,) + if len(outputs) != len(self.output_layouts): + raise ValueError( + "module outputs and output_layouts should have same length!" + ) + for out, out_layout, desired_out_layout in zip( + outputs, self.output_layouts, self.desired_output_layouts + ): + if out_layout is not None: + if isinstance(out, DTensor): + # TODO: re-enable the check once we fix the compile path + # assert out.placements[0] == out_layout + dt_out = out + else: + dt_out = DTensor.from_local( + out, device_mesh, (out_layout,), run_check=False + ) + + if out_layout != desired_out_layout: + dt_out = dt_out.redistribute(placements=(desired_out_layout,)) + prepared_outputs.append( + dt_out.to_local() if self.use_local_output else dt_out + ) + else: + prepared_outputs.append(out) + if len(prepared_outputs) == 1: + return prepared_outputs[0] + else: + return tuple(prepared_outputs) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + module.register_forward_hook( + lambda _, inputs, outputs: self._prepare_out_fn(outputs, device_mesh) + ) # type: ignore[misc, call-arg] + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"output_layouts={self.output_layouts}, " + tmpstr += f"desired_output_layouts={self.desired_output_layouts}, " + tmpstr += f"use_local_output={self.use_local_output}" + tmpstr += ")" + return tmpstr + + +class PrepareModuleInputOutput(ParallelStyle): + """ + Configure the nn.Module's inputs (and outputs) to convert the input tensors (and output tensors, respectively) of the nn.Module + to DTensors at runtime according to ``input_layouts`` (and output_layouts, respectively), and perform layout redistribution + according to the ``desired_input_layouts`` (and ``desired_output_layouts``, respectively). This is a combination of + :class:`PrepareModuleInput` and :class:`PrepareModuleOutput`. + + Keyword Args: + input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The DTensor layouts of input tensors for the nn.Module, this is used to convert the input tensors to + DTensors. If some inputs are not torch.Tensor or no need to convert to DTensors, ``None`` need to be specified + as a placeholder. default: None. + desired_input_layouts (Union[Placement, Tuple[Optional[Placement]]]): + The desired DTensor layout of input tensors for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. This argument needs to have the same length with ``input_layouts``. default: None. + input_kwarg_layouts (Dict[str, Placement]): + The DTensor layouts of input kwargs for the nn.Module, this is used to convert the input kwarg tensors to DTensors. + default: None + desired_input_kwarg_layouts: (Dict[str, Placement]): + The desired DTensor layout of input kwargs for the nn.Module, this is used to ensure the inputs of the nn.Module + have the desired DTensor layouts. default: None. + use_local_input (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module inputs, default: False. + output_layouts (Union[Placement, Tuple[Placement]]): + The DTensor layouts of output tensors for the nn.Module, this is used to convert the output tensors to + DTensors if they are :class:`torch.Tensor`. If some outputs are not torch.Tensor or no need to convert to DTensors, + ``None`` need to be specified as a placeholder. + desired_output_layouts (Union[Placement, Tuple[Placement]]): + The desired DTensor layouts of output tensors for the nn.Module, this is used to ensure the outputs of the nn.Module + have the desired DTensor layouts. + use_local_output (bool, optional): + Whether to use local :class:`torch.Tensor` instead of :class:`DTensor` for the module outputs, default: True. + Returns: + A :class:`ParallelStyle` object that prepares the sharding layouts of the nn.Module's inputs and outputs. + + Example:: + >>> # xdoctest: +SKIP(failing) + >>> from torch.distributed.tensor.parallel import parallelize_module, PrepareModuleInputOutput + >>> from torch.distributed.device_mesh import init_device_mesh + >>> ... + >>> block = TransformerBlock(...) # block is a nn.Module that contains an "attn" Attention submodule + >>> tp_mesh = init_device_mesh("cuda", (8,)) + >>> + >>> # According to the style specified below, the first input of attn will be annotated as Sharded DTensor + >>> # and then redistributed to Replicated DTensor, and the output of the TransformerBlock will be annotated + >>> # as Replicated DTensor and then redistributed to Sharded DTensor. + >>> parallelize_module( + >>> block, # this can be a submodule or module + >>> tp_mesh, + >>> parallelize_plan={ + >>> "attn": PrepareModuleInputOutput( + >>> input_layouts=(Shard(0), None, None, ...), + >>> desired_input_layouts=(Replicate(), None, None, ...), + >>> output_layouts=Replicate(), + >>> desired_output_layouts=Shard(0), + >>> ), + >>> } + >>> ) + """ + + def __init__( + self, + *, + input_layouts: Optional[Union[Placement, tuple[Optional[Placement]]]] = None, + desired_input_layouts: Optional[ + Union[Placement, tuple[Optional[Placement]]] + ] = None, + input_kwarg_layouts: Optional[dict[str, Placement]] = None, + desired_input_kwarg_layouts: Optional[dict[str, Placement]] = None, + use_local_input: bool = False, + output_layouts: Union[Placement, tuple[Placement]], + desired_output_layouts: Union[Placement, tuple[Placement]], + use_local_output: bool = True, + ): + self.prepare_module_input = PrepareModuleInput( + input_layouts=input_layouts, + desired_input_layouts=desired_input_layouts, + input_kwarg_layouts=input_kwarg_layouts, + desired_input_kwarg_layouts=desired_input_kwarg_layouts, + use_local_output=use_local_input, + ) + self.prepare_module_output = PrepareModuleOutput( + output_layouts=output_layouts, + desired_output_layouts=desired_output_layouts, + use_local_output=use_local_output, + ) + + def _apply(self, module: nn.Module, device_mesh: DeviceMesh) -> nn.Module: + self.prepare_module_input._apply(module, device_mesh) + self.prepare_module_output._apply(module, device_mesh) + + return module + + def __repr__(self) -> str: + tmpstr = self.__class__.__name__ + "(" + tmpstr += f"input_layouts={self.prepare_module_input.input_layouts}, " + tmpstr += ( + f"desired_input_layouts={self.prepare_module_input.desired_input_layouts}, " + ) + tmpstr += ( + f"input_kwarg_layouts={self.prepare_module_input.input_kwarg_layouts}, " + ) + tmpstr += f"desired_input_kwarg_layouts={self.prepare_module_input.desired_input_kwarg_layouts}, " + tmpstr += f"use_local_input={self.prepare_module_input.use_local_output}, " + tmpstr += f"output_layouts={self.prepare_module_output.output_layouts}, " + tmpstr += f"desired_output_layouts={self.prepare_module_output.desired_output_layouts}, " + tmpstr += f"use_local_output={self.prepare_module_output.use_local_output}" + tmpstr += ")" + return tmpstr diff --git a/phivenv/Lib/site-packages/torch/distributed/tensor/placement_types.py b/phivenv/Lib/site-packages/torch/distributed/tensor/placement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..58025e2ff25d74e15681d4ba015d551789806e33 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributed/tensor/placement_types.py @@ -0,0 +1,732 @@ +# mypy: allow-untyped-defs +# Copyright (c) Meta Platforms, Inc. and affiliates + +from dataclasses import dataclass +from typing import cast, Optional + +import torch +import torch.distributed._functional_collectives as funcol +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.tensor._collective_utils import ( + fill_empty_tensor_to_shards, + mesh_broadcast, + mesh_scatter, + pad_tensor, + shard_dim_alltoall, + unpad_tensor, +) + + +__all__ = ["Placement", "Shard", "Replicate", "Partial"] + + +class Placement: + """ + The base class for the Placement type, where it describes how a DTensor is placed onto the + ``DeviceMesh``. ``Placement`` and ``DeviceMesh`` together could describe the DTensor Layout. + It is the base class of the three main DTensor Placement types: ``Shard``, ``Replicate``, + and ``Partial``. + + This class is not meant to be used directly, mainly served as a typing stub. + """ + + # convenient utils to check for placement types + def is_shard(self, dim: Optional[int] = None) -> bool: + is_shard_instance = isinstance(self, Shard) + if dim is not None and is_shard_instance: + return cast(Shard, self).dim == dim + else: + return is_shard_instance + + def is_replicate(self) -> bool: + return isinstance(self, Replicate) + + def is_partial(self) -> bool: + return isinstance(self, Partial) + + +@dataclass(frozen=True) +class Shard(Placement): + """ + The ``Shard(dim)`` placement describes the DTensor sharding on tensor dimension + ``dim`` over a corresponding ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension only holds a shard/piece of the global Tensor. The + ``Shard(dim)`` placement follows the ``torch.chunk(dim)`` semantic, where the + last few shards on the DeviceMesh dimension might be empty when the tensor dimension + is not evenly divisible on the DeviceMesh dimension. The ``Shard`` placement can be + used by all DTensor APIs (i.e. distribute_tensor, from_local, etc.) + + Args: + dim (int): The tensor dimension that describes the DTensor is sharded over its + corresponding DeviceMesh dimension. + + .. warning:: sharding on a tensor dimension where the tensor dimension size is not + evenly divisible on a DeviceMesh dimension is currently experimental and subject to change. + """ + + dim: int + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> tuple[list[torch.Tensor], list[int]]: + """ + This function uses torch.chunk to split a tensor into num_chunks shards along + the Shard placement dimension, and return a list of shards with their pad sizes. + + Keyword args: + with_padding (bool, optional): when True, we pad the tensor on the last + few ranks before calling the collectives (i.e. scatter/all_gather, etc.). + This is because collectives usually require equal size tensor inputs + """ + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) + + # chunk tensor over dimension `dim` into n slices + tensor_list = list(torch.chunk(tensor, num_chunks, dim=self.dim)) + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, num_chunks - len(tensor_list) + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + num_chunks - 1) // num_chunks + + shard_list: list[torch.Tensor] = [] + pad_sizes: list[int] = [] + for shard in tensor_list: + if with_padding: + pad_size = full_chunk_size - shard.size(self.dim) + shard = pad_tensor(shard, self.dim, pad_size) + pad_sizes.append(pad_size) + if contiguous: + shard = shard.contiguous() + shard_list.append(shard) + return shard_list, pad_sizes + + @staticmethod + def _local_shard_size_and_offset( + curr_local_size: int, + num_chunks: int, + rank: int, + ) -> tuple[int, int]: + """ + Given the size of the current local tensor (which may already be sharded on some dimensions), + computes the new local shard size and offset given the desired number of chunks + (num_chunks is generally equal to the size of the current sharding dim). + + Note: new local shard offset is relative to the current sharded tensor, not the global tensor. + See `_utils.compute_local_shape_and_global_offset` for computing global offset. + + Returns (new local shard size, offset) + + """ + # Compute the chunk size inline with ``torch.chunk`` + if curr_local_size % num_chunks == 0: + full_chunk_size = curr_local_size // num_chunks + return full_chunk_size, full_chunk_size * rank + + # uneven sharding case + full_chunk_size = (curr_local_size + num_chunks - 1) // num_chunks + shard_starting_idx = full_chunk_size * rank + + if curr_local_size < shard_starting_idx: + return 0, curr_local_size + else: + local_shard_size = ( + min(curr_local_size, shard_starting_idx + full_chunk_size) + - shard_starting_idx + ) + return local_shard_size, shard_starting_idx + + def _shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: Optional[int] = 0, + ) -> torch.Tensor: + """ + shard and scatter a tensor on a mesh dimension (use coordinate + 0 on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + mesh_dim_local_rank = my_coordinate[mesh_dim] + + if src_data_rank is None: + # src_data_rank specified as None explicitly means to skip the + # communications, simply split + scatter_list, _ = self._split_tensor( + tensor, num_chunks, with_padding=False, contiguous=True + ) + + return scatter_list[mesh_dim_local_rank] + + scatter_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + output = torch.empty_like(scatter_list[mesh_dim_local_rank]) + + # perform scatter from the src_data_rank as data source when it is not None + mesh_scatter( + output, scatter_list, mesh, mesh_dim=mesh_dim, group_src=src_data_rank + ) + + # Only unpad if the local_tensor was padded on the dimension. + if pad_sizes[mesh_dim_local_rank] > 0: + output = unpad_tensor(output, self.dim, pad_sizes[mesh_dim_local_rank]) + # Unpad might return a view, hence we need to remake it contiguous + output = output.contiguous() + return output + + def _reduce_shard_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + reduce_op: str, + mesh_dim: int, + ) -> torch.Tensor: + """ + reduce and scatter a tensor on a mesh dimension + """ + my_coordinate = mesh.get_coordinate() + num_chunks = mesh.size(mesh_dim=mesh_dim) + + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return tensor + + is_padded = tensor.size(self.dim) % num_chunks != 0 + if is_padded: + scattered_list, pad_sizes = self._split_tensor( + tensor, num_chunks, with_padding=True, contiguous=True + ) + tensor = torch.cat(scattered_list, dim=self.dim) + elif not tensor.is_contiguous(): + tensor = tensor.contiguous() + + output = funcol.reduce_scatter_tensor( + tensor, reduce_op, scatter_dim=self.dim, group=(mesh, mesh_dim) + ) + + if is_padded: + output = unpad_tensor(output, self.dim, pad_sizes[my_coordinate[mesh_dim]]) # type: ignore[possibly-undefined] + return output + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + ) -> torch.Tensor: + """ + This function all_gather all shards and return a tensor that + is replicated on the previously sharded mesh dimension + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + + logical_dim_size = current_logical_shape[self.dim] + is_padded = logical_dim_size % num_chunks != 0 + + if is_padded: + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + pad_size = full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if is_padded: + unpad_size = full_chunk_size * num_chunks - logical_dim_size # type: ignore[possibly-undefined] + result = unpad_tensor(result, self.dim, unpad_size) + return result + + def _replicate_to_shard( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_index: int, + ) -> torch.Tensor: + """ + transform from replicated tensor to a sharded tensor on + the current rank, which would perform a local chunk + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + shards, _ = self._split_tensor( + local_tensor, + num_chunks, + with_padding=False, + contiguous=False, + ) + return shards[shard_index].clone() + + def _to_new_shard_dim( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + new_shard_dim: int, + ) -> torch.Tensor: + """ + transform from existing sharded tensor to a new sharded tensor on + that shard on a new dimension, which performs an alltoall + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return local_tensor, + # which should be an empty tensor + return local_tensor + + num_chunks = mesh.size(mesh_dim=mesh_dim) + + old_dim_logical_size = current_logical_shape[self.dim] + new_dim_logical_size = current_logical_shape[new_shard_dim] + old_dim_padding = old_dim_logical_size % num_chunks != 0 + new_dim_padding = new_dim_logical_size % num_chunks != 0 + if old_dim_padding: + old_dim_full_chunk_size = ( + old_dim_logical_size + num_chunks - 1 + ) // num_chunks + old_dim_pad_size = old_dim_full_chunk_size - local_tensor.size(self.dim) + local_tensor = pad_tensor(local_tensor, self.dim, old_dim_pad_size) + if new_dim_padding: + new_dim_full_chunk_size = ( + new_dim_logical_size + num_chunks - 1 + ) // num_chunks + new_dim_pad_size = new_dim_full_chunk_size * num_chunks - local_tensor.size( + new_shard_dim + ) + local_tensor = pad_tensor(local_tensor, new_shard_dim, new_dim_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + new_tensor = shard_dim_alltoall( + local_tensor, self.dim, new_shard_dim, mesh, mesh_dim + ) + + if old_dim_padding: + old_dim_unpad_size = ( + old_dim_full_chunk_size * num_chunks - current_logical_shape[self.dim] # type: ignore[possibly-undefined] + ) + new_tensor = unpad_tensor(new_tensor, self.dim, old_dim_unpad_size) # type: ignore[possibly-undefined] + + if new_dim_padding: + local_shard_size_on_new_dim = self._local_shard_size_and_offset( + new_dim_logical_size, num_chunks, my_coordinate[mesh_dim] + )[0] + new_dim_unpad_size = new_dim_full_chunk_size - local_shard_size_on_new_dim # type: ignore[possibly-undefined] + new_tensor = unpad_tensor(new_tensor, new_shard_dim, new_dim_unpad_size) # type: ignore[possibly-undefined] + + return new_tensor + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Shard): + return False + return self.dim == other.dim + + def __hash__(self) -> int: + return hash(self.dim) + + def __repr__(self) -> str: + """ + machine readable representation of the Shard placement + """ + return f"Shard(dim={self.dim})" + + def __str__(self) -> str: + """human readable representation of the Shard placement""" + return f"S({self.dim})" + + +# kw_only is only available in python >= 3.10 +kw_only_dataclass = dict(kw_only=True) if "kw_only" in dataclass.__kwdefaults__ else {} + + +@dataclass(frozen=True, **kw_only_dataclass) +class _StridedShard(Shard): + """ + _StridedShard is only introduced to support 2D FSDP2 + TP sharding where the tensor + is sharded on the TP mesh dimension first, then sharded on the FSDP mesh dimension. + We call this right-to-left sharding which is the opposite of the default + left-to-right sharding. See the example below: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [Shard(0), Shard(0)] + + The default sharding behavior shards the tensor on "dp" mesh dimension first then + "tp" dimension. The sharding result will be: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 1 (row 2-3) + 2 | (1, 0) | 2 (row 4-5) + 3 | (1, 1) | 3 (row 6-7) + + While the FSDP2 + TP sharding behavior does the opposite: it shards the tensor on + "tp" mesh dim first then "dp" dim. This right-to-left sharding will produce the + result: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The consequence is, any attempt to redistribute this DTensor to a full replica will + produce a wrong result because the shard-to-replicate redistribution always happens + right-to-left, regardless it's left-to-right sharding or right-to-left. To address + this, we use _StridedShard placement to make this right-to-left sharding compatible + with our left-to-right convention on both tensor distribution and redistribution. + + Now with _StridedShard, the right-to-left sharding above can be represented as: + tensor shape: [8, 8] + mesh: [[0, 1], [2, 3]], names=("dp", "tp") + placements: [_StridedShard(0, split_factor=2), Shard(0)] + + And a left-to-right processing of `placements` will produce the same result, which is + different from using the `Shard` placement: + Rank | Mesh Coordinate | Shard Index + ------------------------------------------------ + 0 | (0, 0) | 0 (row 0-1) + 1 | (0, 1) | 2 (row 4-5) + 2 | (1, 0) | 1 (row 2-3) + 3 | (1, 1) | 3 (row 6-7) + + The argument `split_factor` is the number of existing shards over the tensor sharding + dimension before processing the _StridedShard placement, as if the sharding happened + right-to-left. In the example above, the tensor should first be sharded on the "tp" + dimension into 2 shards before being sharded on the "dp" dimension. Therefore, the + `split_factor` of the _StridedShard placement on "dp" dim is 2. + + TODO: we should remove _StridedShard placement once we can unify it with Shard + """ + + split_factor: int + + def __eq__(self, other: object) -> bool: + if isinstance(other, _StridedShard): + return self.dim == other.dim and self.split_factor == other.split_factor + elif isinstance(other, Shard): + # TODO: this is to avoid extra all-gather in dtensor op dispatch + # note that sharding prop would not produce _StridedShard and an + # placement inequality would introduce an all-gather for resharding + return self.dim == other.dim + return False + + def __hash__(self) -> int: + return hash((self.dim, self.split_factor)) + + def __repr__(self) -> str: + """ + machine readable representation of the _StridedShard placement + """ + return f"_StridedShard(dim={self.dim}, sf={self.split_factor})" + + def __str__(self) -> str: + """human readable representation of the _StridedShard placement""" + return f"_S({self.dim}, {self.split_factor})" + + def _split_tensor( + self, + tensor: torch.Tensor, + num_chunks: int, + *, + with_padding: bool = True, + contiguous: bool = True, + ) -> tuple[list[torch.Tensor], list[int]]: + assert self.dim <= tensor.ndim, ( + f"Sharding dim {self.dim} greater than tensor ndim {tensor.ndim}" + ) + + # num_chunks represents the size of this StridedShard mesh dim, while self.split_factor + # represents the aggregate num chunks for other shardings applied logically earlier than this strided shard. + # (e.g. in FSDP+TP case, num_chunks is size(dp dim), split_factor is size(tp dim)) + total_split = num_chunks * self.split_factor + + tensor_list = list(torch.chunk(tensor, total_split, dim=self.dim)) + tensor_list = fill_empty_tensor_to_shards( + tensor_list, self.dim, total_split - len(tensor_list) + ) + + # compute the chunk size inline with ``torch.chunk`` to calculate padding + full_chunk_size = (tensor.size(self.dim) + total_split - 1) // total_split + + shard_list: list[torch.Tensor] = [] + pad_sizes: list[int] = [] + for i in range(num_chunks): + shard = torch.cat( + [tensor_list[i + j * num_chunks] for j in range(self.split_factor)], + dim=self.dim, + ) + if with_padding: + pad_size = full_chunk_size * self.split_factor - shard.size(self.dim) + shard = pad_tensor(shard, self.dim, pad_size) + pad_sizes.append(pad_size) + if contiguous: + shard = shard.contiguous() + shard_list.append(shard) + return shard_list, pad_sizes + + def _to_replicate_tensor( + self, + local_tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + current_logical_shape: list[int], + ) -> torch.Tensor: + """ + Given a tensor with strided sharding (e.g. [StridedShard(d), Shard(d)]), + this function is called during the process of converting to [Replicate(), Replicate()], + and `local_tensor` represents the portion of the tensor on this rank after the intermediate step of + converting to [StridedShard(d), Replicate()] in right-to-left unsharding order. + + note: this conversion logic is pretty specialized on this 2D case. It could be generalized further. This + is a common enough case to be worth fixing (since it occurs when applying TP and then FSDP to a model). + + note: this does not support 'reduce_scatter' for StridedShard. + + Example + ------- + mesh = (DP=2, TP=2) + # single-gpu "weight" of size 5, will be 'uneven' for sharding + original = torch.arange(5) + + tp sharded tensor + ----------------- + `tp = distribute_tensor(x, world_mesh['tp'], [Shard(0)])` + + local_tensors: + rank0: [0,1,2] rank1: [3,4] + rank1: [0,1,2] rank3: [3,4] + + fsdp+tp sharded tensor + ---------------------- + `dp_tp = ...` (the process of creating a strided-shard tensor is skipped over as it is complicated + dp_tp has placement (_StridedShard(0, split_factor=2), Shard(0)) + local_tensors: + rank0: [0,1] rank1: [3] + rank1: [2] rank3: [4] + + Now, say someone wants to reconstruct dp_tp's full tensor. This will invoke 'redistribute' to replicate. + redistribute will first replicate the "Shard(0)" placement on the rightmost mesh dim, then replicate the + StridedShard placement second, which is implemented by this function. + So our starting point (`local_tensor` arg) is the result of replicating the Shard(0) placement across the + TP dim, which looks like this. + + Note the discrepancy with the 'tp sharded tensor' line above! We'll fix it by locally shuffling data. + + local_tensors: + rank0: [0,1,3] rank1: [0,1,3] + rank2: [2,4] rank3: [2,4] + + Step 1: replicate over the DP dimension. Afterwards, each rank can locally sort the values. + note: we need padding to do this allgather, and we'll need to keep track of the padding amount for later + local_tensors: + rank0: [0,1,3,2,4] rank1: [0,1,3,2,4] + rank2: [0,1,3,2,4] rank3: [0,1,3,2,4] + + Step 2: chunk and shuffle values around to account for the wrong order of operations above + and get the original tensor content back + + 01324# <- our allgather includes padding, if padding was applied in step 1 + 01324 <- Remove the padding + 013, 24 <- chunk once, 'undoing' the DP allgather + 01, 3, 2, 4 <- chunk each chunk, 'undoing' the initial (wrong) TP allgather performed by Shard(0)->Replicate() + 012, 34 <- interleave with stride=TP mesh dim size + 01234 <- concatenate + + Note: the current implementation of this function is incomplete, and supports only the common pattern of one + strided shard placement, which is used in the FSDP + TP case. We could extend this implementation to handle + multiple strided shardings (e.g. [StridedShard, StridedShard, Shard]), by repeating the chunking step more times + and handling more complex shuffling in the last step. On the other hand, we plan to replace 'StridedShard' + with using just Shard and specifying a sharding order, so it may be ok to leave this as-is for the time being. + """ + num_chunks = mesh.size(mesh_dim=mesh_dim) + logical_dim_size = current_logical_shape[self.dim] + full_chunk_size = (logical_dim_size + num_chunks - 1) // num_chunks + local_pad_size = full_chunk_size - local_tensor.size(self.dim) + + local_tensor = pad_tensor(local_tensor, self.dim, local_pad_size) + + if not local_tensor.is_contiguous(): + local_tensor = local_tensor.contiguous() + + result = funcol.all_gather_tensor( + local_tensor, + gather_dim=self.dim, + group=(mesh, mesh_dim), + ) + if isinstance(result, funcol.AsyncCollectiveTensor): + result = result.wait() + + if result.shape[self.dim] > logical_dim_size: + result = unpad_tensor( + result, self.dim, result.shape[self.dim] - logical_dim_size + ) + + # this reverses our 'all_gather' but gives every rank a copy + outer_shards = torch.chunk(result, num_chunks, dim=self.dim) + # this undoes the 'Shard(0)' -> Replicate() that happened over the wrong mesh dim in the first place + inner_shards: list[torch.Tensor] = [] + for p in outer_shards: + inner_shards.extend(torch.chunk(p, self.split_factor, dim=self.dim)) + # now we just have to correctly stride the shards + reordered_shards = [] + for i in range(self.split_factor): + reordered_shards.extend(inner_shards[i :: self.split_factor]) + return torch.cat(reordered_shards, dim=self.dim).contiguous() + + +@dataclass(frozen=True) +class Replicate(Placement): + """ + The ``Replicate()`` placement describes the DTensor replicating on a corresponding + ``DeviceMesh`` dimension, where each rank on the DeviceMesh dimension holds a + replica of the global Tensor. The ``Replicate`` placement can be used by all + DTensor APIs (i.e. ``distribute_tensor``, ``DTensor.from_local``, etc.) + """ + + def __eq__(self, other: object) -> bool: + return isinstance(other, Replicate) + + def __hash__(self) -> int: + # every replicate placement is the same + return -1 + + def __repr__(self) -> str: + """ + machine readable representation of the Replicate placement + """ + return "Replicate()" + + def __str__(self) -> str: + """ + human readable representation of the Replicate placement + """ + return "R" + + def _replicate_tensor( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + src_data_rank: Optional[int] = 0, + ) -> torch.Tensor: + """ + Replicate (broadcast) a torch.Tensor on a mesh dimension (use + the first coordinate on the mesh dimension as source of truth) + """ + my_coordinate = mesh.get_coordinate() + if my_coordinate is None: + # if rank is not part of mesh, we simply return an empty tensor + return tensor.new_empty(0, requires_grad=tensor.requires_grad) + + tensor = tensor.contiguous() + + if src_data_rank is not None: + # perform broadcast from the src_data_rank as data source when it is not None + mesh_broadcast(tensor, mesh, mesh_dim=mesh_dim, group_src=src_data_rank) + return tensor + + +@dataclass(frozen=True) +class Partial(Placement): + """ + The ``Partial(reduce_op)`` placement describes the DTensor that is pending + reduction on a specified ``DeviceMesh`` dimension, where each rank on the + DeviceMesh dimension holds the partial value of the global Tensor. User can + redistribute the ``Partial`` DTensor to a ``Replicate`` or ``Shard(dim)`` + placement on the specified ``DeviceMesh`` dimension using ``redistribute``, + which would trigger necessary communication operations under the hood (i.e. + ``allreduce``, ``reduce_scatter``). + + Args: + reduce_op (str, optional): The reduction op to be used for the partial DTensor + to produce Replicated/Sharded DTensor. Only element-wise reduction operations + are supported, including: "sum", "avg", "product", "max", "min", default: "sum". + + .. note:: The ``Partial`` placement can be generated as a result of the DTensor operators, + and can only be used by the ``DTensor.from_local`` API. + """ + + reduce_op: str = "sum" + + def _reduce_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #1: + # _reduce_value: reduce the value of the tensor on the mesh dimension + return funcol.all_reduce( + tensor, reduceOp=self.reduce_op, group=(mesh, mesh_dim) + ) + + def _reduce_shard_value( + self, + tensor: torch.Tensor, + mesh: DeviceMesh, + mesh_dim: int, + shard_spec: Placement, + ) -> torch.Tensor: + # Partial placement contract #2: + # _reduce_shard_value: reduce_scatter the value of the tensor over the mesh dimension + shard_spec = cast(Shard, shard_spec) + return shard_spec._reduce_shard_tensor(tensor, mesh, self.reduce_op, mesh_dim) + + def _partition_value( + self, tensor: torch.Tensor, mesh: DeviceMesh, mesh_dim: int + ) -> torch.Tensor: + # Partial placement contract #3: + # _partition_value: partition the value of a replicated tensor on the mesh dimension + + # _partition_value is the conjugate operation of _reduce_value + # - i.e. _partition_value on a sum reduce op is just a divison operation + # - the _reduce_value on a sum reduce op would just be a sum(allreduce) operation + # TODO: if the reduce_op is min/max, etc. the _partition_value should be a + # different operation + assert self.reduce_op == "sum", "only support replicate to PartialSUM for now!" + num_chunks = mesh.size(mesh_dim=mesh_dim) + return tensor / num_chunks + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Partial): + return False + return self.reduce_op == other.reduce_op + + def __hash__(self) -> int: + return 1 + hash(self.reduce_op) + + def __repr__(self) -> str: + """ + machine readable representation of the Partial placement + """ + return f"Partial({self.reduce_op})" + + def __str__(self) -> str: + """ + human readable representation of the Partial placement + """ + return "P" + + +# We keep the old _Partial name for a while for BC reason +_Partial = Partial diff --git a/phivenv/Lib/site-packages/torch/distributions/__init__.py b/phivenv/Lib/site-packages/torch/distributions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8e204d12fba8fcc3f6eac76a9cc112814263b293 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/__init__.py @@ -0,0 +1,174 @@ +r""" +The ``distributions`` package contains parameterizable probability distributions +and sampling functions. This allows the construction of stochastic computation +graphs and stochastic gradient estimators for optimization. This package +generally follows the design of the `TensorFlow Distributions`_ package. + +.. _`TensorFlow Distributions`: + https://arxiv.org/abs/1711.10604 + +It is not possible to directly backpropagate through random samples. However, +there are two main methods for creating surrogate functions that can be +backpropagated through. These are the score function estimator/likelihood ratio +estimator/REINFORCE and the pathwise derivative estimator. REINFORCE is commonly +seen as the basis for policy gradient methods in reinforcement learning, and the +pathwise derivative estimator is commonly seen in the reparameterization trick +in variational autoencoders. Whilst the score function only requires the value +of samples :math:`f(x)`, the pathwise derivative requires the derivative +:math:`f'(x)`. The next sections discuss these two in a reinforcement learning +example. For more details see +`Gradient Estimation Using Stochastic Computation Graphs`_ . + +.. _`Gradient Estimation Using Stochastic Computation Graphs`: + https://arxiv.org/abs/1506.05254 + +Score function +^^^^^^^^^^^^^^ + +When the probability density function is differentiable with respect to its +parameters, we only need :meth:`~torch.distributions.Distribution.sample` and +:meth:`~torch.distributions.Distribution.log_prob` to implement REINFORCE: + +.. math:: + + \Delta\theta = \alpha r \frac{\partial\log p(a|\pi^\theta(s))}{\partial\theta} + +where :math:`\theta` are the parameters, :math:`\alpha` is the learning rate, +:math:`r` is the reward and :math:`p(a|\pi^\theta(s))` is the probability of +taking action :math:`a` in state :math:`s` given policy :math:`\pi^\theta`. + +In practice we would sample an action from the output of a network, apply this +action in an environment, and then use ``log_prob`` to construct an equivalent +loss function. Note that we use a negative because optimizers use gradient +descent, whilst the rule above assumes gradient ascent. With a categorical +policy, the code for implementing REINFORCE would be as follows:: + + probs = policy_network(state) + # Note that this is equivalent to what used to be called multinomial + m = Categorical(probs) + action = m.sample() + next_state, reward = env.step(action) + loss = -m.log_prob(action) * reward + loss.backward() + +Pathwise derivative +^^^^^^^^^^^^^^^^^^^ + +The other way to implement these stochastic/policy gradients would be to use the +reparameterization trick from the +:meth:`~torch.distributions.Distribution.rsample` method, where the +parameterized random variable can be constructed via a parameterized +deterministic function of a parameter-free random variable. The reparameterized +sample therefore becomes differentiable. The code for implementing the pathwise +derivative would be as follows:: + + params = policy_network(state) + m = Normal(*params) + # Any distribution with .has_rsample == True could work based on the application + action = m.rsample() + next_state, reward = env.step(action) # Assuming that reward is differentiable + loss = -reward + loss.backward() +""" + +from . import transforms +from .bernoulli import Bernoulli +from .beta import Beta +from .binomial import Binomial +from .categorical import Categorical +from .cauchy import Cauchy +from .chi2 import Chi2 +from .constraint_registry import biject_to, transform_to +from .continuous_bernoulli import ContinuousBernoulli +from .dirichlet import Dirichlet +from .distribution import Distribution +from .exp_family import ExponentialFamily +from .exponential import Exponential +from .fishersnedecor import FisherSnedecor +from .gamma import Gamma +from .generalized_pareto import GeneralizedPareto +from .geometric import Geometric +from .gumbel import Gumbel +from .half_cauchy import HalfCauchy +from .half_normal import HalfNormal +from .independent import Independent +from .inverse_gamma import InverseGamma +from .kl import _add_kl_info, kl_divergence, register_kl +from .kumaraswamy import Kumaraswamy +from .laplace import Laplace +from .lkj_cholesky import LKJCholesky +from .log_normal import LogNormal +from .logistic_normal import LogisticNormal +from .lowrank_multivariate_normal import LowRankMultivariateNormal +from .mixture_same_family import MixtureSameFamily +from .multinomial import Multinomial +from .multivariate_normal import MultivariateNormal +from .negative_binomial import NegativeBinomial +from .normal import Normal +from .one_hot_categorical import OneHotCategorical, OneHotCategoricalStraightThrough +from .pareto import Pareto +from .poisson import Poisson +from .relaxed_bernoulli import RelaxedBernoulli +from .relaxed_categorical import RelaxedOneHotCategorical +from .studentT import StudentT +from .transformed_distribution import TransformedDistribution +from .transforms import * # noqa: F403 +from .uniform import Uniform +from .von_mises import VonMises +from .weibull import Weibull +from .wishart import Wishart + + +_add_kl_info() +del _add_kl_info + +__all__ = [ + "Bernoulli", + "Beta", + "Binomial", + "Categorical", + "Cauchy", + "Chi2", + "ContinuousBernoulli", + "Dirichlet", + "Distribution", + "Exponential", + "ExponentialFamily", + "FisherSnedecor", + "Gamma", + "GeneralizedPareto", + "Geometric", + "Gumbel", + "HalfCauchy", + "HalfNormal", + "Independent", + "InverseGamma", + "Kumaraswamy", + "LKJCholesky", + "Laplace", + "LogNormal", + "LogisticNormal", + "LowRankMultivariateNormal", + "MixtureSameFamily", + "Multinomial", + "MultivariateNormal", + "NegativeBinomial", + "Normal", + "OneHotCategorical", + "OneHotCategoricalStraightThrough", + "Pareto", + "RelaxedBernoulli", + "RelaxedOneHotCategorical", + "StudentT", + "Poisson", + "Uniform", + "VonMises", + "Weibull", + "Wishart", + "TransformedDistribution", + "biject_to", + "kl_divergence", + "register_kl", + "transform_to", +] +__all__.extend(transforms.__all__) diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eab7bc1acb4822e0519dbd14f0eb360b52e98443 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/bernoulli.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/bernoulli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60ba0ddcae923bdcfe524393256dd85940533e3e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/bernoulli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/beta.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/beta.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bd5e3c75bde48ba41818aa5ec47b6047d7e4818 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/beta.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/binomial.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/binomial.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f89f5d98e1db135c226e70000a5f8c41d61f2d40 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/binomial.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/categorical.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/categorical.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..423520fc3789d39067480325ca1cc11acd0960a7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/categorical.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/cauchy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/cauchy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa8641f31be4554f7a3297dff911156389ce8d95 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/cauchy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/chi2.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/chi2.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2493fbb613ff1df5cc876fe91bf2fd5c5b1903fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/chi2.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6738c64938c74ab02b7cbe562249ee8661b22ce6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/constraint_registry.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/constraints.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/constraints.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2314eb83852328b3c7c561ac42b2e64ddc2a1ced Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/constraints.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06e35c076277ec214c758b4ca87c144a0fa9edcd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/continuous_bernoulli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/dirichlet.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/dirichlet.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba683f04db110f39103b53fb41addd9d96df6b04 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/dirichlet.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/distribution.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/distribution.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7b7e63eaf0e88931e2cb75f9d1807dbb0e09610 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/distribution.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/exp_family.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/exp_family.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1896a392d457a81cf3323ede44bf5217d83b0e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/exp_family.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/exponential.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/exponential.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15d54352aa9989bb4631ccf1ce867441fb0c03b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/exponential.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c666cc86297047bf4409a8abbf09a0074392873 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/fishersnedecor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/gamma.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/gamma.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d57186651b10a5ede754ccbe59a8d6947bb55d1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/gamma.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29d460493f631c0280f3cca40e7bf0fd20b51953 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/generalized_pareto.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/geometric.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/geometric.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10cac3ce11eca35706c7aa97d309e5bed753ff7a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/geometric.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/gumbel.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/gumbel.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a574ba7dc0d79d786a9ea5349bee6ccd336639f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/gumbel.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af559cf374dd0142f30df68a747818e9c123c2fa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/half_cauchy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/half_normal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/half_normal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3a444fbd4f02ab709dd0e9637444d2eb4b704de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/half_normal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/independent.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/independent.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4d516c86ca174094d19ac43d15f47e590e7047c0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/independent.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03dc55c014bf48f00a5130734808fc63e2938757 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/inverse_gamma.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/kl.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/kl.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..baeaa0daf5b88f6e1b345cc2f13e2d0b6b9d90e5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/kl.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3726880f5acb47cc2e51a9444d6e1de636da37df Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/kumaraswamy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/laplace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/laplace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc8e70d16004fd46e2293e00da02c63caabc7e8f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/laplace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a684354d8bd5c6e8d0047d111672d17be4fc7b9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/lkj_cholesky.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/log_normal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/log_normal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23c5ef888a25166dbd27f478ca8a8fc566c4b513 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/log_normal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb4c5be5f62274ff954aab8d3b56b6bc274562be Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/logistic_normal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..527f8cf01b9f289c075d2995ebca3cfc596418b7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/lowrank_multivariate_normal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a0371d9e70fc91f458a9f48d97021b412cc9e2d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/mixture_same_family.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/multinomial.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/multinomial.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90fb7d5ed7959f42a2e7ef68c179b62032dd0a6c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/multinomial.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30b82f6f3c6397c84a980f458ad7cf7ada12d875 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/multivariate_normal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67259e949074dce610a756090654c7ab858dae50 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/negative_binomial.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/normal.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/normal.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5ff1daca0eb82c240d9b143ac6ef4aa17644f57d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/normal.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1675b13a2319e9e1d061f5250077bdc5b5e340ee Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/one_hot_categorical.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/pareto.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/pareto.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1a7ced50898e6b30176065b00a7acfc69053564 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/pareto.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/poisson.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/poisson.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9fbb9dfe3fa64d399553b96b85f93b291bc33e6b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/poisson.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c6ebf91fc87d51e1fa663969562cfddc5faf1dd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/relaxed_bernoulli.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96a0782afe1c233cd37c6c2adec549a6b89188c4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/relaxed_categorical.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/studentT.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/studentT.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83b7b77ed0447c1e847421d963e47560bc91f44a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/studentT.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1fa6fa06abea6ba86c4d6d8f2a173bfbece012c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/transformed_distribution.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/transforms.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/transforms.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9bd08049c213431d2473ec71ef80e6b65b461ef Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/transforms.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/uniform.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/uniform.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa5c92b0419fccdee801cf268538fcc9b79f6980 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/uniform.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5389802d7c9aa517b760d9be6a5dcc3d3f949339 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/von_mises.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/von_mises.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..665c02fb4f87f6b32adc98a73d0691b39f708a26 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/von_mises.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/weibull.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/weibull.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c199432f43faebfed58cbdfa795ceb0d40fc2a4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/weibull.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/__pycache__/wishart.cpython-39.pyc b/phivenv/Lib/site-packages/torch/distributions/__pycache__/wishart.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..63abe741e098ca6f1849b6cb89549f6aecac870a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/distributions/__pycache__/wishart.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/distributions/bernoulli.py b/phivenv/Lib/site-packages/torch/distributions/bernoulli.py new file mode 100644 index 0000000000000000000000000000000000000000..7dd1faf8c0174dad65b545c1841dbe15e1dacc0b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/bernoulli.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import nan, Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits +from torch.types import _Number, Number + + +__all__ = ["Bernoulli"] + + +class Bernoulli(ExponentialFamily): + r""" + Creates a Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + Samples are binary (0 or 1). They take the value `1` with probability `p` + and `0` with probability `1 - p`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Bernoulli(torch.tensor([0.3])) + >>> m.sample() # 30% chance 1; 70% chance 0 + tensor([ 0.]) + + Args: + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + validate_args (bool, optional): whether to validate arguments, None by default + """ + + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.boolean + has_enumerate_support = True + _mean_carrier_measure = 0 + + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, _Number) + (self.probs,) = broadcast_all(probs) + else: + assert logits is not None # helps mypy + is_scalar = isinstance(logits, _Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Bernoulli, _instance) + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(Bernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @property + def mean(self) -> Tensor: + return self.probs + + @property + def mode(self) -> Tensor: + mode = (self.probs >= 0.5).to(self.probs) + mode[self.probs == 0.5] = nan + return mode + + @property + def variance(self) -> Tensor: + return self.probs * (1 - self.probs) + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self) -> Tensor: + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self) -> torch.Size: + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.bernoulli(self.probs.expand(shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + return -binary_cross_entropy_with_logits(logits, value, reduction="none") + + def entropy(self): + return binary_cross_entropy_with_logits( + self.logits, self.probs, reduction="none" + ) + + def enumerate_support(self, expand=True): + values = torch.arange(2, dtype=self._param.dtype, device=self._param.device) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values + + @property + def _natural_params(self) -> tuple[Tensor]: + return (torch.logit(self.probs),) + + def _log_normalizer(self, x): + return torch.log1p(torch.exp(x)) diff --git a/phivenv/Lib/site-packages/torch/distributions/beta.py b/phivenv/Lib/site-packages/torch/distributions/beta.py new file mode 100644 index 0000000000000000000000000000000000000000..914f16af3b7b0f2d1b3f9237bab665d5c0322fc3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/beta.py @@ -0,0 +1,117 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.dirichlet import Dirichlet +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Beta"] + + +class Beta(ExponentialFamily): + r""" + Beta distribution parameterized by :attr:`concentration1` and :attr:`concentration0`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Beta(torch.tensor([0.5]), torch.tensor([0.5])) + >>> m.sample() # Beta distributed with concentration concentration1 and concentration0 + tensor([ 0.1046]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } + support = constraints.unit_interval + has_rsample = True + + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + if isinstance(concentration1, _Number) and isinstance(concentration0, _Number): + concentration1_concentration0 = torch.tensor( + [float(concentration1), float(concentration0)] + ) + else: + concentration1, concentration0 = broadcast_all( + concentration1, concentration0 + ) + concentration1_concentration0 = torch.stack( + [concentration1, concentration0], -1 + ) + self._dirichlet = Dirichlet( + concentration1_concentration0, validate_args=validate_args + ) + super().__init__(self._dirichlet._batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Beta, _instance) + batch_shape = torch.Size(batch_shape) + new._dirichlet = self._dirichlet.expand(batch_shape) + super(Beta, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> Tensor: + return self.concentration1 / (self.concentration1 + self.concentration0) + + @property + def mode(self) -> Tensor: + return self._dirichlet.mode[..., 0] + + @property + def variance(self) -> Tensor: + total = self.concentration1 + self.concentration0 + return self.concentration1 * self.concentration0 / (total.pow(2) * (total + 1)) + + def rsample(self, sample_shape: _size = ()) -> Tensor: + return self._dirichlet.rsample(sample_shape).select(-1, 0) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + heads_tails = torch.stack([value, 1.0 - value], -1) + return self._dirichlet.log_prob(heads_tails) + + def entropy(self): + return self._dirichlet.entropy() + + @property + def concentration1(self) -> Tensor: + result = self._dirichlet.concentration[..., 0] + if isinstance(result, _Number): + return torch.tensor([result]) + else: + return result + + @property + def concentration0(self) -> Tensor: + result = self._dirichlet.concentration[..., 1] + if isinstance(result, _Number): + return torch.tensor([result]) + else: + return result + + @property + def _natural_params(self) -> tuple[Tensor, Tensor]: + return (self.concentration1, self.concentration0) + + def _log_normalizer(self, x, y): + return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) diff --git a/phivenv/Lib/site-packages/torch/distributions/binomial.py b/phivenv/Lib/site-packages/torch/distributions/binomial.py new file mode 100644 index 0000000000000000000000000000000000000000..1ae0f261cf45d3c716d5e1d75f2e4b6898a85ba7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/binomial.py @@ -0,0 +1,178 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + + +__all__ = ["Binomial"] + + +def _clamp_by_zero(x): + # works like clamp(x, min=0) but has grad at 0 is 0.5 + return (x.clamp(min=0) + x - x.clamp(max=0)) / 2 + + +class Binomial(Distribution): + r""" + Creates a Binomial distribution parameterized by :attr:`total_count` and + either :attr:`probs` or :attr:`logits` (but not both). :attr:`total_count` must be + broadcastable with :attr:`probs`/:attr:`logits`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Binomial(100, torch.tensor([0 , .2, .8, 1])) + >>> x = m.sample() + tensor([ 0., 22., 71., 100.]) + + >>> m = Binomial(torch.tensor([[5.], [10.]]), torch.tensor([0.5, 0.8])) + >>> x = m.sample() + tensor([[ 4., 5.], + [ 7., 6.]]) + + Args: + total_count (int or Tensor): number of Bernoulli trials + probs (Tensor): Event probabilities + logits (Tensor): Event log-odds + """ + + arg_constraints = { + "total_count": constraints.nonnegative_integer, + "probs": constraints.unit_interval, + "logits": constraints.real, + } + has_enumerate_support = True + + def __init__( + self, + total_count: Union[Tensor, int] = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.probs) + else: + assert logits is not None # helps mypy + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + + self._param = self.probs if probs is not None else self.logits + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Binomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count.expand(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(Binomial, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=0) + def support(self): + return constraints.integer_interval(0, self.total_count) + + @property + def mean(self) -> Tensor: + return self.total_count * self.probs + + @property + def mode(self) -> Tensor: + return ((self.total_count + 1) * self.probs).floor().clamp(max=self.total_count) + + @property + def variance(self) -> Tensor: + return self.total_count * self.probs * (1 - self.probs) + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self) -> Tensor: + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self) -> torch.Size: + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.binomial( + self.total_count.expand(shape), self.probs.expand(shape) + ) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_factorial_n = torch.lgamma(self.total_count + 1) + log_factorial_k = torch.lgamma(value + 1) + log_factorial_nmk = torch.lgamma(self.total_count - value + 1) + # k * log(p) + (n - k) * log(1 - p) = k * (log(p) - log(1 - p)) + n * log(1 - p) + # (case logit < 0) = k * logit - n * log1p(e^logit) + # (case logit > 0) = k * logit - n * (log(p) - log(1 - p)) + n * log(p) + # = k * logit - n * logit - n * log1p(e^-logit) + # (merge two cases) = k * logit - n * max(logit, 0) - n * log1p(e^-|logit|) + normalize_term = ( + self.total_count * _clamp_by_zero(self.logits) + + self.total_count * torch.log1p(torch.exp(-torch.abs(self.logits))) + - log_factorial_n + ) + return ( + value * self.logits - log_factorial_k - log_factorial_nmk - normalize_term + ) + + def entropy(self): + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError( + "Inhomogeneous total count not supported by `entropy`." + ) + + log_prob = self.log_prob(self.enumerate_support(False)) + return -(torch.exp(log_prob) * log_prob).sum(0) + + def enumerate_support(self, expand=True): + total_count = int(self.total_count.max()) + if not self.total_count.min() == total_count: + raise NotImplementedError( + "Inhomogeneous total count not supported by `enumerate_support`." + ) + values = torch.arange( + 1 + total_count, dtype=self._param.dtype, device=self._param.device + ) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values diff --git a/phivenv/Lib/site-packages/torch/distributions/categorical.py b/phivenv/Lib/site-packages/torch/distributions/categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..4d1a7d317ff6037b8dcd8f16bd4327f2f4e6f356 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/categorical.py @@ -0,0 +1,166 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import nan, Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import lazy_property, logits_to_probs, probs_to_logits + + +__all__ = ["Categorical"] + + +class Categorical(Distribution): + r""" + Creates a categorical distribution parameterized by either :attr:`probs` or + :attr:`logits` (but not both). + + .. note:: + It is equivalent to the distribution that :func:`torch.multinomial` + samples from. + + Samples are integers from :math:`\{0, \ldots, K-1\}` where `K` is ``probs.size(-1)``. + + If `probs` is 1-dimensional with length-`K`, each element is the relative probability + of sampling the class at that index. + + If `probs` is N-dimensional, the first N-1 dimensions are treated as a batch of + relative probability vectors. + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + See also: :func:`torch.multinomial` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Categorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample() # equal probability of 0, 1, 2, 3 + tensor(3) + + Args: + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + has_enumerate_support = True + + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + if probs.dim() < 1: + raise ValueError("`probs` parameter must be at least one-dimensional.") + self.probs = probs / probs.sum(-1, keepdim=True) + else: + assert logits is not None # helps mypy + if logits.dim() < 1: + raise ValueError("`logits` parameter must be at least one-dimensional.") + # Normalize + self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) + self._param = self.probs if probs is not None else self.logits + self._num_events = self._param.size()[-1] + batch_shape = ( + self._param.size()[:-1] if self._param.ndimension() > 1 else torch.Size() + ) + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Categorical, _instance) + batch_shape = torch.Size(batch_shape) + param_shape = batch_shape + torch.Size((self._num_events,)) + if "probs" in self.__dict__: + new.probs = self.probs.expand(param_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(param_shape) + new._param = new.logits + new._num_events = self._num_events + super(Categorical, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=0) + def support(self): + return constraints.integer_interval(0, self._num_events - 1) + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs) + + @lazy_property + def probs(self) -> Tensor: + return logits_to_probs(self.logits) + + @property + def param_shape(self) -> torch.Size: + return self._param.size() + + @property + def mean(self) -> Tensor: + return torch.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + device=self.probs.device, + ) + + @property + def mode(self) -> Tensor: + return self.probs.argmax(dim=-1) + + @property + def variance(self) -> Tensor: + return torch.full( + self._extended_shape(), + nan, + dtype=self.probs.dtype, + device=self.probs.device, + ) + + def sample(self, sample_shape=torch.Size()): + if not isinstance(sample_shape, torch.Size): + sample_shape = torch.Size(sample_shape) + probs_2d = self.probs.reshape(-1, self._num_events) + samples_2d = torch.multinomial(probs_2d, sample_shape.numel(), True).T + return samples_2d.reshape(self._extended_shape(sample_shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = value.long().unsqueeze(-1) + value, log_pmf = torch.broadcast_tensors(value, self.logits) + value = value[..., :1] + return log_pmf.gather(-1, value).squeeze(-1) + + def entropy(self): + min_real = torch.finfo(self.logits.dtype).min + logits = torch.clamp(self.logits, min=min_real) + p_log_p = logits * self.probs + return -p_log_p.sum(-1) + + def enumerate_support(self, expand=True): + num_events = self._num_events + values = torch.arange(num_events, dtype=torch.long, device=self._param.device) + values = values.view((-1,) + (1,) * len(self._batch_shape)) + if expand: + values = values.expand((-1,) + self._batch_shape) + return values diff --git a/phivenv/Lib/site-packages/torch/distributions/cauchy.py b/phivenv/Lib/site-packages/torch/distributions/cauchy.py new file mode 100644 index 0000000000000000000000000000000000000000..b94eea6f0118190b080e4ac0ba7e4bd7b09e7037 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/cauchy.py @@ -0,0 +1,99 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import inf, nan, Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Cauchy"] + + +class Cauchy(Distribution): + r""" + Samples from a Cauchy (Lorentz) distribution. The distribution of the ratio of + independent normally distributed random variables with means `0` follows a + Cauchy distribution. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Cauchy(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Cauchy distribution with loc=0 and scale=1 + tensor([ 2.3214]) + + Args: + loc (float or Tensor): mode or median of the distribution. + scale (float or Tensor): half width at half maximum. + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, _Number) and isinstance(scale, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Cauchy, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Cauchy, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> Tensor: + return torch.full( + self._extended_shape(), nan, dtype=self.loc.dtype, device=self.loc.device + ) + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def variance(self) -> Tensor: + return torch.full( + self._extended_shape(), inf, dtype=self.loc.dtype, device=self.loc.device + ) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + eps = self.loc.new(shape).cauchy_() + return self.loc + eps * self.scale + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return ( + -math.log(math.pi) + - self.scale.log() + - (((value - self.loc) / self.scale) ** 2).log1p() + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return torch.atan((value - self.loc) / self.scale) / math.pi + 0.5 + + def icdf(self, value): + return torch.tan(math.pi * (value - 0.5)) * self.scale + self.loc + + def entropy(self): + return math.log(4 * math.pi) + self.scale.log() diff --git a/phivenv/Lib/site-packages/torch/distributions/chi2.py b/phivenv/Lib/site-packages/torch/distributions/chi2.py new file mode 100644 index 0000000000000000000000000000000000000000..7a654306aad967459a46e191816c79381da163ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/chi2.py @@ -0,0 +1,43 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.gamma import Gamma + + +__all__ = ["Chi2"] + + +class Chi2(Gamma): + r""" + Creates a Chi-squared distribution parameterized by shape parameter :attr:`df`. + This is exactly equivalent to ``Gamma(alpha=0.5*df, beta=0.5)`` + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Chi2(torch.tensor([1.0])) + >>> m.sample() # Chi2 distributed with shape df=1 + tensor([ 0.1046]) + + Args: + df (float or Tensor): shape parameter of the distribution + """ + + arg_constraints = {"df": constraints.positive} + + def __init__( + self, + df: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + super().__init__(0.5 * df, 0.5, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Chi2, _instance) + return super().expand(batch_shape, new) + + @property + def df(self) -> Tensor: + return self.concentration * 2 diff --git a/phivenv/Lib/site-packages/torch/distributions/constraint_registry.py b/phivenv/Lib/site-packages/torch/distributions/constraint_registry.py new file mode 100644 index 0000000000000000000000000000000000000000..17b5928b2afd399cc18d7d7064638aa53d76a9d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/constraint_registry.py @@ -0,0 +1,291 @@ +# mypy: allow-untyped-defs +r""" +PyTorch provides two global :class:`ConstraintRegistry` objects that link +:class:`~torch.distributions.constraints.Constraint` objects to +:class:`~torch.distributions.transforms.Transform` objects. These objects both +input constraints and return transforms, but they have different guarantees on +bijectivity. + +1. ``biject_to(constraint)`` looks up a bijective + :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` + to the given ``constraint``. The returned transform is guaranteed to have + ``.bijective = True`` and should implement ``.log_abs_det_jacobian()``. +2. ``transform_to(constraint)`` looks up a not-necessarily bijective + :class:`~torch.distributions.transforms.Transform` from ``constraints.real`` + to the given ``constraint``. The returned transform is not guaranteed to + implement ``.log_abs_det_jacobian()``. + +The ``transform_to()`` registry is useful for performing unconstrained +optimization on constrained parameters of probability distributions, which are +indicated by each distribution's ``.arg_constraints`` dict. These transforms often +overparameterize a space in order to avoid rotation; they are thus more +suitable for coordinate-wise optimization algorithms like Adam:: + + loc = torch.zeros(100, requires_grad=True) + unconstrained = torch.zeros(100, requires_grad=True) + scale = transform_to(Normal.arg_constraints["scale"])(unconstrained) + loss = -Normal(loc, scale).log_prob(data).sum() + +The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where +samples from a probability distribution with constrained ``.support`` are +propagated in an unconstrained space, and algorithms are typically rotation +invariant.:: + + dist = Exponential(rate) + unconstrained = torch.zeros(100, requires_grad=True) + sample = biject_to(dist.support)(unconstrained) + potential_energy = -dist.log_prob(sample).sum() + +.. note:: + + An example where ``transform_to`` and ``biject_to`` differ is + ``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a + :class:`~torch.distributions.transforms.SoftmaxTransform` that simply + exponentiates and normalizes its inputs; this is a cheap and mostly + coordinate-wise operation appropriate for algorithms like SVI. In + contrast, ``biject_to(constraints.simplex)`` returns a + :class:`~torch.distributions.transforms.StickBreakingTransform` that + bijects its input down to a one-fewer-dimensional space; this a more + expensive less numerically stable transform but is needed for algorithms + like HMC. + +The ``biject_to`` and ``transform_to`` objects can be extended by user-defined +constraints and transforms using their ``.register()`` method either as a +function on singleton constraints:: + + transform_to.register(my_constraint, my_transform) + +or as a decorator on parameterized constraints:: + + @transform_to.register(MyConstraintClass) + def my_factory(constraint): + assert isinstance(constraint, MyConstraintClass) + return MyTransform(constraint.param1, constraint.param2) + +You can create your own registry by creating a new :class:`ConstraintRegistry` +object. +""" + +from torch.distributions import constraints, transforms +from torch.types import _Number + + +__all__ = [ + "ConstraintRegistry", + "biject_to", + "transform_to", +] + + +class ConstraintRegistry: + """ + Registry to link constraints to transforms. + """ + + def __init__(self): + self._registry = {} + super().__init__() + + def register(self, constraint, factory=None): + """ + Registers a :class:`~torch.distributions.constraints.Constraint` + subclass in this registry. Usage:: + + @my_registry.register(MyConstraintClass) + def construct_transform(constraint): + assert isinstance(constraint, MyConstraint) + return MyTransform(constraint.arg_constraints) + + Args: + constraint (subclass of :class:`~torch.distributions.constraints.Constraint`): + A subclass of :class:`~torch.distributions.constraints.Constraint`, or + a singleton object of the desired class. + factory (Callable): A callable that inputs a constraint object and returns + a :class:`~torch.distributions.transforms.Transform` object. + """ + # Support use as decorator. + if factory is None: + return lambda factory: self.register(constraint, factory) + + # Support calling on singleton instances. + if isinstance(constraint, constraints.Constraint): + constraint = type(constraint) + + if not isinstance(constraint, type) or not issubclass( + constraint, constraints.Constraint + ): + raise TypeError( + f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}" + ) + + self._registry[constraint] = factory + return factory + + def __call__(self, constraint): + """ + Looks up a transform to constrained space, given a constraint object. + Usage:: + + constraint = Normal.arg_constraints["scale"] + scale = transform_to(constraint)(torch.zeros(1)) # constrained + u = transform_to(constraint).inv(scale) # unconstrained + + Args: + constraint (:class:`~torch.distributions.constraints.Constraint`): + A constraint object. + + Returns: + A :class:`~torch.distributions.transforms.Transform` object. + + Raises: + `NotImplementedError` if no transform has been registered. + """ + # Look up by Constraint subclass. + try: + factory = self._registry[type(constraint)] + except KeyError: + raise NotImplementedError( + f"Cannot transform {type(constraint).__name__} constraints" + ) from None + return factory(constraint) + + +biject_to = ConstraintRegistry() +transform_to = ConstraintRegistry() + + +################################################################################ +# Registration Table +################################################################################ + + +@biject_to.register(constraints.real) +@transform_to.register(constraints.real) +def _transform_to_real(constraint): + return transforms.identity_transform + + +@biject_to.register(constraints.independent) +def _biject_to_independent(constraint): + base_transform = biject_to(constraint.base_constraint) + return transforms.IndependentTransform( + base_transform, constraint.reinterpreted_batch_ndims + ) + + +@transform_to.register(constraints.independent) +def _transform_to_independent(constraint): + base_transform = transform_to(constraint.base_constraint) + return transforms.IndependentTransform( + base_transform, constraint.reinterpreted_batch_ndims + ) + + +@biject_to.register(constraints.positive) +@biject_to.register(constraints.nonnegative) +@transform_to.register(constraints.positive) +@transform_to.register(constraints.nonnegative) +def _transform_to_positive(constraint): + return transforms.ExpTransform() + + +@biject_to.register(constraints.greater_than) +@biject_to.register(constraints.greater_than_eq) +@transform_to.register(constraints.greater_than) +@transform_to.register(constraints.greater_than_eq) +def _transform_to_greater_than(constraint): + return transforms.ComposeTransform( + [ + transforms.ExpTransform(), + transforms.AffineTransform(constraint.lower_bound, 1), + ] + ) + + +@biject_to.register(constraints.less_than) +@transform_to.register(constraints.less_than) +def _transform_to_less_than(constraint): + return transforms.ComposeTransform( + [ + transforms.ExpTransform(), + transforms.AffineTransform(constraint.upper_bound, -1), + ] + ) + + +@biject_to.register(constraints.interval) +@biject_to.register(constraints.half_open_interval) +@transform_to.register(constraints.interval) +@transform_to.register(constraints.half_open_interval) +def _transform_to_interval(constraint): + # Handle the special case of the unit interval. + lower_is_0 = ( + isinstance(constraint.lower_bound, _Number) and constraint.lower_bound == 0 + ) + upper_is_1 = ( + isinstance(constraint.upper_bound, _Number) and constraint.upper_bound == 1 + ) + if lower_is_0 and upper_is_1: + return transforms.SigmoidTransform() + + loc = constraint.lower_bound + scale = constraint.upper_bound - constraint.lower_bound + return transforms.ComposeTransform( + [transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)] + ) + + +@biject_to.register(constraints.simplex) +def _biject_to_simplex(constraint): + return transforms.StickBreakingTransform() + + +@transform_to.register(constraints.simplex) +def _transform_to_simplex(constraint): + return transforms.SoftmaxTransform() + + +# TODO define a bijection for LowerCholeskyTransform +@transform_to.register(constraints.lower_cholesky) +def _transform_to_lower_cholesky(constraint): + return transforms.LowerCholeskyTransform() + + +@transform_to.register(constraints.positive_definite) +@transform_to.register(constraints.positive_semidefinite) +def _transform_to_positive_definite(constraint): + return transforms.PositiveDefiniteTransform() + + +@biject_to.register(constraints.corr_cholesky) +@transform_to.register(constraints.corr_cholesky) +def _transform_to_corr_cholesky(constraint): + return transforms.CorrCholeskyTransform() + + +@biject_to.register(constraints.cat) +def _biject_to_cat(constraint): + return transforms.CatTransform( + [biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths + ) + + +@transform_to.register(constraints.cat) +def _transform_to_cat(constraint): + return transforms.CatTransform( + [transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths + ) + + +@biject_to.register(constraints.stack) +def _biject_to_stack(constraint): + return transforms.StackTransform( + [biject_to(c) for c in constraint.cseq], constraint.dim + ) + + +@transform_to.register(constraints.stack) +def _transform_to_stack(constraint): + return transforms.StackTransform( + [transform_to(c) for c in constraint.cseq], constraint.dim + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/constraints.py b/phivenv/Lib/site-packages/torch/distributions/constraints.py new file mode 100644 index 0000000000000000000000000000000000000000..1529492d8ff30aaecb27f6e36265f460f2199ee1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/constraints.py @@ -0,0 +1,737 @@ +# mypy: allow-untyped-defs + +from typing import Any, Callable, Optional + + +r""" +The following constraints are implemented: + +- ``constraints.boolean`` +- ``constraints.cat`` +- ``constraints.corr_cholesky`` +- ``constraints.dependent`` +- ``constraints.greater_than(lower_bound)`` +- ``constraints.greater_than_eq(lower_bound)`` +- ``constraints.independent(constraint, reinterpreted_batch_ndims)`` +- ``constraints.integer_interval(lower_bound, upper_bound)`` +- ``constraints.interval(lower_bound, upper_bound)`` +- ``constraints.less_than(upper_bound)`` +- ``constraints.lower_cholesky`` +- ``constraints.lower_triangular`` +- ``constraints.MixtureSameFamilyConstraint(base_constraint)`` +- ``constraints.multinomial`` +- ``constraints.nonnegative`` +- ``constraints.nonnegative_integer`` +- ``constraints.one_hot`` +- ``constraints.positive_integer`` +- ``constraints.positive`` +- ``constraints.positive_semidefinite`` +- ``constraints.positive_definite`` +- ``constraints.real_vector`` +- ``constraints.real`` +- ``constraints.simplex`` +- ``constraints.symmetric`` +- ``constraints.stack`` +- ``constraints.square`` +- ``constraints.symmetric`` +- ``constraints.unit_interval`` +""" + +import torch + + +__all__ = [ + "Constraint", + "boolean", + "cat", + "corr_cholesky", + "dependent", + "dependent_property", + "greater_than", + "greater_than_eq", + "independent", + "integer_interval", + "interval", + "half_open_interval", + "is_dependent", + "less_than", + "lower_cholesky", + "lower_triangular", + "MixtureSameFamilyConstraint", + "multinomial", + "nonnegative", + "nonnegative_integer", + "one_hot", + "positive", + "positive_semidefinite", + "positive_definite", + "positive_integer", + "real", + "real_vector", + "simplex", + "square", + "stack", + "symmetric", + "unit_interval", +] + + +class Constraint: + """ + Abstract base class for constraints. + + A constraint object represents a region over which a variable is valid, + e.g. within which a variable can be optimized. + + Attributes: + is_discrete (bool): Whether constrained space is discrete. + Defaults to False. + event_dim (int): Number of rightmost dimensions that together define + an event. The :meth:`check` method will remove this many dimensions + when computing validity. + """ + + is_discrete = False # Default to continuous. + event_dim = 0 # Default to univariate. + + def check(self, value): + """ + Returns a byte tensor of ``sample_shape + batch_shape`` indicating + whether each event in value satisfies this constraint. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__[1:] + "()" + + +class _Dependent(Constraint): + """ + Placeholder for variables whose support depends on other variables. + These variables obey no simple coordinate-wise constraints. + + Args: + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + + def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + self._is_discrete = is_discrete + self._event_dim = event_dim + super().__init__() + + @property + def is_discrete(self) -> bool: # type: ignore[override] + if self._is_discrete is NotImplemented: + raise NotImplementedError(".is_discrete cannot be determined statically") + return self._is_discrete + + @property + def event_dim(self) -> int: # type: ignore[override] + if self._event_dim is NotImplemented: + raise NotImplementedError(".event_dim cannot be determined statically") + return self._event_dim + + def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented): + """ + Support for syntax to customize static attributes:: + + constraints.dependent(is_discrete=True, event_dim=1) + """ + if is_discrete is NotImplemented: + is_discrete = self._is_discrete + if event_dim is NotImplemented: + event_dim = self._event_dim + return _Dependent(is_discrete=is_discrete, event_dim=event_dim) + + def check(self, x): + raise ValueError("Cannot determine validity of dependent constraint") + + +def is_dependent(constraint): + """ + Checks if ``constraint`` is a ``_Dependent`` object. + + Args: + constraint : A ``Constraint`` object. + + Returns: + ``bool``: True if ``constraint`` can be refined to the type ``_Dependent``, False otherwise. + + Examples: + >>> import torch + >>> from torch.distributions import Bernoulli + >>> from torch.distributions.constraints import is_dependent + + >>> dist = Bernoulli(probs=torch.tensor([0.6], requires_grad=True)) + >>> constraint1 = dist.arg_constraints["probs"] + >>> constraint2 = dist.arg_constraints["logits"] + + >>> for constraint in [constraint1, constraint2]: + >>> if is_dependent(constraint): + >>> continue + """ + return isinstance(constraint, _Dependent) + + +class _DependentProperty(property, _Dependent): + """ + Decorator that extends @property to act like a `Dependent` constraint when + called on a class and act like a property when called on an object. + + Example:: + + class Uniform(Distribution): + def __init__(self, low, high): + self.low = low + self.high = high + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.interval(self.low, self.high) + + Args: + fn (Callable): The function to be decorated. + is_discrete (bool): Optional value of ``.is_discrete`` in case this + can be computed statically. If not provided, access to the + ``.is_discrete`` attribute will raise a NotImplementedError. + event_dim (int): Optional value of ``.event_dim`` in case this + can be computed statically. If not provided, access to the + ``.event_dim`` attribute will raise a NotImplementedError. + """ + + def __init__( + self, + fn: Optional[Callable[..., Any]] = None, + *, + is_discrete: Optional[bool] = NotImplemented, + event_dim: Optional[int] = NotImplemented, + ) -> None: + super().__init__(fn) + self._is_discrete = is_discrete + self._event_dim = event_dim + + def __call__(self, fn: Callable[..., Any]) -> "_DependentProperty": # type: ignore[override] + """ + Support for syntax to customize static attributes:: + + @constraints.dependent_property(is_discrete=True, event_dim=1) + def support(self): ... + """ + return _DependentProperty( + fn, is_discrete=self._is_discrete, event_dim=self._event_dim + ) + + +class _IndependentConstraint(Constraint): + """ + Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many + dims in :meth:`check`, so that an event is valid only if all its + independent entries are valid. + """ + + def __init__(self, base_constraint, reinterpreted_batch_ndims): + assert isinstance(base_constraint, Constraint) + assert isinstance(reinterpreted_batch_ndims, int) + assert reinterpreted_batch_ndims >= 0 + self.base_constraint = base_constraint + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__() + + @property + def is_discrete(self) -> bool: # type: ignore[override] + return self.base_constraint.is_discrete + + @property + def event_dim(self) -> int: # type: ignore[override] + return self.base_constraint.event_dim + self.reinterpreted_batch_ndims + + def check(self, value): + result = self.base_constraint.check(value) + if result.dim() < self.reinterpreted_batch_ndims: + expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims + raise ValueError( + f"Expected value.dim() >= {expected} but got {value.dim()}" + ) + result = result.reshape( + result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,) + ) + result = result.all(-1) + return result + + def __repr__(self): + return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})" + + +class MixtureSameFamilyConstraint(Constraint): + """ + Constraint for the :class:`~torch.distribution.MixtureSameFamily` + distribution that adds back the rightmost batch dimension before + performing the validity check with the component distribution + constraint. + + Args: + base_constraint: The ``Constraint`` object of + the component distribution of + the :class:`~torch.distribution.MixtureSameFamily` distribution. + """ + + def __init__(self, base_constraint): + assert isinstance(base_constraint, Constraint) + self.base_constraint = base_constraint + super().__init__() + + @property + def is_discrete(self) -> bool: # type: ignore[override] + return self.base_constraint.is_discrete + + @property + def event_dim(self) -> int: # type: ignore[override] + return self.base_constraint.event_dim + + def check(self, value): + """ + Check validity of ``value`` as a possible outcome of sampling + the :class:`~torch.distribution.MixtureSameFamily` distribution. + """ + unsqueezed_value = value.unsqueeze(-1 - self.event_dim) + result = self.base_constraint.check(unsqueezed_value) + if value.dim() < self.event_dim: + raise ValueError( + f"Expected value.dim() >= {self.event_dim} but got {value.dim()}" + ) + num_dim_to_keep = value.dim() - self.event_dim + result = result.reshape(result.shape[:num_dim_to_keep] + (-1,)) + result = result.all(-1) + return result + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.base_constraint)})" + + +class _Boolean(Constraint): + """ + Constrain to the two values `{0, 1}`. + """ + + is_discrete = True + + def check(self, value): + return (value == 0) | (value == 1) + + +class _OneHot(Constraint): + """ + Constrain to one-hot vectors. + """ + + is_discrete = True + event_dim = 1 + + def check(self, value): + is_boolean = (value == 0) | (value == 1) + is_normalized = value.sum(-1).eq(1) + return is_boolean.all(-1) & is_normalized + + +class _IntegerInterval(Constraint): + """ + Constrain to an integer interval `[lower_bound, upper_bound]`. + """ + + is_discrete = True + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return ( + (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) + ) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _IntegerLessThan(Constraint): + """ + Constrain to an integer interval `(-inf, upper_bound]`. + """ + + is_discrete = True + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (value % 1 == 0) & (value <= self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(upper_bound={self.upper_bound})" + return fmt_string + + +class _IntegerGreaterThan(Constraint): + """ + Constrain to an integer interval `[lower_bound, inf)`. + """ + + is_discrete = True + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return (value % 1 == 0) & (value >= self.lower_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _Real(Constraint): + """ + Trivially constrain to the extended real line `[-inf, inf]`. + """ + + def check(self, value): + return value == value # False for NANs. + + +class _GreaterThan(Constraint): + """ + Constrain to a real half line `(lower_bound, inf]`. + """ + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return self.lower_bound < value + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _GreaterThanEq(Constraint): + """ + Constrain to a real half line `[lower_bound, inf)`. + """ + + def __init__(self, lower_bound): + self.lower_bound = lower_bound + super().__init__() + + def check(self, value): + return self.lower_bound <= value + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(lower_bound={self.lower_bound})" + return fmt_string + + +class _LessThan(Constraint): + """ + Constrain to a real half line `[-inf, upper_bound)`. + """ + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return value < self.upper_bound + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += f"(upper_bound={self.upper_bound})" + return fmt_string + + +class _Interval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound]`. + """ + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (self.lower_bound <= value) & (value <= self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _HalfOpenInterval(Constraint): + """ + Constrain to a real interval `[lower_bound, upper_bound)`. + """ + + def __init__(self, lower_bound, upper_bound): + self.lower_bound = lower_bound + self.upper_bound = upper_bound + super().__init__() + + def check(self, value): + return (self.lower_bound <= value) & (value < self.upper_bound) + + def __repr__(self): + fmt_string = self.__class__.__name__[1:] + fmt_string += ( + f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})" + ) + return fmt_string + + +class _Simplex(Constraint): + """ + Constrain to the unit simplex in the innermost (rightmost) dimension. + Specifically: `x >= 0` and `x.sum(-1) == 1`. + """ + + event_dim = 1 + + def check(self, value): + return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6) + + +class _Multinomial(Constraint): + """ + Constrain to nonnegative integer values summing to at most an upper bound. + + Note due to limitations of the Multinomial distribution, this currently + checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future + this may be strengthened to ``value.sum(-1) == upper_bound``. + """ + + is_discrete = True + event_dim = 1 + + def __init__(self, upper_bound): + self.upper_bound = upper_bound + + def check(self, x): + return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound) + + +class _LowerTriangular(Constraint): + """ + Constrain to lower-triangular square matrices. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + + +class _LowerCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals. + """ + + event_dim = 2 + + def check(self, value): + value_tril = value.tril() + lower_triangular = ( + (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0] + ) + + positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0] + return lower_triangular & positive_diagonal + + +class _CorrCholesky(Constraint): + """ + Constrain to lower-triangular square matrices with positive diagonals and each + row vector being of unit length. + """ + + event_dim = 2 + + def check(self, value): + tol = ( + torch.finfo(value.dtype).eps * value.size(-1) * 10 + ) # 10 is an adjustable fudge factor + row_norm = torch.linalg.norm(value.detach(), dim=-1) + unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1) + return _LowerCholesky().check(value) & unit_row_norm + + +class _Square(Constraint): + """ + Constrain to square matrices. + """ + + event_dim = 2 + + def check(self, value): + return torch.full( + size=value.shape[:-2], + fill_value=(value.shape[-2] == value.shape[-1]), + dtype=torch.bool, + device=value.device, + ) + + +class _Symmetric(_Square): + """ + Constrain to Symmetric square matrices. + """ + + def check(self, value): + square_check = super().check(value) + if not square_check.all(): + return square_check + return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1) + + +class _PositiveSemidefinite(_Symmetric): + """ + Constrain to positive-semidefinite matrices. + """ + + def check(self, value): + sym_check = super().check(value) + if not sym_check.all(): + return sym_check + return torch.linalg.eigvalsh(value).ge(0).all(-1) + + +class _PositiveDefinite(_Symmetric): + """ + Constrain to positive-definite matrices. + """ + + def check(self, value): + sym_check = super().check(value) + if not sym_check.all(): + return sym_check + return torch.linalg.cholesky_ex(value).info.eq(0) + + +class _Cat(Constraint): + """ + Constraint functor that applies a sequence of constraints + `cseq` at the submatrices at dimension `dim`, + each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`. + """ + + def __init__(self, cseq, dim=0, lengths=None): + assert all(isinstance(c, Constraint) for c in cseq) + self.cseq = list(cseq) + if lengths is None: + lengths = [1] * len(self.cseq) + self.lengths = list(lengths) + assert len(self.lengths) == len(self.cseq) + self.dim = dim + super().__init__() + + @property + def is_discrete(self) -> bool: # type: ignore[override] + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self) -> int: # type: ignore[override] + return max(c.event_dim for c in self.cseq) + + def check(self, value): + assert -value.dim() <= self.dim < value.dim() + checks = [] + start = 0 + for constr, length in zip(self.cseq, self.lengths): + v = value.narrow(self.dim, start, length) + checks.append(constr.check(v)) + start = start + length # avoid += for jit compat + return torch.cat(checks, self.dim) + + +class _Stack(Constraint): + """ + Constraint functor that applies a sequence of constraints + `cseq` at the submatrices at dimension `dim`, + in a way compatible with :func:`torch.stack`. + """ + + def __init__(self, cseq, dim=0): + assert all(isinstance(c, Constraint) for c in cseq) + self.cseq = list(cseq) + self.dim = dim + super().__init__() + + @property + def is_discrete(self) -> bool: # type: ignore[override] + return any(c.is_discrete for c in self.cseq) + + @property + def event_dim(self) -> int: # type: ignore[override] + dim = max(c.event_dim for c in self.cseq) + if self.dim + dim < 0: + dim += 1 + return dim + + def check(self, value): + assert -value.dim() <= self.dim < value.dim() + vs = [value.select(self.dim, i) for i in range(value.size(self.dim))] + return torch.stack( + [constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim + ) + + +# Public interface. +dependent = _Dependent() +dependent_property = _DependentProperty +independent = _IndependentConstraint +boolean = _Boolean() +one_hot = _OneHot() +nonnegative_integer = _IntegerGreaterThan(0) +positive_integer = _IntegerGreaterThan(1) +integer_interval = _IntegerInterval +real = _Real() +real_vector = independent(real, 1) +positive = _GreaterThan(0.0) +nonnegative = _GreaterThanEq(0.0) +greater_than = _GreaterThan +greater_than_eq = _GreaterThanEq +less_than = _LessThan +multinomial = _Multinomial +unit_interval = _Interval(0.0, 1.0) +interval = _Interval +half_open_interval = _HalfOpenInterval +simplex = _Simplex() +lower_triangular = _LowerTriangular() +lower_cholesky = _LowerCholesky() +corr_cholesky = _CorrCholesky() +square = _Square() +symmetric = _Symmetric() +positive_semidefinite = _PositiveSemidefinite() +positive_definite = _PositiveDefinite() +cat = _Cat +stack = _Stack diff --git a/phivenv/Lib/site-packages/torch/distributions/continuous_bernoulli.py b/phivenv/Lib/site-packages/torch/distributions/continuous_bernoulli.py new file mode 100644 index 0000000000000000000000000000000000000000..78cff84d30e8fb1c01fd6753e2596def412f4287 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/continuous_bernoulli.py @@ -0,0 +1,245 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits +from torch.types import _Number, _size, Number + + +__all__ = ["ContinuousBernoulli"] + + +class ContinuousBernoulli(ExponentialFamily): + r""" + Creates a continuous Bernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both). + + The distribution is supported in [0, 1] and parameterized by 'probs' (in + (0,1)) or 'logits' (real-valued). Note that, unlike the Bernoulli, 'probs' + does not correspond to a probability and 'logits' does not correspond to + log-odds, but the same names are used due to the similarity with the + Bernoulli. See [1] for more details. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = ContinuousBernoulli(torch.tensor([0.3])) + >>> m.sample() + tensor([ 0.2538]) + + Args: + probs (Number, Tensor): (0,1) valued parameters + logits (Number, Tensor): real valued parameters whose sigmoid matches 'probs' + + [1] The continuous Bernoulli: fixing a pervasive error in variational + autoencoders, Loaiza-Ganem G and Cunningham JP, NeurIPS 2019. + https://arxiv.org/abs/1907.06845 + """ + + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.unit_interval + _mean_carrier_measure = 0 + has_rsample = True + + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + lims: tuple[float, float] = (0.499, 0.501), + validate_args: Optional[bool] = None, + ) -> None: + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, _Number) + (self.probs,) = broadcast_all(probs) + # validate 'probs' here if necessary as it is later clamped for numerical stability + # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass + if validate_args is not None: + if not self.arg_constraints["probs"].check(self.probs).all(): + raise ValueError("The parameter probs has invalid values") + self.probs = clamp_probs(self.probs) + else: + assert logits is not None # helps mypy + is_scalar = isinstance(logits, _Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + self._lims = lims + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ContinuousBernoulli, _instance) + new._lims = self._lims + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(ContinuousBernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + def _outside_unstable_region(self): + return torch.max( + torch.le(self.probs, self._lims[0]), torch.gt(self.probs, self._lims[1]) + ) + + def _cut_probs(self): + return torch.where( + self._outside_unstable_region(), + self.probs, + self._lims[0] * torch.ones_like(self.probs), + ) + + def _cont_bern_log_norm(self): + """computes the log normalizing constant as a function of the 'probs' parameter""" + cut_probs = self._cut_probs() + cut_probs_below_half = torch.where( + torch.le(cut_probs, 0.5), cut_probs, torch.zeros_like(cut_probs) + ) + cut_probs_above_half = torch.where( + torch.ge(cut_probs, 0.5), cut_probs, torch.ones_like(cut_probs) + ) + log_norm = torch.log( + torch.abs(torch.log1p(-cut_probs) - torch.log(cut_probs)) + ) - torch.where( + torch.le(cut_probs, 0.5), + torch.log1p(-2.0 * cut_probs_below_half), + torch.log(2.0 * cut_probs_above_half - 1.0), + ) + x = torch.pow(self.probs - 0.5, 2) + taylor = math.log(2.0) + (4.0 / 3.0 + 104.0 / 45.0 * x) * x + return torch.where(self._outside_unstable_region(), log_norm, taylor) + + @property + def mean(self) -> Tensor: + cut_probs = self._cut_probs() + mus = cut_probs / (2.0 * cut_probs - 1.0) + 1.0 / ( + torch.log1p(-cut_probs) - torch.log(cut_probs) + ) + x = self.probs - 0.5 + taylor = 0.5 + (1.0 / 3.0 + 16.0 / 45.0 * torch.pow(x, 2)) * x + return torch.where(self._outside_unstable_region(), mus, taylor) + + @property + def stddev(self) -> Tensor: + return torch.sqrt(self.variance) + + @property + def variance(self) -> Tensor: + cut_probs = self._cut_probs() + vars = cut_probs * (cut_probs - 1.0) / torch.pow( + 1.0 - 2.0 * cut_probs, 2 + ) + 1.0 / torch.pow(torch.log1p(-cut_probs) - torch.log(cut_probs), 2) + x = torch.pow(self.probs - 0.5, 2) + taylor = 1.0 / 12.0 - (1.0 / 15.0 - 128.0 / 945.0 * x) * x + return torch.where(self._outside_unstable_region(), vars, taylor) + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self) -> Tensor: + return clamp_probs(logits_to_probs(self.logits, is_binary=True)) + + @property + def param_shape(self) -> torch.Size: + return self._param.size() + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + with torch.no_grad(): + return self.icdf(u) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + return self.icdf(u) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + return ( + -binary_cross_entropy_with_logits(logits, value, reduction="none") + + self._cont_bern_log_norm() + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + cut_probs = self._cut_probs() + cdfs = ( + torch.pow(cut_probs, value) * torch.pow(1.0 - cut_probs, 1.0 - value) + + cut_probs + - 1.0 + ) / (2.0 * cut_probs - 1.0) + unbounded_cdfs = torch.where(self._outside_unstable_region(), cdfs, value) + return torch.where( + torch.le(value, 0.0), + torch.zeros_like(value), + torch.where(torch.ge(value, 1.0), torch.ones_like(value), unbounded_cdfs), + ) + + def icdf(self, value): + cut_probs = self._cut_probs() + return torch.where( + self._outside_unstable_region(), + ( + torch.log1p(-cut_probs + value * (2.0 * cut_probs - 1.0)) + - torch.log1p(-cut_probs) + ) + / (torch.log(cut_probs) - torch.log1p(-cut_probs)), + value, + ) + + def entropy(self): + log_probs0 = torch.log1p(-self.probs) + log_probs1 = torch.log(self.probs) + return ( + self.mean * (log_probs0 - log_probs1) + - self._cont_bern_log_norm() + - log_probs0 + ) + + @property + def _natural_params(self) -> tuple[Tensor]: + return (self.logits,) + + def _log_normalizer(self, x): + """computes the log normalizing constant as a function of the natural parameter""" + out_unst_reg = torch.max( + torch.le(x, self._lims[0] - 0.5), torch.gt(x, self._lims[1] - 0.5) + ) + cut_nat_params = torch.where( + out_unst_reg, x, (self._lims[0] - 0.5) * torch.ones_like(x) + ) + log_norm = torch.log( + torch.abs(torch.special.expm1(cut_nat_params)) + ) - torch.log(torch.abs(cut_nat_params)) + taylor = 0.5 * x + torch.pow(x, 2) / 24.0 - torch.pow(x, 4) / 2880.0 + return torch.where(out_unst_reg, log_norm, taylor) diff --git a/phivenv/Lib/site-packages/torch/distributions/dirichlet.py b/phivenv/Lib/site-packages/torch/distributions/dirichlet.py new file mode 100644 index 0000000000000000000000000000000000000000..5df1c7f66027577fc3b839fab511ae7325026977 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/dirichlet.py @@ -0,0 +1,134 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import Tensor +from torch.autograd import Function +from torch.autograd.function import once_differentiable +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.types import _size + + +__all__ = ["Dirichlet"] + + +# This helper is exposed for testing. +def _Dirichlet_backward(x, concentration, grad_output): + total = concentration.sum(-1, True).expand_as(concentration) + grad = torch._dirichlet_grad(x, concentration, total) + return grad * (grad_output - (x * grad_output).sum(-1, True)) + + +class _Dirichlet(Function): + @staticmethod + def forward(ctx, concentration): + x = torch._sample_dirichlet(concentration) + ctx.save_for_backward(x, concentration) + return x + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + x, concentration = ctx.saved_tensors + return _Dirichlet_backward(x, concentration, grad_output) + + +class Dirichlet(ExponentialFamily): + r""" + Creates a Dirichlet distribution parameterized by concentration :attr:`concentration`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Dirichlet(torch.tensor([0.5, 0.5])) + >>> m.sample() # Dirichlet distributed with concentration [0.5, 0.5] + tensor([ 0.1046, 0.8954]) + + Args: + concentration (Tensor): concentration parameter of the distribution + (often referred to as alpha) + """ + + arg_constraints = { + "concentration": constraints.independent(constraints.positive, 1) + } + support = constraints.simplex + has_rsample = True + + def __init__( + self, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: + if concentration.dim() < 1: + raise ValueError( + "`concentration` parameter must be at least one-dimensional." + ) + self.concentration = concentration + batch_shape, event_shape = concentration.shape[:-1], concentration.shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Dirichlet, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape + self.event_shape) + super(Dirichlet, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = ()) -> Tensor: + shape = self._extended_shape(sample_shape) + concentration = self.concentration.expand(shape) + return _Dirichlet.apply(concentration) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return ( + torch.xlogy(self.concentration - 1.0, value).sum(-1) + + torch.lgamma(self.concentration.sum(-1)) + - torch.lgamma(self.concentration).sum(-1) + ) + + @property + def mean(self) -> Tensor: + return self.concentration / self.concentration.sum(-1, True) + + @property + def mode(self) -> Tensor: + concentrationm1 = (self.concentration - 1).clamp(min=0.0) + mode = concentrationm1 / concentrationm1.sum(-1, True) + mask = (self.concentration < 1).all(dim=-1) + mode[mask] = torch.nn.functional.one_hot( + mode[mask].argmax(dim=-1), concentrationm1.shape[-1] + ).to(mode) + return mode + + @property + def variance(self) -> Tensor: + con0 = self.concentration.sum(-1, True) + return ( + self.concentration + * (con0 - self.concentration) + / (con0.pow(2) * (con0 + 1)) + ) + + def entropy(self): + k = self.concentration.size(-1) + a0 = self.concentration.sum(-1) + return ( + torch.lgamma(self.concentration).sum(-1) + - torch.lgamma(a0) + - (k - a0) * torch.digamma(a0) + - ((self.concentration - 1.0) * torch.digamma(self.concentration)).sum(-1) + ) + + @property + def _natural_params(self) -> tuple[Tensor]: + return (self.concentration,) + + def _log_normalizer(self, x): + return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) diff --git a/phivenv/Lib/site-packages/torch/distributions/distribution.py b/phivenv/Lib/site-packages/torch/distributions/distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..ebb26a909b52fc0306639c53aad525fd03159da4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/distribution.py @@ -0,0 +1,346 @@ +# mypy: allow-untyped-defs +import warnings +from typing import Optional +from typing_extensions import deprecated + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.utils import lazy_property +from torch.types import _size + + +__all__ = ["Distribution"] + + +class Distribution: + r""" + Distribution is the abstract base class for probability distributions. + + Args: + batch_shape (torch.Size): The shape over which parameters are batched. + event_shape (torch.Size): The shape of a single sample (without batching). + validate_args (bool, optional): Whether to validate arguments. Default: None. + """ + + has_rsample = False + has_enumerate_support = False + _validate_args = __debug__ + + @staticmethod + def set_default_validate_args(value: bool) -> None: + """ + Sets whether validation is enabled or disabled. + + The default behavior mimics Python's ``assert`` statement: validation + is on by default, but is disabled if Python is run in optimized mode + (via ``python -O``). Validation may be expensive, so you may want to + disable it once a model is working. + + Args: + value (bool): Whether to enable validation. + """ + if value not in [True, False]: + raise ValueError + Distribution._validate_args = value + + def __init__( + self, + batch_shape: torch.Size = torch.Size(), + event_shape: torch.Size = torch.Size(), + validate_args: Optional[bool] = None, + ) -> None: + self._batch_shape = batch_shape + self._event_shape = event_shape + if validate_args is not None: + self._validate_args = validate_args + if self._validate_args: + try: + arg_constraints = self.arg_constraints + except NotImplementedError: + arg_constraints = {} + warnings.warn( + f"{self.__class__} does not define `arg_constraints`. " + + "Please set `arg_constraints = {}` or initialize the distribution " + + "with `validate_args=False` to turn off validation." + ) + for param, constraint in arg_constraints.items(): + if constraints.is_dependent(constraint): + continue # skip constraints that cannot be checked + if param not in self.__dict__ and isinstance( + getattr(type(self), param), lazy_property + ): + continue # skip checking lazily-constructed args + value = getattr(self, param) + valid = constraint.check(value) + if not torch._is_all_true(valid): + raise ValueError( + f"Expected parameter {param} " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"of distribution {repr(self)} " + f"to satisfy the constraint {repr(constraint)}, " + f"but found invalid values:\n{value}" + ) + super().__init__() + + def expand(self, batch_shape: _size, _instance=None): + """ + Returns a new distribution instance (or populates an existing instance + provided by a derived class) with batch dimensions expanded to + `batch_shape`. This method calls :class:`~torch.Tensor.expand` on + the distribution's parameters. As such, this does not allocate new + memory for the expanded distribution instance. Additionally, + this does not repeat any args checking or parameter broadcasting in + `__init__.py`, when an instance is first created. + + Args: + batch_shape (torch.Size): the desired expanded size. + _instance: new instance provided by subclasses that + need to override `.expand`. + + Returns: + New distribution instance with batch dimensions expanded to + `batch_size`. + """ + raise NotImplementedError + + @property + def batch_shape(self) -> torch.Size: + """ + Returns the shape over which parameters are batched. + """ + return self._batch_shape + + @property + def event_shape(self) -> torch.Size: + """ + Returns the shape of a single sample (without batching). + """ + return self._event_shape + + @property + def arg_constraints(self) -> dict[str, constraints.Constraint]: + """ + Returns a dictionary from argument names to + :class:`~torch.distributions.constraints.Constraint` objects that + should be satisfied by each argument of this distribution. Args that + are not tensors need not appear in this dict. + """ + raise NotImplementedError + + @property + def support(self) -> Optional[constraints.Constraint]: + """ + Returns a :class:`~torch.distributions.constraints.Constraint` object + representing this distribution's support. + """ + raise NotImplementedError + + @property + def mean(self) -> Tensor: + """ + Returns the mean of the distribution. + """ + raise NotImplementedError + + @property + def mode(self) -> Tensor: + """ + Returns the mode of the distribution. + """ + raise NotImplementedError(f"{self.__class__} does not implement mode") + + @property + def variance(self) -> Tensor: + """ + Returns the variance of the distribution. + """ + raise NotImplementedError + + @property + def stddev(self) -> Tensor: + """ + Returns the standard deviation of the distribution. + """ + return self.variance.sqrt() + + def sample(self, sample_shape: _size = torch.Size()) -> Tensor: + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. + """ + with torch.no_grad(): + return self.rsample(sample_shape) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. + """ + raise NotImplementedError + + @deprecated( + "`sample_n(n)` will be deprecated. Use `sample((n,))` instead.", + category=FutureWarning, + ) + def sample_n(self, n: int) -> Tensor: + """ + Generates n samples or n batches of samples if the distribution + parameters are batched. + """ + return self.sample(torch.Size((n,))) + + def log_prob(self, value: Tensor) -> Tensor: + """ + Returns the log of the probability density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + def cdf(self, value: Tensor) -> Tensor: + """ + Returns the cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + def icdf(self, value: Tensor) -> Tensor: + """ + Returns the inverse cumulative density/mass function evaluated at + `value`. + + Args: + value (Tensor): + """ + raise NotImplementedError + + def enumerate_support(self, expand: bool = True) -> Tensor: + """ + Returns tensor containing all values supported by a discrete + distribution. The result will enumerate over dimension 0, so the shape + of the result will be `(cardinality,) + batch_shape + event_shape` + (where `event_shape = ()` for univariate distributions). + + Note that this enumerates over all batched tensors in lock-step + `[[0, 0], [1, 1], ...]`. With `expand=False`, enumeration happens + along dim 0, but with the remaining batch dimensions being + singleton dimensions, `[[0], [1], ..`. + + To iterate over the full Cartesian product use + `itertools.product(m.enumerate_support())`. + + Args: + expand (bool): whether to expand the support over the + batch dims to match the distribution's `batch_shape`. + + Returns: + Tensor iterating over dimension 0. + """ + raise NotImplementedError + + def entropy(self) -> Tensor: + """ + Returns entropy of distribution, batched over batch_shape. + + Returns: + Tensor of shape batch_shape. + """ + raise NotImplementedError + + def perplexity(self) -> Tensor: + """ + Returns perplexity of distribution, batched over batch_shape. + + Returns: + Tensor of shape batch_shape. + """ + return torch.exp(self.entropy()) + + def _extended_shape(self, sample_shape: _size = torch.Size()) -> torch.Size: + """ + Returns the size of the sample returned by the distribution, given + a `sample_shape`. Note, that the batch and event shapes of a distribution + instance are fixed at the time of construction. If this is empty, the + returned shape is upcast to (1,). + + Args: + sample_shape (torch.Size): the size of the sample to be drawn. + """ + if not isinstance(sample_shape, torch.Size): + sample_shape = torch.Size(sample_shape) + return torch.Size(sample_shape + self._batch_shape + self._event_shape) + + def _validate_sample(self, value: Tensor) -> None: + """ + Argument validation for distribution methods such as `log_prob`, + `cdf` and `icdf`. The rightmost dimensions of a value to be + scored via these methods must agree with the distribution's batch + and event shapes. + + Args: + value (Tensor): the tensor whose log probability is to be + computed by the `log_prob` method. + Raises + ValueError: when the rightmost dimensions of `value` do not match the + distribution's batch and event shapes. + """ + if not isinstance(value, torch.Tensor): + raise ValueError("The value argument to log_prob must be a Tensor") + + event_dim_start = len(value.size()) - len(self._event_shape) + if value.size()[event_dim_start:] != self._event_shape: + raise ValueError( + f"The right-most size of value must match event_shape: {value.size()} vs {self._event_shape}." + ) + + actual_shape = value.size() + expected_shape = self._batch_shape + self._event_shape + for i, j in zip(reversed(actual_shape), reversed(expected_shape)): + if i != 1 and j != 1 and i != j: + raise ValueError( + f"Value is not broadcastable with batch_shape+event_shape: {actual_shape} vs {expected_shape}." + ) + try: + support = self.support + except NotImplementedError: + warnings.warn( + f"{self.__class__} does not define `support` to enable " + + "sample validation. Please initialize the distribution with " + + "`validate_args=False` to turn off validation." + ) + return + assert support is not None + valid = support.check(value) + if not torch._is_all_true(valid): + raise ValueError( + "Expected value argument " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"to be within the support ({repr(support)}) " + f"of the distribution {repr(self)}, " + f"but found invalid values:\n{value}" + ) + + def _get_checked_instance(self, cls, _instance=None): + if _instance is None and type(self).__init__ != cls.__init__: + raise NotImplementedError( + f"Subclass {self.__class__.__name__} of {cls.__name__} that defines a custom __init__ method " + "must also define a custom .expand() method." + ) + return self.__new__(type(self)) if _instance is None else _instance + + def __repr__(self) -> str: + param_names = [k for k, _ in self.arg_constraints.items() if k in self.__dict__] + args_string = ", ".join( + [ + f"{p}: {self.__dict__[p] if self.__dict__[p].numel() == 1 else self.__dict__[p].size()}" + for p in param_names + ] + ) + return self.__class__.__name__ + "(" + args_string + ")" diff --git a/phivenv/Lib/site-packages/torch/distributions/exp_family.py b/phivenv/Lib/site-packages/torch/distributions/exp_family.py new file mode 100644 index 0000000000000000000000000000000000000000..b79666ed94c5fb48aba6d9ee2419bc84cf93d43a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/exp_family.py @@ -0,0 +1,65 @@ +# mypy: allow-untyped-defs +import torch +from torch import Tensor +from torch.distributions.distribution import Distribution + + +__all__ = ["ExponentialFamily"] + + +class ExponentialFamily(Distribution): + r""" + ExponentialFamily is the abstract base class for probability distributions belonging to an + exponential family, whose probability mass/density function has the form is defined below + + .. math:: + + p_{F}(x; \theta) = \exp(\langle t(x), \theta\rangle - F(\theta) + k(x)) + + where :math:`\theta` denotes the natural parameters, :math:`t(x)` denotes the sufficient statistic, + :math:`F(\theta)` is the log normalizer function for a given family and :math:`k(x)` is the carrier + measure. + + Note: + This class is an intermediary between the `Distribution` class and distributions which belong + to an exponential family mainly to check the correctness of the `.entropy()` and analytic KL + divergence methods. We use this class to compute the entropy and KL divergence using the AD + framework and Bregman divergences (courtesy of: Frank Nielsen and Richard Nock, Entropies and + Cross-entropies of Exponential Families). + """ + + @property + def _natural_params(self) -> tuple[Tensor, ...]: + """ + Abstract method for natural parameters. Returns a tuple of Tensors based + on the distribution + """ + raise NotImplementedError + + def _log_normalizer(self, *natural_params): + """ + Abstract method for log normalizer function. Returns a log normalizer based on + the distribution and input + """ + raise NotImplementedError + + @property + def _mean_carrier_measure(self) -> float: + """ + Abstract method for expected carrier measure, which is required for computing + entropy. + """ + raise NotImplementedError + + def entropy(self): + """ + Method to compute the entropy using Bregman divergence of the log normalizer. + """ + result = -self._mean_carrier_measure + nparams = [p.detach().requires_grad_() for p in self._natural_params] + lg_normal = self._log_normalizer(*nparams) + gradients = torch.autograd.grad(lg_normal.sum(), nparams, create_graph=True) + result += lg_normal + for np, g in zip(nparams, gradients): + result -= (np * g).reshape(self._batch_shape + (-1,)).sum(-1) + return result diff --git a/phivenv/Lib/site-packages/torch/distributions/exponential.py b/phivenv/Lib/site-packages/torch/distributions/exponential.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cc2ad5bc8f0dc41bfd55e01cce3a13bcc33d78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/exponential.py @@ -0,0 +1,93 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Exponential"] + + +class Exponential(ExponentialFamily): + r""" + Creates a Exponential distribution parameterized by :attr:`rate`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Exponential(torch.tensor([1.0])) + >>> m.sample() # Exponential distributed with rate=1 + tensor([ 0.1046]) + + Args: + rate (float or Tensor): rate = 1 / scale of the distribution + """ + + arg_constraints = {"rate": constraints.positive} + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self) -> Tensor: + return self.rate.reciprocal() + + @property + def mode(self) -> Tensor: + return torch.zeros_like(self.rate) + + @property + def stddev(self) -> Tensor: + return self.rate.reciprocal() + + @property + def variance(self) -> Tensor: + return self.rate.pow(-2) + + def __init__( + self, + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + (self.rate,) = broadcast_all(rate) + batch_shape = torch.Size() if isinstance(rate, _Number) else self.rate.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Exponential, _instance) + batch_shape = torch.Size(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Exponential, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + return self.rate.new(shape).exponential_() / self.rate + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return self.rate.log() - self.rate * value + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 1 - torch.exp(-self.rate * value) + + def icdf(self, value): + return -torch.log1p(-value) / self.rate + + def entropy(self): + return 1.0 - torch.log(self.rate) + + @property + def _natural_params(self) -> tuple[Tensor]: + return (-self.rate,) + + def _log_normalizer(self, x): + return -torch.log(-x) diff --git a/phivenv/Lib/site-packages/torch/distributions/fishersnedecor.py b/phivenv/Lib/site-packages/torch/distributions/fishersnedecor.py new file mode 100644 index 0000000000000000000000000000000000000000..3a649adf9162f738e67ff1725b01b2c199c0e46c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/fishersnedecor.py @@ -0,0 +1,107 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import nan, Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.gamma import Gamma +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["FisherSnedecor"] + + +class FisherSnedecor(Distribution): + r""" + Creates a Fisher-Snedecor distribution parameterized by :attr:`df1` and :attr:`df2`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = FisherSnedecor(torch.tensor([1.0]), torch.tensor([2.0])) + >>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2 + tensor([ 0.2453]) + + Args: + df1 (float or Tensor): degrees of freedom parameter 1 + df2 (float or Tensor): degrees of freedom parameter 2 + """ + + arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} + support = constraints.positive + has_rsample = True + + def __init__( + self, + df1: Union[Tensor, float], + df2: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.df1, self.df2 = broadcast_all(df1, df2) + self._gamma1 = Gamma(self.df1 * 0.5, self.df1) + self._gamma2 = Gamma(self.df2 * 0.5, self.df2) + + if isinstance(df1, _Number) and isinstance(df2, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.df1.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(FisherSnedecor, _instance) + batch_shape = torch.Size(batch_shape) + new.df1 = self.df1.expand(batch_shape) + new.df2 = self.df2.expand(batch_shape) + new._gamma1 = self._gamma1.expand(batch_shape) + new._gamma2 = self._gamma2.expand(batch_shape) + super(FisherSnedecor, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> Tensor: + df2 = self.df2.clone(memory_format=torch.contiguous_format) + df2[df2 <= 2] = nan + return df2 / (df2 - 2) + + @property + def mode(self) -> Tensor: + mode = (self.df1 - 2) / self.df1 * self.df2 / (self.df2 + 2) + mode[self.df1 <= 2] = nan + return mode + + @property + def variance(self) -> Tensor: + df2 = self.df2.clone(memory_format=torch.contiguous_format) + df2[df2 <= 4] = nan + return ( + 2 + * df2.pow(2) + * (self.df1 + df2 - 2) + / (self.df1 * (df2 - 2).pow(2) * (df2 - 4)) + ) + + def rsample(self, sample_shape: _size = torch.Size(())) -> Tensor: + shape = self._extended_shape(sample_shape) + # X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2) + # Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2) + X1 = self._gamma1.rsample(sample_shape).view(shape) + X2 = self._gamma2.rsample(sample_shape).view(shape) + tiny = torch.finfo(X2.dtype).tiny + X2.clamp_(min=tiny) + Y = X1 / X2 + Y.clamp_(min=tiny) + return Y + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + ct1 = self.df1 * 0.5 + ct2 = self.df2 * 0.5 + ct3 = self.df1 / self.df2 + t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma() + t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value) + t3 = (ct1 + ct2) * torch.log1p(ct3 * value) + return t1 + t2 - t3 diff --git a/phivenv/Lib/site-packages/torch/distributions/gamma.py b/phivenv/Lib/site-packages/torch/distributions/gamma.py new file mode 100644 index 0000000000000000000000000000000000000000..e738ca2397e66880226ccf3efe0a646b076c52c9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/gamma.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Gamma"] + + +def _standard_gamma(concentration): + return torch._standard_gamma(concentration) + + +class Gamma(ExponentialFamily): + r""" + Creates a Gamma distribution parameterized by shape :attr:`concentration` and :attr:`rate`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Gamma(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # Gamma distributed with concentration=1 and rate=1 + tensor([ 0.1046]) + + Args: + concentration (float or Tensor): shape parameter of the distribution + (often referred to as alpha) + rate (float or Tensor): rate parameter of the distribution + (often referred to as beta), rate = 1 / scale + """ + + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } + support = constraints.nonnegative + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self) -> Tensor: + return self.concentration / self.rate + + @property + def mode(self) -> Tensor: + return ((self.concentration - 1) / self.rate).clamp(min=0) + + @property + def variance(self) -> Tensor: + return self.concentration / self.rate.pow(2) + + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.concentration, self.rate = broadcast_all(concentration, rate) + if isinstance(concentration, _Number) and isinstance(rate, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.concentration.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Gamma, _instance) + batch_shape = torch.Size(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Gamma, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + value = _standard_gamma(self.concentration.expand(shape)) / self.rate.expand( + shape + ) + value.detach().clamp_( + min=torch.finfo(value.dtype).tiny + ) # do not record in autograd graph + return value + + def log_prob(self, value): + value = torch.as_tensor(value, dtype=self.rate.dtype, device=self.rate.device) + if self._validate_args: + self._validate_sample(value) + return ( + torch.xlogy(self.concentration, self.rate) + + torch.xlogy(self.concentration - 1, value) + - self.rate * value + - torch.lgamma(self.concentration) + ) + + def entropy(self): + return ( + self.concentration + - torch.log(self.rate) + + torch.lgamma(self.concentration) + + (1.0 - self.concentration) * torch.digamma(self.concentration) + ) + + @property + def _natural_params(self) -> tuple[Tensor, Tensor]: + return (self.concentration - 1, -self.rate) + + def _log_normalizer(self, x, y): + return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return torch.special.gammainc(self.concentration, self.rate * value) diff --git a/phivenv/Lib/site-packages/torch/distributions/generalized_pareto.py b/phivenv/Lib/site-packages/torch/distributions/generalized_pareto.py new file mode 100644 index 0000000000000000000000000000000000000000..358c2a56b766b55126f4fec3d1a563c8b12c3ae6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/generalized_pareto.py @@ -0,0 +1,150 @@ +# mypy: allow-untyped-defs +import math +from numbers import Number, Real + +import torch +from torch import inf, nan +from torch.distributions import constraints, Distribution +from torch.distributions.utils import broadcast_all + + +__all__ = ["GeneralizedPareto"] + + +class GeneralizedPareto(Distribution): + r""" + Creates a Generalized Pareto distribution parameterized by :attr:`loc`, :attr:`scale`, and :attr:`concentration`. + + The Generalized Pareto distribution is a family of continuous probability distributions on the real line. + Special cases include Exponential (when :attr:`loc` = 0, :attr:`concentration` = 0), Pareto (when :attr:`concentration` > 0, + :attr:`loc` = :attr:`scale` / :attr:`concentration`), and Uniform (when :attr:`concentration` = -1). + + This distribution is often used to model the tails of other distributions. This implementation is based on the + implementation in TensorFlow Probability. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = GeneralizedPareto(torch.tensor([0.1]), torch.tensor([2.0]), torch.tensor([0.4])) + >>> m.sample() # sample from a Generalized Pareto distribution with loc=0.1, scale=2.0, and concentration=0.4 + tensor([ 1.5623]) + + Args: + loc (float or Tensor): Location parameter of the distribution + scale (float or Tensor): Scale parameter of the distribution + concentration (float or Tensor): Concentration parameter of the distribution + """ + + arg_constraints = { + "loc": constraints.real, + "scale": constraints.positive, + "concentration": constraints.real, + } + has_rsample = True + + def __init__(self, loc, scale, concentration, validate_args=None): + self.loc, self.scale, self.concentration = broadcast_all( + loc, scale, concentration + ) + if ( + isinstance(loc, Number) + and isinstance(scale, Number) + and isinstance(concentration, Number) + ): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(GeneralizedPareto, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + super(GeneralizedPareto, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.icdf(u) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + z = self._z(value) + eq_zero = torch.isclose(self.concentration, torch.tensor(0.0)) + safe_conc = torch.where( + eq_zero, torch.ones_like(self.concentration), self.concentration + ) + y = 1 / safe_conc + torch.ones_like(z) + where_nonzero = torch.where(y == 0, y, y * torch.log1p(safe_conc * z)) + log_scale = ( + math.log(self.scale) if isinstance(self.scale, Real) else self.scale.log() + ) + return -log_scale - torch.where(eq_zero, z, where_nonzero) + + def log_survival_function(self, value): + if self._validate_args: + self._validate_sample(value) + z = self._z(value) + eq_zero = torch.isclose(self.concentration, torch.tensor(0.0)) + safe_conc = torch.where( + eq_zero, torch.ones_like(self.concentration), self.concentration + ) + where_nonzero = -torch.log1p(safe_conc * z) / safe_conc + return torch.where(eq_zero, -z, where_nonzero) + + def log_cdf(self, value): + return torch.log1p(-torch.exp(self.log_survival_function(value))) + + def cdf(self, value): + return torch.exp(self.log_cdf(value)) + + def icdf(self, value): + loc = self.loc + scale = self.scale + concentration = self.concentration + eq_zero = torch.isclose(concentration, torch.zeros_like(concentration)) + safe_conc = torch.where(eq_zero, torch.ones_like(concentration), concentration) + logu = torch.log1p(-value) + where_nonzero = loc + scale / safe_conc * torch.expm1(-safe_conc * logu) + where_zero = loc - scale * logu + return torch.where(eq_zero, where_zero, where_nonzero) + + def _z(self, x): + return (x - self.loc) / self.scale + + @property + def mean(self): + concentration = self.concentration + valid = concentration < 1 + safe_conc = torch.where(valid, concentration, 0.5) + result = self.loc + self.scale / (1 - safe_conc) + return torch.where(valid, result, nan) + + @property + def variance(self): + concentration = self.concentration + valid = concentration < 0.5 + safe_conc = torch.where(valid, concentration, 0.25) + result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc)) + return torch.where(valid, result, nan) + + def entropy(self): + ans = torch.log(self.scale) + self.concentration + 1 + return torch.broadcast_to(ans, self._batch_shape) + + @property + def mode(self): + return self.loc + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + lower = self.loc + upper = torch.where( + self.concentration < 0, lower - self.scale / self.concentration, inf + ) + return constraints.interval(lower, upper) diff --git a/phivenv/Lib/site-packages/torch/distributions/geometric.py b/phivenv/Lib/site-packages/torch/distributions/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7302841166cbad779bc9c867d348d86c4c4332 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/geometric.py @@ -0,0 +1,140 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.nn.functional import binary_cross_entropy_with_logits +from torch.types import _Number, Number + + +__all__ = ["Geometric"] + + +class Geometric(Distribution): + r""" + Creates a Geometric distribution parameterized by :attr:`probs`, + where :attr:`probs` is the probability of success of Bernoulli trials. + + .. math:: + + P(X=k) = (1-p)^{k} p, k = 0, 1, ... + + .. note:: + :func:`torch.distributions.geometric.Geometric` :math:`(k+1)`-th trial is the first success + hence draws samples in :math:`\{0, 1, \ldots\}`, whereas + :func:`torch.Tensor.geometric_` `k`-th trial is the first success hence draws samples in :math:`\{1, 2, \ldots\}`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Geometric(torch.tensor([0.3])) + >>> m.sample() # underlying Bernoulli has 30% chance 1; 70% chance 0 + tensor([ 2.]) + + Args: + probs (Number, Tensor): the probability of sampling `1`. Must be in range (0, 1] + logits (Number, Tensor): the log-odds of sampling `1`. + """ + + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.nonnegative_integer + + def __init__( + self, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + (self.probs,) = broadcast_all(probs) + else: + assert logits is not None # helps mypy + (self.logits,) = broadcast_all(logits) + probs_or_logits = probs if probs is not None else logits + if isinstance(probs_or_logits, _Number): + batch_shape = torch.Size() + else: + assert probs_or_logits is not None # helps mypy + batch_shape = probs_or_logits.size() + super().__init__(batch_shape, validate_args=validate_args) + if self._validate_args and probs is not None: + # Add an extra check beyond unit_interval + value = self.probs + valid = value > 0 + if not valid.all(): + invalid_value = value.data[~valid] + raise ValueError( + "Expected parameter probs " + f"({type(value).__name__} of shape {tuple(value.shape)}) " + f"of distribution {repr(self)} " + f"to be positive but found invalid values:\n{invalid_value}" + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Geometric, _instance) + batch_shape = torch.Size(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + super(Geometric, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> Tensor: + return 1.0 / self.probs - 1.0 + + @property + def mode(self) -> Tensor: + return torch.zeros_like(self.probs) + + @property + def variance(self) -> Tensor: + return (1.0 / self.probs - 1.0) / self.probs + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self) -> Tensor: + return logits_to_probs(self.logits, is_binary=True) + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + tiny = torch.finfo(self.probs.dtype).tiny + with torch.no_grad(): + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .uniform_() + u = torch.rand(shape, dtype=self.probs.dtype, device=self.probs.device) + u = u.clamp(min=tiny) + else: + u = self.probs.new(shape).uniform_(tiny, 1) + return (u.log() / (-self.probs).log1p()).floor() + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value, probs = broadcast_all(value, self.probs) + probs = probs.clone(memory_format=torch.contiguous_format) + probs[(probs == 1) & (value == 0)] = 0 + return value * (-probs).log1p() + self.probs.log() + + def entropy(self): + return ( + binary_cross_entropy_with_logits(self.logits, self.probs, reduction="none") + / self.probs + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/gumbel.py b/phivenv/Lib/site-packages/torch/distributions/gumbel.py new file mode 100644 index 0000000000000000000000000000000000000000..7859b6b731c8c35412bfbe11a43eaba831bfec54 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/gumbel.py @@ -0,0 +1,91 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform +from torch.distributions.uniform import Uniform +from torch.distributions.utils import broadcast_all, euler_constant +from torch.types import _Number + + +__all__ = ["Gumbel"] + + +class Gumbel(TransformedDistribution): + r""" + Samples from a Gumbel Distribution. + + Examples:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Gumbel(torch.tensor([1.0]), torch.tensor([2.0])) + >>> m.sample() # sample from Gumbel distribution with loc=1, scale=2 + tensor([ 1.0124]) + + Args: + loc (float or Tensor): Location parameter of the distribution + scale (float or Tensor): Scale parameter of the distribution + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.loc, self.scale = broadcast_all(loc, scale) + finfo = torch.finfo(self.loc.dtype) + if isinstance(loc, _Number) and isinstance(scale, _Number): + base_dist = Uniform(finfo.tiny, 1 - finfo.eps, validate_args=validate_args) + else: + base_dist = Uniform( + torch.full_like(self.loc, finfo.tiny), + torch.full_like(self.loc, 1 - finfo.eps), + validate_args=validate_args, + ) + transforms = [ + ExpTransform().inv, + AffineTransform(loc=0, scale=-torch.ones_like(self.scale)), + ExpTransform().inv, + AffineTransform(loc=loc, scale=-self.scale), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Gumbel, _instance) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + # Explicitly defining the log probability function for Gumbel due to precision issues + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (self.loc - value) / self.scale + return (y - y.exp()) - self.scale.log() + + @property + def mean(self) -> Tensor: + return self.loc + self.scale * euler_constant + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def stddev(self) -> Tensor: + return (math.pi / math.sqrt(6)) * self.scale + + @property + def variance(self) -> Tensor: + return self.stddev.pow(2) + + def entropy(self): + return self.scale.log() + (1 + euler_constant) diff --git a/phivenv/Lib/site-packages/torch/distributions/half_cauchy.py b/phivenv/Lib/site-packages/torch/distributions/half_cauchy.py new file mode 100644 index 0000000000000000000000000000000000000000..70cc3b693958b40b745842f54fb025ac46f2c305 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/half_cauchy.py @@ -0,0 +1,91 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import inf, Tensor +from torch.distributions import constraints +from torch.distributions.cauchy import Cauchy +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AbsTransform + + +__all__ = ["HalfCauchy"] + + +class HalfCauchy(TransformedDistribution): + r""" + Creates a half-Cauchy distribution parameterized by `scale` where:: + + X ~ Cauchy(0, scale) + Y = |X| ~ HalfCauchy(scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = HalfCauchy(torch.tensor([1.0])) + >>> m.sample() # half-cauchy distributed with scale=1 + tensor([ 2.3214]) + + Args: + scale (float or Tensor): scale of the full Cauchy distribution + """ + + arg_constraints = {"scale": constraints.positive} + support = constraints.nonnegative + has_rsample = True + base_dist: Cauchy + + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + base_dist = Cauchy(0, scale, validate_args=False) + super().__init__(base_dist, AbsTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(HalfCauchy, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def scale(self) -> Tensor: + return self.base_dist.scale + + @property + def mean(self) -> Tensor: + return torch.full( + self._extended_shape(), + math.inf, + dtype=self.scale.dtype, + device=self.scale.device, + ) + + @property + def mode(self) -> Tensor: + return torch.zeros_like(self.scale) + + @property + def variance(self) -> Tensor: + return self.base_dist.variance + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + value = torch.as_tensor( + value, dtype=self.base_dist.scale.dtype, device=self.base_dist.scale.device + ) + log_prob = self.base_dist.log_prob(value) + math.log(2) + log_prob = torch.where(value >= 0, log_prob, -inf) + return log_prob + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 2 * self.base_dist.cdf(value) - 1 + + def icdf(self, prob): + return self.base_dist.icdf((prob + 1) / 2) + + def entropy(self): + return self.base_dist.entropy() - math.log(2) diff --git a/phivenv/Lib/site-packages/torch/distributions/half_normal.py b/phivenv/Lib/site-packages/torch/distributions/half_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..8faae4e1b852fca25de29d36453a5fff29dab91e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/half_normal.py @@ -0,0 +1,83 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import inf, Tensor +from torch.distributions import constraints +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AbsTransform + + +__all__ = ["HalfNormal"] + + +class HalfNormal(TransformedDistribution): + r""" + Creates a half-normal distribution parameterized by `scale` where:: + + X ~ Normal(0, scale) + Y = |X| ~ HalfNormal(scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = HalfNormal(torch.tensor([1.0])) + >>> m.sample() # half-normal distributed with scale=1 + tensor([ 0.1046]) + + Args: + scale (float or Tensor): scale of the full Normal distribution + """ + + arg_constraints = {"scale": constraints.positive} + support = constraints.nonnegative + has_rsample = True + base_dist: Normal + + def __init__( + self, + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + base_dist = Normal(0, scale, validate_args=False) + super().__init__(base_dist, AbsTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(HalfNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def scale(self) -> Tensor: + return self.base_dist.scale + + @property + def mean(self) -> Tensor: + return self.scale * math.sqrt(2 / math.pi) + + @property + def mode(self) -> Tensor: + return torch.zeros_like(self.scale) + + @property + def variance(self) -> Tensor: + return self.scale.pow(2) * (1 - 2 / math.pi) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_prob = self.base_dist.log_prob(value) + math.log(2) + log_prob = torch.where(value >= 0, log_prob, -inf) + return log_prob + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 2 * self.base_dist.cdf(value) - 1 + + def icdf(self, prob): + return self.base_dist.icdf((prob + 1) / 2) + + def entropy(self): + return self.base_dist.entropy() - math.log(2) diff --git a/phivenv/Lib/site-packages/torch/distributions/independent.py b/phivenv/Lib/site-packages/torch/distributions/independent.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0023bda6667b018c1552bace35dc8a2a8382ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/independent.py @@ -0,0 +1,137 @@ +# mypy: allow-untyped-defs +from typing import Generic, Optional, TypeVar + +import torch +from torch import Size, Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _sum_rightmost +from torch.types import _size + + +__all__ = ["Independent"] + + +D = TypeVar("D", bound=Distribution) + + +class Independent(Distribution, Generic[D]): + r""" + Reinterprets some of the batch dims of a distribution as event dims. + + This is mainly useful for changing the shape of the result of + :meth:`log_prob`. For example to create a diagonal Normal distribution with + the same shape as a Multivariate Normal distribution (so they are + interchangeable), you can:: + + >>> from torch.distributions.multivariate_normal import MultivariateNormal + >>> from torch.distributions.normal import Normal + >>> loc = torch.zeros(3) + >>> scale = torch.ones(3) + >>> mvn = MultivariateNormal(loc, scale_tril=torch.diag(scale)) + >>> [mvn.batch_shape, mvn.event_shape] + [torch.Size([]), torch.Size([3])] + >>> normal = Normal(loc, scale) + >>> [normal.batch_shape, normal.event_shape] + [torch.Size([3]), torch.Size([])] + >>> diagn = Independent(normal, 1) + >>> [diagn.batch_shape, diagn.event_shape] + [torch.Size([]), torch.Size([3])] + + Args: + base_distribution (torch.distributions.distribution.Distribution): a + base distribution + reinterpreted_batch_ndims (int): the number of batch dims to + reinterpret as event dims + """ + + arg_constraints: dict[str, constraints.Constraint] = {} + base_dist: D + + def __init__( + self, + base_distribution: D, + reinterpreted_batch_ndims: int, + validate_args: Optional[bool] = None, + ) -> None: + if reinterpreted_batch_ndims > len(base_distribution.batch_shape): + raise ValueError( + "Expected reinterpreted_batch_ndims <= len(base_distribution.batch_shape), " + f"actual {reinterpreted_batch_ndims} vs {len(base_distribution.batch_shape)}" + ) + shape: Size = base_distribution.batch_shape + base_distribution.event_shape + event_dim: int = reinterpreted_batch_ndims + len(base_distribution.event_shape) + batch_shape = shape[: len(shape) - event_dim] + event_shape = shape[len(shape) - event_dim :] + self.base_dist = base_distribution + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Independent, _instance) + batch_shape = torch.Size(batch_shape) + new.base_dist = self.base_dist.expand( + batch_shape + self.event_shape[: self.reinterpreted_batch_ndims] + ) + new.reinterpreted_batch_ndims = self.reinterpreted_batch_ndims + super(Independent, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @property + def has_rsample(self) -> bool: # type: ignore[override] + return self.base_dist.has_rsample + + @property + def has_enumerate_support(self) -> bool: # type: ignore[override] + if self.reinterpreted_batch_ndims > 0: + return False + return self.base_dist.has_enumerate_support + + @constraints.dependent_property + def support(self): + result = self.base_dist.support + if self.reinterpreted_batch_ndims: + result = constraints.independent(result, self.reinterpreted_batch_ndims) + return result + + @property + def mean(self) -> Tensor: + return self.base_dist.mean + + @property + def mode(self) -> Tensor: + return self.base_dist.mode + + @property + def variance(self) -> Tensor: + return self.base_dist.variance + + def sample(self, sample_shape=torch.Size()) -> Tensor: + return self.base_dist.sample(sample_shape) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + return self.base_dist.rsample(sample_shape) + + def log_prob(self, value): + log_prob = self.base_dist.log_prob(value) + return _sum_rightmost(log_prob, self.reinterpreted_batch_ndims) + + def entropy(self): + entropy = self.base_dist.entropy() + return _sum_rightmost(entropy, self.reinterpreted_batch_ndims) + + def enumerate_support(self, expand=True): + if self.reinterpreted_batch_ndims > 0: + raise NotImplementedError( + "Enumeration over cartesian product is not implemented" + ) + return self.base_dist.enumerate_support(expand=expand) + + def __repr__(self): + return ( + self.__class__.__name__ + + f"({self.base_dist}, {self.reinterpreted_batch_ndims})" + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/inverse_gamma.py b/phivenv/Lib/site-packages/torch/distributions/inverse_gamma.py new file mode 100644 index 0000000000000000000000000000000000000000..269a97fe6223b1cfc573b956ed6f578c2daad04a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/inverse_gamma.py @@ -0,0 +1,91 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.gamma import Gamma +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import PowerTransform + + +__all__ = ["InverseGamma"] + + +class InverseGamma(TransformedDistribution): + r""" + Creates an inverse gamma distribution parameterized by :attr:`concentration` and :attr:`rate` + where:: + + X ~ Gamma(concentration, rate) + Y = 1 / X ~ InverseGamma(concentration, rate) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterinistic") + >>> m = InverseGamma(torch.tensor([2.0]), torch.tensor([3.0])) + >>> m.sample() + tensor([ 1.2953]) + + Args: + concentration (float or Tensor): shape parameter of the distribution + (often referred to as alpha) + rate (float or Tensor): rate = 1 / scale of the distribution + (often referred to as beta) + """ + + arg_constraints = { + "concentration": constraints.positive, + "rate": constraints.positive, + } + support = constraints.positive + has_rsample = True + base_dist: Gamma + + def __init__( + self, + concentration: Union[Tensor, float], + rate: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + base_dist = Gamma(concentration, rate, validate_args=validate_args) + neg_one = -base_dist.rate.new_ones(()) + super().__init__( + base_dist, PowerTransform(neg_one), validate_args=validate_args + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(InverseGamma, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def concentration(self) -> Tensor: + return self.base_dist.concentration + + @property + def rate(self) -> Tensor: + return self.base_dist.rate + + @property + def mean(self) -> Tensor: + result = self.rate / (self.concentration - 1) + return torch.where(self.concentration > 1, result, torch.inf) + + @property + def mode(self) -> Tensor: + return self.rate / (self.concentration + 1) + + @property + def variance(self) -> Tensor: + result = self.rate.square() / ( + (self.concentration - 1).square() * (self.concentration - 2) + ) + return torch.where(self.concentration > 2, result, torch.inf) + + def entropy(self): + return ( + self.concentration + + self.rate.log() + + self.concentration.lgamma() + - (1 + self.concentration) * self.concentration.digamma() + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/kl.py b/phivenv/Lib/site-packages/torch/distributions/kl.py new file mode 100644 index 0000000000000000000000000000000000000000..3faf8ada296aec3c49a5dce55ff10c28c9ac52de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/kl.py @@ -0,0 +1,972 @@ +# mypy: allow-untyped-defs +import math +import warnings +from functools import total_ordering +from typing import Callable + +import torch +from torch import inf, Tensor + +from .bernoulli import Bernoulli +from .beta import Beta +from .binomial import Binomial +from .categorical import Categorical +from .cauchy import Cauchy +from .continuous_bernoulli import ContinuousBernoulli +from .dirichlet import Dirichlet +from .distribution import Distribution +from .exp_family import ExponentialFamily +from .exponential import Exponential +from .gamma import Gamma +from .geometric import Geometric +from .gumbel import Gumbel +from .half_normal import HalfNormal +from .independent import Independent +from .laplace import Laplace +from .lowrank_multivariate_normal import ( + _batch_lowrank_logdet, + _batch_lowrank_mahalanobis, + LowRankMultivariateNormal, +) +from .multivariate_normal import _batch_mahalanobis, MultivariateNormal +from .normal import Normal +from .one_hot_categorical import OneHotCategorical +from .pareto import Pareto +from .poisson import Poisson +from .transformed_distribution import TransformedDistribution +from .uniform import Uniform +from .utils import _sum_rightmost, euler_constant as _euler_gamma + + +_KL_REGISTRY: dict[ + tuple[type, type], Callable +] = {} # Source of truth mapping a few general (type, type) pairs to functions. +_KL_MEMOIZE: dict[ + tuple[type, type], Callable +] = {} # Memoized version mapping many specific (type, type) pairs to functions. + +__all__ = ["register_kl", "kl_divergence"] + + +def register_kl(type_p, type_q): + """ + Decorator to register a pairwise function with :meth:`kl_divergence`. + Usage:: + + @register_kl(Normal, Normal) + def kl_normal_normal(p, q): + # insert implementation here + + Lookup returns the most specific (type,type) match ordered by subclass. If + the match is ambiguous, a `RuntimeWarning` is raised. For example to + resolve the ambiguous situation:: + + @register_kl(BaseP, DerivedQ) + def kl_version1(p, q): ... + @register_kl(DerivedP, BaseQ) + def kl_version2(p, q): ... + + you should register a third most-specific implementation, e.g.:: + + register_kl(DerivedP, DerivedQ)(kl_version1) # Break the tie. + + Args: + type_p (type): A subclass of :class:`~torch.distributions.Distribution`. + type_q (type): A subclass of :class:`~torch.distributions.Distribution`. + """ + if not isinstance(type_p, type) and issubclass(type_p, Distribution): + raise TypeError( + f"Expected type_p to be a Distribution subclass but got {type_p}" + ) + if not isinstance(type_q, type) and issubclass(type_q, Distribution): + raise TypeError( + f"Expected type_q to be a Distribution subclass but got {type_q}" + ) + + def decorator(fun): + _KL_REGISTRY[type_p, type_q] = fun + _KL_MEMOIZE.clear() # reset since lookup order may have changed + return fun + + return decorator + + +@total_ordering +class _Match: + __slots__ = ["types"] + + def __init__(self, *types): + self.types = types + + def __eq__(self, other): + return self.types == other.types + + def __le__(self, other): + for x, y in zip(self.types, other.types): + if not issubclass(x, y): + return False + if x is not y: + break + return True + + +def _dispatch_kl(type_p, type_q): + """ + Find the most specific approximate match, assuming single inheritance. + """ + matches = [ + (super_p, super_q) + for super_p, super_q in _KL_REGISTRY + if issubclass(type_p, super_p) and issubclass(type_q, super_q) + ] + if not matches: + return NotImplemented + # Check that the left- and right- lexicographic orders agree. + # mypy isn't smart enough to know that _Match implements __lt__ + # see: https://github.com/python/typing/issues/760#issuecomment-710670503 + left_p, left_q = min(_Match(*m) for m in matches).types # type: ignore[type-var] + right_q, right_p = min(_Match(*reversed(m)) for m in matches).types # type: ignore[type-var] + left_fun = _KL_REGISTRY[left_p, left_q] + right_fun = _KL_REGISTRY[right_p, right_q] + if left_fun is not right_fun: + warnings.warn( + f"Ambiguous kl_divergence({type_p.__name__}, {type_q.__name__}). " + f"Please register_kl({left_p.__name__}, {right_q.__name__})", + RuntimeWarning, + ) + return left_fun + + +def _infinite_like(tensor): + """ + Helper function for obtaining infinite KL Divergence throughout + """ + return torch.full_like(tensor, inf) + + +def _x_log_x(tensor): + """ + Utility function for calculating x log x + """ + return torch.special.xlogy(tensor, tensor) # produces correct result for x=0 + + +def _batch_trace_XXT(bmat): + """ + Utility function for calculating the trace of XX^{T} with X having arbitrary trailing batch dimensions + """ + n = bmat.size(-1) + m = bmat.size(-2) + flat_trace = bmat.reshape(-1, m * n).pow(2).sum(-1) + return flat_trace.reshape(bmat.shape[:-2]) + + +def kl_divergence(p: Distribution, q: Distribution) -> Tensor: + r""" + Compute Kullback-Leibler divergence :math:`KL(p \| q)` between two distributions. + + .. math:: + + KL(p \| q) = \int p(x) \log\frac {p(x)} {q(x)} \,dx + + Args: + p (Distribution): A :class:`~torch.distributions.Distribution` object. + q (Distribution): A :class:`~torch.distributions.Distribution` object. + + Returns: + Tensor: A batch of KL divergences of shape `batch_shape`. + + Raises: + NotImplementedError: If the distribution types have not been registered via + :meth:`register_kl`. + """ + try: + fun = _KL_MEMOIZE[type(p), type(q)] + except KeyError: + fun = _dispatch_kl(type(p), type(q)) + _KL_MEMOIZE[type(p), type(q)] = fun + if fun is NotImplemented: + raise NotImplementedError( + f"No KL(p || q) is implemented for p type {p.__class__.__name__} and q type {q.__class__.__name__}" + ) + return fun(p, q) + + +################################################################################ +# KL Divergence Implementations +################################################################################ + +# Same distributions + + +@register_kl(Bernoulli, Bernoulli) +def _kl_bernoulli_bernoulli(p, q): + t1 = p.probs * ( + torch.nn.functional.softplus(-q.logits) + - torch.nn.functional.softplus(-p.logits) + ) + t1[q.probs == 0] = inf + t1[p.probs == 0] = 0 + t2 = (1 - p.probs) * ( + torch.nn.functional.softplus(q.logits) - torch.nn.functional.softplus(p.logits) + ) + t2[q.probs == 1] = inf + t2[p.probs == 1] = 0 + return t1 + t2 + + +@register_kl(Beta, Beta) +def _kl_beta_beta(p, q): + sum_params_p = p.concentration1 + p.concentration0 + sum_params_q = q.concentration1 + q.concentration0 + t1 = q.concentration1.lgamma() + q.concentration0.lgamma() + (sum_params_p).lgamma() + t2 = p.concentration1.lgamma() + p.concentration0.lgamma() + (sum_params_q).lgamma() + t3 = (p.concentration1 - q.concentration1) * torch.digamma(p.concentration1) + t4 = (p.concentration0 - q.concentration0) * torch.digamma(p.concentration0) + t5 = (sum_params_q - sum_params_p) * torch.digamma(sum_params_p) + return t1 - t2 + t3 + t4 + t5 + + +@register_kl(Binomial, Binomial) +def _kl_binomial_binomial(p, q): + # from https://math.stackexchange.com/questions/2214993/ + # kullback-leibler-divergence-for-binomial-distributions-p-and-q + if (p.total_count < q.total_count).any(): + raise NotImplementedError( + "KL between Binomials where q.total_count > p.total_count is not implemented" + ) + kl = p.total_count * ( + p.probs * (p.logits - q.logits) + (-p.probs).log1p() - (-q.probs).log1p() + ) + inf_idxs = p.total_count > q.total_count + kl[inf_idxs] = _infinite_like(kl[inf_idxs]) + return kl + + +@register_kl(Categorical, Categorical) +def _kl_categorical_categorical(p, q): + t = p.probs * (p.logits - q.logits) + t[(q.probs == 0).expand_as(t)] = inf + t[(p.probs == 0).expand_as(t)] = 0 + return t.sum(-1) + + +@register_kl(ContinuousBernoulli, ContinuousBernoulli) +def _kl_continuous_bernoulli_continuous_bernoulli(p, q): + t1 = p.mean * (p.logits - q.logits) + t2 = p._cont_bern_log_norm() + torch.log1p(-p.probs) + t3 = -q._cont_bern_log_norm() - torch.log1p(-q.probs) + return t1 + t2 + t3 + + +@register_kl(Dirichlet, Dirichlet) +def _kl_dirichlet_dirichlet(p, q): + # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ + sum_p_concentration = p.concentration.sum(-1) + sum_q_concentration = q.concentration.sum(-1) + t1 = sum_p_concentration.lgamma() - sum_q_concentration.lgamma() + t2 = (p.concentration.lgamma() - q.concentration.lgamma()).sum(-1) + t3 = p.concentration - q.concentration + t4 = p.concentration.digamma() - sum_p_concentration.digamma().unsqueeze(-1) + return t1 - t2 + (t3 * t4).sum(-1) + + +@register_kl(Exponential, Exponential) +def _kl_exponential_exponential(p, q): + rate_ratio = q.rate / p.rate + t1 = -rate_ratio.log() + return t1 + rate_ratio - 1 + + +@register_kl(ExponentialFamily, ExponentialFamily) +def _kl_expfamily_expfamily(p, q): + if not type(p) == type(q): + raise NotImplementedError( + "The cross KL-divergence between different exponential families cannot \ + be computed using Bregman divergences" + ) + p_nparams = [np.detach().requires_grad_() for np in p._natural_params] + q_nparams = q._natural_params + lg_normal = p._log_normalizer(*p_nparams) + gradients = torch.autograd.grad(lg_normal.sum(), p_nparams, create_graph=True) + result = q._log_normalizer(*q_nparams) - lg_normal + for pnp, qnp, g in zip(p_nparams, q_nparams, gradients): + term = (qnp - pnp) * g + result -= _sum_rightmost(term, len(q.event_shape)) + return result + + +@register_kl(Gamma, Gamma) +def _kl_gamma_gamma(p, q): + t1 = q.concentration * (p.rate / q.rate).log() + t2 = torch.lgamma(q.concentration) - torch.lgamma(p.concentration) + t3 = (p.concentration - q.concentration) * torch.digamma(p.concentration) + t4 = (q.rate - p.rate) * (p.concentration / p.rate) + return t1 + t2 + t3 + t4 + + +@register_kl(Gumbel, Gumbel) +def _kl_gumbel_gumbel(p, q): + ct1 = p.scale / q.scale + ct2 = q.loc / q.scale + ct3 = p.loc / q.scale + t1 = -ct1.log() - ct2 + ct3 + t2 = ct1 * _euler_gamma + t3 = torch.exp(ct2 + (1 + ct1).lgamma() - ct3) + return t1 + t2 + t3 - (1 + _euler_gamma) + + +@register_kl(Geometric, Geometric) +def _kl_geometric_geometric(p, q): + return -p.entropy() - torch.log1p(-q.probs) / p.probs - q.logits + + +@register_kl(HalfNormal, HalfNormal) +def _kl_halfnormal_halfnormal(p, q): + return _kl_normal_normal(p.base_dist, q.base_dist) + + +@register_kl(Laplace, Laplace) +def _kl_laplace_laplace(p, q): + # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf + scale_ratio = p.scale / q.scale + loc_abs_diff = (p.loc - q.loc).abs() + t1 = -scale_ratio.log() + t2 = loc_abs_diff / q.scale + t3 = scale_ratio * torch.exp(-loc_abs_diff / p.scale) + return t1 + t2 + t3 - 1 + + +@register_kl(LowRankMultivariateNormal, LowRankMultivariateNormal) +def _kl_lowrankmultivariatenormal_lowrankmultivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two Low Rank Multivariate Normals with\ + different event shapes cannot be computed" + ) + + term1 = _batch_lowrank_logdet( + q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril + ) - _batch_lowrank_logdet( + p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril + ) + term3 = _batch_lowrank_mahalanobis( + q._unbroadcasted_cov_factor, + q._unbroadcasted_cov_diag, + q.loc - p.loc, + q._capacitance_tril, + ) + # Expands term2 according to + # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ (pW @ pW.T + pD) + # = [inv(qD) - A.T @ A] @ (pD + pW @ pW.T) + qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) + A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) + term21 = (p._unbroadcasted_cov_diag / q._unbroadcasted_cov_diag).sum(-1) + term22 = _batch_trace_XXT( + p._unbroadcasted_cov_factor * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) + ) + term23 = _batch_trace_XXT(A * p._unbroadcasted_cov_diag.sqrt().unsqueeze(-2)) + term24 = _batch_trace_XXT(A.matmul(p._unbroadcasted_cov_factor)) + term2 = term21 + term22 - term23 - term24 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(MultivariateNormal, LowRankMultivariateNormal) +def _kl_multivariatenormal_lowrankmultivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed" + ) + + term1 = _batch_lowrank_logdet( + q._unbroadcasted_cov_factor, q._unbroadcasted_cov_diag, q._capacitance_tril + ) - 2 * p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + term3 = _batch_lowrank_mahalanobis( + q._unbroadcasted_cov_factor, + q._unbroadcasted_cov_diag, + q.loc - p.loc, + q._capacitance_tril, + ) + # Expands term2 according to + # inv(qcov) @ pcov = [inv(qD) - inv(qD) @ qW @ inv(qC) @ qW.T @ inv(qD)] @ p_tril @ p_tril.T + # = [inv(qD) - A.T @ A] @ p_tril @ p_tril.T + qWt_qDinv = q._unbroadcasted_cov_factor.mT / q._unbroadcasted_cov_diag.unsqueeze(-2) + A = torch.linalg.solve_triangular(q._capacitance_tril, qWt_qDinv, upper=False) + term21 = _batch_trace_XXT( + p._unbroadcasted_scale_tril * q._unbroadcasted_cov_diag.rsqrt().unsqueeze(-1) + ) + term22 = _batch_trace_XXT(A.matmul(p._unbroadcasted_scale_tril)) + term2 = term21 - term22 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(LowRankMultivariateNormal, MultivariateNormal) +def _kl_lowrankmultivariatenormal_multivariatenormal(p, q): + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two (Low Rank) Multivariate Normals with\ + different event shapes cannot be computed" + ) + + term1 = 2 * q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( + -1 + ) - _batch_lowrank_logdet( + p._unbroadcasted_cov_factor, p._unbroadcasted_cov_diag, p._capacitance_tril + ) + term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) + # Expands term2 according to + # inv(qcov) @ pcov = inv(q_tril @ q_tril.T) @ (pW @ pW.T + pD) + combined_batch_shape = torch._C._infer_size( + q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_cov_factor.shape[:-2] + ) + n = p.event_shape[0] + q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + p_cov_factor = p._unbroadcasted_cov_factor.expand( + combined_batch_shape + (n, p.cov_factor.size(-1)) + ) + p_cov_diag = torch.diag_embed(p._unbroadcasted_cov_diag.sqrt()).expand( + combined_batch_shape + (n, n) + ) + term21 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_cov_factor, upper=False) + ) + term22 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_cov_diag, upper=False) + ) + term2 = term21 + term22 + return 0.5 * (term1 + term2 + term3 - p.event_shape[0]) + + +@register_kl(MultivariateNormal, MultivariateNormal) +def _kl_multivariatenormal_multivariatenormal(p, q): + # From https://en.wikipedia.org/wiki/Multivariate_normal_distribution#Kullback%E2%80%93Leibler_divergence + if p.event_shape != q.event_shape: + raise ValueError( + "KL-divergence between two Multivariate Normals with\ + different event shapes cannot be computed" + ) + + half_term1 = q._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum( + -1 + ) - p._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + combined_batch_shape = torch._C._infer_size( + q._unbroadcasted_scale_tril.shape[:-2], p._unbroadcasted_scale_tril.shape[:-2] + ) + n = p.event_shape[0] + q_scale_tril = q._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + p_scale_tril = p._unbroadcasted_scale_tril.expand(combined_batch_shape + (n, n)) + term2 = _batch_trace_XXT( + torch.linalg.solve_triangular(q_scale_tril, p_scale_tril, upper=False) + ) + term3 = _batch_mahalanobis(q._unbroadcasted_scale_tril, (q.loc - p.loc)) + return half_term1 + 0.5 * (term2 + term3 - n) + + +@register_kl(Normal, Normal) +def _kl_normal_normal(p, q): + var_ratio = (p.scale / q.scale).pow(2) + t1 = ((p.loc - q.loc) / q.scale).pow(2) + return 0.5 * (var_ratio + t1 - 1 - var_ratio.log()) + + +@register_kl(OneHotCategorical, OneHotCategorical) +def _kl_onehotcategorical_onehotcategorical(p, q): + return _kl_categorical_categorical(p._categorical, q._categorical) + + +@register_kl(Pareto, Pareto) +def _kl_pareto_pareto(p, q): + # From http://www.mast.queensu.ca/~communications/Papers/gil-msc11.pdf + scale_ratio = p.scale / q.scale + alpha_ratio = q.alpha / p.alpha + t1 = q.alpha * scale_ratio.log() + t2 = -alpha_ratio.log() + result = t1 + t2 + alpha_ratio - 1 + result[p.support.lower_bound < q.support.lower_bound] = inf + return result + + +@register_kl(Poisson, Poisson) +def _kl_poisson_poisson(p, q): + return p.rate * (p.rate.log() - q.rate.log()) - (p.rate - q.rate) + + +@register_kl(TransformedDistribution, TransformedDistribution) +def _kl_transformed_transformed(p, q): + if p.transforms != q.transforms: + raise NotImplementedError + if p.event_shape != q.event_shape: + raise NotImplementedError + return kl_divergence(p.base_dist, q.base_dist) + + +@register_kl(Uniform, Uniform) +def _kl_uniform_uniform(p, q): + result = ((q.high - q.low) / (p.high - p.low)).log() + result[(q.low > p.low) | (q.high < p.high)] = inf + return result + + +# Different distributions +@register_kl(Bernoulli, Poisson) +def _kl_bernoulli_poisson(p, q): + return -p.entropy() - (p.probs * q.rate.log() - q.rate) + + +@register_kl(Beta, ContinuousBernoulli) +def _kl_beta_continuous_bernoulli(p, q): + return ( + -p.entropy() + - p.mean * q.logits + - torch.log1p(-q.probs) + - q._cont_bern_log_norm() + ) + + +@register_kl(Beta, Pareto) +def _kl_beta_infinity(p, q): + return _infinite_like(p.concentration1) + + +@register_kl(Beta, Exponential) +def _kl_beta_exponential(p, q): + return ( + -p.entropy() + - q.rate.log() + + q.rate * (p.concentration1 / (p.concentration1 + p.concentration0)) + ) + + +@register_kl(Beta, Gamma) +def _kl_beta_gamma(p, q): + t1 = -p.entropy() + t2 = q.concentration.lgamma() - q.concentration * q.rate.log() + t3 = (q.concentration - 1) * ( + p.concentration1.digamma() - (p.concentration1 + p.concentration0).digamma() + ) + t4 = q.rate * p.concentration1 / (p.concentration1 + p.concentration0) + return t1 + t2 - t3 + t4 + + +# TODO: Add Beta-Laplace KL Divergence + + +@register_kl(Beta, Normal) +def _kl_beta_normal(p, q): + E_beta = p.concentration1 / (p.concentration1 + p.concentration0) + var_normal = q.scale.pow(2) + t1 = -p.entropy() + t2 = 0.5 * (var_normal * 2 * math.pi).log() + t3 = ( + E_beta * (1 - E_beta) / (p.concentration1 + p.concentration0 + 1) + + E_beta.pow(2) + ) * 0.5 + t4 = q.loc * E_beta + t5 = q.loc.pow(2) * 0.5 + return t1 + t2 + (t3 - t4 + t5) / var_normal + + +@register_kl(Beta, Uniform) +def _kl_beta_uniform(p, q): + result = -p.entropy() + (q.high - q.low).log() + result[(q.low > p.support.lower_bound) | (q.high < p.support.upper_bound)] = inf + return result + + +# Note that the KL between a ContinuousBernoulli and Beta has no closed form + + +@register_kl(ContinuousBernoulli, Pareto) +def _kl_continuous_bernoulli_infinity(p, q): + return _infinite_like(p.probs) + + +@register_kl(ContinuousBernoulli, Exponential) +def _kl_continuous_bernoulli_exponential(p, q): + return -p.entropy() - torch.log(q.rate) + q.rate * p.mean + + +# Note that the KL between a ContinuousBernoulli and Gamma has no closed form +# TODO: Add ContinuousBernoulli-Laplace KL Divergence + + +@register_kl(ContinuousBernoulli, Normal) +def _kl_continuous_bernoulli_normal(p, q): + t1 = -p.entropy() + t2 = 0.5 * (math.log(2.0 * math.pi) + torch.square(q.loc / q.scale)) + torch.log( + q.scale + ) + t3 = (p.variance + torch.square(p.mean) - 2.0 * q.loc * p.mean) / ( + 2.0 * torch.square(q.scale) + ) + return t1 + t2 + t3 + + +@register_kl(ContinuousBernoulli, Uniform) +def _kl_continuous_bernoulli_uniform(p, q): + result = -p.entropy() + (q.high - q.low).log() + return torch.where( + torch.max( + torch.ge(q.low, p.support.lower_bound), + torch.le(q.high, p.support.upper_bound), + ), + torch.ones_like(result) * inf, + result, + ) + + +@register_kl(Exponential, Beta) +@register_kl(Exponential, ContinuousBernoulli) +@register_kl(Exponential, Pareto) +@register_kl(Exponential, Uniform) +def _kl_exponential_infinity(p, q): + return _infinite_like(p.rate) + + +@register_kl(Exponential, Gamma) +def _kl_exponential_gamma(p, q): + ratio = q.rate / p.rate + t1 = -q.concentration * torch.log(ratio) + return ( + t1 + + ratio + + q.concentration.lgamma() + + q.concentration * _euler_gamma + - (1 + _euler_gamma) + ) + + +@register_kl(Exponential, Gumbel) +def _kl_exponential_gumbel(p, q): + scale_rate_prod = p.rate * q.scale + loc_scale_ratio = q.loc / q.scale + t1 = scale_rate_prod.log() - 1 + t2 = torch.exp(loc_scale_ratio) * scale_rate_prod / (scale_rate_prod + 1) + t3 = scale_rate_prod.reciprocal() + return t1 - loc_scale_ratio + t2 + t3 + + +# TODO: Add Exponential-Laplace KL Divergence + + +@register_kl(Exponential, Normal) +def _kl_exponential_normal(p, q): + var_normal = q.scale.pow(2) + rate_sqr = p.rate.pow(2) + t1 = 0.5 * torch.log(rate_sqr * var_normal * 2 * math.pi) + t2 = rate_sqr.reciprocal() + t3 = q.loc / p.rate + t4 = q.loc.pow(2) * 0.5 + return t1 - 1 + (t2 - t3 + t4) / var_normal + + +@register_kl(Gamma, Beta) +@register_kl(Gamma, ContinuousBernoulli) +@register_kl(Gamma, Pareto) +@register_kl(Gamma, Uniform) +def _kl_gamma_infinity(p, q): + return _infinite_like(p.concentration) + + +@register_kl(Gamma, Exponential) +def _kl_gamma_exponential(p, q): + return -p.entropy() - q.rate.log() + q.rate * p.concentration / p.rate + + +@register_kl(Gamma, Gumbel) +def _kl_gamma_gumbel(p, q): + beta_scale_prod = p.rate * q.scale + loc_scale_ratio = q.loc / q.scale + t1 = ( + (p.concentration - 1) * p.concentration.digamma() + - p.concentration.lgamma() + - p.concentration + ) + t2 = beta_scale_prod.log() + p.concentration / beta_scale_prod + t3 = ( + torch.exp(loc_scale_ratio) + * (1 + beta_scale_prod.reciprocal()).pow(-p.concentration) + - loc_scale_ratio + ) + return t1 + t2 + t3 + + +# TODO: Add Gamma-Laplace KL Divergence + + +@register_kl(Gamma, Normal) +def _kl_gamma_normal(p, q): + var_normal = q.scale.pow(2) + beta_sqr = p.rate.pow(2) + t1 = ( + 0.5 * torch.log(beta_sqr * var_normal * 2 * math.pi) + - p.concentration + - p.concentration.lgamma() + ) + t2 = 0.5 * (p.concentration.pow(2) + p.concentration) / beta_sqr + t3 = q.loc * p.concentration / p.rate + t4 = 0.5 * q.loc.pow(2) + return ( + t1 + + (p.concentration - 1) * p.concentration.digamma() + + (t2 - t3 + t4) / var_normal + ) + + +@register_kl(Gumbel, Beta) +@register_kl(Gumbel, ContinuousBernoulli) +@register_kl(Gumbel, Exponential) +@register_kl(Gumbel, Gamma) +@register_kl(Gumbel, Pareto) +@register_kl(Gumbel, Uniform) +def _kl_gumbel_infinity(p, q): + return _infinite_like(p.loc) + + +# TODO: Add Gumbel-Laplace KL Divergence + + +@register_kl(Gumbel, Normal) +def _kl_gumbel_normal(p, q): + param_ratio = p.scale / q.scale + t1 = (param_ratio / math.sqrt(2 * math.pi)).log() + t2 = (math.pi * param_ratio * 0.5).pow(2) / 3 + t3 = ((p.loc + p.scale * _euler_gamma - q.loc) / q.scale).pow(2) * 0.5 + return -t1 + t2 + t3 - (_euler_gamma + 1) + + +@register_kl(Laplace, Beta) +@register_kl(Laplace, ContinuousBernoulli) +@register_kl(Laplace, Exponential) +@register_kl(Laplace, Gamma) +@register_kl(Laplace, Pareto) +@register_kl(Laplace, Uniform) +def _kl_laplace_infinity(p, q): + return _infinite_like(p.loc) + + +@register_kl(Laplace, Normal) +def _kl_laplace_normal(p, q): + var_normal = q.scale.pow(2) + scale_sqr_var_ratio = p.scale.pow(2) / var_normal + t1 = 0.5 * torch.log(2 * scale_sqr_var_ratio / math.pi) + t2 = 0.5 * p.loc.pow(2) + t3 = p.loc * q.loc + t4 = 0.5 * q.loc.pow(2) + return -t1 + scale_sqr_var_ratio + (t2 - t3 + t4) / var_normal - 1 + + +@register_kl(Normal, Beta) +@register_kl(Normal, ContinuousBernoulli) +@register_kl(Normal, Exponential) +@register_kl(Normal, Gamma) +@register_kl(Normal, Pareto) +@register_kl(Normal, Uniform) +def _kl_normal_infinity(p, q): + return _infinite_like(p.loc) + + +@register_kl(Normal, Gumbel) +def _kl_normal_gumbel(p, q): + mean_scale_ratio = p.loc / q.scale + var_scale_sqr_ratio = (p.scale / q.scale).pow(2) + loc_scale_ratio = q.loc / q.scale + t1 = var_scale_sqr_ratio.log() * 0.5 + t2 = mean_scale_ratio - loc_scale_ratio + t3 = torch.exp(-mean_scale_ratio + 0.5 * var_scale_sqr_ratio + loc_scale_ratio) + return -t1 + t2 + t3 - (0.5 * (1 + math.log(2 * math.pi))) + + +@register_kl(Normal, Laplace) +def _kl_normal_laplace(p, q): + loc_diff = p.loc - q.loc + scale_ratio = p.scale / q.scale + loc_diff_scale_ratio = loc_diff / p.scale + t1 = torch.log(scale_ratio) + t2 = ( + math.sqrt(2 / math.pi) * p.scale * torch.exp(-0.5 * loc_diff_scale_ratio.pow(2)) + ) + t3 = loc_diff * torch.erf(math.sqrt(0.5) * loc_diff_scale_ratio) + return -t1 + (t2 + t3) / q.scale - (0.5 * (1 + math.log(0.5 * math.pi))) + + +@register_kl(Pareto, Beta) +@register_kl(Pareto, ContinuousBernoulli) +@register_kl(Pareto, Uniform) +def _kl_pareto_infinity(p, q): + return _infinite_like(p.scale) + + +@register_kl(Pareto, Exponential) +def _kl_pareto_exponential(p, q): + scale_rate_prod = p.scale * q.rate + t1 = (p.alpha / scale_rate_prod).log() + t2 = p.alpha.reciprocal() + t3 = p.alpha * scale_rate_prod / (p.alpha - 1) + result = t1 - t2 + t3 - 1 + result[p.alpha <= 1] = inf + return result + + +@register_kl(Pareto, Gamma) +def _kl_pareto_gamma(p, q): + common_term = p.scale.log() + p.alpha.reciprocal() + t1 = p.alpha.log() - common_term + t2 = q.concentration.lgamma() - q.concentration * q.rate.log() + t3 = (1 - q.concentration) * common_term + t4 = q.rate * p.alpha * p.scale / (p.alpha - 1) + result = t1 + t2 + t3 + t4 - 1 + result[p.alpha <= 1] = inf + return result + + +# TODO: Add Pareto-Laplace KL Divergence + + +@register_kl(Pareto, Normal) +def _kl_pareto_normal(p, q): + var_normal = 2 * q.scale.pow(2) + common_term = p.scale / (p.alpha - 1) + t1 = (math.sqrt(2 * math.pi) * q.scale * p.alpha / p.scale).log() + t2 = p.alpha.reciprocal() + t3 = p.alpha * common_term.pow(2) / (p.alpha - 2) + t4 = (p.alpha * common_term - q.loc).pow(2) + result = t1 - t2 + (t3 + t4) / var_normal - 1 + result[p.alpha <= 2] = inf + return result + + +@register_kl(Poisson, Bernoulli) +@register_kl(Poisson, Binomial) +def _kl_poisson_infinity(p, q): + return _infinite_like(p.rate) + + +@register_kl(Uniform, Beta) +def _kl_uniform_beta(p, q): + common_term = p.high - p.low + t1 = torch.log(common_term) + t2 = ( + (q.concentration1 - 1) + * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) + / common_term + ) + t3 = ( + (q.concentration0 - 1) + * (_x_log_x(1 - p.high) - _x_log_x(1 - p.low) + common_term) + / common_term + ) + t4 = ( + q.concentration1.lgamma() + + q.concentration0.lgamma() + - (q.concentration1 + q.concentration0).lgamma() + ) + result = t3 + t4 - t1 - t2 + result[(p.high > q.support.upper_bound) | (p.low < q.support.lower_bound)] = inf + return result + + +@register_kl(Uniform, ContinuousBernoulli) +def _kl_uniform_continuous_bernoulli(p, q): + result = ( + -p.entropy() + - p.mean * q.logits + - torch.log1p(-q.probs) + - q._cont_bern_log_norm() + ) + return torch.where( + torch.max( + torch.ge(p.high, q.support.upper_bound), + torch.le(p.low, q.support.lower_bound), + ), + torch.ones_like(result) * inf, + result, + ) + + +@register_kl(Uniform, Exponential) +def _kl_uniform_exponetial(p, q): + result = q.rate * (p.high + p.low) / 2 - ((p.high - p.low) * q.rate).log() + result[p.low < q.support.lower_bound] = inf + return result + + +@register_kl(Uniform, Gamma) +def _kl_uniform_gamma(p, q): + common_term = p.high - p.low + t1 = common_term.log() + t2 = q.concentration.lgamma() - q.concentration * q.rate.log() + t3 = ( + (1 - q.concentration) + * (_x_log_x(p.high) - _x_log_x(p.low) - common_term) + / common_term + ) + t4 = q.rate * (p.high + p.low) / 2 + result = -t1 + t2 + t3 + t4 + result[p.low < q.support.lower_bound] = inf + return result + + +@register_kl(Uniform, Gumbel) +def _kl_uniform_gumbel(p, q): + common_term = q.scale / (p.high - p.low) + high_loc_diff = (p.high - q.loc) / q.scale + low_loc_diff = (p.low - q.loc) / q.scale + t1 = common_term.log() + 0.5 * (high_loc_diff + low_loc_diff) + t2 = common_term * (torch.exp(-high_loc_diff) - torch.exp(-low_loc_diff)) + return t1 - t2 + + +# TODO: Uniform-Laplace KL Divergence + + +@register_kl(Uniform, Normal) +def _kl_uniform_normal(p, q): + common_term = p.high - p.low + t1 = (math.sqrt(math.pi * 2) * q.scale / common_term).log() + t2 = (common_term).pow(2) / 12 + t3 = ((p.high + p.low - 2 * q.loc) / 2).pow(2) + return t1 + 0.5 * (t2 + t3) / q.scale.pow(2) + + +@register_kl(Uniform, Pareto) +def _kl_uniform_pareto(p, q): + support_uniform = p.high - p.low + t1 = (q.alpha * q.scale.pow(q.alpha) * (support_uniform)).log() + t2 = (_x_log_x(p.high) - _x_log_x(p.low) - support_uniform) / support_uniform + result = t2 * (q.alpha + 1) - t1 + result[p.low < q.support.lower_bound] = inf + return result + + +@register_kl(Independent, Independent) +def _kl_independent_independent(p, q): + if p.reinterpreted_batch_ndims != q.reinterpreted_batch_ndims: + raise NotImplementedError + result = kl_divergence(p.base_dist, q.base_dist) + return _sum_rightmost(result, p.reinterpreted_batch_ndims) + + +@register_kl(Cauchy, Cauchy) +def _kl_cauchy_cauchy(p, q): + # From https://arxiv.org/abs/1905.10965 + t1 = ((p.scale + q.scale).pow(2) + (p.loc - q.loc).pow(2)).log() + t2 = (4 * p.scale * q.scale).log() + return t1 - t2 + + +def _add_kl_info(): + """Appends a list of implemented KL functions to the doc for kl_divergence.""" + rows = [ + "KL divergence is currently implemented for the following distribution pairs:" + ] + for p, q in sorted( + _KL_REGISTRY, key=lambda p_q: (p_q[0].__name__, p_q[1].__name__) + ): + rows.append( + f"* :class:`~torch.distributions.{p.__name__}` and :class:`~torch.distributions.{q.__name__}`" + ) + kl_info = "\n\t".join(rows) + if kl_divergence.__doc__: + kl_divergence.__doc__ += kl_info diff --git a/phivenv/Lib/site-packages/torch/distributions/kumaraswamy.py b/phivenv/Lib/site-packages/torch/distributions/kumaraswamy.py new file mode 100644 index 0000000000000000000000000000000000000000..1ce8808ffea7c5c57951a2ebc9df3e32631b4e82 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/kumaraswamy.py @@ -0,0 +1,106 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import nan, Tensor +from torch.distributions import constraints +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.uniform import Uniform +from torch.distributions.utils import broadcast_all, euler_constant + + +__all__ = ["Kumaraswamy"] + + +def _moments(a, b, n): + """ + Computes nth moment of Kumaraswamy using using torch.lgamma + """ + arg1 = 1 + n / a + log_value = torch.lgamma(arg1) + torch.lgamma(b) - torch.lgamma(arg1 + b) + return b * torch.exp(log_value) + + +class Kumaraswamy(TransformedDistribution): + r""" + Samples from a Kumaraswamy distribution. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Kumaraswamy(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Kumaraswamy distribution with concentration alpha=1 and beta=1 + tensor([ 0.1729]) + + Args: + concentration1 (float or Tensor): 1st concentration parameter of the distribution + (often referred to as alpha) + concentration0 (float or Tensor): 2nd concentration parameter of the distribution + (often referred to as beta) + """ + + arg_constraints = { + "concentration1": constraints.positive, + "concentration0": constraints.positive, + } + support = constraints.unit_interval + has_rsample = True + + def __init__( + self, + concentration1: Union[Tensor, float], + concentration0: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.concentration1, self.concentration0 = broadcast_all( + concentration1, concentration0 + ) + base_dist = Uniform( + torch.full_like(self.concentration0, 0), + torch.full_like(self.concentration0, 1), + validate_args=validate_args, + ) + transforms = [ + PowerTransform(exponent=self.concentration0.reciprocal()), + AffineTransform(loc=1.0, scale=-1.0), + PowerTransform(exponent=self.concentration1.reciprocal()), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Kumaraswamy, _instance) + new.concentration1 = self.concentration1.expand(batch_shape) + new.concentration0 = self.concentration0.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + @property + def mean(self) -> Tensor: + return _moments(self.concentration1, self.concentration0, 1) + + @property + def mode(self) -> Tensor: + # Evaluate in log-space for numerical stability. + log_mode = ( + self.concentration0.reciprocal() * (-self.concentration0).log1p() + - (-self.concentration0 * self.concentration1).log1p() + ) + log_mode[(self.concentration0 < 1) | (self.concentration1 < 1)] = nan + return log_mode.exp() + + @property + def variance(self) -> Tensor: + return _moments(self.concentration1, self.concentration0, 2) - torch.pow( + self.mean, 2 + ) + + def entropy(self): + t1 = 1 - self.concentration1.reciprocal() + t0 = 1 - self.concentration0.reciprocal() + H0 = torch.digamma(self.concentration0 + 1) + euler_constant + return ( + t0 + + t1 * H0 + - torch.log(self.concentration1) + - torch.log(self.concentration0) + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/laplace.py b/phivenv/Lib/site-packages/torch/distributions/laplace.py new file mode 100644 index 0000000000000000000000000000000000000000..64eee907e4b7755a0a9e99b9e4c7ca432def2bd4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/laplace.py @@ -0,0 +1,104 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Laplace"] + + +class Laplace(Distribution): + r""" + Creates a Laplace distribution parameterized by :attr:`loc` and :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Laplace(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # Laplace distributed with loc=0, scale=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of the distribution + scale (float or Tensor): scale of the distribution + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + + @property + def mean(self) -> Tensor: + return self.loc + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def variance(self) -> Tensor: + return 2 * self.scale.pow(2) + + @property + def stddev(self) -> Tensor: + return (2**0.5) * self.scale + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, _Number) and isinstance(scale, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Laplace, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Laplace, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + finfo = torch.finfo(self.loc.dtype) + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .uniform_() + u = torch.rand(shape, dtype=self.loc.dtype, device=self.loc.device) * 2 - 1 + return self.loc - self.scale * u.sign() * torch.log1p( + -u.abs().clamp(min=finfo.tiny) + ) + u = self.loc.new(shape).uniform_(finfo.eps - 1, 1) + # TODO: If we ever implement tensor.nextafter, below is what we want ideally. + # u = self.loc.new(shape).uniform_(self.loc.nextafter(-.5, 0), .5) + return self.loc - self.scale * u.sign() * torch.log1p(-u.abs()) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 0.5 - 0.5 * (value - self.loc).sign() * torch.expm1( + -(value - self.loc).abs() / self.scale + ) + + def icdf(self, value): + term = value - 0.5 + return self.loc - self.scale * (term).sign() * torch.log1p(-2 * term.abs()) + + def entropy(self): + return 1 + torch.log(2 * self.scale) diff --git a/phivenv/Lib/site-packages/torch/distributions/lkj_cholesky.py b/phivenv/Lib/site-packages/torch/distributions/lkj_cholesky.py new file mode 100644 index 0000000000000000000000000000000000000000..a402a07be7918be8d571665aafe28aec2da1d82f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/lkj_cholesky.py @@ -0,0 +1,152 @@ +# mypy: allow-untyped-defs +""" +This closely follows the implementation in NumPyro (https://github.com/pyro-ppl/numpyro). + +Original copyright notice: + +# Copyright: Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 +""" + +import math +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import Beta, constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + + +__all__ = ["LKJCholesky"] + + +class LKJCholesky(Distribution): + r""" + LKJ distribution for lower Cholesky factor of correlation matrices. + The distribution is controlled by ``concentration`` parameter :math:`\eta` + to make the probability of the correlation matrix :math:`M` generated from + a Cholesky factor proportional to :math:`\det(M)^{\eta - 1}`. Because of that, + when ``concentration == 1``, we have a uniform distribution over Cholesky + factors of correlation matrices:: + + L ~ LKJCholesky(dim, concentration) + X = L @ L' ~ LKJCorr(dim, concentration) + + Note that this distribution samples the + Cholesky factor of correlation matrices and not the correlation matrices + themselves and thereby differs slightly from the derivations in [1] for + the `LKJCorr` distribution. For sampling, this uses the Onion method from + [1] Section 3. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> l = LKJCholesky(3, 0.5) + >>> l.sample() # l @ l.T is a sample of a correlation 3x3 matrix + tensor([[ 1.0000, 0.0000, 0.0000], + [ 0.3516, 0.9361, 0.0000], + [-0.1899, 0.4748, 0.8593]]) + + Args: + dimension (dim): dimension of the matrices + concentration (float or Tensor): concentration/shape parameter of the + distribution (often referred to as eta) + + **References** + + [1] `Generating random correlation matrices based on vines and extended onion method` (2009), + Daniel Lewandowski, Dorota Kurowicka, Harry Joe. + Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 + """ + + arg_constraints = {"concentration": constraints.positive} + support = constraints.corr_cholesky + + def __init__( + self, + dim: int, + concentration: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: + if dim < 2: + raise ValueError( + f"Expected dim to be an integer greater than or equal to 2. Found dim={dim}." + ) + self.dim = dim + (self.concentration,) = broadcast_all(concentration) + batch_shape = self.concentration.size() + event_shape = torch.Size((dim, dim)) + # This is used to draw vectorized samples from the beta distribution in Sec. 3.2 of [1]. + marginal_conc = self.concentration + 0.5 * (self.dim - 2) + offset = torch.arange( + self.dim - 1, + dtype=self.concentration.dtype, + device=self.concentration.device, + ) + offset = torch.cat([offset.new_zeros((1,)), offset]) + beta_conc1 = offset + 0.5 + beta_conc0 = marginal_conc.unsqueeze(-1) - 0.5 * offset + self._beta = Beta(beta_conc1, beta_conc0) + super().__init__(batch_shape, event_shape, validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LKJCholesky, _instance) + batch_shape = torch.Size(batch_shape) + new.dim = self.dim + new.concentration = self.concentration.expand(batch_shape) + new._beta = self._beta.expand(batch_shape + (self.dim,)) + super(LKJCholesky, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + # This uses the Onion method, but there are a few differences from [1] Sec. 3.2: + # - This vectorizes the for loop and also works for heterogeneous eta. + # - Same algorithm generalizes to n=1. + # - The procedure is simplified since we are sampling the cholesky factor of + # the correlation matrix instead of the correlation matrix itself. As such, + # we only need to generate `w`. + y = self._beta.sample(sample_shape).unsqueeze(-1) + u_normal = torch.randn( + self._extended_shape(sample_shape), dtype=y.dtype, device=y.device + ).tril(-1) + u_hypersphere = u_normal / u_normal.norm(dim=-1, keepdim=True) + # Replace NaNs in first row + u_hypersphere[..., 0, :].fill_(0.0) + w = torch.sqrt(y) * u_hypersphere + # Fill diagonal elements; clamp for numerical stability + eps = torch.finfo(w.dtype).tiny + diag_elems = torch.clamp(1 - torch.sum(w**2, dim=-1), min=eps).sqrt() + w += torch.diag_embed(diag_elems) + return w + + def log_prob(self, value): + # See: https://mc-stan.org/docs/2_25/functions-reference/cholesky-lkj-correlation-distribution.html + # The probability of a correlation matrix is proportional to + # determinant ** (concentration - 1) = prod(L_ii ^ 2(concentration - 1)) + # Additionally, the Jacobian of the transformation from Cholesky factor to + # correlation matrix is: + # prod(L_ii ^ (D - i)) + # So the probability of a Cholesky factor is propotional to + # prod(L_ii ^ (2 * concentration - 2 + D - i)) = prod(L_ii ^ order_i) + # with order_i = 2 * concentration - 2 + D - i + if self._validate_args: + self._validate_sample(value) + diag_elems = value.diagonal(dim1=-1, dim2=-2)[..., 1:] + order = torch.arange(2, self.dim + 1, device=self.concentration.device) + order = 2 * (self.concentration - 1).unsqueeze(-1) + self.dim - order + unnormalized_log_pdf = torch.sum(order * diag_elems.log(), dim=-1) + # Compute normalization constant (page 1999 of [1]) + dm1 = self.dim - 1 + alpha = self.concentration + 0.5 * dm1 + denominator = torch.lgamma(alpha) * dm1 + numerator = torch.mvlgamma(alpha - 0.5, dm1) + # pi_constant in [1] is D * (D - 1) / 4 * log(pi) + # pi_constant in multigammaln is (D - 1) * (D - 2) / 4 * log(pi) + # hence, we need to add a pi_constant = (D - 1) * log(pi) / 2 + pi_constant = 0.5 * dm1 * math.log(math.pi) + normalize_term = pi_constant + numerator - denominator + return unnormalized_log_pdf - normalize_term diff --git a/phivenv/Lib/site-packages/torch/distributions/log_normal.py b/phivenv/Lib/site-packages/torch/distributions/log_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6e04dde876d9f954bb18ee5c427066fcaaee05 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/log_normal.py @@ -0,0 +1,74 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform + + +__all__ = ["LogNormal"] + + +class LogNormal(TransformedDistribution): + r""" + Creates a log-normal distribution parameterized by + :attr:`loc` and :attr:`scale` where:: + + X ~ Normal(loc, scale) + Y = exp(X) ~ LogNormal(loc, scale) + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LogNormal(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # log-normal distributed with mean=0 and stddev=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of log of distribution + scale (float or Tensor): standard deviation of log of the distribution + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.positive + has_rsample = True + base_dist: Normal + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + base_dist = Normal(loc, scale, validate_args=validate_args) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def loc(self) -> Tensor: + return self.base_dist.loc + + @property + def scale(self) -> Tensor: + return self.base_dist.scale + + @property + def mean(self) -> Tensor: + return (self.loc + self.scale.pow(2) / 2).exp() + + @property + def mode(self) -> Tensor: + return (self.loc - self.scale.square()).exp() + + @property + def variance(self) -> Tensor: + scale_sq = self.scale.pow(2) + return scale_sq.expm1() * (2 * self.loc + scale_sq).exp() + + def entropy(self): + return self.base_dist.entropy() + self.loc diff --git a/phivenv/Lib/site-packages/torch/distributions/logistic_normal.py b/phivenv/Lib/site-packages/torch/distributions/logistic_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ce3b0debfe6d24660c825b82f490b2cd3ecebc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/logistic_normal.py @@ -0,0 +1,66 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +from torch import Tensor +from torch.distributions import constraints, Independent +from torch.distributions.normal import Normal +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import StickBreakingTransform + + +__all__ = ["LogisticNormal"] + + +class LogisticNormal(TransformedDistribution): + r""" + Creates a logistic-normal distribution parameterized by :attr:`loc` and :attr:`scale` + that define the base `Normal` distribution transformed with the + `StickBreakingTransform` such that:: + + X ~ LogisticNormal(loc, scale) + Y = log(X / (1 - X.cumsum(-1)))[..., :-1] ~ Normal(loc, scale) + + Args: + loc (float or Tensor): mean of the base distribution + scale (float or Tensor): standard deviation of the base distribution + + Example:: + + >>> # logistic-normal distributed with mean=(0, 0, 0) and stddev=(1, 1, 1) + >>> # of the base Normal distribution + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LogisticNormal(torch.tensor([0.0] * 3), torch.tensor([1.0] * 3)) + >>> m.sample() + tensor([ 0.7653, 0.0341, 0.0579, 0.1427]) + + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.simplex + has_rsample = True + base_dist: Independent[Normal] + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + base_dist = Normal(loc, scale, validate_args=validate_args) + if not base_dist.batch_shape: + base_dist = base_dist.expand([1]) + super().__init__( + base_dist, StickBreakingTransform(), validate_args=validate_args + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogisticNormal, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def loc(self) -> Tensor: + return self.base_dist.base_dist.loc + + @property + def scale(self) -> Tensor: + return self.base_dist.base_dist.scale diff --git a/phivenv/Lib/site-packages/torch/distributions/lowrank_multivariate_normal.py b/phivenv/Lib/site-packages/torch/distributions/lowrank_multivariate_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..7874e7416fbce3d2ae36ae749d920757038cb626 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/lowrank_multivariate_normal.py @@ -0,0 +1,251 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.multivariate_normal import _batch_mahalanobis, _batch_mv +from torch.distributions.utils import _standard_normal, lazy_property +from torch.types import _size + + +__all__ = ["LowRankMultivariateNormal"] + + +def _batch_capacitance_tril(W, D): + r""" + Computes Cholesky of :math:`I + W.T @ inv(D) @ W` for a batch of matrices :math:`W` + and a batch of vectors :math:`D`. + """ + m = W.size(-1) + Wt_Dinv = W.mT / D.unsqueeze(-2) + K = torch.matmul(Wt_Dinv, W).contiguous() + K.view(-1, m * m)[:, :: m + 1] += 1 # add identity matrix to K + return torch.linalg.cholesky(K) + + +def _batch_lowrank_logdet(W, D, capacitance_tril): + r""" + Uses "matrix determinant lemma":: + log|W @ W.T + D| = log|C| + log|D|, + where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute + the log determinant. + """ + return 2 * capacitance_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + D.log().sum( + -1 + ) + + +def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): + r""" + Uses "Woodbury matrix identity":: + inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D), + where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the squared + Mahalanobis distance :math:`x.T @ inv(W @ W.T + D) @ x`. + """ + Wt_Dinv = W.mT / D.unsqueeze(-2) + Wt_Dinv_x = _batch_mv(Wt_Dinv, x) + mahalanobis_term1 = (x.pow(2) / D).sum(-1) + mahalanobis_term2 = _batch_mahalanobis(capacitance_tril, Wt_Dinv_x) + return mahalanobis_term1 - mahalanobis_term2 + + +class LowRankMultivariateNormal(Distribution): + r""" + Creates a multivariate normal distribution with covariance matrix having a low-rank form + parameterized by :attr:`cov_factor` and :attr:`cov_diag`:: + + covariance_matrix = cov_factor @ cov_factor.T + cov_diag + + Example: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = LowRankMultivariateNormal( + ... torch.zeros(2), torch.tensor([[1.0], [0.0]]), torch.ones(2) + ... ) + >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` + tensor([-0.2102, -0.5429]) + + Args: + loc (Tensor): mean of the distribution with shape `batch_shape + event_shape` + cov_factor (Tensor): factor part of low-rank form of covariance matrix with shape + `batch_shape + event_shape + (rank,)` + cov_diag (Tensor): diagonal part of low-rank form of covariance matrix with shape + `batch_shape + event_shape` + + Note: + The computation for determinant and inverse of covariance matrix is avoided when + `cov_factor.shape[1] << cov_factor.shape[0]` thanks to `Woodbury matrix identity + `_ and + `matrix determinant lemma `_. + Thanks to these formulas, we just need to compute the determinant and inverse of + the small size "capacitance" matrix:: + + capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor + """ + + arg_constraints = { + "loc": constraints.real_vector, + "cov_factor": constraints.independent(constraints.real, 2), + "cov_diag": constraints.independent(constraints.positive, 1), + } + support = constraints.real_vector + has_rsample = True + + def __init__( + self, + loc: Tensor, + cov_factor: Tensor, + cov_diag: Tensor, + validate_args: Optional[bool] = None, + ) -> None: + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + event_shape = loc.shape[-1:] + if cov_factor.dim() < 2: + raise ValueError( + "cov_factor must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + if cov_factor.shape[-2:-1] != event_shape: + raise ValueError( + f"cov_factor must be a batch of matrices with shape {event_shape[0]} x m" + ) + if cov_diag.shape[-1:] != event_shape: + raise ValueError( + f"cov_diag must be a batch of vectors with shape {event_shape}" + ) + + loc_ = loc.unsqueeze(-1) + cov_diag_ = cov_diag.unsqueeze(-1) + try: + loc_, self.cov_factor, cov_diag_ = torch.broadcast_tensors( + loc_, cov_factor, cov_diag_ + ) + except RuntimeError as e: + raise ValueError( + f"Incompatible batch shapes: loc {loc.shape}, cov_factor {cov_factor.shape}, cov_diag {cov_diag.shape}" + ) from e + self.loc = loc_[..., 0] + self.cov_diag = cov_diag_[..., 0] + batch_shape = self.loc.shape[:-1] + + self._unbroadcasted_cov_factor = cov_factor + self._unbroadcasted_cov_diag = cov_diag + self._capacitance_tril = _batch_capacitance_tril(cov_factor, cov_diag) + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LowRankMultivariateNormal, _instance) + batch_shape = torch.Size(batch_shape) + loc_shape = batch_shape + self.event_shape + new.loc = self.loc.expand(loc_shape) + new.cov_diag = self.cov_diag.expand(loc_shape) + new.cov_factor = self.cov_factor.expand(loc_shape + self.cov_factor.shape[-1:]) + new._unbroadcasted_cov_factor = self._unbroadcasted_cov_factor + new._unbroadcasted_cov_diag = self._unbroadcasted_cov_diag + new._capacitance_tril = self._capacitance_tril + super(LowRankMultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> Tensor: + return self.loc + + @property + def mode(self) -> Tensor: + return self.loc + + @lazy_property + def variance(self) -> Tensor: # type: ignore[override] + return ( + self._unbroadcasted_cov_factor.pow(2).sum(-1) + self._unbroadcasted_cov_diag + ).expand(self._batch_shape + self._event_shape) + + @lazy_property + def scale_tril(self) -> Tensor: + # The following identity is used to increase the numerically computation stability + # for Cholesky decomposition (see http://www.gaussianprocess.org/gpml/, Section 3.4.3): + # W @ W.T + D = D1/2 @ (I + D-1/2 @ W @ W.T @ D-1/2) @ D1/2 + # The matrix "I + D-1/2 @ W @ W.T @ D-1/2" has eigenvalues bounded from below by 1, + # hence it is well-conditioned and safe to take Cholesky decomposition. + n = self._event_shape[0] + cov_diag_sqrt_unsqueeze = self._unbroadcasted_cov_diag.sqrt().unsqueeze(-1) + Dinvsqrt_W = self._unbroadcasted_cov_factor / cov_diag_sqrt_unsqueeze + K = torch.matmul(Dinvsqrt_W, Dinvsqrt_W.mT).contiguous() + K.view(-1, n * n)[:, :: n + 1] += 1 # add identity matrix to K + scale_tril = cov_diag_sqrt_unsqueeze * torch.linalg.cholesky(K) + return scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self) -> Tensor: + covariance_matrix = torch.matmul( + self._unbroadcasted_cov_factor, self._unbroadcasted_cov_factor.mT + ) + torch.diag_embed(self._unbroadcasted_cov_diag) + return covariance_matrix.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def precision_matrix(self) -> Tensor: + # We use "Woodbury matrix identity" to take advantage of low rank form:: + # inv(W @ W.T + D) = inv(D) - inv(D) @ W @ inv(C) @ W.T @ inv(D) + # where :math:`C` is the capacitance matrix. + Wt_Dinv = ( + self._unbroadcasted_cov_factor.mT + / self._unbroadcasted_cov_diag.unsqueeze(-2) + ) + A = torch.linalg.solve_triangular(self._capacitance_tril, Wt_Dinv, upper=False) + precision_matrix = ( + torch.diag_embed(self._unbroadcasted_cov_diag.reciprocal()) - A.mT @ A + ) + return precision_matrix.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + W_shape = shape[:-1] + self.cov_factor.shape[-1:] + eps_W = _standard_normal(W_shape, dtype=self.loc.dtype, device=self.loc.device) + eps_D = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return ( + self.loc + + _batch_mv(self._unbroadcasted_cov_factor, eps_W) + + self._unbroadcasted_cov_diag.sqrt() * eps_D + ) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + diff = value - self.loc + M = _batch_lowrank_mahalanobis( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + diff, + self._capacitance_tril, + ) + log_det = _batch_lowrank_logdet( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + self._capacitance_tril, + ) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + log_det + M) + + def entropy(self): + log_det = _batch_lowrank_logdet( + self._unbroadcasted_cov_factor, + self._unbroadcasted_cov_diag, + self._capacitance_tril, + ) + H = 0.5 * (self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + log_det) + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) diff --git a/phivenv/Lib/site-packages/torch/distributions/mixture_same_family.py b/phivenv/Lib/site-packages/torch/distributions/mixture_same_family.py new file mode 100644 index 0000000000000000000000000000000000000000..2797f51adb4dac70ef9eeecd64ac09a06570d9d3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/mixture_same_family.py @@ -0,0 +1,220 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import Tensor +from torch.distributions import Categorical, constraints +from torch.distributions.constraints import MixtureSameFamilyConstraint +from torch.distributions.distribution import Distribution + + +__all__ = ["MixtureSameFamily"] + + +class MixtureSameFamily(Distribution): + r""" + The `MixtureSameFamily` distribution implements a (batch of) mixture + distribution where all component are from different parameterizations of + the same distribution type. It is parameterized by a `Categorical` + "selecting distribution" (over `k` component) and a component + distribution, i.e., a `Distribution` with a rightmost batch shape + (equal to `[k]`) which indexes each (batch of) component. + + Examples:: + + >>> # xdoctest: +SKIP("undefined vars") + >>> # Construct Gaussian Mixture Model in 1D consisting of 5 equally + >>> # weighted normal distributions + >>> mix = D.Categorical(torch.ones(5,)) + >>> comp = D.Normal(torch.randn(5,), torch.rand(5,)) + >>> gmm = MixtureSameFamily(mix, comp) + + >>> # Construct Gaussian Mixture Model in 2D consisting of 5 equally + >>> # weighted bivariate normal distributions + >>> mix = D.Categorical(torch.ones(5,)) + >>> comp = D.Independent(D.Normal( + ... torch.randn(5,2), torch.rand(5,2)), 1) + >>> gmm = MixtureSameFamily(mix, comp) + + >>> # Construct a batch of 3 Gaussian Mixture Models in 2D each + >>> # consisting of 5 random weighted bivariate normal distributions + >>> mix = D.Categorical(torch.rand(3,5)) + >>> comp = D.Independent(D.Normal( + ... torch.randn(3,5,2), torch.rand(3,5,2)), 1) + >>> gmm = MixtureSameFamily(mix, comp) + + Args: + mixture_distribution: `torch.distributions.Categorical`-like + instance. Manages the probability of selecting component. + The number of categories must match the rightmost batch + dimension of the `component_distribution`. Must have either + scalar `batch_shape` or `batch_shape` matching + `component_distribution.batch_shape[:-1]` + component_distribution: `torch.distributions.Distribution`-like + instance. Right-most batch dimension indexes component. + """ + + arg_constraints: dict[str, constraints.Constraint] = {} + has_rsample = False + + def __init__( + self, + mixture_distribution: Categorical, + component_distribution: Distribution, + validate_args: Optional[bool] = None, + ) -> None: + self._mixture_distribution = mixture_distribution + self._component_distribution = component_distribution + + if not isinstance(self._mixture_distribution, Categorical): + raise ValueError( + " The Mixture distribution needs to be an " + " instance of torch.distributions.Categorical" + ) + + if not isinstance(self._component_distribution, Distribution): + raise ValueError( + "The Component distribution need to be an " + "instance of torch.distributions.Distribution" + ) + + # Check that batch size matches + mdbs = self._mixture_distribution.batch_shape + cdbs = self._component_distribution.batch_shape[:-1] + for size1, size2 in zip(reversed(mdbs), reversed(cdbs)): + if size1 != 1 and size2 != 1 and size1 != size2: + raise ValueError( + f"`mixture_distribution.batch_shape` ({mdbs}) is not " + "compatible with `component_distribution." + f"batch_shape`({cdbs})" + ) + + # Check that the number of mixture component matches + km = self._mixture_distribution.logits.shape[-1] + kc = self._component_distribution.batch_shape[-1] + if km is not None and kc is not None and km != kc: + raise ValueError( + f"`mixture_distribution component` ({km}) does not" + " equal `component_distribution.batch_shape[-1]`" + f" ({kc})" + ) + self._num_component = km + + event_shape = self._component_distribution.event_shape + self._event_ndims = len(event_shape) + super().__init__( + batch_shape=cdbs, event_shape=event_shape, validate_args=validate_args + ) + + def expand(self, batch_shape, _instance=None): + batch_shape = torch.Size(batch_shape) + batch_shape_comp = batch_shape + (self._num_component,) + new = self._get_checked_instance(MixtureSameFamily, _instance) + new._component_distribution = self._component_distribution.expand( + batch_shape_comp + ) + new._mixture_distribution = self._mixture_distribution.expand(batch_shape) + new._num_component = self._num_component + new._event_ndims = self._event_ndims + event_shape = new._component_distribution.event_shape + super(MixtureSameFamily, new).__init__( + batch_shape=batch_shape, event_shape=event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property + def support(self): + return MixtureSameFamilyConstraint(self._component_distribution.support) + + @property + def mixture_distribution(self) -> Categorical: + return self._mixture_distribution + + @property + def component_distribution(self) -> Distribution: + return self._component_distribution + + @property + def mean(self) -> Tensor: + probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) + return torch.sum( + probs * self.component_distribution.mean, dim=-1 - self._event_ndims + ) # [B, E] + + @property + def variance(self) -> Tensor: + # Law of total variance: Var(Y) = E[Var(Y|X)] + Var(E[Y|X]) + probs = self._pad_mixture_dimensions(self.mixture_distribution.probs) + mean_cond_var = torch.sum( + probs * self.component_distribution.variance, dim=-1 - self._event_ndims + ) + var_cond_mean = torch.sum( + probs * (self.component_distribution.mean - self._pad(self.mean)).pow(2.0), + dim=-1 - self._event_ndims, + ) + return mean_cond_var + var_cond_mean + + def cdf(self, x): + x = self._pad(x) + cdf_x = self.component_distribution.cdf(x) + mix_prob = self.mixture_distribution.probs + + return torch.sum(cdf_x * mix_prob, dim=-1) + + def log_prob(self, x): + if self._validate_args: + self._validate_sample(x) + x = self._pad(x) + log_prob_x = self.component_distribution.log_prob(x) # [S, B, k] + log_mix_prob = torch.log_softmax( + self.mixture_distribution.logits, dim=-1 + ) # [B, k] + return torch.logsumexp(log_prob_x + log_mix_prob, dim=-1) # [S, B] + + def sample(self, sample_shape=torch.Size()): + with torch.no_grad(): + sample_len = len(sample_shape) + batch_len = len(self.batch_shape) + gather_dim = sample_len + batch_len + es = self.event_shape + + # mixture samples [n, B] + mix_sample = self.mixture_distribution.sample(sample_shape) + mix_shape = mix_sample.shape + + # component samples [n, B, k, E] + comp_samples = self.component_distribution.sample(sample_shape) + + # Gather along the k dimension + mix_sample_r = mix_sample.reshape( + mix_shape + torch.Size([1] * (len(es) + 1)) + ) + mix_sample_r = mix_sample_r.repeat( + torch.Size([1] * len(mix_shape)) + torch.Size([1]) + es + ) + + samples = torch.gather(comp_samples, gather_dim, mix_sample_r) + return samples.squeeze(gather_dim) + + def _pad(self, x): + return x.unsqueeze(-1 - self._event_ndims) + + def _pad_mixture_dimensions(self, x): + dist_batch_ndims = len(self.batch_shape) + cat_batch_ndims = len(self.mixture_distribution.batch_shape) + pad_ndims = 0 if cat_batch_ndims == 1 else dist_batch_ndims - cat_batch_ndims + xs = x.shape + x = x.reshape( + xs[:-1] + + torch.Size(pad_ndims * [1]) + + xs[-1:] + + torch.Size(self._event_ndims * [1]) + ) + return x + + def __repr__(self): + args_string = ( + f"\n {self.mixture_distribution},\n {self.component_distribution}" + ) + return "MixtureSameFamily" + "(" + args_string + ")" diff --git a/phivenv/Lib/site-packages/torch/distributions/multinomial.py b/phivenv/Lib/site-packages/torch/distributions/multinomial.py new file mode 100644 index 0000000000000000000000000000000000000000..5c9cc8019ebf79c46cc4a82443118a5b7af718f7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/multinomial.py @@ -0,0 +1,146 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import inf, Tensor +from torch.distributions import Categorical, constraints +from torch.distributions.binomial import Binomial +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all + + +__all__ = ["Multinomial"] + + +class Multinomial(Distribution): + r""" + Creates a Multinomial distribution parameterized by :attr:`total_count` and + either :attr:`probs` or :attr:`logits` (but not both). The innermost dimension of + :attr:`probs` indexes over categories. All other dimensions index over batches. + + Note that :attr:`total_count` need not be specified if only :meth:`log_prob` is + called (see example below) + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + - :meth:`sample` requires a single shared `total_count` for all + parameters and samples. + - :meth:`log_prob` allows different `total_count` for each parameter and + sample. + + Example:: + + >>> # xdoctest: +SKIP("FIXME: found invalid values") + >>> m = Multinomial(100, torch.tensor([ 1., 1., 1., 1.])) + >>> x = m.sample() # equal probability of 0, 1, 2, 3 + tensor([ 21., 24., 30., 25.]) + + >>> Multinomial(probs=torch.tensor([1., 1., 1., 1.])).log_prob(x) + tensor([-4.1338]) + + Args: + total_count (int): number of trials + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + total_count: int + + @property + def mean(self) -> Tensor: + return self.probs * self.total_count + + @property + def variance(self) -> Tensor: + return self.total_count * self.probs * (1 - self.probs) + + def __init__( + self, + total_count: int = 1, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + if not isinstance(total_count, int): + raise NotImplementedError("inhomogeneous total_count is not supported") + self.total_count = total_count + self._categorical = Categorical(probs=probs, logits=logits) + self._binomial = Binomial(total_count=total_count, probs=self.probs) + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Multinomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count + new._categorical = self._categorical.expand(batch_shape) + super(Multinomial, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @constraints.dependent_property(is_discrete=True, event_dim=1) + def support(self): + return constraints.multinomial(self.total_count) + + @property + def logits(self) -> Tensor: + return self._categorical.logits + + @property + def probs(self) -> Tensor: + return self._categorical.probs + + @property + def param_shape(self) -> torch.Size: + return self._categorical.param_shape + + def sample(self, sample_shape=torch.Size()): + sample_shape = torch.Size(sample_shape) + samples = self._categorical.sample( + torch.Size((self.total_count,)) + sample_shape + ) + # samples.shape is (total_count, sample_shape, batch_shape), need to change it to + # (sample_shape, batch_shape, total_count) + shifted_idx = list(range(samples.dim())) + shifted_idx.append(shifted_idx.pop(0)) + samples = samples.permute(*shifted_idx) + counts = samples.new(self._extended_shape(sample_shape)).zero_() + counts.scatter_add_(-1, samples, torch.ones_like(samples)) + return counts.type_as(self.probs) + + def entropy(self): + n = torch.tensor(self.total_count) + + cat_entropy = self._categorical.entropy() + term1 = n * cat_entropy - torch.lgamma(n + 1) + + support = self._binomial.enumerate_support(expand=False)[1:] + binomial_probs = torch.exp(self._binomial.log_prob(support)) + weights = torch.lgamma(support + 1) + term2 = (binomial_probs * weights).sum([0, -1]) + + return term1 + term2 + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + logits = logits.clone(memory_format=torch.contiguous_format) + log_factorial_n = torch.lgamma(value.sum(-1) + 1) + log_factorial_xs = torch.lgamma(value + 1).sum(-1) + logits[(value == 0) & (logits == -inf)] = 0 + log_powers = (logits * value).sum(-1) + return log_factorial_n - log_factorial_xs + log_powers diff --git a/phivenv/Lib/site-packages/torch/distributions/multivariate_normal.py b/phivenv/Lib/site-packages/torch/distributions/multivariate_normal.py new file mode 100644 index 0000000000000000000000000000000000000000..da20d6a6ad4ffd2e707b33e315f2f7c77d3c3075 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/multivariate_normal.py @@ -0,0 +1,269 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _standard_normal, lazy_property +from torch.types import _size + + +__all__ = ["MultivariateNormal"] + + +def _batch_mv(bmat, bvec): + r""" + Performs a batched matrix-vector product, with compatible but different batch shapes. + + This function takes as input `bmat`, containing :math:`n \times n` matrices, and + `bvec`, containing length :math:`n` vectors. + + Both `bmat` and `bvec` may have any number of leading dimensions, which correspond + to a batch shape. They are not necessarily assumed to have the same batch shape, + just ones which can be broadcasted. + """ + return torch.matmul(bmat, bvec.unsqueeze(-1)).squeeze(-1) + + +def _batch_mahalanobis(bL, bx): + r""" + Computes the squared Mahalanobis distance :math:`\mathbf{x}^\top\mathbf{M}^{-1}\mathbf{x}` + for a factored :math:`\mathbf{M} = \mathbf{L}\mathbf{L}^\top`. + + Accepts batches for both bL and bx. They are not necessarily assumed to have the same batch + shape, but `bL` one should be able to broadcasted to `bx` one. + """ + n = bx.size(-1) + bx_batch_shape = bx.shape[:-1] + + # Assume that bL.shape = (i, 1, n, n), bx.shape = (..., i, j, n), + # we are going to make bx have shape (..., 1, j, i, 1, n) to apply batched tri.solve + bx_batch_dims = len(bx_batch_shape) + bL_batch_dims = bL.dim() - 2 + outer_batch_dims = bx_batch_dims - bL_batch_dims + old_batch_dims = outer_batch_dims + bL_batch_dims + new_batch_dims = outer_batch_dims + 2 * bL_batch_dims + # Reshape bx with the shape (..., 1, i, j, 1, n) + bx_new_shape = bx.shape[:outer_batch_dims] + for sL, sx in zip(bL.shape[:-2], bx.shape[outer_batch_dims:-1]): + bx_new_shape += (sx // sL, sL) + bx_new_shape += (n,) + bx = bx.reshape(bx_new_shape) + # Permute bx to make it have shape (..., 1, j, i, 1, n) + permute_dims = ( + list(range(outer_batch_dims)) + + list(range(outer_batch_dims, new_batch_dims, 2)) + + list(range(outer_batch_dims + 1, new_batch_dims, 2)) + + [new_batch_dims] + ) + bx = bx.permute(permute_dims) + + flat_L = bL.reshape(-1, n, n) # shape = b x n x n + flat_x = bx.reshape(-1, flat_L.size(0), n) # shape = c x b x n + flat_x_swap = flat_x.permute(1, 2, 0) # shape = b x n x c + M_swap = ( + torch.linalg.solve_triangular(flat_L, flat_x_swap, upper=False).pow(2).sum(-2) + ) # shape = b x c + M = M_swap.t() # shape = c x b + + # Now we revert the above reshape and permute operators. + permuted_M = M.reshape(bx.shape[:-1]) # shape = (..., 1, j, i, 1) + permute_inv_dims = list(range(outer_batch_dims)) + for i in range(bL_batch_dims): + permute_inv_dims += [outer_batch_dims + i, old_batch_dims + i] + reshaped_M = permuted_M.permute(permute_inv_dims) # shape = (..., 1, i, j, 1) + return reshaped_M.reshape(bx_batch_shape) + + +def _precision_to_scale_tril(P): + # Ref: https://nbviewer.jupyter.org/gist/fehiepsi/5ef8e09e61604f10607380467eb82006#Precision-to-scale_tril + Lf = torch.linalg.cholesky(torch.flip(P, (-2, -1))) + L_inv = torch.transpose(torch.flip(Lf, (-2, -1)), -2, -1) + Id = torch.eye(P.shape[-1], dtype=P.dtype, device=P.device) + L = torch.linalg.solve_triangular(L_inv, Id, upper=False) + return L + + +class MultivariateNormal(Distribution): + r""" + Creates a multivariate normal (also called Gaussian) distribution + parameterized by a mean vector and a covariance matrix. + + The multivariate normal distribution can be parameterized either + in terms of a positive definite covariance matrix :math:`\mathbf{\Sigma}` + or a positive definite precision matrix :math:`\mathbf{\Sigma}^{-1}` + or a lower-triangular matrix :math:`\mathbf{L}` with positive-valued + diagonal entries, such that + :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top`. This triangular matrix + can be obtained via e.g. Cholesky decomposition of the covariance. + + Example: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = MultivariateNormal(torch.zeros(2), torch.eye(2)) + >>> m.sample() # normally distributed with mean=`[0,0]` and covariance_matrix=`I` + tensor([-0.2102, -0.5429]) + + Args: + loc (Tensor): mean of the distribution + covariance_matrix (Tensor): positive-definite covariance matrix + precision_matrix (Tensor): positive-definite precision matrix + scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal + + Note: + Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or + :attr:`scale_tril` can be specified. + + Using :attr:`scale_tril` will be more efficient: all computations internally + are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or + :attr:`precision_matrix` is passed instead, it is only used to compute + the corresponding lower triangular matrices using a Cholesky decomposition. + """ + + arg_constraints = { + "loc": constraints.real_vector, + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.real_vector + has_rsample = True + + def __init__( + self, + loc: Tensor, + covariance_matrix: Optional[Tensor] = None, + precision_matrix: Optional[Tensor] = None, + scale_tril: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + if loc.dim() < 1: + raise ValueError("loc must be at least one-dimensional.") + if (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) != 1: + raise ValueError( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) + + if scale_tril is not None: + if scale_tril.dim() < 2: + raise ValueError( + "scale_tril matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) + self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) + elif covariance_matrix is not None: + if covariance_matrix.dim() < 2: + raise ValueError( + "covariance_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + covariance_matrix.shape[:-2], loc.shape[:-1] + ) + self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) + else: + assert precision_matrix is not None # helps mypy + if precision_matrix.dim() < 2: + raise ValueError( + "precision_matrix must be at least two-dimensional, " + "with optional leading batch dimensions" + ) + batch_shape = torch.broadcast_shapes( + precision_matrix.shape[:-2], loc.shape[:-1] + ) + self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) + self.loc = loc.expand(batch_shape + (-1,)) + + event_shape = self.loc.shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + elif covariance_matrix is not None: + self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) + else: # precision_matrix is not None + self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(MultivariateNormal, _instance) + batch_shape = torch.Size(batch_shape) + loc_shape = batch_shape + self.event_shape + cov_shape = batch_shape + self.event_shape + self.event_shape + new.loc = self.loc.expand(loc_shape) + new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril + if "covariance_matrix" in self.__dict__: + new.covariance_matrix = self.covariance_matrix.expand(cov_shape) + if "scale_tril" in self.__dict__: + new.scale_tril = self.scale_tril.expand(cov_shape) + if "precision_matrix" in self.__dict__: + new.precision_matrix = self.precision_matrix.expand(cov_shape) + super(MultivariateNormal, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @lazy_property + def scale_tril(self) -> Tensor: + return self._unbroadcasted_scale_tril.expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self) -> Tensor: + return torch.matmul( + self._unbroadcasted_scale_tril, self._unbroadcasted_scale_tril.mT + ).expand(self._batch_shape + self._event_shape + self._event_shape) + + @lazy_property + def precision_matrix(self) -> Tensor: + return torch.cholesky_inverse(self._unbroadcasted_scale_tril).expand( + self._batch_shape + self._event_shape + self._event_shape + ) + + @property + def mean(self) -> Tensor: + return self.loc + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def variance(self) -> Tensor: + return ( + self._unbroadcasted_scale_tril.pow(2) + .sum(-1) + .expand(self._batch_shape + self._event_shape) + ) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.loc + _batch_mv(self._unbroadcasted_scale_tril, eps) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + diff = value - self.loc + M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff) + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) + return -0.5 * (self._event_shape[0] * math.log(2 * math.pi) + M) - half_log_det + + def entropy(self): + half_log_det = ( + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1) + ) + H = 0.5 * self._event_shape[0] * (1.0 + math.log(2 * math.pi)) + half_log_det + if len(self._batch_shape) == 0: + return H + else: + return H.expand(self._batch_shape) diff --git a/phivenv/Lib/site-packages/torch/distributions/negative_binomial.py b/phivenv/Lib/site-packages/torch/distributions/negative_binomial.py new file mode 100644 index 0000000000000000000000000000000000000000..1107580b2e588a35dfc2f7d75b039497e2fea0d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/negative_binomial.py @@ -0,0 +1,147 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.gamma import Gamma +from torch.distributions.utils import ( + broadcast_all, + lazy_property, + logits_to_probs, + probs_to_logits, +) + + +__all__ = ["NegativeBinomial"] + + +class NegativeBinomial(Distribution): + r""" + Creates a Negative Binomial distribution, i.e. distribution + of the number of successful independent and identical Bernoulli trials + before :attr:`total_count` failures are achieved. The probability + of success of each Bernoulli trial is :attr:`probs`. + + Args: + total_count (float or Tensor): non-negative number of negative Bernoulli + trials to stop, although the distribution is still valid for real + valued count + probs (Tensor): Event probabilities of success in the half open interval [0, 1) + logits (Tensor): Event log-odds for probabilities of success + """ + + arg_constraints = { + "total_count": constraints.greater_than_eq(0), + "probs": constraints.half_open_interval(0.0, 1.0), + "logits": constraints.real, + } + support = constraints.nonnegative_integer + + def __init__( + self, + total_count: Union[Tensor, float], + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + ( + self.total_count, + self.probs, + ) = broadcast_all(total_count, probs) + self.total_count = self.total_count.type_as(self.probs) + else: + assert logits is not None # helps mypy + ( + self.total_count, + self.logits, + ) = broadcast_all(total_count, logits) + self.total_count = self.total_count.type_as(self.logits) + + self._param = self.probs if probs is not None else self.logits + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(NegativeBinomial, _instance) + batch_shape = torch.Size(batch_shape) + new.total_count = self.total_count.expand(batch_shape) + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(NegativeBinomial, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @property + def mean(self) -> Tensor: + return self.total_count * torch.exp(self.logits) + + @property + def mode(self) -> Tensor: + return ((self.total_count - 1) * self.logits.exp()).floor().clamp(min=0.0) + + @property + def variance(self) -> Tensor: + return self.mean / torch.sigmoid(-self.logits) + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self) -> Tensor: + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self) -> torch.Size: + return self._param.size() + + @lazy_property + def _gamma(self) -> Gamma: + # Note we avoid validating because self.total_count can be zero. + return Gamma( + concentration=self.total_count, + rate=torch.exp(-self.logits), + validate_args=False, + ) + + def sample(self, sample_shape=torch.Size()): + with torch.no_grad(): + rate = self._gamma.sample(sample_shape=sample_shape) + return torch.poisson(rate) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + + log_unnormalized_prob = self.total_count * F.logsigmoid( + -self.logits + ) + value * F.logsigmoid(self.logits) + + log_normalization = ( + -torch.lgamma(self.total_count + value) + + torch.lgamma(1.0 + value) + + torch.lgamma(self.total_count) + ) + # The case self.total_count == 0 and value == 0 has probability 1 but + # lgamma(0) is infinite. Handle this case separately using a function + # that does not modify tensors in place to allow Jit compilation. + log_normalization = log_normalization.masked_fill( + self.total_count + value == 0.0, 0.0 + ) + + return log_unnormalized_prob - log_normalization diff --git a/phivenv/Lib/site-packages/torch/distributions/normal.py b/phivenv/Lib/site-packages/torch/distributions/normal.py new file mode 100644 index 0000000000000000000000000000000000000000..f3d65eded6c34298fa7dfaab06b62d55775c5c7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/normal.py @@ -0,0 +1,121 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import _standard_normal, broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Normal"] + + +class Normal(ExponentialFamily): + r""" + Creates a normal (also called Gaussian) distribution parameterized by + :attr:`loc` and :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Normal(torch.tensor([0.0]), torch.tensor([1.0])) + >>> m.sample() # normally distributed with loc=0 and scale=1 + tensor([ 0.1046]) + + Args: + loc (float or Tensor): mean of the distribution (often referred to as mu) + scale (float or Tensor): standard deviation of the distribution + (often referred to as sigma) + """ + + arg_constraints = {"loc": constraints.real, "scale": constraints.positive} + support = constraints.real + has_rsample = True + _mean_carrier_measure = 0 + + @property + def mean(self) -> Tensor: + return self.loc + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def stddev(self) -> Tensor: + return self.scale + + @property + def variance(self) -> Tensor: + return self.stddev.pow(2) + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.loc, self.scale = broadcast_all(loc, scale) + if isinstance(loc, _Number) and isinstance(scale, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.loc.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Normal, _instance) + batch_shape = torch.Size(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + super(Normal, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.normal(self.loc.expand(shape), self.scale.expand(shape)) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + eps = _standard_normal(shape, dtype=self.loc.dtype, device=self.loc.device) + return self.loc + eps * self.scale + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + # compute the variance + var = self.scale**2 + log_scale = ( + math.log(self.scale) + if isinstance(self.scale, _Number) + else self.scale.log() + ) + return ( + -((value - self.loc) ** 2) / (2 * var) + - log_scale + - math.log(math.sqrt(2 * math.pi)) + ) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + return 0.5 * ( + 1 + torch.erf((value - self.loc) * self.scale.reciprocal() / math.sqrt(2)) + ) + + def icdf(self, value): + return self.loc + self.scale * torch.erfinv(2 * value - 1) * math.sqrt(2) + + def entropy(self): + return 0.5 + 0.5 * math.log(2 * math.pi) + torch.log(self.scale) + + @property + def _natural_params(self) -> tuple[Tensor, Tensor]: + return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) + + def _log_normalizer(self, x, y): + return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y) diff --git a/phivenv/Lib/site-packages/torch/distributions/one_hot_categorical.py b/phivenv/Lib/site-packages/torch/distributions/one_hot_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..8ada13a534083178f9e48c9856168aac2c2e25f9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/one_hot_categorical.py @@ -0,0 +1,142 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.categorical import Categorical +from torch.distributions.distribution import Distribution +from torch.types import _size + + +__all__ = ["OneHotCategorical", "OneHotCategoricalStraightThrough"] + + +class OneHotCategorical(Distribution): + r""" + Creates a one-hot categorical distribution parameterized by :attr:`probs` or + :attr:`logits`. + + Samples are one-hot coded vectors of size ``probs.size(-1)``. + + .. note:: The `probs` argument must be non-negative, finite and have a non-zero sum, + and it will be normalized to sum to 1 along the last dimension. :attr:`probs` + will return this normalized value. + The `logits` argument will be interpreted as unnormalized log probabilities + and can therefore be any real number. It will likewise be normalized so that + the resulting probabilities sum to 1 along the last dimension. :attr:`logits` + will return this normalized value. + + See also: :func:`torch.distributions.Categorical` for specifications of + :attr:`probs` and :attr:`logits`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = OneHotCategorical(torch.tensor([ 0.25, 0.25, 0.25, 0.25 ])) + >>> m.sample() # equal probability of 0, 1, 2, 3 + tensor([ 0., 0., 0., 1.]) + + Args: + probs (Tensor): event probabilities + logits (Tensor): event log probabilities (unnormalized) + """ + + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = constraints.one_hot + has_enumerate_support = True + + def __init__( + self, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + self._categorical = Categorical(probs, logits) + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(OneHotCategorical, _instance) + batch_shape = torch.Size(batch_shape) + new._categorical = self._categorical.expand(batch_shape) + super(OneHotCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @property + def _param(self) -> Tensor: + return self._categorical._param + + @property + def probs(self) -> Tensor: + return self._categorical.probs + + @property + def logits(self) -> Tensor: + return self._categorical.logits + + @property + def mean(self) -> Tensor: + return self._categorical.probs + + @property + def mode(self) -> Tensor: + probs = self._categorical.probs + mode = probs.argmax(dim=-1) + return torch.nn.functional.one_hot(mode, num_classes=probs.shape[-1]).to(probs) + + @property + def variance(self) -> Tensor: + return self._categorical.probs * (1 - self._categorical.probs) + + @property + def param_shape(self) -> torch.Size: + return self._categorical.param_shape + + def sample(self, sample_shape=torch.Size()): + sample_shape = torch.Size(sample_shape) + probs = self._categorical.probs + num_events = self._categorical._num_events + indices = self._categorical.sample(sample_shape) + return torch.nn.functional.one_hot(indices, num_events).to(probs) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + indices = value.max(-1)[1] + return self._categorical.log_prob(indices) + + def entropy(self): + return self._categorical.entropy() + + def enumerate_support(self, expand=True): + n = self.event_shape[0] + values = torch.eye(n, dtype=self._param.dtype, device=self._param.device) + values = values.view((n,) + (1,) * len(self.batch_shape) + (n,)) + if expand: + values = values.expand((n,) + self.batch_shape + (n,)) + return values + + +class OneHotCategoricalStraightThrough(OneHotCategorical): + r""" + Creates a reparameterizable :class:`OneHotCategorical` distribution based on the straight- + through gradient estimator from [1]. + + [1] Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation + (Bengio et al., 2013) + """ + + has_rsample = True + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + samples = self.sample(sample_shape) + probs = self._categorical.probs # cached via @lazy_property + return samples + (probs - probs.detach()) diff --git a/phivenv/Lib/site-packages/torch/distributions/pareto.py b/phivenv/Lib/site-packages/torch/distributions/pareto.py new file mode 100644 index 0000000000000000000000000000000000000000..a682894a53c532b91979810b03e57395e9fdbb1d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/pareto.py @@ -0,0 +1,73 @@ +from typing import Optional, Union + +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exponential import Exponential +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, ExpTransform +from torch.distributions.utils import broadcast_all +from torch.types import _size + + +__all__ = ["Pareto"] + + +class Pareto(TransformedDistribution): + r""" + Samples from a Pareto Type 1 distribution. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Pareto(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Pareto distribution with scale=1 and alpha=1 + tensor([ 1.5623]) + + Args: + scale (float or Tensor): Scale parameter of the distribution + alpha (float or Tensor): Shape parameter of the distribution + """ + + arg_constraints = {"alpha": constraints.positive, "scale": constraints.positive} + + def __init__( + self, + scale: Union[Tensor, float], + alpha: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.scale, self.alpha = broadcast_all(scale, alpha) + base_dist = Exponential(self.alpha, validate_args=validate_args) + transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand( + self, batch_shape: _size, _instance: Optional["Pareto"] = None + ) -> "Pareto": + new = self._get_checked_instance(Pareto, _instance) + new.scale = self.scale.expand(batch_shape) + new.alpha = self.alpha.expand(batch_shape) + return super().expand(batch_shape, _instance=new) + + @property + def mean(self) -> Tensor: + # mean is inf for alpha <= 1 + a = self.alpha.clamp(min=1) + return a * self.scale / (a - 1) + + @property + def mode(self) -> Tensor: + return self.scale + + @property + def variance(self) -> Tensor: + # var is inf for alpha <= 2 + a = self.alpha.clamp(min=2) + return self.scale.pow(2) * a / ((a - 1).pow(2) * (a - 2)) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self) -> constraints.Constraint: + return constraints.greater_than_eq(self.scale) + + def entropy(self) -> Tensor: + return (self.scale / self.alpha).log() + (1 + self.alpha.reciprocal()) diff --git a/phivenv/Lib/site-packages/torch/distributions/poisson.py b/phivenv/Lib/site-packages/torch/distributions/poisson.py new file mode 100644 index 0000000000000000000000000000000000000000..bc1a535d0e8e2fbca788bbf1567d5116fd451177 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/poisson.py @@ -0,0 +1,86 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.utils import broadcast_all +from torch.types import _Number, Number + + +__all__ = ["Poisson"] + + +class Poisson(ExponentialFamily): + r""" + Creates a Poisson distribution parameterized by :attr:`rate`, the rate parameter. + + Samples are nonnegative integers, with a pmf given by + + .. math:: + \mathrm{rate}^k \frac{e^{-\mathrm{rate}}}{k!} + + Example:: + + >>> # xdoctest: +SKIP("poisson_cpu not implemented for 'Long'") + >>> m = Poisson(torch.tensor([4])) + >>> m.sample() + tensor([ 3.]) + + Args: + rate (Number, Tensor): the rate parameter + """ + + arg_constraints = {"rate": constraints.nonnegative} + support = constraints.nonnegative_integer + + @property + def mean(self) -> Tensor: + return self.rate + + @property + def mode(self) -> Tensor: + return self.rate.floor() + + @property + def variance(self) -> Tensor: + return self.rate + + def __init__( + self, + rate: Union[Tensor, Number], + validate_args: Optional[bool] = None, + ) -> None: + (self.rate,) = broadcast_all(rate) + if isinstance(rate, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.rate.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Poisson, _instance) + batch_shape = torch.Size(batch_shape) + new.rate = self.rate.expand(batch_shape) + super(Poisson, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def sample(self, sample_shape=torch.Size()): + shape = self._extended_shape(sample_shape) + with torch.no_grad(): + return torch.poisson(self.rate.expand(shape)) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + rate, value = broadcast_all(self.rate, value) + return value.xlogy(rate) - rate - (value + 1).lgamma() + + @property + def _natural_params(self) -> tuple[Tensor]: + return (torch.log(self.rate),) + + def _log_normalizer(self, x): + return torch.exp(x) diff --git a/phivenv/Lib/site-packages/torch/distributions/relaxed_bernoulli.py b/phivenv/Lib/site-packages/torch/distributions/relaxed_bernoulli.py new file mode 100644 index 0000000000000000000000000000000000000000..8fe1302e0abc8cc2d2f0dad0bfef02e7d71bfd0e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/relaxed_bernoulli.py @@ -0,0 +1,169 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import SigmoidTransform +from torch.distributions.utils import ( + broadcast_all, + clamp_probs, + lazy_property, + logits_to_probs, + probs_to_logits, +) +from torch.types import _Number, _size, Number + + +__all__ = ["LogitRelaxedBernoulli", "RelaxedBernoulli"] + + +class LogitRelaxedBernoulli(Distribution): + r""" + Creates a LogitRelaxedBernoulli distribution parameterized by :attr:`probs` + or :attr:`logits` (but not both), which is the logit of a RelaxedBernoulli + distribution. + + Samples are logits of values in (0, 1). See [1] for more details. + + Args: + temperature (Tensor): relaxation temperature + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random + Variables (Maddison et al., 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al., 2017) + """ + + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.real + + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: + self.temperature = temperature + if (probs is None) == (logits is None): + raise ValueError( + "Either `probs` or `logits` must be specified, but not both." + ) + if probs is not None: + is_scalar = isinstance(probs, _Number) + (self.probs,) = broadcast_all(probs) + else: + assert logits is not None # helps mypy + is_scalar = isinstance(logits, _Number) + (self.logits,) = broadcast_all(logits) + self._param = self.probs if probs is not None else self.logits + if is_scalar: + batch_shape = torch.Size() + else: + batch_shape = self._param.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(LogitRelaxedBernoulli, _instance) + batch_shape = torch.Size(batch_shape) + new.temperature = self.temperature + if "probs" in self.__dict__: + new.probs = self.probs.expand(batch_shape) + new._param = new.probs + if "logits" in self.__dict__: + new.logits = self.logits.expand(batch_shape) + new._param = new.logits + super(LogitRelaxedBernoulli, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._param.new(*args, **kwargs) + + @lazy_property + def logits(self) -> Tensor: + return probs_to_logits(self.probs, is_binary=True) + + @lazy_property + def probs(self) -> Tensor: + return logits_to_probs(self.logits, is_binary=True) + + @property + def param_shape(self) -> torch.Size: + return self._param.size() + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + probs = clamp_probs(self.probs.expand(shape)) + uniforms = clamp_probs( + torch.rand(shape, dtype=probs.dtype, device=probs.device) + ) + return ( + uniforms.log() - (-uniforms).log1p() + probs.log() - (-probs).log1p() + ) / self.temperature + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + diff = logits - value.mul(self.temperature) + return self.temperature.log() + diff - 2 * diff.exp().log1p() + + +class RelaxedBernoulli(TransformedDistribution): + r""" + Creates a RelaxedBernoulli distribution, parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` + (but not both). This is a relaxed version of the `Bernoulli` distribution, + so the values are in (0, 1), and has reparametrizable samples. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedBernoulli(torch.tensor([2.2]), + ... torch.tensor([0.1, 0.2, 0.3, 0.99])) + >>> m.sample() + tensor([ 0.2951, 0.3442, 0.8918, 0.9021]) + + Args: + temperature (Tensor): relaxation temperature + probs (Number, Tensor): the probability of sampling `1` + logits (Number, Tensor): the log-odds of sampling `1` + """ + + arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} + support = constraints.unit_interval + has_rsample = True + base_dist: LogitRelaxedBernoulli + + def __init__( + self, + temperature: Tensor, + probs: Optional[Union[Tensor, Number]] = None, + logits: Optional[Union[Tensor, Number]] = None, + validate_args: Optional[bool] = None, + ) -> None: + base_dist = LogitRelaxedBernoulli(temperature, probs, logits) + super().__init__(base_dist, SigmoidTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedBernoulli, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def temperature(self) -> Tensor: + return self.base_dist.temperature + + @property + def logits(self) -> Tensor: + return self.base_dist.logits + + @property + def probs(self) -> Tensor: + return self.base_dist.probs diff --git a/phivenv/Lib/site-packages/torch/distributions/relaxed_categorical.py b/phivenv/Lib/site-packages/torch/distributions/relaxed_categorical.py new file mode 100644 index 0000000000000000000000000000000000000000..44803412c51b5e089d47ae9988ed01c824fb6d31 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/relaxed_categorical.py @@ -0,0 +1,160 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.categorical import Categorical +from torch.distributions.distribution import Distribution +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import ExpTransform +from torch.distributions.utils import broadcast_all, clamp_probs +from torch.types import _size + + +__all__ = ["ExpRelaxedCategorical", "RelaxedOneHotCategorical"] + + +class ExpRelaxedCategorical(Distribution): + r""" + Creates a ExpRelaxedCategorical parameterized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits` (but not both). + Returns the log of a point in the simplex. Based on the interface to + :class:`OneHotCategorical`. + + Implementation based on [1]. + + See also: :func:`torch.distributions.OneHotCategorical` + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + + [1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables + (Maddison et al., 2017) + + [2] Categorical Reparametrization with Gumbel-Softmax + (Jang et al., 2017) + """ + + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = ( + constraints.real_vector + ) # The true support is actually a submanifold of this. + has_rsample = True + + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + self._categorical = Categorical(probs, logits) + self.temperature = temperature + batch_shape = self._categorical.batch_shape + event_shape = self._categorical.param_shape[-1:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(ExpRelaxedCategorical, _instance) + batch_shape = torch.Size(batch_shape) + new.temperature = self.temperature + new._categorical = self._categorical.expand(batch_shape) + super(ExpRelaxedCategorical, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + def _new(self, *args, **kwargs): + return self._categorical._new(*args, **kwargs) + + @property + def param_shape(self) -> torch.Size: + return self._categorical.param_shape + + @property + def logits(self) -> Tensor: + return self._categorical.logits + + @property + def probs(self) -> Tensor: + return self._categorical.probs + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + uniforms = clamp_probs( + torch.rand(shape, dtype=self.logits.dtype, device=self.logits.device) + ) + gumbels = -((-(uniforms.log())).log()) + scores = (self.logits + gumbels) / self.temperature + return scores - scores.logsumexp(dim=-1, keepdim=True) + + def log_prob(self, value): + K = self._categorical._num_events + if self._validate_args: + self._validate_sample(value) + logits, value = broadcast_all(self.logits, value) + log_scale = torch.full_like( + self.temperature, float(K) + ).lgamma() - self.temperature.log().mul(-(K - 1)) + score = logits - value.mul(self.temperature) + score = (score - score.logsumexp(dim=-1, keepdim=True)).sum(-1) + return score + log_scale + + +class RelaxedOneHotCategorical(TransformedDistribution): + r""" + Creates a RelaxedOneHotCategorical distribution parametrized by + :attr:`temperature`, and either :attr:`probs` or :attr:`logits`. + This is a relaxed version of the :class:`OneHotCategorical` distribution, so + its samples are on simplex, and are reparametrizable. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = RelaxedOneHotCategorical(torch.tensor([2.2]), + ... torch.tensor([0.1, 0.2, 0.3, 0.4])) + >>> m.sample() + tensor([ 0.1294, 0.2324, 0.3859, 0.2523]) + + Args: + temperature (Tensor): relaxation temperature + probs (Tensor): event probabilities + logits (Tensor): unnormalized log probability for each event + """ + + arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} + support = constraints.simplex + has_rsample = True + base_dist: ExpRelaxedCategorical + + def __init__( + self, + temperature: Tensor, + probs: Optional[Tensor] = None, + logits: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + base_dist = ExpRelaxedCategorical( + temperature, probs, logits, validate_args=validate_args + ) + super().__init__(base_dist, ExpTransform(), validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(RelaxedOneHotCategorical, _instance) + return super().expand(batch_shape, _instance=new) + + @property + def temperature(self) -> Tensor: + return self.base_dist.temperature + + @property + def logits(self) -> Tensor: + return self.base_dist.logits + + @property + def probs(self) -> Tensor: + return self.base_dist.probs diff --git a/phivenv/Lib/site-packages/torch/distributions/studentT.py b/phivenv/Lib/site-packages/torch/distributions/studentT.py new file mode 100644 index 0000000000000000000000000000000000000000..0ce901f075315e45207719f6327a64046fbcdc3b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/studentT.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional, Union + +import torch +from torch import inf, nan, Tensor +from torch.distributions import Chi2, constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import _standard_normal, broadcast_all +from torch.types import _size + + +__all__ = ["StudentT"] + + +class StudentT(Distribution): + r""" + Creates a Student's t-distribution parameterized by degree of + freedom :attr:`df`, mean :attr:`loc` and scale :attr:`scale`. + + Example:: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = StudentT(torch.tensor([2.0])) + >>> m.sample() # Student's t-distributed with degrees of freedom=2 + tensor([ 0.1046]) + + Args: + df (float or Tensor): degrees of freedom + loc (float or Tensor): mean of the distribution + scale (float or Tensor): scale of the distribution + """ + + arg_constraints = { + "df": constraints.positive, + "loc": constraints.real, + "scale": constraints.positive, + } + support = constraints.real + has_rsample = True + + @property + def mean(self) -> Tensor: + m = self.loc.clone(memory_format=torch.contiguous_format) + m[self.df <= 1] = nan + return m + + @property + def mode(self) -> Tensor: + return self.loc + + @property + def variance(self) -> Tensor: + m = self.df.clone(memory_format=torch.contiguous_format) + m[self.df > 2] = ( + self.scale[self.df > 2].pow(2) + * self.df[self.df > 2] + / (self.df[self.df > 2] - 2) + ) + m[(self.df <= 2) & (self.df > 1)] = inf + m[self.df <= 1] = nan + return m + + def __init__( + self, + df: Union[Tensor, float], + loc: Union[Tensor, float] = 0.0, + scale: Union[Tensor, float] = 1.0, + validate_args: Optional[bool] = None, + ) -> None: + self.df, self.loc, self.scale = broadcast_all(df, loc, scale) + self._chi2 = Chi2(self.df) + batch_shape = self.df.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(StudentT, _instance) + batch_shape = torch.Size(batch_shape) + new.df = self.df.expand(batch_shape) + new.loc = self.loc.expand(batch_shape) + new.scale = self.scale.expand(batch_shape) + new._chi2 = self._chi2.expand(batch_shape) + super(StudentT, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + # NOTE: This does not agree with scipy implementation as much as other distributions. + # (see https://github.com/fritzo/notebooks/blob/master/debug-student-t.ipynb). Using DoubleTensor + # parameters seems to help. + + # X ~ Normal(0, 1) + # Z ~ Chi2(df) + # Y = X / sqrt(Z / df) ~ StudentT(df) + shape = self._extended_shape(sample_shape) + X = _standard_normal(shape, dtype=self.df.dtype, device=self.df.device) + Z = self._chi2.rsample(sample_shape) + Y = X * torch.rsqrt(Z / self.df) + return self.loc + self.scale * Y + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + y = (value - self.loc) / self.scale + Z = ( + self.scale.log() + + 0.5 * self.df.log() + + 0.5 * math.log(math.pi) + + torch.lgamma(0.5 * self.df) + - torch.lgamma(0.5 * (self.df + 1.0)) + ) + return -0.5 * (self.df + 1.0) * torch.log1p(y**2.0 / self.df) - Z + + def entropy(self): + lbeta = ( + torch.lgamma(0.5 * self.df) + + math.lgamma(0.5) + - torch.lgamma(0.5 * (self.df + 1)) + ) + return ( + self.scale.log() + + 0.5 + * (self.df + 1) + * (torch.digamma(0.5 * (self.df + 1)) - torch.digamma(0.5 * self.df)) + + 0.5 * self.df.log() + + lbeta + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/transformed_distribution.py b/phivenv/Lib/site-packages/torch/distributions/transformed_distribution.py new file mode 100644 index 0000000000000000000000000000000000000000..7ee1f5776fb5b20ee042bd6ad72aae6ddf871898 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/transformed_distribution.py @@ -0,0 +1,223 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.independent import Independent +from torch.distributions.transforms import ComposeTransform, Transform +from torch.distributions.utils import _sum_rightmost +from torch.types import _size + + +__all__ = ["TransformedDistribution"] + + +class TransformedDistribution(Distribution): + r""" + Extension of the Distribution class, which applies a sequence of Transforms + to a base distribution. Let f be the composition of transforms applied:: + + X ~ BaseDistribution + Y = f(X) ~ TransformedDistribution(BaseDistribution, f) + log p(Y) = log p(X) + log |det (dX/dY)| + + Note that the ``.event_shape`` of a :class:`TransformedDistribution` is the + maximum shape of its base distribution and its transforms, since transforms + can introduce correlations among events. + + An example for the usage of :class:`TransformedDistribution` would be:: + + # Building a Logistic Distribution + # X ~ Uniform(0, 1) + # f = a + b * logit(X) + # Y ~ f(X) ~ Logistic(a, b) + base_distribution = Uniform(0, 1) + transforms = [SigmoidTransform().inv, AffineTransform(loc=a, scale=b)] + logistic = TransformedDistribution(base_distribution, transforms) + + For more examples, please look at the implementations of + :class:`~torch.distributions.gumbel.Gumbel`, + :class:`~torch.distributions.half_cauchy.HalfCauchy`, + :class:`~torch.distributions.half_normal.HalfNormal`, + :class:`~torch.distributions.log_normal.LogNormal`, + :class:`~torch.distributions.pareto.Pareto`, + :class:`~torch.distributions.weibull.Weibull`, + :class:`~torch.distributions.relaxed_bernoulli.RelaxedBernoulli` and + :class:`~torch.distributions.relaxed_categorical.RelaxedOneHotCategorical` + """ + + arg_constraints: dict[str, constraints.Constraint] = {} + + def __init__( + self, + base_distribution: Distribution, + transforms: Union[Transform, list[Transform]], + validate_args: Optional[bool] = None, + ) -> None: + if isinstance(transforms, Transform): + self.transforms = [ + transforms, + ] + elif isinstance(transforms, list): + if not all(isinstance(t, Transform) for t in transforms): + raise ValueError( + "transforms must be a Transform or a list of Transforms" + ) + self.transforms = transforms + else: + raise ValueError( + f"transforms must be a Transform or list, but was {transforms}" + ) + + # Reshape base_distribution according to transforms. + base_shape = base_distribution.batch_shape + base_distribution.event_shape + base_event_dim = len(base_distribution.event_shape) + transform = ComposeTransform(self.transforms) + if len(base_shape) < transform.domain.event_dim: + raise ValueError( + f"base_distribution needs to have shape with size at least {transform.domain.event_dim}, but got {base_shape}." + ) + forward_shape = transform.forward_shape(base_shape) + expanded_base_shape = transform.inverse_shape(forward_shape) + if base_shape != expanded_base_shape: + base_batch_shape = expanded_base_shape[ + : len(expanded_base_shape) - base_event_dim + ] + base_distribution = base_distribution.expand(base_batch_shape) + reinterpreted_batch_ndims = transform.domain.event_dim - base_event_dim + if reinterpreted_batch_ndims > 0: + base_distribution = Independent( + base_distribution, reinterpreted_batch_ndims + ) + self.base_dist = base_distribution + + # Compute shapes. + transform_change_in_event_dim = ( + transform.codomain.event_dim - transform.domain.event_dim + ) + event_dim = max( + transform.codomain.event_dim, # the transform is coupled + base_event_dim + transform_change_in_event_dim, # the base dist is coupled + ) + assert len(forward_shape) >= event_dim + cut = len(forward_shape) - event_dim + batch_shape = forward_shape[:cut] + event_shape = forward_shape[cut:] + super().__init__(batch_shape, event_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(TransformedDistribution, _instance) + batch_shape = torch.Size(batch_shape) + shape = batch_shape + self.event_shape + for t in reversed(self.transforms): + shape = t.inverse_shape(shape) + base_batch_shape = shape[: len(shape) - len(self.base_dist.event_shape)] + new.base_dist = self.base_dist.expand(base_batch_shape) + new.transforms = self.transforms + super(TransformedDistribution, new).__init__( + batch_shape, self.event_shape, validate_args=False + ) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property(is_discrete=False) + def support(self): + if not self.transforms: + return self.base_dist.support + support = self.transforms[-1].codomain + if len(self.event_shape) > support.event_dim: + support = constraints.independent( + support, len(self.event_shape) - support.event_dim + ) + return support + + @property + def has_rsample(self) -> bool: # type: ignore[override] + return self.base_dist.has_rsample + + def sample(self, sample_shape=torch.Size()): + """ + Generates a sample_shape shaped sample or sample_shape shaped batch of + samples if the distribution parameters are batched. Samples first from + base distribution and applies `transform()` for every transform in the + list. + """ + with torch.no_grad(): + x = self.base_dist.sample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + """ + Generates a sample_shape shaped reparameterized sample or sample_shape + shaped batch of reparameterized samples if the distribution parameters + are batched. Samples first from base distribution and applies + `transform()` for every transform in the list. + """ + x = self.base_dist.rsample(sample_shape) + for transform in self.transforms: + x = transform(x) + return x + + def log_prob(self, value): + """ + Scores the sample by inverting the transform(s) and computing the score + using the score of the base distribution and the log abs det jacobian. + """ + if self._validate_args: + self._validate_sample(value) + event_dim = len(self.event_shape) + log_prob = 0.0 + y = value + for transform in reversed(self.transforms): + x = transform.inv(y) + event_dim += transform.domain.event_dim - transform.codomain.event_dim + log_prob = log_prob - _sum_rightmost( + transform.log_abs_det_jacobian(x, y), + event_dim - transform.domain.event_dim, + ) + y = x + + log_prob = log_prob + _sum_rightmost( + self.base_dist.log_prob(y), event_dim - len(self.base_dist.event_shape) + ) + return log_prob + + def _monotonize_cdf(self, value): + """ + This conditionally flips ``value -> 1-value`` to ensure :meth:`cdf` is + monotone increasing. + """ + sign = 1 + for transform in self.transforms: + sign = sign * transform.sign + if isinstance(sign, int) and sign == 1: + return value + return sign * (value - 0.5) + 0.5 + + def cdf(self, value): + """ + Computes the cumulative distribution function by inverting the + transform(s) and computing the score of the base distribution. + """ + for transform in self.transforms[::-1]: + value = transform.inv(value) + if self._validate_args: + self.base_dist._validate_sample(value) + value = self.base_dist.cdf(value) + value = self._monotonize_cdf(value) + return value + + def icdf(self, value): + """ + Computes the inverse cumulative distribution function using + transform(s) and computing the score of the base distribution. + """ + value = self._monotonize_cdf(value) + value = self.base_dist.icdf(value) + for transform in self.transforms: + value = transform(value) + return value diff --git a/phivenv/Lib/site-packages/torch/distributions/transforms.py b/phivenv/Lib/site-packages/torch/distributions/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..21e3b4863b1a902f02c4f348463a1baed1574f93 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/transforms.py @@ -0,0 +1,1287 @@ +# mypy: allow-untyped-defs +import functools +import math +import operator +import weakref +from collections.abc import Sequence +from typing import Optional, Union + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import ( + _sum_rightmost, + broadcast_all, + lazy_property, + tril_matrix_to_vec, + vec_to_tril_matrix, +) +from torch.nn.functional import pad, softplus +from torch.types import _Number + + +__all__ = [ + "AbsTransform", + "AffineTransform", + "CatTransform", + "ComposeTransform", + "CorrCholeskyTransform", + "CumulativeDistributionTransform", + "ExpTransform", + "IndependentTransform", + "LowerCholeskyTransform", + "PositiveDefiniteTransform", + "PowerTransform", + "ReshapeTransform", + "SigmoidTransform", + "SoftplusTransform", + "TanhTransform", + "SoftmaxTransform", + "StackTransform", + "StickBreakingTransform", + "Transform", + "identity_transform", +] + + +class Transform: + """ + Abstract class for invertable transformations with computable log + det jacobians. They are primarily used in + :class:`torch.distributions.TransformedDistribution`. + + Caching is useful for transforms whose inverses are either expensive or + numerically unstable. Note that care must be taken with memoized values + since the autograd graph may be reversed. For example while the following + works with or without caching:: + + y = t(x) + t.log_abs_det_jacobian(x, y).backward() # x will receive gradients. + + However the following will error when caching due to dependency reversal:: + + y = t(x) + z = t.inv(y) + grad(z.sum(), [y]) # error because z is x + + Derived classes should implement one or both of :meth:`_call` or + :meth:`_inverse`. Derived classes that set `bijective=True` should also + implement :meth:`log_abs_det_jacobian`. + + Args: + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. + + Attributes: + domain (:class:`~torch.distributions.constraints.Constraint`): + The constraint representing valid inputs to this transform. + codomain (:class:`~torch.distributions.constraints.Constraint`): + The constraint representing valid outputs to this transform + which are inputs to the inverse transform. + bijective (bool): Whether this transform is bijective. A transform + ``t`` is bijective iff ``t.inv(t(x)) == x`` and + ``t(t.inv(y)) == y`` for every ``x`` in the domain and ``y`` in + the codomain. Transforms that are not bijective should at least + maintain the weaker pseudoinverse properties + ``t(t.inv(t(x)) == t(x)`` and ``t.inv(t(t.inv(y))) == t.inv(y)``. + sign (int or Tensor): For bijective univariate transforms, this + should be +1 or -1 depending on whether transform is monotone + increasing or decreasing. + """ + + bijective = False + domain: constraints.Constraint + codomain: constraints.Constraint + + def __init__(self, cache_size: int = 0) -> None: + self._cache_size = cache_size + self._inv: Optional[weakref.ReferenceType[Transform]] = None + if cache_size == 0: + pass # default behavior + elif cache_size == 1: + self._cached_x_y = None, None + else: + raise ValueError("cache_size must be 0 or 1") + super().__init__() + + def __getstate__(self): + state = self.__dict__.copy() + state["_inv"] = None + return state + + @property + def event_dim(self) -> int: + if self.domain.event_dim == self.codomain.event_dim: + return self.domain.event_dim + raise ValueError("Please use either .domain.event_dim or .codomain.event_dim") + + @property + def inv(self) -> "Transform": + """ + Returns the inverse :class:`Transform` of this transform. + This should satisfy ``t.inv.inv is t``. + """ + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = _InverseTransform(self) + self._inv = weakref.ref(inv) + return inv + + @property + def sign(self) -> int: + """ + Returns the sign of the determinant of the Jacobian, if applicable. + In general this only makes sense for bijective transforms. + """ + raise NotImplementedError + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + if type(self).__init__ is Transform.__init__: + return type(self)(cache_size=cache_size) + raise NotImplementedError(f"{type(self)}.with_cache is not implemented") + + def __eq__(self, other): + return self is other + + def __ne__(self, other): + # Necessary for Python2 + return not self.__eq__(other) + + def __call__(self, x): + """ + Computes the transform `x => y`. + """ + if self._cache_size == 0: + return self._call(x) + x_old, y_old = self._cached_x_y + if x is x_old: + return y_old + y = self._call(x) + self._cached_x_y = x, y + return y + + def _inv_call(self, y): + """ + Inverts the transform `y => x`. + """ + if self._cache_size == 0: + return self._inverse(y) + x_old, y_old = self._cached_x_y + if y is y_old: + return x_old + x = self._inverse(y) + self._cached_x_y = x, y + return x + + def _call(self, x): + """ + Abstract method to compute forward transformation. + """ + raise NotImplementedError + + def _inverse(self, y): + """ + Abstract method to compute inverse transformation. + """ + raise NotImplementedError + + def log_abs_det_jacobian(self, x, y): + """ + Computes the log det jacobian `log |dy/dx|` given input and output. + """ + raise NotImplementedError + + def __repr__(self): + return self.__class__.__name__ + "()" + + def forward_shape(self, shape): + """ + Infers the shape of the forward computation, given the input shape. + Defaults to preserving shape. + """ + return shape + + def inverse_shape(self, shape): + """ + Infers the shapes of the inverse computation, given the output shape. + Defaults to preserving shape. + """ + return shape + + +class _InverseTransform(Transform): + """ + Inverts a single :class:`Transform`. + This class is private; please instead use the ``Transform.inv`` property. + """ + + def __init__(self, transform: Transform) -> None: + super().__init__(cache_size=transform._cache_size) + self._inv: Transform = transform # type: ignore[assignment] + + @constraints.dependent_property(is_discrete=False) + def domain(self): + assert self._inv is not None + return self._inv.codomain + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + assert self._inv is not None + return self._inv.domain + + @property + def bijective(self) -> bool: # type: ignore[override] + assert self._inv is not None + return self._inv.bijective + + @property + def sign(self) -> int: + assert self._inv is not None + return self._inv.sign + + @property + def inv(self) -> Transform: + return self._inv + + def with_cache(self, cache_size=1): + assert self._inv is not None + return self.inv.with_cache(cache_size).inv + + def __eq__(self, other): + if not isinstance(other, _InverseTransform): + return False + assert self._inv is not None + return self._inv == other._inv + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self._inv)})" + + def __call__(self, x): + assert self._inv is not None + return self._inv._inv_call(x) + + def log_abs_det_jacobian(self, x, y): + assert self._inv is not None + return -self._inv.log_abs_det_jacobian(y, x) + + def forward_shape(self, shape): + return self._inv.inverse_shape(shape) + + def inverse_shape(self, shape): + return self._inv.forward_shape(shape) + + +class ComposeTransform(Transform): + """ + Composes multiple transforms in a chain. + The transforms being composed are responsible for caching. + + Args: + parts (list of :class:`Transform`): A list of transforms to compose. + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. + """ + + def __init__(self, parts: list[Transform], cache_size: int = 0) -> None: + if cache_size: + parts = [part.with_cache(cache_size) for part in parts] + super().__init__(cache_size=cache_size) + self.parts = parts + + def __eq__(self, other): + if not isinstance(other, ComposeTransform): + return False + return self.parts == other.parts + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if not self.parts: + return constraints.real + domain = self.parts[0].domain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[-1].codomain.event_dim + for part in reversed(self.parts): + event_dim += part.domain.event_dim - part.codomain.event_dim + event_dim = max(event_dim, part.domain.event_dim) + assert event_dim >= domain.event_dim + if event_dim > domain.event_dim: + domain = constraints.independent(domain, event_dim - domain.event_dim) + return domain + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if not self.parts: + return constraints.real + codomain = self.parts[-1].codomain + # Adjust event_dim to be maximum among all parts. + event_dim = self.parts[0].domain.event_dim + for part in self.parts: + event_dim += part.codomain.event_dim - part.domain.event_dim + event_dim = max(event_dim, part.codomain.event_dim) + assert event_dim >= codomain.event_dim + if event_dim > codomain.event_dim: + codomain = constraints.independent(codomain, event_dim - codomain.event_dim) + return codomain + + @lazy_property + def bijective(self) -> bool: # type: ignore[override] + return all(p.bijective for p in self.parts) + + @lazy_property + def sign(self) -> int: # type: ignore[override] + sign = 1 + for p in self.parts: + sign = sign * p.sign + return sign + + @property + def inv(self) -> Transform: + inv = None + if self._inv is not None: + inv = self._inv() + if inv is None: + inv = ComposeTransform([p.inv for p in reversed(self.parts)]) + self._inv = weakref.ref(inv) + inv._inv = weakref.ref(self) + return inv + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ComposeTransform(self.parts, cache_size=cache_size) + + def __call__(self, x): + for part in self.parts: + x = part(x) + return x + + def log_abs_det_jacobian(self, x, y): + if not self.parts: + return torch.zeros_like(x) + + # Compute intermediates. This will be free if parts[:-1] are all cached. + xs = [x] + for part in self.parts[:-1]: + xs.append(part(xs[-1])) + xs.append(y) + + terms = [] + event_dim = self.domain.event_dim + for part, x, y in zip(self.parts, xs[:-1], xs[1:]): + terms.append( + _sum_rightmost( + part.log_abs_det_jacobian(x, y), event_dim - part.domain.event_dim + ) + ) + event_dim += part.codomain.event_dim - part.domain.event_dim + return functools.reduce(operator.add, terms) + + def forward_shape(self, shape): + for part in self.parts: + shape = part.forward_shape(shape) + return shape + + def inverse_shape(self, shape): + for part in reversed(self.parts): + shape = part.inverse_shape(shape) + return shape + + def __repr__(self): + fmt_string = self.__class__.__name__ + "(\n " + fmt_string += ",\n ".join([p.__repr__() for p in self.parts]) + fmt_string += "\n)" + return fmt_string + + +identity_transform = ComposeTransform([]) + + +class IndependentTransform(Transform): + """ + Wrapper around another transform to treat + ``reinterpreted_batch_ndims``-many extra of the right most dimensions as + dependent. This has no effect on the forward or backward transforms, but + does sum out ``reinterpreted_batch_ndims``-many of the rightmost dimensions + in :meth:`log_abs_det_jacobian`. + + Args: + base_transform (:class:`Transform`): A base transform. + reinterpreted_batch_ndims (int): The number of extra rightmost + dimensions to treat as dependent. + """ + + def __init__( + self, + base_transform: Transform, + reinterpreted_batch_ndims: int, + cache_size: int = 0, + ) -> None: + super().__init__(cache_size=cache_size) + self.base_transform = base_transform.with_cache(cache_size) + self.reinterpreted_batch_ndims = reinterpreted_batch_ndims + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return IndependentTransform( + self.base_transform, self.reinterpreted_batch_ndims, cache_size=cache_size + ) + + @constraints.dependent_property(is_discrete=False) + def domain(self): + return constraints.independent( + self.base_transform.domain, self.reinterpreted_batch_ndims + ) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + return constraints.independent( + self.base_transform.codomain, self.reinterpreted_batch_ndims + ) + + @property + def bijective(self) -> bool: # type: ignore[override] + return self.base_transform.bijective + + @property + def sign(self) -> int: + return self.base_transform.sign + + def _call(self, x): + if x.dim() < self.domain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform(x) + + def _inverse(self, y): + if y.dim() < self.codomain.event_dim: + raise ValueError("Too few dimensions on input") + return self.base_transform.inv(y) + + def log_abs_det_jacobian(self, x, y): + result = self.base_transform.log_abs_det_jacobian(x, y) + result = _sum_rightmost(result, self.reinterpreted_batch_ndims) + return result + + def __repr__(self): + return f"{self.__class__.__name__}({repr(self.base_transform)}, {self.reinterpreted_batch_ndims})" + + def forward_shape(self, shape): + return self.base_transform.forward_shape(shape) + + def inverse_shape(self, shape): + return self.base_transform.inverse_shape(shape) + + +class ReshapeTransform(Transform): + """ + Unit Jacobian transform to reshape the rightmost part of a tensor. + + Note that ``in_shape`` and ``out_shape`` must have the same number of + elements, just as for :meth:`torch.Tensor.reshape`. + + Arguments: + in_shape (torch.Size): The input event shape. + out_shape (torch.Size): The output event shape. + cache_size (int): Size of cache. If zero, no caching is done. If one, + the latest single value is cached. Only 0 and 1 are supported. (Default 0.) + """ + + bijective = True + + def __init__( + self, + in_shape: torch.Size, + out_shape: torch.Size, + cache_size: int = 0, + ) -> None: + self.in_shape = torch.Size(in_shape) + self.out_shape = torch.Size(out_shape) + if self.in_shape.numel() != self.out_shape.numel(): + raise ValueError("in_shape, out_shape have different numbers of elements") + super().__init__(cache_size=cache_size) + + @constraints.dependent_property + def domain(self): + return constraints.independent(constraints.real, len(self.in_shape)) + + @constraints.dependent_property + def codomain(self): + return constraints.independent(constraints.real, len(self.out_shape)) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return ReshapeTransform(self.in_shape, self.out_shape, cache_size=cache_size) + + def _call(self, x): + batch_shape = x.shape[: x.dim() - len(self.in_shape)] + return x.reshape(batch_shape + self.out_shape) + + def _inverse(self, y): + batch_shape = y.shape[: y.dim() - len(self.out_shape)] + return y.reshape(batch_shape + self.in_shape) + + def log_abs_det_jacobian(self, x, y): + batch_shape = x.shape[: x.dim() - len(self.in_shape)] + return x.new_zeros(batch_shape) + + def forward_shape(self, shape): + if len(shape) < len(self.in_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.in_shape) + if shape[cut:] != self.in_shape: + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.in_shape}" + ) + return shape[:cut] + self.out_shape + + def inverse_shape(self, shape): + if len(shape) < len(self.out_shape): + raise ValueError("Too few dimensions on input") + cut = len(shape) - len(self.out_shape) + if shape[cut:] != self.out_shape: + raise ValueError( + f"Shape mismatch: expected {shape[cut:]} but got {self.out_shape}" + ) + return shape[:cut] + self.in_shape + + +class ExpTransform(Transform): + r""" + Transform via the mapping :math:`y = \exp(x)`. + """ + + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, ExpTransform) + + def _call(self, x): + return x.exp() + + def _inverse(self, y): + return y.log() + + def log_abs_det_jacobian(self, x, y): + return x + + +class PowerTransform(Transform): + r""" + Transform via the mapping :math:`y = x^{\text{exponent}}`. + """ + + domain = constraints.positive + codomain = constraints.positive + bijective = True + + def __init__(self, exponent: Tensor, cache_size: int = 0) -> None: + super().__init__(cache_size=cache_size) + (self.exponent,) = broadcast_all(exponent) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return PowerTransform(self.exponent, cache_size=cache_size) + + @lazy_property + def sign(self) -> int: # type: ignore[override] + return self.exponent.sign() # type: ignore[return-value] + + def __eq__(self, other): + if not isinstance(other, PowerTransform): + return False + return self.exponent.eq(other.exponent).all().item() + + def _call(self, x): + return x.pow(self.exponent) + + def _inverse(self, y): + return y.pow(1 / self.exponent) + + def log_abs_det_jacobian(self, x, y): + return (self.exponent * y / x).abs().log() + + def forward_shape(self, shape): + return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + def inverse_shape(self, shape): + return torch.broadcast_shapes(shape, getattr(self.exponent, "shape", ())) + + +def _clipped_sigmoid(x): + finfo = torch.finfo(x.dtype) + return torch.clamp(torch.sigmoid(x), min=finfo.tiny, max=1.0 - finfo.eps) + + +class SigmoidTransform(Transform): + r""" + Transform via the mapping :math:`y = \frac{1}{1 + \exp(-x)}` and :math:`x = \text{logit}(y)`. + """ + + domain = constraints.real + codomain = constraints.unit_interval + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SigmoidTransform) + + def _call(self, x): + return _clipped_sigmoid(x) + + def _inverse(self, y): + finfo = torch.finfo(y.dtype) + y = y.clamp(min=finfo.tiny, max=1.0 - finfo.eps) + return y.log() - (-y).log1p() + + def log_abs_det_jacobian(self, x, y): + return -F.softplus(-x) - F.softplus(x) + + +class SoftplusTransform(Transform): + r""" + Transform via the mapping :math:`\text{Softplus}(x) = \log(1 + \exp(x))`. + The implementation reverts to the linear function when :math:`x > 20`. + """ + + domain = constraints.real + codomain = constraints.positive + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, SoftplusTransform) + + def _call(self, x): + return softplus(x) + + def _inverse(self, y): + return (-y).expm1().neg().log() + y + + def log_abs_det_jacobian(self, x, y): + return -softplus(-x) + + +class TanhTransform(Transform): + r""" + Transform via the mapping :math:`y = \tanh(x)`. + + It is equivalent to + + .. code-block:: python + + ComposeTransform( + [ + AffineTransform(0.0, 2.0), + SigmoidTransform(), + AffineTransform(-1.0, 2.0), + ] + ) + + However this might not be numerically stable, thus it is recommended to use `TanhTransform` + instead. + + Note that one should use `cache_size=1` when it comes to `NaN/Inf` values. + + """ + + domain = constraints.real + codomain = constraints.interval(-1.0, 1.0) + bijective = True + sign = +1 + + def __eq__(self, other): + return isinstance(other, TanhTransform) + + def _call(self, x): + return x.tanh() + + def _inverse(self, y): + # We do not clamp to the boundary here as it may degrade the performance of certain algorithms. + # one should use `cache_size=1` instead + return torch.atanh(y) + + def log_abs_det_jacobian(self, x, y): + # We use a formula that is more numerically stable, see details in the following link + # https://github.com/tensorflow/probability/blob/master/tensorflow_probability/python/bijectors/tanh.py#L69-L80 + return 2.0 * (math.log(2.0) - x - softplus(-2.0 * x)) + + +class AbsTransform(Transform): + r"""Transform via the mapping :math:`y = |x|`.""" + + domain = constraints.real + codomain = constraints.positive + + def __eq__(self, other): + return isinstance(other, AbsTransform) + + def _call(self, x): + return x.abs() + + def _inverse(self, y): + return y + + +class AffineTransform(Transform): + r""" + Transform via the pointwise affine mapping :math:`y = \text{loc} + \text{scale} \times x`. + + Args: + loc (Tensor or float): Location parameter. + scale (Tensor or float): Scale parameter. + event_dim (int): Optional size of `event_shape`. This should be zero + for univariate random variables, 1 for distributions over vectors, + 2 for distributions over matrices, etc. + """ + + bijective = True + + def __init__( + self, + loc: Union[Tensor, float], + scale: Union[Tensor, float], + event_dim: int = 0, + cache_size: int = 0, + ) -> None: + super().__init__(cache_size=cache_size) + self.loc = loc + self.scale = scale + self._event_dim = event_dim + + @property + def event_dim(self) -> int: + return self._event_dim + + @constraints.dependent_property(is_discrete=False) + def domain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + @constraints.dependent_property(is_discrete=False) + def codomain(self): + if self.event_dim == 0: + return constraints.real + return constraints.independent(constraints.real, self.event_dim) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return AffineTransform( + self.loc, self.scale, self.event_dim, cache_size=cache_size + ) + + def __eq__(self, other): + if not isinstance(other, AffineTransform): + return False + + if isinstance(self.loc, _Number) and isinstance(other.loc, _Number): + if self.loc != other.loc: + return False + else: + if not (self.loc == other.loc).all().item(): # type: ignore[union-attr] + return False + + if isinstance(self.scale, _Number) and isinstance(other.scale, _Number): + if self.scale != other.scale: + return False + else: + if not (self.scale == other.scale).all().item(): # type: ignore[union-attr] + return False + + return True + + @property + def sign(self) -> Union[Tensor, int]: # type: ignore[override] + if isinstance(self.scale, _Number): + return 1 if float(self.scale) > 0 else -1 if float(self.scale) < 0 else 0 + return self.scale.sign() + + def _call(self, x): + return self.loc + self.scale * x + + def _inverse(self, y): + return (y - self.loc) / self.scale + + def log_abs_det_jacobian(self, x, y): + shape = x.shape + scale = self.scale + if isinstance(scale, _Number): + result = torch.full_like(x, math.log(abs(scale))) + else: + result = torch.abs(scale).log() + if self.event_dim: + result_size = result.size()[: -self.event_dim] + (-1,) + result = result.view(result_size).sum(-1) + shape = shape[: -self.event_dim] + return result.expand(shape) + + def forward_shape(self, shape): + return torch.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) + + def inverse_shape(self, shape): + return torch.broadcast_shapes( + shape, getattr(self.loc, "shape", ()), getattr(self.scale, "shape", ()) + ) + + +class CorrCholeskyTransform(Transform): + r""" + Transforms an uncontrained real vector :math:`x` with length :math:`D*(D-1)/2` into the + Cholesky factor of a D-dimension correlation matrix. This Cholesky factor is a lower + triangular matrix with positive diagonals and unit Euclidean norm for each row. + The transform is processed as follows: + + 1. First we convert x into a lower triangular matrix in row order. + 2. For each row :math:`X_i` of the lower triangular part, we apply a *signed* version of + class :class:`StickBreakingTransform` to transform :math:`X_i` into a + unit Euclidean length vector using the following steps: + - Scales into the interval :math:`(-1, 1)` domain: :math:`r_i = \tanh(X_i)`. + - Transforms into an unsigned domain: :math:`z_i = r_i^2`. + - Applies :math:`s_i = StickBreakingTransform(z_i)`. + - Transforms back into signed domain: :math:`y_i = sign(r_i) * \sqrt{s_i}`. + """ + + domain = constraints.real_vector + codomain = constraints.corr_cholesky + bijective = True + + def _call(self, x): + x = torch.tanh(x) + eps = torch.finfo(x.dtype).eps + x = x.clamp(min=-1 + eps, max=1 - eps) + r = vec_to_tril_matrix(x, diag=-1) + # apply stick-breaking on the squared values + # Note that y = sign(r) * sqrt(z * z1m_cumprod) + # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) + z = r**2 + z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) + # Diagonal elements must be 1. + r = r + torch.eye(r.shape[-1], dtype=r.dtype, device=r.device) + y = r * pad(z1m_cumprod_sqrt[..., :-1], [1, 0], value=1) + return y + + def _inverse(self, y): + # inverse stick-breaking + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y_cumsum = 1 - torch.cumsum(y * y, dim=-1) + y_cumsum_shifted = pad(y_cumsum[..., :-1], [1, 0], value=1) + y_vec = tril_matrix_to_vec(y, diag=-1) + y_cumsum_vec = tril_matrix_to_vec(y_cumsum_shifted, diag=-1) + t = y_vec / (y_cumsum_vec).sqrt() + # inverse of tanh + x = (t.log1p() - t.neg().log1p()) / 2 + return x + + def log_abs_det_jacobian(self, x, y, intermediates=None): + # Because domain and codomain are two spaces with different dimensions, determinant of + # Jacobian is not well-defined. We return `log_abs_det_jacobian` of `x` and the + # flattened lower triangular part of `y`. + + # See: https://mc-stan.org/docs/2_18/reference-manual/cholesky-factors-of-correlation-matrices-1.html + y1m_cumsum = 1 - (y * y).cumsum(dim=-1) + # by taking diagonal=-2, we don't need to shift z_cumprod to the right + # also works for 2 x 2 matrix + y1m_cumsum_tril = tril_matrix_to_vec(y1m_cumsum, diag=-2) + stick_breaking_logdet = 0.5 * (y1m_cumsum_tril).log().sum(-1) + tanh_logdet = -2 * (x + softplus(-2 * x) - math.log(2.0)).sum(dim=-1) + return stick_breaking_logdet + tanh_logdet + + def forward_shape(self, shape): + # Reshape from (..., N) to (..., D, D). + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + N = shape[-1] + D = round((0.25 + 2 * N) ** 0.5 + 0.5) + if D * (D - 1) // 2 != N: + raise ValueError("Input is not a flattend lower-diagonal number") + return shape[:-1] + (D, D) + + def inverse_shape(self, shape): + # Reshape from (..., D, D) to (..., N). + if len(shape) < 2: + raise ValueError("Too few dimensions on input") + if shape[-2] != shape[-1]: + raise ValueError("Input is not square") + D = shape[-1] + N = D * (D - 1) // 2 + return shape[:-2] + (N,) + + +class SoftmaxTransform(Transform): + r""" + Transform from unconstrained space to the simplex via :math:`y = \exp(x)` then + normalizing. + + This is not bijective and cannot be used for HMC. However this acts mostly + coordinate-wise (except for the final normalization), and thus is + appropriate for coordinate-wise optimization algorithms. + """ + + domain = constraints.real_vector + codomain = constraints.simplex + + def __eq__(self, other): + return isinstance(other, SoftmaxTransform) + + def _call(self, x): + logprobs = x + probs = (logprobs - logprobs.max(-1, True)[0]).exp() + return probs / probs.sum(-1, True) + + def _inverse(self, y): + probs = y + return probs.log() + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape + + +class StickBreakingTransform(Transform): + """ + Transform from unconstrained space to the simplex of one additional + dimension via a stick-breaking process. + + This transform arises as an iterated sigmoid transform in a stick-breaking + construction of the `Dirichlet` distribution: the first logit is + transformed via sigmoid to the first probability and the probability of + everything else, and then the process recurses. + + This is bijective and appropriate for use in HMC; however it mixes + coordinates together and is less appropriate for optimization. + """ + + domain = constraints.real_vector + codomain = constraints.simplex + bijective = True + + def __eq__(self, other): + return isinstance(other, StickBreakingTransform) + + def _call(self, x): + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + z = _clipped_sigmoid(x - offset.log()) + z_cumprod = (1 - z).cumprod(-1) + y = pad(z, [0, 1], value=1) * pad(z_cumprod, [1, 0], value=1) + return y + + def _inverse(self, y): + y_crop = y[..., :-1] + offset = y.shape[-1] - y.new_ones(y_crop.shape[-1]).cumsum(-1) + sf = 1 - y_crop.cumsum(-1) + # we clamp to make sure that sf is positive which sometimes does not + # happen when y[-1] ~ 0 or y[:-1].sum() ~ 1 + sf = torch.clamp(sf, min=torch.finfo(y.dtype).tiny) + x = y_crop.log() - sf.log() + offset.log() + return x + + def log_abs_det_jacobian(self, x, y): + offset = x.shape[-1] + 1 - x.new_ones(x.shape[-1]).cumsum(-1) + x = x - offset.log() + # use the identity 1 - sigmoid(x) = exp(-x) * sigmoid(x) + detJ = (-x + F.logsigmoid(x) + y[..., :-1].log()).sum(-1) + return detJ + + def forward_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] + 1,) + + def inverse_shape(self, shape): + if len(shape) < 1: + raise ValueError("Too few dimensions on input") + return shape[:-1] + (shape[-1] - 1,) + + +class LowerCholeskyTransform(Transform): + """ + Transform from unconstrained matrices to lower-triangular matrices with + nonnegative diagonal entries. + + This is useful for parameterizing positive definite matrices in terms of + their Cholesky factorization. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = constraints.lower_cholesky + + def __eq__(self, other): + return isinstance(other, LowerCholeskyTransform) + + def _call(self, x): + return x.tril(-1) + x.diagonal(dim1=-2, dim2=-1).exp().diag_embed() + + def _inverse(self, y): + return y.tril(-1) + y.diagonal(dim1=-2, dim2=-1).log().diag_embed() + + +class PositiveDefiniteTransform(Transform): + """ + Transform from unconstrained matrices to positive-definite matrices. + """ + + domain = constraints.independent(constraints.real, 2) + codomain = constraints.positive_definite + + def __eq__(self, other): + return isinstance(other, PositiveDefiniteTransform) + + def _call(self, x): + x = LowerCholeskyTransform()(x) + return x @ x.mT + + def _inverse(self, y): + y = torch.linalg.cholesky(y) + return LowerCholeskyTransform().inv(y) + + +class CatTransform(Transform): + """ + Transform functor that applies a sequence of transforms `tseq` + component-wise to each submatrix at `dim`, of length `lengths[dim]`, + in a way compatible with :func:`torch.cat`. + + Example:: + + x0 = torch.cat([torch.range(1, 10), torch.range(1, 10)], dim=0) + x = torch.cat([x0, x0], dim=0) + t0 = CatTransform([ExpTransform(), identity_transform], dim=0, lengths=[10, 10]) + t = CatTransform([t0, t0], dim=0, lengths=[20, 20]) + y = t(x) + """ + + transforms: list[Transform] + + def __init__( + self, + tseq: Sequence[Transform], + dim: int = 0, + lengths: Optional[Sequence[int]] = None, + cache_size: int = 0, + ) -> None: + assert all(isinstance(t, Transform) for t in tseq) + if cache_size: + tseq = [t.with_cache(cache_size) for t in tseq] + super().__init__(cache_size=cache_size) + self.transforms = list(tseq) + if lengths is None: + lengths = [1] * len(self.transforms) + self.lengths = list(lengths) + assert len(self.lengths) == len(self.transforms) + self.dim = dim + + @lazy_property + def event_dim(self) -> int: # type: ignore[override] + return max(t.event_dim for t in self.transforms) + + @lazy_property + def length(self) -> int: + return sum(self.lengths) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return CatTransform(self.transforms, self.dim, self.lengths, cache_size) + + def _call(self, x): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == self.length + yslices = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + xslice = x.narrow(self.dim, start, length) + yslices.append(trans(xslice)) + start = start + length # avoid += for jit compat + return torch.cat(yslices, dim=self.dim) + + def _inverse(self, y): + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == self.length + xslices = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + yslice = y.narrow(self.dim, start, length) + xslices.append(trans.inv(yslice)) + start = start + length # avoid += for jit compat + return torch.cat(xslices, dim=self.dim) + + def log_abs_det_jacobian(self, x, y): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == self.length + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == self.length + logdetjacs = [] + start = 0 + for trans, length in zip(self.transforms, self.lengths): + xslice = x.narrow(self.dim, start, length) + yslice = y.narrow(self.dim, start, length) + logdetjac = trans.log_abs_det_jacobian(xslice, yslice) + if trans.event_dim < self.event_dim: + logdetjac = _sum_rightmost(logdetjac, self.event_dim - trans.event_dim) + logdetjacs.append(logdetjac) + start = start + length # avoid += for jit compat + # Decide whether to concatenate or sum. + dim = self.dim + if dim >= 0: + dim = dim - x.dim() + dim = dim + self.event_dim + if dim < 0: + return torch.cat(logdetjacs, dim=dim) + else: + return sum(logdetjacs) + + @property + def bijective(self) -> bool: # type: ignore[override] + return all(t.bijective for t in self.transforms) + + @constraints.dependent_property + def domain(self): + return constraints.cat( + [t.domain for t in self.transforms], self.dim, self.lengths + ) + + @constraints.dependent_property + def codomain(self): + return constraints.cat( + [t.codomain for t in self.transforms], self.dim, self.lengths + ) + + +class StackTransform(Transform): + """ + Transform functor that applies a sequence of transforms `tseq` + component-wise to each submatrix at `dim` + in a way compatible with :func:`torch.stack`. + + Example:: + + x = torch.stack([torch.range(1, 10), torch.range(1, 10)], dim=1) + t = StackTransform([ExpTransform(), identity_transform], dim=1) + y = t(x) + """ + + transforms: list[Transform] + + def __init__( + self, tseq: Sequence[Transform], dim: int = 0, cache_size: int = 0 + ) -> None: + assert all(isinstance(t, Transform) for t in tseq) + if cache_size: + tseq = [t.with_cache(cache_size) for t in tseq] + super().__init__(cache_size=cache_size) + self.transforms = list(tseq) + self.dim = dim + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return StackTransform(self.transforms, self.dim, cache_size) + + def _slice(self, z): + return [z.select(self.dim, i) for i in range(z.size(self.dim))] + + def _call(self, x): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == len(self.transforms) + yslices = [] + for xslice, trans in zip(self._slice(x), self.transforms): + yslices.append(trans(xslice)) + return torch.stack(yslices, dim=self.dim) + + def _inverse(self, y): + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == len(self.transforms) + xslices = [] + for yslice, trans in zip(self._slice(y), self.transforms): + xslices.append(trans.inv(yslice)) + return torch.stack(xslices, dim=self.dim) + + def log_abs_det_jacobian(self, x, y): + assert -x.dim() <= self.dim < x.dim() + assert x.size(self.dim) == len(self.transforms) + assert -y.dim() <= self.dim < y.dim() + assert y.size(self.dim) == len(self.transforms) + logdetjacs = [] + yslices = self._slice(y) + xslices = self._slice(x) + for xslice, yslice, trans in zip(xslices, yslices, self.transforms): + logdetjacs.append(trans.log_abs_det_jacobian(xslice, yslice)) + return torch.stack(logdetjacs, dim=self.dim) + + @property + def bijective(self) -> bool: # type: ignore[override] + return all(t.bijective for t in self.transforms) + + @constraints.dependent_property + def domain(self): + return constraints.stack([t.domain for t in self.transforms], self.dim) + + @constraints.dependent_property + def codomain(self): + return constraints.stack([t.codomain for t in self.transforms], self.dim) + + +class CumulativeDistributionTransform(Transform): + """ + Transform via the cumulative distribution function of a probability distribution. + + Args: + distribution (Distribution): Distribution whose cumulative distribution function to use for + the transformation. + + Example:: + + # Construct a Gaussian copula from a multivariate normal. + base_dist = MultivariateNormal( + loc=torch.zeros(2), + scale_tril=LKJCholesky(2).sample(), + ) + transform = CumulativeDistributionTransform(Normal(0, 1)) + copula = TransformedDistribution(base_dist, [transform]) + """ + + bijective = True + codomain = constraints.unit_interval + sign = +1 + + def __init__(self, distribution: Distribution, cache_size: int = 0) -> None: + super().__init__(cache_size=cache_size) + self.distribution = distribution + + @property + def domain(self) -> Optional[constraints.Constraint]: # type: ignore[override] + return self.distribution.support + + def _call(self, x): + return self.distribution.cdf(x) + + def _inverse(self, y): + return self.distribution.icdf(y) + + def log_abs_det_jacobian(self, x, y): + return self.distribution.log_prob(x) + + def with_cache(self, cache_size=1): + if self._cache_size == cache_size: + return self + return CumulativeDistributionTransform(self.distribution, cache_size=cache_size) diff --git a/phivenv/Lib/site-packages/torch/distributions/uniform.py b/phivenv/Lib/site-packages/torch/distributions/uniform.py new file mode 100644 index 0000000000000000000000000000000000000000..9e310dc23b5ab5f172554c0b35ef9653ac683844 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/uniform.py @@ -0,0 +1,108 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import nan, Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all +from torch.types import _Number, _size + + +__all__ = ["Uniform"] + + +class Uniform(Distribution): + r""" + Generates uniformly distributed random samples from the half-open interval + ``[low, high)``. + + Example:: + + >>> m = Uniform(torch.tensor([0.0]), torch.tensor([5.0])) + >>> m.sample() # uniformly distributed in the range [0.0, 5.0) + >>> # xdoctest: +SKIP + tensor([ 2.3418]) + + Args: + low (float or Tensor): lower range (inclusive). + high (float or Tensor): upper range (exclusive). + """ + + has_rsample = True + + @property + def arg_constraints(self): + # TODO allow (loc,scale) parameterization to allow independent constraints. + return { + "low": constraints.less_than(self.high), + "high": constraints.greater_than(self.low), + } + + @property + def mean(self) -> Tensor: + return (self.high + self.low) / 2 + + @property + def mode(self) -> Tensor: + return nan * self.high + + @property + def stddev(self) -> Tensor: + return (self.high - self.low) / 12**0.5 + + @property + def variance(self) -> Tensor: + return (self.high - self.low).pow(2) / 12 + + def __init__( + self, + low: Union[Tensor, float], + high: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.low, self.high = broadcast_all(low, high) + + if isinstance(low, _Number) and isinstance(high, _Number): + batch_shape = torch.Size() + else: + batch_shape = self.low.size() + super().__init__(batch_shape, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Uniform, _instance) + batch_shape = torch.Size(batch_shape) + new.low = self.low.expand(batch_shape) + new.high = self.high.expand(batch_shape) + super(Uniform, new).__init__(batch_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self): + return constraints.interval(self.low, self.high) + + def rsample(self, sample_shape: _size = torch.Size()) -> Tensor: + shape = self._extended_shape(sample_shape) + rand = torch.rand(shape, dtype=self.low.dtype, device=self.low.device) + return self.low + rand * (self.high - self.low) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + lb = self.low.le(value).type_as(self.low) + ub = self.high.gt(value).type_as(self.low) + return torch.log(lb.mul(ub)) - torch.log(self.high - self.low) + + def cdf(self, value): + if self._validate_args: + self._validate_sample(value) + result = (value - self.low) / (self.high - self.low) + return result.clamp(min=0, max=1) + + def icdf(self, value): + result = value * (self.high - self.low) + self.low + return result + + def entropy(self): + return torch.log(self.high - self.low) diff --git a/phivenv/Lib/site-packages/torch/distributions/utils.py b/phivenv/Lib/site-packages/torch/distributions/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53e1d79f73eaad160a38f08a0ee3e60616513fd1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/utils.py @@ -0,0 +1,221 @@ +from collections.abc import Sequence +from functools import update_wrapper +from typing import Any, Callable, Final, Generic, Optional, overload, TypeVar, Union + +import torch +import torch.nn.functional as F +from torch import SymInt, Tensor +from torch.overrides import is_tensor_like +from torch.types import _dtype, _Number, Device, Number + + +euler_constant: Final[float] = 0.57721566490153286060 # Euler Mascheroni Constant + +__all__ = [ + "broadcast_all", + "logits_to_probs", + "clamp_probs", + "probs_to_logits", + "lazy_property", + "tril_matrix_to_vec", + "vec_to_tril_matrix", +] + + +# FIXME: Use (*values: *Ts) -> tuple[Tensor for T in Ts] if Mapping-Type is ever added. +# See https://github.com/python/typing/issues/1216#issuecomment-2126153831 +def broadcast_all(*values: Union[Tensor, Number]) -> tuple[Tensor, ...]: + r""" + Given a list of values (possibly containing numbers), returns a list where each + value is broadcasted based on the following rules: + - `torch.*Tensor` instances are broadcasted as per :ref:`_broadcasting-semantics`. + - Number instances (scalars) are upcast to tensors having + the same size and type as the first tensor passed to `values`. If all the + values are scalars, then they are upcasted to scalar Tensors. + + Args: + values (list of `Number`, `torch.*Tensor` or objects implementing __torch_function__) + + Raises: + ValueError: if any of the values is not a `Number` instance, + a `torch.*Tensor` instance, or an instance implementing __torch_function__ + """ + if not all(is_tensor_like(v) or isinstance(v, _Number) for v in values): + raise ValueError( + "Input arguments must all be instances of Number, " + "torch.Tensor or objects implementing __torch_function__." + ) + if not all(is_tensor_like(v) for v in values): + options: dict[str, Any] = dict(dtype=torch.get_default_dtype()) + for value in values: + if isinstance(value, torch.Tensor): + options = dict(dtype=value.dtype, device=value.device) + break + new_values = [ + v if is_tensor_like(v) else torch.tensor(v, **options) for v in values + ] + return torch.broadcast_tensors(*new_values) + return torch.broadcast_tensors(*values) + + +def _standard_normal( + shape: Sequence[Union[int, SymInt]], + dtype: Optional[_dtype], + device: Optional[Device], +) -> Tensor: + if torch._C._get_tracing_state(): + # [JIT WORKAROUND] lack of support for .normal_() + return torch.normal( + torch.zeros(shape, dtype=dtype, device=device), + torch.ones(shape, dtype=dtype, device=device), + ) + return torch.empty(shape, dtype=dtype, device=device).normal_() + + +def _sum_rightmost(value: Tensor, dim: int) -> Tensor: + r""" + Sum out ``dim`` many rightmost dimensions of a given tensor. + + Args: + value (Tensor): A tensor of ``.dim()`` at least ``dim``. + dim (int): The number of rightmost dims to sum out. + """ + if dim == 0: + return value + required_shape = value.shape[:-dim] + (-1,) + return value.reshape(required_shape).sum(-1) + + +def logits_to_probs(logits: Tensor, is_binary: bool = False) -> Tensor: + r""" + Converts a tensor of logits into probabilities. Note that for the + binary case, each value denotes log odds, whereas for the + multi-dimensional case, the values along the last dimension denote + the log probabilities (possibly unnormalized) of the events. + """ + if is_binary: + return torch.sigmoid(logits) + return F.softmax(logits, dim=-1) + + +def clamp_probs(probs: Tensor) -> Tensor: + """Clamps the probabilities to be in the open interval `(0, 1)`. + + The probabilities would be clamped between `eps` and `1 - eps`, + and `eps` would be the smallest representable positive number for the input data type. + + Args: + probs (Tensor): A tensor of probabilities. + + Returns: + Tensor: The clamped probabilities. + + Examples: + >>> probs = torch.tensor([0.0, 0.5, 1.0]) + >>> clamp_probs(probs) + tensor([1.1921e-07, 5.0000e-01, 1.0000e+00]) + + >>> probs = torch.tensor([0.0, 0.5, 1.0], dtype=torch.float64) + >>> clamp_probs(probs) + tensor([2.2204e-16, 5.0000e-01, 1.0000e+00], dtype=torch.float64) + + """ + eps = torch.finfo(probs.dtype).eps + return probs.clamp(min=eps, max=1 - eps) + + +def probs_to_logits(probs: Tensor, is_binary: bool = False) -> Tensor: + r""" + Converts a tensor of probabilities into logits. For the binary case, + this denotes the probability of occurrence of the event indexed by `1`. + For the multi-dimensional case, the values along the last dimension + denote the probabilities of occurrence of each of the events. + """ + ps_clamped = clamp_probs(probs) + if is_binary: + return torch.log(ps_clamped) - torch.log1p(-ps_clamped) + return torch.log(ps_clamped) + + +T = TypeVar("T", contravariant=True) +R = TypeVar("R", covariant=True) + + +class lazy_property(Generic[T, R]): + r""" + Used as a decorator for lazy loading of class attributes. This uses a + non-data descriptor that calls the wrapped method to compute the property on + first call; thereafter replacing the wrapped method into an instance + attribute. + """ + + def __init__(self, wrapped: Callable[[T], R]) -> None: + self.wrapped: Callable[[T], R] = wrapped + update_wrapper(self, wrapped) # type:ignore[arg-type] + + @overload + def __get__( + self, instance: None, obj_type: Any = None + ) -> "_lazy_property_and_property[T, R]": ... + + @overload + def __get__(self, instance: T, obj_type: Any = None) -> R: ... + + def __get__( + self, instance: Union[T, None], obj_type: Any = None + ) -> "R | _lazy_property_and_property[T, R]": + if instance is None: + return _lazy_property_and_property(self.wrapped) + with torch.enable_grad(): + value = self.wrapped(instance) + setattr(instance, self.wrapped.__name__, value) + return value + + +class _lazy_property_and_property(lazy_property[T, R], property): + """We want lazy properties to look like multiple things. + + * property when Sphinx autodoc looks + * lazy_property when Distribution validate_args looks + """ + + def __init__(self, wrapped: Callable[[T], R]) -> None: + property.__init__(self, wrapped) + + +def tril_matrix_to_vec(mat: Tensor, diag: int = 0) -> Tensor: + r""" + Convert a `D x D` matrix or a batch of matrices into a (batched) vector + which comprises of lower triangular elements from the matrix in row order. + """ + n = mat.shape[-1] + if not torch._C._get_tracing_state() and (diag < -n or diag >= n): + raise ValueError(f"diag ({diag}) provided is outside [{-n}, {n - 1}].") + arange = torch.arange(n, device=mat.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + vec = mat[..., tril_mask] + return vec + + +def vec_to_tril_matrix(vec: Tensor, diag: int = 0) -> Tensor: + r""" + Convert a vector or a batch of vectors into a batched `D x D` + lower triangular matrix containing elements from the vector in row order. + """ + # +ve root of D**2 + (1+2*diag)*D - |diag| * (diag+1) - 2*vec.shape[-1] = 0 + n = ( + -(1 + 2 * diag) + + ((1 + 2 * diag) ** 2 + 8 * vec.shape[-1] + 4 * abs(diag) * (diag + 1)) ** 0.5 + ) / 2 + eps = torch.finfo(vec.dtype).eps + if not torch._C._get_tracing_state() and (round(n) - n > eps): + raise ValueError( + f"The size of last dimension is {vec.shape[-1]} which cannot be expressed as " + + "the lower triangular part of a square D x D matrix." + ) + n = round(n.item()) if isinstance(n, torch.Tensor) else round(n) + mat = vec.new_zeros(vec.shape[:-1] + torch.Size((n, n))) + arange = torch.arange(n, device=vec.device) + tril_mask = arange < arange.view(-1, 1) + (diag + 1) + mat[..., tril_mask] = vec + return mat diff --git a/phivenv/Lib/site-packages/torch/distributions/von_mises.py b/phivenv/Lib/site-packages/torch/distributions/von_mises.py new file mode 100644 index 0000000000000000000000000000000000000000..7b42f6aa0f4344899ec2f7db9942b006d9d605a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/von_mises.py @@ -0,0 +1,218 @@ +# mypy: allow-untyped-defs +import math +from typing import Optional + +import torch +import torch.jit +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.distribution import Distribution +from torch.distributions.utils import broadcast_all, lazy_property + + +__all__ = ["VonMises"] + + +def _eval_poly(y, coef): + coef = list(coef) + result = coef.pop() + while coef: + result = coef.pop() + y * result + return result + + +_I0_COEF_SMALL = [ + 1.0, + 3.5156229, + 3.0899424, + 1.2067492, + 0.2659732, + 0.360768e-1, + 0.45813e-2, +] +_I0_COEF_LARGE = [ + 0.39894228, + 0.1328592e-1, + 0.225319e-2, + -0.157565e-2, + 0.916281e-2, + -0.2057706e-1, + 0.2635537e-1, + -0.1647633e-1, + 0.392377e-2, +] +_I1_COEF_SMALL = [ + 0.5, + 0.87890594, + 0.51498869, + 0.15084934, + 0.2658733e-1, + 0.301532e-2, + 0.32411e-3, +] +_I1_COEF_LARGE = [ + 0.39894228, + -0.3988024e-1, + -0.362018e-2, + 0.163801e-2, + -0.1031555e-1, + 0.2282967e-1, + -0.2895312e-1, + 0.1787654e-1, + -0.420059e-2, +] + +_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL] +_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE] + + +def _log_modified_bessel_fn(x, order=0): + """ + Returns ``log(I_order(x))`` for ``x > 0``, + where `order` is either 0 or 1. + """ + assert order == 0 or order == 1 + + # compute small solution + y = x / 3.75 + y = y * y + small = _eval_poly(y, _COEF_SMALL[order]) + if order == 1: + small = x.abs() * small + small = small.log() + + # compute large solution + y = 3.75 / x + large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log() + + result = torch.where(x < 3.75, small, large) + return result + + +@torch.jit.script_if_tracing +def _rejection_sample(loc, concentration, proposal_r, x): + done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) + while not done.all(): + u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) + u1, u2, u3 = u.unbind() + z = torch.cos(math.pi * u1) + f = (1 + proposal_r * z) / (proposal_r + z) + c = concentration * (proposal_r - f) + accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) + if accept.any(): + x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) + done = done | accept + return (x + math.pi + loc) % (2 * math.pi) - math.pi + + +class VonMises(Distribution): + """ + A circular von Mises distribution. + + This implementation uses polar coordinates. The ``loc`` and ``value`` args + can be any real number (to facilitate unconstrained optimization), but are + interpreted as angles modulo 2 pi. + + Example:: + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # von Mises distributed with loc=1 and concentration=1 + tensor([1.9777]) + + :param torch.Tensor loc: an angle in radians. + :param torch.Tensor concentration: concentration parameter + """ + + arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} + support = constraints.real + has_rsample = False + + def __init__( + self, + loc: Tensor, + concentration: Tensor, + validate_args: Optional[bool] = None, + ) -> None: + self.loc, self.concentration = broadcast_all(loc, concentration) + batch_shape = self.loc.shape + event_shape = torch.Size() + super().__init__(batch_shape, event_shape, validate_args) + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + log_prob = self.concentration * torch.cos(value - self.loc) + log_prob = ( + log_prob + - math.log(2 * math.pi) + - _log_modified_bessel_fn(self.concentration, order=0) + ) + return log_prob + + @lazy_property + def _loc(self) -> Tensor: + return self.loc.to(torch.double) + + @lazy_property + def _concentration(self) -> Tensor: + return self.concentration.to(torch.double) + + @lazy_property + def _proposal_r(self) -> Tensor: + kappa = self._concentration + tau = 1 + (1 + 4 * kappa**2).sqrt() + rho = (tau - (2 * tau).sqrt()) / (2 * kappa) + _proposal_r = (1 + rho**2) / (2 * rho) + # second order Taylor expansion around 0 for small kappa + _proposal_r_taylor = 1 / kappa + kappa + return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r) + + @torch.no_grad() + def sample(self, sample_shape=torch.Size()): + """ + The sampling algorithm for the von Mises distribution is based on the + following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the + von Mises distribution." Applied Statistics (1979): 152-157. + + Sampling is always done in double precision internally to avoid a hang + in _rejection_sample() for small values of the concentration, which + starts to happen for single precision around 1e-4 (see issue #88443). + """ + shape = self._extended_shape(sample_shape) + x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device) + return _rejection_sample( + self._loc, self._concentration, self._proposal_r, x + ).to(self.loc.dtype) + + def expand(self, batch_shape, _instance=None): + try: + return super().expand(batch_shape) + except NotImplementedError: + validate_args = self.__dict__.get("_validate_args") + loc = self.loc.expand(batch_shape) + concentration = self.concentration.expand(batch_shape) + return type(self)(loc, concentration, validate_args=validate_args) + + @property + def mean(self) -> Tensor: + """ + The provided mean is the circular one. + """ + return self.loc + + @property + def mode(self) -> Tensor: + return self.loc + + @lazy_property + def variance(self) -> Tensor: # type: ignore[override] + """ + The provided variance is the circular one. + """ + return ( + 1 + - ( + _log_modified_bessel_fn(self.concentration, order=1) + - _log_modified_bessel_fn(self.concentration, order=0) + ).exp() + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/weibull.py b/phivenv/Lib/site-packages/torch/distributions/weibull.py new file mode 100644 index 0000000000000000000000000000000000000000..d94714eff64b9300fd9f941f54db85c7db3e1dd0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/weibull.py @@ -0,0 +1,95 @@ +# mypy: allow-untyped-defs +from typing import Optional, Union + +import torch +from torch import Tensor +from torch.distributions import constraints +from torch.distributions.exponential import Exponential +from torch.distributions.gumbel import euler_constant +from torch.distributions.transformed_distribution import TransformedDistribution +from torch.distributions.transforms import AffineTransform, PowerTransform +from torch.distributions.utils import broadcast_all + + +__all__ = ["Weibull"] + + +class Weibull(TransformedDistribution): + r""" + Samples from a two-parameter Weibull distribution. + + Example: + + >>> # xdoctest: +IGNORE_WANT("non-deterministic") + >>> m = Weibull(torch.tensor([1.0]), torch.tensor([1.0])) + >>> m.sample() # sample from a Weibull distribution with scale=1, concentration=1 + tensor([ 0.4784]) + + Args: + scale (float or Tensor): Scale parameter of distribution (lambda). + concentration (float or Tensor): Concentration parameter of distribution (k/shape). + validate_args (bool, optional): Whether to validate arguments. Default: None. + """ + + arg_constraints = { + "scale": constraints.positive, + "concentration": constraints.positive, + } + support = constraints.positive + + def __init__( + self, + scale: Union[Tensor, float], + concentration: Union[Tensor, float], + validate_args: Optional[bool] = None, + ) -> None: + self.scale, self.concentration = broadcast_all(scale, concentration) + self.concentration_reciprocal = self.concentration.reciprocal() + base_dist = Exponential( + torch.ones_like(self.scale), validate_args=validate_args + ) + transforms = [ + PowerTransform(exponent=self.concentration_reciprocal), + AffineTransform(loc=0, scale=self.scale), + ] + super().__init__(base_dist, transforms, validate_args=validate_args) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Weibull, _instance) + new.scale = self.scale.expand(batch_shape) + new.concentration = self.concentration.expand(batch_shape) + new.concentration_reciprocal = new.concentration.reciprocal() + base_dist = self.base_dist.expand(batch_shape) + transforms = [ + PowerTransform(exponent=new.concentration_reciprocal), + AffineTransform(loc=0, scale=new.scale), + ] + super(Weibull, new).__init__(base_dist, transforms, validate_args=False) + new._validate_args = self._validate_args + return new + + @property + def mean(self) -> Tensor: + return self.scale * torch.exp(torch.lgamma(1 + self.concentration_reciprocal)) + + @property + def mode(self) -> Tensor: + return ( + self.scale + * ((self.concentration - 1) / self.concentration) + ** self.concentration.reciprocal() + ) + + @property + def variance(self) -> Tensor: + return self.scale.pow(2) * ( + torch.exp(torch.lgamma(1 + 2 * self.concentration_reciprocal)) + - torch.exp(2 * torch.lgamma(1 + self.concentration_reciprocal)) + ) + + def entropy(self): + return ( + euler_constant * (1 - self.concentration_reciprocal) + + torch.log(self.scale * self.concentration_reciprocal) + + 1 + ) diff --git a/phivenv/Lib/site-packages/torch/distributions/wishart.py b/phivenv/Lib/site-packages/torch/distributions/wishart.py new file mode 100644 index 0000000000000000000000000000000000000000..19ea33fdd70d06af629598d2fb636239732b2e90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/distributions/wishart.py @@ -0,0 +1,342 @@ +# mypy: allow-untyped-defs +import math +import warnings +from typing import Optional, Union + +import torch +from torch import nan, Tensor +from torch.distributions import constraints +from torch.distributions.exp_family import ExponentialFamily +from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torch.distributions.utils import lazy_property +from torch.types import _Number, _size, Number + + +__all__ = ["Wishart"] + +_log_2 = math.log(2) + + +def _mvdigamma(x: Tensor, p: int) -> Tensor: + assert x.gt((p - 1) / 2).all(), "Wrong domain for multivariate digamma function." + return torch.digamma( + x.unsqueeze(-1) + - torch.arange(p, dtype=x.dtype, device=x.device).div(2).expand(x.shape + (-1,)) + ).sum(-1) + + +def _clamp_above_eps(x: Tensor) -> Tensor: + # We assume positive input for this function + return x.clamp(min=torch.finfo(x.dtype).eps) + + +class Wishart(ExponentialFamily): + r""" + Creates a Wishart distribution parameterized by a symmetric positive definite matrix :math:`\Sigma`, + or its Cholesky decomposition :math:`\mathbf{\Sigma} = \mathbf{L}\mathbf{L}^\top` + + Example: + >>> # xdoctest: +SKIP("FIXME: scale_tril must be at least two-dimensional") + >>> m = Wishart(torch.Tensor([2]), covariance_matrix=torch.eye(2)) + >>> m.sample() # Wishart distributed with mean=`df * I` and + >>> # variance(x_ij)=`df` for i != j and variance(x_ij)=`2 * df` for i == j + + Args: + df (float or Tensor): real-valued parameter larger than the (dimension of Square matrix) - 1 + covariance_matrix (Tensor): positive-definite covariance matrix + precision_matrix (Tensor): positive-definite precision matrix + scale_tril (Tensor): lower-triangular factor of covariance, with positive-valued diagonal + Note: + Only one of :attr:`covariance_matrix` or :attr:`precision_matrix` or + :attr:`scale_tril` can be specified. + Using :attr:`scale_tril` will be more efficient: all computations internally + are based on :attr:`scale_tril`. If :attr:`covariance_matrix` or + :attr:`precision_matrix` is passed instead, it is only used to compute + the corresponding lower triangular matrices using a Cholesky decomposition. + 'torch.distributions.LKJCholesky' is a restricted Wishart distribution.[1] + + **References** + + [1] Wang, Z., Wu, Y. and Chu, H., 2018. `On equivalence of the LKJ distribution and the restricted Wishart distribution`. + [2] Sawyer, S., 2007. `Wishart Distributions and Inverse-Wishart Sampling`. + [3] Anderson, T. W., 2003. `An Introduction to Multivariate Statistical Analysis (3rd ed.)`. + [4] Odell, P. L. & Feiveson, A. H., 1966. `A Numerical Procedure to Generate a SampleCovariance Matrix`. JASA, 61(313):199-203. + [5] Ku, Y.-C. & Bloomfield, P., 2010. `Generating Random Wishart Matrices with Fractional Degrees of Freedom in OX`. + """ + + support = constraints.positive_definite + has_rsample = True + _mean_carrier_measure = 0 + + @property + def arg_constraints(self): + return { + "covariance_matrix": constraints.positive_definite, + "precision_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + "df": constraints.greater_than(self.event_shape[-1] - 1), + } + + def __init__( + self, + df: Union[Tensor, Number], + covariance_matrix: Optional[Tensor] = None, + precision_matrix: Optional[Tensor] = None, + scale_tril: Optional[Tensor] = None, + validate_args: Optional[bool] = None, + ) -> None: + assert (covariance_matrix is not None) + (scale_tril is not None) + ( + precision_matrix is not None + ) == 1, ( + "Exactly one of covariance_matrix or precision_matrix or scale_tril may be specified." + ) + + param = next( + p + for p in (covariance_matrix, precision_matrix, scale_tril) + if p is not None + ) + + if param.dim() < 2: + raise ValueError( + "scale_tril must be at least two-dimensional, with optional leading batch dimensions" + ) + + if isinstance(df, _Number): + batch_shape = torch.Size(param.shape[:-2]) + self.df = torch.tensor(df, dtype=param.dtype, device=param.device) + else: + batch_shape = torch.broadcast_shapes(param.shape[:-2], df.shape) + self.df = df.expand(batch_shape) + event_shape = param.shape[-2:] + + if self.df.le(event_shape[-1] - 1).any(): + raise ValueError( + f"Value of df={df} expected to be greater than ndim - 1 = {event_shape[-1] - 1}." + ) + + if scale_tril is not None: + self.scale_tril = param.expand(batch_shape + (-1, -1)) + elif covariance_matrix is not None: + self.covariance_matrix = param.expand(batch_shape + (-1, -1)) + elif precision_matrix is not None: + self.precision_matrix = param.expand(batch_shape + (-1, -1)) + + if self.df.lt(event_shape[-1]).any(): + warnings.warn( + "Low df values detected. Singular samples are highly likely to occur for ndim - 1 < df < ndim." + ) + + super().__init__(batch_shape, event_shape, validate_args=validate_args) + self._batch_dims = [-(x + 1) for x in range(len(self._batch_shape))] + + if scale_tril is not None: + self._unbroadcasted_scale_tril = scale_tril + elif covariance_matrix is not None: + self._unbroadcasted_scale_tril = torch.linalg.cholesky(covariance_matrix) + else: # precision_matrix is not None + self._unbroadcasted_scale_tril = _precision_to_scale_tril(precision_matrix) + + # Chi2 distribution is needed for Bartlett decomposition sampling + self._dist_chi2 = torch.distributions.chi2.Chi2( + df=( + self.df.unsqueeze(-1) + - torch.arange( + self._event_shape[-1], + dtype=self._unbroadcasted_scale_tril.dtype, + device=self._unbroadcasted_scale_tril.device, + ).expand(batch_shape + (-1,)) + ) + ) + + def expand(self, batch_shape, _instance=None): + new = self._get_checked_instance(Wishart, _instance) + batch_shape = torch.Size(batch_shape) + cov_shape = batch_shape + self.event_shape + new._unbroadcasted_scale_tril = self._unbroadcasted_scale_tril.expand(cov_shape) + new.df = self.df.expand(batch_shape) + + new._batch_dims = [-(x + 1) for x in range(len(batch_shape))] + + if "covariance_matrix" in self.__dict__: + new.covariance_matrix = self.covariance_matrix.expand(cov_shape) + if "scale_tril" in self.__dict__: + new.scale_tril = self.scale_tril.expand(cov_shape) + if "precision_matrix" in self.__dict__: + new.precision_matrix = self.precision_matrix.expand(cov_shape) + + # Chi2 distribution is needed for Bartlett decomposition sampling + new._dist_chi2 = torch.distributions.chi2.Chi2( + df=( + new.df.unsqueeze(-1) + - torch.arange( + self.event_shape[-1], + dtype=new._unbroadcasted_scale_tril.dtype, + device=new._unbroadcasted_scale_tril.device, + ).expand(batch_shape + (-1,)) + ) + ) + + super(Wishart, new).__init__(batch_shape, self.event_shape, validate_args=False) + new._validate_args = self._validate_args + return new + + @lazy_property + def scale_tril(self) -> Tensor: + return self._unbroadcasted_scale_tril.expand( + self._batch_shape + self._event_shape + ) + + @lazy_property + def covariance_matrix(self) -> Tensor: + return ( + self._unbroadcasted_scale_tril + @ self._unbroadcasted_scale_tril.transpose(-2, -1) + ).expand(self._batch_shape + self._event_shape) + + @lazy_property + def precision_matrix(self) -> Tensor: + identity = torch.eye( + self._event_shape[-1], + device=self._unbroadcasted_scale_tril.device, + dtype=self._unbroadcasted_scale_tril.dtype, + ) + return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( + self._batch_shape + self._event_shape + ) + + @property + def mean(self) -> Tensor: + return self.df.view(self._batch_shape + (1, 1)) * self.covariance_matrix + + @property + def mode(self) -> Tensor: + factor = self.df - self.covariance_matrix.shape[-1] - 1 + factor[factor <= 0] = nan + return factor.view(self._batch_shape + (1, 1)) * self.covariance_matrix + + @property + def variance(self) -> Tensor: + V = self.covariance_matrix # has shape (batch_shape x event_shape) + diag_V = V.diagonal(dim1=-2, dim2=-1) + return self.df.view(self._batch_shape + (1, 1)) * ( + V.pow(2) + torch.einsum("...i,...j->...ij", diag_V, diag_V) + ) + + def _bartlett_sampling(self, sample_shape=torch.Size()): + p = self._event_shape[-1] # has singleton shape + + # Implemented Sampling using Bartlett decomposition + noise = _clamp_above_eps( + self._dist_chi2.rsample(sample_shape).sqrt() + ).diag_embed(dim1=-2, dim2=-1) + + i, j = torch.tril_indices(p, p, offset=-1) + noise[..., i, j] = torch.randn( + torch.Size(sample_shape) + self._batch_shape + (int(p * (p - 1) / 2),), + dtype=noise.dtype, + device=noise.device, + ) + chol = self._unbroadcasted_scale_tril @ noise + return chol @ chol.transpose(-2, -1) + + def rsample( + self, sample_shape: _size = torch.Size(), max_try_correction=None + ) -> Tensor: + r""" + .. warning:: + In some cases, sampling algorithm based on Bartlett decomposition may return singular matrix samples. + Several tries to correct singular samples are performed by default, but it may end up returning + singular matrix samples. Singular samples may return `-inf` values in `.log_prob()`. + In those cases, the user should validate the samples and either fix the value of `df` + or adjust `max_try_correction` value for argument in `.rsample` accordingly. + """ + + if max_try_correction is None: + max_try_correction = 3 if torch._C._get_tracing_state() else 10 + + sample_shape = torch.Size(sample_shape) + sample = self._bartlett_sampling(sample_shape) + + # Below part is to improve numerical stability temporally and should be removed in the future + is_singular = self.support.check(sample) + if self._batch_shape: + is_singular = is_singular.amax(self._batch_dims) + + if torch._C._get_tracing_state(): + # Less optimized version for JIT + for _ in range(max_try_correction): + sample_new = self._bartlett_sampling(sample_shape) + sample = torch.where(is_singular, sample_new, sample) + + is_singular = ~self.support.check(sample) + if self._batch_shape: + is_singular = is_singular.amax(self._batch_dims) + + else: + # More optimized version with data-dependent control flow. + if is_singular.any(): + warnings.warn("Singular sample detected.") + + for _ in range(max_try_correction): + sample_new = self._bartlett_sampling(is_singular[is_singular].shape) + sample[is_singular] = sample_new + + is_singular_new = ~self.support.check(sample_new) + if self._batch_shape: + is_singular_new = is_singular_new.amax(self._batch_dims) + is_singular[is_singular.clone()] = is_singular_new + + if not is_singular.any(): + break + + return sample + + def log_prob(self, value): + if self._validate_args: + self._validate_sample(value) + nu = self.df # has shape (batch_shape) + p = self._event_shape[-1] # has singleton shape + return ( + -nu + * ( + p * _log_2 / 2 + + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) + .log() + .sum(-1) + ) + - torch.mvlgamma(nu / 2, p=p) + + (nu - p - 1) / 2 * torch.linalg.slogdet(value).logabsdet + - torch.cholesky_solve(value, self._unbroadcasted_scale_tril) + .diagonal(dim1=-2, dim2=-1) + .sum(dim=-1) + / 2 + ) + + def entropy(self): + nu = self.df # has shape (batch_shape) + p = self._event_shape[-1] # has singleton shape + return ( + (p + 1) + * ( + p * _log_2 / 2 + + self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1) + .log() + .sum(-1) + ) + + torch.mvlgamma(nu / 2, p=p) + - (nu - p - 1) / 2 * _mvdigamma(nu / 2, p=p) + + nu * p / 2 + ) + + @property + def _natural_params(self) -> tuple[Tensor, Tensor]: + nu = self.df # has shape (batch_shape) + p = self._event_shape[-1] # has singleton shape + return -self.precision_matrix / 2, (nu - p - 1) / 2 + + def _log_normalizer(self, x, y): + p = self._event_shape[-1] + return (y + (p + 1) / 2) * ( + -torch.linalg.slogdet(-2 * x).logabsdet + _log_2 * p + ) + torch.mvlgamma(y + (p + 1) / 2, p=p) diff --git a/phivenv/Lib/site-packages/torch/export/__init__.py b/phivenv/Lib/site-packages/torch/export/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4b24f74db6053372e3109dfdb8f1f6458536ae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/__init__.py @@ -0,0 +1,605 @@ +import builtins +import copy +import dataclasses +import inspect +import os +import sys +import typing +import warnings +import zipfile +from collections.abc import Iterator +from enum import auto, Enum +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +import torch.utils._pytree as pytree +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager +from torch.types import FileLike +from torch.utils._pytree import ( + FlattenFunc, + FromDumpableContextFn, + ToDumpableContextFn, + UnflattenFunc, +) + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # Do not import unconditionally, as they import sympy and importing sympy is very slow + from torch._ops import OpOverload + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + + +__all__ = [ + "Constraint", + "Dim", + "ExportBackwardSignature", + "ExportGraphSignature", + "ExportedProgram", + "CustomDecompTable", + "ModuleCallEntry", + "ModuleCallSignature", + "default_decompositions", + "dims", + "export", + "export_for_training", + "load", + "register_dataclass", + "save", + "unflatten", + "FlatArgsAdapter", + "UnflattenedModule", + "AdditionalInputs", + "draft_export", +] + +# To make sure export specific custom ops are loaded +import torch.export.custom_ops + +from .decomp_utils import CustomDecompTable +from .dynamic_shapes import AdditionalInputs, Constraint, Dim, dims, ShapesCollection +from .exported_program import ( + default_decompositions, + ExportedProgram, + ModuleCallEntry, + ModuleCallSignature, +) +from .graph_signature import ExportBackwardSignature, ExportGraphSignature +from .unflatten import FlatArgsAdapter, unflatten, UnflattenedModule + + +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +def export_for_training( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + strict: bool = False, + preserve_module_call_signature: tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export_for_training` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the all ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. This API is intended for PT2 quantization training use cases + and will soon be the default IR of torch.export.export in the near future. To read further about + the motivation behind this change, please refer to + https://dev-discuss.pytorch.org/t/why-pytorch-does-not-need-a-new-standardized-operator-set/2206 + With this API, and :func:`run_decompositions()`, you should be able to get inference IR with + your custom decomposition behaviour. + + **Soundness Guarantee** + + See :func:`export()` docstring for more details. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When enabled (default), the export function will trace the program through + TorchDynamo which will ensure the soundness of the resulting graph. Otherwise, the + exported program will not validate the implicit assumptions baked into the graph and + may cause behavior divergence between the original model and the exported one. This is + useful when users need to workaround bugs in the tracer, or simply want incrementally + enable safety in their models. Note that this does not affect the resulting IR spec + to be different and the model will be serialized in the same way regardless of what value + is passed here. + WARNING: This option is experimental and use this at your own risk. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. The metadata will be used when calling + torch.export.unflatten to preserve the original calling conventions of modules. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export_for_training + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + return _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + + +def export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + strict: bool = False, + preserve_module_call_signature: tuple[str, ...] = (), +) -> ExportedProgram: + """ + :func:`export` takes any nn.Module along with example inputs, and produces a traced graph representing + only the Tensor computation of the function in an Ahead-of-Time (AOT) fashion, + which can subsequently be executed with different inputs or serialized. The + traced graph (1) produces normalized operators in the functional ATen operator set + (as well as any user-specified custom operators), (2) has eliminated all Python control + flow and data structures (with certain exceptions), and (3) records the set of + shape constraints needed to show that this normalization and control-flow elimination + is sound for future inputs. + + **Soundness Guarantee** + + While tracing, :func:`export()` takes note of shape-related assumptions + made by the user program and the underlying PyTorch operator kernels. + The output :class:`ExportedProgram` is considered valid only when these + assumptions hold true. + + Tracing makes assumptions on the shapes (not values) of input tensors. + Such assumptions must be validated at graph capture time for :func:`export` + to succeed. Specifically: + + - Assumptions on static shapes of input tensors are automatically validated without additional effort. + - Assumptions on dynamic shape of input tensors require explicit specification + by using the :func:`Dim` API to construct dynamic dimensions and by associating + them with example inputs through the ``dynamic_shapes`` argument. + + If any assumption can not be validated, a fatal error will be raised. When that happens, + the error message will include suggested fixes to the specification that are needed + to validate the assumptions. For example :func:`export` might suggest the + following fix to the definition of a dynamic dimension ``dim0_x``, say appearing in the + shape associated with input ``x``, that was previously defined as ``Dim("dim0_x")``:: + + dim = Dim("dim0_x", max=5) + + This example means the generated code requires dimension 0 of input ``x`` to be less + than or equal to 5 to be valid. You can inspect the suggested fixes to dynamic dimension + definitions and then copy them verbatim into your code without needing to change the + ``dynamic_shapes`` argument to your :func:`export` call. + + Args: + mod: We will trace the forward method of this module. + + args: Example positional inputs. + + kwargs: Optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + strict: When disabled (default), the export function will trace the program through + Python runtime, which by itself will not validate some of the implicit assumptions + baked into the graph. It will still validate most critical assumptions like shape + safety. When enabled (by setting ``strict=True``), the export function will trace + the program through TorchDynamo which will ensure the soundness of the resulting + graph. TorchDynamo has limited Python feature coverage, thus you may experience more + errors. Note that toggling this argument does not affect the resulting IR spec to be + different and the model will be serialized in the same way regardless of what value + is passed here. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. The metadata will be used when calling + torch.export.unflatten to preserve the original calling conventions of modules. + + Returns: + An :class:`ExportedProgram` containing the traced callable. + + **Acceptable input/output types** + + Acceptable types of inputs (for ``args`` and ``kwargs``) and outputs include: + + - Primitive types, i.e. ``torch.Tensor``, ``int``, ``float``, ``bool`` and ``str``. + - Dataclasses, but they must be registered by calling :func:`register_dataclass` first. + - (Nested) Data structures comprising of ``dict``, ``list``, ``tuple``, ``namedtuple`` and + ``OrderedDict`` containing all above types. + + """ + from ._trace import _export + + if not isinstance(mod, torch.nn.Module): + raise ValueError( + f"Expected `mod` to be an instance of `torch.nn.Module`, got {type(mod)}." + ) + if isinstance(mod, torch.jit.ScriptModule): + raise ValueError( + "Exporting a ScriptModule is not supported. " + "Maybe try converting your ScriptModule to an ExportedProgram " + "using `TS2EPConverter(mod, args, kwargs).convert()` instead." + ) + + try: + return _export( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + pre_dispatch=True, + ) + except Exception as e: + draft_export_msg = ( + "The error above occurred when calling torch.export.export. If you would " + "like to view some more information about this error, and get a list " + "of all other errors that may occur in your export call, you can " + "replace your `export()` call with `draft_export()`." + ) + + # For errors that we know can be caught by draft-export, add the message + # to ask users to try out draft-export + if isinstance( + e, + ( + torch.fx.experimental.symbolic_shapes.GuardOnDataDependentSymNode, + torch._subclasses.fake_tensor.UnsupportedOperatorException, + torch._dynamo.exc.UserError, + torch.fx.experimental.symbolic_shapes.ConstraintViolationError, + ), + ): + new_msg = str(e) + "\n\n" + draft_export_msg + e.args = (new_msg,) + elif isinstance(e, RuntimeError) and "no fake impl registered" in str(e): + new_msg = str(e) + "\n\n" + draft_export_msg + e.args = (new_msg,) + raise e + + +DEFAULT_PICKLE_PROTOCOL = 2 + + +def save( + ep: ExportedProgram, + f: FileLike, + *, + extra_files: Optional[dict[str, Any]] = None, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Saves an :class:`ExportedProgram` to a file-like object. It can then be + loaded using the Python API :func:`torch.export.load `. + + Args: + ep (ExportedProgram): The exported program to save. + + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of f. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + Example:: + + import torch + import io + + + class MyModule(torch.nn.Module): + def forward(self, x): + return x + 10 + + + ep = torch.export.export(MyModule(), (torch.randn(5),)) + + # Save to file + torch.export.save(ep, "exported_program.pt2") + + # Save to io.BytesIO buffer + buffer = io.BytesIO() + torch.export.save(ep, buffer) + + # Save with extra files + extra_files = {"foo.txt": b"bar".decode("utf-8")} + torch.export.save(ep, "exported_program.pt2", extra_files=extra_files) + + """ + if not isinstance(ep, ExportedProgram): + raise TypeError( + f"The 'ep' parameter must be an instance of 'ExportedProgram', got '{type(ep).__name__}' instead." + ) + + from torch.export.pt2_archive._package import package_pt2 + + package_pt2( + f, + exported_programs={"model": ep}, + extra_files=extra_files, + pickle_protocol=pickle_protocol, + opset_version=opset_version, + ) + + +def load( + f: FileLike, + *, + extra_files: Optional[dict[str, Any]] = None, + expected_opset_version: Optional[dict[str, int]] = None, +) -> ExportedProgram: + """ + + .. warning:: + Under active development, saved files may not be usable in newer versions + of PyTorch. + + Loads an :class:`ExportedProgram` previously saved with + :func:`torch.export.save `. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + extra_files (Optional[Dict[str, Any]]): The extra filenames given in + this map would be loaded and their content would be stored in the + provided map. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + Returns: + An :class:`ExportedProgram` object + + Example:: + + import torch + import io + + # Load ExportedProgram from file + ep = torch.export.load("exported_program.pt2") + + # Load ExportedProgram from io.BytesIO object + with open("exported_program.pt2", "rb") as f: + buffer = io.BytesIO(f.read()) + buffer.seek(0) + ep = torch.export.load(buffer) + + # Load with extra files. + extra_files = {"foo.txt": ""} # values will be replaced with data + ep = torch.export.load("exported_program.pt2", extra_files=extra_files) + print(extra_files["foo.txt"]) + print(ep(torch.randn(5))) + """ + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + extra_files = extra_files or {} + + from torch.export.pt2_archive._package import load_pt2, PT2ArchiveContents + + try: + pt2_contents = load_pt2( + f, + expected_opset_version=expected_opset_version, + ) + except RuntimeError: + pt2_contents = PT2ArchiveContents({}, {}, {}) + + if len(pt2_contents.exported_programs) > 0 or len(pt2_contents.extra_files) > 0: + for k, v in pt2_contents.extra_files.items(): + extra_files[k] = v + + return pt2_contents.exported_programs["model"] + + # TODO: For backward compatibility, we support loading a zip file from 2.7. Delete this path in 2.9(?) + warnings.warn( + "This version of file is deprecated. Please generate a new pt2 saved file." + ) + with zipfile.ZipFile(f, "r") as zipf: + # Check the version + version = zipf.read("version").decode().split(".") + from torch._export.serde.schema import ( + SCHEMA_VERSION, # todo change archive version to schema version + ) + + assert len(version) == len(SCHEMA_VERSION), ( + "Version in the saved file has incorrect length, double check if the file is generated by torch.export.save()" + ) + if version[0] != str(SCHEMA_VERSION[0]): + raise RuntimeError( + f"Serialized version {version} does not match our current " + f"schema version {SCHEMA_VERSION}." + ) + + from torch._export.serde.serialize import deserialize, SerializedArtifact + + # Load serialized_ep and serialized_state_dict from the zip file + + serialized_exported_program: Optional[bytes] = None + serialized_state_dict: Optional[bytes] = None + serialized_constants: Optional[bytes] = None + serialized_example_inputs: Optional[bytes] = None + + for file_info in zipf.infolist(): + file_content = zipf.read(file_info.filename) + + if file_info.filename == "serialized_exported_program.json": + serialized_exported_program = file_content + elif file_info.filename == "serialized_state_dict.json": + warnings.warn("This version of file is deprecated") + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.json": + warnings.warn("This version of file is deprecated") + serialized_constants = file_content + elif file_info.filename == "serialized_state_dict.pt": + serialized_state_dict = file_content + elif file_info.filename == "serialized_constants.pt": + serialized_constants = file_content + elif file_info.filename == "serialized_example_inputs.pt": + serialized_example_inputs = file_content + elif file_info.filename.startswith("extra_files"): + filename = file_info.filename.split("/", 1)[1] + extra_files[filename] = file_content.decode("utf-8") + + assert serialized_exported_program is not None + assert serialized_state_dict is not None + assert serialized_constants is not None + assert serialized_example_inputs is not None + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_state_dict, + serialized_constants, + serialized_example_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + + return ep + + +def draft_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + preserve_module_call_signature: tuple[str, ...] = (), + strict: bool = False, +) -> ExportedProgram: + """ + A version of torch.export.export which is designed to consistently produce + an ExportedProgram, even if there are potential soundness issues, and to + generate a report listing the issues found. + """ + from ._draft_export import draft_export + + return draft_export( + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + strict=strict, + ) + + +def register_dataclass( + cls: type[Any], + *, + serialized_type_name: Optional[str] = None, +) -> None: + """ + Registers a dataclass as a valid input/output type for :func:`torch.export.export`. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + dataclass. + + Example:: + + import torch + from dataclasses import dataclass + + + @dataclass + class InputDataClass: + feature: torch.Tensor + bias: int + + + @dataclass + class OutputDataClass: + res: torch.Tensor + + + torch.export.register_dataclass(InputDataClass) + torch.export.register_dataclass(OutputDataClass) + + + class Mod(torch.nn.Module): + def forward(self, x: InputDataClass) -> OutputDataClass: + res = x.feature + x.bias + return OutputDataClass(res=res) + + + ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) + print(ep) + + """ + pytree.register_dataclass(cls, serialized_type_name=serialized_type_name) diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69186e20358b9f8c64a83c0255b202fc49dffe4a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_draft_export.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_draft_export.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25c59df6bd5c43798cfbe043a12b927b36af12ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_draft_export.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1899b4596d33804cd197a1a9ecf66f3977d7a3eb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_remove_auto_functionalized_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1806eb769be5f0b546d86c85cadac786556f8caa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_remove_effect_tokens_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_safeguard.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_safeguard.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df125d3f5b2a895dcca2b6e58820eaa8bf53818b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_safeguard.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_swap.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_swap.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..973175e3c5474b591ade51075bd53b2e6ed91f20 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_swap.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_trace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03469bfb352717f61a874c979f998c71da105dd4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_tree_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_tree_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f10a8a6eb6e88331b5d46b6ce2e41d8670d9e44c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_tree_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_unlift.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_unlift.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2177fa0a05a5b07cfad19b26f3bb20bafbd30019 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_unlift.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83fab5c2e4ca23cec74c4ed3b62f47d41adf376c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/_wrapper_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/custom_obj.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/custom_obj.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a942e2a373bd7173b72b33ca2da4b31e64e3f51c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/custom_obj.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/custom_ops.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/custom_ops.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f08aa6c0c089ceba5075cfb95a4bd58102c1565 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/custom_ops.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/decomp_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/decomp_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee11994cd02b600a05eb2a30cd8e8baee9e8116f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/decomp_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8611851e0345bd2880ed6d52ccac7b3bb48cf681 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/dynamic_shapes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/exported_program.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/exported_program.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3d02295ebfc0c5c5a2161d5fc649b9e3edd08e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/exported_program.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/graph_signature.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/graph_signature.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51245776d3ec96d5d5b9a0fad68d2e93868dfd18 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/graph_signature.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/__pycache__/unflatten.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/__pycache__/unflatten.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c5e1ce0a9178f4dc2836a5e3093719721b64ae5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/__pycache__/unflatten.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/_draft_export.py b/phivenv/Lib/site-packages/torch/export/_draft_export.py new file mode 100644 index 0000000000000000000000000000000000000000..d2856a577dc72fe8fafbb103d886a16e7440ce77 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_draft_export.py @@ -0,0 +1,512 @@ +import getpass +import json +import logging +import os +import re +import tempfile +from dataclasses import dataclass +from enum import IntEnum +from typing import Any, Callable, Optional, Union + +import torch +import torch._logging._internal +import torch._logging.structured +import torch.utils._pytree as pytree +from torch._export.passes.insert_custom_op_guards import ( + get_op_profiles, + insert_custom_op_guards, + OpProfile, +) + +from ._trace import _export +from .dynamic_shapes import _DimHint, _DimHintType, Dim +from .exported_program import ExportedProgram + + +log = logging.getLogger(__name__) + + +class FailureType(IntEnum): + MISSING_FAKE_KERNEL = 1 + DATA_DEPENDENT_ERROR = 2 + GUARD_ADDED = 3 + MISMATCHED_FAKE_KERNEL = 4 + + def __str__(self) -> str: + return self.name + + +def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str]) -> str: + res = "" + for frame in stack: + if frame["filename"] not in str_to_filename: + continue + + res += f""" + File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index] + + res += f"\n {stack[-1]['loc']}" + return res + + +def prettify_frame_locals( + loc: str, locals: dict[str, Any], symbols: dict[str, Any] +) -> str: + local_str = "\n".join(f" {k}: {v}" for k, v in locals.items()) + res = f""" + Locals: +{local_str} +""" + if any(v is not None for v in symbols.values()): + symbol_str = "\n".join( + f" {k}: {v}" for k, v in symbols.items() if v is not None + ) + res += f""" + Symbols: +{symbol_str} +""" + return res + + +def get_loc(filename: str, lineno: int) -> Optional[str]: + try: + with open(filename) as f: + for i, line in enumerate(f): + if i == lineno - 1: + return line.strip() + except FileNotFoundError: + pass + return None + + +class FailureReport: + def __init__( + self, failure_type: FailureType, data: dict[str, Any], xfail: bool = False + ) -> None: + self.failure_type: FailureType = failure_type + self.data: dict[str, Any] = data + self.xfail: bool = xfail + + def __repr__(self) -> str: + return f"FailureReport(failure_type={self.failure_type}, xfail={self.xfail}, data={self.data})" + + def print(self, str_to_filename: dict[int, str]) -> str: + if self.failure_type == FailureType.MISSING_FAKE_KERNEL: + op = self.data["op"] + + return f"""Missing fake kernel. + torch.ops.{op} is missing a fake kernel implementation. + + Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a meta implementation. +""" # noqa: B950 + + elif self.failure_type == FailureType.GUARD_ADDED: + locals_info = ( + prettify_frame_locals(**self.data["frame_locals"]) + if self.data["frame_locals"] + else "" + ) + return f"""Guard Added. + A guard was added during tracing, which might've resulted in some incorrect + tracing or constraint violation error. + Specifically, this guard was added: {self.data["expr"]}, where {self.data["symbol_to_sources"]}. + This occurred at the following stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)}: + {locals_info} + And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n + Because of this, we have modified the dynamic shapes structure to be the + following. You can also use torch.export.Dim.AUTO instead to specify your + dynamic shapes, and we will automatically infer the dynamism for you. + ``` + dynamic_shapes = {self.data["new_dynamic_shapes"]} + ``` +""" + + elif self.failure_type == FailureType.DATA_DEPENDENT_ERROR: + locals_info = ( + prettify_frame_locals(**self.data["frame_locals"]) + if self.data["frame_locals"] + else "" + ) + return f"""Data dependent error. + When exporting, we were unable to evaluate the value of `{self.data["expr"]}`. + This was encountered {self.data["occurrences"]} times. + This occurred at the following user stacktrace: {prettify_stack(self.data["user_stack"], str_to_filename)} + {locals_info} + And the following framework stacktrace: {prettify_stack(self.data["stack"], str_to_filename)}\n + As a result, it was specialized to a constant (e.g. `{self.data["result"]}` in the 1st occurrence), and asserts were inserted into the graph. + + Please add `torch._check(...)` to the original code to assert this data-dependent assumption. + Please refer to https://docs.google.com/document/d/1kZ_BbB3JnoLbUZleDT6635dHs88ZVYId8jT-yTFgf3A/edit#heading=h.boi2xurpqa0o for more details. +""" # noqa: B950 + + elif self.failure_type == FailureType.MISMATCHED_FAKE_KERNEL: + op = self.data["op"] + reason = self.data["reason"] + return f"""Mismatched fake kernel. + torch.ops.{op} has a fake kernel implementation, but it has incorrect behavior, based on the real kernel. + The reason for the mismatch is: {reason}. + + Please refer to https://docs.google.com/document/d/1_W62p8WJOQQUzPsJYa7s701JXt0qf2OfLub2sbkHOaU/edit#heading=h.ahugy69p2jmz for more detailed instructions on how to write a fake implementation. +""" # noqa: B950 + + else: + raise ValueError(f"Unknown failure type: {self.failure_type}") + + +class DraftExportReport: + def __init__( + self, + failures: list[FailureReport], + str_to_filename: dict[int, str], + expressions_created: dict[int, dict[str, Any]], + op_profiles: dict[str, set[OpProfile]], + ): + self.failures: list[FailureReport] = failures + self.str_to_filename = str_to_filename + self.expressions_created: dict[int, dict[str, Any]] = expressions_created + self.op_profiles = op_profiles + + def successful(self) -> bool: + return len(self.failures) == 0 or all( + failure.xfail for failure in self.failures + ) + + def __repr__(self) -> str: + return f"DraftExportReport({self.failures})" + + def __str__(self) -> str: + WARNING_COLOR = "\033[93m" + GREEN_COLOR = "\033[92m" + END_COLOR = "\033[0m" + + if self.successful(): + return f"""{GREEN_COLOR} +############################################################################################## +Congratuations: No issues are found during export, and it was able to soundly produce a graph. +You can now change back to torch.export.export() +############################################################################################## +{END_COLOR}""" + + error = f"""{WARNING_COLOR} +################################################################################################### +WARNING: {len(self.failures)} issue(s) found during export, and it was not able to soundly produce a graph. +Please follow the instructions to fix the errors. +################################################################################################### + +""" + + for i, failure in enumerate(self.failures): + error += f"{i + 1}. {failure.print(self.str_to_filename)}\n" + error += END_COLOR + return error + + def apply_suggested_fixes(self) -> None: + raise NotImplementedError("Not implemented yet") + + +@dataclass +class ExpressionCreatedNode: + result_id: int + argument_ids: list[int] + record: dict[str, object] + visited: bool = False + + +class LogRecord: + def __init__(self) -> None: + self.log_count: dict[int, int] = {} + self.logs: list[tuple[str, dict[str, Any]]] = [] + + def _hash(self, element: tuple[str, dict[str, Any]]) -> int: + key, data = element + + if key == "missing_fake_kernel": + return hash((key, data["op"])) + elif key == "mismatched_fake_kernel": + return hash((key, data["op"], data["reason"])) + elif key == "propagate_real_tensors_provenance": + return hash((key, json.dumps(data["user_stack"]))) + elif key == "guard_added": + return hash((key, json.dumps(data["user_stack"]))) + elif key == "create_unbacked_symbol": + return hash((key, json.dumps(data["user_stack"]))) + + return hash((key, json.dumps(data))) + + def try_add(self, element: tuple[str, dict[str, str]]) -> bool: + hash_value = self._hash(element) + if hash_value in self.log_count: + self.log_count[hash_value] += 1 + return False + + self.log_count[hash_value] = 1 + self.logs.append(element) + return True + + def get_log_count(self, element: tuple[str, dict[str, Any]]) -> int: + return self.log_count[self._hash(element)] + + +class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler): + def __init__(self) -> None: + self.specific_log_keys = [ + "str", + "exported_program", + "propagate_real_tensors_provenance", + "guard_added", + "missing_fake_kernel", + "mismatched_fake_kernel", + "expression_created", + "create_unbacked_symbol", + ] + self.log_record: LogRecord = LogRecord() + self.expression_created_logs: dict[int, ExpressionCreatedNode] = {} + self.symbol_to_expressions: dict[str, list[dict[str, Any]]] = {} + self.logger = logging.getLogger("torch.__trace") + self.prev_get_dtrace = False + + if root_dir := os.environ.get(torch._logging._internal.DTRACE_ENV_VAR): + super().__init__(root_dir) + else: + sanitized_username = re.sub(r'[\\/:*?"<>|]', "_", getpass.getuser()) + root_dir = os.path.join( + tempfile.gettempdir(), + "export_" + sanitized_username, + ) + super().__init__(root_dir) + + self.setFormatter(torch._logging._internal.TorchLogsFormatter(trace=True)) + + def __enter__(self) -> "CaptureStructuredTrace": + self.log_record = LogRecord() + self.expression_created_logs = {} + + # Remove the lazy trace handler if it exists + possible_lazy_trace_handlers = [ + handler + for handler in self.logger.handlers + if isinstance(handler, torch._logging._internal.LazyTraceHandler) + ] + for handler in possible_lazy_trace_handlers: + self.logger.removeHandler(handler) + + self.logger.addHandler(self) + self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED + torch._logging._internal.GET_DTRACE_STRUCTURED = True + return self + + def __exit__(self, exc_type, exc_value, traceback) -> None: # type: ignore[no-untyped-def] + self.log_record = LogRecord() + self.expression_created_logs = {} + self.logger.removeHandler(self) + torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace + self.prev_get_dtrace = False + + def emit(self, record: Any) -> None: + def _log_expression_created( + emit_func: Callable[[Any], None], sym_node_id: int + ) -> None: + # Log all the relevant expression_created logs + if sym_node_id is None: + return + if res := self.expression_created_logs.get(sym_node_id, None): + # Don't log the expression if we have already + # printed it beforehand + if not res.visited: + res.visited = True + for arg in res.argument_ids: + _log_expression_created(emit_func, arg) + + emit_func(res.record) + + metadata = record.metadata + for key in self.specific_log_keys: + if key in metadata: + if self.log_record.try_add((key, metadata[key])): + if key == "expression_created": + # We don't want to log all expression_created logs, only + # the ones that are relevant to the + # guards/propagate_real_tensor + self.expression_created_logs[metadata[key]["result_id"]] = ( + ExpressionCreatedNode( + metadata[key]["result_id"], + metadata[key].get("argument_ids", []), + record, + ) + ) + return + + elif key == "propagate_real_tensors_provenance": + _log_expression_created( + super().emit, metadata[key].get("expr_node_id") + ) + + elif key == "guard_added": + if len(metadata[key]["symbol_to_sources"]) == 0: + # We only want to include guards added that are relevant to + # the symbolic shapes corresponding to the inputs which were + # specified in the dynamic_shapes arg. These have a source. + return + elif metadata[key]["prefix"] == "runtime_assert": + # This should've been captured by a + # propagate_real_tensors log + return + + _log_expression_created( + super().emit, metadata[key].get("expr_node_id") + ) + + super().emit(record) + + +def draft_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + *, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + preserve_module_call_signature: tuple[str, ...] = (), + strict: bool = False, + pre_dispatch: bool = True, +) -> ExportedProgram: + kwargs = kwargs or {} + dynamic_shapes = dynamic_shapes or {} + + capture_structured_log = CaptureStructuredTrace() + + with ( + torch._functorch.config.patch( + fake_tensor_propagate_real_tensors=True, + generate_fake_kernels_from_real_mismatches=True, + ), + capture_structured_log, + ): + try: + new_shapes = None + ep = _export( + mod, + args, + kwargs, + dynamic_shapes=dynamic_shapes, + strict=strict, + pre_dispatch=pre_dispatch, + preserve_module_call_signature=preserve_module_call_signature, + ) + except torch._dynamo.exc.UserError: + + def convert_dim_to_auto(dim: Any) -> Any: + if isinstance(dim, Dim): + return Dim.AUTO(min=dim.min, max=dim.max) + elif isinstance(dim, _DimHint) and dim.type == _DimHintType.DYNAMIC: + return Dim.AUTO(min=dim.min, max=dim.max) + return dim + + new_shapes = pytree.tree_map(convert_dim_to_auto, dynamic_shapes) + ep = _export( + mod, + args, + kwargs, + dynamic_shapes=new_shapes, + strict=strict, + pre_dispatch=pre_dispatch, + preserve_module_call_signature=preserve_module_call_signature, + ) + + torch._logging.dtrace_structured("exported_program", payload_fn=lambda: str(ep)) + + str_to_filename: dict[int, str] = {} + failures: list[FailureReport] = [] + incorrect_custom_ops: set[str] = set() + expressions_created: dict[int, dict[str, Any]] = {} + + for log_name, log_contents in capture_structured_log.log_record.logs: + failure_type = None + + if log_name == "str": + str_to_filename[log_contents[1]] = log_contents[0] # type: ignore[index] + continue + + elif log_name == "propagate_real_tensors_provenance": + log_contents["occurrences"] = ( + capture_structured_log.log_record.get_log_count( + (log_name, log_contents) + ) + ) + + failure_type = FailureType.DATA_DEPENDENT_ERROR + + elif log_name == "guard_added": + if new_shapes is None: + continue + + failure_type = FailureType.GUARD_ADDED + log_contents["new_dynamic_shapes"] = new_shapes + elif log_name == "missing_fake_kernel": + failure_type = FailureType.MISSING_FAKE_KERNEL + incorrect_custom_ops.add(log_contents["op"]) + + elif log_name == "mismatched_fake_kernel": + failure_type = FailureType.MISMATCHED_FAKE_KERNEL + incorrect_custom_ops.add(log_contents["op"]) + + else: + continue + + assert failure_type is not None + failures.append( + FailureReport( + failure_type, + log_contents, + ) + ) + + for k, v in capture_structured_log.expression_created_logs.items(): + if v.visited: + expressions_created[k] = v.record + + op_profiles = get_op_profiles(ep.graph_module, incorrect_custom_ops) + report = DraftExportReport( + failures, str_to_filename, expressions_created, op_profiles + ) + + # Add asserts around custom ops + insert_custom_op_guards(ep.graph_module, incorrect_custom_ops) + + ep._report = report + if not report.successful(): + log_filename = capture_structured_log.stream.name + + warning_msg = f""" +################################################################################################### +WARNING: {len(report.failures)} issue(s) found during export, and it was not able to soundly produce a graph. +To view the report of failures in an html page, please run the command: + `tlparse {log_filename} --export` +Or, you can view the errors in python by inspecting `print(ep._report)`. +""" + + if len(report.op_profiles) > 0: + warning_msg += f""" +While tracing we found {len(report.op_profiles)} operator(s) which do not have a fake kernel registered. +If you intend to retrace the exported graph or run it with fake tensors, please run it under the +following context manager, which will register a fake kernel for those operators. +``` +with torch._library.fake_profile.unsafe_generate_fake_kernels(ep._report.op_profiles): + # run with fake tensors +``` +""" + + warning_msg += """#################################################################################################""" + + log.warning(warning_msg) + + else: + log.info( + """ +############################################################################################## +Congratuations: No issues are found during export, and it was able to soundly produce a graph. +You can now change back to torch.export.export() +############################################################################################## + """ + ) + + return ep diff --git a/phivenv/Lib/site-packages/torch/export/_remove_auto_functionalized_pass.py b/phivenv/Lib/site-packages/torch/export/_remove_auto_functionalized_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..e45748eccb9e60e1ccc4c2c2355b676ed432e15c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_remove_auto_functionalized_pass.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch +from torch._higher_order_ops.auto_functionalize import ( + auto_functionalized, + auto_functionalized_v2, +) +from torch._inductor.fx_passes.post_grad import decompose_auto_functionalized +from torch.export import ExportedProgram +from torch.fx import Graph + + +def remove_self_clone(graph: Graph) -> None: + for node in graph.nodes: + if node.target == torch.ops.aten.copy_.default and node.args[0] == node.args[1]: + node.replace_all_uses_with(node.args[0]) + graph.erase_node(node) + + +def unsafe_remove_auto_functionalized_pass( + ep: ExportedProgram, +) -> ExportedProgram: + """ + This pass removes an instances of the higher order op 'auto_functionalized', + and modifies the calling EP inplace to have the original mutator op. + This pass doesn't perform safety checks to make sure that this inplace mutation is safe. + """ + + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + for node in ep.graph.nodes: + if ( + node.op == "call_function" and node.target is auto_functionalized + ) or ( + node.op == "call_function" and node.target is auto_functionalized_v2 + ): + func = node.args[0] + assert isinstance(func, torch._ops.OpOverload) + # re-inplace everything + node.meta["only_clone_these_tensors"] = [] + decompose_auto_functionalized(ep.graph) + remove_self_clone(ep.graph) + ep.graph.eliminate_dead_code() + + return ep diff --git a/phivenv/Lib/site-packages/torch/export/_remove_effect_tokens_pass.py b/phivenv/Lib/site-packages/torch/export/_remove_effect_tokens_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..2da1e0d8d31548c99fe8a6347b23e34892ce036f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_remove_effect_tokens_pass.py @@ -0,0 +1,167 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch._higher_order_ops.effects import _get_schema, with_effects + +from .exported_program import ExportedProgram +from .graph_signature import ( + CustomObjArgument, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TokenArgument, +) + + +def _remove_effect_tokens_from_graph_helper( + ep, num_tokens, input_token_names, output_token_names +): + inputs_to_lifted_custom_objs = ep.graph_signature.inputs_to_lifted_custom_objs + + output_node = None + with_effect_nodes: list[torch.fx.Node] = [] + + # Output node need to check its args agianst output_token_names (collected from output_spec) + # Therefore, we only need to find the top-levele output node + output_node = next(reversed(ep.graph_module.graph.find_nodes(op="output"))) + for module in ep.graph_module.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + for node in module.graph.nodes: + if not (node.op == "call_function" and node.target is with_effects): + continue + + with_effect_nodes.append(node) + + # Remove tokens from outputs + assert output_node is not None + output_args = output_node.args[0] + assert len(output_args) >= num_tokens + out_token_nodes = output_args[:num_tokens] + output_node.args = (tuple(output_args[num_tokens:]),) + for out_token in out_token_nodes: + assert out_token.name in output_token_names + out_token.users.clear() + ep.graph.erase_node(out_token) + + # Replace with_effects(token, func, args) with just func(args) + for node in reversed(with_effect_nodes): + func = node.args[1] + assert isinstance(func, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)) + + if func == torch.ops.higher_order.call_torchbind: + custom_obj_meta = node.args[2].meta["val"] # type: ignore[union-attr] + assert isinstance(custom_obj_meta, CustomObjArgument) + if custom_obj_meta.fake_val: + custom_obj = custom_obj_meta.fake_val + elif node.args[2].name in inputs_to_lifted_custom_objs: # type: ignore[union-attr] + custom_obj = ep.constants[ + inputs_to_lifted_custom_objs[node.args[2].name] # type: ignore[union-attr] + ] + else: + raise RuntimeError(f"Unable to find custom obj for node {node}") + schema = _get_schema(func, (custom_obj,) + node.args[3:]) + else: + schema = _get_schema(func, node.args[2:]) + + with ep.graph.inserting_before(node): + new_node = ep.graph.call_function(func, node.args[2:], node.kwargs) + for k, v in node.meta.items(): + new_node.meta[k] = v + if k == "unbacked_bindings": + # Remove the extra layer for effect token + old_bindings = new_node.meta[k] + new_bindings = { + k: path[1:] if path else path for k, path in old_bindings.items() + } + new_node.meta[k] = new_bindings + + node.replace_all_uses_with(new_node) + + # Update user getitem nodes + for user in list(new_node.users.keys()): + assert user.target == operator.getitem + # getitem(with_effects, 0) == token + if user.args[1] == 0: + ep.graph.erase_node(user) + + if len(schema.returns) == 1: + # If the function has 1 return then it will just directly return the + # result -- we don't need a getitem. So we can replace all the + # getitem(with_effects, 1) with just the note itself. + for user in list(new_node.users.keys()): + assert user.args[1] == 1 + user.replace_all_uses_with(new_node) + + new_node.meta["val"] = node.meta["val"][1] + elif len(schema.returns) > 1: + # If the function has more than 1 return then since we got rid of + # the 1st return value (the token), we need to bump all the other + # getitem calls by 1 down + for user in list(new_node.users.keys()): + assert user.args[1] >= 1 + user.args = (user.args[0], user.args[1] - 1) + + new_node.meta["val"] = node.meta["val"][1:] + else: + assert len(schema.returns) == 0 + assert len(new_node.users) == 0 + new_node.meta["val"] = None + + ep.graph.erase_node(node) + + # Remove tokens from inputs + placeholders = [node for node in ep.graph.nodes if node.op == "placeholder"] + assert len(placeholders) >= num_tokens + inp_token_nodes = placeholders[:num_tokens] + for inp_token in inp_token_nodes: + assert inp_token.name in input_token_names + ep.graph.erase_node(inp_token) + + ep.graph.eliminate_dead_code() + + +def _remove_effect_tokens(ep: ExportedProgram) -> ExportedProgram: + """ + Removes the existance of tokens from the exported program, including: + - Removes the input and output tokens + - Replaces with_effects(token, func, args) with just func(args) + + This function does an inplace modification on the given ExportedProgram. + """ + num_tokens: int = 0 + input_token_names: list[str] = [] + new_input_specs: list[InputSpec] = [] + for inp in ep.graph_signature.input_specs: + if inp.kind == InputKind.TOKEN: + num_tokens += 1 + assert isinstance(inp.arg, TokenArgument) + input_token_names.append(inp.arg.name) + else: + new_input_specs.append(inp) + + num_out_tokens: int = 0 + new_output_specs: list[OutputSpec] = [] + output_token_names: list[OutputSpec] = [] + for out in ep.graph_signature.output_specs: + if out.kind == OutputKind.TOKEN: + num_out_tokens += 1 + output_token_names.append(out.arg.name) + else: + new_output_specs.append(out) + + # Update graph signature + ep.graph_signature.input_specs = new_input_specs + ep.graph_signature.output_specs = new_output_specs + + assert num_tokens == num_out_tokens + + with ep.graph_module._set_replace_hook(ep.graph_signature.get_replace_hook()): + _remove_effect_tokens_from_graph_helper( + ep, num_tokens, input_token_names, output_token_names + ) + + return ep diff --git a/phivenv/Lib/site-packages/torch/export/_safeguard.py b/phivenv/Lib/site-packages/torch/export/_safeguard.py new file mode 100644 index 0000000000000000000000000000000000000000..24ad485ebe7e2a51425fbfca53da267ea2c45551 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_safeguard.py @@ -0,0 +1,44 @@ +# mypy: allow-untyped-defs +import torch +from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode +from torch.overrides import TorchFunctionMode + + +class AutogradStateOpsFailSafeguard(TorchFunctionMode): + """ + Detect grad state ops during exporting the graph and fail the process by + raising an error, to avoid unexpected behavior. Those grad mode ops could be: + `torch.no_grad` + `torch.enable_grad` + `torch.set_grad_enabled` + + Export with predispatch mode is exempted. + """ + + def __torch_function__(self, func, types, args=(), kwargs=None): + kwargs = kwargs or {} + unsupported_grad_mode_ops = [ + torch._C._set_grad_enabled, + ] + # It's only enabled while tracing, by confirming the torch dispatch mode is + # any active PROXY. This is to allow the autograd ops out of tracing. + current_state = torch._C.is_grad_enabled() + if func in unsupported_grad_mode_ops: + assert len(args) == 1 + changed_state = args[0] + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + # Intend to check if it's not the pre_dispatch mode. It's allowed to use + # autograd ops in pre_dispatch mode, e.g. `torch.no_grad` + if ( + mode + and isinstance(mode, ProxyTorchDispatchMode) + and not mode.pre_dispatch + and changed_state != current_state + ): + raise RuntimeError( + f"Encountered autograd state manager op {func} trying to change global autograd state " + "while exporting. This is unsafe because we don't capture this op in torch.export " + "today, hence we can't reflect the user intention soundly. You can fix this by " + "adding a torch.no_grad() context around the export call." + ) + return func(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/export/_swap.py b/phivenv/Lib/site-packages/torch/export/_swap.py new file mode 100644 index 0000000000000000000000000000000000000000..80679cc1edf7d91a4601ad0a7dbce6228730faa8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_swap.py @@ -0,0 +1,438 @@ +import logging +import operator +import types +from collections import defaultdict +from typing import Optional + +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from torch.export.exported_program import ( + ConstantArgument, + ExportedProgram, + ModuleCallSignature, +) +from torch.fx.passes.tools_common import legalize_graph, NodeList +from torch.fx.passes.utils.fuser_utils import erase_nodes, fuse_as_graphmodule + + +log = logging.getLogger(__name__) + + +def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]: + node_users = list(node.users.keys()) + getitem_users = set() + for user in node_users: + if user.op == "output": + continue + + assert user.op == "call_function" and user.target == operator.getitem, ( + f"Expected getitem node as user for {node}, instead got {user}" + ) + getitem_users.update(list(user.users.keys())) + return getitem_users + + +def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: + """ + We want to try to remove extraneous pytree flatten/unflatten calls between modules + calls. Instead of having the following: + graph(): + ... + %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) + %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (%foo, %_spec_1), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) + %tree_unflatten_1 : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%getitem_4], %_spec_2), kwargs = {}) + %getitem_5 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 0), kwargs = {}) + %getitem_7 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten_1, 1), kwargs = {}) + %getitem_6 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem_5, 0), kwargs = {}) + %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {}) + ... + + We could do the following, if we know that all the outputs of `foo` feed into `bar`: + graph(): + ... + %foo : [num_users=1] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) + %bar : [num_users=1] = call_module[target=bar](args = (%getitem_6,), kwargs = {}) + ... + + Currently this optimization only works for the case where all of the outputs + of `foo` go directly into `bar`, and `bar` has no other inputs. + """ # noqa: B950 + + log.debug("Trying to remove pytrees for module call %s", curr_module_node) + + curr_module_users = list(curr_module_node.users.keys()) + assert len(curr_module_users) == 1, ( + f"Expected only one user for module node, instead got {list(curr_module_users)}" + ) + flatten_node = curr_module_users[0] + assert ( + flatten_node.op == "call_function" + and flatten_node.target == fx_pytree.tree_flatten_spec + ) + + flatten_getitem_users = _get_getitem_users(flatten_node) + if len(flatten_getitem_users) != 1: + log.debug( + "More than one user found for flatten node, %s: %s. " + "Unable to fuse it with another unflatten call.", + flatten_node, + flatten_getitem_users, + ) + return + + unflatten_node = next(iter(flatten_getitem_users)) + if not ( + unflatten_node.op == "call_function" + and unflatten_node.target == pytree.tree_unflatten + ): + log.debug( + "Flatten node %s's user is not a pytree.tree_unflatten. " + "Instead it is: %s. Passing...", + flatten_node, + unflatten_node, + ) + return + + for i, arg in enumerate(unflatten_node.args[0]): # type: ignore[union-attr,arg-type] + if arg not in flatten_node.users: + log.debug( + "Module %s's outputs are not all directly used as inputs to " + "the subsequent module. Unable to fuse the connecting " + "flatten/unflatten. The inputs to the subsequent module are: %s. ", + curr_module_node, + unflatten_node.args[0], + ) + return + + if not ( + arg.op == "call_function" + and arg.target == operator.getitem + and arg.args[1] == i + ): + log.debug( + "Module %s's outputs are not all directly used in the same " + "order as outputted. Unable to fuse the connecting " + "flatten/unflatten. The inputs to the " + "subsequent module are: %s. ", + curr_module_node, + unflatten_node.args[0], + ) + return + + # Unflatten has two levels of getitem, because it gets the args and kwargs + unflatten_getitem_getitem_users = set() + unflatten_getitem_users = _get_getitem_users(unflatten_node) + for unflatten_getitem_user in unflatten_getitem_users: + unflatten_getitem_getitem_users.update( + list(unflatten_getitem_user.users.keys()) + ) + + if len(unflatten_getitem_getitem_users) != 1: + log.debug( + "More than one user found for unflatten node, %s: %s. " + "Unable to fuse it with another flatten call.", + unflatten_node, + unflatten_getitem_getitem_users, + ) + return + + next_module_node = next(iter(unflatten_getitem_getitem_users)) + if not (next_module_node.op == "call_module"): + log.debug( + "Unflatten node %s's user is not a call_module. " + "Instead it is: %s. Passing...", + unflatten_node, + next_module_node, + ) + return + + # Directly put the outputs of the current module into the next module + next_module_node.args = (curr_module_node,) + + +def _remove_extraneous_pytrees(gm: torch.fx.GraphModule) -> None: + """ + Remove extraneous pytree flatten/unflatten calls. + + We try a couple of optimizations here: + 1. Remove pytree flatten/unflatten calls between modules + 2. TODO: Remove module's in_spec + initial unflatten call + 3. TODO: Remove module's out_spec + final flatten call + """ + + for node in gm.graph.nodes: + if node.op == "call_module": + _try_remove_connecting_pytrees(node) + + gm.graph.eliminate_dead_code() + + +def _construct_inputs( + gm: torch.fx.GraphModule, + signature: ModuleCallSignature, + node_name_map: dict[str, torch.fx.Node], +) -> tuple[list[torch.fx.Node], dict[str, torch.fx.Node]]: + tree_unflatten_args: list[Optional[torch.fx.Node]] = [] + for input_ in signature.inputs: + if isinstance(input_, ConstantArgument) and input_.value is None: + # Constants should be directly embedded into the graph and not used + # as inputs + tree_unflatten_args.append(None) + elif input_.name not in node_name_map: + # For unused inputs + tree_unflatten_args.append(None) + else: + tree_unflatten_args.append(node_name_map[input_.name]) + + # Insert unflatten call + from .unflatten import _generate_unflatten + + unflatten_node = _generate_unflatten(gm, tree_unflatten_args, signature.in_spec) + + assert signature.in_spec.num_children == 2 + + args_spec = signature.in_spec.children_specs[0] + assert args_spec.context is None + args_node = gm.graph.call_function(operator.getitem, (unflatten_node, 0)) + args_nodes = [ + gm.graph.call_function(operator.getitem, (args_node, i)) + for i in range(args_spec.num_children) + ] + + kwargs_spec = signature.in_spec.children_specs[1] + assert kwargs_spec.context is not None + kwargs_node = gm.graph.call_function(operator.getitem, (unflatten_node, 1)) + kwargs_nodes = { + k: gm.graph.call_function(operator.getitem, (kwargs_node, k)) + for k in kwargs_spec.context + } + return args_nodes, kwargs_nodes + + +def _insert_call_module( + gm: torch.fx.GraphModule, + args_nodes: list[torch.fx.Node], + kwargs_nodes: dict[str, torch.fx.Node], + module_to_swap: torch.nn.Module, + name: str, +) -> torch.fx.Node: + from .unflatten import _assign_attr, _AttrKind + + _assign_attr(module_to_swap, gm, name, _AttrKind.MODULE) + module_node = gm.graph.call_module(name, tuple(args_nodes), kwargs_nodes) # type: ignore[arg-type] + return module_node + + +def _deconstruct_outputs( + gm: torch.fx.GraphModule, + signature: ModuleCallSignature, + module_node: torch.fx.Node, + node_name_map: dict[str, torch.fx.Node], + orig_outputs: tuple[torch.fx.Node, ...], +) -> None: + from .unflatten import _generate_flatten_spec + + flatten_node = _generate_flatten_spec(gm, module_node, signature.out_spec) + + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(flatten_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + node_name_map[orig_output.name] = proxy_out + + +def _swap_module_helper( + gm: torch.fx.GraphModule, + modules_to_swap: dict[str, torch.nn.Module], + module_call_graph: dict[str, ModuleCallSignature], +) -> torch.fx.GraphModule: + log.debug("Starting graph:") + log.debug(gm.graph) + + legalize_graph(gm) + + partitions: dict[str, NodeList] = defaultdict(list) + + node_name_map: dict[str, torch.fx.Node] = { + node.name: node for node in gm.graph.nodes + } + + # TODO: Handle the duplicate module case + for node in gm.graph.nodes: + if nn_module_stack := node.meta.get("nn_module_stack"): + for path, _ in nn_module_stack.values(): + if path in modules_to_swap: + partitions[path].append(node) + break + + for name, nodes in partitions.items(): + """ + Given a graph like the following, and we want to swap out the submodule "foo": + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=2] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}), nn_module_stack = {"foo": ("foo", torch.nn.Module)} + %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %add), kwargs = {}), nn_module_stack = {"bar": ("bar", torch.nn.Module)} + return (sub,) + + We will first partition out foo's subgraph: + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=2] = placeholder[target=y] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%y, %x), kwargs = {}) + return add + + And then insert an unflatten + call_module + flatten to replace the subgraph: + graph(): + %x : [num_users=1] = placeholder[target=x] + %y : [num_users=1] = placeholder[target=y] + + %_spec_0 : [num_users=1] = get_attr[target=_spec_0] + %tree_unflatten : [num_users=2] = call_function[target=torch.utils._pytree.tree_unflatten](args = ([%x, %y], %_spec_0), kwargs = {}) + %getitem : [num_users=2] = call_function[target=operator.getitem](args = (%tree_unflatten, 0), kwargs = {}) + %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 0), kwargs = {}) + %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%getitem, 1), kwargs = {}) + %getitem_3 : [num_users=0] = call_function[target=operator.getitem](args = (%tree_unflatten, 1), kwargs = {}) + %foo : [num_users=0] = call_module[target=foo](args = (%getitem_1, %getitem_2), kwargs = {}) + %_spec_1 : [num_users=1] = get_attr[target=_spec_1] + %tree_flatten_spec : [num_users=1] = call_function[target=torch.fx._pytree.tree_flatten_spec](args = (None, %_spec_1), kwargs = {}) + %getitem_4 : [num_users=1] = call_function[target=operator.getitem](args = (%tree_flatten_spec, 0), kwargs = {}) + + %sub : [num_users=1] = call_function[target=torch.ops.aten.sub.Tensor](args = (%y, %getitem_4), kwargs = {}) + return (%sub,) + + The `tree_unflatten` call will construct tensor inputs into the input + format needed by the swapped eager module. + The `call_module` node should now reference the swapped torch.nn.Module. + The `tree_flatten_spec` call will deconstruct the eager outputs of the + swapped module into tensors. + """ # noqa: B950 + + submod_name = name.replace(".", "_") + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, nodes, f"fused_{submod_name}" + ) + + log.debug("Fused subgraph nodes:") + log.debug(sub_gm.graph) + + signature: ModuleCallSignature = module_call_graph[name] + + args_nodes, kwargs_nodes = _construct_inputs(gm, signature, node_name_map) + module_node = _insert_call_module( + gm, args_nodes, kwargs_nodes, modules_to_swap[name], name + ) + _deconstruct_outputs(gm, signature, module_node, node_name_map, orig_outputs) + + erase_nodes(gm, nodes) + + log.debug("Swapped graph:") + log.debug(gm.graph) + + legalize_graph(gm) + + log.debug("Before removing extraneous pytrees:") + log.debug(gm.graph) + + _remove_extraneous_pytrees(gm) + log.debug("After removing extraneous pytrees:") + log.debug(gm.graph) + + gm.recompile() + + return gm + + +def _fix_input_output_signature( + gm: torch.fx.GraphModule, signature: ModuleCallSignature +) -> None: + """ + Given the unlifted module from calling ep.module(), we want to remove the + pytree processing from the graph module's PyTreeCodeGen and instead make it + nodes inside of the graph. This allows us to do some optimizations, like + remove these pytree calls if it is unnecessary, and makes the PyTree part + more obvious to graph passes. + """ + from torch.export.unflatten import _generate_flatten, _generate_unflatten + + # Remove the registered pytree codegen because we will take care of it + # through inserting pytree nodes into the graph + gm.graph._codegen = torch.fx.graph.CodeGen() + + old_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + + new_placeholders = [] + forward_arg_names = signature.forward_arg_names + if forward_arg_names is None: + forward_arg_names = [] + assert signature.in_spec.num_children == 2 + arg_spec = signature.in_spec.children_specs[0] + kwarg_spec = signature.in_spec.children_specs[1] + assert arg_spec.type == tuple + assert kwarg_spec.type == dict + for i in range(arg_spec.num_children): + forward_arg_names.append(f"arg_{i}") + forward_arg_names.extend(kwarg_spec.context) + + for arg in forward_arg_names: + with gm.graph.inserting_before(old_placeholders[0]): + new_placeholders.append(gm.graph.placeholder(arg)) + + # Insert flatten call for the inputs + with gm.graph.inserting_before(old_placeholders[0]): + flat_node = _generate_flatten(gm, tuple(new_placeholders)) + for i, old_placeholder in enumerate(old_placeholders): + old_placeholder.op = "call_function" + old_placeholder.target = operator.getitem + old_placeholder.args = (flat_node, i) + + # Insert unflatten call for the outputs + output_node = next(node for node in gm.graph.nodes if node.op == "output") + with gm.graph.inserting_before(output_node): + unflat = _generate_unflatten(gm, output_node.args[0], signature.out_spec) + output_node.args = (unflat,) + + gm.recompile() + + +def _swap_modules( + ep: ExportedProgram, modules_to_swap: dict[str, torch.nn.Module] +) -> torch.fx.GraphModule: + """ + Unlifts the given ExportedProgram into a fx.GraphModule, and then swaps + previously traced modules with new eager modules specified. Returns a + fx.GraphModule with a custom forward function. + + Args: + ep (ExportedProgram): Exported program to modify + modules_to_swap (Dict[str, torch.nn.Module]): Mapping from module fqn to + eager module to swap with. The specified module fqn should have also + been specified in the `preserve_module_call_signature` argument to + torch.export so that we know how to restore the calling convention + to this argument. + run_with_interpreter: Whether or not to run the graph using + fx.Interpreter. Setting to true will help result in better error + messages and easier debugging, but it has found to result in a QPS + drop. + """ + module_call_graph = { + entry.fqn: entry.signature for entry in ep.module_call_graph if entry.signature + } + + gm = ep.module() + gm.validate_inputs = False # type: ignore[assignment] + gm.graph.eliminate_dead_code() # type: ignore[operator, union-attr] + assert isinstance(gm, torch.fx.GraphModule) + _fix_input_output_signature(gm, ep.module_call_graph[0].signature) + + gm.module_call_graph = ep.module_call_graph + gm.train = types.MethodType(type(gm).train, gm) # type: ignore[assignment] + gm.eval = types.MethodType(type(gm).eval, gm) # type: ignore[assignment] + + assert isinstance(gm, torch.fx.GraphModule) + gm = _swap_module_helper(gm, modules_to_swap, module_call_graph) + + return gm diff --git a/phivenv/Lib/site-packages/torch/export/_trace.py b/phivenv/Lib/site-packages/torch/export/_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..5051035b4fe42906c4ffa573077007f2faaa8eb7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_trace.py @@ -0,0 +1,2267 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import dataclasses +import functools +import inspect +import logging +import re +import sys +import time +import warnings +from contextlib import contextmanager, nullcontext +from typing import Any, Callable, Optional, Union + +import torch +import torch._dynamo +import torch.fx +import torch.utils._pytree as pytree +from torch._dispatch.python import enable_python_dispatcher +from torch._dynamo.exc import UserError, UserErrorType +from torch._export.db.logging import ( + exportdb_error_message, + get_class_if_classified_error, +) +from torch._export.non_strict_utils import ( + _fakify_module_inputs, + _fakify_script_objects, + _gather_constant_attrs, + _NonStrictTorchFunctionHandler, + _override_builtin_ops, + make_constraints, + make_fake_inputs, + produce_guards_and_solve_constraints, +) +from torch._export.passes.collect_tracepoints_pass import CollectTracepointsPass +from torch._export.passes.lift_constants_pass import ( + _materialize_and_lift_constants, + ConstantAttrMap, +) +from torch._export.utils import ( + _collect_param_buffer_metadata, + _compiling_state_context, + _fakify_params_buffers, + _populate_param_buffer_metadata_to_new_gm, + _update_gm_meta_if_possible, + apply_runtime_assertion_pass, + placeholder_naming_pass, + placeholder_prefixes, +) +from torch._export.verifier import SpecViolationError +from torch._export.wrappers import _wrap_submodules +from torch._functorch._aot_autograd.input_output_analysis import ( + _graph_input_names, + _graph_output_names, +) +from torch._functorch._aot_autograd.schemas import GraphSignature +from torch._functorch._aot_autograd.subclass_utils import get_subclass_typing_container +from torch._functorch._aot_autograd.traced_function_transforms import ( + create_functional_call, +) +from torch._functorch._aot_autograd.utils import ( + create_tree_flattened_fn, + register_buffer_assignment_hook, +) +from torch._functorch.aot_autograd import ( + _detect_attribute_assignment, + aot_export_module, +) +from torch._guards import detect_fake_mode, tracing, TracingContext +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import dtrace_structured +from torch._subclasses.fake_tensor import FakeTensorMode +from torch._utils_internal import log_export_usage +from torch.export._unlift import _check_input_constraints_pre_hook +from torch.export.dynamic_shapes import ( + _check_dynamic_shapes, + _combine_args, + _DimHintType, + _IntWrapper, + _process_dynamic_shapes, +) +from torch.export.exported_program import OutputKind +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx.experimental.proxy_tensor import ( + get_proxy_slot, + make_fx, + PreDispatchTorchFunctionMode, + track_tensor_tree, +) +from torch.fx.experimental.symbolic_shapes import ( + ConstraintViolationError, + free_unbacked_symbols, + GuardOnDataDependentSymNode, + ShapeEnv, +) +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.utils._pytree import TreeSpec +from torch.utils._sympy.value_ranges import ValueRangeError + +from ._safeguard import AutogradStateOpsFailSafeguard +from ._wrapper_utils import _WrapperModule +from .exported_program import ( + _disable_prexisiting_fake_mode, + ExportedProgram, + InputKind, + ModuleCallEntry, + ModuleCallSignature, +) +from .graph_signature import _convert_to_export_graph_signature, ExportGraphSignature + + +log = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ExportDynamoConfig: + """ + Manage Export-specific configurations of Dynamo. + """ + + allow_rnn: bool = True + reorderable_logging_functions: set[Callable] = dataclasses.field( + default_factory=set + ) + # Emit runtime asserts after AOTAutograd instead. + # This isn't really necessary, and isn't much more efficient since the runtime asserts pass does CSE, + # but if we want to reason more about what guards/runtime asserts to emit, + # this makes it a bit cleaner to do from the export side. Also no real point in running this twice. + do_not_emit_runtime_asserts: bool = True + specialize_int: bool = True + specialize_float: bool = True + assume_static_by_default: bool = False + automatic_dynamic_shapes: bool = False + capture_dynamic_output_shape_ops: bool = True + capture_scalar_outputs: bool = True + prefer_deferred_runtime_asserts_over_guards: bool = False + + +@dataclasses.dataclass +class ATenExportArtifact: + gm: torch.fx.GraphModule + sig: ExportGraphSignature + constants: dict[str, _ConstantAttributeType] + + +@dataclasses.dataclass(frozen=True) +class ExportArtifact: + aten: ATenExportArtifact + in_spec: TreeSpec + out_spec: TreeSpec + fake_mode: FakeTensorMode + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] + + +DEFAULT_EXPORT_DYNAMO_CONFIG = ExportDynamoConfig() +DEFAULT_EXPORT_DYNAMO_CONFIG.reorderable_logging_functions = { + logging.critical, + logging.debug, + logging.error, + logging.exception, + logging.info, + logging.log, + logging.warning, + print, + warnings.warn, +} + + +@contextmanager +def _ignore_backend_decomps(): + orig_mkldnn_flag = torch.backends.mkldnn.set_flags(False) + orig_nnpack_flag = torch.backends.nnpack.set_flags(False) + try: + yield + finally: + torch.backends.mkldnn.set_flags(*orig_mkldnn_flag) + torch.backends.nnpack.set_flags(*orig_nnpack_flag) + + +@contextmanager +def _disable_custom_triton_op_functional_decomposition(): + old = torch._functorch.config.decompose_custom_triton_ops + try: + torch._functorch.config.decompose_custom_triton_ops = False + yield torch._functorch.config.decompose_custom_triton_ops + finally: + torch._functorch.config.decompose_custom_triton_ops = old + + +def custom_triton_ops_decomposition_disabled(): + return not torch._functorch.config.decompose_custom_triton_ops + + +def _fixup_key(x): + return "L__self__" + _strip_root(x) + + +def _strip_root(x): + if isinstance(x, str) and x.startswith("_export_root"): + stripped = x[len("_export_root") :] + return stripped.removeprefix(".") + return x + + +def _rewrite_tracepoint_node(gm: torch.fx.GraphModule): + """ + In-place modifiy input graph module by replacing the export tracepoint with a new node + that has the same target and args, but with the _export_root stripped from path. + """ + for node in gm.graph.nodes: + if node.target == torch.ops.higher_order._export_tracepoint: + if "path" in node.kwargs: + path = _strip_root(node.kwargs["path"]) + with gm.graph.inserting_before(node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.higher_order._export_tracepoint, + args=node.args, + kwargs={ + "path": path, + "kind": node.kwargs["kind"], + }, + ) + new_node.meta = node.meta + node.replace_all_uses_with(new_node) + gm.graph.erase_node(node) + + +def detect_shape_env(inputs: Any = None): + shape_envs = [] + + for i, flat_input in enumerate(inputs): + if isinstance(flat_input, torch.SymInt): + shape_envs.append((flat_input.node.shape_env, "symint input", i)) + + if shape_envs: + shape_env, desc1, i1 = shape_envs[0] + for m, desc2, i2 in shape_envs[1:]: + assert shape_env is m, ( + f"shape env ({shape_env}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" + f"shape env from {desc1} {i1} allocated at:\n{shape_env.stack}\n" + f"shape env from {desc2} {i2} allocated at:\n{m.stack}" + ) + return shape_env + else: + return None + + +def _extract_fake_inputs(gm, args, kwargs): + """ + Given a graph module, extract fakified input tensors from the metadata of + its placeholders, and map them to the structure of given args and kwargs. + Also return the fake mode used to fakify those inputs. + """ + fake_inps: list[Any] = [] + fake_vals: list[Any] = [] + for node in gm.graph.nodes: + if node.op == "placeholder": + fake_inps.append(node.meta.get("val")) + else: + fake_vals.append(node.meta.get("example_value")) + + # We get both because now we might have a combination of symint and tensor + # inputs, and we want to check that the shape env is consistent between + # both. Unforunately we can't see what fake mode is attached to the shape + # env, then we can just compare fake modes. + detected_fake_mode = detect_fake_mode(fake_inps + fake_vals) + detected_shape_env = detect_shape_env(fake_inps + fake_vals) + + if detected_fake_mode: + if detected_shape_env: + assert detected_shape_env is detected_fake_mode.shape_env, ( + "Detected shape env does not match fake mode's shape env" + ) + fake_mode = detected_fake_mode + elif detected_shape_env: + fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True) + else: + fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) + + count = 0 + + def lookup_fake(x): + nonlocal count + val = fake_inps[count] if isinstance(x, (int, torch.Tensor)) else x + count += 1 + return val + + fake_args = pytree.tree_map(lookup_fake, args) + fake_kwargs = pytree.tree_map(lookup_fake, kwargs) + + return fake_args, fake_kwargs, fake_mode + + +def _replace_param_buffer_names(param_buffer_table, sig): + for spec in sig.input_specs: + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + spec.target = param_buffer_table[spec.target] + for spec in sig.output_specs: + if spec.kind in ( + OutputKind.BUFFER_MUTATION, + OutputKind.GRADIENT_TO_PARAMETER, + ): + spec.target = param_buffer_table[spec.target] + + +def _convert_to_positional_args(orig_arg_names, args, kwargs): + assert len(orig_arg_names) == len(args) + len(kwargs), ( + f"Total number of arg names is expected to be {len(orig_arg_names)} " + f"but got {len(args)} positional args, {len(kwargs)} kwargs." + ) + reordered_kwargs = [kwargs[kw_name] for kw_name in orig_arg_names[len(args) :]] + return ( + *args, + *reordered_kwargs, + ) + + +def _normalize_nn_module_stack(gm_torch_level, root_cls): + # Append a root module to every nn_module_stack. + root = "L['self']" + root_key = re.sub(r"[^a-zA-Z0-9]", "_", root) + for gm in gm_torch_level.modules(): + if not isinstance(gm, torch.fx.GraphModule): + continue + for node in gm.graph.nodes: + if node.op in ["placeholder", "output"]: + continue + add_root = True + if nn_module_stack := node.meta.get("nn_module_stack", {}): + path, ty = next(iter(nn_module_stack.values())) + # After deserializing the class `ty` might not exist anymore so + # it could be a string + if inspect.isclass(ty) and issubclass(ty, torch.nn.Module): + # TODO Figure out why sometimes we have root sometimes we don't. + if path == root and ty is root_cls: + add_root = False + else: + assert isinstance(ty, str) + if add_root: + + def normalize_path(path): + try: + parts = [] + + class Path: + def __getattr__(self, name): + if name != "_modules": + parts.append(name) + return self + + def __getitem__(self, idx): + parts.append(str(idx)) + return self + + eval(path, {"L": {"self": Path()}}) + return ".".join(parts) + except Exception: # TODO(zhxchen17) Remove this. + return path + + nn_module_stack = { + root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__), + **nn_module_stack, + } + node.meta["nn_module_stack"] = { + key: (normalize_path(path), ty) + for key, (path, ty) in nn_module_stack.items() + } + + +def _get_param_buffer_mapping( + original_module: torch.nn.Module, + traced_module: torch.nn.Module, +) -> dict[str, str]: + """ + Returns a mapping of parameter/buffer names from the new module to the + original model. This is to help with restoring the FQN for parameter/buffers + of a traced module to what the original module contains. + """ + + param_lookup: dict[int, str] = {} + buffer_lookup: dict[int, str] = {} + for name, param in original_module.named_parameters(remove_duplicate=False): + param_lookup[id(param)] = name + for name, buffer in original_module.named_buffers(remove_duplicate=False): + buffer_lookup[id(buffer)] = name + + param_buffer_table: dict[str, str] = {} + for dynamo_name, dynamo_param in traced_module.named_parameters( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_param) in param_lookup: + param_buffer_table[dynamo_name] = param_lookup[id(dynamo_param)] + + for dynamo_name, dynamo_buffer in traced_module.named_buffers( + remove_duplicate=False + ): + assert dynamo_name not in param_buffer_table + if id(dynamo_buffer) in buffer_lookup: + param_buffer_table[dynamo_name] = buffer_lookup[id(dynamo_buffer)] + + return param_buffer_table + + +def _preserve_requires_grad_pass( + gm: torch.fx.GraphModule, + sig: ExportGraphSignature, + fake_params_buffers: dict[str, torch.Tensor], + constants: dict[str, _ConstantAttributeType], + flat_fake_args: list[Any], +): + placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(sig.input_specs) == len(placeholders) + i = 0 + for node, spec in zip(placeholders, sig.input_specs): + if spec.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + assert spec.target is not None + node.meta["val"].requires_grad = fake_params_buffers[ + spec.target + ].requires_grad + elif spec.kind == InputKind.USER_INPUT: + fake_arg = flat_fake_args[i] + if isinstance(fake_arg, torch.Tensor): + node.meta["val"].requires_grad = fake_arg.requires_grad + i += 1 + elif spec.kind == InputKind.CONSTANT_TENSOR: + assert spec.target is not None + constant = constants[spec.target] + if isinstance(constant, torch.Tensor): + # If the tensor is not leaf, it should already have a correct requires grad field + if node.meta["val"].is_leaf: + node.meta["val"].requires_grad = constant.requires_grad + else: + assert node.meta["val"].requires_grad == constant.requires_grad + elif spec.kind in (InputKind.CUSTOM_OBJ, InputKind.TOKEN): + continue + else: + raise AssertionError(spec.kind) + + +def _remap_constants( + orig_constant_attrs: ConstantAttrMap, + graph_signature: ExportGraphSignature, + constants: dict[str, _ConstantAttributeType], +) -> None: + """Rewrite the graph signature and constants table to use the FQN from the original module.""" + remap_table: dict[str, list[str]] = {} + for name, value in constants.items(): + if value in orig_constant_attrs: + remap_table[name] = orig_constant_attrs[value] + + for spec in graph_signature.input_specs: + if spec.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + orig_target = spec.target + assert orig_target is not None + targets = remap_table.get(orig_target, [orig_target]) + spec.target = targets[0] + + constant = constants[orig_target] + del constants[orig_target] + for target in targets: + constants[target] = constant + + +def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None: + """ + When we run an interpreter-based pass over a GraphModule, execution of data-dependent operators + will produce example values with new unbacked symbols. To track that the new/old symbols are equivalent, + we used to rely on the unbacked_renamings mapping. This led to problematic metadata where the unbacked_bindings + keys mapped new symbols (u2) to paths containing old symbols (u0) in the example values, or worse, backed symbols + or constants (e.g. if the original unbacked was replaced/specialized). Additionally this created problems with + de/serialized programs, since we didn't comprehensively serialize ShapeEnv/unbacked renamings/node bindings. + + This pass attempts a simpler way of handling these for export, by throwing away the previously computed bindings, and re-running + the pattern match used in compute_unbacked_bindings. This ensures we keep the original symbols contained in the example values, + or delete bindings if they've been replaced/specialized. + """ + from torch._export.utils import _get_shape_env_from_gm + from torch.fx.experimental.symbolic_shapes import _free_unbacked_symbols_with_path + from torch.utils._sympy.symbol import symbol_is_type, SymT + + if (shape_env := _get_shape_env_from_gm(gm)) is None: + return + + base_unbacked_symbols = { + symbol + for symbol in shape_env.var_to_range + if symbol_is_type(symbol, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)) + and symbol not in shape_env.unbacked_renamings + } + for node in gm.graph.nodes: + node.meta.pop("unbacked_bindings", None) + if (val := node.meta.get("val")) is not None and ( + unbacked_bindings := _free_unbacked_symbols_with_path( + val, + (), + shape_env=shape_env, + pending=base_unbacked_symbols, + simplify=True, + ) + ): + node.meta["unbacked_bindings"] = unbacked_bindings + + +def _produce_aten_artifact( + *, + gm: torch.fx.GraphModule, + mod, + constant_attrs, + graph_signature, + pre_dispatch, + fake_args, + fake_kwargs, + fake_params_buffers, + _prettify_placeholder_names=True, +) -> ATenExportArtifact: + """ + This is a helper function that is shared between export_to_aten_ir and export_to_aten_ir_make_fx + to produce the aten artifact. (export compatible graph module + signature) + + It does: + 1. Applies runtime assertion pass + 2. Recompute unbacked_bindings pass + 3. Populate meta val when missing + 4. Lift constants as placeholders + 5. Replace raw autograd and autocast ops with HOPs + 6. Prettify names for placeholders + 7. Preserve requires_grad value on node meta val + """ + # Run runtime asserts pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + flat_fake_args = pytree.tree_leaves((fake_args, fake_kwargs)) + gm, graph_signature = apply_runtime_assertion_pass(gm, graph_signature) + + # Simplify unbacked_bindings by recomputing them. + # Useful for any pass that's interpreter-based and might call rebind_unbacked(), + # e.g. AOTAutograd in this case. + _replace_unbacked_bindings(gm) + + total_non_user_inputs = ( + len(graph_signature.parameters) + + len(graph_signature.buffers) + + len(graph_signature.input_tokens) + ) + set_missing_meta_vals(gm, flat_fake_args, total_non_user_inputs) + + export_graph_signature: Optional[ExportGraphSignature] + export_graph_signature = _convert_to_export_graph_signature( + graph_signature, gm, _get_non_persistent_buffers(mod) + ) + + # script objects are always stored in constants no matter whether they're initial inputs or + # they're lifted in aot" before rewrite_script_object_meta + constants = _materialize_and_lift_constants( + gm, export_graph_signature, constant_attrs + ) + + if pre_dispatch: + from torch._export.passes.replace_autocast_with_hop_pass import ( + replace_autocast_with_hop_pass, + ) + from torch._export.passes.replace_set_grad_with_hop_pass import ( + replace_set_grad_with_hop_pass, + ) + + # Note: replace_set_grad_with_hop_pass need to be after lift_constant_pass because + # a getattr of a constant tensor doesn't have meta["val"] until after lift_constant_pass. + # If replace_set_grad_with_hop_pass is before lift_constant_pass, + # and the constant_tensor is passed as input of the set grad hop, the placeholder's + # meta["val"] will be None and fails our verifier for placeholder. + gm, export_graph_signature = replace_set_grad_with_hop_pass( + gm, export_graph_signature + ) + + gm, export_graph_signature = replace_autocast_with_hop_pass( + gm, export_graph_signature + ) + + # Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes. + for _mod in gm.modules(): + if not isinstance(_mod, torch.fx.GraphModule): + continue + for node in _mod.graph.nodes: + if node.op in ["placeholder", "output"]: + node.meta.pop("nn_module_stack", None) + node.meta.pop("stack_trace", None) + + # Prettify names for placeholder nodes. + assert export_graph_signature is not None + if _prettify_placeholder_names: + placeholder_naming_pass( + gm, + export_graph_signature, + mod, + fake_args, + fake_kwargs, + fake_params_buffers, + constants, + ) + + _preserve_requires_grad_pass( + gm, export_graph_signature, fake_params_buffers, constants, flat_fake_args + ) + + return ATenExportArtifact( + gm, + export_graph_signature, + constants, + ) + + +def _rename_constants_nodes( + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, +) -> None: + """ + For strict mode, rename constants nodes that were previously annotated as buffers. + """ + # handle name collisions with existing constants + node_names = {node.name for node in gm.graph.nodes} + + def rename_constant(name): + if name in node_names: + n = 1 + while (dup_name := f"{name}_{n}") in node_names: + n += 1 + name = dup_name + node_names.add(name) + return name + + # use input specs to map names from buffers to constants + buffer_prefix = placeholder_prefixes[InputKind.BUFFER] + const_prefix = placeholder_prefixes[InputKind.CONSTANT_TENSOR] + buffer_to_constant = {} + for spec in graph_signature.input_specs: + if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith( + const_prefix + ): + if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants + c_name = rename_constant( + const_prefix + spec.arg.name[len(buffer_prefix) :] + ) + else: # lifted constant + c_name = rename_constant(const_prefix + spec.arg.name) + buffer_to_constant[spec.arg.name] = c_name + spec.arg.name = c_name + for spec in graph_signature.output_specs: + if spec.arg.name in buffer_to_constant: + spec.arg.name = buffer_to_constant[spec.arg.name] + + # Rename constants nodes for all modules + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.name in buffer_to_constant: + node.name = node.target = buffer_to_constant[node.name] + mod.recompile() + + +def _restore_state_dict( + original_module: torch.nn.Module, traced_module: torch.fx.GraphModule +) -> None: + """ + Restores the state dict of the traced module to that of the original module. + """ + param_buffer_table = _get_param_buffer_mapping(original_module, traced_module) + # Since the graph module is flattened (no module heirarchy), we + # need to noramlize the module by replacing "." with "_". If we + # don't, it will try to save the weight to a submodule which no + # longer exists. + for name, fqn in param_buffer_table.items(): + param_buffer_table[name] = fqn.replace(".", "_") + + # Replace state dict attr names with the fqn + for name, fqn in param_buffer_table.items(): + if not hasattr(traced_module, name): + continue + + attr = getattr(traced_module, name) + if isinstance(attr, torch.Tensor) and not isinstance(attr, torch.nn.Parameter): + traced_module.register_buffer(fqn, attr) + else: + setattr(traced_module, fqn, attr) + delattr(traced_module, name) + + # Replace graph getattr nodes with the correct name + for node in traced_module.graph.nodes: + if node.op == "get_attr": + attr_name = node.target + if attr_name in param_buffer_table: + node.target = param_buffer_table[attr_name] + + traced_module.recompile() + + +def _get_module_hierarchy(mod: torch.nn.Module) -> dict[str, str]: + return { + name: type(m).__name__ for name, m in mod.named_modules(remove_duplicate=False) + } + + +def _make_module_call_graph( + in_spec: TreeSpec, + out_spec: TreeSpec, + module_call_signatures: dict[str, ModuleCallSignature], + forward_arg_names: Optional[list[str]] = None, +) -> list[ModuleCallEntry]: + original = [ + ModuleCallEntry(fqn=fqn, signature=module_call_signatures.get(fqn)) + for fqn in _EXPORT_MODULE_HIERARCHY # type: ignore[union-attr] + ] + assert original[0].fqn == "" + original[0].signature = ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=in_spec, + out_spec=out_spec, + forward_arg_names=forward_arg_names, + ) + additional = [ + ModuleCallEntry(fqn=fqn, signature=signature) + for fqn, signature in module_call_signatures.items() + if fqn not in _EXPORT_MODULE_HIERARCHY # type: ignore[operator] + ] + return [*original, *additional] + + +def _export_to_torch_ir( + f: Callable, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + *, + preserve_module_call_signature: tuple[str, ...] = (), + disable_constraint_solver: bool = False, + allow_complex_guards_as_runtime_asserts: bool = False, + restore_fqn: bool = True, + _log_export_usage: bool = True, + same_signature: bool = True, +) -> torch.fx.GraphModule: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a torch.fx.GraphModule in torch IR. + """ + + if _log_export_usage: + log_export_usage(event="export.private_api", flags={"_export_to_torch_ir"}) + + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + + kwargs = kwargs or {} + + # Map ints to a wrapper structure to help us mark it as dynamic, if it is + # dynamic. We will unwrap ints in fakify later. + args, kwargs = pytree.tree_map_only(int, _IntWrapper, (args, kwargs)) + + combined_args = _combine_args(f, args, kwargs) + _check_dynamic_shapes(combined_args, dynamic_shapes) + constraints = _process_dynamic_shapes(combined_args, dynamic_shapes) + + # Unwrap static ints -- in the case where we have an empty graph + # containing just integer computation, dynamo will run its generated + # bytecode with these args/kwargs, which will error because we cannot + # directly apply int operations on IntWrapper. So we will just unwrap + # them here. + args, kwargs = pytree.tree_map_only( + _IntWrapper, + lambda a: a.val + if a.dynamism is None or a.dynamism.type == _DimHintType.STATIC + else a, + (args, kwargs), + ) + + with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)): + try: + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {} + ctx = nullcontext() + if not isinstance(f, torch.fx.GraphModule): + ctx = _wrap_submodules( # type: ignore[assignment] + f, preserve_module_call_signature, module_call_specs + ) + with ctx, _ignore_backend_decomps(): + gm_torch_level, _ = torch._dynamo.export( + f, + dynamic_shapes=dynamic_shapes, # type: ignore[arg-type] + constraints=constraints, # type: ignore[arg-type] + assume_static_by_default=True, + tracing_mode="symbolic", + disable_constraint_solver=disable_constraint_solver, + # currently the following 2 flags are tied together for export purposes, + # but untangle for sake of dynamo export api + prefer_deferred_runtime_asserts_over_guards=True, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _log_export_usage=_log_export_usage, + same_signature=same_signature, + )( + *args, + **kwargs, + ) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + except GuardOnDataDependentSymNode as e: + raise UserError( # noqa: B904 + UserErrorType.ANTI_PATTERN, + f"Consider annotating your code using torch._check*(). {str(e)}", + case_name="constrain_as_size_example", + ) + + gm_torch_level.meta["module_call_specs"] = module_call_specs + + if isinstance(f, torch.nn.Module) and restore_fqn: + _restore_state_dict(f, gm_torch_level) + + return gm_torch_level + + +def _export_to_aten_ir( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + produce_guards_callback=None, + *, + transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later. + pre_dispatch=False, + decomp_table=None, + _check_autograd_state: bool = True, + _is_torch_jit_trace: bool = False, + _prettify_placeholder_names: bool = True, + decompose_custom_triton_ops: bool = False, +) -> ATenExportArtifact: + # [NOTE] If the user is exporting under training mode, we want to detect if there is any + # state change in the autograd global state and error. If the user is exporting under inference + # mode, we don't care. At predispatch level, we don't care about the state change. + is_grad_enabled = torch._C.is_grad_enabled() + grad_safe_guard = nullcontext() + # export_to_aten_ir is called when we decompose the ep into inference IR + # In that setting, we actually shouldn't check the state change as at this point, + # because the intention is specalizing to inference. + if _check_autograd_state: + if not pre_dispatch and is_grad_enabled: + grad_safe_guard = AutogradStateOpsFailSafeguard() # type: ignore[assignment] + + custom_triton_ops_decomposition_ctx = ( + nullcontext + if decompose_custom_triton_ops + else _disable_custom_triton_op_functional_decomposition + ) + # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with ( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), + grad_safe_guard, + _ignore_backend_decomps(), + _compiling_state_context(), + custom_triton_ops_decomposition_ctx(), + ): + gm, graph_signature = transform(aot_export_module)( + mod, + fake_args, + trace_joint=False, + pre_dispatch=pre_dispatch, + decompositions=decomp_table, + kwargs=fake_kwargs, + ) + + def _maybe_fixup_gm_and_output_node_meta(old_gm, new_gm): + if isinstance(old_gm, torch.fx.GraphModule): + if hasattr(old_gm, "meta"): + new_gm.meta.update(old_gm.meta) + old_output_node = list(old_gm.graph.nodes)[-1] + new_output_node = list(new_gm.graph.nodes)[-1] + assert old_output_node.op == "output" and new_output_node.op == "output" + # make sure we don't override any meta + assert len(new_output_node.meta) == 0 + new_output_node.meta.update(old_output_node.meta) + + # TODO unfortunately preserving graph-level metadata and output node's meta + # is not working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + _maybe_fixup_gm_and_output_node_meta(mod, gm) + + # Run produce guards before we handle runtime asserts. + # This means we run the export solver before the runtime asserts pass. + # Right now this doesn't mean much - the export solver is only there for suggested fixes, + # and we won't even get to constraint solving if that's needed. + # But if in future we want to control what runtime asserts are emitted for export, + # or rely on produce_guards + solver for some simplification on runtime asserts, this probably makes sense. + if produce_guards_callback: + try: + produce_guards_callback(gm) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + + return _produce_aten_artifact( + gm=gm, + mod=mod, + constant_attrs=constant_attrs, + graph_signature=graph_signature, + pre_dispatch=pre_dispatch, + fake_args=fake_args, + fake_kwargs=fake_kwargs, + fake_params_buffers=fake_params_buffers, + _prettify_placeholder_names=_prettify_placeholder_names, + ) + + +def _get_forward_arg_names( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, +) -> list[str]: + """ + Gets the argument names to forward that are used, for restoring the + original signature when unlifting the exported program module. + - Positional args: retain the original argument names, and enumerate + *args as args_0, args_1, ... + - Keyword args: retain the original kwarg names in the order specified + by the user. This order seems to matter for the current state of + export lifted modules. + """ + sig = inspect.signature(mod.forward) + _args = sig.bind_partial(*args).arguments + + names: list[str] = [] + for name, value in _args.items(): + # handle variable number of positional args + if sig.parameters[name].kind == inspect._ParameterKind.VAR_POSITIONAL: + names.extend([f"{name}_{i}" for i, _ in enumerate(value)]) + else: + names.append(name) + # order of kwargs matters for input spec + if kwargs: + names.extend([kwarg for kwarg, _ in kwargs.items()]) + + return names + + +def _get_non_persistent_buffers(mod: torch.nn.Module) -> set[str]: + """ + Returns set of non-persistent buffers in a module and its submodules. + """ + result: set[str] = set() + for name, m in mod.named_modules(remove_duplicate=False): + if name: + result.update(f"{name}.{b}" for b in m._non_persistent_buffers_set) + else: + result.update(m._non_persistent_buffers_set) + return result + + +def _rewrite_dynamo_tensor_constants( + orig_mod_buffers: set[torch.Tensor], + traced_mod_buffers: dict[str, torch.Tensor], + graph_signature: ExportGraphSignature, + constants: dict[str, _ConstantAttributeType], +) -> None: + """ + Dynamo erroneously marks tensor attributes on modules as buffers. + Rewrite them to be tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER: + assert spec.target is not None + value = traced_mod_buffers[spec.target] + if value not in orig_mod_buffers: + # This was a tensor constant erroneously marked as a buffer. + # Convert it into a constant in the graph signature, and add its + # value to the constants table. + spec.kind = InputKind.CONSTANT_TENSOR + constants[spec.target] = value # type: ignore[arg-type] + + +def _move_non_persistent_buffers_to_tensor_constants( + orig_mod: torch.nn.Module, + graph_signature: ExportGraphSignature, + constants: dict[str, _ConstantAttributeType], +) -> None: + """ + Moves non-persistent buffers to tensor constants. + """ + for spec in graph_signature.input_specs: + if spec.kind == InputKind.BUFFER and not spec.persistent: + assert spec.target is not None + assert spec.target not in constants + constants[spec.target] = orig_mod.get_buffer(spec.target) # type: ignore[arg-type] + + +def _verify_nn_module_stack(graph_module: torch.fx.GraphModule) -> None: + """ + Perform nn_module_stack checks on the graph. + Current constraints: + For the top level graph: + - populated for 'call_function', 'get_attr' + - None for 'placeholder', 'output' + For submodule graphs: + - None for 'placeholder', output' + + TODO(pianpwk): make this a consistent node-level check once nn_module_stack is populated for cond submodules. + """ + # Check top-level graph for all nodes, all graphs for placeholder & output nodes + for i, mod in enumerate([graph_module] + list(graph_module.modules())): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op in ["call_function", "get_attr"]: + if i == 0: + if ( + nn_module_stack := node.meta.get("nn_module_stack", None) + ) is None: + raise SpecViolationError( + f"Node {node} of type {node.op} is missing nn_module_stack metadata" + ) + if not all( + isinstance(k, str) + and isinstance(v, tuple) + and len(v) == 2 + and all(isinstance(x, str) for x in v) + for k, v in nn_module_stack.items() + ): + raise SpecViolationError( + f"Node {node} of type {node.op} has incorrect nn_module_stack metadata format" + f"expected Dict[str, Tuple[str, str]], but got {nn_module_stack}" + ) + elif node.op in ["placeholder", "output"]: + if node.meta.get("nn_module_stack", None): + raise SpecViolationError( + f"Node {node} of type {node.op} contains nn_module_stack metadata, this should be None" + ) + + +def _verify_stack_trace(graph_module: torch.fx.GraphModule) -> None: + """ + Perform stack trace checks on the graph. + Constraints: + - None or non-empty str for 'call_function', 'get_attr' + - None for 'placeholder', 'output' + """ + for mod in [graph_module, *graph_module.modules()]: + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in graph_module.graph.nodes: + stack_trace = node.meta.get("stack_trace", None) + if node.op in ["call_function", "get_attr"]: + if not (stack_trace is None or isinstance(stack_trace, str)): + raise SpecViolationError( + f"Node {node} of type {node.op} has invalid stack_trace metadata, " + f"expected a string or None but instead found: {stack_trace}" + ) + elif node.op in ["placeholder", "output"]: + if stack_trace: + raise SpecViolationError( + f"Node {node} of type {node.op} contains stack_trace metadata, " + f"expected None but instead found: {stack_trace}" + ) + + +def _verify_placeholder_names( + gm: torch.fx.GraphModule, sig: ExportGraphSignature +) -> None: + """ + Performs a sanity check on the placeholder node names. + - User input nodes: no restrictions, should match the original forward() signature + - Params/buffers/constants/custom_obj/token nodes: should start with prefixes defined in + """ + name_to_kind = {spec.arg.name: spec.kind for spec in sig.input_specs} + for mod in gm.modules(): + if not isinstance(mod, torch.fx.GraphModule): + continue + for node in mod.graph.nodes: + if node.op == "placeholder": + if node.name not in name_to_kind: + continue + node_kind = name_to_kind[node.name] + prefix = placeholder_prefixes[node_kind] + if not node.name.startswith(prefix): + raise SpecViolationError( + f"Placeholder node name {node.name} does not follow spec for {node_kind}, name should have prefix: {prefix}" + ) + + +def get_ep_stats(ep: ExportedProgram) -> dict[str, Any]: + op_count = 0 + op_set = set() + for m in ep.graph_module.modules(): + if not isinstance(m, torch.fx.GraphModule): + continue + for node in m.graph.nodes: + if node.op != "call_function": + continue + op_count += 1 + assert hasattr(node.target, "__module__") + assert hasattr(node.target, "__name__") + op_set.add(f"{node.target.__module__}.{node.target.__name__}") + return {"op_count": op_count, "op_set": op_set} + + +_EXPORT_FLAGS: Optional[set[str]] = None +_EXPORT_MODULE_HIERARCHY: Optional[dict[str, str]] = None + + +def _log_export_wrapper(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY + try: + start = time.time() + ep = fn(*args, **kwargs) + end = time.time() + log_export_usage( + event="export.time", + metrics=end - start, + flags=_EXPORT_FLAGS, + **get_ep_stats(ep), + ) + except Exception as e: + t = type(e) + error_type = t.__module__ + "." + t.__qualname__ + case_name = get_class_if_classified_error(e) + if case_name is not None: + log.error(exportdb_error_message(case_name)) + log_export_usage( + event="export.error.classified", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + else: + log_export_usage( + event="export.error.unclassified", + type=error_type, + message=str(e), + flags=_EXPORT_FLAGS, + ) + + if hasattr(e, "partial_fx_graph"): + print( + e.partial_fx_graph, + file=sys.stderr, + ) + + raise e + finally: + _EXPORT_FLAGS = None + _EXPORT_MODULE_HIERARCHY = None + + return ep + + return wrapper + + +def _process_jit_trace_inputs_for_export(example_inputs, example_kwarg_inputs): + if not isinstance(example_inputs, (tuple, list, dict)): + example_inputs = (example_inputs,) + + elif isinstance(example_inputs, list): + example_inputs = tuple(example_inputs) + + elif ( + isinstance(example_inputs, (torch.Tensor, dict)) + and example_kwarg_inputs is None + ): + example_inputs = (example_inputs,) + + if example_kwarg_inputs is None: + example_kwarg_inputs = {} + return example_inputs, example_kwarg_inputs + + +def _get_original_state_dict(mod: torch.nn.Module) -> dict[str, Any]: + # Explicitly not calling mode.state_dict() as we do not want the module state for serialization + # but the running module state so we can always match by id() the entries here with the graph inputs + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + original_state_dict = named_parameters | named_buffers + + non_persistent_buffers = _get_non_persistent_buffers(mod) + for k in non_persistent_buffers: + original_state_dict.pop(k, None) + + return original_state_dict + + +def _process_export_inputs(mod, args, kwargs, dynamic_shapes): + if not isinstance(args, tuple): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expecting `args` to be a tuple of example positional inputs, got {type(args)}", + ) + kwargs = kwargs if kwargs is not None else {} + _, original_in_spec = pytree.tree_flatten((args, kwargs)) + + if isinstance(dynamic_shapes, torch.export.AdditionalInputs): + verify_additional_inputs = dynamic_shapes.verify + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + else: + verify_additional_inputs = lambda ep: None # noqa: E731 + if isinstance(dynamic_shapes, torch.export.ShapesCollection): + dynamic_shapes = dynamic_shapes.dynamic_shapes(mod, args, kwargs) + + return args, kwargs, original_in_spec, dynamic_shapes, verify_additional_inputs + + +def _get_module_call_graph( + export_artifact: ExportArtifact, + preserve_module_call_signature: tuple[str, ...], + strict_mode_export: bool, + forward_arg_names: Optional[list[str]] = None, +) -> tuple[torch.fx.GraphModule, list[ModuleCallEntry]]: + """ + In-place modify the graph module in export_artifact, remove _export_tracepoint nodes and + return module_call_graph. + """ + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + module_call_specs: dict[str, dict[str, TreeSpec]] = ( + export_artifact.module_call_specs + ) + in_spec: TreeSpec = export_artifact.in_spec + out_spec: TreeSpec = export_artifact.out_spec + + # Make module signatures. + module_call_signatures: dict[str, ModuleCallSignature] = {} + for fqn, specs in module_call_specs.items(): + mod_fqn = _strip_root(fqn) if not strict_mode_export else fqn + module_call_signatures[mod_fqn] = ModuleCallSignature( + inputs=[], + outputs=[], + in_spec=specs["in_spec"], + out_spec=specs["out_spec"], + forward_arg_names=None, # we only propage forward_arg_names for the top level module + ) + + if len(preserve_module_call_signature) > 0: + if not strict_mode_export: + _rewrite_tracepoint_node(gm) + res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm) + assert res is not None + gm = res.graph_module + + assert _EXPORT_MODULE_HIERARCHY is not None + module_call_graph = _make_module_call_graph( + in_spec, + out_spec, + module_call_signatures, + forward_arg_names, + ) + return gm, module_call_graph + + +def _get_range_constraints( + mod: torch.nn.Module, + export_artifact: ExportArtifact, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=False, +): + gm: torch.fx.GraphModule = export_artifact.aten.gm + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + fake_mode: FakeTensorMode = export_artifact.fake_mode + num_lifted = next( + ( + i + for i, s in enumerate(export_graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(export_graph_signature.input_specs), + ) + combined_args = _combine_args( + mod, args, kwargs, _is_torch_jit_trace=_is_torch_jit_trace + ) + + # This is because we trace based on the kewargs passed in from user + # not based on the signature. I feel it would be better to just enforce + # one ordering at the start of tracing to avoid confusions, but that is + # bigger refactor, so do this to unblock for now. + if not _is_torch_jit_trace: + combined_args_traced_order = {} + for arg in combined_args: + if arg not in kwargs: + combined_args_traced_order[arg] = combined_args[arg] + + for key in kwargs: + combined_args_traced_order[key] = kwargs[key] + + combined_args = combined_args_traced_order + + range_constraints = make_constraints( + fake_mode, + gm, + combined_args, + dynamic_shapes, + num_lifted, + ) + return range_constraints + + +def _get_inline_constraints(fake_mode: FakeTensorMode): + assert fake_mode.shape_env is not None + return { + k: v + for k, v in fake_mode.shape_env.var_to_range.items() + if free_unbacked_symbols(k) + } + + +@contextmanager +def patch_forward(obj: torch.nn.Module, new_method): + """Helper method to make it easier to cleanly torch.export() a method on a + module that is not `forward`. + """ + # Save the original method + original_method = obj.forward + + # Patch the method + obj.forward = new_method.__get__(obj, obj.__class__) + + try: + yield + finally: + # Restore the original method + obj.forward = original_method + + +@contextmanager +def _temp_disable_texpr_fuser(): + original_state = torch._C._jit_texpr_fuser_enabled() + torch._C._jit_set_texpr_fuser_enabled(False) + try: + yield + finally: + torch._C._jit_set_texpr_fuser_enabled(original_state) + + +def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): + with _temp_disable_texpr_fuser(): + from torch.jit._trace import TopLevelTracedModule + + export_args, export_kwargs = _process_jit_trace_inputs_for_export(args, kwargs) + + if isinstance(traced_callable, (TopLevelTracedModule, torch._C.ScriptModule)): # type: ignore[operator] + return _export( + traced_callable, + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( + traced_callable.owner(), # type: ignore[operator] + (torch._C.ScriptModule, torch.nn.Module), + ): + with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] + return _export( + traced_callable.owner(), # type: ignore[operator] + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + else: + return _export( + _WrapperModule(traced_callable), + export_args, + export_kwargs, + strict=False, + _is_torch_jit_trace=True, + ).module() + + +def _strict_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], + preserve_module_call_signature: tuple[str, ...], + orig_in_spec: TreeSpec, + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, + _to_aten_func: Callable, +) -> ExportArtifact: + """ + _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir` + """ + + gm_torch_level = _export_to_torch_ir( + mod, + args, + kwargs, + dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + restore_fqn=False, # don't need to restore because we will do it later + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _log_export_usage=False, + ) + + # We detect the fake_mode by looking at gm_torch_level's placeholders, this is the fake_mode created in dynamo. + ( + fake_args, + fake_kwargs, + dynamo_fake_mode, + ) = _extract_fake_inputs(gm_torch_level, args, kwargs) + + fake_params_buffers = _fakify_params_buffers(dynamo_fake_mode, gm_torch_level) + + # First, we want to pass through the graph to try populating + # val field for getattr if there is anything missing. + # This can happen when quantization adds extra params and forgets + # to update "val" + for node in gm_torch_level.graph.nodes: + if node.op == "get_attr" and "val" not in node.meta: + attr = getattr(gm_torch_level, node.target) + # Checks if it is not a HigherOrderOp branch or a module + if not isinstance(attr, torch.nn.Module): + assert dynamo_fake_mode is not None, ( + "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + ) + node.meta["val"] = dynamo_fake_mode.from_tensor( + attr, static_shapes=True + ) + + # Fix the graph output signature to be tuple if scalar + + # gm_torch_level.graph._codegen is made a _PyTreeCodeGen in rewrite_signature in eval_frame.py + assert isinstance(gm_torch_level.graph._codegen, torch.fx.graph._PyTreeCodeGen) + + # Calling gm_torch_level._out_spec is not safe because gm_torch_level might be + # a _LazyGraphModule, which does not populate _out_spec when calling recompile(). + # TODO: Fix recompile() in _LazyGraphModule. T207713214 + out_spec = orig_out_spec = gm_torch_level.graph._codegen.pytree_info.out_spec + + # Used to get rid of lint type error. + assert out_spec is not None + assert orig_out_spec is not None + + # aot_export expect the return type to always be a tuple. + if out_spec.type not in (list, tuple): + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + + orig_arg_names = gm_torch_level.graph._codegen.pytree_info.orig_args # type: ignore[attr-defined] + + gm_torch_level.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + orig_arg_names, + gm_torch_level._in_spec, + out_spec, + ) + ) + gm_torch_level.recompile() + + _normalize_nn_module_stack(gm_torch_level, type(mod)) + + params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + # When aot_export lifts the params, we lose metadata (e.g. source_fn_stack, stack_trace) + # from the param nodes as they are treated as fresh inputs + # Therefore, we manually extract them before calling into aot_export + # params_buffers_to_node_meta = _collect_param_buffer_metadata(gm_torch_level) + + constant_attrs = _gather_constant_attrs(mod) + param_buffer_table: dict[str, str] = _get_param_buffer_mapping(mod, gm_torch_level) + + # Dynamo does not track which buffers were registered as non-persistent. This info + # is available in the original module, so we transfer it to the traced module. Also, + # since we didn't restore original param/buffer names yet, we must use traced names. + non_persistent_buffers = _get_non_persistent_buffers(mod) + reverse_name_lookup = {orig: traced for traced, orig in param_buffer_table.items()} + gm_torch_level._non_persistent_buffers_set = { + reverse_name_lookup[name] + for name in non_persistent_buffers + if name in reverse_name_lookup + } + + tx = TracingContext(dynamo_fake_mode) + with dynamo_fake_mode, tracing(tx): + aten_export_artifact = _to_aten_func( + gm_torch_level, + # NOTE: graph module expects only positional args + _convert_to_positional_args(orig_arg_names, fake_args, fake_kwargs), + {}, + fake_params_buffers, + constant_attrs, + ) + + # Decompose for readability. + gm = aten_export_artifact.gm + export_graph_signature = aten_export_artifact.sig + constants = aten_export_artifact.constants + + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, export_graph_signature + ) + + # Do some cleanups on the graph module to restore the state dict to the + # expected form. Each of these steps should probably get fixed upstream. + # 1. Remove tensor constants that were added as buffers. + _rewrite_dynamo_tensor_constants( + orig_mod_buffers=set(mod.buffers()), + traced_mod_buffers=dict(gm_torch_level.named_buffers()), + graph_signature=export_graph_signature, + constants=constants, + ) + # 2. Restore FQN of param/buffers + _replace_param_buffer_names(param_buffer_table, export_graph_signature) + + # 3. Move non-persistent buffers to tensor constants + _move_non_persistent_buffers_to_tensor_constants( + mod, export_graph_signature, constants + ) + + # 4. Rewrite constants to have the same FQN as the original module. + _remap_constants(constant_attrs, export_graph_signature, constants) + + # 5. Rename constants nodes in graph module from buffers to constants + _rename_constants_nodes(gm, export_graph_signature) + + return ExportArtifact( + aten=aten_export_artifact, + in_spec=orig_in_spec, + out_spec=orig_out_spec, + fake_mode=dynamo_fake_mode, + module_call_specs=gm_torch_level.meta["module_call_specs"], + ) + + +def _export_to_aten_ir_make_fx( + mod: torch.nn.Module, + fake_args, + fake_kwargs, + fake_params_buffers, + constant_attrs: ConstantAttrMap, + produce_guards_callback=None, + transform=lambda x: x, +) -> ATenExportArtifact: + def _make_fx_helper(mod, args, kwargs, **flags): + kwargs = kwargs or {} + + named_parameters = dict(mod.named_parameters(remove_duplicate=False)) + named_buffers = dict(mod.named_buffers(remove_duplicate=False)) + + params_and_buffers = {**named_parameters, **named_buffers} + params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers) + params_and_buffers_flat = tuple(params_and_buffers_flat) + + param_len = len(named_parameters) + buffer_len = len(named_buffers) + params_len = len(params_and_buffers) + + functional_call = create_functional_call( + mod, params_spec, params_len, store_orig_mod=True + ) + + params_buffers_args: list[Any] = [] + params_buffers_args.extend(params_and_buffers_flat) + params_buffers_args.extend(args) + + flat_fn, out_spec = create_tree_flattened_fn( + functional_call, params_buffers_args, kwargs + ) + flat_args, in_spec = pytree.tree_flatten((params_buffers_args, kwargs)) + + @functools.wraps(flat_fn) + def wrapped_fn(*args): + return tuple(flat_fn(*args)) + + with enable_python_dispatcher(): + ctx = nullcontext() + non_strict_root = getattr(mod, "_export_root", None) + if non_strict_root is not None: + ctx = _detect_attribute_assignment(non_strict_root) # type: ignore[assignment] + + # For any buffer that is assigned, we want to associate it to the final proxy node + # that it is assigned to. This node can then be copied into the buffer. + assigned_buffers: dict[str, str] = {} + hook = register_buffer_assignment_hook( + non_strict_root, assigned_buffers + ) + + def custom_getattribute(self, attr, *, original_getattr, attrs_to_proxy): + """ + The idea here is that we override subclass getattr methods to proxy + inner tensors and metadata. Because of infinite loop shenanigans, we have + to manually construct the getattr proxy nodes without relying on torch function + system. + """ + out = original_getattr(self, attr) + if attr in attrs_to_proxy: + if torch._C._is_torch_function_mode_enabled(): + if isinstance(out, torch.Tensor): + # When we get here there is no guarantee that we will hit the + # PreDispatchTorchFunctionMode, so we manually peak into the torch + # function mode list and tweak the PreDispatchTorchFunctionMode. + # This has side effect of proxying stuff like + # proxy.node.meta["val"] = extract_val(val) because at that time, torch function + # mode is still active. It seems bad to turn it off inside proxy_tensor.py, so + # I guess we will just rely on DCE for now to remove extra stuff like detach + torch_function_mode_stack = ( + torch.overrides._get_current_function_mode_stack() + ) + for mode in torch_function_mode_stack: + if isinstance(mode, PreDispatchTorchFunctionMode): + tracer = mode.tracer + proxy = get_proxy_slot(self, tracer).proxy + inner_proxy = tracer.create_proxy( + "call_function", + torch.ops.export.access_subclass_inner_tensor.default, + (proxy, attr), + {}, + ) + track_tensor_tree( + out, inner_proxy, constant=None, tracer=tracer + ) + return out + + @contextmanager + def override_getattribute_for_subclasses(args): + """ + Context manager that temporarily monkey patches + tensor.__getattribute__ so that we can intercept it at + torch_function layer. + """ + + # Dictionary that tracks subclass type to original getattr function + # and the attributes we can proxy. + tensor_type_to_old_getattribute: dict[ + type[torch.Tensor], tuple[Callable, set[str]] + ] = {} + for arg in args: + subclass_types_to_instances: dict[ + type[torch.Tensor], list[type[torch.Tensor]] + ] = get_subclass_typing_container(arg) + for subclass_type in subclass_types_to_instances: + if subclass_type not in tensor_type_to_old_getattribute: + assert len(subclass_types_to_instances[subclass_type]) > 0 + instance = subclass_types_to_instances[subclass_type][0] + # Query subclass specific attrs + attrs_to_proxy = set(dir(instance)) - set(dir(torch.Tensor)) + tensor_type_to_old_getattribute[subclass_type] = ( + subclass_type.__getattribute__, # type: ignore[attr-defined] + attrs_to_proxy, + ) + + try: + for k, ( + old_getattr, + attrs_to_proxy, + ) in tensor_type_to_old_getattribute.items(): + custom = functools.partialmethod( + custom_getattribute, + original_getattr=old_getattr, + attrs_to_proxy=attrs_to_proxy, + ) + k.__getattribute__ = custom # type: ignore[assignment, attr-defined] + yield + finally: + for k, (old_getattr, _) in tensor_type_to_old_getattribute.items(): + k.__getattribute__ = old_getattr # type: ignore[method-assign, attr-defined] + + with ctx, override_getattribute_for_subclasses(flat_args): + gm = make_fx( + wrapped_fn, + record_module_stack=True, + pre_dispatch=True, + )(*flat_args) + + if non_strict_root is not None: + input_names = _graph_input_names(gm) + buffer_input_names = { + name: input_names[param_len + i] + for i, (name, buf) in enumerate(non_strict_root._buffers.items()) + if buf is not None + } + output_node = list(gm.graph.nodes)[-1] + # We copy nodes corresponding to buffer assignments to buffers in the graph. + for buf, name in assigned_buffers.items(): # type: ignore[possibly-undefined] + buf_node = _find_node(gm, buffer_input_names[buf]) + name_node = _find_node(gm, name) + with gm.graph.inserting_before(output_node): + new_node = gm.graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + args=(buf_node, name_node), + ) + new_node.meta = name_node.meta + + hook.remove() # type: ignore[possibly-undefined] + + def _is_impure(node): + if node.op == "call_function" and node.target in ( + # In export, we ignore any op that is related to + # eager mode profiling call. The expectation is + # that either runtimes provide their own profiling + # OR user wrap the compiled region on a profiling in + # later stage. + torch.ops.profiler._record_function_enter.default, + torch.ops.profiler._record_function_enter_new.default, + torch.ops.profiler._record_function_exit._RecordFunction, + # In theory, we could fix this dead detach and getattr nodes + # from subclass tensors if we carefully rewrite track_tensor_tree + # in a way that it doesn't do any tensor methods. + torch.ops.aten.detach.default, + torch.ops.export.access_subclass_inner_tensor.default, + ): + return False + return True + + gm.graph.eliminate_dead_code(_is_impure) + + # create graph signature + input_names = _graph_input_names(gm) + output_names = _graph_output_names(gm) + sig = GraphSignature( + parameters=list(named_parameters), + buffers=list(named_buffers), + user_inputs=input_names[params_len:], + user_outputs=output_names, + inputs_to_parameters=dict(zip(input_names[0:param_len], named_parameters)), + inputs_to_buffers=dict( + zip(input_names[param_len : param_len + buffer_len], named_buffers) + ), + buffers_to_mutate={}, + user_inputs_to_mutate={}, + in_spec=in_spec, + out_spec=out_spec, # type: ignore[arg-type] + backward_signature=None, + input_tokens=[], + output_tokens=[], + ) + return gm, sig + + # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, + # otherwise aot_export_module will error out because it sees a mix of fake_modes. + # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. + with ( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), + _ignore_backend_decomps(), + _compiling_state_context(), + ): + gm, graph_signature = transform(_make_fx_helper)( + mod, + fake_args, + trace_joint=False, + kwargs=fake_kwargs, + ) + + # [NOTE] In training IR, we don't run + # any DCE as a result we preserve constant + # nodes in the graph. make_fx invariant is that + # they don't guarantee every node gets a meta['val'] + # field. Since the actual value is already hardcoded in + # graph, the node.meta here actually doesn't matter. But + # we do this to make spec verifier happy. + for node in gm.graph.nodes: + if ( + node.op == "call_function" + and len(node.users) == 0 + and "val" not in node.meta + ): + node.meta["val"] = None + + if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"): + gm.meta.update(mod.meta) + + # See comment in _export_to_aten_ir() + if produce_guards_callback: + try: + produce_guards_callback(gm) + except (ConstraintViolationError, ValueRangeError) as e: + raise UserError(UserErrorType.CONSTRAINT_VIOLATION, str(e)) # noqa: B904 + + return _produce_aten_artifact( + gm=gm, + mod=mod, + constant_attrs=constant_attrs, + graph_signature=graph_signature, + pre_dispatch=True, + fake_args=fake_args, + fake_kwargs=fake_kwargs, + fake_params_buffers=fake_params_buffers, + ) + + +def set_missing_meta_vals(gm, flat_args, num_params_buffers): + # Sets missing metadata to address two problems: + # 1. aot_export adds symint metadata for placeholders with int values; since + # these become specialized, we replace such metadata with the original values. + # 2. any tensor attributes that are not params / buffers, i.e., are constants + # need to have their metadata set before lifting them because it is needed + # for computing the exported program's signature. + index = 0 + for node in gm.graph.nodes: + if node.op == "placeholder": + if index >= num_params_buffers: + user_arg = flat_args[index - num_params_buffers] + if not isinstance(user_arg, torch.Tensor): + node.meta["val"] = user_arg + index += 1 + + +def _find_node(gm: torch.fx.GraphModule, name: str) -> torch.fx.Node: + return next(iter(node for node in gm.graph.nodes if node.name == name)) + + +def _non_strict_export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: dict[str, Any], + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]], + preserve_module_call_signature: tuple[str, ...], + orig_in_spec: TreeSpec, + allow_complex_guards_as_runtime_asserts: bool, + _is_torch_jit_trace: bool, + _to_aten_func: Callable, +) -> ExportArtifact: + """ + _to_aten_func can either be `_export_to_aten_ir_make_fx` or `_export_to_aten_ir` + """ + + out_spec: Optional[TreeSpec] = None + in_spec: Optional[TreeSpec] = None + + module_call_specs: dict[str, dict[str, pytree.TreeSpec]] = {} + + def _tuplify_outputs(aot_export): + def _aot_export_non_strict(mod, args, kwargs=None, **flags): + kwargs = kwargs or {} + + class Wrapper(torch.nn.Module): + def __init__(self, mod): + super().__init__() + self._export_root = mod + + def forward(self, *args, **kwargs): + nonlocal out_spec + nonlocal in_spec + mod = self._export_root + _, in_spec = pytree.tree_flatten((args, kwargs)) + if isinstance(mod, torch.fx.GraphModule): + # NOTE: We're going to run this graph module with an fx interpreter, + # which will not run any forward hooks. Thus, ideally, we should run + # all forward hooks here. But the general logic for running them is + # complicated (see nn/module.py), and probably not worth duplicating. + # Instead we only look for, and run, an export-specific forward hook. + if ( + _check_input_constraints_pre_hook + in mod._forward_pre_hooks.values() + ): + _check_input_constraints_pre_hook(mod, args, kwargs) + with torch.fx.traceback.preserve_node_meta(): + args = (*args, *kwargs.values()) + tree_out = torch.fx.Interpreter(mod).run(*args) + else: + tree_out = mod(*args, **kwargs) + flat_outs, out_spec = pytree.tree_flatten(tree_out) + return tuple(flat_outs) + + wrapped_mod = Wrapper(mod) + # Patch export_root to the signatures so that wrapper module correctly populates the + # in/out spec + new_preserved_call_signatures = [ + "_export_root." + i for i in preserve_module_call_signature + ] + ctx = nullcontext() + if not isinstance(mod, torch.fx.GraphModule): + ctx = _wrap_submodules( # type: ignore[assignment] + wrapped_mod, new_preserved_call_signatures, module_call_specs + ) + with ctx: + gm, sig = aot_export(wrapped_mod, args, kwargs=kwargs, **flags) + log.debug("Exported program from AOTAutograd:\n%s", gm) + + sig.parameters = pytree.tree_map(_strip_root, sig.parameters) + sig.buffers = pytree.tree_map(_strip_root, sig.buffers) + sig.inputs_to_buffers = pytree.tree_map(_strip_root, sig.inputs_to_buffers) + sig.inputs_to_parameters = pytree.tree_map( + _strip_root, sig.inputs_to_parameters + ) + sig.buffers_to_mutate = pytree.tree_map(_strip_root, sig.buffers_to_mutate) + + for node in gm.graph.nodes: + if "nn_module_stack" in node.meta: + nn_module_stack = node.meta["nn_module_stack"] + node.meta["nn_module_stack"] = { + _fixup_key(key): val + for key, val in pytree.tree_map( + _strip_root, nn_module_stack + ).items() + } + + return gm, sig + + return _aot_export_non_strict + + ( + fake_mode, + fake_args, + fake_kwargs, + equalities_inputs, + original_signature, + dynamic_shapes, + ) = make_fake_inputs( + mod, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=_is_torch_jit_trace, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, # for shape env initialization + ) + + fake_params_buffers = _fakify_params_buffers(fake_mode, mod) + + def _produce_guards_callback(gm): + return produce_guards_and_solve_constraints( + fake_mode=fake_mode, + gm=gm, + dynamic_shapes=dynamic_shapes, + equalities_inputs=equalities_inputs, + original_signature=original_signature, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + + tx = TracingContext(fake_mode) + + # We also need to attach dynamo configs as these will be used in HOOs that + # use torch.compile, like cond + dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG) + dynamo_config["do_not_emit_runtime_asserts"] = ( + False # We want to emit runtime asserts + ) + + with ( + fake_mode, + _NonStrictTorchFunctionHandler(), + tracing(tx), + torch._dynamo.config.patch(dynamo_config), + ): + with ( + _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ), + _fakify_module_inputs(fake_args, fake_kwargs, fake_mode), + _override_builtin_ops(), + ): + aten_export_artifact = _to_aten_func( # type: ignore[operator] + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + new_fake_constant_attrs, + produce_guards_callback=_produce_guards_callback, + transform=_tuplify_outputs, + ) + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { + fqn: map_fake_to_real[obj] if isinstance(obj, FakeScriptObject) else obj + for fqn, obj in aten_export_artifact.constants.items() + } + + _move_non_persistent_buffers_to_tensor_constants( + mod, aten_export_artifact.sig, aten_export_artifact.constants + ) + + assert out_spec is not None + assert in_spec is not None + + return ExportArtifact( + aten=aten_export_artifact, + in_spec=in_spec, + out_spec=out_spec, + fake_mode=fake_mode, + module_call_specs=module_call_specs, + ) + + +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export_for_training( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + *, + strict: bool = True, + preserve_module_call_signature: tuple[str, ...] = (), +) -> ExportedProgram: + global _EXPORT_MODULE_HIERARCHY + _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) + + ( + args, + kwargs, + orig_in_spec, + dynamic_shapes, + verify_additional_inputs, + ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) + + original_state_dict = _get_original_state_dict(mod) + + # Call the appropriate export function based on the strictness of tracing. + export_func = _strict_export if strict else _non_strict_export + + export_artifact = export_func( + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + orig_in_spec=orig_in_spec, + allow_complex_guards_as_runtime_asserts=False, + _is_torch_jit_trace=False, + _to_aten_func=_export_to_aten_ir_make_fx, + ) + + export_graph_signature = export_artifact.aten.sig + + forward_arg_names = _get_forward_arg_names(mod, args, kwargs) + inline_constraints = _get_inline_constraints(export_artifact.fake_mode) + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo. + # Note: _get_range_constraints depends on "inline_constraints" to be set. + export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints + range_constraints = _get_range_constraints( + mod, + export_artifact, + args, + kwargs, + dynamic_shapes, + ) + # The returned the gm is in-place modified + gm, module_call_graph = _get_module_call_graph( + export_artifact, + preserve_module_call_signature, + strict, + forward_arg_names, + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, export_graph_signature) + + _update_gm_meta_if_possible(gm, mod) + + from torch._export.verifier import TrainingIRVerifier + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=original_state_dict, + range_constraints=range_constraints, + module_call_graph=module_call_graph, + example_inputs=(args, kwargs), + constants=export_artifact.aten.constants, + verifiers=[TrainingIRVerifier], + ) + + verify_additional_inputs(exported_program) + return exported_program + + +@_log_export_wrapper +@_disable_prexisiting_fake_mode +def _export( + mod: torch.nn.Module, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None, + *, + strict: bool = True, + preserve_module_call_signature: tuple[str, ...] = (), + pre_dispatch: bool = False, + allow_complex_guards_as_runtime_asserts: bool = False, + _is_torch_jit_trace: bool = False, +) -> ExportedProgram: + """ + Traces either an nn.Module's forward function or just a callable with PyTorch + operations inside and produce a ExportedProgram. + + Args: + mod: the `nn.Module` to trace. + + args: example positional inputs. + + kwargs: optional example keyword inputs. + + dynamic_shapes: + An optional argument where the type should either be: + 1) a dict from argument names of ``f`` to their dynamic shape specifications, + 2) a tuple that specifies dynamic shape specifications for each input in original order. + If you are specifying dynamism on keyword args, you will need to pass them in the order that + is defined in the original function signature. + + The dynamic shape of a tensor argument can be specified as either + (1) a dict from dynamic dimension indices to :func:`Dim` types, where it is + not required to include static dimension indices in this dict, but when they are, + they should be mapped to None; or (2) a tuple / list of :func:`Dim` types or None, + where the :func:`Dim` types correspond to dynamic dimensions, and static dimensions + are denoted by None. Arguments that are dicts or tuples / lists of tensors are + recursively specified by using mappings or sequences of contained specifications. + + preserve_module_call_signature: A list of submodule paths for which the original + calling conventions are preserved as metadata. + + allow_complex_guards_as_runtime_asserts: + With the current dynamic shapes language for dims and derived dims, we can run into constraints + that are not expressible with the language. For example, flattening a matrix and adding to a vector, + both fully dynamic (i.e. x.reshape([-1]) + y) emits a guard s0 * s1 = s2, which is not expressible. + By default, we either raise a constraint violation error or specialize to static values. + If this flag is set to True, we avoid erroring out and instead allow complex constraints to exist as runtime + assertions in the graph. The sympy interpreter (torch/utils/_sympy/interp.py) will produce the math ops + required to compute and assert the value of the guard (e.g. sym_size_int, eq, _assert_scalar). + Additionally, if TORCH_DYNAMO_DO_NOT_EMIT_RUNTIME_ASSERTS=1 is specified, we will allow complex constraints + while not emitting runtime asserts, returning a cleaner graph with lesser guarantees around dynamic shapes. + + Returns: + An ExportedProgram containing the traced module. + """ + + from torch._utils_internal import export_training_ir_rollout_check + + global _EXPORT_FLAGS, _EXPORT_MODULE_HIERARCHY + _EXPORT_MODULE_HIERARCHY = _get_module_hierarchy(mod) + + flags = set() + flags.add("strict" if strict else "non_strict") + flags.add("pre_dispatch" if pre_dispatch else "aot_dispatch") + _EXPORT_FLAGS = flags + + log_export_usage(event="export.enter", flags=_EXPORT_FLAGS) + + dtrace_structured("export", payload_fn=lambda: "start!") + + # NOTE Export training IR rollout + # Old export calls export._trace(pre_dispatch=True) + # and there are still lot of internal/OSS callsites that + # use export._trace(pre_dispatch=True) directly. Therefore, + # it makes more sense to do the switch here. + # export_training_ir_rollout_check returns True in OSS + # while internally it returns False UNLESS otherwise specified. + if pre_dispatch and export_training_ir_rollout_check(): + ep = _export_for_training( + mod, + args, + kwargs, + dynamic_shapes, + strict=strict, + preserve_module_call_signature=preserve_module_call_signature, + ) + dtrace_structured("exported_program", payload_fn=lambda: str(ep)) + return ep + + ( + args, + kwargs, + original_in_spec, + dynamic_shapes, + verify_additional_inputs, + ) = _process_export_inputs(mod, args, kwargs, dynamic_shapes) + + original_state_dict = _get_original_state_dict(mod) + + # Call the appropriate export function based on the strictness of tracing. + export_func = _strict_export if strict else _non_strict_export + + export_artifact = export_func( # type: ignore[operator] + mod=mod, + args=args, + kwargs=kwargs, + dynamic_shapes=dynamic_shapes, + preserve_module_call_signature=preserve_module_call_signature, + orig_in_spec=original_in_spec, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + _is_torch_jit_trace=_is_torch_jit_trace, + _to_aten_func=functools.partial( + _export_to_aten_ir, + pre_dispatch=pre_dispatch, + _is_torch_jit_trace=_is_torch_jit_trace, + ), + ) + export_graph_signature: ExportGraphSignature = export_artifact.aten.sig + + forward_arg_names = ( + _get_forward_arg_names(mod, args, kwargs) if not _is_torch_jit_trace else None + ) + inline_constraints = _get_inline_constraints(export_artifact.fake_mode) + # The unbacked symint symbols are updated in aot_export + # so we serialize them here instead of inside dynamo. + # Note: this step must be before _get_range_constraints. + export_artifact.aten.gm.meta["inline_constraints"] = inline_constraints + range_constraints = _get_range_constraints( + mod, + export_artifact, + args, + kwargs, + dynamic_shapes, + _is_torch_jit_trace=_is_torch_jit_trace, + ) + gm, module_call_graph = _get_module_call_graph( + export_artifact, + preserve_module_call_signature, + strict, + forward_arg_names, + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + if not _is_torch_jit_trace: + _verify_placeholder_names(gm, export_graph_signature) + + # Remove Proxy because they cannot be deepcopied or pickled. + torch._export.utils.remove_proxy_from_state_dict(original_state_dict, in_place=True) + + from torch._export.verifier import Verifier + + _update_gm_meta_if_possible(gm, mod) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=export_graph_signature, + state_dict=original_state_dict, + range_constraints=range_constraints, + module_call_graph=module_call_graph, + example_inputs=(args, kwargs), + constants=export_artifact.aten.constants, + verifiers=[Verifier], + ) + + dtrace_structured("exported_program", payload_fn=lambda: str(exported_program)) + + verify_additional_inputs(exported_program) + return exported_program diff --git a/phivenv/Lib/site-packages/torch/export/_tree_utils.py b/phivenv/Lib/site-packages/torch/export/_tree_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ff9e441d5945667cb12252e15e5546a73d9e602 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_tree_utils.py @@ -0,0 +1,64 @@ +from typing import Any, Callable, Optional + +from torch.utils._pytree import Context, TreeSpec + + +def reorder_kwargs(user_kwargs: dict[str, Any], spec: TreeSpec) -> dict[str, Any]: + """Reorder user-provided kwargs to match the order in `spec`. `spec` is + expected to be the in_spec of an exported program, i.e. the spec that + results from flattening `(args, kwargs)`. + + We need this to provide consistent input ordering, such so that users can + pass in foo(a=a, b=b) OR foo(b=b, a=a) and receive the same result. + """ + # Make sure that the spec is actually shaped like (args, kwargs) + assert spec.type is tuple + assert spec.num_children == 2 + kwargs_spec = spec.children_specs[1] + assert kwargs_spec.type is dict + + if set(user_kwargs) != set(kwargs_spec.context): + raise ValueError( + f"Ran into a kwarg keyword mismatch: " + f"Got the following keywords {list(user_kwargs)} but expected {kwargs_spec.context}" + ) + + reordered_kwargs = {} + for kw in kwargs_spec.context: + reordered_kwargs[kw] = user_kwargs[kw] + + return reordered_kwargs + + +def is_equivalent( + spec1: TreeSpec, + spec2: TreeSpec, + equivalence_fn: Callable[[Optional[type], Context, Optional[type], Context], bool], +) -> bool: + """Customizable equivalence check for two TreeSpecs. + + Arguments: + spec1: The first TreeSpec to compare + spec2: The second TreeSpec to compare + equivalence_fn: A function to determine the equivalence of two + TreeSpecs by examining their types and contexts. It will be called like: + + equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context) + + This function will be applied recursively to all children. + + Returns: + True if the two TreeSpecs are equivalent, False otherwise. + """ + if not equivalence_fn(spec1.type, spec1.context, spec2.type, spec2.context): + return False + + # Recurse on children + if len(spec1.children_specs) != len(spec2.children_specs): + return False + + for child_spec1, child_spec2 in zip(spec1.children_specs, spec2.children_specs): + if not is_equivalent(child_spec1, child_spec2, equivalence_fn): + return False + + return True diff --git a/phivenv/Lib/site-packages/torch/export/_unlift.py b/phivenv/Lib/site-packages/torch/export/_unlift.py new file mode 100644 index 0000000000000000000000000000000000000000..8af7fe76ed5761bc340fc5271ada2d13893b099f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_unlift.py @@ -0,0 +1,481 @@ +# mypy: allow-untyped-defs +import copy +import warnings +from collections.abc import Sequence +from itertools import chain +from typing import Any, Optional + +import torch +import torch.utils._pytree as pytree +from torch._export.non_strict_utils import ( + _enter_enable_graph_inputs_of_type_nn_module, + _exit_enable_graph_inputs_of_type_nn_module, + _get_graph_inputs_of_type_nn_module, +) +from torch._export.utils import _check_input_constraints_for_graph +from torch.export.unflatten import _assign_attr, _AttrKind +from torch.fx.experimental.proxy_tensor import _pytree_subclasses_that_lose_info +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.traceback import NodeSource, NodeSourceAction + +from ._remove_effect_tokens_pass import _remove_effect_tokens +from ._tree_utils import reorder_kwargs +from .exported_program import ( + ExportedProgram, + ExportGraphSignature, + InputKind, + OutputKind, +) + + +def eq_spec(self: pytree.TreeSpec, other: pytree.TreeSpec) -> bool: + """ + Refinement of TreeSpec.__eq__ where, e.g., torch.Size(...) matches tuple(...). + See _pytree_subclasses_that_lose_info in proxy_tensor.py for more details. + """ + + def _normalize_type(t): + return str(_pytree_subclasses_that_lose_info.get(t, t)) + + def _match_normalized_structure(a, b): + if a is b: + return True + if _normalize_type(a.type) != _normalize_type(b.type): + return False + if a.context != b.context: + return False + if len(a.children_specs) != len(b.children_specs): + return False + return all( + _match_normalized_structure(a, b) + for a, b in zip(a.children_specs, b.children_specs) + ) + + return _match_normalized_structure(self, other) + + +def _check_inputs_match(args, kwargs, in_spec: pytree.TreeSpec) -> list: + reordered_kwargs = reorder_kwargs(kwargs, in_spec) + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, reordered_kwargs) + ) + + if not eq_spec(received_spec, in_spec): + raise ValueError( # noqa: B904 + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}.\n" + "Please check that the inputs have the same number and type of " + "args and kwargs as the ones you used when tracing." + ) + + return flat_args_with_path + + +@torch._dynamo.disable +def _check_input_constraints_pre_hook(self, args, kwargs): + if not self.validate_inputs: + return + + flat_args_with_path = _check_inputs_match(args, kwargs, self._in_spec) + + _check_input_constraints_for_graph( + [node for node in self.graph.nodes if node.op == "placeholder"], + flat_args_with_path, + self.range_constraints, + ) + + +def _unlift_inputs_as_getattr( + gm: torch.fx.GraphModule, + lifted_inputs: Sequence[Optional[str]], +) -> tuple[dict[str, torch.fx.Node], dict[str, torch.fx.Node]]: + """ + Unlift inputs referring to params/buffers/constants as getattr nodes in the + graph + """ + unlifted_name_to_node = {} + input_name_to_node = {} + + placeholder_nodes = [node for node in gm.graph.nodes if node.op == "placeholder"] + assert len(lifted_inputs) == len(placeholder_nodes) + for input_node, lifted_node in zip(placeholder_nodes, lifted_inputs): + if lifted_node is None: + input_name_to_node[input_node.name] = input_node + + else: + with gm.graph.inserting_after(input_node): + # It is fine to ignore this warning because + # it is guaranteed that we will populate this + # attr later. + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + getattr_node = gm.graph.get_attr(lifted_node) + input_node.replace_all_uses_with(getattr_node) + metadata = input_node.meta + gm.graph.erase_node(input_node) + getattr_node.meta = metadata + getattr_node.meta["from_node"] = [ + NodeSource( + input_node, + "ExportedProgram.module().unlift()", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + ] + unlifted_name_to_node[lifted_node] = getattr_node + + return unlifted_name_to_node, input_name_to_node + + +def _insert_copy_for_mutations( + gm: torch.fx.GraphModule, + mutated_outputs: Sequence[Optional[str]], + unlifted_name_to_node: dict[str, torch.fx.Node], + input_name_to_node: dict[str, torch.fx.Node], +) -> None: + """ + Find the all the buffers and inputs that were mutated and insert copy_ + operators to reflect mutations. + """ + output_node = None + for node in gm.graph.nodes: + if node.op == "output": + output_node = node + break + assert output_node is not None + outputs = pytree.tree_flatten(output_node.args)[0] + assert len(outputs) == len(mutated_outputs) + + user_output_nodes = [] + return_nodes_to_copy = {} + for return_node, mutated_node_name in zip(outputs, mutated_outputs): + if mutated_node_name is None: + user_output_nodes.append(return_node) + continue + + if mutated_node_name in unlifted_name_to_node: + mutated_node = unlifted_name_to_node[mutated_node_name] + elif mutated_node_name in input_name_to_node: + mutated_node = input_name_to_node[mutated_node_name] + else: + raise RuntimeError( + f"Could not find {mutated_node_name} in either buffer or input nodes" + ) + + with gm.graph.inserting_before(output_node): + copy_node = gm.graph.call_function( + torch.ops.aten.copy_.default, (mutated_node, return_node) + ) + return_nodes_to_copy[return_node] = copy_node + + output_args = [ + return_nodes_to_copy[node] if node in return_nodes_to_copy else node + for node in user_output_nodes + ] + with gm.graph.inserting_before(output_node): + # Only return user outputs + new_output = gm.graph.output(tuple(output_args)) + output_node.replace_all_uses_with(new_output) + gm.graph.erase_node(output_node) + new_output.name = output_node.name + new_output.meta.update(output_node.meta) + new_output.meta["from_node"] = [ + NodeSource( + output_node, + "ExportedProgram.module().unlift()", + [NodeSourceAction.CREATE, NodeSourceAction.REPLACE], + ) + ] + + +def _get_codegen( + in_spec: pytree.TreeSpec, + out_spec: Optional[pytree.TreeSpec], + forward_arg_names: Optional[list[str]] = None, +) -> _PyTreeCodeGen: + """ + Create the codegen for the graph module based on the in/out specs + """ + if forward_arg_names: + names = forward_arg_names + else: + if ( + in_spec.type == tuple + and in_spec.num_children == 2 + and in_spec.children_specs[0].type == tuple + and in_spec.children_specs[1].type == dict + ): + # if in_spec contains the args (tuple) and kwargs (dict) + names = [f"arg_{i}" for i in range(in_spec.children_specs[0].num_children)] + # add kwarg names + names.extend(in_spec.children_specs[1].context) + else: + names = [f"arg_{i}" for i in range(in_spec.num_children)] + + return _PyTreeCodeGen( + _PyTreeInfo( + names, + in_spec, + out_spec, + ) + ) + + +def _unlift( + gm: torch.fx.GraphModule, + lifted_inputs: Sequence[Optional[str]], + mutated_outputs: Sequence[Optional[str]], + in_spec: pytree.TreeSpec, + out_spec: Optional[pytree.TreeSpec], + state_dict: dict[str, Any], + constants: dict[str, Any], + forward_arg_names: Optional[list[str]] = None, +): + """ + Args: + lifted_inputs: A list matching the graph module's input nodes. For + an input node that is referring to a lifted parameter/buffer, this + list will contain the fqn the corresponding attribute. Otherwise, this + list will contain None. This is used to unlift the lifted parameters as + get_attr nodes. + + mutated_outputs: A list matching the graph module's output nodes. For + an output node that is referring to a mutated buffer or user input, this + list will contain the name of the corresponding buffer or user input + that needs to be mutated. Otherwise, this list will contain None. This + is used to re-insert an inplace copy_ operator to copy the mutated + values back to the original node. + """ + unlifted_name_to_node, input_name_to_node = _unlift_inputs_as_getattr( + gm, lifted_inputs + ) + _insert_copy_for_mutations( + gm, mutated_outputs, unlifted_name_to_node, input_name_to_node + ) + gm.graph._codegen = _get_codegen(in_spec, out_spec, forward_arg_names) + gm.graph.lint() + gm.recompile() + return gm + + +def _register_attrs_to_new_gm( + new_gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + state_dict: dict[str, Any], + constants: dict[str, Any], +) -> None: + non_persistent_buffers = set(graph_signature.non_persistent_buffers) + for name in graph_signature.buffers: + if name in non_persistent_buffers: + persistent = False + value = constants[name] + else: + persistent = True + value = state_dict[name] + _assign_attr( + value, new_gm, name, attr_kind=_AttrKind.BUFFER, persistent=persistent + ) + for name in graph_signature.parameters: + value = state_dict[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + # Technically this doesn't account for the aliased multiple constants but + # it is ok because we have a separate pass later in the stack that populates + # the final gm. + for name in chain( + graph_signature.lifted_custom_objs, graph_signature.lifted_tensor_constants + ): + value = constants[name] + _assign_attr( + value, + new_gm, + name, + attr_kind=_AttrKind.CONSTANT, + ) + + +class _StatefulGraphModuleFactory(type): + """ + Metaclass that ensures a private constructor for _StatefulGraphModule + """ + + def __call__(cls, *args, **kwargs): + raise TypeError( + f"{cls.__module__}.{cls.__qualname__} has no public constructor. " + ) + + def _create(cls, root, graph, range_constraints=None): + return super().__call__( + root, + graph, + range_constraints=range_constraints, + ) + + +class _StatefulGraphModule(torch.fx.GraphModule, metaclass=_StatefulGraphModuleFactory): + def __init__(self, root, graph, range_constraints=None): + super().__init__(root, graph) + # Need to fix up non-persistent buffers. + self.range_constraints = range_constraints or [] + self.validate_inputs = True + + +def _create_stateful_graph_module( + plain_graph_module: torch.fx.GraphModule, + range_constraints, + ep: ExportedProgram, +) -> _StatefulGraphModule: + stateful_gm = _StatefulGraphModule._create( + plain_graph_module, + plain_graph_module.graph, + range_constraints=range_constraints, + ) + + module_types = _get_graph_inputs_of_type_nn_module(ep.example_inputs) + stateful_gm.register_forward_pre_hook( + lambda *args, **kwargs: _enter_enable_graph_inputs_of_type_nn_module( + module_types + ) + ) + stateful_gm.register_forward_pre_hook( + _check_input_constraints_pre_hook, with_kwargs=True + ) + + stateful_gm.register_forward_hook( + lambda *args, **kwargs: _exit_enable_graph_inputs_of_type_nn_module( + module_types + ), + always_call=True, + ) + + # When we have a constant that has requires_grad=True, we need to detach it + # when we unlift as the tensors that require gradients should be registered + # via parameters. But this is problematic when we have aliasing two constants + # because when we call detach, they will become different tensors. This dict + # keeps track of this logic. + original_tensor_to_detached_tensor = {} + + # Fix up lifted tensor constants. + # fx.GraphModule() constructor silently turns a constant attribute of plain_graph_module + # into a buffer in stateful_gm and creates an inconsistency with graph_signature. + # We fix this by de-registering these buffers in lifted_tensor_constants + # and call _assign_attr(attr_kind=CONSTANT) to register them as constants. + for constant_fqn in ep.graph_signature.lifted_tensor_constants: + # Sometimes, the constant can require gradient, this is probably a bug in user code, + # e.g. `self.const = torch.randn(2, 2, requires_grad=True)`. + # We call detach on the constant_val since they're tensor contants and we don't need to + # compute their gradients anyway. + # Users should properly register it as parameter if they want it to require gradient. + buffer = stateful_gm.get_buffer(constant_fqn) + if buffer.requires_grad: + warnings.warn( + f"A model attribute `{constant_fqn}` requires gradient. " + f"but it's not properly registered as a parameter. " + f"torch.export will detach it and treat it as a constant tensor " + f"but please register it as parameter instead." + ) + detached_buffer = buffer.detach() + original_tensor_to_detached_tensor[buffer] = detached_buffer + buffer = detached_buffer + *prefix, field = constant_fqn.rsplit(".") + submod = torch.fx.graph_module._get_attr_via_attr_list(stateful_gm, prefix) + delattr(submod, field) + _assign_attr(buffer, stateful_gm, constant_fqn, attr_kind=_AttrKind.CONSTANT) + + # Constants are not preserved well when we create a new GraphModule unlike param/buffers + for const_name, value in ep.constants.items(): + if not torch.fx.graph_module._has_attr(stateful_gm, const_name): + if isinstance(value, torch.Tensor): + if value.requires_grad: + warnings.warn( + f"A model attribute `{const_name}` requires gradient " + f"but it's not properly registered as a parameter. " + f"torch.export will detach it and treat it as a constant tensor " + f"but please register it as parameter instead." + ) + if value in original_tensor_to_detached_tensor: + value = original_tensor_to_detached_tensor[value] + else: + detached_value = value.detach() + original_tensor_to_detached_tensor[value] = detached_value + value = detached_value + _assign_attr( + value, + stateful_gm, + const_name, + attr_kind=_AttrKind.CONSTANT, + ) + + # Fix up non-persistent buffers. torch.fx does not distinguish between + # persistent and non-persistent buffers, so we must restore that distinction + # here. + for buffer in ep.graph_signature.non_persistent_buffers: + _assign_attr( + plain_graph_module.get_buffer(buffer), + stateful_gm, + buffer, + attr_kind=_AttrKind.BUFFER, + persistent=False, + ) + + return stateful_gm + + +def _unlift_exported_program_lifted_states(ep: ExportedProgram) -> torch.nn.Module: + # TODO T206340015 + if ep.verifiers[0].dialect != "TRAINING": + ep = _remove_effect_tokens(ep) + new_gm = torch.fx.GraphModule(ep.graph_module, copy.deepcopy(ep.graph)) + _register_attrs_to_new_gm(new_gm, ep.graph_signature, ep.state_dict, ep.constants) + forward_arg_names = ( + sig.forward_arg_names if (sig := ep.module_call_graph[0].signature) else None + ) + lifted_inputs: list[Optional[str]] = [ + ( + in_spec.target + if in_spec.kind + in ( + InputKind.BUFFER, + InputKind.CONSTANT_TENSOR, + InputKind.PARAMETER, + InputKind.CUSTOM_OBJ, + ) + else None + ) + for in_spec in ep.graph_signature.input_specs + ] + + mutated_outputs: list[Optional[str]] = [ + ( + out_spec.target + if out_spec.kind + in (OutputKind.BUFFER_MUTATION, OutputKind.USER_INPUT_MUTATION) + else None + ) + for out_spec in ep.graph_signature.output_specs + ] + + for node in new_gm.graph.nodes: + node.meta["from_node"] = [ + NodeSource(node, "ExportedProgram.module()", NodeSourceAction.CREATE) + ] + + new_gm = _unlift( + new_gm, + lifted_inputs, + mutated_outputs, + ep.call_spec.in_spec, + ep.call_spec.out_spec, + ep.state_dict, + ep.constants, + forward_arg_names=forward_arg_names, + ) + unlift_gm = _create_stateful_graph_module(new_gm, ep.range_constraints, ep) + unlift_gm.meta.update(ep.graph_module.meta) + return unlift_gm diff --git a/phivenv/Lib/site-packages/torch/export/_wrapper_utils.py b/phivenv/Lib/site-packages/torch/export/_wrapper_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ae7c49d8b860936956a076b41a6df063c353994b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/_wrapper_utils.py @@ -0,0 +1,10 @@ +import torch + + +class _WrapperModule(torch.nn.Module): + def __init__(self, f): # type: ignore[no-untyped-def] + super().__init__() + self.f = f + + def forward(self, *args, **kwargs): # type: ignore[no-untyped-def] + return self.f(*args, **kwargs) diff --git a/phivenv/Lib/site-packages/torch/export/custom_obj.py b/phivenv/Lib/site-packages/torch/export/custom_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..12b04215c31fb79af34511606600a856bc5ba6a8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/custom_obj.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + + +__all__ = ["ScriptObjectMeta"] + + +@dataclass +class ScriptObjectMeta: + """ + Metadata which is stored on nodes representing ScriptObjects. + """ + + # Key into constants table to retrieve the real ScriptObject. + constant_name: str + + class_fqn: str diff --git a/phivenv/Lib/site-packages/torch/export/custom_ops.py b/phivenv/Lib/site-packages/torch/export/custom_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..e5b1ac8346251cf6ffa8aaa6fb941867e9bcd9bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/custom_ops.py @@ -0,0 +1,26 @@ +import torch + + +lib = torch.library.Library("export", "FRAGMENT") # noqa: TOR901 + +lib.define( + "access_subclass_inner_tensor(Tensor src_subclass_tensor, str attr) -> Tensor" +) + + +@torch.library.impl(lib, "access_subclass_inner_tensor", "Autograd") +# When running under torch.inference_mode(), we seem to skip AUtograd key +# so we should desugar this op as soon as we start tracing to post-dispatch. +@torch.library.impl(lib, "access_subclass_inner_tensor", "Python") +def _access_subclass_inner_tensor( + src_subclass_tensor: torch.Tensor, attr: str +) -> torch.Tensor: + from torch.utils._python_dispatch import is_traceable_wrapper_subclass + + assert is_traceable_wrapper_subclass(src_subclass_tensor) + val = getattr(src_subclass_tensor, attr, None) + if val is None or not isinstance(val, torch.Tensor): + raise RuntimeError( + f"Attribute {attr} is not a tensor or doesn't exist in {src_subclass_tensor}" + ) + return val diff --git a/phivenv/Lib/site-packages/torch/export/decomp_utils.py b/phivenv/Lib/site-packages/torch/export/decomp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed5625cb9b65254563fcd5a457c2e4063786efa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/decomp_utils.py @@ -0,0 +1,156 @@ +# mypy: allow-untyped-defs +from typing import Callable + +import torch +from torch._export.utils import ( + _collect_all_valid_cia_ops, + _collect_all_valid_cia_ops_for_aten_namespace, + _get_decomp_for_cia, + _is_aten_op, +) + + +__all__ = ["CustomDecompTable"] + + +""" +Core ATen ops with Composite Implicit Autograd dispatch that should be excluded from decomposition +by default. The decomposition logic should eventually exclude all core-tagged CIA ops, but until all +backends are ready, this list allows opt-in one at a time. +""" +PRESERVED_ATEN_CIA_OPS = { + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, +} + + +class CustomDecompTable(dict[torch._ops.OperatorBase, Callable]): + """ + This is a custom dictionary that is specifically used for handling decomp_table in export. + The reason we need this is because in the new world, you can only *delete* an op from decomp + table to preserve it. This is problematic for custom ops because we don't know when the custom + op will actually be loaded to the dispatcher. As a result, we need to record the custom ops operations + until we really need to materialize it (which is when we run decomposition pass.) + + Invariants we hold are: + 1. All aten decomp is loaded at the init time + 2. We materialize ALL ops when user ever reads from the table to make it more likely + that dispatcher picks up the custom op. + 3. If it is write operation, we don't necessarily materialize + 4. We load the final time during export, right before calling run_decompositions() + + """ + + def __init__(self): + super().__init__() + from torch._decomp import _core_aten_decompositions_post_autograd + + # For aten ops, we load them up in the beginning + self.decomp_table = _core_aten_decompositions_post_autograd() + + for op in _collect_all_valid_cia_ops_for_aten_namespace(): + if op not in PRESERVED_ATEN_CIA_OPS: + self.decomp_table[op] = _get_decomp_for_cia(op) + + # This is to track the *pending* deleted custom ops that haven't been materialized yet + self.deleted_custom_ops = set() + # When this is true, there shouldn't be any pending operations in the table. + self.has_materialized = False + + def __getitem__(self, key): + self._materialize_if_needed() + return self.decomp_table.__getitem__(key) + + def __setitem__(self, key, value) -> None: + self.decomp_table.__setitem__(key, value) + + if key in self.deleted_custom_ops: + self.deleted_custom_ops.remove(key) + + def keys(self): + self._materialize_if_needed() + return self.decomp_table.keys() + + def __delitem__(self, key) -> None: + self.pop(key) + + def update(self, other_dict): # type: ignore[override] + for k, v in other_dict.items(): + self.decomp_table.__setitem__(k, v) + + def __missing__(self, key) -> bool: + return not self.__contains__(key) + + def __contains__(self, key) -> bool: + self._materialize_if_needed() + return self.decomp_table.__contains__(key) + + def __len__(self) -> int: + self._materialize_if_needed() + return self.decomp_table.__len__() + + def __iter__(self): + self._materialize_if_needed() + return self.decomp_table.__iter__() + + def __reversed__(self): + self._materialize_if_needed() + return self.decomp_table.__reversed__() + + def copy(self) -> "CustomDecompTable": + new_dict = CustomDecompTable() + new_dict.decomp_table = self.decomp_table.copy() + new_dict.deleted_custom_ops = self.deleted_custom_ops.copy() + new_dict.has_materialized = self.has_materialized + return new_dict + + def pop(self, *args): + def _pop_if_can(key): + if _is_aten_op(key): + return self.decomp_table.pop(key) + + if key in self.decomp_table: + # Even if we materialized it, we should add it to the deleted + # custom ops list so that when we materialize next time, + # we should respect user's intention. + self.deleted_custom_ops.add(key) + return self.decomp_table.pop(key) + + if key in self.deleted_custom_ops: + raise KeyError(f"{key} doesn't exist in the table") + + self.deleted_custom_ops.add(key) + # We would come here when user pops off something that is + # not in the table. In this case, we just pretend that it + # was in the table. + return _get_decomp_for_cia(key) + + if len(args) == 1: + return _pop_if_can(args[0]) + + if len(args) == 2: + try: + return _pop_if_can(args[0]) + except KeyError: + return args[1] + + def items(self): + self._materialize_if_needed() + return self.decomp_table.items() + + def materialize(self) -> dict[torch._ops.OperatorBase, Callable]: + for op in _collect_all_valid_cia_ops(): + if _is_aten_op(op): + continue + elif op in self.decomp_table: + continue + elif op not in self.deleted_custom_ops: + self.decomp_table[op] = _get_decomp_for_cia(op) + + self.has_materialized = True + self.deleted_custom_ops = set() + return {**self.decomp_table} + + def _materialize_if_needed(self) -> None: + if not self.has_materialized: + self.materialize() diff --git a/phivenv/Lib/site-packages/torch/export/dynamic_shapes.py b/phivenv/Lib/site-packages/torch/export/dynamic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..6f5457ab794b14daa7b1b36ab0ee71c3d7306b86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/dynamic_shapes.py @@ -0,0 +1,1355 @@ +# mypy: allow-untyped-defs +import dataclasses +import inspect +import logging +import sys +from collections import defaultdict +from enum import auto, Enum +from typing import Any, Callable, Optional, TYPE_CHECKING, Union + +import torch +from torch.utils._pytree import ( + _get_node_type, + BUILTIN_TYPES, + keystr, + LeafSpec, + MappingKey, + SequenceKey, + SUPPORTED_NODES, + tree_flatten, + tree_map, + tree_map_with_path, +) + +from .exported_program import ExportedProgram + + +if TYPE_CHECKING: + from sympy import Symbol + + from torch._guards import Source + from torch.fx.experimental.symbolic_shapes import ShapeEnv, StrictMinMaxConstraint + +__all__ = [ + "Constraint", + "Dim", + "dims", + "refine_dynamic_shapes_from_suggested_fixes", + "AdditionalInputs", +] + + +log = logging.getLogger(__name__) + + +class _DimHintType(Enum): + """ + Enum for dynamic shape hints. + - AUTO means automatic inference of shape (static or dynamic). + - STATIC means static shape (always specialized). + - DYNAMIC means dynamic, will error out if specialized. + """ + + AUTO = auto() + STATIC = auto() + DYNAMIC = auto() + + +@dataclasses.dataclass +class _DimHint: + type: _DimHintType + min: Optional[int] = None + max: Optional[int] = None + _factory: Optional[bool] = True + + @staticmethod + def AUTO(): + return _DimHint(_DimHintType.AUTO) + + @staticmethod + def DYNAMIC(): + return _DimHint(_DimHintType.DYNAMIC) + + @staticmethod + def STATIC(): + return _DimHint(_DimHintType.STATIC) + + def __call__(self, min=None, max=None) -> "_DimHint": + if not self._factory: + raise TypeError(f"'{type(self)}' object is not callable") + assert min is None or min >= 0, "min must be non-negative" + assert max is None or max >= 0, "max must be non-negative" + assert min is None or max is None or min <= max, "min must be <= max" + return _DimHint(self.type, min=min, max=max, _factory=False) + + +class Dim: + """ + The `Dim` class allows users to specify dynamism in their exported programs. By marking a dimension with a `Dim`, + the compiler associates the dimension with a symbolic integer containing a dynamic range. + + The API can be used in 2 ways: Dim hints (i.e. automatic dynamic shapes: `Dim.AUTO`, `Dim.DYNAMIC`, `Dim.STATIC`), + or named Dims (i.e. `Dim("name", min=1, max=2)`). + + Dim hints provide the lowest barrier to exportability, with the user only needing to specify if a dimension + if dynamic, static, or left for the compiler to decide (`Dim.AUTO`). The export process will automatically + infer the remaining constraints on min/max ranges and relationships between dimensions. + + Example:: + + class Foo(nn.Module): + def forward(self, x, y): + assert x.shape[0] == 4 + assert y.shape[0] >= 16 + return x @ y + + + x = torch.randn(4, 8) + y = torch.randn(8, 16) + dynamic_shapes = { + "x": {0: Dim.AUTO, 1: Dim.AUTO}, + "y": {0: Dim.AUTO, 1: Dim.AUTO}, + } + ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) + + Here, export would raise an exception if we replaced all uses of `Dim.AUTO` with `Dim.DYNAMIC`, + as x.shape[0] is constrained to be static by the model. + + More complex relations between dimensions may also be codegened as runtime assertion nodes by the compiler, + e.g. (x.shape[0] + y.shape[1]) % 4 == 0, to be raised if runtime inputs do not satisfy such constraints. + + You may also specify min-max bounds for Dim hints, e.g. `Dim.AUTO(min=16, max=32)`, `Dim.DYNAMIC(max=64)`, + with the compiler inferring the remaining constraints within the ranges. An exception will be raised if + the valid range is entirely outside the user-specified range. + + Named Dims provide a stricter way of specifying dynamism, where exceptions are raised if the compiler + infers constraints that do not match the user specification. For example, exporting the previous + model, the user would need the following `dynamic_shapes` argument:: + + s0 = Dim("s0") + s1 = Dim("s1", min=16) + dynamic_shapes = { + "x": {0: 4, 1: s0}, + "y": {0: s0, 1: s1}, + } + ep = torch.export(Foo(), (x, y), dynamic_shapes=dynamic_shapes) + + Named Dims also allow specification of relationships between dimensions, up to univariate linear relations. + For example, the following indicates one dimension is a multiple of another plus 4:: + + s0 = Dim("s0") + s1 = 3 * s0 + 4 + + """ + + AUTO = _DimHint.AUTO() + DYNAMIC = _DimHint.DYNAMIC() + STATIC = _DimHint.STATIC() + + def __init__( + self, name: str, *, min: Optional[int] = None, max: Optional[int] = None + ): + from torch.utils._sympy.numbers import int_oo + + _min = 0 if min is None else min + _max = int_oo if max is None else max + assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}" + assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}" + self.__name__ = name + self.min = _min + self.max = _max + + def __add__(self, other) -> "Dim": + # e.g., dim + 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to add {other} to {self.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return self._derive(lambda x: x + other) + + def __radd__(self, other) -> "Dim": + return self + other + + def __sub__(self, other) -> "Dim": + # e.g., dim - 1 + if type(other) is not int: + raise NotImplementedError( + f"Attempted to subtract {other} from {self.__name__}, where an integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return self._derive(lambda x: x - other) + + def __rsub__(self, other) -> "Dim": + raise NotImplementedError( + f"Attempted to negate {self.__name__}. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + + def __mul__(self, other) -> "Dim": + # e.g., dim * 2 + if type(other) is not int or other <= 0: + raise NotImplementedError( + f"Attempted to multiply {other} with {self.__name__}, where a positive integer was expected. " + "(Only increasing linear operations with integer coefficients are supported.)" + ) + return self._derive(lambda x: x * other) + + def __rmul__(self, other) -> "Dim": + return self * other + + def _derived_name(self, fn) -> str: + from sympy import sympify + + return str(fn(sympify(self.__name__))) + + def _derive(self, fn) -> "Dim": + return _DerivedDim(self._derived_name(fn), self, fn) + + @staticmethod + def _readable(name: str, min_: int, max_: int) -> str: + from torch.utils._sympy.numbers import int_oo + + if min_ == 2: + min_ = None # type: ignore[assignment] + if max_ == int_oo: + max_ = None # type: ignore[assignment] + if min_ is None and max_ is None: + return f"Dim('{name}')" + if min_ is None: + return f"Dim('{name}', max={max_})" + if max_ is None: + return f"Dim('{name}', min={min_})" + return f"Dim('{name}', min={min_}, max={max_})" + + def __repr__(self): + return Dim._readable(self.__name__, self.min, self.max) + + +_Dim = Dim # TODO(pianpwk): remove after it's no longer internally breaking + + +class _StaticDim(Dim): + """ + Class for static :func:`Dim` types. + + This class is only for setting and checking static dim constraints, + and the user should never interact with it. + """ + + def __init__(self, value: int): + self.__name__ = str(value) + self.value = value + + @property + def min(self): # type: ignore[override] + return self.value # type: ignore[attr-defined] + + @property + def max(self): # type: ignore[override] + return self.value # type: ignore[attr-defined] + + +class _DerivedDim(Dim): + """ + Class for derived :func:`Dim` types. + + Currently we only support increasing linear expressions with integer coefficients. + In other words, a derived Dim can always be written in the form Ax + B, where + x is a regular Dim (i.e., non-derived Dim), A and B are integers, and A is positive. + (In particular, the latter ensures that x < y => Ax + B < Ay + B.) + These restrictions on the form of derived Dims makes the metatheory simpler: e.g., + it simplifies computing ranges for derived Dims, solving for underlying regular Dims, + deciding equalities between derived Dims, and so on. + + The function lambda x: Ax + B is expressed by `fn`, where x is a normal Dim, `root`. + The range of a derived Dim is computed by mapping `fn` over the range of its `root`. + """ + + def __init__(self, name: str, root: Dim, fn: Callable): + self.__name__ = name + self.root = root + self.fn = fn + + @property + def min(self): # type: ignore[override] + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + from torch.utils._sympy.numbers import int_oo + + if self.root.min is -int_oo: # type: ignore[attr-defined] + return -int_oo # fn not needed cuz increasing + + _min_symint = self.fn(Integer(self.root.min)) # type: ignore[attr-defined] + root = self.root # type: ignore[attr-defined] + assert _min_symint >= 0, ( + f"Expected derived min value of {self.__name__} to be >= 0. " + f"Please specify an appropriate min value for {root.__name__} " + f"(currently {root.min})." + ) + return int(_min_symint) + + @property + def max(self): # type: ignore[override] + # assume that self.fn is an increasing function + # TODO(avik): use sympy value range analysis instead? + from sympy import Integer + + from torch.utils._sympy.numbers import int_oo + + if self.root.max is int_oo: # type: ignore[attr-defined] + return int_oo # fn not needed cuz increasing + + _max_symint = self.fn(Integer(self.root.max)) # type: ignore[attr-defined] + root = self.root # type: ignore[attr-defined] + assert _max_symint <= sys.maxsize - 1, ( + f"Expected derived max value of {self.__name__} to be <= {sys.maxsize - 1}. " + f"Please specify an appropriate max value for {root.__name__} " + f"(currently {root.max})." + ) + return int(_max_symint) + + def _derive(self, fn): + # We support nesting, e.g., 2*dim + 1. + # This is implemented by composing operations on the same root. + # As a consequence, roots are always regular Dims (i.e., not derived Dims). + return _DerivedDim( + self._derived_name(fn), + self.root, + lambda x: fn(self.fn(x)), + ) + + def __repr__(self): + return self.__name__ + + +def dims( + *names: str, min: Optional[int] = None, max: Optional[int] = None +) -> tuple[Dim, ...]: + """ + Util to create multiple :func:`Dim` types. + + Returns: + A tuple of :func:`Dim` types. + """ + return tuple(Dim(name, min=min, max=max) for name in names) # type: ignore[misc] + + +@dataclasses.dataclass +class _ConstraintTarget: + """ + This represents input tensor dimensions. + """ + + t_id: int + dim: int + + +@dataclasses.dataclass +class _Constraint(_ConstraintTarget): + """ + This represents a Dim describing a constraint target. + + `name` is the name of the Dim. + `constraint_range` contains the min/max bounds of the Dim. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + + def _clone_with_range(self, lower=0, upper=None): + # Import sympy locally + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.numbers import int_oo + from torch.utils._sympy.value_ranges import ValueRanges + + if upper is None: + upper = int_oo + + constraint_range = StrictMinMaxConstraint( + vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), + warn_only=False, + ) + return _Constraint( + self.t_id, + self.dim, + self.name, + constraint_range, + ) + + def __ge__(self, lower): + return self._clone_with_range(lower=lower) + + def __gt__(self, lower): + return self._clone_with_range(lower=lower + 1) + + def __le__(self, upper): + return self._clone_with_range(upper=upper) + + def __lt__(self, upper): + return self._clone_with_range(upper=upper - 1) + + def __bool__(self): + # NOTE(avik): We do not support compound expressions like a <= x <= b. + # This is because Python implicitly desugars them into bool(a <= x) and bool(x <= b), + # and moreover, enforces that any overload of __bool__ must return True or False. + # FWIW, sympy also raises TypeError in this case. + raise TypeError( + "Cannot determine truth value of _Constraint. " + "If you are trying to combine _Constraint's with logical connectives, " + "you can specify them separately instead." + ) + + @property + def serializable_spec(self): + # We need a serialization compatible format of the constraint so that it + # can be savedin the graph module w/o breaking the module serialization. + # The saved constraints will be used directly for the post-exporting pass + # that converts constraints to runtime assertion. The saved constraints + # will not be saved in the serialized module. + # TODO: A better way is needed. Currently we use 't_id' to map the constraint, + # which is not reliable + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +@dataclasses.dataclass +class _PhantomRoot: + """ + This represents the root of a derived Dim where the root does not directly + specify the shape of any input dimension, but the derived Dim does. + + e.g., the input shapes 2*dim and dim + 1 are related via a "phantom" dim. + + The fields `name`, `constraint_range`, and `val` carried by a phantom root + help create a symbol for it. Any derived dims with this phantom root are + backed by expressions over this symbol. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + val: int + + +@dataclasses.dataclass +class _DerivedConstraint(_ConstraintTarget): + """ + This represents a derived Dim, whose root is either a regular constraint target + (which directly specifies the shape of some input dimension) or a phantom root + (which does so indirectly). + + It can be thought of as a subclass of `_Constraint`, except that it does not + support <, <=, >, >= operations. + """ + + name: str + constraint_range: "StrictMinMaxConstraint" + root: Union[_ConstraintTarget, _PhantomRoot] + fn: Callable + + @property + def serializable_spec(self): + # same as _Constraint.serializable_spec + return { + "t_id": self.t_id, + "dim": self.dim, + "min": self.constraint_range.vr.lower, + "max": self.constraint_range.vr.upper, + } + + +@dataclasses.dataclass +class _RelaxedConstraint(_ConstraintTarget): + """ + This represents a dim marked with Dim.AUTO/DYNAMIC (i.e. mark_dynamic() or maybe_mark_dynamic()), + which leaves relations & min/max ranges for inference, instead of requiring explicit specification. + The intention is for constraint violations to not be raised if produce_guards() finds equalities or + relations between a _RelaxedConstraint and another type of _Constraint. + """ + + @property + def serializable_spec(self): + return { + "t_id": self.t_id, + "dim": self.dim, + } + + +Constraint = Union[_Constraint, _DerivedConstraint, _RelaxedConstraint] + + +@dataclasses.dataclass +class _IntWrapper: + """ + Dummy wrapper class to wrap around integer inputs so that when we parse the + dynamic_shapes structure, we can mark if any of the integers were marked as + dynamic. + """ + + val: int + # Disallow specifying dynamism + dynamism: Optional[Union[_DimHint, int]] = dataclasses.field( + init=False, default=None + ) + + +def _process_equalities( + constraint: Constraint, + get_sources: Callable[[int, int], list["Source"]], + shape_env: "ShapeEnv", + names: dict[str, tuple[int, int]], + source_pairs: list[tuple["Source", "Source"]], + derived_equalities: list[tuple["Source", Union["Source", "Symbol"], Callable]], + phantom_symbols: dict[str, "Symbol"], + relaxed_sources: set["Source"], +): + """ + Updates `source_pairs`, `derived_equalities`, and `phantom_symbols` (which become + fields of `EqualityConstraint`) based on a given input `constraint`. + """ + + sources = get_sources(constraint.t_id, constraint.dim) + if not sources: # empty sources due to unused shapes + return + + source, *other_sources = sources + # When t.size()[dim] maps to src0, src1, ..., srcN, we add + # constraints that make src0 "equal" to src1, ..., srcN. + source_pairs.extend((source, other_source) for other_source in other_sources) + if isinstance(constraint, _Constraint): + if constraint.name in names: + shared_t_id, shared_dim = names[constraint.name] + other_sources = get_sources(shared_t_id, shared_dim) + source_pairs.extend( + (source, other_source) for other_source in other_sources + ) + else: + names[constraint.name] = (constraint.t_id, constraint.dim) + elif isinstance(constraint, _DerivedConstraint): + # branch based on the root of the _DerivedConstraint + if not isinstance(constraint.root, _PhantomRoot): + # either root points to an input source + root = get_sources(constraint.root.t_id, constraint.root.dim)[0] + else: + # or root points to a phantom symbol + if constraint.root.name in phantom_symbols: + root = phantom_symbols[constraint.root.name] + else: + # create a phantom symbol in the shape env based on the _PhantomRoot + root = shape_env.create_symbol( + val=constraint.root.val, + source=torch._dynamo.source.ConstantSource(constraint.root.name), + dynamic_dim=torch.fx.experimental.symbolic_shapes.DimDynamic.DYNAMIC, + constraint_dim=constraint.root.constraint_range, + ) + phantom_symbols[constraint.root.name] = root + + fn = constraint.fn + # A derived equality (source, root, fn) informally corresponds to source = fn(root). + # Here source describes an input and root might describe another input or a phantom symbol. + derived_equalities.append((source, root, fn)) + elif isinstance(constraint, _RelaxedConstraint): + relaxed_sources.add(source) + + +def _tree_map_with_path( + func: Callable[..., Any], + tree: Any, + *dynamic_shapes: Any, + tree_name: Optional[str] = None, +) -> Any: + """ + Customized tree_map for mapping pytrees to dynamic_shapes. + + For built-in types (e.g., standard collections) this behaves exactly like tree_map. + + OTOH for a user-defined class C registered with pytree, we cannot assume that a C + containing tensors can be mapped to a C containing dynamic shapes (i.e., C may not + be a polymorphic container). In that case we use the flattened form of C instead. + Thus a C(**tensors) that flattens to (**tensors) will map to (**dynamic_shapes). + + Args: + func: function to apply to each (int, float, str, bool, None, torch.Tensor) + tree: input pytree + dynamic_shapes: zero or more (typically one) dynamic_shapes to match + + Returns: + output pytree mapping func to each (int, float, str, bool, None, torch.Tensor) + """ + + def is_leaf(t): + # BUILTIN_TYPES is a subset of SUPPORTED_NODES, the latter being all types + # registered with pytree. Types *not* in BUILTIN_TYPES include primitive types + # (int, float, str, bool, None, torch.Tensor), which are not in SUPPORTED_NODES, + # as well as user-defined classes registered with pytree, which are. + return _get_node_type(t) not in BUILTIN_TYPES + + def f(path, t, *dynamic_shapes): + typ = _get_node_type(t) + # typ is not in BUILTIN_TYPES + if typ in SUPPORTED_NODES: + # thus typ is a user-defined class registered with pytree, + # in which case flatten and recurse + return tree_map_with_path( + f, + SUPPORTED_NODES[typ].flatten_fn(t)[0], + *dynamic_shapes, + is_leaf=is_leaf, + ) + else: + return func(path, t, *dynamic_shapes) + + try: + return tree_map_with_path(f, tree, *dynamic_shapes, is_leaf=is_leaf) + except ValueError as e: + if "mismatch" in e.args[0]: + # When PyTree finds a structural mismatch between tree and dynamic_shapes, + # the error message is unfortunately quite horrible. Let's fix that. + assert dynamic_shapes, "Cannot be a mismatch if there is no dynamic_shapes" + assert tree_name, "Must provide a tree_name when there might be a mismatch" + + def _key(type_, context, i): + # derive a PyTree key given the type, context, and child # of a TreeSpec + if type_ is dict: + return MappingKey(context[i]) + if type_ in (list, tuple): + assert context is None + return SequenceKey(i) + raise AssertionError(f"Did not expect type {type_}") + + def raise_mismatch_error(msg): + from torch._dynamo.exc import UserError, UserErrorType + + raise UserError( + UserErrorType.INVALID_INPUT, + f"Detected mismatch between the structure of `{tree_name}` and `dynamic_shapes`: {msg}", + case_name="dynamic_shapes_validation", + ) + + def _compare(tree, dynamic_shapes, path): + # raise an error at the point where tree and dynamic_shapes differ, + # including the path to that point and the reason for the difference + rendered_path = keystr(path) + if isinstance(tree, LeafSpec): + return + if isinstance(dynamic_shapes, LeafSpec): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` is a {tree.type}, " + f"but `dynamic_shapes{rendered_path}` is not" + ) + if tree.type != dynamic_shapes.type: + raise_mismatch_error( + f"`{tree_name}{rendered_path}` is a {tree.type}, " + f"but `dynamic_shapes{rendered_path}` is a {dynamic_shapes.type}" + ) + if len(tree.children_specs) != len(dynamic_shapes.children_specs): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` has {len(tree.children_specs)} elements, " + f"but `dynamic_shapes{rendered_path}` has {len(dynamic_shapes.children_specs)} elements" + ) + if tree.type is dict: + # context, children could be out of order + if sorted(tree.context) != sorted(dynamic_shapes.context): + raise_mismatch_error( + f"`{tree_name}{rendered_path}` has keys {tree.context}, " + f"but `dynamic_shapes{rendered_path}` has keys {dynamic_shapes.context}" + ) + _remap = dict( + zip(dynamic_shapes.context, dynamic_shapes.children_specs) + ) + dynamic_shapes_children_specs = [_remap[k] for k in tree.context] + else: + dynamic_shapes_children_specs = dynamic_shapes.children_specs + for i, (tree_, dynamic_shapes_) in enumerate( + zip(tree.children_specs, dynamic_shapes_children_specs) + ): + _compare( + tree_, + dynamic_shapes_, + path + [_key(tree.type, tree.context, i)], + ) + + _, tree_spec = tree_flatten(tree, is_leaf=is_leaf) + for other_tree in dynamic_shapes: + _, other_tree_spec = tree_flatten(other_tree, is_leaf) + _compare(tree_spec, other_tree_spec, []) + raise + + +def _combine_args(f, args, kwargs, _is_torch_jit_trace=False) -> dict[str, Any]: + # combine args and kwargs following the signature of f, as it happens + # in the body of f when called with *args, **kwargs + if isinstance(f, ExportedProgram): + f = f.module() + if not _is_torch_jit_trace: + signature = ( + inspect.signature(f.forward) + if isinstance(f, torch.nn.Module) + else inspect.signature(f) + ) + kwargs = kwargs if kwargs is not None else {} + return signature.bind(*args, **kwargs).arguments + return args + + +class ShapesCollection: + """ + Builder for dynamic_shapes. + Used to assign dynamic shape specifications to tensors that appear in inputs. + + This is useful particularly when :func:`args` is a nested input structure, and it's + easier to index the input tensors, than to replicate the structure of :func:`args` in + the :func:`dynamic_shapes` specification. + + Example:: + + args = {"x": tensor_x, "others": [tensor_y, tensor_z]} + + dim = torch.export.Dim(...) + dynamic_shapes = torch.export.ShapesCollection() + dynamic_shapes[tensor_x] = (dim, dim + 1, 8) + dynamic_shapes[tensor_y] = {0: dim * 2} + # This is equivalent to the following (now auto-generated): + # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [{0: dim * 2}, None]} + + torch.export(..., args, dynamic_shapes=dynamic_shapes) + + To specify dynamism for integers, we need to first wrap the integers using + _IntWrapper so that we have a "unique identification tag" for each integer. + + Example:: + + args = {"x": tensor_x, "others": [int_x, int_y]} + # Wrap all ints with _IntWrapper + mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) + + dynamic_shapes = torch.export.ShapesCollection() + dynamic_shapes[tensor_x] = (dim, dim + 1, 8) + dynamic_shapes[mapped_args["others"][0]] = Dim.DYNAMIC + + # This is equivalent to the following (now auto-generated): + # dynamic_shapes = {"x": (dim, dim + 1, 8), "others": [Dim.DYNAMIC, None]} + + torch.export(..., args, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._shapes = {} + + def __setitem__(self, t, shape): + assert isinstance(t, (torch.Tensor, _IntWrapper)), ( + f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" + ) + + # TODO(avik): check that shape is indeed a Shape + + t_id = id(t) + if t_id in self._shapes: + _shape = self._shapes[t_id] + assert shape == _shape, ( + f"Shapes assigned to input do not match: expected {_shape}, got {shape}" + ) + else: + self._shapes[id(t)] = shape + + def __getitem__(self, t): + t_id = id(t) + if t_id not in self._shapes: + self._shapes[t_id] = {} + return self._shapes[t_id] + + def __len__(self): + return len(self._shapes) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Generates the :func:`dynamic_shapes` pytree structure according to :func:`args` and :func:`kwargs`. + """ + + t_ids = set() + + def find_shape(path, t): + t_id = id(t) + if t_id in self._shapes: + t_ids.add(t_id) + return self._shapes[t_id] + else: + return None + + combined_args = _combine_args(m, args, kwargs) + dynamic_shapes = _tree_map_with_path(find_shape, combined_args) + if any(t_id not in t_ids for t_id in self._shapes): + raise ValueError( + "Some tensors that were assigned shapes were not found in args. " + "Maybe such tensors were copied when passing them as args? " + "Maybe such tensors are contained in classes that were not registered with pytree?" + ) + return dynamic_shapes + + +class AdditionalInputs: + """ + Infers dynamic_shapes based on additional inputs. + + This is useful particularly for deployment engineers who, on the one hand, may + have access to ample testing or profiling data that can provide a fair sense of + representative inputs for a model, but on the other hand, may not know enough + about the model to guess which input shapes should be dynamic. + + Input shapes that are different than the original are considered dynamic; conversely, + those that are the same as the original are considered static. Moreover, we verify + that the additional inputs are valid for the exported program. This guarantees that + tracing with them instead of the original would have generated the same graph. + + Example:: + + args0, kwargs0 = ... # example inputs for export + + # other representative inputs that the exported program will run on + dynamic_shapes = torch.export.AdditionalInputs() + dynamic_shapes.add(args1, kwargs1) + ... + dynamic_shapes.add(argsN, kwargsN) + + torch.export(..., args0, kwargs0, dynamic_shapes=dynamic_shapes) + """ + + def __init__(self): + self._examples = [] + + def add(self, args, kwargs=None): + """ + Additional input :func:`args` and :func:`kwargs`. + """ + + assert type(args) is tuple, f"Representative args {args} must be a tuple" + assert kwargs is None or type(kwargs) is dict, ( + f"Representative kwargs {kwargs} must be None or a dict" + ) + self._examples.append((args, kwargs)) + + def dynamic_shapes(self, m, args, kwargs=None): + """ + Infers a :func:`dynamic_shapes` pytree structure by merging shapes of the + original input :func:`args` and :func:`kwargs` and of each additional input + args and kwargs. + """ + + dynamic_shapes, *other_dynamic_shapes = [ + _tree_map_with_path( + lambda path, t: tuple(t.shape) if isinstance(t, torch.Tensor) else t, + _combine_args(m, args, kwargs), + ) + for args, kwargs in [(args, kwargs), *self._examples] + ] + + def _mark_dynamism(v, *other_vs): + if not all(type(v) == type(other) for other in other_vs): + raise ValueError( + "The following inputs were found to have differing types, " + f"so they cannot be marked as dynamic: {(v,) + other_vs}." + ) + + if isinstance(v, int) and not isinstance(v, bool): + if all(other_v == v for other_v in other_vs): + return None + else: + return Dim.DYNAMIC + else: + if not all(other_v == v for other_v in other_vs): + raise ValueError( + "The following inputs were found to have differing values, " + f"but they cannot be marked as dynamic: {(v,) + other_vs}." + ) + return None + + return tree_map( + _mark_dynamism, + dynamic_shapes, + *other_dynamic_shapes, + is_leaf=lambda i: type(i) is int, + ) + + def verify(self, ep): + """ + Verifies that an exported program is valid for each additional input. + """ + + epm = ep.module() + for args, kwargs in self._examples: + torch.export._unlift._check_input_constraints_pre_hook( + epm, args, kwargs or {} + ) + + +def _warn_on_None_dynamic_shape_dimension(): + msg = ( + "Using None as a dynamic shape dimension is deprecated. " + "Please use Dim.STATIC instead" + ) + # TODO(avik): raise an error in the future + log.warning(msg) + + +def _check_dynamic_shapes( + combined_args: dict[str, Any], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], +): + """ + Checks the dynamic_shapes specification for correctness, + using combined args + kwargs as reference for inputs structure. + """ + from torch._dynamo.exc import UserError, UserErrorType + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + return + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + bounds: dict[str, tuple[int, int]] = {} + + def check_same_bounds(dim): + if dim.__name__ in bounds: + min_, max_ = bounds[dim.__name__] + if dim.min != min_ or dim.max != max_: + this_ = Dim._readable(dim.__name__, min_, max_) + that_ = Dim._readable(dim.__name__, dim.min, dim.max) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Found different definitions {this_} and {that_} " + f"for the same symbolic dimension {dim}!", + ) + else: + bounds[dim.__name__] = (dim.min, dim.max) + + def check_symbols(path, tensor, shape): + if isinstance(shape, dict): + for i, dim in shape.items(): + if isinstance(dim, Dim): + check_same_bounds(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dimension mapped to index {i} in input tensor shape {shape} " + f"specified at `dynamic_shapes{keystr(path)}` " + f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, " + f" but got {dim} instead)", + case_name="dynamic_shapes_validation", + ) + elif isinstance(shape, (tuple, list)): + if len(shape) != len(tensor.shape): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Expected dynamic shape spec {shape} specified at `dynamic_shapes{keystr(path)}` " + f"to have the same length as the actual tensor shape {tensor.shape} " + f"(expected {len(tensor.shape)}, but got {len(shape)} instead)", + case_name="dynamic_shapes_validation", + ) + for i, dim in enumerate(shape): + if isinstance(dim, Dim): + check_same_bounds(dim) + elif dim is None: + _warn_on_None_dynamic_shape_dimension() + elif not (isinstance(dim, (int, _DimHint))): + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected dimension #{i} in input tensor shape {shape} " + f"specified at `dynamic_shapes{keystr(path)}` " + f"(expected None, an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC, " + f"but got {dim} instead)", + case_name="dynamic_shapes_validation", + ) + elif shape is not None: + raise UserError( + UserErrorType.INVALID_INPUT, + f"Unexpected input tensor shape {shape} specified at `dynamic_shapes{keystr(path)}` " + f"(expected either a list/tuple of dimensions, or a dict mapping indices to dimensions," + f" where each dimension is an int, a Dim, Dim.AUTO, Dim.STATIC, or Dim.DYNAMIC)", + case_name="dynamic_shapes_validation", + ) + + assert isinstance(dynamic_shapes, (dict, tuple, list)) + if isinstance(dynamic_shapes, dict): + got_keys = list(dynamic_shapes.keys()) + expected_arg_names = list(combined_args.keys()) + if sorted(got_keys) != sorted(expected_arg_names): + msg = ( + f"When `dynamic_shapes` is specified as a dict, its top-level keys " + f"must be the arg names {expected_arg_names} of `inputs`, but " + f"here they are {got_keys}. " + ) + if ( + len(combined_args) == 1 + and expected_arg_names[0] not in got_keys + and isinstance(combined_args[expected_arg_names[0]], dict) + ): + msg += ( + "Since here `inputs` is a list/tuple enclosing a single dict, " + "maybe you just forgot to enclose `dynamic_shapes` in a list/tuple?" + ) + else: + msg += ( + "Alternatively, you could also ignore arg names entirely " + "and specify `dynamic_shapes` as a list/tuple matching `inputs`." + ) + raise UserError( + UserErrorType.INVALID_INPUT, msg, case_name="dynamic_shapes_validation" + ) + + def check_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + check_symbols(path, t, dynamic_shape) + elif isinstance(t, _IntWrapper): + if isinstance(dynamic_shape, _Dim): + raise ValueError( + "Unable to specify input integers as dynamic through named " + "Dims. Please use Dim.AUTO/DYNAMIC instead." + ) + assert dynamic_shape is None or isinstance(dynamic_shape, (int, _DimHint)) + else: + if dynamic_shape is not None: + rendered_path = keystr(path) + raise UserError( + UserErrorType.INVALID_INPUT, + f"Cannot associate shape {dynamic_shape} specified at `dynamic_shapes{rendered_path}` " + f"to non-tensor type {type(t)} at `inputs{rendered_path}` (expected None)", + case_name="dynamic_shapes_validation", + ) + + _tree_map_with_path(check_shape, combined_args, dynamic_shapes, tree_name="inputs") + + +def _process_dynamic_shapes( + combined_args: dict[str, Any], + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], +) -> list[Constraint]: + """ + Reads the dynamic_shapes specification and produces a list of constraints. + """ + from torch._dynamo.exc import UserError, UserErrorType + + if dynamic_shapes is None or len(dynamic_shapes) == 0: + # we run with dynamic by default, so no need to produce constraints + return [] + if isinstance(dynamic_shapes, (tuple, list)): + combined_args = type(dynamic_shapes)(combined_args.values()) # type: ignore[assignment, misc] + + # map of Dim names representing input shape dimensions to constraints on them + symbols: dict[str, list[Constraint]] = defaultdict(list) + # track roots that do not directly represent input shape dimensions + phantom_roots: dict[str, _PhantomRoot] = {} + derived_constraints_with_phantom_root: list[_DerivedConstraint] = [] + # list of constraints to return + constraints: list[Constraint] = [] + + def to_constraint(dim, tensor, i): + import sympy + + from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint + from torch.utils._sympy.solve import try_solve + from torch.utils._sympy.value_ranges import ValueRanges + + def root_value(): + # given tensor.shape[i] is the value of dim = fn(root), + # find the value of root + symbol = sympy.Symbol(dim.root.__name__, integer=True) + expr = dim.fn(symbol) + solution = try_solve(sympy.Eq(expr, tensor.shape[i]), symbol) + if solution is not None: + return int(solution[1]) + else: + raise UserError( # noqa: B904 + UserErrorType.CONSTRAINT_VIOLATION, + f"Expected shape[{i}] = {tensor.shape[i]} of input Tensor to be " + f"of the form {expr}, where {symbol} is an integer", + ) + + if isinstance(dim, _DerivedDim): + # generate a _DerivedConstraint where the root is: + # - either a _ConstraintTarget (if dim.root directly describes an input shape) + # - or a _PhantomRoot (otherwise) + dim_root = dim.root # type: ignore[attr-defined] + if dim_root.__name__ in symbols: + # root represents an input shape dimension + root_constraint = symbols[dim_root.__name__][0] + root = _ConstraintTarget( + root_constraint.t_id, + root_constraint.dim, + ) + elif dim_root.__name__ not in phantom_roots: + # create a phantom root + root = _PhantomRoot( # type: ignore[assignment] + name=dim_root.__name__, + constraint_range=StrictMinMaxConstraint( + vr=ValueRanges(lower=dim_root.min, upper=dim_root.max), + warn_only=False, + ), + val=root_value(), + ) + phantom_roots[dim_root.__name__] = root # type: ignore[assignment] + else: + root = phantom_roots[dim_root.__name__] # type: ignore[assignment] + constraint = _DerivedConstraint( + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), + warn_only=False, + ), + root, + dim.fn, # type: ignore[attr-defined] + ) + if isinstance(root, _PhantomRoot): + # NOTE(avik): since we have not processed all inputs yet, we may replace this + # with a root that does represent an input shape dimension later (see below) + derived_constraints_with_phantom_root.append(constraint) + elif isinstance(dim, _StaticDim): + constraint = _Constraint( # type: ignore[assignment] + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined] + warn_only=False, + ), + ) + else: + assert isinstance(dim, Dim) + constraint = _Constraint( # type: ignore[assignment] + id(tensor), + i, + dim.__name__, + StrictMinMaxConstraint( + vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined] + warn_only=False, + ), + ) + return constraint + + def _parse_tensor_dim(tensor, idx, dim) -> None: + def _create_static_dim(tensor, i, value): + return _StaticDim(value) + + if isinstance(dim, (int, Dim)): + if isinstance(dim, int): + dim = _create_static_dim(tensor, idx, dim) + constraint = to_constraint(dim, tensor, idx) + symbols[dim.__name__].append(constraint) + elif isinstance(dim, _DimHint): + if dim.type == _DimHintType.AUTO: + torch._dynamo.maybe_mark_dynamic(tensor, idx) + elif dim.type == _DimHintType.STATIC: + torch._dynamo.mark_static(tensor, idx) + elif dim.type == _DimHintType.DYNAMIC: + torch._dynamo.mark_dynamic(tensor, idx) + constraints.append(_RelaxedConstraint(id(tensor), idx)) + elif dim is None: + torch._dynamo.mark_static(tensor, idx) + + def update_symbols(path, tensor, shape): + # clean out decorators from user side, or previous export call + # we also delete these attributes in non_strict_utils.py/make_constraints() + tensor._dynamo_weak_dynamic_indices = set() + tensor._dynamo_dynamic_indices = set() + tensor._dynamo_dynamic_range = set() + tensor._dynamo_static_indices = set() + tensor._dynamo_unbacked_indices = set() + + if isinstance(shape, dict): + for i, dim in shape.items(): + _parse_tensor_dim(tensor, i, dim) + elif isinstance(shape, (tuple, list)): + for i, dim in enumerate(shape): + _parse_tensor_dim(tensor, i, dim) + elif shape is None: + for i in range(tensor.dim()): + _parse_tensor_dim(tensor, i, None) + + def assoc_shape(path, t, dynamic_shape): + if isinstance(t, torch.Tensor): + update_symbols(path, t, dynamic_shape) + elif isinstance(t, _IntWrapper): + # If tensor dimensions are marked as dynamic, the tensors themselves + # get marked using mark_dynamic. However since we can't mark + # integers as dynamic, we first wrap integers in this class, and + # then set the `dim` field of the class with the dynamic shapes dim + # to mark the integer as dynamic. + t.dynamism = dynamic_shape + + _tree_map_with_path(assoc_shape, combined_args, dynamic_shapes, tree_name="inputs") + + for derived_constraint_with_phantom_root in derived_constraints_with_phantom_root: + phantom_root_name = derived_constraint_with_phantom_root.root.name # type: ignore[union-attr] + if phantom_root_name in symbols: + # We found an input shape dimension corresponding to this name, so we + # do not need a phantom symbol for it after all. + # NOTE(avik): Overall we want to maintain the invariant that roots that + # are phantom symbols are really "phantom," i.e., they cannot be represented + # by any input source. This is important when we are deciding derived equalities, + # since we can focus our attention exclusively on input sources: deciding + # derived equalities involving phantom symbols are, in comparison, trivial. + derived_constraint_with_phantom_root.root = symbols[phantom_root_name][0] + + for dynamic_dims in symbols.values(): + constraints.extend(dynamic_dims) + + return constraints + + +def _get_dim_name_mapping( + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], +): + name_to_dim = {} + for dim in tree_flatten( + dynamic_shapes, + is_leaf=lambda x: isinstance(x, Dim), + )[0]: + if dim is None: + # NOTE: this must denote a non-Tensor or automatic at this point. + continue + if isinstance(dim, int): + continue + elif isinstance(dim, Dim): + name_to_dim[dim.__name__] = dim + if isinstance(dim, _DerivedDim): + name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined] + else: + assert isinstance(dim, _DimHint) + return name_to_dim + + +def refine_dynamic_shapes_from_suggested_fixes( + msg: str, + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], +) -> Union[dict[str, Any], tuple[Any], list[Any]]: + """ + When exporting with :func:`dynamic_shapes`, export may fail with a ConstraintViolation error if the specification + doesn't match the constraints inferred from tracing the model. The error message may provide suggested fixes - + changes that can be made to :func:`dynamic_shapes` to export successfully. + + Example ConstraintViolation error message:: + + Suggested fixes: + + dim = Dim('dim', min=3, max=6) # this just refines the dim's range + dim = 4 # this specializes to a constant + dy = dx + 1 # dy was specified as an independent dim, but is actually tied to dx with this relation + + This is a helper function that takes the ConstraintViolation error message and the original :func:`dynamic_shapes` spec, + and returns a new :func:`dynamic_shapes` spec that incorporates the suggested fixes. + + Example usage:: + + try: + ep = export(mod, args, dynamic_shapes=dynamic_shapes) + except torch._dynamo.exc.UserError as exc: + new_shapes = refine_dynamic_shapes_from_suggested_fixes( + exc.msg, dynamic_shapes + ) + ep = export(mod, args, dynamic_shapes=new_shapes) + + """ + + import re + + import sympy + + from torch._dynamo.exc import UserError, UserErrorType + from torch.fx.experimental.symbolic_shapes import _is_supported_equivalence + + try: + shape_fixes_msg = msg.split("Suggested fixes:")[1].strip() + except Exception as exc: + raise UserError( + UserErrorType.INVALID_INPUT, + "Suggested fixes not found in error message given to refine_dynamic_shapes_from_suggested_fixes()", + ) from exc + + # build shape_fixes dictionary + shape_fixes = {} + for fix in shape_fixes_msg.split("\n"): + fix = fix.strip() + if match := re.match(r"(.*) = Dim\('(.*)'.*\)", fix): + name = match.group(1) + _min, _max = None, None + if match_min := re.match(r".* = Dim\('.*', min\=([0-9]+).*\)", fix): + _min = int(match_min.group(1)) + if match_max := re.match(r".* = Dim\('.*'.*max\=([0-9]+)\)", fix): + _max = int(match_max.group(1)) + shape_fixes[name] = Dim(name, min=_min, max=_max) + else: + name, expr = fix.split(" = ") + expr = sympy.sympify(expr) + if isinstance(expr, sympy.Number): + # static, integer + shape_fixes[name] = int(expr) # type: ignore[assignment] + else: + # relation or derived dim + shape_fixes[name] = expr + + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + # track derived dim roots + roots: set[str] = set() + for k, c in shape_fixes.items(): + assert isinstance(c, (int, Dim, _DerivedDim, sympy.Expr)) + if isinstance(c, sympy.Expr): # check dim/derived dim expression + assert _is_supported_equivalence(c) + shape_fixes[k] = c + roots.add(str(next(iter(c.free_symbols)))) + if isinstance(c, _DerivedDim): + roots.add(c.root.__name__) # type: ignore[attr-defined] + + # check keys are existing dims or new roots + for k, c in shape_fixes.items(): + assert k in name_to_dim or k in roots + + # cache so we don't produce multiple derived dim objects + derived_dim_cache: dict[str, _DerivedDim] = {} + + def apply_fixes(path, dim, dummy): + if dim is None or isinstance(dim, int): # not dynamic + return dim + elif dim.__name__ in shape_fixes: # directly fix + fix = shape_fixes[dim.__name__] + if isinstance(fix, sympy.Expr): # now derived or related + if str(fix) in derived_dim_cache: + return derived_dim_cache[str(fix)] + else: + symbol = next(iter(fix.free_symbols)) + # try to locate symbol + if symbol.name in shape_fixes: + root = shape_fixes[symbol.name] + else: + assert symbol.name in name_to_dim + root = name_to_dim[symbol.name] + # figure out value of fix + modulus, remainder = sympy.polys.polytools.div(fix, symbol) + dim = root + if modulus != 1: + dim = int(modulus) * dim + if remainder != 0: + dim = dim + int(remainder) + derived_dim_cache[str(fix)] = dim + return dim + else: + return fix + elif isinstance(dim, _DerivedDim) and dim.root.__name__ in shape_fixes: # type: ignore[attr-defined] + if dim.__name__ in derived_dim_cache: + return derived_dim_cache[dim.__name__] + else: # evaluate new derived value based on root + _dim = dim.fn(shape_fixes[dim.root.__name__]) # type: ignore[attr-defined] + derived_dim_cache[dim.__name__] = _dim + return _dim + return dim # unchanged dim + + return _tree_map_with_path(apply_fixes, dynamic_shapes, dynamic_shapes) diff --git a/phivenv/Lib/site-packages/torch/export/experimental/__init__.py b/phivenv/Lib/site-packages/torch/export/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..397330f7bff43d6847a57b53cb15430f89fef4b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/experimental/__init__.py @@ -0,0 +1,326 @@ +import copy +import dataclasses +import functools +import types +import typing +import typing_extensions + +import torch +from torch.export.exported_program import _decompose_exported_program + + +def _copy_graph_module_and_signature( + ep: torch.fx.GraphModule, +) -> tuple[torch.fx.GraphModule, torch.export.graph_signature.ExportGraphSignature]: + # copy.deepcopy lets the objects override __deepcopy__ methods with graph_copy() and node_copy(), + # and this can break placeholder names in some particular cases. + # For example, node copying will avoid Python keywords like 'input', suffixing and renaming to 'input_1'. + # So we manually overwrite placeholder names by reading the old graph. + gm = copy.deepcopy(ep.graph_module) + new_graph_signature = copy.deepcopy(ep.graph_signature) + + # iterate over old/new graph modules + for old_gm, new_gm in zip(ep.graph_module.modules(), gm.modules()): # type: ignore[union-attr] + old_phs = [node for node in old_gm.graph.nodes if node.op == "placeholder"] + new_phs = [node for node in new_gm.graph.nodes if node.op == "placeholder"] + # iterate over placeholders + assert len(old_phs) == len(new_phs) + for old_node, new_node in zip(old_phs, new_phs): + new_node.name = old_node.name + + return gm, new_graph_signature # type: ignore[return-value] + + +def _remove_detach_pass( + gm: torch.fx.GraphModule, sig: torch.export.graph_signature.ExportGraphSignature +) -> None: + with gm._set_replace_hook(sig.get_replace_hook()): + for node in list(reversed(gm.graph.nodes)): + if node.op != "call_function": + continue + if ( + node.target == torch.ops.aten.detach.default + and len(node.users) == 1 + and next(iter(node.users)).target == torch.ops.aten.detach.default + ): + next(iter(node.users)).replace_all_uses_with(node) + + gm.graph.eliminate_dead_code() + gm.recompile() + + +def _export_forward_backward( + ep: torch.export.ExportedProgram, joint_loss_index: int = 0 +) -> torch.export.ExportedProgram: + """ + WARNING: This API is highly unstable and will be subject to change in the future. + """ + from torch._decomp import core_aten_decompositions + + ep = _decompose_exported_program( + ep, + cia_to_decomp={}, + python_decomp_table=core_aten_decompositions(), + joint_loss_index=joint_loss_index, + # For serialization purpose, we don't want to decompose custom triton ops. + # If users would like to decompose custom triton ops, they could do it + # with run_decompositions() API. + decompose_custom_triton_ops=False, + ) + gm, new_graph_signature = _copy_graph_module_and_signature(ep) + _remove_detach_pass(gm, new_graph_signature) + + return ep._update(gm, new_graph_signature) + + +@typing.no_type_check +def _sticky_export(forward_func, dynamic_shapes_callback=None): + """ + Lazily export the model on first forward call. + Usage: + model.forward = _sticky_export(model.forward, dynamic_shapes_callback=callback) + """ + model = forward_func.__self__ + original_forward = forward_func.__func__ + + @functools.wraps(forward_func) + def wrapper(*args, **kwargs): + # Unpatch forward to avoid recursion during export + model.forward = types.MethodType(original_forward, model) + + dynamic_shapes_spec = None + if dynamic_shapes_callback: + dynamic_shapes_spec = dynamic_shapes_callback(*args, **kwargs) + + try: + exported = torch.export.export( + model, + args, + kwargs, + dynamic_shapes=dynamic_shapes_spec, + ).module() + wrapper._exported_artifact = exported + finally: + # Restore the wrapper after export + model.forward = wrapper + + return exported(*args, **kwargs) + + return wrapper + + +@dataclasses.dataclass +class _ExportMethod: + overloads: dict[str, torch.export.ExportedProgram] + fallbacks: list[torch.export.ExportedProgram] + + +_InputT = typing_extensions.ParamSpec("_InputT") +_RetT = typing.TypeVar("_RetT") + + +class _ExportPackage: + """ + An export package is a collection of torch.export()-ed PyTorch models consisting of + a list of exported methods and their corresponding overloads. ExportPackage is introduced + on top of torch.export() to support the following use cases: + - Exporting a model with multiple methods if a model has multiple independent parts. + - Exporting a function with multiple overloads based on tensor shapes or other metadata. + + ExportPackage is designed to contain multiple methods (associated with method names) and for + each method, it can have multiple overloads (associated with overload names). + + Here is an example of the data structure for an ExportPackage: + ``` + ExportPackage( + methods={ + "decoder": ExportMethod( + overloads={ + "prefill": ExportedProgram(...), + "decode": ExportedProgram(...), + }, + fallbacks=[], + ), + "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]), + }, + ) + ``` + + To export a model into an ExportPackage, users can use the exporter API provided by ExportPackage. + Exporter is a decorator that takes a callable and returns a wrapper. The wrapper will export the + function into an ExportPackage, when it's invoked with some sample inputs (similar to how + torch.compile() works). For more details, please refer to the document on .exporter() method. + + This design allows users to decouple the exported callables from the actual sample inputs which can + be helpful for use cases where the exported callable is hidden behind helper functions or when sample + inpusts are hard to get. + + NOTE: This is an experimental API and anything can be changed in the future. + + Example usage: + ``` + def fn(x): + return x + 1 + + def main(f, x): + x += 1 + ret = f(x) + return ret + 1 + + package = ExportPackage() + main(package.exporter(fn), torch.randn(3, 2)) + ``` + + """ + + def __init__(self) -> None: + self.methods: dict[str, _ExportMethod] = {} + + def _exporter( + self, + method: str, + fn: typing.Callable[_InputT, _RetT], + *, + fallback: str = "once", + ) -> typing.Callable[_InputT, _RetT]: + """ + A function/module decorator that sets up a callable to be exported later invoked. + By default the exporter will only trigger torch.export for once and error on + later invocations. To customize this behavior, users have the following two options: + 1. Call .define_overload() method on the returned wrapper to define an overload. + 2. Adjust the fallback policy using `fallback` argument. + + An "overload" is a named branch for an ExportMethod with a user defined precondition, + typically based on input tensor shapes. It's up to a downstream backend implementation + of ExportMethod to respect the precondition later in inference. + + define_overload() takes arguments like the following: + - A name, for indexing purposes in a backend. + - A callable (spec) that: + - Has the same model input signature as the original model code. + - Returns an optional dynamic shape spec. + + Exporter will only export an overload when the spec callable successfully returns + a result without rasing AssertionError. + + For example: + ``` + package = ExportPackage() + + + def prefill(x, xa, kv_cache): + assert x.shape[1] == 3 + assert kv_cache == {} + + + def decode(x, xa, kv_cache): + assert x.shape[1] > 1 + assert len(kv_cache) > 0 + return {...} # dynamic shape specs here + + + exporter = ( + package.exporter(decoder) + .define_overload("prefill", prefill) + .define_overload("decode", decode) + ) + ``` + + A "fallback" is exported when no overload precondition matches a given set of sample + inputs. Overloads should + Fallbacks don't have names and are ordered in a list. It's up to a backend to decide + which fallback is used amony multiple ones. + + A reference backend implementation of ExportMethod may look like the following: + ``` + def execute(method: ExportMethod, *args, **kwargs): + for overload in method.overloads: + if match_precondition(overload, *args, **kwargs): + return execute_overload(overload, *args, **kwargs) + for fallback in method.fallbacks: + if match_precondition(fallback, *args, **kwargs): + return execute_fallback(fallback, *args, **kwargs) + ``` + + Args: + method(str): The method name for an exported part of PyTorch model. This + will be saved together with the exported/compiled artifacts + in any serialization format and can be used as the key to + index ExportPackage methods later. + fn(callable): A PyTorch function/module to be exported. + fallback(str): The fallback policy to decide when to call torch.export + - "once" is the default policy. Under this policy a PyTorch program is assumed + to be only called once later and an error will be raised for subsequent + runs. + - "error" means the ExportMethod will never have any fallbacks, meaning + users should define all the possible overloads ahead of time. + + """ + + fallbacks: list[torch.export.ExportedProgram] = [] + specs: dict[str, typing.Callable[_InputT, typing.Any]] = {} + overloads: dict[str, torch.export.ExportedProgram] = {} + self.methods[method] = _ExportMethod(fallbacks=fallbacks, overloads=overloads) + + @functools.wraps(fn) + def _exporter_context(*args, **kwargs): # type: ignore[no-untyped-def] + import torch.export._wrapper_utils + + model: torch.nn.Module + if not isinstance(fn, torch.nn.Module): + model = torch.export._wrapper_utils._WrapperModule(fn) + else: + model = fn + + for k, v in specs.items(): + try: + if isinstance(fn, torch.nn.Module): + dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type] + else: + dynamic_shapes = v(*args, **kwargs) + except AssertionError: + continue + if k not in overloads: + ep = torch.export.export( + model, args, kwargs, dynamic_shapes=dynamic_shapes + ) + overloads[k] = ep + ep = overloads[k] + return ep.module()(*args, **kwargs) + + if fallback == "error": + raise RuntimeError( + f"Exporter: Cannot export fallback {fn} when fallback policy is set to 'error'," + + "please specify an overload or adjust the fallback policy." + ) + elif fallback == "once": + if len(fallbacks) > 0: + raise RuntimeError( + f"Exporter: Cannot export {fn} more than once, " + + "please specify an overload or adjust the fallback policy." + ) + else: + raise RuntimeError(f"Unknown fallback policy: {fallback}") + ep = torch.export.export(model, args, kwargs) + + fallbacks.append(ep) + return ep.module()(*args, **kwargs) + + if isinstance(fn, torch.nn.Module): + _exporter_context = torch._dynamo.eval_frame.OptimizedModule( # type: ignore[assignment] # noqa: F811 + fn, lambda _: _exporter_context + ) + + def _define_overload( + overload: str, spec: typing.Callable[_InputT, typing.Any] + ) -> typing.Any: + assert overload not in specs + assert callable(spec) + assert overload.isidentifier() + specs[overload] = spec + return _exporter_context + + assert not hasattr(fn, "_define_overload") + _exporter_context._define_overload = _define_overload # type: ignore[attr-defined] + + return _exporter_context diff --git a/phivenv/Lib/site-packages/torch/export/experimental/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/experimental/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dc2207f01affa82d00b81c50d3d0bc6b26c85b0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/experimental/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/exported_program.py b/phivenv/Lib/site-packages/torch/export/exported_program.py new file mode 100644 index 0000000000000000000000000000000000000000..75654a03bed67bcc994ac5e31f849a06aed54689 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/exported_program.py @@ -0,0 +1,1696 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import contextlib +import copy +import dataclasses +import functools +import operator +import types +import warnings +from collections import defaultdict, namedtuple +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any, Callable, final, Optional, TYPE_CHECKING, Union + +from torch._guards import tracing, TracingContext +from torch._higher_order_ops.utils import autograd_not_implemented +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_impls import ( + _deregister_op_impl, + _is_op_registered_to_fake_rule, + register_op_impl, +) +from torch._subclasses.fake_tensor import FakeTensorMode +from torch.fx._symbolic_trace import _ConstantAttributeType +from torch.fx._utils import first_call_function_nn_module_stack +from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo +from torch.fx.immutable_collections import immutable_dict, immutable_list +from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts + + +if TYPE_CHECKING: + # Import the following modules during type checking to enable code intelligence features, + # such as auto-completion in tools like pylance, even when these modules are not explicitly + # imported in user code. + + import sympy + + from torch.utils._sympy.value_ranges import ValueRanges + +import torch +import torch.utils._pytree as pytree +from torch._export.utils import ( + _collect_all_valid_cia_ops, + _collect_and_set_constant_attrs, + _collect_param_buffer_metadata, + _detect_fake_mode_from_gm, + _fakify_params_buffers, + _get_decomp_for_cia, + _is_preservable_cia_op, + _name_hoo_subgraph_placeholders, + _override_graph_signature_for_temp_registered_constants, + _overwrite_signature_for_non_persistent_buffers, + _populate_param_buffer_metadata_to_new_gm, + _register_constants_as_buffers, + _rename_without_collisions, + _special_op_to_preserve_cia, + placeholder_naming_pass, +) +from torch._export.verifier import Verifier +from torch._guards import detect_fake_mode +from torch._subclasses.fake_tensor import unset_fake_temporarily +from torch.export._tree_utils import is_equivalent, reorder_kwargs +from torch.export.decomp_utils import CustomDecompTable +from torch.fx._compatibility import compatibility +from torch.fx.passes.infra.pass_base import PassResult +from torch.fx.passes.infra.pass_manager import PassManager + +from .graph_signature import ( # noqa: F401 + ArgumentSpec, + ConstantArgument, + CustomObjArgument, + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + SymBoolArgument, + SymFloatArgument, + SymIntArgument, + TensorArgument, + TokenArgument, +) + + +__all__ = [ + "ExportedProgram", + "ModuleCallEntry", + "ModuleCallSignature", + "default_decompositions", +] + + +PassType = Callable[[torch.fx.GraphModule], Optional[PassResult]] + + +@dataclasses.dataclass +class ModuleCallSignature: + inputs: list[ArgumentSpec] + outputs: list[ArgumentSpec] + in_spec: pytree.TreeSpec + out_spec: pytree.TreeSpec + forward_arg_names: Optional[list[str]] = None + + def replace_all_uses_with(self, original_node, new_node): + for i in self.inputs: + if i.name == original_node.name: + i.name = new_node.name + for o in self.outputs: + if o.name == original_node.name: + o.name = new_node.name + + +@dataclasses.dataclass +class ModuleCallEntry: + fqn: str + signature: Optional[ModuleCallSignature] = None + + +def _disable_prexisiting_fake_mode(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + with unset_fake_temporarily(): + return fn(*args, **kwargs) + + return wrapper + + +def _fx_collection_equivalence_fn( + spec1_type: Optional[type], + spec1_context: pytree.Context, + spec2_type: Optional[type], + spec2_context: pytree.Context, +) -> bool: + """Treat containers and their immutable variants as the same type. Otherwise + compare as normal. + """ + if spec1_type is None or spec2_type is None: + return spec1_type is spec2_type and spec1_context == spec2_context + + if issubclass(spec1_type, (dict, immutable_dict)) and issubclass( + spec2_type, (dict, immutable_dict) + ): + return spec1_context == spec2_context + + if issubclass(spec1_type, (list, immutable_list)) and issubclass( + spec2_type, (list, immutable_list) + ): + return spec1_context == spec2_context + + return spec1_type is spec2_type and spec1_context == spec2_context + + +# This list is compiled from DispatchKey.cpp. +# The idea is that we use these keys to override +# CIA decomp in export +_AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE = [ + torch._C.DispatchKey.AutogradCPU, + torch._C.DispatchKey.AutogradCUDA, + torch._C.DispatchKey.AutogradMeta, + torch._C.DispatchKey.AutogradXLA, + torch._C.DispatchKey.AutogradLazy, + torch._C.DispatchKey.AutogradIPU, + torch._C.DispatchKey.AutogradXPU, + torch._C.DispatchKey.AutogradMPS, + torch._C.DispatchKey.AutogradHPU, + torch._C.DispatchKey.AutogradPrivateUse1, + torch._C.DispatchKey.AutogradPrivateUse2, + torch._C.DispatchKey.AutogradPrivateUse3, +] + + +# This list is compiled from DispatchKey.cpp. +# The idea is that we use these keys to add +# python kernels that directly uses default +# CIA decomp +# See NOTE Registering old CIA to Backend kernel +_BACKEND_KEYS_TO_OVERRIDE = [ + torch._C.DispatchKey.CPU, + torch._C.DispatchKey.CUDA, + torch._C.DispatchKey.Meta, + torch._C.DispatchKey.XLA, + torch._C.DispatchKey.Lazy, + torch._C.DispatchKey.IPU, + torch._C.DispatchKey.XPU, + torch._C.DispatchKey.MPS, + torch._C.DispatchKey.HPU, +] + + +@contextmanager +def _override_composite_implicit_decomp(cia_ops_to_callable): + # This function overrides CompositeImplicitAutograd decomp for + # functional composite ops that user specified. Ideally we want to not-decompose + # ALL composite ops but today's C++ functinalization relies on + # the fact that it is working with the opset after decomp is run. + # Hence we can only do it for functional ops. One caveat is that + # there are some composite ops that lie about their schema (claimed to be + # functional but not really aka dropout), for these cases, we just decompose. + saved_tables = {} + patched_ops = set() + for op_overload, decomp_callable in cia_ops_to_callable.items(): + saved_tables[op_overload] = op_overload.py_kernels.copy() + patched_ops.add(op_overload) + for override_dispatch_key in _AUTOGRAD_ALIAS_BACKEND_KEYS_TO_OVERRIDE: + if override_dispatch_key not in op_overload.py_kernels: + # TODO (tmanlaibaatar)https://github.com/pytorch/pytorch/issues/129430 + op_overload.py_impl(override_dispatch_key)( + autograd_not_implemented(op_overload, deferred_error=True) + ) + # See NOTE: Registering old CIA to Backend kernel + # It is important that we cache this before we override py_kernels. + orig_cia_callable = _get_decomp_for_cia(op_overload) + if torch._C.DispatchKey.CompositeImplicitAutograd in op_overload.py_kernels: + del op_overload.py_kernels[torch._C.DispatchKey.CompositeImplicitAutograd] + + op_overload.py_impl(torch._C.DispatchKey.CompositeImplicitAutograd)( + decomp_callable + ) + + # [NOTE] Directly registering fake tensor rule to CIA ops + # The problem we are facing here is if your CIA custom rule + # says we want to preserve the op, we will return NotImplemented. + # Unfortunately, this will invoke meta device tracing in fake tensor + # resulting in divergent behaviour for CIA kernels that has device based + # branching (one case is torch.ops.aten.scaled_dot_product.attention) + # To get around this issue, we register direct fake impl so that we + # run the kernel before we actually try to decompose the op in FakeTensorMode. + # Note that is a no-op in most cases, because: + # 1) In post dispatch tracing, CIA would have already decomposed + # 2) Most CIA impl are device agnostic. + def _force_dispatch_to_orig_cia_callable(fake_tensor_mode, op, *args, **kwargs): + orig_cia_callable = kwargs["original_callable"] + del kwargs["original_callable"] + with fake_tensor_mode: + return orig_cia_callable(*args, **kwargs) + + if not _is_op_registered_to_fake_rule(op_overload): + register_op_impl(op_overload)( + functools.partial( + _force_dispatch_to_orig_cia_callable, + original_callable=orig_cia_callable, + ) + ) + + for key in _BACKEND_KEYS_TO_OVERRIDE: + if key not in op_overload.py_kernels: + # [NOTE] Registering old CIA to Backend kernel + # We always register original CIA behavior to the backend keys kernel + # The reason is when we are fake tensor prop-ing or executing real kernel, + # we end up calling an operator on respective backend, which in python dispatcher, + # will resolve into CIA key. (see resolve_key in torch/_ops.py) + # As a result, this CIA now will call into the custom user defined + # CIA which can cause a problem. + # To make it more concrete, the case we are handling is: + # (1) there is a tensor constant we are performing constant propagation + # on during tracing + # (2) we invoke an op underneath autograd (either because we are below autograd, + # or we are tracing in inference mode), so one of the backend keys gets hit + # (3) the op we are invoking has a CIA impl that normally runs in eager mode + # (and the user wants to tweak this CIA impl during tracing, but during + # const-prop we want the original CIA to run + op_overload.py_impl(key)(orig_cia_callable) + + try: + yield + finally: + for op in patched_ops: + op.py_kernels.clear() + op.py_kernels.update(saved_tables[op]) + op._dispatch_cache.clear() + _deregister_op_impl(op) + + +def _split_decomp_table_to_cia_and_python_decomp( + decomp_table: dict[torch._ops.OperatorBase, Callable], +) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: + all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) + cia_ops_to_callable = {} + + for op in list(decomp_table.keys()): + # TODO we are silently allowing non-safe(non-functional) ops through a crack + # due to core aten decomp table having non-functional entries. Once we have + # a tigher check around core aten decomp, we should warn users about them. + # Tracking issue: (https://github.com/pytorch/pytorch/issues/135759) + + # if it is a valid CIA op we can mess with in export, we check if it is: + # 1. Has been marked as to be decomposed. Example: + # decomp_table = decomp_table_to_core_aten() + # del decomp_table[aten.linear] + # In this case, user says decompose everything except for aten.linear + # 2. Has been marked with custom decomp behavour. Example: + # decomp_table = {aten.linear: some_op} + # For (1), we want to remove all the CIA ops that weren't handled by user as + # it suggests they are safe to decompose, so we should remove from preservable_list. + # for (2), we just plumb the custom decomp to AOTDIspatcher. + # In both cases, we want to remove this CIA op from the decomp_table as it is special + # handled. + if op in all_preservable_cia_ops: + cia_ops_to_callable[op] = decomp_table[op] + all_preservable_cia_ops.remove(op) + del decomp_table[op] + # If it is a custom op, we want to still preserve or do whatever + # with it if it is a functional CIA. The reason we don't remove + # from CIA list is because we don't query custom ops. + elif _is_preservable_cia_op(op): + op_name = op.name() + assert not op_name.startswith("aten"), "This should be a custom op" + cia_ops_to_callable[op] = decomp_table[op] + + # If we reached here, it means user intentionally deleted these CIA ops from + # decomp table. + for k in all_preservable_cia_ops: + cia_ops_to_callable[k] = _special_op_to_preserve_cia + + return cia_ops_to_callable, decomp_table + + +def default_decompositions() -> "CustomDecompTable": + """ + This is the default decomposition table which contains decomposition of + all ATEN operators to core aten opset. Use this API together with + :func:`run_decompositions()` + """ + return CustomDecompTable() + + +def _decompose_and_get_gm_with_new_signature_constants( + ep, + *, + cia_to_decomp: dict[torch._ops.OperatorBase, Callable], + python_decomp_table: dict[torch._ops.OperatorBase, Callable], + joint_loss_index: Optional[int], + decompose_custom_triton_ops, +): + from torch._export.passes.lift_constants_pass import _materialize_and_lift_constants + from torch._functorch.aot_autograd import aot_export_module + from torch.export._trace import ( + _disable_custom_triton_op_functional_decomposition, + _export_to_aten_ir, + _ignore_backend_decomps, + _verify_nn_module_stack, + _verify_placeholder_names, + _verify_stack_trace, + ) + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + def _is_joint_ir_decomp(ep, joint_loss_index): + return ( + joint_loss_index is not None + or ep.graph_signature.backward_signature is not None + ) + + if not _is_joint_ir_decomp(ep, joint_loss_index): + mod = ep.module() + + wrapped_params_buffers = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + from torch._functorch._aot_autograd.subclass_parametrization import ( + unwrap_tensor_subclass_parameters, + ) + + # [NOTE] Unwrapping subclasses AOT + # In torch.compile, the subclass unwrapping/wrapping happen at runtime + # but at export, this is impossible as it is intented to be run on + # C++ environment. As a result, we unwrap subclass parameters AOT. After this, + # ExportedProgram state_dict won't be same as eager model because eager model + # could have subclass weights while ExportedProgram will have desugared versions. + # This is fine because run_decompositions is supposed to specialize to post-autograd + # graph where the subclass desugaring is supposed to happen. + unwrap_tensor_subclass_parameters(mod) + unwrapped_params_buffers = { + **dict(mod.named_parameters(remove_duplicate=False)), + **dict(mod.named_buffers(remove_duplicate=False)), + } + + # TODO T204030333 + fake_mode = _detect_fake_mode_from_gm(ep.graph_module) + if fake_mode is None: + fake_mode = FakeTensorMode(shape_env=ShapeEnv(), export=True) + + # Fix the graph output signature to be tuple if scalar + out_spec = mod._out_spec + + orig_arg_names = mod.graph._codegen.pytree_info.orig_args + + # aot_export expect the return type to always be a tuple. + if out_spec.type not in (list, tuple): + out_spec = pytree.TreeSpec(tuple, None, [out_spec]) + + mod.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo( + orig_arg_names, + mod._in_spec, + out_spec, + ) + ) + + mod.recompile() + + # the exported module will store constants & non-persistent buffers such that + # retracing treats them as persistent buffers, so we inform the constants lifting pass + # and overwrite the new graph signature using the previous program. + _collect_and_set_constant_attrs(ep.graph_signature, ep.constants, mod) + + # When we have a module with constant attributes, AotDispatcher doesn't actually + # wrap them as functional tensors, because dynamo would have already made it buffer. + # In non-strict case, however, AotDispatcher can intercept constants, causing it to not + # functionalize the operators that are operating on constant tensors. Since dynamo already + # wraps constants as buffers, we temporarily register the constants as buffers and undo this + # operation after AOTDispatcher is done. + temp_registered_constants = _register_constants_as_buffers( + mod, ep.state_dict, ep.graph_signature.non_persistent_buffers + ) + + # get params & buffers after excluding constants + fake_params_buffers = _fakify_params_buffers(fake_mode, mod) + + params_buffers_to_node_meta = _collect_param_buffer_metadata(mod) + + # TODO (tmanlaibaatar) Ideally run_decomp should just call _non_strict_export + # but due to special handling of constants as non-persistent buffers make it little + # diffucult. But we should unify this code path together. T206837815 + from torch._export.non_strict_utils import ( + _enable_graph_inputs_of_type_nn_module, + _fakify_script_objects, + ) + + retracing_args = [] + for node in mod.graph.nodes: + if node.op == "placeholder": + if isinstance(node.meta["val"], CustomObjArgument): + real_script_obj = None + if node.meta["val"].fake_val is None: + real_script_obj = ep.constants[node.meta["val"].name] + else: + real_script_obj = node.meta["val"].fake_val.real_obj + retracing_args.append(real_script_obj) + else: + retracing_args.append(node.meta["val"]) + + tx = TracingContext(fake_mode) + + with ( + fake_mode, + _override_composite_implicit_decomp( + cia_to_decomp, + ), + _enable_graph_inputs_of_type_nn_module(ep.example_inputs), + tracing(tx), + ): + retracing_args_unwrapped = pytree.tree_unflatten( + retracing_args, mod._in_spec + ) + # this requires empty kwargs, but not in pytree.flattened format + with _fakify_script_objects( + mod, + ( + *retracing_args_unwrapped[0], + *retracing_args_unwrapped[1].values(), + ), + {}, + fake_mode, + ) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ): + aten_export_artifact = _export_to_aten_ir( + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + new_fake_constant_attrs, + decomp_table=python_decomp_table, + _check_autograd_state=False, + _prettify_placeholder_names=False, + decompose_custom_triton_ops=decompose_custom_triton_ops, + ) + + # aten_export_artifact.constants contains only fake script objects, we need to map them back + aten_export_artifact.constants = { + fqn: ( + map_fake_to_real[obj] + if isinstance(obj, FakeScriptObject) + else obj + ) + for fqn, obj in aten_export_artifact.constants.items() + } + + gm = aten_export_artifact.gm + new_graph_signature = aten_export_artifact.sig + + # In the previous step, we assume constants as buffers for AOTDispatcher to + # functianalize properly, so undo that here + new_graph_signature = ( + _override_graph_signature_for_temp_registered_constants( + new_graph_signature, temp_registered_constants + ) + ) + + _populate_param_buffer_metadata_to_new_gm( + params_buffers_to_node_meta, gm, new_graph_signature + ) + + # overwrite signature for non-persistent buffers + new_graph_signature = _overwrite_signature_for_non_persistent_buffers( + ep.graph_signature, new_graph_signature + ) + + constants = _materialize_and_lift_constants( + gm, new_graph_signature, new_fake_constant_attrs + ) + + placeholder_naming_pass( + gm, + new_graph_signature, + patched_mod, + new_fake_args, + new_fake_kwargs, + fake_params_buffers, + constants, + ) + + _verify_nn_module_stack(gm) + _verify_stack_trace(gm) + _verify_placeholder_names(gm, new_graph_signature) + + gm, new_graph_signature = _remove_unneccessary_copy_op_pass( + gm, new_graph_signature + ) + + # When we apply parameterixzation rule to unwrap + # subclasses, the state dict will now have different + # desugared parameters. We need to manually filter those + # and update the ep.state_dict. Ideally, we should just return + # the state dict of ep.module but ep.module only stores params + # buffers that participate in forward. If we undo this behaviour, + # it would break some downstream users. + new_state_dict = { + **ep.state_dict, + **{ + name: p + for name, p in unwrapped_params_buffers.items() + if name not in wrapped_params_buffers + }, + } + + for name, p in wrapped_params_buffers.items(): + # Buffers can be persistent/non-persistent + if name not in new_state_dict: + assert not isinstance(p, torch.nn.Parameter) + + if name in new_state_dict: + if name not in unwrapped_params_buffers: + new_state_dict.pop(name) + + return gm, new_graph_signature, new_state_dict + + old_placeholders = [ + node for node in ep.graph_module.graph.nodes if node.op == "placeholder" + ] + fake_args = [node.meta["val"] for node in old_placeholders] + + buffers_to_remove = [name for name, _ in ep.graph_module.named_buffers()] + for name in buffers_to_remove: + delattr(ep.graph_module, name) + + # TODO(zhxhchen17) Return the new graph_signature directly. + fake_mode = detect_fake_mode(fake_args) + fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment] + custom_triton_ops_decomposition_ctx = ( + contextlib.nullcontext + if decompose_custom_triton_ops + else _disable_custom_triton_op_functional_decomposition + ) + with ( + _ignore_backend_decomps(), + fake_mode, + _override_composite_implicit_decomp(cia_to_decomp), + custom_triton_ops_decomposition_ctx(), + ): + gm, graph_signature = aot_export_module( + ep.graph_module, + fake_args, + decompositions=python_decomp_table, + trace_joint=True if joint_loss_index is not None else False, + output_loss_index=( + joint_loss_index if joint_loss_index is not None else None + ), + ) + gm.graph.eliminate_dead_code() + + # Update the signatures with the new placeholder names in case they + # changed when calling aot_export + def update_arg(old_arg, new_ph): + if isinstance(old_arg, ConstantArgument): + return old_arg + elif isinstance(old_arg, TensorArgument): + return TensorArgument(name=new_ph.name) + elif isinstance(old_arg, SymIntArgument): + return SymIntArgument(name=new_ph.name) + elif isinstance(old_arg, SymFloatArgument): + return SymFloatArgument(name=new_ph.name) + elif isinstance(old_arg, SymBoolArgument): + return SymBoolArgument(name=new_ph.name) + raise RuntimeError(f"Type of old_arg not supported: {type(old_arg)}") + + new_placeholders = [node for node in gm.graph.nodes if node.op == "placeholder"] + new_outputs = list(gm.graph.nodes)[-1].args[0] + + # rename the placeholders + assert len(new_placeholders) == len(old_placeholders) + for old_ph, new_ph in zip(old_placeholders, new_placeholders): + new_ph.name = new_ph.target = old_ph.name + + # handle name collisions with newly decomposed graph nodes + name_map = {ph.name: ph.name for ph in new_placeholders} + for node in gm.graph.nodes: + if node.op == "placeholder": + continue + node.name = _rename_without_collisions(name_map, node.name, node.name) + + # propagate names to higher order op subgraphs + _name_hoo_subgraph_placeholders(gm) + + # Run this pass before creating input/output specs, since size-related CSE/DCE might affect output signature. + # Overwrite output specs afterwards. + from torch._export.passes._node_metadata_hook import ( + _node_metadata_hook, + _set_node_metadata_hook, + ) + from torch._functorch._aot_autograd.input_output_analysis import _graph_output_names + + if not torch._dynamo.config.do_not_emit_runtime_asserts: + stack_trace = ( + 'File "torch/fx/passes/runtime_assert.py", line 24, ' + "in insert_deferred_runtime_asserts" + ) + shape_env = _get_shape_env(gm) + if shape_env is not None: + with _set_node_metadata_hook( + gm, functools.partial(_node_metadata_hook, stack_trace=stack_trace) + ): + insert_deferred_runtime_asserts( + gm, + shape_env, + f"exported program: {first_call_function_nn_module_stack(gm.graph)}", + export=True, + ) + + # update output specs + gm.recompile() + for i, name in enumerate(_graph_output_names(gm)): + if isinstance(new_outputs[i], torch.fx.Node): + new_outputs[i].name = name + + # To match the output target with correct input for input mutations + # need to find the old to new placeholder map + old_new_placeholder_map = { + spec.arg.name: new_placeholders[i].name + for i, spec in enumerate(ep.graph_signature.input_specs) + if not isinstance(spec.arg, ConstantArgument) + } + + input_specs = [ + InputSpec( + spec.kind, + update_arg(spec.arg, new_placeholders[i]), + spec.target, + spec.persistent, + ) + for i, spec in enumerate(ep.graph_signature.input_specs) + ] + + output_specs = [] + + # handle buffer & input mutations; these appear before loss output & gradients + # (1) ep.graph_signature.input_specs tells us types of inputs + # (2) graph_signature.user_inputs tells us node input names in order + # (3) graph_signature.user_inputs_to_mutate tells us buffer & input mutations + # map (3) -> (2) for input order, -> (1) for input type + user_inputs_index = {name: i for i, name in enumerate(graph_signature.user_inputs)} + mutation_names = list(graph_signature.user_inputs_to_mutate.keys()) + assert mutation_names == [node.name for node in new_outputs[: len(mutation_names)]] + for output_name, input_name in graph_signature.user_inputs_to_mutate.items(): + i = user_inputs_index[input_name] + input_spec = ep.graph_signature.input_specs[i] + assert input_spec.kind in (InputKind.USER_INPUT, InputKind.BUFFER) + output_kind = ( + OutputKind.BUFFER_MUTATION + if input_spec.kind == InputKind.BUFFER + else OutputKind.USER_INPUT_MUTATION + ) + target = ( + input_spec.target + if input_spec.kind == InputKind.BUFFER + else input_spec.arg.name + ) + output_specs.append( + OutputSpec( + kind=output_kind, + arg=TensorArgument(name=output_name), + target=target, + ) + ) + + # handle actual user outputs + for i, spec in enumerate(ep.graph_signature.output_specs): + output_specs.append( + OutputSpec( + OutputKind.LOSS_OUTPUT if i == joint_loss_index else spec.kind, + update_arg(spec.arg, new_outputs[len(mutation_names) + i]), + old_new_placeholder_map.get(spec.target, spec.target), + ) + ) + + if joint_loss_index is not None: + assert graph_signature.backward_signature is not None + gradients = graph_signature.backward_signature.gradients_to_user_inputs + assert len(graph_signature.user_inputs) == len(ep.graph_signature.input_specs) + specs = { + graph_signature.user_inputs[i]: spec + for i, spec in enumerate(ep.graph_signature.input_specs) + if isinstance(spec.arg, TensorArgument) + } + for i, node in enumerate(new_outputs[len(output_specs) :]): + source = gradients[node.name] + spec = specs[source] # type: ignore[index] + if spec.kind == InputKind.PARAMETER: + kind = OutputKind.GRADIENT_TO_PARAMETER + target = spec.target + elif spec.kind == InputKind.USER_INPUT: + kind = OutputKind.GRADIENT_TO_USER_INPUT + target = source + else: + raise AssertionError(f"Unknown input kind: {spec.kind}") + output_specs.append( + OutputSpec( + kind, + TensorArgument(name=node.name), + target, + ) + ) + + assert len(new_placeholders) == len(old_placeholders) + + new_graph_signature = ExportGraphSignature( + input_specs=input_specs, output_specs=output_specs + ) + # NOTE: aot_export adds symint metadata for placeholders with int + # values; since these become specialized, we replace such metadata with + # the original values. + # Also, set the param/buffer metadata back to the placeholders. + for old_node, new_node in zip(old_placeholders, new_placeholders): + if not isinstance(old_node.meta["val"], torch.Tensor): + new_node.meta["val"] = old_node.meta["val"] + + if ( + new_node.target in new_graph_signature.inputs_to_parameters + or new_node.target in new_graph_signature.inputs_to_buffers + ): + for k, v in old_node.meta.items(): + new_node.meta[k] = v + return gm, new_graph_signature, ep.state_dict + + +def _remove_unneccessary_copy_op_pass( + gm: torch.fx.GraphModule, new_graph_signature: ExportGraphSignature +) -> tuple[torch.fx.GraphModule, ExportGraphSignature]: + """ + Removes redundant copy_ node that was introduced due to mutated buffer. + """ + with gm._set_replace_hook(new_graph_signature.get_replace_hook()): + for node in gm.graph.nodes: + if node.op == "output": + args, _ = pytree.tree_flatten(node.args) + for out in args: + if ( + isinstance(out, torch.fx.Node) + and out.name in new_graph_signature.buffers_to_mutate + ): + if ( + out.op == "call_function" + and out.target == torch.ops.aten.copy.default + ): + out.replace_all_uses_with(out.args[1]) # type: ignore[arg-type] + gm.graph.erase_node(out) + gm.recompile() + return gm, new_graph_signature + + +def _common_getitem_elimination_pass( + gm: torch.fx.GraphModule, graph_signature, module_call_graph +): + with gm._set_replace_hook(graph_signature.get_replace_hook()): + for module in gm.modules(): + if not isinstance(module, torch.fx.GraphModule): + continue + + node_id: dict[torch.fx.Node, str] = {} + getitems: dict[str, torch.fx.Node] = {} + for node in list(module.graph.nodes): + if node.op == "call_function" and node.target == operator.getitem: + source, idx = node.args + new_id = f"{node_id[source]}.{idx}" + if new_id in getitems: + node.replace_all_uses_with(getitems[new_id]) + for entry in module_call_graph: + if entry.signature is not None: + entry.signature.replace_all_uses_with( + node, getitems[new_id] + ) + module.graph.erase_node(node) + else: + getitems[new_id] = node + node_id[node] = new_id + else: + node_id[node] = node.name + + +def _get_updated_module_call_graph( + old_gm: torch.fx.GraphModule, + old_graph_signature: ExportGraphSignature, + gm: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + old_module_call_graph: list[ModuleCallEntry], +): + new_module_call_graph = copy.deepcopy(old_module_call_graph) + + old_nodes = {node.name: node for node in old_gm.graph.nodes} + + old_graph_params_buffers = { + **old_graph_signature.inputs_to_parameters, + **old_graph_signature.inputs_to_buffers, + } + new_graph_params_buffers = { + **graph_signature.inputs_to_parameters, + **graph_signature.inputs_to_buffers, + } + + # use node-level provenance metadata to create a map + # from old node names to new node names + provenance: dict[str, str] = {} + + user_input_counter = 0 + old_user_input_names = [ + node.target for node in old_gm.graph.nodes if node.op == "placeholder" + ] + old_user_input_names = list( + filter( + lambda x: x not in old_graph_params_buffers + and x not in old_graph_signature.input_tokens, + old_user_input_names, + ) + ) + new_user_input_names = [ + node.target for node in gm.graph.nodes if node.op == "placeholder" + ] + + for node in gm.graph.nodes: + if history := node.meta.get("from_node", []): + provenance[history[-1].name] = node.name + + # For params and buffers, we might have applied parameterizaiton rule + # so that the names might have changed. But for user inputs, we know we + # must preserve the old name. + elif node.op == "placeholder": + if not ( + node.target in new_graph_params_buffers + or node.target in graph_signature.input_tokens + ): + if node.target in new_user_input_names: + assert isinstance(node.name, str) + old_name = old_user_input_names[user_input_counter] + assert isinstance(old_name, str) + provenance[old_name] = node.name + user_input_counter += 1 + + # For all the parameters and buffers, we first see + # if they are result of paramerizaitons and if they + # are, we log them and error later + old_param_to_desugared = defaultdict(list) + for name, target in new_graph_params_buffers.items(): + # if the parameters are not parametrized, the naming won't change. + if not target.startswith("parametrizations."): + # If we are in strict mode, we can't just reuse the param names + if name in old_graph_params_buffers: + provenance[name] = name + else: + old_target = ".".join(target.split(".")[1:-1]) + old_param_to_desugared[old_target].append(name) + + # map old names to new names in module call signatures + for entry in new_module_call_graph: + signature = entry.signature + if signature is None: + continue + for x in [*signature.inputs, *signature.outputs]: + # We noticed that submodule is taking subclass as input. we can't + # preserve signature here. + if x.name in old_param_to_desugared: + raise ValueError( + f"It looks like {x.name} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + if x.name in provenance: + x.name = provenance[x.name] + + # This can happen when aten.to is called at graph boundaries. + # Basically aten.to at post-dispatch level can either be copy + # or alias. In the alias case, we will no-op it so it will + # disappear from the graph. If we detect such case, we should + # reuse the input to aten.to as the new input to the submodule. + # Technically this can happen for other maybe aliasing ops, + # but aten.to is probably the most common one. + elif x.name in old_nodes: + old_node = old_nodes[x.name] + if old_node.op == "call_function" and old_node.target in [ + torch.ops.aten.to.dtype_layout, + torch.ops.aten.to.device, + torch.ops.aten.to.dtype, + ]: + old_target = old_node.args[0].name + if old_target not in provenance: + raise ValueError( + f"It looks like {old_target} is a tensor subclass. " + f"Preserving submodule that takes subclass parameter is not supported" + f" in inference IR because we desugar them, resulting in more tensors" + ) + + x.name = provenance[old_target] + + return new_module_call_graph + + +def _decompose_exported_program( + ep, + *, + cia_to_decomp: dict[torch._ops.OperatorBase, Callable], + python_decomp_table: dict[torch._ops.OperatorBase, Callable], + joint_loss_index: Optional[int], + decompose_custom_triton_ops: bool, +): + ( + gm, + new_graph_signature, + state_dict, + ) = _decompose_and_get_gm_with_new_signature_constants( + ep, + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, + joint_loss_index=joint_loss_index, + decompose_custom_triton_ops=decompose_custom_triton_ops, + ) + + # The signatures of ep.module_call_graph refer to input / output nodes of + # the original graph module. However, the new graph module may have + # new nodes due to decompositions. So we need to update these signatures + # in the decomposed exported program's module_call_graph. + new_module_call_graph = _get_updated_module_call_graph( + ep.graph_module, + ep.graph_signature, + gm, + new_graph_signature, + ep.module_call_graph, + ) + + # TODO unfortunately preserving graph-level metadata is not + # working well with aot_export. So we manually copy it. + # (The node-level meta is addressed above.) + gm.meta.update(ep.graph_module.meta) + + new_range_constraints = _get_updated_range_constraints( + gm, + ep.range_constraints, + ) + + exported_program = ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=new_graph_signature, + state_dict=state_dict, + range_constraints=new_range_constraints, + module_call_graph=new_module_call_graph, + example_inputs=ep.example_inputs, + constants=ep.constants, + ) + return exported_program + + +class ExportedProgram: + """ + Package of a program from :func:`export`. It contains + an :class:`torch.fx.Graph` that represents Tensor computation, a state_dict containing + tensor values of all lifted parameters and buffers, and various metadata. + + You can call an ExportedProgram like the original callable traced by + :func:`export` with the same calling convention. + + To perform transformations on the graph, use ``.module`` property to access + an :class:`torch.fx.GraphModule`. You can then use + `FX transformation `_ + to rewrite the graph. Afterwards, you can simply use :func:`export` + again to construct a correct ExportedProgram. + """ + + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, + state_dict: dict[str, Union[torch.Tensor, torch.nn.Parameter]], + range_constraints: "dict[sympy.Symbol, Any]", + module_call_graph: list[ModuleCallEntry], + example_inputs: Optional[tuple[tuple[Any, ...], dict[str, Any]]] = None, + constants: Optional[dict[str, _ConstantAttributeType]] = None, + *, + verifiers: Optional[list[type[Verifier]]] = None, + ): + # Remove codegen related things from the graph. It should just be a flat graph. + graph._codegen = torch.fx.graph.CodeGen() + self._graph_module = _create_graph_module_for_export(root, graph) + if isinstance(root, torch.fx.GraphModule): + self._graph_module.meta.update(root.meta) + + _common_getitem_elimination_pass( + self._graph_module, graph_signature, module_call_graph + ) + self._graph_signature: ExportGraphSignature = graph_signature + self._state_dict: dict[str, Any] = state_dict + self._range_constraints: dict[sympy.Symbol, ValueRanges] = range_constraints + assert module_call_graph is not None + self._module_call_graph: list[ModuleCallEntry] = module_call_graph + self._example_inputs = example_inputs + + self._constants = constants or {} + + verifiers = verifiers or [Verifier] + assert all(issubclass(v, Verifier) for v in verifiers) + self._verifiers = verifiers + # Validate should be always the last step of the constructor. + self.validate() + + @property + @compatibility(is_backward_compatible=False) + def graph_module(self): + return self._graph_module + + @graph_module.setter + @compatibility(is_backward_compatible=False) + def graph_module(self, value): + raise RuntimeError("Unable to set ExportedProgram's graph_module attribute.") + + @property + @compatibility(is_backward_compatible=False) + def graph(self): + return self.graph_module.graph + + @graph.setter + @compatibility(is_backward_compatible=False) + def graph(self, value): + raise RuntimeError("Unable to set ExportedProgram's graph attribute.") + + @property + @compatibility(is_backward_compatible=False) + def graph_signature(self): + return self._graph_signature + + @graph_signature.setter + @compatibility(is_backward_compatible=False) + def graph_signature(self, value): + raise RuntimeError("Unable to set ExportedProgram's graph_signature attribute.") + + @property + @compatibility(is_backward_compatible=False) + def state_dict(self): + return self._state_dict + + @state_dict.setter + @compatibility(is_backward_compatible=False) + def state_dict(self, value): + raise RuntimeError("Unable to set ExportedProgram's state_dict attribute.") + + @compatibility(is_backward_compatible=False) + def parameters(self) -> Iterator[torch.nn.Parameter]: + """ + Returns an iterator over original module's parameters. + """ + for _, param in self.named_parameters(): + yield param + + @compatibility(is_backward_compatible=False) + def named_parameters(self) -> Iterator[tuple[str, torch.nn.Parameter]]: + """ + Returns an iterator over original module parameters, yielding + both the name of the parameter as well as the parameter itself. + """ + for param_name in self.graph_signature.parameters: + yield param_name, self.state_dict[param_name] + + @compatibility(is_backward_compatible=False) + def buffers(self) -> Iterator[torch.Tensor]: + """ + Returns an iterator over original module buffers. + """ + for _, buf in self.named_buffers(): + yield buf + + @compatibility(is_backward_compatible=False) + def named_buffers(self) -> Iterator[tuple[str, torch.Tensor]]: + """ + Returns an iterator over original module buffers, yielding + both the name of the buffer as well as the buffer itself. + """ + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + for buffer_name in self.graph_signature.buffers: + if buffer_name in non_persistent_buffers: + yield buffer_name, self.constants[buffer_name] + else: + yield buffer_name, self.state_dict[buffer_name] + + @property + @compatibility(is_backward_compatible=False) + def range_constraints(self): + return self._range_constraints + + @range_constraints.setter + @compatibility(is_backward_compatible=False) + def range_constraints(self, value): + raise RuntimeError( + "Unable to set ExportedProgram's range_constraints attribute." + ) + + @property + @compatibility(is_backward_compatible=False) + def module_call_graph(self): + return self._module_call_graph + + @module_call_graph.setter + @compatibility(is_backward_compatible=False) + def module_call_graph(self, value): + raise RuntimeError( + "Unable to set ExportedProgram's module_call_graph attribute." + ) + + @property + @compatibility(is_backward_compatible=False) + def example_inputs(self): + return self._example_inputs + + @example_inputs.setter + @compatibility(is_backward_compatible=False) + def example_inputs(self, value): + # This is allowed + + if value is None: + self._example_inputs = value + return + + if not ( + isinstance(value, tuple) + and len(value) == 2 + and isinstance(value[0], tuple) + and isinstance(value[1], dict) + ): + raise ValueError( + "Example inputs should be a tuple containing example arguments (as " + "a tuple), and example kwargs (as a dictionary)." + ) + + args, kwargs = value + from ._unlift import _check_inputs_match + + _check_inputs_match(args, kwargs, self.call_spec.in_spec) + + self._example_inputs = value + + @property + @compatibility(is_backward_compatible=False) + def call_spec(self): + CallSpec = namedtuple("CallSpec", ["in_spec", "out_spec"]) + + if len(self.module_call_graph) == 0: + return CallSpec(in_spec=None, out_spec=None) + assert self.module_call_graph[0].fqn == "" + return CallSpec( + in_spec=self.module_call_graph[0].signature.in_spec, + out_spec=self.module_call_graph[0].signature.out_spec, + ) + + @call_spec.setter + @compatibility(is_backward_compatible=False) + def call_spec(self, value): + raise RuntimeError("Unable to set ExportedProgram's call_spec attribute.") + + @property + @compatibility(is_backward_compatible=False) + def verifier(self) -> Any: + return self._verifiers[0] + + @verifier.setter + @compatibility(is_backward_compatible=False) + def verifier(self, value): + raise RuntimeError("Unable to set ExportedProgram's verifier attribute.") + + @property + @compatibility(is_backward_compatible=False) + def dialect(self) -> str: + assert self._verifiers is not None + return self._verifiers[0].dialect + + @dialect.setter + @compatibility(is_backward_compatible=False) + def dialect(self, value): + raise RuntimeError("Unable to set ExportedProgram's dialect attribute.") + + @property + @compatibility(is_backward_compatible=False) + def verifiers(self): + return self._verifiers + + @verifiers.setter + @compatibility(is_backward_compatible=False) + def verifiers(self, value): + raise RuntimeError("Unable to set ExportedProgram's verifiers attribute.") + + @property + @compatibility(is_backward_compatible=False) + def tensor_constants(self): + return self._constants + + @tensor_constants.setter + @compatibility(is_backward_compatible=False) + def tensor_constants(self, value): + raise RuntimeError( + "Unable to set ExportedProgram's tensor_constants attribute." + ) + + @property + @compatibility(is_backward_compatible=False) + def constants(self): + return self._constants + + @constants.setter + @compatibility(is_backward_compatible=False) + def constants(self, value): + raise RuntimeError("Unable to set ExportedProgram's constants attribute.") + + def _get_flat_args_with_check(self, args, kwargs): + """Flatten args, kwargs using pytree, then, check specs. + + Args: + args: List[Any] original args passed to __call__ + kwargs: Dict[str, Any] original kwargs passed to __call + + Returns: + A tuple of (flat_args, received_spec) + flat_args is flattend args / kwargs + received_spec is the pytree spec produced while flattening the + tuple (args, kwargs) + """ + in_spec = self.call_spec.in_spec + if in_spec is not None: + kwargs = reorder_kwargs(kwargs, in_spec) + flat_args_with_path, received_spec = pytree.tree_flatten_with_path( + (args, kwargs) + ) + self._check_input_constraints(flat_args_with_path) + flat_args = tuple(x[1] for x in flat_args_with_path) + return flat_args, received_spec + + def _graph_module_flat_inputs(self, args: Any, kwargs: Any) -> Any: + """Transform args, kwargs of __call__ to args for graph_module. + + self.graph_module takes stuff from state dict as inputs. + The invariant is for ep: ExportedProgram is + ep(args, kwargs) == + ep.postprocess(ep.graph_module(ep.graph_module_flat_inputs(args, kwargs))) + """ + + in_spec = self.call_spec.in_spec + flat_args, received_spec = self._get_flat_args_with_check(args, kwargs) + if in_spec is not None and not is_equivalent( + received_spec, in_spec, _fx_collection_equivalence_fn + ): + raise ValueError( + "Trying to flatten user inputs with exported input tree spec: \n" + f"{in_spec}\n" + "but actually got inputs with tree spec of: \n" + f"{received_spec}" + ) + + additional_inputs = [] + for input_ in self.graph_signature.input_specs: + if input_.kind == InputKind.USER_INPUT: + continue + elif input_.kind in ( + InputKind.PARAMETER, + InputKind.BUFFER, + ): + if input_.persistent is False: + # This is a non-persistent buffer, grab it from our + # constants instead of the state dict. + additional_inputs.append(self.constants[input_.target]) + else: + additional_inputs.append(self.state_dict[input_.target]) + elif input_.kind in ( + InputKind.CONSTANT_TENSOR, + InputKind.CUSTOM_OBJ, + ): + additional_inputs.append(self.constants[input_.target]) + additional_inputs = tuple(additional_inputs) + + # NOTE: calling convention is first params, then buffers, then args as user supplied them. + # See: torch/_functorch/aot_autograd.py#L1034 + return additional_inputs + flat_args + + def __call__(self, *args: Any, **kwargs: Any) -> Any: + raise RuntimeError( + "Unable to call ExportedProgram directly. " + "You should use `exported_program.module()` instead." + ) + + def _postprocess_graph_module_outputs(self, res, orig_args, orig_kwargs): + """Process potential mutations to the input. + + Because self.graph_module is functional, so mutations has to be written + back after execution of graph_module. + """ + import torch._export.error as error + + flat_args, _ = self._get_flat_args_with_check(orig_args, orig_kwargs) + if self.call_spec.out_spec is not None: + buffer_mutation = self.graph_signature.buffers_to_mutate + user_input_mutation = self.graph_signature.user_inputs_to_mutate + num_mutated = len(buffer_mutation) + len(user_input_mutation) + mutated_values = res[:num_mutated] + + # Exclude dependency token from final result. + assertion_dep_token = self.graph_signature.assertion_dep_token + if assertion_dep_token is not None: + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + res = res[:assertion_dep_token_index] + + res = res[num_mutated:] + try: + res = pytree.tree_unflatten(res, self.call_spec.out_spec) + except Exception: + _, received_spec = pytree.tree_flatten(res) + raise error.InternalError( # noqa: B904 + "Trying to flatten user outputs with exported output tree spec: \n" + f"{self.call_spec.out_spec}\n" + "but actually got outputs with tree spec of: \n" + f"{received_spec}" + ) + finally: + user_inputs = [ + spec + for spec in self.graph_signature.input_specs + if spec.kind == InputKind.USER_INPUT + ] + for i, value in enumerate(mutated_values): + output_spec = self.graph_signature.output_specs[i] + if output_spec.kind == OutputKind.BUFFER_MUTATION: + assert output_spec.target is not None + self.state_dict[output_spec.target] = value + elif output_spec.kind == OutputKind.USER_INPUT_MUTATION: + assert output_spec.target is not None + index = next( + i + for i, spec in enumerate(user_inputs) + if spec.arg.name == output_spec.target + ) + flat_args[index].copy_(value) + else: + raise AssertionError(f"Unexpected kind: {output_spec.kind}") + return res + + def __str__(self) -> str: + graph_module = self.graph_module.print_readable( + print_output=False, colored=False + ).replace("\n", "\n ") + graph_signature = str(self.graph_signature).replace("\n", "\n ") + string = ( + "ExportedProgram:\n" + f" {graph_module}\n" + f"Graph signature: {graph_signature}\n" + f"Range constraints: {self.range_constraints}\n" + ) + return string + + def module(self) -> torch.nn.Module: + """ + Returns a self contained GraphModule with all the parameters/buffers inlined. + """ + from ._unlift import _unlift_exported_program_lifted_states + + module = _unlift_exported_program_lifted_states(self) + + def _train(self, mode: bool = True): + raise NotImplementedError("Calling train() is not supported yet.") + + def _eval(self, mode: bool = True): + raise NotImplementedError("Calling eval() is not supported yet.") + + module.train = types.MethodType(_train, module) # type: ignore[method-assign] + module.eval = types.MethodType(_eval, module) # type: ignore[method-assign] + return module + + def _num_lifted_params_buffers(self): + return next( + ( + i + for i, s in enumerate(self._graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ), + len(self._graph_signature.input_specs), + ) + + @_disable_prexisiting_fake_mode + def run_decompositions( + self, + decomp_table: Optional[dict[torch._ops.OperatorBase, Callable]] = None, + decompose_custom_triton_ops: bool = False, + ) -> "ExportedProgram": + """ + Run a set of decompositions on the exported program and returns a new + exported program. By default we will run the Core ATen decompositions to + get operators in the + `Core ATen Operator Set `_. + + For now, we do not decompose joint graphs. + + Args: + decomp_table: + An optional argument that specifies decomp behaviour for Aten ops + (1) If None, we decompose to core aten decompositions + (2) If empty, we don't decompose any operator + + + Some examples: + + If you don't want to decompose anything + + .. code-block:: python + + ep = torch.export.export(model, ...) + ep = ep.run_decompositions(decomp_table={}) + + If you want to get a core aten operator set except for certain operator, you can do following: + + .. code-block:: python + + ep = torch.export.export(model, ...) + decomp_table = torch.export.default_decompositions() + decomp_table[your_op] = your_custom_decomp + ep = ep.run_decompositions(decomp_table=decomp_table) + """ + _decomp_table = ( + default_decompositions() if decomp_table is None else dict(decomp_table) + ) + + if isinstance(_decomp_table, CustomDecompTable): + _decomp_table = _decomp_table.materialize() + + # Note [Seperating decomp_table into CIA decomps and non-CIA decomps] + # At this point, we have a decomp_table that contains decomp behaviour for + # both CIA and post-autograd ops. + # We need to separate the op into two categories: + # 1. CIA op: These are the ops that we want to override + # CompositeImplicitAutograd decomp for. For them, we need to use _override_composite_implicit_decomp + # context manager to plumb it through AOTDispatcher + # 2. Non-CIA op: These ops are only relevant after AOTDIspatcher runs, so just + # checking if they are statically functional is enough. + # For joint IR case tho, we need to use the old path because we can't register + # custom decomps this way because we can't use context manager as it installs + # autograd_error node. + ( + cia_to_decomp, + python_decomp_table, + ) = _split_decomp_table_to_cia_and_python_decomp(_decomp_table) + + return _decompose_exported_program( + self, + cia_to_decomp=cia_to_decomp, + python_decomp_table=python_decomp_table, + joint_loss_index=None, + decompose_custom_triton_ops=decompose_custom_triton_ops, + ) + + def _transform_do_not_use(self, *passes: PassType) -> "ExportedProgram": + pm = PassManager(list(passes)) + # Since we abstractly run the passes, we need to disable backend decomp here + # again. + from torch.export._trace import _ignore_backend_decomps + + with _ignore_backend_decomps(): + res = pm(self.graph_module) + transformed_gm = res.graph_module if res is not None else self.graph_module + assert transformed_gm is not None + + if transformed_gm is self.graph_module and not res.modified: + return self + + # TODO(zhxchen17) Remove this. + def _get_updated_graph_signature( + old_signature: ExportGraphSignature, + new_gm: torch.fx.GraphModule, + ) -> ExportGraphSignature: + """ + Update the graph signature's user_input/user_outputs. + """ + new_input_specs = [] + for i, node in enumerate(new_gm.graph.nodes): + if node.op != "placeholder": + break + + assert i < len(old_signature.input_specs), ( + "Number of inputs changed after transformation" + ) + old_input_spec = old_signature.input_specs[i] + arg = ( + old_input_spec.arg + if isinstance( + old_input_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_input_spec.arg)(node.name) + ) + new_input_specs.append( + InputSpec( + old_input_spec.kind, + arg, + old_input_spec.target, + old_input_spec.persistent, + ) + ) + + output_node = list(new_gm.graph.nodes)[-1] + assert output_node.op == "output" + + new_output_specs = [] + for i, node in enumerate(output_node.args[0]): + assert i < len(old_signature.output_specs), ( + "Number of outputs changed after transformation" + ) + old_output_spec = old_signature.output_specs[i] + arg = ( + old_output_spec.arg + if isinstance( + old_output_spec.arg, (ConstantArgument, CustomObjArgument) + ) + else type(old_output_spec.arg)(node.name) + ) + new_output_specs.append( + OutputSpec(old_output_spec.kind, arg, old_output_spec.target) + ) + + new_signature = ExportGraphSignature( + input_specs=new_input_specs, output_specs=new_output_specs + ) + return new_signature + + transformed_ep = ExportedProgram( + root=transformed_gm, + graph=transformed_gm.graph, + graph_signature=_get_updated_graph_signature( + self.graph_signature, transformed_gm + ), + state_dict=self.state_dict, + range_constraints=_get_updated_range_constraints( + transformed_gm, + self.range_constraints, + ), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + constants=self.constants, + verifiers=self.verifiers, + ) + transformed_ep.graph_module.meta.update(self.graph_module.meta) + transformed_ep.graph_module.meta.update(res.graph_module.meta) + return transformed_ep + + def _check_input_constraints(self, flat_args_with_path): + from torch._export.utils import _check_input_constraints_for_graph + + placeholders = [p for p in self.graph.nodes if p.op == "placeholder"] + input_placeholders = [ + p + for p, s in zip(placeholders, self.graph_signature.input_specs) + if s.kind == InputKind.USER_INPUT + ] + _check_input_constraints_for_graph( + input_placeholders, flat_args_with_path, self.range_constraints + ) + + @compatibility(is_backward_compatible=False) + def validate(self): + self._validate() + + # TODO: remove this + @final + def _validate(self): + assert len(self.verifiers) > 0, ( + "ExportedProgram must have at least one verifier." + ) + for v in self.verifiers: + v().check(self) + + # TODO(zhxchen17) Formalize this. + def _update( + self, + graph_module, + graph_signature, + *, + state_dict=None, + constants=None, + verifiers=None, + ) -> "ExportedProgram": + return ExportedProgram( + root=graph_module, + graph=graph_module.graph, + graph_signature=graph_signature, + state_dict=state_dict if state_dict is not None else self.state_dict, + range_constraints=copy.deepcopy(self.range_constraints), + module_call_graph=copy.deepcopy(self._module_call_graph), + example_inputs=self.example_inputs, + constants=constants if constants is not None else self.constants, + verifiers=verifiers if verifiers is not None else self.verifiers, + ) + + +def _get_shape_env(gm): + vals = [ + node.meta["val"] + for node in gm.graph.nodes + if node.meta.get("val", None) is not None + ] + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode(vals) + if fake_mode is not None: + return fake_mode.shape_env + for v in vals: + if isinstance(v, torch.SymInt): + return v.node.shape_env + + +def _get_updated_range_constraints( + gm: torch.fx.GraphModule, + old_range_constraints: "Optional[dict[sympy.Symbol, Any]]" = None, +) -> "dict[sympy.Symbol, Any]": + assert old_range_constraints is not None + + shape_env = _get_shape_env(gm) + if shape_env is None: + return {} + + range_constraints = copy.copy(old_range_constraints) + range_constraints = { + k: v for k, v in range_constraints.items() if k not in shape_env.replacements + } + # Only when we have an unbacked symint, and it's used as constructor inputs, + # runtime_var_to_range will make a difference compated to var_to_range. + # e.g. [2, oo) -> [0, oo) + for k, v in shape_env.var_to_range.items(): + if k not in shape_env.replacements and k not in range_constraints: + range_constraints[k] = v + return range_constraints + + +def _create_graph_module_for_export(root, graph): + try: + gm = torch.fx.GraphModule(root, graph) + except SyntaxError: + # If custom objects stored in memory are being used in the graph, + # the generated python code will result in a syntax error on the custom + # object, since it is unable to parse the in-memory object. However + # we can still run the graph eagerly through torch.fx.Interpreter, + # so we will bypass this error. + warnings.warn( + "Unable to execute the generated python source code from " + "the graph. The graph module will no longer be directly callable, " + "but you can still run the ExportedProgram, and if needed, you can " + "run the graph module eagerly using torch.fx.Interpreter." + ) + gm = torch.fx.GraphModule(root, torch.fx.Graph()) + gm._graph = graph + + return gm diff --git a/phivenv/Lib/site-packages/torch/export/graph_signature.py b/phivenv/Lib/site-packages/torch/export/graph_signature.py new file mode 100644 index 0000000000000000000000000000000000000000..011eec66a16a13ad732960db0e05967884514885 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/graph_signature.py @@ -0,0 +1,704 @@ +# mypy: allow-untyped-defs +import dataclasses +from collections.abc import Collection, Mapping +from enum import auto, Enum +from typing import Optional, TYPE_CHECKING, Union + +from torch._library.fake_class_registry import FakeScriptObject +from torch._subclasses.fake_tensor import is_fake + + +if TYPE_CHECKING: + import torch + from torch._functorch._aot_autograd.schemas import GraphSignature + +__all__ = [ + "ConstantArgument", + "CustomObjArgument", + "ExportBackwardSignature", + "ExportGraphSignature", + "InputKind", + "InputSpec", + "OutputKind", + "OutputSpec", + "SymIntArgument", + "SymFloatArgument", + "SymBoolArgument", + "TensorArgument", +] + + +@dataclasses.dataclass +class TensorArgument: + name: str + + +@dataclasses.dataclass +class TokenArgument: + name: str + + +@dataclasses.dataclass +class SymIntArgument: + name: str + + +@dataclasses.dataclass +class SymFloatArgument: + name: str + + +@dataclasses.dataclass +class SymBoolArgument: + name: str + + +@dataclasses.dataclass +class CustomObjArgument: + name: str + class_fqn: str + fake_val: Optional[FakeScriptObject] = None + + +@dataclasses.dataclass +class ConstantArgument: + name: str + value: Union[int, float, bool, str, None] + + +ArgumentSpec = Union[ + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + ConstantArgument, + CustomObjArgument, + TokenArgument, +] + + +class InputKind(Enum): + USER_INPUT = auto() + PARAMETER = auto() + BUFFER = auto() + CONSTANT_TENSOR = auto() + CUSTOM_OBJ = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class InputSpec: + kind: InputKind + arg: ArgumentSpec + target: Optional[str] + persistent: Optional[bool] = None + + def __post_init__(self): + if self.kind == InputKind.BUFFER: + assert self.persistent is not None, ( + "Failed to specify persistent flag on BUFFER." + ) + assert isinstance( + self.arg, + ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + ConstantArgument, + CustomObjArgument, + TokenArgument, + ), + ), f"got {type(self.arg)}" + + def __str__(self): + target = "" if self.target is None else f" target='{self.target}'" + persistent = "" if self.persistent is None else f" persistent={self.persistent}" + return f"{str(self.arg.name)}: {str(self.kind.name)}{target}{persistent}" + + +class OutputKind(Enum): + USER_OUTPUT = auto() + LOSS_OUTPUT = auto() + BUFFER_MUTATION = auto() + GRADIENT_TO_PARAMETER = auto() + GRADIENT_TO_USER_INPUT = auto() + USER_INPUT_MUTATION = auto() + TOKEN = auto() + + +@dataclasses.dataclass +class OutputSpec: + kind: OutputKind + arg: ArgumentSpec + target: Optional[str] + + def __post_init__(self): + assert isinstance( + self.arg, + ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + ConstantArgument, + TokenArgument, + CustomObjArgument, + ), + ), self.arg + + def __str__(self): + target = "" if self.target is None else f" target='{self.target}'" + return f"{str(self.arg.name)}: {str(self.kind.name)}{target}" + + +@dataclasses.dataclass +class ExportBackwardSignature: + gradients_to_parameters: dict[str, str] + gradients_to_user_inputs: dict[str, str] + loss_output: str + + +@dataclasses.dataclass +class ExportGraphSignature: + """ + :class:`ExportGraphSignature` models the input/output signature of Export Graph, + which is a fx.Graph with stronger invariants gurantees. + + Export Graph is functional and does not access "states" like parameters + or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` + gurantees that parameters, buffers, and constant tensors are lifted out of + the graph as inputs. Similarly, any mutations to buffers are not included + in the graph either, instead the updated values of mutated buffers are + modeled as additional outputs of Export Graph. + + The ordering of all inputs and outputs are:: + + Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] + Outputs = [*mutated_inputs, *flattened_user_outputs] + + e.g. If following module is exported:: + + class CustomModule(nn.Module): + def __init__(self) -> None: + super(CustomModule, self).__init__() + + # Define a parameter + self.my_parameter = nn.Parameter(torch.tensor(2.0)) + + # Define two buffers + self.register_buffer("my_buffer1", torch.tensor(3.0)) + self.register_buffer("my_buffer2", torch.tensor(4.0)) + + def forward(self, x1, x2): + # Use the parameter, buffers, and both inputs in the forward method + output = ( + x1 + self.my_parameter + ) * self.my_buffer1 + x2 * self.my_buffer2 + + # Mutate one of the buffers (e.g., increment it by 1) + self.my_buffer2.add_(1.0) # In-place addition + + return output + + + mod = CustomModule() + ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) + + Resulting Graph is non-functional:: + + graph(): + %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] + %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] + %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] + %x1 : [num_users=1] = placeholder[target=x1] + %x2 : [num_users=1] = placeholder[target=x2] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) + %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) + %add_ : [num_users=0] = call_function[target=torch.ops.aten.add_.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) + return (add_1,) + + Resulting ExportGraphSignature of the non-functional Graph would be:: + + # inputs + p_my_parameter: PARAMETER target='my_parameter' + b_my_buffer1: BUFFER target='my_buffer1' persistent=True + b_my_buffer2: BUFFER target='my_buffer2' persistent=True + x1: USER_INPUT + x2: USER_INPUT + + # outputs + add_1: USER_OUTPUT + + To get a functional Graph, you can use :func:`run_decompositions`:: + + mod = CustomModule() + ep = torch.export.export(mod, (torch.tensor(1.0), torch.tensor(2.0))) + ep = ep.run_decompositions() + + Resulting Graph is functional:: + + graph(): + %p_my_parameter : [num_users=1] = placeholder[target=p_my_parameter] + %b_my_buffer1 : [num_users=1] = placeholder[target=b_my_buffer1] + %b_my_buffer2 : [num_users=2] = placeholder[target=b_my_buffer2] + %x1 : [num_users=1] = placeholder[target=x1] + %x2 : [num_users=1] = placeholder[target=x2] + %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x1, %p_my_parameter), kwargs = {}) + %mul : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%add, %b_my_buffer1), kwargs = {}) + %mul_1 : [num_users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%x2, %b_my_buffer2), kwargs = {}) + %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%mul, %mul_1), kwargs = {}) + %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%b_my_buffer2, 1.0), kwargs = {}) + return (add_2, add_1) + + Resulting ExportGraphSignature of the functional Graph would be:: + + # inputs + p_my_parameter: PARAMETER target='my_parameter' + b_my_buffer1: BUFFER target='my_buffer1' persistent=True + b_my_buffer2: BUFFER target='my_buffer2' persistent=True + x1: USER_INPUT + x2: USER_INPUT + + # outputs + add_2: BUFFER_MUTATION target='my_buffer2' + add_1: USER_OUTPUT + + """ + + input_specs: list[InputSpec] + output_specs: list[OutputSpec] + + # A list of parameters uniquely identified by mangled fully qualified name + @property + def parameters(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.PARAMETER + if isinstance(s.target, str) + ) + + # A list of buffers uniquely identified by mangled fully qualified name + @property + def buffers(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if isinstance(s.target, str) + ) + + @property + def non_persistent_buffers(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.BUFFER + if s.persistent is False + if isinstance(s.target, str) + ) + + # A list of lifted constant tensors + @property + def lifted_tensor_constants(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + if isinstance(s.target, str) + ) + + @property + def lifted_custom_objs(self) -> Collection[str]: + return tuple( + s.target + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + if isinstance(s.target, str) + ) + + # Graph node names of pytree-flattened inputs of original program + @property + def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]: + user_inputs: list[Union[int, float, bool, None, str]] = [] + for s in self.input_specs: + if s.kind != InputKind.USER_INPUT: + continue + + if isinstance( + s.arg, + ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + CustomObjArgument, + ), + ): + user_inputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_inputs.append(s.arg.value) + else: + raise RuntimeError(f"{s.arg} is not a valid user inputs") + return tuple(user_inputs) + + # Graph node names of pytree-flattened outputs of original program + # For joint-graph purposes, will include the loss output. + @property + def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: + user_outputs: list[Union[int, float, bool, None, str]] = [] + for s in self.output_specs: + if s.kind not in [ + OutputKind.USER_OUTPUT, + OutputKind.LOSS_OUTPUT, + ]: + continue + + if isinstance( + s.arg, + (TensorArgument, SymIntArgument, SymFloatArgument, SymBoolArgument), + ): + user_outputs.append(s.arg.name) + elif isinstance(s.arg, ConstantArgument): + user_outputs.append(s.arg.value) + elif isinstance(s.arg, CustomObjArgument): + user_outputs.append(s.arg.name) + else: + raise RuntimeError(f"{s.arg} is not a valid user output") + return tuple(user_outputs) + + # A dictionary mapping graph input node names to parameters. If a graph input + # name is found in this dictionary, it is guranteed to be a lifted parameter. + @property + def inputs_to_parameters(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.PARAMETER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph input node names to buffers. If a graph input + # name is found in this dictionary, it is guranteed to be a lifted buffer. + @property + def inputs_to_buffers(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) # type: ignore[union-attr, misc] + for s in self.input_specs + if s.kind == InputKind.BUFFER + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph output node names to buffers that are mutated in the + # original program. Buffers that are not mutated will not be found in this dictionary. + @property + def buffers_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.BUFFER_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def user_inputs_to_mutate(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.output_specs + if s.kind == OutputKind.USER_INPUT_MUTATION + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + # A dictionary mapping graph input node names to lifted tensor constants. + @property + def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.CONSTANT_TENSOR + and isinstance(s.arg, TensorArgument) + and isinstance(s.target, str) + ) + + @property + def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: + return _immutable_dict( + (s.arg.name, s.target) + for s in self.input_specs + if s.kind == InputKind.CUSTOM_OBJ + and isinstance(s.arg, CustomObjArgument) + and isinstance(s.target, str) + ) + + @property + def backward_signature(self) -> Optional[ExportBackwardSignature]: + loss_output = None + gradients_to_parameters: dict[str, str] = {} + gradients_to_user_inputs: dict[str, str] = {} + for spec in self.output_specs: + if spec.kind == OutputKind.LOSS_OUTPUT: + assert loss_output is None + assert isinstance(spec.arg, TensorArgument) + loss_output = spec.arg.name + elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_parameters[spec.arg.name] = spec.target + elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: + assert isinstance(spec.target, str) + assert isinstance(spec.arg, TensorArgument) + gradients_to_user_inputs[spec.arg.name] = spec.target + + if loss_output is None: + return None + + return ExportBackwardSignature( + loss_output=loss_output, + gradients_to_parameters=gradients_to_parameters, + gradients_to_user_inputs=gradients_to_user_inputs, + ) + + # Map from assertion dependency token index to assertion dep token output + # name in output. The shape of output after aot_autograd will be like: + # (updated_inputs, user_outputs, dep_token). + @property + def assertion_dep_token(self) -> Optional[Mapping[int, str]]: + return None + + @property + def input_tokens(self) -> Collection[str]: + input_tokens = [] + for s in self.input_specs: + if s.kind == InputKind.TOKEN: + assert isinstance(s.arg, TokenArgument) + input_tokens.append(s.arg.name) + return tuple(input_tokens) + + @property + def output_tokens(self) -> Collection[str]: + output_tokens = [] + for s in self.output_specs: + if s.kind == OutputKind.TOKEN: + assert isinstance(s.arg, TokenArgument) + output_tokens.append(s.arg.name) + return tuple(output_tokens) + + def __post_init__(self) -> None: + assertion_dep_token = self.assertion_dep_token + if assertion_dep_token is None: + return + assert len(assertion_dep_token) == 1 + assertion_dep_token_index = next(iter(assertion_dep_token.keys())) + assert ( + len(self.user_outputs) + len(self.buffers_to_mutate) + == assertion_dep_token_index + ) + + def replace_all_uses(self, old: str, new: str): + """ + Replace all uses of the old name with new name in the signature. + """ + assert isinstance(old, str) + assert isinstance(new, str) + arg_types = ( + TensorArgument, + SymIntArgument, + SymFloatArgument, + SymBoolArgument, + CustomObjArgument, + TokenArgument, + ) + for o in self.output_specs: + if isinstance(o.arg, arg_types): + if o.arg.name == old: + o.arg.name = new + for i in self.input_specs: + if isinstance(i.arg, arg_types): + if i.arg.name == old: + i.arg.name = new + + def get_replace_hook(self, replace_inputs=False): + def _(old, new, user): + if user.op == "output": + self.replace_all_uses(old.name, new) + if replace_inputs and old.op == "placeholder": + self.replace_all_uses(old.name, new) + + return _ + + def __str__(self): + input_specs = "\n".join(str(s) for s in self.input_specs) + output_specs = "\n".join(str(s) for s in self.output_specs) + return f"\n# inputs\n{input_specs}\n\n# outputs\n{output_specs}\n" + + +def _immutable_dict(items): + """ + Creates a mapping where items cannot be added, deleted, or updated. + NOTE: The immutability is shallow (like tuple is an immutable collection). + """ + from types import MappingProxyType + + return MappingProxyType(dict(items)) + + +def _make_argument_spec(node, token_names) -> ArgumentSpec: + from torch import ScriptObject, SymBool, SymFloat, SymInt + from torch._library.fake_class_registry import FakeScriptObject + + if isinstance(node, (int, bool, float, type(None), str)): + # For const outputs we just directly return this + return ConstantArgument(name="", value=node) + + assert "val" in node.meta, ( + f"{node} is not a constant or a node with a 'val' metadata field" + ) + val = node.meta["val"] + if node.name in token_names: + return TokenArgument(name=node.name) + elif is_fake(val): + return TensorArgument(name=node.name) + elif isinstance(val, SymInt): + return SymIntArgument(name=node.name) + elif isinstance(val, SymFloat): + return SymFloatArgument(name=node.name) + elif isinstance(val, SymBool): + return SymBoolArgument(name=node.name) + elif isinstance(val, ScriptObject): + return CustomObjArgument(name=node.name, class_fqn=val._type().qualified_name()) # type: ignore[attr-defined] + elif isinstance(val, FakeScriptObject): + return CustomObjArgument( + name=node.name, class_fqn=val.script_class_name, fake_val=val + ) + elif isinstance(val, (int, bool, str, float, type(None))): + return ConstantArgument(name=node.name, value=val) + else: + raise AssertionError( + f"Encountered an unsupported object of type {type(val)} " + f"while writing the metadata for exported program" + ) + + +def _convert_to_export_graph_signature( + graph_signature: "GraphSignature", + gm: "torch.fx.GraphModule", + non_persistent_buffers: set[str], +) -> "ExportGraphSignature": + from torch.utils import _pytree as pytree + + is_joint = graph_signature.backward_signature is not None + + # unpack objects + user_inputs = set(graph_signature.user_inputs) + inputs_to_parameters = graph_signature.inputs_to_parameters + inputs_to_buffers = graph_signature.inputs_to_buffers + user_outputs = set(graph_signature.user_outputs) + buffer_mutations = graph_signature.buffers_to_mutate + user_input_mutations = graph_signature.user_inputs_to_mutate + grad_params = ( + graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr] + if is_joint + else {} + ) + grad_user_inputs = ( + graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr] + if is_joint + else {} + ) + loss_output = ( + graph_signature.backward_signature.loss_output # type: ignore[union-attr] + if is_joint + else None + ) + input_tokens = graph_signature.input_tokens + output_tokens = graph_signature.output_tokens + + inputs = [ + _make_argument_spec(node, input_tokens) + for node in gm.graph.nodes + if node.op == "placeholder" + ] + outputs = [ + _make_argument_spec(node, output_tokens) + for node in pytree.tree_leaves(next(iter(reversed(gm.graph.nodes))).args) + ] + + def to_input_spec(inp: ArgumentSpec) -> InputSpec: + if isinstance(inp, TokenArgument): + return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) + + if not isinstance(inp, TensorArgument): + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + name = inp.name + if name in user_inputs: + return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) + elif name in inputs_to_parameters: + return InputSpec( + kind=InputKind.PARAMETER, + arg=inp, + target=inputs_to_parameters[name], # type: ignore[index] + ) + elif name in inputs_to_buffers: + return InputSpec( + kind=InputKind.BUFFER, + arg=inp, + target=inputs_to_buffers[name], # type: ignore[index] + persistent=(inputs_to_buffers[name] not in non_persistent_buffers), # type: ignore[index] + ) + else: + raise AssertionError(f"Unknown tensor input kind: {name}") + + def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: + if isinstance(o, TokenArgument): + return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) + + if not isinstance(o, TensorArgument): + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + name = o.name + if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): + if name in buffer_mutations: + return OutputSpec( + kind=OutputKind.BUFFER_MUTATION, + arg=o, + target=buffer_mutations[name], # type: ignore[index] + ) + elif name in user_input_mutations: + return OutputSpec( + kind=OutputKind.USER_INPUT_MUTATION, + arg=o, + target=user_input_mutations[name], # type: ignore[index] + ) + else: + raise AssertionError(f"Unknown tensor mutation kind: {name}") + else: + if name in user_outputs: + return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) + + elif name in grad_params: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_PARAMETER, + arg=o, + target=grad_params[name], + ) + elif name in grad_user_inputs: + return OutputSpec( + kind=OutputKind.GRADIENT_TO_USER_INPUT, + arg=o, + target=grad_user_inputs[name], + ) + elif name == loss_output: + return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) + + else: + raise AssertionError(f"Unknown tensor output kind: {name}") + + input_specs = [to_input_spec(inp) for inp in inputs] + output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] + return ExportGraphSignature(input_specs=input_specs, output_specs=output_specs) diff --git a/phivenv/Lib/site-packages/torch/export/passes/__init__.py b/phivenv/Lib/site-packages/torch/export/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0adfa5c9112b2a4bcb6d2689c6e5796a84ccd43a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/passes/__init__.py @@ -0,0 +1,70 @@ +from typing import Union + +import torch +import torch.utils._pytree as pytree +from torch.export.exported_program import ExportedProgram + + +__all__ = ["move_to_device_pass"] + + +def move_to_device_pass( + ep: ExportedProgram, location: Union[torch.device, str, dict[str, str]] +) -> ExportedProgram: + """ + Move the exported program to the given device. + + Args: + ep (ExportedProgram): The exported program to move. + location (Union[torch.device, str, Dict[str, str]]): The device to move the exported program to. + If a string, it is interpreted as a device name. + If a dict, it is interpreted as a mapping from + the existing device to the intended one + + Returns: + ExportedProgram: The moved exported program. + """ + + def _get_new_device( + curr_device: torch.device, + location: Union[torch.device, str, dict[str, str]], + ) -> str: + if isinstance(location, dict): + if str(curr_device) in location.keys(): + return location[str(curr_device)] + else: + return str(curr_device) + else: + return str(location) + + # move all the state_dict + for k, v in ep.state_dict.items(): + if isinstance(v, torch.nn.Parameter): + ep._state_dict[k] = torch.nn.Parameter( + v.to(_get_new_device(v.device, location)), + v.requires_grad, + ) + else: + ep._state_dict[k] = v.to(_get_new_device(v.device, location)) + + # move all the constants + for k, v in ep.constants.items(): + if isinstance(v, torch.Tensor): + ep._constants[k] = v.to(_get_new_device(v.device, location)) + + for node in ep.graph.nodes: + # move all the nodes kwargs with burnt-in device + if "device" in node.kwargs: + kwargs = node.kwargs.copy() + kwargs["device"] = _get_new_device(kwargs["device"], location) + node.kwargs = kwargs + # move all the tensor metadata + node.meta["val"] = pytree.tree_map( + lambda v: v.to(_get_new_device(v.device, location)) + if isinstance(v, torch.Tensor) + else v, + node.meta.get("val"), + ) + + ep.validate() + return ep diff --git a/phivenv/Lib/site-packages/torch/export/passes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc708104942539835e040fe46d3f2877877d01c5 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/__init__.py b/phivenv/Lib/site-packages/torch/export/pt2_archive/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d684407fcb3dc733da829e5ff32993a8bb87c6dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/pt2_archive/__init__.py @@ -0,0 +1,4 @@ +from ._package import is_pt2_package, PT2ArchiveReader, PT2ArchiveWriter + + +__all__ = ["PT2ArchiveWriter", "PT2ArchiveReader", "is_pt2_package"] diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36c578cb88c2657ed1aabe1799a77395622ee6f2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dffb2c9171ff7e150c4c814e63efbdd41c88c2f8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/_package.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5a1b21b7bcbaf587c38d0b7cfb159daf92f3bfa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/_package_weights.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-39.pyc b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4ef8aa347533b82aca2ba6fb9a2123de8d80050 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/export/pt2_archive/__pycache__/constants.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/_package.py b/phivenv/Lib/site-packages/torch/export/pt2_archive/_package.py new file mode 100644 index 0000000000000000000000000000000000000000..8ee77497a380e0500a2e1e381dff466e30a27fd1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/pt2_archive/_package.py @@ -0,0 +1,685 @@ +import glob +import io +import json +import logging +import os +import tempfile +import zipfile +from dataclasses import dataclass +from typing import Any, IO, Optional, TYPE_CHECKING, Union +from typing_extensions import TypeAlias + +import torch +import torch.utils._pytree as pytree +from torch._export.serde.serialize import deserialize, serialize, SerializedArtifact +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ExportedProgram +from torch.export.pt2_archive._package_weights import ( + get_complete, + group_weights, + Weights, +) +from torch.export.pt2_archive.constants import ( + AOTINDUCTOR_DIR, + ARCHIVE_FORMAT_PATH, + ARCHIVE_FORMAT_VALUE, + ARCHIVE_VERSION_PATH, + ARCHIVE_VERSION_VALUE, + CONSTANTS_DIR, + CUSTOM_OBJ_FILENAME_PREFIX, + EXTRA_DIR, + MODELS_DIR, + MODELS_FILENAME_FORMAT, + SAMPLE_INPUTS_FILENAME_FORMAT, + WEIGHT_FILENAME_PREFIX, + WEIGHTS_DIR, +) +from torch.types import FileLike + + +if TYPE_CHECKING: + from torch.utils._ordered_set import OrderedSet + + +DEFAULT_PICKLE_PROTOCOL = 2 +AOTI_FILES: TypeAlias = Union[ + list[Union[str, Weights]], dict[str, list[Union[str, Weights]]] +] + + +logger: logging.Logger = logging.getLogger(__name__) + + +def is_pt2_package(serialized_model: Union[bytes, str]) -> bool: + """ + Check if the serialized model is a PT2 Archive package. + """ + try: + zip_reader = zipfile.ZipFile( + io.BytesIO(serialized_model) + if isinstance(serialized_model, bytes) + else serialized_model + ) + root_folder = zip_reader.namelist()[0].split(os.path.sep)[0] + archive_format_path = f"{root_folder}/{ARCHIVE_FORMAT_PATH}" + if archive_format_path in zip_reader.namelist(): + return zip_reader.read(archive_format_path) == b"pt2" + except Exception as ex: + logger.info("Model is not a PT2 package: %s", str(ex)) + return False + + +class PT2ArchiveWriter: + """ + Context manager for writing a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + self.archive_file = torch._C.PyTorchFileWriter(archive_path_or_buffer) # type: ignore[arg-type] + # NOTICE: version here is different from the archive_version + # this is the version of zip file format, which is used by PyTorchFileWriter, which write to /.data/version + # archive_version is the version of the PT2 archive spec, which write to /archive_version + self.archive_file.set_min_version(6) + + def __enter__(self) -> "PT2ArchiveWriter": + return self + + def __exit__(self, *args: Any) -> None: + if not self.has_record(ARCHIVE_FORMAT_PATH): + self.write_string(ARCHIVE_FORMAT_PATH, ARCHIVE_FORMAT_VALUE) + + if not self.has_record(ARCHIVE_VERSION_PATH): + self.write_string(ARCHIVE_VERSION_PATH, ARCHIVE_VERSION_VALUE) + + self.close() + + def has_record(self, name: str) -> bool: + """ + Check if a record exists in the archive. + """ + return name in self.archive_file.get_all_written_records() + + def count_prefix(self, prefix: str) -> int: + """ + Count the number of records that start with a given prefix. + """ + return sum( + 1 + for record in self.archive_file.get_all_written_records() + if record.startswith(prefix) + ) + + def write_bytes(self, name: str, data: bytes) -> None: + """ + Write a bytes object to the archive. + name: The destination file inside the archive. + data: The bytes object to write. + """ + assert isinstance(data, bytes), f"Expected bytes but got {type(data)}" + self.archive_file.write_record(name, data, len(data)) + + def write_string(self, name: str, data: str) -> None: + """ + Write a string object to the archive. + name: The destination file inside the archive. + data: The string object to write. + """ + assert isinstance(data, str), f"Expected string but got {type(data)}" + data_bytes = data.encode() + self.write_bytes(name, data_bytes) + + def write_file(self, name: str, file_path: str) -> None: + """ + Copy a file into the archive. + name: The destination file inside the archive. + file_path: The source file on disk. + """ + assert os.path.isfile(file_path), f"{file_path} is not a valid file path" + + with open(file_path, "rb") as f: + file_bytes = f.read() + self.write_bytes(name, file_bytes) + + def write_folder(self, archive_dir: str, folder_dir: str) -> None: + """ + Copy a folder into the archive. + archive_dir: The destination folder inside the archive. + folder_dir: The source folder on disk. + """ + assert os.path.isdir(folder_dir), f"{folder_dir} is not a valid directory path" + + file_paths = filter( + os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True) + ) + for file_path in file_paths: + filename = os.path.relpath(file_path, folder_dir) + archive_path = os.path.join(archive_dir, filename) + self.write_file(archive_path, file_path) + + def close(self) -> None: + """ + Close the archive. + """ + self.archive_file.write_end_of_file() + + +class PT2ArchiveReader: + """ + Context manager for reading a PT2 archive. + """ + + def __init__(self, archive_path_or_buffer: FileLike): + self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] + assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( + "Invalid archive format" + ) + + def __enter__(self) -> "PT2ArchiveReader": + return self + + def __exit__(self, *args: Any) -> None: + # torch._C.PyTorchFileReader doesn't have a close method + pass + + def read_bytes(self, name: str) -> bytes: + """ + Read a bytes object from the archive. + name: The source file inside the archive. + """ + return self.archive_file.get_record(name) + + def read_string(self, name: str) -> str: + """ + Read a string object from the archive. + name: The source file inside the archive. + """ + data = self.read_bytes(name) + return data.decode() + + def archive_version(self) -> int: + """ + Get the archive version. + """ + try: + archive_version = self.read_string(ARCHIVE_VERSION_PATH) + except Exception: + # if archive_version is not found, it means the archive is older than version 0. + # In this case, we assume the archive is version 0. + archive_version = "0" + + return int(archive_version) + + def get_file_names(self) -> list[str]: + """ + Get the file names in the archive. + """ + return self.archive_file.get_all_records() + + +def _package_aoti_files( + archive_writer: PT2ArchiveWriter, + aoti_files: Optional[AOTI_FILES], + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if aoti_files is None: + return + + if isinstance(aoti_files, list): + aoti_files = {"model": aoti_files} + + assert isinstance(aoti_files, dict) + + all_weights: dict[str, Weights] = {} # model_name -> weight + weights_configs: dict[ + str, dict[str, Any] + ] = {} # model_name -> (weight_name -> (filename, shape, stride, offset)) + + for model_name, files in aoti_files.items(): + num_so_files = 0 + weights_configs[model_name] = {} + + for file in files: + if file == "": + continue + + if isinstance(file, Weights): + all_weights[model_name] = file + continue + + if file.endswith(".so"): + num_so_files += 1 + if num_so_files > 1: + raise RuntimeError( + f"Multiple .so files found in {files}. " + "You might need to clear your cache " + "directory before calling aoti_compile again." + ) + + filename = os.path.basename(file) + if filename.startswith(CUSTOM_OBJ_FILENAME_PREFIX): + new_filepath = os.path.join(CONSTANTS_DIR, filename) + else: + new_filepath = os.path.join(AOTINDUCTOR_DIR, model_name, filename) + logger.debug( + "Saving AOTI generated file %s to archive in %s", file, new_filepath + ) + archive_writer.write_file( + str(new_filepath), + file, + ) + + if len(all_weights) > 0: + # Dedup weights + grouped_tensors: list[OrderedSet[tuple[str, str]]] = group_weights(all_weights) + for idx, group in enumerate(grouped_tensors): + filename = f"{WEIGHT_FILENAME_PREFIX}{idx}" + model_name, weight_name = get_complete(group, all_weights) + complete_tensor, _ = all_weights[model_name].get_weight(weight_name) + buffer = io.BytesIO() + torch.save(complete_tensor, buffer, pickle_protocol=pickle_protocol) + archive_writer.write_bytes( + os.path.join(WEIGHTS_DIR, filename), buffer.getvalue() + ) + for model_name, weight_name in group: + _, w_property = all_weights[model_name].get_weight(weight_name) + weights_configs[model_name][weight_name] = ( + filename, + w_property.shape, + w_property.stride, + w_property.offset, + ) + + for model_name, weights_config in weights_configs.items(): + archive_writer.write_string( + os.path.join(AOTINDUCTOR_DIR, model_name, "weights_config.json"), + json.dumps(weights_config), + ) + logger.debug("packaging weights_config for model %s", model_name) + logger.debug(weights_config) + + +def _package_exported_programs( + archive_writer: PT2ArchiveWriter, + exported_programs: Optional[Union[ExportedProgram, dict[str, ExportedProgram]]], + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> None: + if exported_programs is None: + return + + if isinstance(exported_programs, ExportedProgram): + exported_programs = {"model", exported_programs} # type: ignore[assignment] + + assert isinstance(exported_programs, dict) + + for model_name, ep in exported_programs.items(): + artifact: SerializedArtifact = serialize(ep, opset_version, pickle_protocol) + + archive_writer.write_bytes( + MODELS_FILENAME_FORMAT.format(model_name), artifact.exported_program + ) + # TODO:Consider dedup this with the weights saved in package_aoti_files + archive_writer.write_bytes(f"{WEIGHTS_DIR}{model_name}.pt", artifact.state_dict) + archive_writer.write_bytes( + f"{CONSTANTS_DIR}{model_name}.pt", artifact.constants + ) + archive_writer.write_bytes( + SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name), + artifact.example_inputs, + ) + + +def _package_extra_files( + archive_writer: PT2ArchiveWriter, extra_files: Optional[dict[str, Any]] +) -> None: + if extra_files is None: + return + + for extra_file_name, content in extra_files.items(): + archive_writer.write_string(f"{EXTRA_DIR}{extra_file_name}", content) + + +def package_pt2( + f: FileLike, + *, + exported_programs: Optional[ + Union[ExportedProgram, dict[str, ExportedProgram]] + ] = None, + aoti_files: Optional[AOTI_FILES] = None, + extra_files: Optional[dict[str, Any]] = None, + opset_version: Optional[dict[str, int]] = None, + pickle_protocol: int = DEFAULT_PICKLE_PROTOCOL, +) -> FileLike: + """ + Saves the artifacts to a PT2Archive format + (https://docs.google.com/document/d/1RQ4cmywilnFUT1VE-4oTGxwXdc8vowCSZsrRgo3wFA8/edit?tab=t.0#heading=h.v2y2jgnwc56a). + The artifact can then be loaded using ``load_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]) A file-like object (has to + implement write and flush) or a string containing a file name. + + exported_programs (Union[ExportedProgram, dict[str, ExportedProgram]]): + The exported program to save, or a dictionary mapping model name to an + exported program to save. The exported program will be saved under + models/*.json. If only one ExportedProgram is specified, this will + automatically be named "model". + + aoti_files (Union[list[str], dict[str, list[str]]): A list of files + generated by AOTInductor via + ``torch._inductor.aot_compile(..., {"aot_inductor.package": True})``, + or a dictionary mapping model name to its AOTInductor generated files. + If only one set of files is specified, this will automatically be named + "model". + + extra_files (Optional[Dict[str, Any]]): Map from filename to contents + which will be stored as part of the pt2. + + opset_version (Optional[Dict[str, int]]): A map of opset names + to the version of this opset + + pickle_protocol: can be specified to override the default protocol + + """ + assert not ( + exported_programs is None and aoti_files is None and extra_files is None + ), ( + "No value passed in for `exported_programs`, `aoti_files`, and " + "`extra_files`, implying that you do not plan on saving anything." + ) + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error + logger.warning( + "Expect archive file to be a file ending in .pt2, or is a buffer. " + "Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + with PT2ArchiveWriter(f) as archive_writer: + _package_exported_programs( + archive_writer, exported_programs, pickle_protocol=pickle_protocol + ) + _package_aoti_files( + archive_writer, + aoti_files, + pickle_protocol=pickle_protocol, + ) + _package_extra_files(archive_writer, extra_files) + + if isinstance(f, (io.IOBase, IO)): + f.seek(0) + return f + + +class AOTICompiledModel: + """ + Callable AOT Inductor loaded model from a .pt2 + """ + + def __init__(self, loader: torch._C._aoti.AOTIModelPackageLoader) -> None: + self.loader = loader + + def __call__(self, *args, **kwargs): # type: ignore[no-untyped-def] + call_spec = self.loader.get_call_spec() + in_spec = pytree.treespec_loads(call_spec[0]) + out_spec = pytree.treespec_loads(call_spec[1]) + flat_inputs = pytree.tree_flatten((args, reorder_kwargs(kwargs, in_spec)))[0] + flat_inputs = [x for x in flat_inputs if isinstance(x, torch.Tensor)] + flat_outputs = self.loader.boxed_run(flat_inputs) + return pytree.tree_unflatten(flat_outputs, out_spec) + + def get_metadata(self) -> dict[str, str]: + return self.loader.get_metadata() + + def load_constants( + self, + constants_map: dict[str, torch.Tensor], + *, + check_full_update: bool, + user_managed: bool = False, + ) -> None: + """ + Given a mapping of constant fqns to tensors, load the constants into the model. + You can use ``get_constant_fqns`` to get the list of constant fqns that + are needed in the compiled model. + + Args: + constants_map: A mapping of constant fqns to tensors. + check_full_update: Whether to add check to see if all the constants + are updated and have values. + """ + self.loader.load_constants( + constants_map, False, check_full_update, user_managed + ) + + def get_constant_fqns(self) -> list[str]: + return self.loader.get_constant_fqns() + + def __deepcopy__(self, memo: Optional[dict[Any, Any]]) -> "AOTICompiledModel": + logger.warning( + "AOTICompiledModel deepcopy warning: AOTICompiledModel.loader is not deepcopied." + ) + return AOTICompiledModel(self.loader) + + +@dataclass +class PT2ArchiveContents: + exported_programs: dict[str, ExportedProgram] + aoti_runners: dict[str, AOTICompiledModel] + extra_files: dict[str, Any] + + +def _load_exported_programs( + archive_reader: PT2ArchiveReader, + file_names: list[str], + expected_opset_version: Optional[dict[str, int]], +) -> dict[str, ExportedProgram]: + exported_program_files = [ + file for file in file_names if file.startswith(MODELS_DIR) + ] + exported_programs = {} + for file in exported_program_files: + prefix, suffix = MODELS_FILENAME_FORMAT.split( + "{}" + ) # split "models/{}.json" into "models/" and "json" + model_name = file[ + len(prefix) : -len(suffix) + ] # given "models/foo.json" we can now get "foo" + + weights_file = f"{WEIGHTS_DIR}{model_name}.pt" + constants_file = f"{CONSTANTS_DIR}{model_name}.pt" + sample_inputs_file = SAMPLE_INPUTS_FILENAME_FORMAT.format(model_name) + + serialized_exported_program = archive_reader.read_bytes(file) + serialized_weights = archive_reader.read_bytes(weights_file) + serialized_constants = archive_reader.read_bytes(constants_file) + serialized_sample_inputs = archive_reader.read_bytes(sample_inputs_file) + + artifact: SerializedArtifact = SerializedArtifact( + serialized_exported_program, + serialized_weights, + serialized_constants, + serialized_sample_inputs, + ) + + # Deserialize ExportedProgram + ep = deserialize(artifact, expected_opset_version) + exported_programs[model_name] = ep + + return exported_programs + + +def _load_extra_files( + archive_reader: PT2ArchiveReader, file_names: list[str] +) -> dict[str, Any]: + extra_files = [file for file in file_names if file.startswith(EXTRA_DIR)] + + extra_file_contents: dict[str, Any] = {} + for file in extra_files: + contents = archive_reader.read_string(file) + extra_file_contents[file[len(EXTRA_DIR) :]] = contents + + return extra_file_contents + + +def load_pt2( + f: FileLike, + *, + expected_opset_version: Optional[dict[str, int]] = None, + run_single_threaded: bool = False, + num_runners: int = 1, + device_index: int = -1, + load_weights_from_disk: bool = False, +) -> PT2ArchiveContents: # type: ignore[type-arg] + """ + Loads all the artifacts previously saved with ``package_pt2``. + + Args: + f (str | os.PathLike[str] | IO[bytes]): A file-like object (has to + implement write and flush) or a string containing a file name. + + expected_opset_version (Optional[Dict[str, int]]): A map of opset names + to expected opset versions + + num_runners (int): Number of runners to load AOTInductor artifacts + + run_single_threaded (bool): Whether the model should be run without + thread synchronization logic. This is useful to avoid conflicts with + CUDAGraphs. + + device_index (int): The index of the device to which the PT2 package is + to be loaded. By default, `device_index=-1` is used, which corresponds + to the device `cuda` when using CUDA. Passing `device_index=1` would + load the package to `cuda:1`, for example. + + Returns: + A ``PT2ArchiveContents`` object which contains all the objects in the PT2. + """ + + if not ( + (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) + or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) + ): + # TODO: turn this into an error in 2.9 + logger.warning( + "Unable to load package. f must be a buffer or a file ending in " + ".pt2. Instead got {%s}", + f, + ) + + if isinstance(f, (str, os.PathLike)): + f = os.fspath(f) + + weights = {} + weight_maps = {} + with PT2ArchiveReader(f) as archive_reader: + version = archive_reader.read_string(ARCHIVE_VERSION_PATH) + if version != ARCHIVE_VERSION_VALUE: + raise ValueError( + f"Saved archive version {version} does not match our current " + f"archive version {ARCHIVE_VERSION_VALUE}." + ) + + file_names = archive_reader.get_file_names() + + exported_programs = _load_exported_programs( + archive_reader, file_names, expected_opset_version + ) + extra_files = _load_extra_files(archive_reader, file_names) + + # Get a list of AOTI model names + aoti_model_names: set[str] = set() + for file in file_names: + if file.startswith(AOTINDUCTOR_DIR): + file_end = file[ + len(AOTINDUCTOR_DIR) : + ] # remove data/aotinductor/ prefix + model_name = file_end.split("/")[ + 0 + ] # split "model_name/...cpp" into "model_name" + aoti_model_names.add(model_name) + if load_weights_from_disk and file.endswith("weights_config.json"): + weight_map = json.loads(archive_reader.read_string(file)) + weight_maps[model_name] = weight_map + elif load_weights_from_disk and file.startswith(WEIGHTS_DIR): + weight_file_name = file[ + len(WEIGHTS_DIR) : + ] # remove data/weights/ prefix + weight_bytes = archive_reader.read_bytes(file) + loaded_weight = torch.load(io.BytesIO(weight_bytes)) + weights[weight_file_name] = loaded_weight + + if isinstance(f, (io.IOBase, IO)): + if len(aoti_model_names) > 0: + # Workaround for AOTIModelPackageLoader not reading buffers + with tempfile.NamedTemporaryFile(suffix=".pt2") as tf: + f.seek(0) + tf.write(f.read()) + f.seek(0) + logger.debug("Writing buffer to tmp file located at %s.", tf.name) + + aoti_runners = { + model_name: AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + tf.name, + model_name, + run_single_threaded, + num_runners, + device_index, + ) + ) + for model_name in aoti_model_names + } + else: + aoti_runners = {} + else: + aoti_runners = { + model_name: AOTICompiledModel( + torch._C._aoti.AOTIModelPackageLoader( + f, model_name, run_single_threaded, num_runners, device_index + ) + ) + for model_name in aoti_model_names + } + + if weight_maps: + for model_name in aoti_model_names: + model_weights = {} + for weight_name, (file, shape, stride, storage_offset) in weight_maps[ + model_name + ].items(): + weight = weights[file] + model_weights[weight_name] = weight.as_strided( + shape, stride, storage_offset + ) + + # user_managed=True ensures the weights updates are shared by all runners. + aoti_runners[model_name].load_constants( + model_weights, check_full_update=True, user_managed=True + ) + + return PT2ArchiveContents(exported_programs, aoti_runners, extra_files) + + +def load_weights_to_pt2_contents( + pt2_contents: PT2ArchiveContents, weights_map: dict[str, Any] +) -> None: + """ + Load weights into the models in PT2 archive contents + + Args: + pt2_contents (PT2ArchiveContents): The contents of the PT2 archive. + """ + for model_name, weights in weights_map.items(): + if model_name not in pt2_contents.aoti_runners: + raise RuntimeError(f"Model {model_name} not found in PT2 archive contents.") + pt2_contents.aoti_runners[model_name].load_constants( + weights, check_full_update=True, user_managed=True + ) diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/_package_weights.py b/phivenv/Lib/site-packages/torch/export/pt2_archive/_package_weights.py new file mode 100644 index 0000000000000000000000000000000000000000..a49b869e9a4307394766ba6e726236081632ec9a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/pt2_archive/_package_weights.py @@ -0,0 +1,101 @@ +import collections + +import torch +from torch.utils._ordered_set import OrderedSet + + +def _end_ptr(tensor: torch.Tensor) -> int: + if tensor.nelement(): + stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size() + else: + stop = tensor.data_ptr() + return stop + + +class TensorProperties: + def __init__(self, tensor: torch.Tensor): + # info about underlying storage + self.storage_ptr = tensor.untyped_storage().data_ptr() + self.storage_size = tensor.untyped_storage().nbytes() + + # info to recover tensor + self.shape = tensor.shape + self.stride = tensor.stride() + self.offset = tensor.storage_offset() + + self.start = tensor.data_ptr() + self.end = _end_ptr(tensor) + + def is_complete(self) -> bool: + """ + Whehter the tensor completely overlaps with its underlying storage + """ + return ( + self.start == self.storage_ptr + and self.end == self.storage_ptr + self.storage_size + ) + + +class Weights(dict): + """ + A dictionary mapping from weight name to a tuple of (tensor, TensorProperties). + tensor represents the actual intial value of the weight. + TensorProperties represents the properties of the weight that are needed to recover the weight. + + We use two separate entries because `tensor` could be a clone of the original weight tensor, + so it doesn't have the same property as the original weight (such as underlying storage pointer). + """ + + def __init__(self, weight_dict: dict[str, tuple[torch.Tensor, TensorProperties]]): + super().__init__(weight_dict) + + def get_weight(self, name: str) -> tuple[torch.Tensor, TensorProperties]: + return self[name] + + def get_weight_properties(self, name: str) -> TensorProperties: + return self[name][1] + + +def get_complete( + group: OrderedSet[tuple[str, str]], models_weights: dict[str, Weights] +) -> tuple[str, str]: + """ + `group` is a (model_name, weight_name) tuple. + `model_weights` is a dictionary mapping from model name to its Weights. + + One of the tensor in `group` must be complete and they must share the + same underlying storage. + + Returns the name of the complete tensor in the `group`. If multiple + tensors are complete, returns an arbitrary one. + """ + + def get_tensor_properties(name_tuple: tuple[str, str]) -> TensorProperties: + # returns the tensor properties + (model_name, weight_name) = name_tuple + return models_weights[model_name].get_weight_properties(weight_name) + + for name_tuple in group: + tensor_property = get_tensor_properties(name_tuple) + if tensor_property.is_complete(): + return name_tuple + + raise RuntimeError("No complete tensor found in the group!") + + +def group_weights(all_weights: dict[str, Weights]) -> list[OrderedSet[tuple[str, str]]]: + """ + Group weights that share the same underlying storage. + + Returns a list of sets, each set contains a tuple of (model_name, weight_name). + """ + + weights_dict: dict[int, OrderedSet[tuple[str, str]]] = collections.defaultdict( + OrderedSet + ) # storage_key -> set(weight) + + for model_name, weights in all_weights.items(): + for weight_name, (_, properties) in weights.items(): + weights_dict[properties.storage_ptr].add((model_name, weight_name)) + + return list(weights_dict.values()) diff --git a/phivenv/Lib/site-packages/torch/export/pt2_archive/constants.py b/phivenv/Lib/site-packages/torch/export/pt2_archive/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6908d8276add83caf4aab48dfb100fd87833fd41 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/pt2_archive/constants.py @@ -0,0 +1,28 @@ +# Defined in torch/csrc/export/pt2_archive_constants.h +from torch._C._export import pt2_archive_constants + + +AOTINDUCTOR_DIR: str = pt2_archive_constants.AOTINDUCTOR_DIR +ARCHIVE_FORMAT_PATH: str = pt2_archive_constants.ARCHIVE_FORMAT_PATH +ARCHIVE_FORMAT_VALUE: str = pt2_archive_constants.ARCHIVE_FORMAT_VALUE +ARCHIVE_ROOT_NAME: str = pt2_archive_constants.ARCHIVE_ROOT_NAME +ARCHIVE_VERSION_PATH: str = pt2_archive_constants.ARCHIVE_VERSION_PATH +ARCHIVE_VERSION_VALUE: str = pt2_archive_constants.ARCHIVE_VERSION_VALUE +CONSTANTS_DIR: str = pt2_archive_constants.CONSTANTS_DIR +CUSTOM_OBJ_FILENAME_PREFIX: str = pt2_archive_constants.CUSTOM_OBJ_FILENAME_PREFIX +EXTRA_DIR: str = pt2_archive_constants.EXTRA_DIR +MODELS_DIR: str = pt2_archive_constants.MODELS_DIR +MODELS_FILENAME_FORMAT: str = pt2_archive_constants.MODELS_FILENAME_FORMAT +MODULE_INFO_PATH: str = pt2_archive_constants.MODULE_INFO_PATH +MTIA_DIR: str = pt2_archive_constants.MTIA_DIR +SAMPLE_INPUTS_DIR: str = pt2_archive_constants.SAMPLE_INPUTS_DIR +SAMPLE_INPUTS_FILENAME_FORMAT: str = pt2_archive_constants.SAMPLE_INPUTS_FILENAME_FORMAT +TENSOR_CONSTANT_FILENAME_PREFIX: str = ( + pt2_archive_constants.TENSOR_CONSTANT_FILENAME_PREFIX +) +WEIGHT_FILENAME_PREFIX: str = pt2_archive_constants.WEIGHT_FILENAME_PREFIX +WEIGHTS_DIR: str = pt2_archive_constants.WEIGHTS_DIR +XL_MODEL_WEIGHTS_DIR: str = pt2_archive_constants.XL_MODEL_WEIGHTS_DIR +XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH: str = ( + pt2_archive_constants.XL_MODEL_WEIGHTS_PARAM_CONFIG_PATH +) diff --git a/phivenv/Lib/site-packages/torch/export/unflatten.py b/phivenv/Lib/site-packages/torch/export/unflatten.py new file mode 100644 index 0000000000000000000000000000000000000000..1bf34bb1aa6ba4aa1c58fc46479cac56207c3309 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/export/unflatten.py @@ -0,0 +1,1665 @@ +# mypy: allow-untyped-defs +import abc +import copy +import logging +import operator +import re +from collections import defaultdict +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from enum import Enum +from typing import Any, Callable, cast, Optional, Union + +import torch +import torch.fx._pytree as fx_pytree +import torch.utils._pytree as pytree +from torch._library.fake_class_registry import FakeScriptObject +from torch.export._tree_utils import reorder_kwargs +from torch.export.exported_program import ( + ConstantArgument, + ExportedProgram, + ExportGraphSignature, + InputKind, + ModuleCallSignature, + SymBoolArgument, + SymFloatArgument, + SymIntArgument, + TensorArgument, +) +from torch.fx._symbolic_trace import is_fx_tracing +from torch.fx.graph_module import _get_attr, _get_attr_via_attr_list, _print_readable +from torch.utils._pytree import GetAttrKey, SequenceKey + +from ._remove_effect_tokens_pass import _remove_effect_tokens + + +log = logging.getLogger(__name__) + + +__all__ = [ + "FlatArgsAdapter", + "InterpreterModule", + "InterpreterModuleDispatcher", + "UnflattenedModule", + "unflatten", +] + + +class _AttrKind(Enum): + PARAMETER = "parameter" + BUFFER = "buffer" + CONSTANT = "constant" + MODULE = "module" + + +RUN_WITH_INTERPRETER = True + + +@contextmanager +def _disable_interpreter(): + global RUN_WITH_INTERPRETER + old_flag = RUN_WITH_INTERPRETER + RUN_WITH_INTERPRETER = False + try: + yield + finally: + RUN_WITH_INTERPRETER = old_flag + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr( + from_obj: Union[torch.Tensor, torch.ScriptObject, torch.nn.Module], + to_module: torch.nn.Module, + target: str, + attr_kind: _AttrKind, + persistent: bool = True, +): + *prefix, field = target.split(".") + # We need to generate all submodules of `to_module` that are at `prefix` and + # variants of `prefix` that differ only by call name. All of these submodules + # will then be assigned `from_obj` at `field` so that they can share this attribute. + # For example, if target is foo.bar.f, foo has another call name foo@1, + # and bar has other call names bar@1, bar@2, then we will assign f to + # foo.bar, foo.bar@1, foo.bar@2, foo@1.bar, foo@1.bar@1, foo@1.bar@2. + to_modules = {to_module} + for item in prefix: + ts: set[torch.nn.Module] = set() + for to_module in to_modules: + if not hasattr(to_module, item): + setattr(to_module, item, torch.nn.Module()) + ts.update( + t_call # type: ignore[misc] + for k, t_call in to_module._modules.items() + if _is_call_name(k, item) + ) + to_modules = ts + + for to_module in to_modules: + if attr_kind == _AttrKind.PARAMETER: + assert isinstance(from_obj, torch.nn.Parameter) + to_module.register_parameter(field, from_obj) + elif attr_kind == _AttrKind.BUFFER: + assert isinstance(from_obj, torch.Tensor) + to_module.register_buffer(field, from_obj, persistent=persistent) + elif attr_kind == _AttrKind.CONSTANT: + assert not isinstance(from_obj, FakeScriptObject), ( + "FakeScriptObject should only exist during tracing." + ) + assert isinstance( + from_obj, + ( + torch.Tensor, + torch.ScriptObject, + ), + ) + setattr(to_module, field, from_obj) + elif attr_kind == _AttrKind.MODULE: + assert isinstance(from_obj, torch.nn.Module) + setattr(to_module, field, from_obj) + + +class _SubmoduleBase: + _ty: Optional[str] + + def type_name(self) -> Optional[str]: + return self._ty + + +class InterpreterModule(_SubmoduleBase, torch.nn.Module): + """A module that uses torch.fx.Interpreter to execute instead of the usual + codegen that GraphModule uses. This provides better stack trace information + and makes it easier to debug execution. + """ + + graph_module: Optional[torch.fx.GraphModule] + + def __init__( + self, + graph: torch.fx.Graph, + ty: Optional[str] = None, + ): + super().__init__() + self.graph = graph + self._ty = ty + self.graph.owning_module = self # type: ignore[assignment] + self._run_with_interpreter = RUN_WITH_INTERPRETER + + def forward(self, *args, **kwargs): + assert self.graph_module is not None, "Didn't finalize this InterpreterModule" + if not is_fx_tracing() and ( + torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter + ): + # Dynamo cannot trace through torch.fx.Interpreter, so fall back to + # GraphModule codegen in this instance. + # Patch the codegened forward to run with this InterpreterModule, + # so attribute accesses, etc. are on this module instead. + return type(self.graph_module).forward(self, *args, **kwargs) + else: + if kwargs: + # Handle **kwargs. FX only natively supports positional + # arguments (through placeholders). So in order to pass in + # kwargs, we must correspond the names of the placeholders with + # the keys in the kwarg dict. + arg_list = list(args) + kwarg_names = self.arg_names[len(arg_list) :] + arg_list.extend( + kwargs[kwarg_name] + for kwarg_name in kwarg_names + if kwarg_name in kwargs + ) + + # Assert that the kwargs passed in exactly match the positional + # arguments specified by the GraphModule. This should be + # guaranteed by the unflattening process. + assert len(kwarg_names) == len(kwargs) + assert len(arg_list) == len(self.arg_names) + args = tuple(arg_list) + + return torch.fx.Interpreter(self, graph=self.graph).run( + *args, enable_io_processing=False + ) + + def finalize(self): + # We need to "finalize" because GraphModule populates its own state_dict + # based on the get_attrs observed in the graph. So we need to fully + # construct the graph and call _sink_params before generating this + # GraphModule. + + # need to set `graph_module` directly on the dict to avoid it getting + # registered as a submodule. + self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) + self.graph.lint() + + # Cache arg names for kwarg handling (see forward()) + self.arg_names = [] + for node in self.graph.nodes: + if node.op == "placeholder": + self.arg_names.append(node.target) + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "InterpreterModule", + print_output, + include_stride, + include_device, + colored, + ) + + +class InterpreterModuleDispatcher(_SubmoduleBase, torch.nn.Module): + """ + A module that carries a sequence of InterpreterModules corresponding to + a sequence of calls of that module. Each call to the module dispatches + to the next InterpreterModule, and wraps back around after the last. + """ + + def __init__(self, attrs: set[str], call_modules: list[InterpreterModule]): + super().__init__() + assert call_modules + self._modules = call_modules[0]._modules + for accessor in attrs: + setattr(self, accessor, getattr(call_modules[0], accessor)) + self._ty = call_modules[0]._ty + self._call_modules = call_modules + self._num_calls = 0 + + def forward(self, *args, **kwargs): + call_module = self._call_modules[self._num_calls] + self._num_calls = (self._num_calls + 1) % len(self._call_modules) + try: + return call_module(*args, **kwargs) + except Exception: + self._num_calls = 0 + raise + + def call_modules(self): + return self._call_modules + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + outputs = [ + mod.print_readable( + print_output, + include_stride, + include_device, + colored, + ) + for mod in self._call_modules + ] + return "\n".join(outputs) + + +class FlatArgsAdapter(abc.ABC): + """ + Adapts input arguments with ``input_spec`` to align ``target_spec``. + """ + + @abc.abstractmethod + def adapt( + self, + target_spec: pytree.TreeSpec, + input_spec: pytree.TreeSpec, + input_args: list[Any], + metadata: Optional[dict[str, Any]] = None, + obj: Optional[Any] = None, + ) -> list[Any]: + """NOTE: This adapter may mutate given ``input_args_with_path``.""" + ... + + +class UnflattenedModule(torch.nn.Module): + def __init__( + self, + export_module: ExportedProgram, + flat_args_adapter: Optional[FlatArgsAdapter] = None, + ): + super().__init__() + if export_module.graph_signature.backward_signature is not None: + raise ValueError("Unflattening on JointExportModule NYI") + + fqn_list = [entry.fqn for entry in export_module.module_call_graph] + assert fqn_list[0] == "" + export_graph = deepcopy(export_module.graph) + self.graph_signature = deepcopy(export_module.graph_signature) + self.graph = torch.fx.Graph() + self.graph.owning_module = self # type: ignore[assignment] + self.module_call_graph = deepcopy(export_module.module_call_graph) + self.flat_args_adapter = flat_args_adapter + + self.meta = export_module.graph_module.meta + self.meta["unflattened_module"] = self + + # Flag to indicate whether args have been adapted. + self.adapted = False + self._run_with_interpreter = RUN_WITH_INTERPRETER + + _inplace_buffer_and_input_mutations(export_graph, self.graph_signature) + + self.ivals = _IVals() + # for any intermediate value of a mutation that is read, track the mutation + seen_modules, seen_attrs = _outline_submodules(export_graph, self) + # for each read intermediate value of a mutation, find where it was created, + # and perform the mutation + self.ivals.update(seen_modules.values()) + # move attributes that correspond to graph arguments for HOPs + # from exported program to unflattened submodules + _copy_graph_attrs(export_module._graph_module, self, seen_attrs) + + self.range_constraints = export_module.range_constraints + self.equality_constraints: list = [] + + # aliasing/unused param or buffer issues: + # in strict-mode export, dynamo export will deduplicate aliased tensors, + # and ignore unused tensors. For aliasing, this causes issues when some aliases + # are unused, and we're unable to match the placeholder node to the correct FQN. + # This leads to the graph signature potentially having the wrong target FQN, + # and downstream issues where parameters are assigned to the wrong target attribute, + # mismatching the relevant placeholder node in the unflattened module. + # To resolve this we restore (_assign_attr) all aliased/unused tensors in + # the state_dict as module attributes, but only keep the used tensors in the + # graph's forward pass (_sink_params). + state_dict = export_module.state_dict + assigned_params: set[str] = set() # tracking unused params + id_to_param: dict[int, torch.nn.Parameter] = {} # handling weight-sharing + for name in self.graph_signature.parameters: # this loop adds used params + param = state_dict[name] + if id(param) not in id_to_param: + id_to_param[id(param)] = torch.nn.Parameter( + param.clone(), requires_grad=param.requires_grad + ) + + _assign_attr( + id_to_param[id(param)], + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + assigned_params.add(name) + + non_persistent_buffers = set(self.graph_signature.non_persistent_buffers) + assigned_buffers: set[str] = set() # tracking unused buffers + id_to_buffer: dict[int, tuple[torch.nn.Parameter, bool]] = {} + for name in self.graph_signature.buffers: # this loop adds used buffers + if name in non_persistent_buffers: + persistent = False + buffer = export_module.constants[name] + else: + persistent = True + buffer = state_dict[name] + + if id(buffer) not in id_to_buffer: + id_to_buffer[id(buffer)] = (buffer.clone(), persistent) + + _assign_attr( + id_to_buffer[id(buffer)][0], + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=persistent, + ) + assigned_buffers.add(name) + + # restore aliased/unused params and buffers + # these appear in state dict but not graph signature + for name, tensor in state_dict.items(): + if name in assigned_params or name in assigned_buffers: # already assigned + continue + + is_buffer = False + if id(tensor) in id_to_buffer or not isinstance( + tensor, torch.nn.Parameter + ): # aliased buffer + is_buffer = True + + if is_buffer: + if ( + id(tensor) not in id_to_buffer + ): # this is completely unused (not weight-sharing) + id_to_buffer[id(tensor)] = ( + tensor, + True, + ) # assign to respect original model + _assign_attr( + id_to_buffer[id(tensor)][0], + self, + name, + attr_kind=_AttrKind.BUFFER, + persistent=True, + ) + else: + if id(tensor) not in id_to_param: # this is unused + id_to_param[id(tensor)] = tensor + _assign_attr( + id_to_param[id(tensor)], + self, + name, + attr_kind=_AttrKind.PARAMETER, + ) + + # use id map so we don't double-clone aliased constants + id_to_const: dict[int, Union[torch.Tensor, torch._C.ScriptObject]] = {} + for fqn, constant in export_module.constants.items(): + if id(constant) not in id_to_const: + if isinstance(constant, torch.Tensor): + constant = constant.clone() + id_to_const[id(constant)] = constant + _constant = id_to_const[id(constant)] + _assign_attr( + _constant, + self, + fqn, + attr_kind=_AttrKind.CONSTANT, + ) + + # This is to handle parameters/buffers that point to the same tensor + # object id -> list of (node_name, target_name) + consts_map: dict[int, list[tuple[str, str]]] = defaultdict(list) + consts_targets: set[str] = set() + + def add_to_consts_map(obj_id, node_name, target_name): + name_list = consts_map[obj_id] + name_list.append((node_name, target_name)) + + added_params_buffers: set[str] = set() # track aliased/unused params, buffers + for s in self.graph_signature.input_specs: + if s.kind == InputKind.PARAMETER or ( + s.kind == InputKind.BUFFER and s.persistent + ): + assert hasattr(s.arg, "name") + assert isinstance(s.target, str) + add_to_consts_map( + id(export_module.state_dict[s.target]), s.arg.name, s.target + ) + consts_targets.add(s.target) + added_params_buffers.add(s.target) + elif ( + (s.kind == InputKind.BUFFER and not s.persistent) + or s.kind == InputKind.CONSTANT_TENSOR + or s.kind == InputKind.CUSTOM_OBJ + ): + assert hasattr(s.arg, "name") + assert isinstance(s.target, str) + add_to_consts_map( + id(export_module.constants[s.target]), s.arg.name, s.target + ) + consts_targets.add(s.target) + + # add constants that are aliased and don't appear in graph signature + for const_name, const in export_module.constants.items(): + if const_name not in consts_targets: + assert id(const) in consts_map, ( + "Constants should be either aliased or appear in graph signature" + ) + ph_name, _ = consts_map[id(const)][0] + add_to_consts_map(id(const), ph_name, const_name) + added_params_buffers.add(s.target) + + # add aliased/unused params and buffers that don't appear in graph signature + for fqn, tensor in export_module.state_dict.items(): + if fqn not in added_params_buffers: + if id(tensor) not in consts_map: + # completely unused (no weight-sharing), ignore. + # this weight doesn't appear in graph module, + # so won't cause FQN assignment issues + continue + ph_name, _ = consts_map[id(tensor)][0] + add_to_consts_map(id(tensor), ph_name, fqn) + + # node name -> list of possible targets + inputs_to_state: dict[str, list[str]] = {} + for node_target in consts_map.values(): + targets = [t[1] for t in node_target] + for n, _ in node_target: + inputs_to_state[n] = targets + + _sink_params(self, inputs_to_state, []) + + redirected_call_indices = _deduplicate_modules(seen_modules.values()) + fqn_list = [fqn for fqn in fqn_list if fqn not in redirected_call_indices] + + self._dispatch_modules(redirected_call_indices, consts_targets) + fqn_list = [fqn for fqn in fqn_list if "@" not in fqn] + + # Cache so we don't have to compute this every time. + # NOTE: this needs to be kept in sync with the placeholders in + # self.graph, but currently we have no way to guarantee that. + self.input_placeholders = [ + node for node in self.graph.nodes if node.op == "placeholder" + ] + self.check_input_constraints = True + # TODO(zhxchen17) We can register modules ahead of time instead of reorder later. + fqn_order = {fqn: i for i, fqn in enumerate(fqn_list)} + # In the case of legacy IR, we might be missing some modules from metadata. + for name, _ in self.named_modules(remove_duplicate=False): + if name not in fqn_order: + fqn_order[name] = len(fqn_order) + _reorder_submodules(self, fqn_order) + self.graph.lint() + self.finalize() + + def _print_graph(self): + for fqn, mod in self.named_modules(): + print(fqn + ":") + if hasattr(mod, "graph") and isinstance(mod.graph, torch.fx.Graph): + print(mod.graph) + + def _adapt_flat_args(self, flat_args, in_spec, input): + signature = self.module_call_graph[0].signature + if in_spec == signature.in_spec: + return flat_args + + if self.flat_args_adapter is None: + raise TypeError( + "There is no flat args adapter specified. " + "Are you sure you are calling this with the right arguments? " + ) + else: + flat_args = self.flat_args_adapter.adapt( + target_spec=signature.in_spec, + input_spec=in_spec, + input_args=flat_args, + metadata=self.meta, + obj=input, + ) + + if len(flat_args) != signature.in_spec.num_leaves: + raise TypeError( + f"Flat args adaption failed, number of args mismatch " + f"Adatped: {len(flat_args)} \n" + f"Exported module: {signature.in_spec.num_leaves}" + ) + return flat_args + + def process_forward_inputs(self, *args, **kwargs): + signature = self.module_call_graph[0].signature + + reordered_kwargs = kwargs + if kwargs: + reordered_kwargs = reorder_kwargs(kwargs, signature.in_spec) + + flat_args_with_path, in_spec = pytree.tree_flatten_with_path( + (args, reordered_kwargs) + ) + flat_args = [x[1] for x in flat_args_with_path] + + if is_fx_tracing(): + return flat_args + + if in_spec != signature.in_spec: + if not self.adapted: + print( + "Input treespec does not match with exported module's: \n" + f"Input treespec: {in_spec}. ", + f"Exported module treespec: {signature.in_spec}", + ) + print("Adapting flat arg to match exported module's treespec") + flat_args = self._adapt_flat_args(flat_args, in_spec, args) + self.adapted = True + + if self.check_input_constraints: + # Import here to avoid an unfortunate circular dependency. + # TODO(suo): untangle this. + from torch._export.utils import _check_input_constraints_for_graph + + if self.adapted is True: + # TODO(suo): The FlatArgsAdapter returns a list of flat args, + # which we don't have keypaths for. For now, just create a dummy + # keypath to associate with the arg. + new_flat_args_with_path = [ # type: ignore[var-annotated] + ((SequenceKey(idx=0), GetAttrKey(name="")), arg) + for arg in flat_args + ] + else: + new_flat_args_with_path = flat_args_with_path # type: ignore[assignment] + + _check_input_constraints_for_graph( + self.input_placeholders, new_flat_args_with_path, self.range_constraints + ) + + return flat_args + + def forward(self, *args, **kwargs): + flat_args = torch._dynamo.disable( + self.process_forward_inputs, + reason="do not trace into preprocessing the inputs", + )(*args, **kwargs) + signature = self.module_call_graph[0].signature + + if is_fx_tracing(): + return_val = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + # For scalar return value, fx.Graph wraps in a tuple + if isinstance(return_val, tuple) and len(return_val) == 1: + return return_val[0] + return return_val + + if torch.compiler.is_dynamo_compiling() or not self._run_with_interpreter: + tree_out = type(self.graph_module).forward(self, *flat_args) # type: ignore[union-attr] + else: + tree_out = torch.fx.Interpreter(self, graph=self.graph).run( + *flat_args, enable_io_processing=False + ) + return pytree.tree_unflatten(tree_out, signature.out_spec) + + def finalize(self): + self.__dict__["graph_module"] = torch.fx.GraphModule(self, self.graph) + self.graph.lint() + + def _dispatch_modules(self, redirected_call_indices, consts_targets): + """For a module whose call signatures are preserved, replace + multiple modules corresponding to multiple calls to that module + with a single dispatcher module that tracks which module to call. + """ + + # for each fqn whose module call signature is preserved, + # map that fqn to a list of called modules + called_modules = defaultdict(list) + for entry in self.module_call_graph: + if entry.fqn and entry.signature: + # some modules were removed and their fqns redirected to other + # fqns during deduplication + fqn = entry.fqn + mod = _get_attr(self, redirected_call_indices.get(fqn, fqn)) + base, idx = fqn.split("@") if "@" in fqn else [fqn, "0"] + called_modules[base].append((int(idx), mod)) + + attrs_map = defaultdict(set) + for target in consts_targets: + if "." in target: + orig_fqn, name = target.rsplit(".", 1) + attrs_map[orig_fqn].add(name) + else: + attrs_map[""].add(target) + + # replace multiple call modules with a single dispatcher module + for orig_fqn, indexed_call_modules in called_modules.items(): + call_modules = [mod for _, mod in sorted(indexed_call_modules)] + if len(call_modules) > 1: + for i in range(len(call_modules)): + fqn = _call_name(orig_fqn, i + 1) + if fqn not in redirected_call_indices: + *prefix, name = fqn.split(".") + _get_attr_via_attr_list(self, prefix)._modules.pop(name) + self.set_submodule( + orig_fqn, + InterpreterModuleDispatcher(attrs_map[orig_fqn], call_modules), + ) + + # elide call indices in call modules because they are + # tracked automatically inside the dispatcher module + def elide_call_indices(prefix, graph): + for node in graph.nodes: + if node.op == "call_module": + fqn = node.target.split("@")[0] + path = f"{prefix}.{fqn}" if prefix else fqn + if path in called_modules: + node.target = fqn + + for fqn, mod in self.named_modules(remove_duplicate=False): + if hasattr(mod, "graph"): + elide_call_indices(fqn, mod.graph) + elif hasattr(mod, "_call_modules"): + for mod_ in mod._call_modules: + assert hasattr(mod_, "graph") + elide_call_indices(fqn, mod_.graph) + + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + ): + return _print_readable( + self, + "UnflattenedModule", + print_output, + include_stride, + include_device, + colored, + ) + + +def unflatten( + module: ExportedProgram, flat_args_adapter: Optional[FlatArgsAdapter] = None +) -> UnflattenedModule: + """Unflatten an ExportedProgram, producing a module with the same module + hierarchy as the original eager module. This can be useful if you are trying + to use :mod:`torch.export` with another system that expects a module + hierachy instead of the flat graph that :mod:`torch.export` usually produces. + + .. note:: The args/kwargs of unflattened modules will not necessarily match + the eager module, so doing a module swap (e.g. :code:`self.submod = + new_mod`) will not necessarily work. If you need to swap a module out, you + need to set the :code:`preserve_module_call_signature` parameter of + :func:`torch.export.export`. + + Args: + module (ExportedProgram): The ExportedProgram to unflatten. + flat_args_adapter (Optional[FlatArgsAdapter]): Adapt flat args if input TreeSpec does not match with exported module's. + + Returns: + An instance of :class:`UnflattenedModule`, which has the same module + hierarchy as the original eager module pre-export. + """ + module = _remove_effect_tokens(module) + return UnflattenedModule(module, flat_args_adapter) + + +def _inplace_buffer_and_input_mutations( + graph: torch.fx.Graph, + graph_signature: ExportGraphSignature, +) -> None: + """Transform buffer and input mutations from their functionalized form + into copy_ nodes in the graph. + + Functionalization represents a buffer mutation by passing the buffer as + an input and output. For example, consider the eager code: + def forward(self, x): + self.buffer += x + return x * x + + This corresponds to a graph that looks like: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + mul = aten.mul(x, x) + return (mutated_buffer, mul) + + We want to inplace this into something that looks like the original + eager code: + def forward(self, buffer, x): + mutated_buffer = aten.add(buffer, x) + buffer.copy_(mutated_buffer) + mul = aten.mul(x, x) + return (mul,) + + Input mutations are handled similarly. + """ + output_node = next(iter(reversed(graph.nodes))) + assert output_node.op == "output" and len(output_node.args) == 1 + return_args = output_node.args[0] + + input_name_to_node = { + node.name: node for node in graph.nodes if node.op == "placeholder" + } + mutation_name_to_input_name = {} + + # Collect mutated buffers. + buffer_fqn_to_input_name = { + buffer_fqn: k for k, buffer_fqn in graph_signature.inputs_to_buffers.items() + } + mutation_name_to_input_name = { + k: buffer_fqn_to_input_name[buffer_fqn] + for k, buffer_fqn in graph_signature.buffers_to_mutate.items() + } + # Collect mutated user inputs. + mutation_name_to_input_name.update(graph_signature.user_inputs_to_mutate) + + num_mutations = len(mutation_name_to_input_name) + + for mutation in return_args[:num_mutations]: + input_name = mutation_name_to_input_name[mutation.name] + input_node = input_name_to_node[input_name] + + with graph.inserting_after(mutation): + # Create a copy_ node that inplaces the mutation. + new_node = graph.create_node( + "call_function", torch.ops.aten.copy_.default, (input_node, mutation) + ) + for k, v in mutation.meta.items(): + new_node.meta[k] = v + # Replace all uses of the previously functional mutation with + # our copy_ node. + mutation.replace_all_uses_with(new_node, lambda x: x is not new_node) + + # Remove the mutated buffer / input from the graph outputs, since we don't + # need to thread it through anymore. + user_outputs = tuple(return_args[num_mutations:]) + output_node.args = ((user_outputs),) + + +def _is_prefix(candidate, target): + """Check whether `candidate` is a prefix of `target`.""" + return len(candidate) < len(target) and target[: len(candidate)] == candidate + + +def _compute_accessor(parent_fqn: str, child_fqn: str) -> str: + if parent_fqn == "": + # Handle the root module correctly. + return child_fqn + + parent_split = parent_fqn.split(".") + child_split = child_fqn.split(".") + + # TODO: support skip connection by inlining the child module. + if child_split[: len(parent_split)] != parent_split: + raise RuntimeError( + f"Child module '{child_fqn}' is not a descendant of parent module '{parent_fqn}'." + "This is currently unsupported." + "Please try to make child module attach to parent module directly." + ) + return ".".join(child_split[len(parent_split) :]) + + +def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): + def graph_dump(graph: torch.fx.Graph) -> str: + ret = [] + nodes_idx: dict[int, int] = {} + + def arg_dump(arg) -> str: + if isinstance(arg, torch.fx.Node): + return "%" + str(nodes_idx[id(arg)]) + return str(arg) + + for i, node in enumerate(graph.nodes): + args_dump = [str(arg) for arg in pytree.tree_map(arg_dump, node.args)] + args_dump += [ + f"{key}={value}" + for key, value in pytree.tree_map(arg_dump, node.kwargs).items() + ] + target = node.target if node.op in ("call_function", "get_attr") else "" + ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") + nodes_idx[id(node)] = i + return "\n".join(ret) + + assert isinstance(x.graph, torch.fx.Graph) + assert isinstance(y.graph, torch.fx.Graph) + return graph_dump(x.graph) == graph_dump(y.graph) + + +def _add_spec(gm: torch.nn.Module, spec) -> str: + i = 0 + while hasattr(gm, f"_spec_{i}"): + i += 1 + name = f"_spec_{i}" + setattr(gm, name, spec) + return name + + +def _generate_flatten(gm: torch.fx.GraphModule, node) -> torch.fx.Node: + flatten = gm.graph.call_function(pytree.tree_flatten, (node,)) + getitem_0 = gm.graph.call_function(operator.getitem, (flatten, 0)) + return getitem_0 + + +def _generate_flatten_spec( + gm: Union[torch.fx.GraphModule, InterpreterModule, UnflattenedModule], node, spec +) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(fx_pytree.tree_flatten_spec, (node, spec_node)) + + +def _generate_unflatten( + gm: Union[torch.fx.GraphModule, InterpreterModule, UnflattenedModule], nodes, spec +) -> torch.fx.Node: + name = _add_spec(gm, spec) + spec_node = gm.graph.get_attr(name) + return gm.graph.call_function(pytree.tree_unflatten, (nodes, spec_node)) + + +def _get_submodule(mod: torch.nn.Module, target: str): + *prefix, field = target.split(".") + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + return None + + if not isinstance(submod, torch.nn.Module): + return None + + mod = submod + + return getattr(mod, field, None) + + +def _add_submodule( + mod: torch.nn.Module, + target: str, + module_to_add: torch.nn.Module, + create_module: Optional[Callable[[str], torch.nn.Module]] = None, +): + *prefix, field = target.split(".") + + for i, item in enumerate(prefix): + submod = getattr(mod, item, None) + + if submod is None: + if create_module is not None: + submod = create_module(".".join(prefix[: i + 1])) + else: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, module_to_add) + + +def _call_name(base: str, n: int) -> str: + # Given n >= 0, generate call names to a submodule `base` of the form + # `base`, `base@1`, `base@2`, etc. + return base if n == 1 else f"{base}@{n - 1}" + + +def _is_call_name(call_name: str, base: str) -> bool: + # Recognize when call_name = _call_name(base, n) for some n >= 0. + return re.match(re.escape(base) + r"(@\d+)?$", call_name) is not None + + +class _ModuleFrame: + def __init__( + self, + flat_graph: torch.fx.Graph, + nodes: tuple[torch.fx.Node, ...], + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + parent, + module_stack: list[tuple[str, Optional[str], int]], + module_id, + module_call_graph: dict[str, ModuleCallSignature], + module: Optional[Union[torch.fx.GraphModule, UnflattenedModule]] = None, + ): + self.flat_graph = flat_graph + self.nodes = nodes + self.seen_nodes = seen_nodes + self.seen_modules = seen_modules + self.seen_attrs = seen_attrs + self.created_modules = created_modules + self.parent = parent + self.module_stack = module_stack + self.module_id = module_id + + self.module_call_graph = module_call_graph + self.verbose = False + + self.fqn, ty, num_calls = self.module_stack[-1] + # generate call name for self.fqn + self.child_fqn = _call_name(self.fqn, num_calls + 1) + + self.module: Union[torch.fx.GraphModule, UnflattenedModule, InterpreterModule] + if module is not None: + self.module = module + self.ivals = module.ivals if hasattr(module, "ivals") else {} # type: ignore[var-annotated] + else: + self.module = self.created_modules.get( + self.fqn, + InterpreterModule(torch.fx.Graph(), ty=ty), + ) + self.ivals = parent.ivals + + self.graph = self.module.graph + + # Mapping of nodes in the flat graph to nodes in this graph. + self.node_map: dict[torch.fx.Node, torch.fx.Node] = {} + self.node_to_placeholder = {} + + self.parent_call_module: Optional[torch.fx.Node] = None + if parent is not None: + accessor = _compute_accessor(parent.fqn, self.child_fqn) + + def create_module(fqn): + path = f"{parent.fqn}.{fqn}" if parent.fqn else fqn + if path in self.created_modules: + return self.created_modules[path] + submod = InterpreterModule(torch.fx.Graph(), ty=ty) + self.created_modules[path] = submod + return submod + + _add_submodule(parent.module, accessor, self.module, create_module) + self.parent_call_module = parent.graph.call_module(accessor) + if self.seen_modules[self.module_id]: + base_module_frame = self.seen_modules[self.module_id][0] + self.module._modules = base_module_frame.module._modules + self.seen_modules[self.module_id].append( + _SubmoduleEntry( + parent_fqn=self.parent.fqn, + parent_module=self.parent.module, + parent_call_module=self.parent_call_module, + fqn=self.fqn, + call_idx=num_calls + 1, + module=self.module, + ) + ) + + signature = module_call_graph.get(self.child_fqn) + if signature is not None and self.parent is not None: + assert signature.in_spec.num_children == 2 + args_spec = signature.in_spec.children_specs[0] + kwargs_spec = signature.in_spec.children_specs[1] + assert args_spec.context is None + assert kwargs_spec.context is not None + + with self.graph.inserting_after(None): + arg_nodes = [ + self.graph.placeholder(f"_positional_arg_{idx}") + for idx in range(args_spec.num_children) + ] + kwarg_nodes = {} + for name in kwargs_spec.context: + kwarg_nodes[name] = self.graph.placeholder(name) + flat_args = _generate_flatten_spec( + self.module, + (tuple(arg_nodes), kwarg_nodes), + signature.in_spec, + ) + for idx, arg in enumerate(signature.inputs): + flat_arg_node = self.graph.create_node( + op="call_function", + target=operator.getitem, + args=(flat_args, idx), + name=( + arg.name + if not isinstance(arg, ConstantArgument) + else f"_constant_{idx}" + ), + ) + if isinstance(arg, ConstantArgument): + continue + + if arg.name in self.seen_nodes: + flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) + self.node_to_placeholder[self.seen_nodes[arg.name]] = ( + flat_arg_node + ) + + with self.parent.graph.inserting_before(self.parent_call_module): + input_nodes: list[Optional[torch.fx.Node]] = [] + for input in signature.inputs: + if isinstance(input, ConstantArgument): + input_nodes.append(input.value) # type: ignore[arg-type] + elif input.name not in self.seen_nodes: + input_nodes.append(None) + else: + assert isinstance( + input, + ( + TensorArgument, + SymIntArgument, + SymBoolArgument, + SymFloatArgument, + ), + ) + input_nodes.append( + self.parent.remap_input(self.seen_nodes[input.name]) + ) + + inputs_node = _generate_unflatten( + self.parent.module, + input_nodes, + signature.in_spec, + ) + + args_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 0) + ) + kwargs_node = self.parent.graph.call_function( + operator.getitem, (inputs_node, 1) + ) + arg_nodes = [ + self.parent.graph.call_function(operator.getitem, (args_node, i)) + for i in range(args_spec.num_children) + ] + kwarg_nodes = { + k: self.parent.graph.call_function( + operator.getitem, (kwargs_node, k) + ) + for k in kwargs_spec.context + } + assert self.parent_call_module is not None + self.parent_call_module.args = tuple(arg_nodes) + self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment] + + def add_placeholder(self, x): + assert self.fqn != "", f"Cannot add placeholder {x} to root module" + assert x.graph is self.flat_graph + # x is not in subgraph, create a new placeholder for subgraph + with self.graph.inserting_before(None): + placeholder_node = self.graph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelevant for + # the placeholder node + placeholder_node.meta = copy.copy(x.meta) + self.node_to_placeholder[x] = placeholder_node + + def copy_sym_call_function(self, x): + # This only exists because we deduplicate sym_size nodes in the flat export graph, + # and if preserve_module_call_signature is set, we may not be able to pass sym_size + # nodes, or their downstream users, as inputs to submodule calls. + # To avoid this we copy these call_function nodes with sym_type results. + # This should however only be done for sym_type nodes - call_function nodes on tensors + # should not be deduplicated in the first place. + args = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.args) + kwargs = pytree.tree_map_only(torch.fx.Node, self.remap_input, x.kwargs) + node = self.graph.call_function(x.target, args, kwargs) + node.meta = copy.copy(x.meta) + self.node_map[x] = node + return node + + def remap_input(self, x): + assert x.graph is self.flat_graph + if x in self.node_map: + return self.node_map[x] + self.print(f"remap_input({x})") + if x in self.node_to_placeholder: + return self.node_to_placeholder[x] + elif ( + x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None + # allow placeholder creation if we are not preserving module call signature + ): + self.add_placeholder(x) + if self.parent_call_module is not None: + # Important to *prepend* the output to match how we are + # inserting placeholder nodes. + with self.parent.graph.inserting_before(self.parent_call_module): + self.parent_call_module.insert_arg(0, self.parent.remap_input(x)) + return self.node_to_placeholder[x] + elif x.op == "call_function" and ( + x.target + in ( + torch.ops.aten.sym_size.int, + torch.ops.aten.item.default, + torch.ops.aten.unbind.int, + torch.ops.aten.sum.dim_IntList, + torch.ops.aten.view.default, + torch.ops.aten.diff.default, + ) + or (hasattr(x.target, "__module__") and x.target.__module__ == "_operator") + ): + # export deduplicates sym_size nodes, and may need to re-copy them + # if module call signature needs to be preserved + self.copy_sym_call_function(x) + return self.node_map[x] + elif self.module_call_graph.get(self.fqn) is not None: + # x is reading the intermediate value of a mutation, so record it; + # later we will find where it was created and perform the update + return self.ivals.read(self, x) # type: ignore[operator, union-attr] + else: + raise RuntimeError( + f"Could not run remap_input() on op type: {x.op} for node {x}" + ) + + def finalize_outputs(self): + self.created_modules.pop(self.fqn, None) + + orig_outputs = [] + + signature = self.module_call_graph.get(self.child_fqn) + if signature is not None and self.parent is not None: + for output in signature.outputs: + if isinstance( + output, + ( + TensorArgument, + SymIntArgument, + SymBoolArgument, + SymFloatArgument, + ConstantArgument, + ), + ): + if output.name in self.seen_nodes: + orig_outputs.append(self.seen_nodes[output.name]) + else: + orig_outputs.append(None) + else: + raise RuntimeError( + f"Unsupported data type for output node: {output}" + ) + + def get_actual_output_node(output): + if output is None: + return None + + seen_node = self.seen_nodes[output.name] + if seen_node in self.node_map: + return self.node_map[seen_node] + elif seen_node in self.node_to_placeholder: + return self.node_to_placeholder[seen_node] + else: + raise RuntimeError( + f"Could not find output node {output}. Graph: {self.graph}" + ) + + tree_out_node = _generate_unflatten( + self.module, + tuple(get_actual_output_node(output) for output in orig_outputs), + signature.out_spec, + ) + parent_out: Optional[torch.fx.Node] = _generate_flatten_spec( + self.parent.module, self.parent_call_module, signature.out_spec + ) + graph_outputs: Union[torch.fx.Node, list[torch.fx.Node]] = tree_out_node + else: + graph_outputs = [] + # Iterate through nodes we have copied into self.graph. + for orig_node in self.node_map.keys(): + for user_node in orig_node.users: + if user_node.name not in self.seen_nodes: + # external user node, need to expose as an output + orig_outputs.append(orig_node) + graph_outputs.append(self.node_map[orig_node]) + break + + parent_out = self.parent_call_module + if len(graph_outputs) == 1: + graph_outputs = graph_outputs[0] + + assert isinstance(graph_outputs, (list, torch.fx.Node)) + + self.graph.output(graph_outputs) + + # Rewrite outputs in parent module + if parent_out is None: + return + + parent_out.meta["val"] = ( + graph_outputs.meta.get("val") + if isinstance(graph_outputs, torch.fx.Node) + else [o.meta.get("val") for o in graph_outputs] + ) + + if len(orig_outputs) == 1 and signature is None: + self.parent.node_map[orig_outputs[0]] = parent_out + else: + for i, orig_output in enumerate(orig_outputs): + if orig_output is None: + continue + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(parent_out)[i].node # type: ignore[index] + proxy_out.meta["val"] = orig_output.meta.get("val") + self.parent.node_map[orig_output] = proxy_out + + def copy_node(self, node): + self.print("copying", node.format_node()) + self.node_map[node] = self.graph.node_copy(node, self.remap_input) + self.seen_nodes[node.name] = node + + def run_outer(self): + for i, node in enumerate(self.flat_graph.nodes): + self.print(i, node.meta.get("nn_module_stack"), node.format_node()) + + # Copy all graph inputs + node_idx: int = 0 + node = self.nodes[node_idx] + while node.op == "placeholder": + self.copy_node(node) + node_idx += 1 + node = self.nodes[node_idx] + + self.run_from(node_idx) + + # Copy graph outputs + for node in self.flat_graph.nodes: + if node.op == "output": + self.copy_node(node) + + def print(self, *args, **kwargs): + if self.verbose: + print(*args, **kwargs) + + def run_from(self, node_idx): + module_idx = 0 + # Walk through the graph, building up a new graph with the right submodules + while node_idx < len(self.nodes): + node = self.nodes[node_idx] + assert node.op != "placeholder" + + self.print() + self.print("STEP", node_idx, node.format_node()) + self.print(self.module_stack) + depth = len(self.module_stack) + if node.op == "output": + if depth == 1: + # We want the output node of the original graph to be handled + # specially by the outermost stack frame (in run_outer). So + # skip finalization here. + return node_idx + + # We've reached the end of the graph. Wrap up all the existing stack frames. + self.finalize_outputs() + return node_idx + + if len(node.meta.get("nn_module_stack", {})) == 0: + raise RuntimeError(f"Unable to find nn_module_stack for node {node}") + + nn_module_stack = node.meta["nn_module_stack"] + from torch._export.passes._node_metadata_hook import ( + _EMPTY_NN_MODULE_STACK_KEY, + ) + + if ( + len(nn_module_stack) == 1 + and _EMPTY_NN_MODULE_STACK_KEY in nn_module_stack + ): + # Empty case from the node_metadata_hook + node_module_stack = self.module_stack + else: + node_module_stack = [ + ( + path, + ty if path else None, + int(k.split("@")[-1]) if "@" in k else 0, + ) + for k, (path, ty) in node.meta["nn_module_stack"].items() + ] + + if node_module_stack[:depth] != self.module_stack: + # This means that the current module is done executing and the + # current node is the beginning of a new module. + # + # In this case, we should finalize this module and return without + # incrementing the node counter. + self.finalize_outputs() + self.print("outlining", self.fqn) + self.print(self.graph) + return node_idx + + assert node_module_stack is not None + + if _is_prefix(self.module_stack, node_module_stack): + # This means that the current node represents the execution of a new + # module. + next_module = node_module_stack[depth] + self.print("Creating new stack frame for", next_module) + # Run a nested version of module outliner from the current node + # counter. Once it is complete, continue from that point. + next_module_key = list(node.meta["nn_module_stack"].keys())[depth] + node_idx = _ModuleFrame( + self.flat_graph, + self.nodes, + self.seen_nodes, + self.seen_modules, + self.seen_attrs, + self.created_modules, + self, + self.module_stack + [next_module], + next_module_key.split("@")[0], + self.module_call_graph, + ).run_from(node_idx) + module_idx += 1 + continue + + # The only remaining possibility is that we are in the right stack + # frame. Copy the node into this frame's graph and increment the node counter. + assert node_module_stack == self.module_stack + + if node.op == "get_attr": + # this must be a graph argument for a HOP + self.seen_attrs[self.child_fqn].add(node.target) + + self.copy_node(node) + node_idx += 1 + + +@dataclass +class _SubmoduleEntry: + parent_fqn: str + parent_module: torch.nn.Module + parent_call_module: torch.fx.Node + fqn: str + call_idx: int + module: torch.nn.Module + + +def _outline_submodules(orig_graph: torch.fx.Graph, root_module: UnflattenedModule): + seen_nodes: dict[str, torch.fx.Node] = {} + seen_modules: dict[int, list[_SubmoduleEntry]] = defaultdict(list) + seen_attrs: dict[str, set[str]] = defaultdict(set) + created_modules: dict[str, torch.nn.Module] = {} + _ModuleFrame( + orig_graph, + tuple(orig_graph.nodes), + seen_nodes, + seen_modules, + seen_attrs, + created_modules, + None, + [("", None, 0)], + "", + { + entry.fqn: entry.signature + for entry in root_module.module_call_graph + if entry.signature + }, + module=root_module, + ).run_outer() + return seen_modules, seen_attrs + + +def _reorder_submodules( + parent: torch.nn.Module, fqn_order: dict[str, int], prefix: str = "" +): + # TODO Can be optimized by adding submodules ahead of time. + if prefix == "": + for fqn in list(fqn_order.keys())[1:]: + if _get_submodule(parent, fqn) is None: + _add_submodule(parent, fqn, torch.nn.Module()) + + children = [] + for name, child in list(parent._modules.items()): + if child is None: + continue + fqn = prefix + name + _reorder_submodules(child, fqn_order, prefix=fqn.split("@")[0] + ".") + delattr(parent, name) + children.append((fqn_order[fqn], name, child)) + children.sort(key=operator.itemgetter(0)) + for _, name, child in children: + parent.register_module(name, child) + + +class _IVals: + """ + Collect the intermediate values of mutations in a graph. + + Example: in the following graph, suppose that buf_in and buf_out + are the input and output values of a buffer. + + buf_in = placeholder() + ... + ival1 = f0(buf_in, ...) # inside self.n0(...) + ... + ival2 = f1(ival1, ...) # inside self.n1(...) + ... + buf_out = f2(ival2, ...) # inside self.n2(...) + return buf_out, ... + + Here ival1 and ival2 are intermediate values created inside + calls to n0 and n1 respectively, and used inside calls to + n1 and n2 respectively. + """ + + def __init__(self): + # for each fqn, set of node names corresponding to intermediate values + self.node_names_by_fqn = defaultdict(set) + + def _is_mutable(self, target): + if isinstance(target, torch._ops.OpOverload): + return target._schema.is_mutable + return False + + def read(self, mf, node): + """ + Read state corresponding to a given intermediate value. + """ + # we can assume that the node must be from a mutation + assert node.op == "call_function" + b = self._is_mutable(node.target) + print("Checking mutability", node.target, b) + if not b: + # so the mutation was functionalized; + # we will apply the original mutation later (see below) + fqn, _ = next(reversed(node.meta["nn_module_stack"].values())) + self.node_names_by_fqn[fqn].add(node.name) + return mf.remap_input(node.args[0]) + + def update(self, partitions): + """ + Update states corresponding to intermediate values that were read. + """ + for shared_submodules in partitions: + for entry in shared_submodules: + graph = entry.module.graph + node_names = self.node_names_by_fqn[entry.fqn] + nodes = [n for n in graph.nodes if n.name in node_names] + for node in nodes: + # so node must be from a functionalized mutation; + # we perform the original mutation now + with graph.inserting_after(node): + new_node = graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + (node.args[0], node), + ) + new_node.meta = copy.copy(node.meta) + + +def _copy_graph_attrs( + gm: torch.fx.GraphModule, + root_module: UnflattenedModule, + seen_attrs: dict[str, set[str]], +): + for child_fqn, names in seen_attrs.items(): + module = _get_attr(root_module, child_fqn) if child_fqn else root_module + for name in names: + val = getattr(gm, name) + setattr(module, name, val) + + +def _deduplicate_modules(partitions): + redirected_call_indices = {} + for shared_submodules in partitions: + for i, entry in enumerate(shared_submodules): + child_fqn = _call_name(entry.fqn, entry.call_idx) + target = _compute_accessor(entry.parent_fqn, child_fqn) + deduplicated = False + # Iterate over all previously seen modules, and deduplicate if possible + for seen in shared_submodules[:i]: + if _check_graph_equivalence(seen.module, entry.module): + parent = entry.parent_module + # Since graphs are equivalent, we can deduplicate. + # There are two cases. + if seen.fqn == entry.fqn: + # Case 1: The current module has the same fqn as the seen module. + # In this case we have generated a call name that can be optimized away. + # So we remove the current module from the hierarchy and replace + # the current call name with the seen call name in the parent graph. + *prefix, name = target.split(".") + _get_attr_via_attr_list(parent, prefix)._modules.pop(name) + seen_child_fqn = _call_name(seen.fqn, seen.call_idx) + seen_target = _compute_accessor( + entry.parent_fqn, seen_child_fqn + ) + entry.parent_call_module.target = seen_target + redirected_call_indices[child_fqn] = seen_child_fqn + break + elif not deduplicated: + # Case 2: The current module has a different fqn than the seen module. + # In this case we replace the current module with the seen module. + # There should be nothing pointing to the current module any more, + # so it can be garbage collected. + # NOTE: We *do not* replace the current call name with the seen call name + # in the parent graph, because this will lose information on which fqn + # was actually called. However, it is possible that the current call name + # will be optimized away when we find another seen module with the same fqn, + # so we do not break out of the loop yet. + parent.set_submodule(target, seen.module) + deduplicated = True + + return redirected_call_indices + + +def _sink_params( + module: torch.nn.Module, + inputs_to_state: dict[str, list[str]], + scope: list[str], + module_id_to_inputs_removed: Optional[dict[int, set[str]]] = None, +): + """Sink params, buffers, and constants from graph inputs into get_attr nodes. + + Exported modules are purely functional, so they pass their parameters and + buffers in as inputs to the graph. + + To replicate eager's semantics, we need to get them from the module state + via get_attr instead. + + module: GraphModule, potentially containing nested submodules. + inputs_to_state: mapping graph input names to the corresponding key in the state_dict. + scope: tracks where we are in the module hierarchy, so that we can emit the + right `getattr(self, "foo.bar")` calls, etc. + module_id_to_inputs_removed: records inputs removed by child modules, mapping + the module object id to the list of placeholder node names in the child module + that were removed. + """ + if module_id_to_inputs_removed is None: + module_id_to_inputs_removed = defaultdict(set) + + if id(module) in module_id_to_inputs_removed: + return {id(module): module_id_to_inputs_removed[id(module)]} + + # We need to use _modules here instead of named_children(), because we + # explicitly want duplicate modules to show up in the traversal. + for name, submodule in module._modules.items(): + submod_id_to_inputs_removed = _sink_params( + cast("torch.nn.Module", submodule), + inputs_to_state, + scope + [name], + module_id_to_inputs_removed, + ) + for k, v in submod_id_to_inputs_removed.items(): + module_id_to_inputs_removed[k].update(v) + + graph = getattr(module, "graph", None) + if graph is None or len(graph.nodes) == 0: + # Not all modules have graphs defined, if they are empty modules with no operations (like ParameterList) + return module_id_to_inputs_removed + + assert isinstance(graph, torch.fx.Graph) + + inputs = list(filter(lambda n: n.op == "placeholder", graph.nodes)) + the_last_input = None if len(inputs) == 0 else inputs[-1] + + # Also remove from call_module nodes + call_module_nodes = filter(lambda n: n.op == "call_module", graph.nodes) + for node in call_module_nodes: + submodule = _get_attr(module, node.target) + # remove placeholder from call_module node arguments, only if we've + # erased the placeholder node in the corresponding _sink_params() call + if submodule is not None and id(submodule) in module_id_to_inputs_removed: + node.args = tuple( + filter( + lambda n: n.name not in module_id_to_inputs_removed[id(submodule)], + node.args, + ) + ) + + # Filter out inputs_to_state corresponding to current scope. + inputs_to_state_of_scope: dict[torch.fx.Node, list[str]] = {} + for node in inputs: + if node.name not in inputs_to_state: + continue + + state_name = None + for sn in inputs_to_state[node.name]: + sn_split = sn.split(".") + if sn_split[: len(scope)] == [x.split("@")[0] for x in scope]: + state_name = sn_split + break + + # If there's a mismatch between scope name and state name, then + # there must be multiple scopes pointing to the same state name, + # meaning some modules are shared. In such case, we can simply skip + # updating the current node because another later iteration will + # take care of this input node when the unique match between scope + # and state name occurs. To make sure this always happen, we should + # enforce the invariant that no placeholder node in the unflattened + # graph appears in inputs_to_state dict, which means all the extra + # input nodes have been handled. + if state_name is None: + continue + + inputs_to_state_of_scope[node] = state_name + + # Record name of remove inputs for return purpose. + inputs_removed: set[str] = set() + + for node, state_name in inputs_to_state_of_scope.items(): + if len(node.users) > 0: + attr_path = state_name[len(scope) :] + state_attr = _get_attr_via_attr_list(module, attr_path) + assert isinstance(state_attr, (torch.Tensor, torch.ScriptObject)) + + # Make sure the newly created get_attr node is placed after the last placeholder node + with graph.inserting_after(the_last_input): + new_node = graph.create_node("get_attr", ".".join(attr_path)) + + node.replace_all_uses_with(new_node, propagate_meta=True) + + graph.erase_node(node) + inputs_removed.add(node.name) + + if isinstance(module, InterpreterModule): + module.finalize() + + return {id(module): inputs_removed} diff --git a/phivenv/Lib/site-packages/torch/fft/__init__.py b/phivenv/Lib/site-packages/torch/fft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5c19325cb03fa6d67e51f68f6499239c66ac024a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fft/__init__.py @@ -0,0 +1,1442 @@ +import torch +from torch._C import _add_docstr, _fft # type: ignore[attr-defined] +from torch._torch_docs import common_args, factory_common_args + + +__all__ = [ + "fft", + "ifft", + "fft2", + "ifft2", + "fftn", + "ifftn", + "rfft", + "irfft", + "rfft2", + "irfft2", + "rfftn", + "irfftn", + "hfft", + "ihfft", + "fftfreq", + "rfftfreq", + "fftshift", + "ifftshift", + "Tensor", +] + +Tensor = torch.Tensor + +# Note: This not only adds the doc strings for the spectral ops, but +# connects the torch.fft Python namespace to the torch._C._fft builtins. + +fft = _add_docstr( + _fft.fft_fft, + r""" +fft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional discrete Fourier transform of :attr:`input`. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: `X[i] = conj(X[-i])`. This function always returns both + the positive and negative frequency terms even though, for real inputs, the + negative frequencies are redundant. :func:`~torch.fft.rfft` returns the + more compact one-sided representation where only the positive frequencies + are returned. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the FFT. + dim (int, optional): The dimension along which to take the one dimensional FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Calling the backward transform (:func:`~torch.fft.ifft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.arange(4) + >>> t + tensor([0, 1, 2, 3]) + >>> torch.fft.fft(t) + tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + + >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) + >>> torch.fft.fft(t) + tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j]) +""".format(**common_args), +) + +ifft = _add_docstr( + _fft.fft_ifft, + r""" +ifft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional inverse discrete Fourier transform of :attr:`input`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the IFFT. + dim (int, optional): The dimension along which to take the one dimensional IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Calling the forward transform (:func:`~torch.fft.fft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + >>> torch.fft.ifft(t) + tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]) +""".format(**common_args), +) + +fft2 = _add_docstr( + _fft.fft_fft2, + r""" +fft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2 dimensional discrete Fourier transform of :attr:`input`. +Equivalent to :func:`~torch.fft.fftn` but FFTs only the last two dimensions by default. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i, j] = conj(X[-i, -j])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`~torch.fft.rfft2` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ifft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`~torch.fft.ifft2` the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> fft2 = torch.fft.fft2(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.fft2` + here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls: + + >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) + >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False) + +""".format(**common_args), +) + +ifft2 = _add_docstr( + _fft.fft_ifft2, + r""" +ifft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2 dimensional inverse discrete Fourier transform of :attr:`input`. +Equivalent to :func:`~torch.fft.ifftn` but IFFTs only the last two dimensions by default. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.fft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> ifft2 = torch.fft.ifft2(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.ifft2` + here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls: + + >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) + >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False) + +""".format(**common_args), +) + +fftn = _add_docstr( + _fft.fft_fftn, + r""" +fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N dimensional discrete Fourier transform of :attr:`input`. + +Note: + The Fourier domain representation of any real signal satisfies the + Hermitian property: ``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])``. This + function always returns all positive and negative frequency terms even + though, for real inputs, half of these values are redundant. + :func:`~torch.fft.rfftn` returns the more compact one-sided representation + where only the positive frequencies of the last dimension are returned. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.fftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ifftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` + between the two transforms. This is required to make + :func:`~torch.fft.ifftn` the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> fftn = torch.fft.fftn(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.fftn` + here is equivalent to two one-dimensional :func:`~torch.fft.fft` calls: + + >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) + >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False) + +""".format(**common_args), +) + +ifftn = _add_docstr( + _fft.fft_ifftn, + r""" +ifftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N dimensional inverse discrete Fourier transform of :attr:`input`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ifftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.fftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ifftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> x = torch.rand(10, 10, dtype=torch.complex64) + >>> ifftn = torch.fft.ifftn(x) + + The discrete Fourier transform is separable, so :func:`~torch.fft.ifftn` + here is equivalent to two one-dimensional :func:`~torch.fft.ifft` calls: + + >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) + >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False) + +""".format(**common_args), +) + +rfft = _add_docstr( + _fft.fft_rfft, + r""" +rfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional Fourier transform of real-valued :attr:`input`. + +The FFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])`` so +the output contains only the positive frequencies below the Nyquist frequency. +To compute the full output, use :func:`~torch.fft.fft` + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the real input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the real FFT. + dim (int, optional): The dimension along which to take the one dimensional real FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the FFT orthonormal) + + Calling the backward transform (:func:`~torch.fft.irfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.arange(4) + >>> t + tensor([0, 1, 2, 3]) + >>> torch.fft.rfft(t) + tensor([ 6.+0.j, -2.+2.j, -2.+0.j]) + + Compare against the full output from :func:`~torch.fft.fft`: + + >>> torch.fft.fft(t) + tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) + + Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. + At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, + and therefore must always be real-valued. +""".format(**common_args), +) + +irfft = _add_docstr( + _fft.fft_irfft, + r""" +irfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfft`. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfft`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`n`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal length :attr:`n`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + With default arguments, size of the transformed dimension should be (2^n + 1) as argument + `n` defaults to even output size = 2 * (transformed_dim_size - 1) + +Args: + input (Tensor): the input tensor representing a half-Hermitian signal + n (int, optional): Output signal length. This determines the length of the + output signal. If given, the input will either be zero-padded or trimmed to this + length before computing the real IFFT. + Defaults to even output: ``n=2*(input.size(dim) - 1)``. + dim (int, optional): The dimension along which to take the one dimensional real IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Calling the forward transform (:func:`~torch.fft.rfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.linspace(0, 1, 5) + >>> t + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + >>> T = torch.fft.rfft(t) + >>> T + tensor([ 2.5000+0.0000j, -0.6250+0.8602j, -0.6250+0.2031j]) + + Without specifying the output length to :func:`~torch.fft.irfft`, the output + will not round-trip properly because the input is odd-length: + + >>> torch.fft.irfft(T) + tensor([0.1562, 0.3511, 0.7812, 1.2114]) + + So, it is recommended to always pass the signal length :attr:`n`: + + >>> roundtrip = torch.fft.irfft(T, t.numel()) + >>> torch.testing.assert_close(roundtrip, t, check_stride=False) + +""".format(**common_args), +) + +rfft2 = _add_docstr( + _fft.fft_rfft2, + r""" +rfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2-dimensional discrete Fourier transform of real :attr:`input`. +Equivalent to :func:`~torch.fft.rfftn` but FFTs only the last two dimensions by default. + +The FFT of a real signal is Hermitian-symmetric, ``X[i, j] = conj(X[-i, -j])``, +so the full :func:`~torch.fft.fft2` output contains redundant information. +:func:`~torch.fft.rfft2` instead omits the negative frequencies in the last +dimension. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.irfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 10) + >>> rfft2 = torch.fft.rfft2(t) + >>> rfft2.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.fft2`, we have all + elements up to the Nyquist frequency. + + >>> fft2 = torch.fft.fft2(t) + >>> torch.testing.assert_close(fft2[..., :6], rfft2, check_stride=False) + + The discrete Fourier transform is separable, so :func:`~torch.fft.rfft2` + here is equivalent to a combination of :func:`~torch.fft.fft` and + :func:`~torch.fft.rfft`: + + >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) + >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False) + +""".format(**common_args), +) + +irfft2 = _add_docstr( + _fft.fft_irfft2, + r""" +irfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfft2`. +Equivalent to :func:`~torch.fft.irfftn` but IFFTs only the last two dimensions by default. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfft2`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.rfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 9) + >>> T = torch.fft.rfft2(t) + + Without specifying the output length to :func:`~torch.fft.irfft2`, the output + will not round-trip properly because the input is odd-length in the last + dimension: + + >>> torch.fft.irfft2(T).size() + torch.Size([10, 8]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.irfft2(T, t.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.testing.assert_close(roundtrip, t, check_stride=False) + +""".format(**common_args), +) + +rfftn = _add_docstr( + _fft.fft_rfftn, + r""" +rfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N-dimensional discrete Fourier transform of real :attr:`input`. + +The FFT of a real signal is Hermitian-symmetric, +``X[i_1, ..., i_n] = conj(X[-i_1, ..., -i_n])`` so the full +:func:`~torch.fft.fftn` output contains redundant information. +:func:`~torch.fft.rfftn` instead omits the negative frequencies in the +last dimension. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.rfftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.irfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfftn` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 10) + >>> rfftn = torch.fft.rfftn(t) + >>> rfftn.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.fftn`, we have all + elements up to the Nyquist frequency. + + >>> fftn = torch.fft.fftn(t) + >>> torch.testing.assert_close(fftn[..., :6], rfftn, check_stride=False) + + The discrete Fourier transform is separable, so :func:`~torch.fft.rfftn` + here is equivalent to a combination of :func:`~torch.fft.fft` and + :func:`~torch.fft.rfft`: + + >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) + >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False) + +""".format(**common_args), +) + +irfftn = _add_docstr( + _fft.fft_irfftn, + r""" +irfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.rfftn`. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the Fourier +domain, as produced by :func:`~torch.fft.rfftn`. By the Hermitian property, the +output will be real-valued. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal shape :attr:`s`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.irfftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the real IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.rfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.irfftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.rand(10, 9) + >>> T = torch.fft.rfftn(t) + + Without specifying the output length to :func:`~torch.fft.irfft`, the output + will not round-trip properly because the input is odd-length in the last + dimension: + + >>> torch.fft.irfftn(T).size() + torch.Size([10, 8]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.irfftn(T, t.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.testing.assert_close(roundtrip, t, check_stride=False) + +""".format(**common_args), +) + +hfft = _add_docstr( + _fft.fft_hfft, + r""" +hfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the one dimensional discrete Fourier transform of a Hermitian +symmetric :attr:`input` signal. + +Note: + + :func:`~torch.fft.hfft`/:func:`~torch.fft.ihfft` are analogous to + :func:`~torch.fft.rfft`/:func:`~torch.fft.irfft`. The real FFT expects + a real signal in the time-domain and gives a Hermitian symmetry in the + frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in + the time-domain and real-valued in the frequency-domain. For this reason, + special care needs to be taken with the length argument :attr:`n`, in the + same way as with :func:`~torch.fft.irfft`. + +Note: + Because the signal is Hermitian in the time-domain, the result will be + real in the frequency domain. Note that some input frequencies must be + real-valued to satisfy the Hermitian property. In these cases the imaginary + component will be ignored. For example, any imaginary component in + ``input[0]`` would result in one or more complex frequency terms which + cannot be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`n`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. So, it is recommended to always pass the signal length :attr:`n`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + With default arguments, size of the transformed dimension should be (2^n + 1) as argument + `n` defaults to even output size = 2 * (transformed_dim_size - 1) + +Args: + input (Tensor): the input tensor representing a half-Hermitian signal + n (int, optional): Output signal length. This determines the length of the + real output. If given, the input will either be zero-padded or trimmed to this + length before computing the Hermitian FFT. + Defaults to even output: ``n=2*(input.size(dim) - 1)``. + dim (int, optional): The dimension along which to take the one dimensional Hermitian FFT. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.hfft`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Calling the backward transform (:func:`~torch.fft.ihfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + Taking a real-valued frequency signal and bringing it into the time domain + gives Hermitian symmetric output: + + >>> t = torch.linspace(0, 1, 5) + >>> t + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + >>> T = torch.fft.ifft(t) + >>> T + tensor([ 0.5000-0.0000j, -0.1250-0.1720j, -0.1250-0.0406j, -0.1250+0.0406j, + -0.1250+0.1720j]) + + Note that ``T[1] == T[-1].conj()`` and ``T[2] == T[-2].conj()`` is + redundant. We can thus compute the forward transform without considering + negative frequencies: + + >>> torch.fft.hfft(T[:3], n=5) + tensor([0.0000, 0.2500, 0.5000, 0.7500, 1.0000]) + + Like with :func:`~torch.fft.irfft`, the output length must be given in order + to recover an even length output: + + >>> torch.fft.hfft(T[:3]) + tensor([0.1250, 0.2809, 0.6250, 0.9691]) +""".format(**common_args), +) + +ihfft = _add_docstr( + _fft.fft_ihfft, + r""" +ihfft(input, n=None, dim=-1, norm=None, *, out=None) -> Tensor + +Computes the inverse of :func:`~torch.fft.hfft`. + +:attr:`input` must be a real-valued signal, interpreted in the Fourier domain. +The IFFT of a real signal is Hermitian-symmetric, ``X[i] = conj(X[-i])``. +:func:`~torch.fft.ihfft` represents this in the one-sided form where only the +positive frequencies below the Nyquist frequency are included. To compute the +full output, use :func:`~torch.fft.ifft`. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimension. + +Args: + input (Tensor): the real input tensor + n (int, optional): Signal length. If given, the input will either be zero-padded + or trimmed to this length before computing the Hermitian IFFT. + dim (int, optional): The dimension along which to take the one dimensional Hermitian IFFT. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ihfft`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the IFFT orthonormal) + + Calling the forward transform (:func:`~torch.fft.hfft`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> t = torch.arange(5) + >>> t + tensor([0, 1, 2, 3, 4]) + >>> torch.fft.ihfft(t) + tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j]) + + Compare against the full output from :func:`~torch.fft.ifft`: + + >>> torch.fft.ifft(t) + tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, + -0.5000+0.6882j]) +""".format(**common_args), +) + +hfft2 = _add_docstr( + _fft.fft_hfft2, + r""" +hfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2-dimensional discrete Fourier transform of a Hermitian symmetric +:attr:`input` signal. Equivalent to :func:`~torch.fft.hfftn` but only +transforms the last two dimensions by default. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the time +domain. By the Hermitian property, the Fourier transform will be real-valued. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.hfft2`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ihfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft2` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + Starting from a real frequency-space signal, we can generate a + Hermitian-symmetric time-domain signal: + >>> T = torch.rand(10, 9) + >>> t = torch.fft.ihfft2(T) + + Without specifying the output length to :func:`~torch.fft.hfftn`, the + output will not round-trip properly because the input is odd-length in the + last dimension: + + >>> torch.fft.hfft2(t).size() + torch.Size([10, 10]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.hfft2(t, T.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.allclose(roundtrip, T) + True + +""".format(**common_args), +) + +ihfft2 = _add_docstr( + _fft.fft_ihfft2, + r""" +ihfft2(input, s=None, dim=(-2, -1), norm=None, *, out=None) -> Tensor + +Computes the 2-dimensional inverse discrete Fourier transform of real +:attr:`input`. Equivalent to :func:`~torch.fft.ihfftn` but transforms only the +two last dimensions by default. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: last two dimensions. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ihfft2`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.hfft2`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfft2` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> T = torch.rand(10, 10) + >>> t = torch.fft.ihfft2(t) + >>> t.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.ifft2`, the + Hermitian time-space signal takes up only half the space. + + >>> fftn = torch.fft.ifft2(t) + >>> torch.allclose(fftn[..., :6], rfftn) + True + + The discrete Fourier transform is separable, so :func:`~torch.fft.ihfft2` + here is equivalent to a combination of :func:`~torch.fft.ifft` and + :func:`~torch.fft.ihfft`: + + >>> two_ffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0) + >>> torch.allclose(t, two_ffts) + True + +""".format(**common_args), +) + +hfftn = _add_docstr( + _fft.fft_hfftn, + r""" +hfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the n-dimensional discrete Fourier transform of a Hermitian symmetric +:attr:`input` signal. + +:attr:`input` is interpreted as a one-sided Hermitian signal in the time +domain. By the Hermitian property, the Fourier transform will be real-valued. + +Note: + :func:`~torch.fft.hfftn`/:func:`~torch.fft.ihfftn` are analogous to + :func:`~torch.fft.rfftn`/:func:`~torch.fft.irfftn`. The real FFT expects + a real signal in the time-domain and gives Hermitian symmetry in the + frequency-domain. The Hermitian FFT is the opposite; Hermitian symmetric in + the time-domain and real-valued in the frequency-domain. For this reason, + special care needs to be taken with the shape argument :attr:`s`, in the + same way as with :func:`~torch.fft.irfftn`. + +Note: + Some input frequencies must be real-valued to satisfy the Hermitian + property. In these cases the imaginary component will be ignored. + For example, any imaginary component in the zero-frequency term cannot + be represented in a real output and so will always be ignored. + +Note: + The correct interpretation of the Hermitian input depends on the length of + the original data, as given by :attr:`s`. This is because each input shape + could correspond to either an odd or even length signal. By default, the + signal is assumed to be even length and odd signals will not round-trip + properly. It is recommended to always pass the signal shape :attr:`s`. + +Note: + Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + With default arguments, the size of last dimension should be (2^n + 1) as argument + `s` defaults to even output size = 2 * (last_dim_size - 1) + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the real FFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Defaults to even output in the last dimension: + ``s[-1] = 2*(input.size(dim[-1]) - 1)``. + dim (Tuple[int], optional): Dimensions to be transformed. + The last dimension must be the half-Hermitian compressed dimension. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the forward transform + (:func:`~torch.fft.hfftn`), these correspond to: + + * ``"forward"`` - normalize by ``1/n`` + * ``"backward"`` - no normalization + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian FFT orthonormal) + + Where ``n = prod(s)`` is the logical FFT size. + Calling the backward transform (:func:`~torch.fft.ihfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfftn` + the exact inverse. + + Default is ``"backward"`` (no normalization). + +Keyword args: + {out} + +Example: + + Starting from a real frequency-space signal, we can generate a + Hermitian-symmetric time-domain signal: + >>> T = torch.rand(10, 9) + >>> t = torch.fft.ihfftn(T) + + Without specifying the output length to :func:`~torch.fft.hfftn`, the + output will not round-trip properly because the input is odd-length in the + last dimension: + + >>> torch.fft.hfftn(t).size() + torch.Size([10, 10]) + + So, it is recommended to always pass the signal shape :attr:`s`. + + >>> roundtrip = torch.fft.hfftn(t, T.size()) + >>> roundtrip.size() + torch.Size([10, 9]) + >>> torch.allclose(roundtrip, T) + True + +""".format(**common_args), +) + +ihfftn = _add_docstr( + _fft.fft_ihfftn, + r""" +ihfftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor + +Computes the N-dimensional inverse discrete Fourier transform of real :attr:`input`. + +:attr:`input` must be a real-valued signal, interpreted in the Fourier domain. +The n-dimensional IFFT of a real signal is Hermitian-symmetric, +``X[i, j, ...] = conj(X[-i, -j, ...])``. :func:`~torch.fft.ihfftn` represents +this in the one-sided form where only the positive frequencies below the +Nyquist frequency are included in the last signal dimension. To compute the +full output, use :func:`~torch.fft.ifftn`. + +Note: + Supports torch.half on CUDA with GPU Architecture SM53 or greater. + However it only supports powers of 2 signal length in every transformed dimensions. + +Args: + input (Tensor): the input tensor + s (Tuple[int], optional): Signal size in the transformed dimensions. + If given, each dimension ``dim[i]`` will either be zero-padded or + trimmed to the length ``s[i]`` before computing the Hermitian IFFT. + If a length ``-1`` is specified, no padding is done in that dimension. + Default: ``s = [input.size(d) for d in dim]`` + dim (Tuple[int], optional): Dimensions to be transformed. + Default: all dimensions, or the last ``len(s)`` dimensions if :attr:`s` is given. + norm (str, optional): Normalization mode. For the backward transform + (:func:`~torch.fft.ihfftn`), these correspond to: + + * ``"forward"`` - no normalization + * ``"backward"`` - normalize by ``1/n`` + * ``"ortho"`` - normalize by ``1/sqrt(n)`` (making the Hermitian IFFT orthonormal) + + Where ``n = prod(s)`` is the logical IFFT size. + Calling the forward transform (:func:`~torch.fft.hfftn`) with the same + normalization mode will apply an overall normalization of ``1/n`` between + the two transforms. This is required to make :func:`~torch.fft.ihfftn` + the exact inverse. + + Default is ``"backward"`` (normalize by ``1/n``). + +Keyword args: + {out} + +Example: + + >>> T = torch.rand(10, 10) + >>> ihfftn = torch.fft.ihfftn(T) + >>> ihfftn.size() + torch.Size([10, 6]) + + Compared against the full output from :func:`~torch.fft.ifftn`, we have all + elements up to the Nyquist frequency. + + >>> ifftn = torch.fft.ifftn(t) + >>> torch.allclose(ifftn[..., :6], ihfftn) + True + + The discrete Fourier transform is separable, so :func:`~torch.fft.ihfftn` + here is equivalent to a combination of :func:`~torch.fft.ihfft` and + :func:`~torch.fft.ifft`: + + >>> two_iffts = torch.fft.ifft(torch.fft.ihfft(t, dim=1), dim=0) + >>> torch.allclose(ihfftn, two_iffts) + True + +""".format(**common_args), +) + +fftfreq = _add_docstr( + _fft.fft_fftfreq, + r""" +fftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Computes the discrete Fourier Transform sample frequencies for a signal of size :attr:`n`. + +Note: + By convention, :func:`~torch.fft.fft` returns positive frequency terms + first, followed by the negative frequencies in reverse order, so that + ``f[-i]`` for all :math:`0 < i \leq n/2`` in Python gives the negative + frequency terms. For an FFT of length :attr:`n` and with inputs spaced in + length unit :attr:`d`, the frequencies are:: + + f = [0, 1, ..., (n - 1) // 2, -(n // 2), ..., -1] / (d * n) + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. :func:`~torch.fft.fftfreq` follows NumPy's + convention of taking it to be negative. + +Args: + n (int): the FFT length + d (float, optional): The sampling length scale. + The spacing between individual samples of the FFT input. + The default assumes unit spacing, dividing that result by the actual + spacing gives the result in physical frequency units. + +Keyword Args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example: + + >>> torch.fft.fftfreq(5) + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + + For even input, we can see the Nyquist frequency at ``f[2]`` is given as + negative: + + >>> torch.fft.fftfreq(4) + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + +""".format(**factory_common_args), +) + +rfftfreq = _add_docstr( + _fft.fft_rfftfreq, + r""" +rfftfreq(n, d=1.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor + +Computes the sample frequencies for :func:`~torch.fft.rfft` with a signal of size :attr:`n`. + +Note: + :func:`~torch.fft.rfft` returns Hermitian one-sided output, so only the + positive frequency terms are returned. For a real FFT of length :attr:`n` + and with inputs spaced in length unit :attr:`d`, the frequencies are:: + + f = torch.arange((n + 1) // 2) / (d * n) + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. Unlike :func:`~torch.fft.fftfreq`, + :func:`~torch.fft.rfftfreq` always returns it as positive. + +Args: + n (int): the real FFT length + d (float, optional): The sampling length scale. + The spacing between individual samples of the FFT input. + The default assumes unit spacing, dividing that result by the actual + spacing gives the result in physical frequency units. + +Keyword Args: + {out} + {dtype} + {layout} + {device} + {requires_grad} + +Example: + + >>> torch.fft.rfftfreq(5) + tensor([0.0000, 0.2000, 0.4000]) + + >>> torch.fft.rfftfreq(4) + tensor([0.0000, 0.2500, 0.5000]) + + Compared to the output from :func:`~torch.fft.fftfreq`, we see that the + Nyquist frequency at ``f[2]`` has changed sign: + >>> torch.fft.fftfreq(4) + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + +""".format(**factory_common_args), +) + +fftshift = _add_docstr( + _fft.fft_fftshift, + r""" +fftshift(input, dim=None) -> Tensor + +Reorders n-dimensional FFT data, as provided by :func:`~torch.fft.fftn`, to have +negative frequency terms first. + +This performs a periodic shift of n-dimensional data such that the origin +``(0, ..., 0)`` is moved to the center of the tensor. Specifically, to +``input.shape[dim] // 2`` in each selected dimension. + +Note: + By convention, the FFT returns positive frequency terms first, followed by + the negative frequencies in reverse order, so that ``f[-i]`` for all + :math:`0 < i \leq n/2` in Python gives the negative frequency terms. + :func:`~torch.fft.fftshift` rearranges all frequencies into ascending order + from negative to positive with the zero-frequency term in the center. + +Note: + For even lengths, the Nyquist frequency at ``f[n/2]`` can be thought of as + either negative or positive. :func:`~torch.fft.fftshift` always puts the + Nyquist term at the 0-index. This is the same convention used by + :func:`~torch.fft.fftfreq`. + +Args: + input (Tensor): the tensor in FFT order + dim (int, Tuple[int], optional): The dimensions to rearrange. + Only dimensions specified here will be rearranged, any other dimensions + will be left in their original order. + Default: All dimensions of :attr:`input`. + +Example: + + >>> f = torch.fft.fftfreq(4) + >>> f + tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) + + >>> torch.fft.fftshift(f) + tensor([-0.5000, -0.2500, 0.0000, 0.2500]) + + Also notice that the Nyquist frequency term at ``f[2]`` was moved to the + beginning of the tensor. + + This also works for multi-dimensional transforms: + + >>> x = torch.fft.fftfreq(5, d=1/5) + 0.1 * torch.fft.fftfreq(5, d=1/5).unsqueeze(1) + >>> x + tensor([[ 0.0000, 1.0000, 2.0000, -2.0000, -1.0000], + [ 0.1000, 1.1000, 2.1000, -1.9000, -0.9000], + [ 0.2000, 1.2000, 2.2000, -1.8000, -0.8000], + [-0.2000, 0.8000, 1.8000, -2.2000, -1.2000], + [-0.1000, 0.9000, 1.9000, -2.1000, -1.1000]]) + + >>> torch.fft.fftshift(x) + tensor([[-2.2000, -1.2000, -0.2000, 0.8000, 1.8000], + [-2.1000, -1.1000, -0.1000, 0.9000, 1.9000], + [-2.0000, -1.0000, 0.0000, 1.0000, 2.0000], + [-1.9000, -0.9000, 0.1000, 1.1000, 2.1000], + [-1.8000, -0.8000, 0.2000, 1.2000, 2.2000]]) + + :func:`~torch.fft.fftshift` can also be useful for spatial data. If our + data is defined on a centered grid (``[-(N//2), (N-1)//2]``) then we can + use the standard FFT defined on an uncentered grid (``[0, N)``) by first + applying an :func:`~torch.fft.ifftshift`. + + >>> x_centered = torch.arange(-5, 5) + >>> x_uncentered = torch.fft.ifftshift(x_centered) + >>> fft_uncentered = torch.fft.fft(x_uncentered) + + Similarly, we can convert the frequency domain components to centered + convention by applying :func:`~torch.fft.fftshift`. + + >>> fft_centered = torch.fft.fftshift(fft_uncentered) + + The inverse transform, from centered Fourier space back to centered spatial + data, can be performed by applying the inverse shifts in reverse order: + + >>> x_centered_2 = torch.fft.fftshift(torch.fft.ifft(torch.fft.ifftshift(fft_centered))) + >>> torch.testing.assert_close(x_centered.to(torch.complex64), x_centered_2, check_stride=False) + + +""", +) + +ifftshift = _add_docstr( + _fft.fft_ifftshift, + r""" +ifftshift(input, dim=None) -> Tensor + +Inverse of :func:`~torch.fft.fftshift`. + +Args: + input (Tensor): the tensor in FFT order + dim (int, Tuple[int], optional): The dimensions to rearrange. + Only dimensions specified here will be rearranged, any other dimensions + will be left in their original order. + Default: All dimensions of :attr:`input`. + +Example: + + >>> f = torch.fft.fftfreq(5) + >>> f + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + + A round-trip through :func:`~torch.fft.fftshift` and + :func:`~torch.fft.ifftshift` gives the same result: + + >>> shifted = torch.fft.fftshift(f) + >>> torch.fft.ifftshift(shifted) + tensor([ 0.0000, 0.2000, 0.4000, -0.4000, -0.2000]) + +""", +) diff --git a/phivenv/Lib/site-packages/torch/fft/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fft/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b72d5eec61bd9c25b51899265c4e7a080f7d575 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fft/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/func/__init__.py b/phivenv/Lib/site-packages/torch/func/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6afd8ad0c45a77d545f52f596d2e75c2ddef20e9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/func/__init__.py @@ -0,0 +1,31 @@ +from torch._functorch.apis import grad, grad_and_value, vmap +from torch._functorch.batch_norm_replacement import replace_all_batch_norm_modules_ +from torch._functorch.eager_transforms import ( + debug_unwrap, + functionalize, + hessian, + jacfwd, + jacrev, + jvp, + linearize, + vjp, +) +from torch._functorch.functional_call import functional_call, stack_module_state + + +__all__ = [ + "grad", + "grad_and_value", + "vmap", + "replace_all_batch_norm_modules_", + "functionalize", + "hessian", + "jacfwd", + "jacrev", + "jvp", + "linearize", + "vjp", + "functional_call", + "stack_module_state", + "debug_unwrap", +] diff --git a/phivenv/Lib/site-packages/torch/func/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/func/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf47847de3c05bf692a7dd6ef636553f5f41c86f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/func/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/futures/__init__.py b/phivenv/Lib/site-packages/torch/futures/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7c3ab69808880c880d40d37bb18325775b3d63dc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/futures/__init__.py @@ -0,0 +1,335 @@ +# mypy: allow-untyped-defs +from __future__ import annotations + +from typing import Callable, cast, Generic, Optional, TypeVar, Union + +import torch + + +__all__ = ["Future", "collect_all", "wait_all"] + + +T = TypeVar("T") +S = TypeVar("S") + + +class _PyFutureMeta(type(torch._C.Future), type(Generic)): # type: ignore[misc, no-redef] + pass + + +class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): + r""" + Wrapper around a ``torch._C.Future`` which encapsulates an asynchronous + execution of a callable, e.g. :meth:`~torch.distributed.rpc.rpc_async`. It + also exposes a set of APIs to add callback functions and set results. + + .. warning:: GPU support is a beta feature, subject to changes. + """ + + def __init__( + self, *, devices: Optional[list[Union[int, str, torch.device]]] = None + ): + r""" + Create an empty unset ``Future``. If the future is intended to hold + values containing CUDA tensors, (a superset of) their CUDA devices must + be specified at construction. (This is only supported if + ``torch.cuda.is_available()`` returns ``True``). This is needed to + ensure proper CUDA stream synchronization. The child futures, returned + by the ``then`` method, will inherit these devices. + + Args: + devices(``List[Union[int, str, torch.device]]``, optional): the set + of devices on which tensors contained in this future's value are + allowed to reside and on which callbacks are allowed to operate. + """ + if devices is None: + devices = [] + super().__init__([torch.device(d) for d in devices]) + + def done(self) -> bool: + r""" + Return ``True`` if this ``Future`` is done. A ``Future`` is done if it + has a result or an exception. + + If the value contains tensors that reside on GPUs, ``Future.done()`` + will return ``True`` even if the asynchronous kernels that are + populating those tensors haven't yet completed running on the device, + because at such stage the result is already usable, provided one + performs the appropriate synchronizations (see :meth:`wait`). + """ + return super().done() + + def wait(self) -> T: + r""" + Block until the value of this ``Future`` is ready. + + If the value contains tensors that reside on GPUs, then an additional + synchronization is performed with the kernels (executing on the device) + which may be asynchronously populating those tensors. Such sync is + non-blocking, which means that ``wait()`` will insert the necessary + instructions in the current streams to ensure that further operations + enqueued on those streams will be properly scheduled after the async + kernels but, once that is done, ``wait()`` will return, even if those + kernels are still running. No further synchronization is required when + accessing and using the values, as long as one doesn't change streams. + + Returns: + The value held by this ``Future``. If the function (callback or RPC) + creating the value has thrown an error, this ``wait`` method will + also throw an error. + """ + return super().wait() + + def value(self) -> T: + r""" + Obtain the value of an already-completed future. + + This method should only be called after a call to :meth:`wait` has + completed, or inside a callback function passed to :meth:`then`. In + other cases this ``Future`` may not yet hold a value and calling + ``value()`` could fail. + + If the value contains tensors that reside on GPUs, then this method will + *not* perform any additional synchronization. This should be done + beforehand, separately, through a call to :meth:`wait` (except within + callbacks, for which it's already being taken care of by :meth:`then`). + + Returns: + The value held by this ``Future``. If the function (callback or RPC) + creating the value has thrown an error, this ``value()`` method will + also throw an error. + """ + return super().value() + + def then(self, callback: Callable[[Future[T]], S]) -> Future[S]: + r""" + Append the given callback function to this ``Future``, which will be run + when the ``Future`` is completed. Multiple callbacks can be added to + the same ``Future``, but the order in which they will be executed cannot + be guaranteed (to enforce a certain order consider chaining: + ``fut.then(cb1).then(cb2)``). The callback must take one argument, which + is the reference to this ``Future``. The callback function can use the + :meth:`value` method to get the value. Note that if this ``Future`` is + already completed, the given callback will be run immediately inline. + + If the ``Future``'s value contains tensors that reside on GPUs, the + callback might be invoked while the async kernels that are populating + those tensors haven't yet finished executing on the device. However, the + callback will be invoked with some dedicated streams set as current + (fetched from a global pool) which will be synchronized with those + kernels. Hence any operation performed by the callback on these tensors + will be scheduled on the device after the kernels complete. In other + words, as long as the callback doesn't switch streams, it can safely + manipulate the result without any additional synchronization. This is + similar to the non-blocking behavior of :meth:`wait`. + + Similarly, if the callback returns a value that contains tensors that + reside on a GPU, it can do so even if the kernels that are producing + these tensors are still running on the device, as long as the callback + didn't change streams during its execution. If one wants to change + streams, one must be careful to re-synchronize them with the original + streams, that is, those that were current when the callback was invoked. + + Args: + callback(``Callable``): a ``Callable`` that takes this ``Future`` as + the only argument. + + Returns: + A new ``Future`` object that holds the return value of the + ``callback`` and will be marked as completed when the given + ``callback`` finishes. + + .. note:: Note that if the callback function throws, either + through the original future being completed with an exception and + calling ``fut.wait()``, or through other code in the callback, the + future returned by ``then`` will be marked appropriately with the + encountered error. However, if this callback later completes + additional futures, those futures are not marked as completed with + an error and the user is responsible for handling completion/waiting + on those futures independently. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> def callback(fut): + ... print(f"RPC return value is {fut.wait()}.") + >>> fut = torch.futures.Future() + >>> # The inserted callback will print the return value when + >>> # receiving the response from "worker1" + >>> cb_fut = fut.then(callback) + >>> chain_cb_fut = cb_fut.then( + ... lambda x : print(f"Chained cb done. {x.wait()}") + ... ) + >>> fut.set_result(5) + RPC return value is 5. + Chained cb done. None + """ + return cast(Future[S], super().then(callback)) + + def add_done_callback(self, callback: Callable[[Future[T]], None]) -> None: + r""" + Append the given callback function to this ``Future``, which will be run + when the ``Future`` is completed. Multiple callbacks can be added to + the same ``Future``, but the order in which they will be executed cannot + be guaranteed. The callback must take one argument, which is the + reference to this ``Future``. The callback function can use the + :meth:`value` method to get the value. Note that if this ``Future`` is + already completed, the given callback will be run inline. + + We recommend that you use the :meth:`then` method as it provides a way + to synchronize after your callback has completed. ``add_done_callback`` + can be cheaper if your callback does not return anything. But both + :meth:`then` and ``add_done_callback`` use the same callback + registration API under the hood. + + With respect to GPU tensors, this method behaves in the same way as + :meth:`then`. + + Args: + callback(``Future``): a ``Callable`` that takes in one argument, + which is the reference to this ``Future``. + + .. note:: Note that if the callback function throws, either + through the original future being completed with an exception and + calling ``fut.wait()``, or through other code in the callback, + error handling must be carefully taken care of. For example, if + this callback later completes additional futures, those futures are + not marked as completed with an error and the user is responsible + for handling completion/waiting on those futures independently. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> def callback(fut): + ... print("This will run after the future has finished.") + ... print(fut.wait()) + >>> fut = torch.futures.Future() + >>> fut.add_done_callback(callback) + >>> fut.set_result(5) + This will run after the future has finished. + 5 + """ + super().add_done_callback(callback) + + def set_result(self, result: T) -> None: + r""" + Set the result for this ``Future``, which will mark this ``Future`` as + completed and trigger all attached callbacks. Note that a ``Future`` + cannot be marked completed twice. + + If the result contains tensors that reside on GPUs, this method can be + called even if the asynchronous kernels that are populating those + tensors haven't yet completed running on the device, provided that the + streams on which those kernels were enqueued are set as the current ones + when this method is called. Put simply, it's safe to call this method + immediately after launching those kernels, without any additional + synchronization, as long as one doesn't change streams in between. This + method will record events on all the relevant current streams and will + use them to ensure proper scheduling for all the consumers of this + ``Future``. + + Args: + result (object): the result object of this ``Future``. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> import threading + >>> import time + >>> def slow_set_future(fut, value): + ... time.sleep(0.5) + ... fut.set_result(value) + >>> fut = torch.futures.Future() + >>> t = threading.Thread( + ... target=slow_set_future, + ... args=(fut, torch.ones(2) * 3) + ... ) + >>> t.start() + >>> print(fut.wait()) + tensor([3., 3.]) + >>> t.join() + """ + super().set_result(result) + + def set_exception(self, result: T) -> None: + r""" + Set an exception for this ``Future``, which will mark this ``Future`` as + completed with an error and trigger all attached callbacks. Note that + when calling wait()/value() on this ``Future``, the exception set here + will be raised inline. + + Args: + result (BaseException): the exception for this ``Future``. + + Example:: + + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> fut = torch.futures.Future() + >>> fut.set_exception(ValueError("foo")) + >>> fut.wait() + Traceback (most recent call last): + ... + ValueError: foo + """ + assert isinstance(result, Exception), ( + f"{result} is of type {type(result)}, not an Exception." + ) + + def raise_error(fut_result): + raise fut_result + + super()._set_unwrap_func(raise_error) + self.set_result(result) # type: ignore[arg-type] + + +def collect_all(futures: list[Future]) -> Future[list[Future]]: + r""" + Collects the provided :class:`~torch.futures.Future` objects into a single + combined :class:`~torch.futures.Future` that is completed when all of the + sub-futures are completed. + + Args: + futures (list): a list of :class:`~torch.futures.Future` objects. + + Returns: + Returns a :class:`~torch.futures.Future` object to a list of the passed + in Futures. + + Example:: + >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_FUTURES) + >>> fut0 = torch.futures.Future() + >>> fut1 = torch.futures.Future() + >>> fut = torch.futures.collect_all([fut0, fut1]) + >>> fut0.set_result(0) + >>> fut1.set_result(1) + >>> fut_list = fut.wait() + >>> print(f"fut0 result = {fut_list[0].wait()}") + fut0 result = 0 + >>> print(f"fut1 result = {fut_list[1].wait()}") + fut1 result = 1 + """ + return cast( + Future[list[Future]], + torch._C._collect_all(cast(list[torch._C.Future], futures)), + ) + + +def wait_all(futures: list[Future]) -> list: + r""" + Waits for all provided futures to be complete, and returns + the list of completed values. If any of the futures encounters an error, + the method will exit early and report the error not waiting for other + futures to complete. + + Args: + futures (list): a list of :class:`~torch.futures.Future` object. + + Returns: + A list of the completed :class:`~torch.futures.Future` results. This + method will throw an error if ``wait`` on any + :class:`~torch.futures.Future` throws. + """ + return [ + fut.wait() + for fut in torch._C._collect_all(cast(list[torch._C.Future], futures)).wait() + ] diff --git a/phivenv/Lib/site-packages/torch/futures/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/futures/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dee952fbc0acc0625b6b77a01a957392862e8cd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/futures/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__init__.py b/phivenv/Lib/site-packages/torch/fx/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9a40aabc4a5174da71bc653ee10d179b00943bca --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/__init__.py @@ -0,0 +1,116 @@ +r''' +FX is a toolkit for developers to use to transform ``nn.Module`` +instances. FX consists of three main components: a **symbolic tracer,** +an **intermediate representation**, and **Python code generation**. A +demonstration of these components in action: + +:: + + import torch + + + # Simple module for demonstration + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return self.linear(x + self.param).clamp(min=0.0, max=1.0) + + + module = MyModule() + + from torch.fx import symbolic_trace + + # Symbolic tracing frontend - captures the semantics of the module + symbolic_traced: torch.fx.GraphModule = symbolic_trace(module) + + # High-level intermediate representation (IR) - Graph representation + print(symbolic_traced.graph) + """ + graph(): + %x : [num_users=1] = placeholder[target=x] + %param : [num_users=1] = get_attr[target=param] + %add : [num_users=1] = call_function[target=operator.add](args = (%x, %param), kwargs = {}) + %linear : [num_users=1] = call_module[target=linear](args = (%add,), kwargs = {}) + %clamp : [num_users=1] = call_method[target=clamp](args = (%linear,), kwargs = {min: 0.0, max: 1.0}) + return clamp + """ + + # Code generation - valid Python code + print(symbolic_traced.code) + """ + def forward(self, x): + param = self.param + add = x + param; x = param = None + linear = self.linear(add); add = None + clamp = linear.clamp(min = 0.0, max = 1.0); linear = None + return clamp + """ + +The **symbolic tracer** performs "symbolic execution" of the Python +code. It feeds fake values, called Proxies, through the code. Operations +on theses Proxies are recorded. More information about symbolic tracing +can be found in the :func:`symbolic_trace` and :class:`Tracer` +documentation. + +The **intermediate representation** is the container for the operations +that were recorded during symbolic tracing. It consists of a list of +Nodes that represent function inputs, callsites (to functions, methods, +or :class:`torch.nn.Module` instances), and return values. More information +about the IR can be found in the documentation for :class:`Graph`. The +IR is the format on which transformations are applied. + +**Python code generation** is what makes FX a Python-to-Python (or +Module-to-Module) transformation toolkit. For each Graph IR, we can +create valid Python code matching the Graph's semantics. This +functionality is wrapped up in :class:`GraphModule`, which is a +:class:`torch.nn.Module` instance that holds a :class:`Graph` as well as a +``forward`` method generated from the Graph. + +Taken together, this pipeline of components (symbolic tracing -> +intermediate representation -> transforms -> Python code generation) +constitutes the Python-to-Python transformation pipeline of FX. In +addition, these components can be used separately. For example, +symbolic tracing can be used in isolation to capture a form of +the code for analysis (and not transformation) purposes. Code +generation can be used for programmatically generating models, for +example from a config file. There are many uses for FX! + +Several example transformations can be found at the +`examples `__ +repository. +''' + +from torch.fx import immutable_collections +from torch.fx._symbolic_trace import ( # noqa: F401 + PH, + ProxyableClassMeta, + symbolic_trace, + Tracer, + wrap, +) +from torch.fx.graph import CodeGen, Graph # noqa: F401 +from torch.fx.graph_module import GraphModule +from torch.fx.interpreter import Interpreter, Transformer +from torch.fx.node import has_side_effect, map_arg, Node +from torch.fx.proxy import Proxy +from torch.fx.subgraph_rewriter import replace_pattern + + +__all__ = [ + "symbolic_trace", + "Tracer", + "wrap", + "Graph", + "GraphModule", + "Interpreter", + "Transformer", + "Node", + "Proxy", + "replace_pattern", + "has_side_effect", + "map_arg", +] diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fab641bfea31288d9ea8fb7075ea94d0cc8d2261 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..895b00ba721a4200f0a0e438f47659fc46f4ae55 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/_compatibility.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f2b9837da9315c8f7cc7f1a6f48d2f1505e5afa Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/_graph_pickler.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bf2f57be86df10ed1cd953500718f8a669918df Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/_lazy_graph_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdd9a225139dc4cab187262cca7cfd242c4d7920 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/_pytree.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1406116c2b66a04aa27e320aadbd086dd47cb98 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/_symbolic_trace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6a3288764959a56fac76a8446f561cd48ec390e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/annotate.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/annotate.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25567ae7d430127125d56fe96afff2fc66e38be0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/annotate.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..350881ad45e5dbf35c5fee92a653a5bf6a143da1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/graph.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/graph.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6b26e52ef63bdfa93b9ff247dba20433be9e6dd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/graph.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/graph_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/graph_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c84a0e36a3744a3d0b1464de34d593f9dda9ab59 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/graph_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/immutable_collections.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/immutable_collections.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e916b6d6bd7b071abcb1a28b450b696542a85d0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/immutable_collections.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/interpreter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/interpreter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7826a4f52d28f7291ab9a1f8c9d776ce1f9213df Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/interpreter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/node.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/node.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4eefc0a385e6185d40b11747a2e37b3d9142173 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/node.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/operator_schemas.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/operator_schemas.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..837a7f0b646246e1fc6dde774d0c087f46291b8a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/operator_schemas.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/proxy.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/proxy.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b48ce3f450007937d9e2a55dfb95b1adff676468 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/proxy.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df35cc201284cbedad4070dd2f7d6cab6fd81e30 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/subgraph_rewriter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/tensor_type.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/tensor_type.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2df7fba299788356df0b1ef888bb63b9ea7bd85f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/tensor_type.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d63d2f8b2263552b319ebb73a248ec22702e7de Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/__pycache__/traceback.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/_compatibility.py b/phivenv/Lib/site-packages/torch/fx/_compatibility.py new file mode 100644 index 0000000000000000000000000000000000000000..a178fc9a80574adf9b635c4a5e83fa62414a5d0d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/_compatibility.py @@ -0,0 +1,39 @@ +import textwrap +from typing import Any, Callable, TypeVar + + +_BACK_COMPAT_OBJECTS: dict[Any, None] = {} +_MARKED_WITH_COMPATIBILITY: dict[Any, None] = {} + + +_T = TypeVar("_T") + + +def compatibility(is_backward_compatible: bool) -> Callable[[_T], _T]: + if is_backward_compatible: + + def mark_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") + docstring += """ +.. note:: + Backwards-compatibility for this API is guaranteed. +""" + fn.__doc__ = docstring + _BACK_COMPAT_OBJECTS.setdefault(fn) + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_back_compat + else: + + def mark_not_back_compat(fn: _T) -> _T: + docstring = textwrap.dedent(getattr(fn, "__doc__", None) or "") + docstring += """ +.. warning:: + This API is experimental and is *NOT* backward-compatible. +""" + fn.__doc__ = docstring + _MARKED_WITH_COMPATIBILITY.setdefault(fn) + return fn + + return mark_not_back_compat diff --git a/phivenv/Lib/site-packages/torch/fx/_graph_pickler.py b/phivenv/Lib/site-packages/torch/fx/_graph_pickler.py new file mode 100644 index 0000000000000000000000000000000000000000..d8f7a519e62be98be8da70bc8d5c34cb8d2f0a6c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/_graph_pickler.py @@ -0,0 +1,608 @@ +import dataclasses +import importlib +import io +import pickle +from abc import abstractmethod +from typing import Any, Callable, NewType, Optional, TypeVar, Union +from typing_extensions import override, Self + +import torch +import torch.utils._pytree as pytree +from torch._guards import TracingContext +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode, Tensor +from torch._subclasses.meta_utils import ( + MetaConverter, + MetaTensorDesc, + MetaTensorDescriber, +) +from torch.fx.experimental.sym_node import SymNode +from torch.fx.experimental.symbolic_shapes import ShapeEnv +from torch.utils._mode_utils import no_dispatch + + +_SymNodeT = TypeVar("_SymNodeT", torch.SymInt, torch.SymFloat) + + +def _ops_filter_safe(name: str) -> bool: + """ + An ops filter which allows pickle-safe ops. Pickle-safe ops are built-in + ones where it will be possible to unpickle on any machine which has PyTorch. + """ + # TODO: This list is pretty pessimistic right now. What's the full list? + return name.startswith( + ( + "torch.ops.aten", + "torch.ops.fbgemm", + ) + ) + + +@dataclasses.dataclass +class Options: + # A filter for which ops will cause the pickler to raise a + # BypassFxGraphCache exception. If None then all ops are allowed. + ops_filter: Optional[Callable[[str], bool]] = _ops_filter_safe + + +class GraphPickler(pickle.Pickler): + """ + GraphPickler is a Pickler which helps pickling fx graph - in particular + GraphModule. + """ + + def __init__(self, file: io.BytesIO, options: Optional[Options] = None) -> None: + super().__init__(file) + self.options = options or Options() + + # This abomination is so we can pass external decoding state to the + # unpickler functions. We serialize _unpickle_state as a persistent + # external item and when we deserialize it we return the common state + # object. + self._unpickle_state = _UnpickleStateToken(object()) + + # This is used to describe tensors. It needs to be common across the + # pickle so that duplicates and views are properly handled. + self._meta_tensor_describer = MetaTensorDescriber(copy_data=False) + + @override + def reducer_override( + self, obj: object + ) -> tuple[Callable[..., Any], tuple[Any, ...]]: + # This function is supposed to return either NotImplemented (meaning to + # do the default pickle behavior) or a pair of (unpickle callable, data + # to pass to unpickle). + + # We could instead teach individual classes how to pickle themselves but + # that has a few problems: + # + # 1. If we have some special needs (maybe for this use-case we don't + # want to fully serialize every field) then we're adding private + # details to a public interface. + # + # 2. If we need to have some common shared data (such as a + # FakeTensorMode) which is passed to each value it's harder to + # support. + + # These are the types that need special handling. See the individual + # *PickleData classes for details on pickling that particular type. + if isinstance(obj, FakeTensor): + return _TensorPickleData.reduce_helper(self, obj) + elif isinstance(obj, torch.fx.GraphModule): + return _GraphModulePickleData.reduce_helper(self, obj) + elif isinstance(obj, (torch._ops.OperatorBase, torch._ops.OpOverloadPacket)): + return _OpPickleData.reduce_helper(self, obj) + elif isinstance(obj, ShapeEnv): + return _ShapeEnvPickleData.reduce_helper(self, obj) + elif isinstance(obj, torch.SymInt): + return _SymNodePickleData.reduce_helper(self, obj) + elif isinstance(obj, torch._guards.TracingContext): + return _TracingContextPickleData.reduce_helper(self, obj) + else: + # We should never get a raw Node! + assert not isinstance(obj, torch.fx.Node) + if reduce := _TorchNumpyPickleData.reduce_helper(self, obj): + return reduce + + # returning `NotImplemented` causes pickle to revert to the default + # behavior for this object. + return NotImplemented + + @override + def persistent_id(self, obj: object) -> Optional[str]: + if obj is self._unpickle_state: + return "unpickle_state" + else: + return None + + @classmethod + def dumps(cls, obj: object, options: Optional[Options] = None) -> bytes: + """ + Pickle an object. + """ + with io.BytesIO() as stream: + pickler = cls(stream, options) + pickler.dump(obj) + return stream.getvalue() + + @staticmethod + def loads(data: bytes, fake_mode: FakeTensorMode) -> object: + """ + Unpickle an object. + """ + state = _UnpickleState(fake_mode) + with io.BytesIO(data) as stream: + unpickler = _GraphUnpickler(stream, state) + return unpickler.load() + + +class _UnpickleState: + def __init__(self, fake_mode: FakeTensorMode) -> None: + self.fake_mode = fake_mode + self.meta_converter: MetaConverter[FakeTensor] = MetaConverter() + + +# This token is passed when pickling to indicate that we want to use the +# unpickler's _UnpickleState as a parameter in that position. +_UnpickleStateToken = NewType("_UnpickleStateToken", object) + + +class _GraphUnpickler(pickle.Unpickler): + def __init__(self, stream: io.BytesIO, unpickle_state: _UnpickleState) -> None: + super().__init__(stream) + self._unpickle_state = unpickle_state + + @override + def persistent_load(self, pid: object) -> object: + if pid == "unpickle_state": + return self._unpickle_state + else: + raise pickle.UnpicklingError("Invalid persistent ID") + + +class _ShapeEnvPickleData: + data: dict[str, object] + + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: ShapeEnv + ) -> tuple[ + Callable[[Self, _UnpickleState], ShapeEnv], tuple[Self, _UnpickleStateToken] + ]: + return cls.unpickle, (cls(obj), pickler._unpickle_state) + + def __init__(self, env: ShapeEnv) -> None: + # In theory pickle should recognize that a given ShapeEnv was already + # pickled and reuse the resulting _ShapeEnvPickleData (so two objects + # pointing at the same ShapeEnv get the same ShapeEnv out). + assert not env._translation_validation_enabled + self.data = env.__dict__.copy() + del self.data["tracked_fakes"] + del self.data["fake_tensor_cache"] + + def unpickle(self, unpickle_state: _UnpickleState) -> ShapeEnv: + # Fill in the existing ShapeEnv rather than creating a new one + assert unpickle_state.fake_mode + assert unpickle_state.fake_mode.shape_env + + for k, v in self.data.items(): + setattr(unpickle_state.fake_mode.shape_env, k, v) + + return unpickle_state.fake_mode.shape_env + + +class _SymNodePickleData: + @classmethod + def reduce_helper( + cls, + pickler: GraphPickler, + obj: _SymNodeT, + ) -> tuple[ + Callable[[Self, _UnpickleState], _SymNodeT], tuple[Self, _UnpickleStateToken] + ]: + args = (cls(obj.node), pickler._unpickle_state) + if isinstance(obj, torch.SymInt): + return _SymNodePickleData.unpickle_sym_int, args + else: + raise NotImplementedError(f"Unhandled SymNode type {type(obj)}") + + def __init__(self, node: SymNode) -> None: + self.expr = node._expr + self.shape_env = node.shape_env + self.pytype = node.pytype + self.hint = node._hint + + def _to_sym_node(self) -> SymNode: + from torch.fx.experimental.sym_node import SymNode + + assert self.shape_env is not None + return SymNode(self.expr, self.shape_env, self.pytype, self.hint) + + def unpickle_sym_int(self, unpickle_state: _UnpickleState) -> torch.SymInt: + return torch.SymInt(self._to_sym_node()) + + +class _TensorPickleData: + metadata: MetaTensorDesc[FakeTensor] + + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: FakeTensor + ) -> tuple[ + Callable[[Self, _UnpickleState], FakeTensor], tuple[Self, _UnpickleStateToken] + ]: + return cls.unpickle, ( + cls(pickler._meta_tensor_describer, obj), + pickler._unpickle_state, + ) + + def __init__(self, describer: MetaTensorDescriber, t: Tensor) -> None: + # THINGS TO WORRY ABOUT: + # 1. Need to make sure that two tensors with the same id end up with the + # same id on the other side of the wire. + + metadata = describer.describe_tensor(t) + + # view_func is fine if it's either None or a _FakeTensorViewFunc. A + # custom one (which is basically a lambda) can't be serialized. + assert not metadata.view_func or isinstance( + metadata.view_func, torch._subclasses.meta_utils._FakeTensorViewFunc + ) + self.metadata = dataclasses.replace(metadata, fake_mode=None) + + # Some debugging/verification + for k in MetaTensorDesc._UNSERIALIZABLE: + if k in ("fake_mode", "view_func"): + continue + assert getattr(self.metadata, k) is None, ( + f"not None: {k}: {getattr(self.metadata, k)}" + ) + + def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: + # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? + metadata = dataclasses.replace( + self.metadata, + fake_mode=unpickle_state.fake_mode, + ) + + def with_fake( + make_meta_t: Callable[[], torch.Tensor], device: Union[torch.device, str] + ) -> FakeTensor: + with no_dispatch(): + return FakeTensor( + unpickle_state.fake_mode, + make_meta_t(), + device, + ) + + return unpickle_state.meta_converter.meta_tensor( + metadata, + unpickle_state.fake_mode.shape_env, + with_fake, + None, + None, + ) + + +class _TorchNumpyPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: object + ) -> Optional[ + tuple[ + Callable[[Self, _UnpickleState], object], tuple[Self, _UnpickleStateToken] + ] + ]: + if data := cls.from_object(obj): + return (cls.unpickle, (data, pickler._unpickle_state)) + else: + return None + + def __init__(self, mod: str, name: str) -> None: + self.mod = mod + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> Callable[..., object]: + np = getattr(importlib.import_module(self.mod), self.name) + return torch._dynamo.variables.misc.get_np_to_tnp_map()[np] + + @classmethod + def from_object(cls, tnp: object) -> Optional[Self]: + if not callable(tnp): + return None + + tnp_to_np = torch._dynamo.variables.misc.get_tnp_to_np_map() + try: + if not (np := tnp_to_np.get(tnp)): + return None + except TypeError: + return None + + if not (mod := getattr(np, "__module__", None)): + mod = "numpy" + + if not (name := getattr(np, "__name__", None)): + return None + + assert np == getattr(importlib.import_module(mod), name) + return cls(mod, name) + + +class _GraphModulePickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: torch.fx.GraphModule + ) -> tuple[ + Callable[[Self, _UnpickleState], torch.fx.GraphModule], + tuple[Self, _UnpickleStateToken], + ]: + return cls.unpickle, ( + cls(obj, pickler.options), + pickler._unpickle_state, + ) + + def __init__(self, gm: torch.fx.GraphModule, options: Options) -> None: + # Need to do this to ensure the code is created for later pickling. + if isinstance(gm, torch.fx._lazy_graph_module._LazyGraphModule): + _python_code = gm._real_recompile() + else: + _python_code = gm.recompile() + self.gm_dict = gm.__dict__.copy() + del self.gm_dict["_graph"] + self.graph = _GraphPickleData(gm._graph, options) + + def unpickle(self, unpickle_state: _UnpickleState) -> torch.fx.GraphModule: + gm = torch.fx.GraphModule.__new__(torch.fx.GraphModule) + gm.__dict__ = self.gm_dict + gm._graph = self.graph.unpickle(gm, unpickle_state) + return gm + + +class _NodePickleData: + def __init__( + self, + node: torch.fx.Node, + mapping: dict[torch.fx.Node, "_NodePickleData"], + options: Options, + ) -> None: + self.args = pytree.tree_map_only(torch.fx.Node, lambda n: mapping[n], node.args) + self.kwargs = pytree.tree_map_only( + torch.fx.Node, lambda n: mapping[n], node.kwargs + ) + # -- self.graph = node.graph + self.name = node.name + self.op = node.op + self.target = _OpPickleData.pickle(node.target, options) + # self.input_nodes = node._input_nodes + # self.users = node.users + self.type = node.type + # self.sort_key = node._sort_key + # self.repr_fn = node._repr_fn + # self.meta = node.meta + self.meta = node.meta + + def unpickle( + self, + graph: torch.fx.Graph, + mapping: dict["_NodePickleData", torch.fx.Node], + unpickle_state: _UnpickleState, + ) -> torch.fx.Node: + args = pytree.tree_map_only(_NodePickleData, lambda n: mapping[n], self.args) + kwargs = pytree.tree_map_only( + _NodePickleData, lambda n: mapping[n], self.kwargs + ) + target = self.target.unpickle(unpickle_state) + assert callable(target) or isinstance(target, str) + node = graph.create_node(self.op, target, args, kwargs, self.name, self.type) + node.meta = self.meta + return node + + +class _OpPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, op: object + ) -> tuple[Callable[[_UnpickleState], object], tuple[_UnpickleStateToken]]: + result = cls.pickle(op, pickler.options) + return (result.unpickle, (pickler._unpickle_state,)) + + @classmethod + def pickle(cls, op: object, options: Options) -> "_OpPickleData": + if isinstance(op, str): + return _OpStrPickleData(op) + + name = torch.fx.Node._pretty_print_target(op) + if isinstance(op, torch._ops.OpOverload): + return cls._pickle_op(name, _OpOverloadPickleData, options) + elif isinstance(op, torch._ops.OpOverloadPacket): + return cls._pickle_op(name, _OpOverloadPacketPickleData, options) + elif name.startswith(("builtins.", "math.", "torch.")): + root, detail = name.split(".", 1) + return _OpBuiltinPickleData(root, detail) + elif name.startswith("operator."): + _, detail = name.split(".", 1) + return _OpOperatorPickleData(detail) + else: + # TODO: raise a BypassFxGraphCache so we will just bypass this one... + raise NotImplementedError(f"TARGET: {type(op)} {op} {name}") + + @staticmethod + def _pickle_op( + name: str, + datacls: Union[ + type["_OpOverloadPickleData"], type["_OpOverloadPacketPickleData"] + ], + options: Options, + ) -> "_OpPickleData": + if (ops_filter := options.ops_filter) and not ops_filter(name): + from torch._inductor.codecache import BypassFxGraphCache + + raise BypassFxGraphCache(f"Unable to pickle non-standard op: {name}") + return datacls(name) + + @abstractmethod + def unpickle(self, unpickle_state: _UnpickleState) -> object: + pass + + @classmethod + def _lookup_global_by_name(cls, name: str) -> object: + """ + Like `globals()[name]` but supports dotted names. + """ + if "." in name: + mod, rest = name.split(".", 1) + root = globals()[mod] + return cls._getattr_by_name(root, rest) + else: + return globals()[name] + + @staticmethod + def _getattr_by_name(root: object, name: str) -> object: + """ + Like `getattr(root, name)` but supports dotted names. + """ + while "." in name: + mod, name = name.split(".", 1) + root = getattr(root, mod) + return getattr(root, name) + + +class _OpStrPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> str: + return self.name + + +class _OpOverloadPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverload: + obj = self._lookup_global_by_name(self.name) + assert isinstance(obj, torch._ops.OpOverload) + return obj + + +class _OpOverloadPacketPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> torch._ops.OpOverloadPacket: + obj = self._lookup_global_by_name(self.name) + assert isinstance(obj, torch._ops.OpOverloadPacket) + return obj + + +class _OpBuiltinPickleData(_OpPickleData): + def __init__(self, root: str, name: str) -> None: + self.root = root + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> object: + if self.root == "builtins": + return __builtins__.get(self.name) # type: ignore[attr-defined] + elif self.root == "math": + import math + + return self._getattr_by_name(math, self.name) + elif self.root == "torch": + return self._getattr_by_name(torch, self.name) + else: + raise NotImplementedError + + +class _OpOperatorPickleData(_OpPickleData): + def __init__(self, name: str) -> None: + self.name = name + + def unpickle(self, unpickle_state: _UnpickleState) -> object: + import operator + + return self._getattr_by_name(operator, self.name) + + +class _GraphPickleData: + def __init__(self, graph: torch.fx.Graph, options: Options) -> None: + self.tracer_cls = graph._tracer_cls + self.tracer_extras = graph._tracer_extras + + nodes: dict[torch.fx.Node, _NodePickleData] = {} + for node in graph.nodes: + nodes[node] = _NodePickleData(node, nodes, options) + self.nodes = tuple(nodes.values()) + + # Unpickled variables: + # self._used_names = graph._used_names + # -- self._insert = self._root.prepend + # self._len = graph._len + # self._graph_namespace = graph._graph_namespace + # self._owning_module = graph._owning_module + # self._codegen = graph._codegen + # self._co_fields: Dict[str, Any] = graph._co_fields + # -- self._find_nodes_lookup_table = _FindNodesLookupTable() + + def unpickle( + self, gm: torch.fx.GraphModule, unpickle_state: _UnpickleState + ) -> torch.fx.Graph: + graph = torch.fx.Graph(gm, self.tracer_cls, self.tracer_extras) + + nodes: dict[_NodePickleData, torch.fx.Node] = {} + for nd in self.nodes: + nodes[nd] = nd.unpickle(graph, nodes, unpickle_state) + + return graph + + +class _TracingContextPickleData: + @classmethod + def reduce_helper( + cls, pickler: GraphPickler, obj: torch._guards.TracingContext + ) -> tuple[ + Callable[[Self, _UnpickleState], torch._guards.TracingContext], + tuple[Self, _UnpickleStateToken], + ]: + return ( + cls.unpickle, + ( + cls(obj), + pickler._unpickle_state, + ), + ) + + def __init__(self, context: TracingContext) -> None: + # TODO: Do we really need all of this? + self.module_context = context.module_context + self.frame_summary_stack = context.frame_summary_stack + self.loc_in_frame = context.loc_in_frame + self.aot_graph_name = context.aot_graph_name + self.params_flat = context.params_flat + self.params_flat_unwrap_subclasses = context.params_flat_unwrap_subclasses + self.params_unwrapped_to_flat_index = context.params_unwrapped_to_flat_index + self.output_strides = context.output_strides + self.force_unspec_int_unbacked_size_like = ( + context.force_unspec_int_unbacked_size_like + ) + # Not saved (because it's difficult and maybe not needed?): + # self.fw_metadata = context.fw_metadata + # self.guards_context = None + # self.global_context = None + # self.fake_mode = None + # self.fakify_first_call = None + # self.hop_dispatch_set_cache = None + # self.tensor_to_context = context.tensor_to_context + + def unpickle(self, unpickle_state: _UnpickleState) -> TracingContext: + context = TracingContext(unpickle_state.fake_mode) + context.module_context = self.module_context + context.frame_summary_stack = self.frame_summary_stack + context.loc_in_frame = self.loc_in_frame + context.aot_graph_name = self.aot_graph_name + context.params_flat = self.params_flat + context.params_flat_unwrap_subclasses = self.params_flat_unwrap_subclasses + context.params_unwrapped_to_flat_index = self.params_unwrapped_to_flat_index + context.output_strides = self.output_strides + context.force_unspec_int_unbacked_size_like = ( + self.force_unspec_int_unbacked_size_like + ) + return context diff --git a/phivenv/Lib/site-packages/torch/fx/_lazy_graph_module.py b/phivenv/Lib/site-packages/torch/fx/_lazy_graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..7c97b34c3a2f8cc3d29f82a9dfe492258055898a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/_lazy_graph_module.py @@ -0,0 +1,185 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from torch.fx.graph_module import ( + _format_import_block, + GraphModule, + reduce_graph_module, + reduce_package_graph_module, +) +from torch.package import PackageExporter, sys_importer + +from ._compatibility import compatibility + + +_use_lazy_graph_module_flag = False +_force_skip_lazy_graph_module_flag = False + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _force_skip_lazy_graph_module(): + """ + Skip using lazy graph module disregarding the setting of _use_lazy_graph_module. + Use to skip _LazyGraphModule when testing inductor torchscript related backend. + + torch.jit.script a _LazyGraphModule results in following error: + https://gist.github.com/shunting314/5143654c8084aed84ecd19b818258a69 + """ + try: + global _force_skip_lazy_graph_module_flag + prior = _force_skip_lazy_graph_module_flag + _force_skip_lazy_graph_module_flag = True + yield + finally: + _force_skip_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +@contextmanager +def _use_lazy_graph_module(should_use: bool): + try: + global _use_lazy_graph_module_flag + prior = _use_lazy_graph_module_flag + _use_lazy_graph_module_flag = ( + should_use and not _force_skip_lazy_graph_module_flag + ) + yield + finally: + _use_lazy_graph_module_flag = prior + + +@compatibility(is_backward_compatible=False) +def _get_graph_module_cls(): + return _LazyGraphModule if _use_lazy_graph_module_flag else GraphModule + + +def _make_graph_module(*args, graph_module_cls=None, **kwargs): + if graph_module_cls is None: + graph_module_cls = _get_graph_module_cls() + + return graph_module_cls(*args, **kwargs) + + +@compatibility(is_backward_compatible=False) +class _LazyGraphModule(GraphModule): + """ + The main difference between _LazyGraphModule and GraphModule is how recompile happens. + GraphModule will do a 'recompile' call to generate python code and the forward method when it's + constructed. Later on if the graph get updated, recompile method can be called again to refresh + the saved python code and forward method. + + However in some cases especially in inductor, the recompilation can be a waste since we never + check the python code for the graph module or call its forward method. A few more concreate + examples regarding pattern matching fx passes in inductor: + 1. some passes will update the graph to be compiled and then call recompile on the GraphModule. + 2. some passes will trace small pattern function to search it in the graph being compiled and + replace the match with the traced graph of a replacement function. The pattern graph and + replacement graph are quite small but there are large amount of them. Doing GraphModule.recompile + for them in GraphModule.__init__ is also a waste of time. + + However simply skip calling GraphModule.recompile in these scenarios is also dangeruous. + People may want to check the python code or call the GraphModule's forward method for debugging purposes. + + The way _LazyGraphModule solves it is, we override the recompile method to just mark the + need for recompilation but does not do the actual recompilation. Later on if people really + access the compiled python code or call the GraphModule's forward method, we do the real + recompilation. + """ + + @classmethod + def from_graphmodule(cls, gm: GraphModule): + if isinstance(gm, _LazyGraphModule): + return gm + else: + return _LazyGraphModule(gm, gm.graph) + + @staticmethod + def force_recompile(gm): + """ + Sometimes we need force a recompile as a workaround + - we want to do the real recompilation before symbolic_trace to avoid error: + https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + """ + if isinstance(gm, _LazyGraphModule): + gm.real_recompile() + + def real_recompile(self): + if self._needs_recompile(): + self._real_recompile() + + @classmethod + def _needs_recompile(cls): + return cls.forward is cls._lazy_forward + + def _lazy_forward(self, *args, **kwargs): + # Call self.real_recompile() rather than self._real_recompile() here. + # The _lazy_forward method may be saved and call repeatedly. + # Calling self.real_recompile can make sure we skip recompilation if + # we have already done so. + self.real_recompile() + assert not self._needs_recompile() + + # call `__call__` rather than 'forward' since recompilation may + # install a wrapper for `__call__` to provide a customized error + # message. + return self(*args, **kwargs) + + forward = _lazy_forward + + # TODO: we shold handle __reduce_deploy__ the same way as __reduce_package__, + # or __reduce__ by calling _real_recompile. But I don't find a good way + # to test __reduce_deploy__ out. Also it's very unlikely that LazyGraphModule + # will be used in torch::deploy. So it's skipped for now. + + def __reduce_package__(self, exporter: PackageExporter): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Follow GraphModule.__reduce__ but call 'self._real_recompile' rather + than 'self.recompile' since for a _LazyGraphModule, self.recompile just + mark the need of recompilation and does not return the PythonCode object. + """ + python_code = self._real_recompile() + dict_without_graph = self.__dict__.copy() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _real_recompile(self): + return super().recompile() + + @classmethod + def recompile(cls): + cls.forward = cls._lazy_forward + + @property + def code(self) -> str: + self.real_recompile() + return super().code + + def __str__(self) -> str: + """ + str(GraphModule) will access the _code attribute. Make sure recompile + happens so _code attribute is available. + """ + self.real_recompile() + return super().__str__() diff --git a/phivenv/Lib/site-packages/torch/fx/_pytree.py b/phivenv/Lib/site-packages/torch/fx/_pytree.py new file mode 100644 index 0000000000000000000000000000000000000000..84e35127330791566a4bf6319f9b257d48ef7454 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/_pytree.py @@ -0,0 +1,113 @@ +from collections import namedtuple +from typing import Any, Callable, Optional, TypeVar +from typing_extensions import NamedTuple + +import torch.return_types +from torch.utils._pytree import PyTree, tree_flatten, TreeSpec + + +FlattenFuncSpec = Callable[[PyTree, TreeSpec], list] +FlattenFuncExactMatchSpec = Callable[[PyTree, TreeSpec], bool] + +SUPPORTED_NODES: dict[type[Any], FlattenFuncSpec] = {} +SUPPORTED_NODES_EXACT_MATCH: dict[type[Any], Optional[FlattenFuncExactMatchSpec]] = {} + +_T = TypeVar("_T") +_K = TypeVar("_K") +_V = TypeVar("_V") + + +def register_pytree_flatten_spec( + cls: type[Any], + flatten_fn_spec: FlattenFuncSpec, + flatten_fn_exact_match_spec: Optional[FlattenFuncExactMatchSpec] = None, +) -> None: + SUPPORTED_NODES[cls] = flatten_fn_spec + SUPPORTED_NODES_EXACT_MATCH[cls] = flatten_fn_exact_match_spec + + +def _deregister_pytree_flatten_spec( + cls: type[Any], +) -> None: + del SUPPORTED_NODES[cls] + del SUPPORTED_NODES_EXACT_MATCH[cls] + + +def tree_flatten_spec( + pytree: PyTree, + spec: TreeSpec, +) -> list[Any]: + if spec.is_leaf(): + return [pytree] + # I guess these exist for BC, FC reasons. + # In general, we should be able to directly + # use pytree tree flattener to flatten them, + # as export serializes the pytree seperately. + # Will remove it in follow up PR. + if spec.type in SUPPORTED_NODES: + flatten_fn_spec = SUPPORTED_NODES[spec.type] + child_pytrees = flatten_fn_spec(pytree, spec) + result = [] + for child, child_spec in zip(child_pytrees, spec.children_specs): + flat = tree_flatten_spec(child, child_spec) + result += flat + return result + flat_result, real_spec = tree_flatten(pytree) + if spec != real_spec: + raise RuntimeError( + f"Real spec {real_spec} of object {pytree} is different from expected spec {spec}. " + f"Please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml" + ) + return flat_result + + +def _dict_flatten_spec(d: dict[_K, _V], spec: TreeSpec) -> list[_V]: + return [d[k] for k in spec.context] + + +def _list_flatten_spec(d: list[_T], spec: TreeSpec) -> list[_T]: + return [d[i] for i in range(spec.num_children)] + + +def _tuple_flatten_spec(d: tuple[_T, ...], spec: TreeSpec) -> list[_T]: + return [d[i] for i in range(spec.num_children)] + + +def _namedtuple_flatten_spec(d: NamedTuple, spec: TreeSpec) -> list[Any]: + return [d[i] for i in range(spec.num_children)] + + +def _dict_flatten_spec_exact_match(d: dict[_K, _V], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _list_flatten_spec_exact_match(d: list[_T], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _tuple_flatten_spec_exact_match(d: tuple[_T, ...], spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +def _namedtuple_flatten_spec_exact_match(d: NamedTuple, spec: TreeSpec) -> bool: + return len(d) == spec.num_children + + +register_pytree_flatten_spec(dict, _dict_flatten_spec, _dict_flatten_spec_exact_match) +register_pytree_flatten_spec(list, _list_flatten_spec, _list_flatten_spec_exact_match) +register_pytree_flatten_spec( + tuple, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, +) +for return_type in torch.return_types.all_return_types: + register_pytree_flatten_spec( + return_type, + _tuple_flatten_spec, + _tuple_flatten_spec_exact_match, + ) +register_pytree_flatten_spec( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten_spec, + _namedtuple_flatten_spec_exact_match, +) diff --git a/phivenv/Lib/site-packages/torch/fx/_symbolic_trace.py b/phivenv/Lib/site-packages/torch/fx/_symbolic_trace.py new file mode 100644 index 0000000000000000000000000000000000000000..0bc438e496c3eb661d64c69d10323312e572a274 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/_symbolic_trace.py @@ -0,0 +1,1335 @@ +# mypy: allow-untyped-defs +import builtins +import collections +import contextlib +import copy +import functools +import inspect +import math +import os +import warnings +from itertools import chain +from types import CodeType, FunctionType, ModuleType +from typing import Any, Callable, get_args, NamedTuple, Optional, Union +from typing_extensions import TypeAlias + +import torch +import torch.utils._pytree as pytree +from torch._C import ScriptObject # type: ignore[attr-defined] +from torch._library.fake_class_registry import FakeScriptObject + +from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module +from .graph import _PyTreeCodeGen, _PyTreeInfo, Graph +from .graph_module import GraphModule +from .node import Argument, base_types, map_aggregate +from .proxy import ParameterProxy, Proxy, Scope, ScopeContextManager, TracerBase + + +HAS_VARSTUFF = inspect.CO_VARARGS | inspect.CO_VARKEYWORDS + +# These need to run in global scope to handle nested calls correctly +_orig_module_call: Callable = torch.nn.Module.__call__ +_orig_module_getattr: Callable = torch.nn.Module.__getattr__ + +_proxyable_classes: dict[type, None] = {} + +_is_fx_tracing_flag = False + +_ConstantAttributeType: TypeAlias = Union[ + torch.Tensor, torch.ScriptObject, FakeScriptObject, pytree.TreeSpec +] + +_constant_attribute_types = get_args(_ConstantAttributeType) + + +def is_fx_tracing(): + return _is_fx_tracing_flag + + +@compatibility(is_backward_compatible=True) +class ProxyableClassMeta(type): + """ + ProxyableClassMeta allows you to make construction of a given Python class + symbolically traceable. For example:: + + import torch + import torch.fx + + + class TensorPair(metaclass=torch.fx.ProxyableClassMeta): + def __init__(self, left, right): + self.left, self.right = left, right + + def add(self, other): + l = self.left + other.left + r = self.right + other.right + return TensorPair(l, r) + + def mul(self, other): + l = self.left * other.left + r = self.right * other.right + return TensorPair(l, r) + + + def use_tensor_pair_ctor(x: TensorPair, y: torch.Tensor): + s = x.add(TensorPair(y, y)) + return s.mul(x) + + + x = TensorPair(torch.randn(5, 3), torch.randn(5, 3)) + y = torch.randn(5, 3) + ref_out = use_tensor_pair_ctor(x, y) + + traced = torch.fx.symbolic_trace(use_tensor_pair_ctor) + print(traced.code) + ''' + def forward(self, x : __main___TensorPair, y : torch.Tensor): + tensor_pair = __main___TensorPair(y, y); y = None + add = x.add(tensor_pair); tensor_pair = None + mul = add.mul(x); add = x = None + return mul + ''' + + From this example, we can see that construction of a class (``TensorPair``) + defined with ``ProxyableClassMeta`` as metaclass can be recorded in symbolic + tracing. + """ + + def __init__(cls, name, bases, attrs): + _proxyable_classes.setdefault(cls) + super().__init__(name, bases, attrs) + + def __call__(cls, *args, **kwargs): + instance = cls.__new__(cls) # type: ignore[call-overload] + + if not is_fx_tracing(): + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + found_proxies = [] + + def check_proxy(a): + if isinstance(a, Proxy): + found_proxies.append(a) + + map_aggregate(args, check_proxy) + map_aggregate(kwargs, check_proxy) + + if len(found_proxies) != 0: + tracer = found_proxies[0].tracer + return tracer.create_proxy("call_function", cls, args, kwargs) + else: + cls.__init__(instance, *args, **kwargs) # type: ignore[misc] + return instance + + +def _patch_function(fn: FunctionType, nargs: int) -> FunctionType: + co = fn.__code__ + co_flags = co.co_flags & ~HAS_VARSTUFF + co_args: tuple + if hasattr(co, "co_qualname"): + # Python-3.11+ code signature + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_qualname, # type: ignore[attr-defined] + co.co_firstlineno, + co.co_lnotab, + co.co_exceptiontable, # type: ignore[attr-defined] + co.co_freevars, + co.co_cellvars, + ) + elif hasattr(co, "co_posonlyargcount"): + co_args = ( + nargs, + 0, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + else: + co_args = ( + nargs, + 0, + co.co_nlocals, + co.co_stacksize, + co_flags, + co.co_code, + co.co_consts, + co.co_names, + co.co_varnames, + co.co_filename, + co.co_name, + co.co_firstlineno, + co.co_lnotab, + co.co_freevars, + co.co_cellvars, + ) + new_code = CodeType(*co_args) # type: ignore[arg-type] + return FunctionType( + new_code, fn.__globals__, fn.__name__, fn.__defaults__, fn.__closure__ + ) + + # we need to insert placeholder nodes for *args and **kwargs + # we can't call this function normally, otherwise it would try to unpack them + # instead, let's make python think that args and kwargs are normal variables + + +@compatibility(is_backward_compatible=False) +class PHBase: + """ + Object representing an input placeholder to `concrete_args` + """ + + def __repr__(self): + return "PH" + + +PH = PHBase() + + +@compatibility(is_backward_compatible=False) +class PHWithMeta(PHBase): + """ + Object representing an input placeholder to `concrete_args` + """ + + def __init__(self, ph_key: Optional[str] = None): + super().__init__() + + # Provide a hey for user to identify placeholder node during analysis + self.ph_key = ph_key + + +def _transfer_attrs(fr, to): + for attr_name in dir(fr): + attr_val = getattr(fr, attr_name) + if ( + not callable(attr_val) + and not attr_name.startswith("__") + and not hasattr(to, attr_name) + ): + setattr(to, attr_name, attr_val) + + +@compatibility(is_backward_compatible=True) +class Tracer(TracerBase): + # Reference: https://github.com/pytorch/pytorch/issues/54354 + # The first line of this docstring overrides the one Sphinx generates for the + # documentation. We need it so that Sphinx doesn't leak `math`s path from the + # build environment (e.g. ` None: + # This method's signature is overridden by the first line of this class' + # docstring. If this method's signature is modified, the signature that + # overrides it also should be modified accordingly. + + """ + Construct a Tracer object. + + Args: + + autowrap_modules (Tuple[ModuleType]): defaults to `(math, )`, + Python modules whose functions should be wrapped automatically + without needing to use fx.wrap(). Backward-compatibility for + this parameter is guaranteed. + + autowrap_functions (Tuple[Callable, ...]): defaults to `()`, + Python functions that should be wrapped automatically without + needing to use fx.wrap(). Backward compatibility for this + parameter is guaranteed. + + param_shapes_constant (bool): When this flag is set, calls to shape, + size and a few other shape like attributes of a module's parameter + will be evaluated directly, rather than returning a new Proxy value + for an attribute access. Backward compatibility for this parameter + is guaranteed. + """ + + super().__init__() + + # Functions we will eagerly wrap when we see them while tracing + # this captures both `math.sqrt()` and `from math import sqrt` automatically + self._autowrap_function_ids: set[int] = { + id(value) + for name, value in chain.from_iterable( + m.__dict__.items() for m in autowrap_modules + ) + if not name.startswith("_") and callable(value) + } + self._autowrap_function_ids.update({id(f) for f in autowrap_functions}) + + # Python modules to apply autowrap to at the start, in addition to + # modules we see while tracing + self._autowrap_search: list[ModuleType] = list(autowrap_modules) + self.param_shapes_constant = param_shapes_constant + + self.submodule_paths: Optional[dict[torch.nn.Module, str]] = None + self.root_module_name: str = "" + # Maps the containing module's name to the operator name + self.scope = Scope("", None) + # Records the module call stack + self.module_stack = collections.OrderedDict() + self.num_calls: dict[str, int] = {} + # Mapping of node name to module scope + self.node_name_to_scope: dict[str, tuple[str, type]] = {} + + _qualname_counter: dict[str, int] = collections.defaultdict(int) + + @compatibility(is_backward_compatible=True) + def get_fresh_qualname(self, prefix: str) -> str: + """ + Gets a fresh name for a prefix and returns it. This function ensures + that it will not clash with an existing attribute on the graph. + """ + # The idea here is that if the module doesn't have this prefix at all we + # should reset the counter to start from the beginning + # It's a ... little bit hacky (doesn't cover all cases) but the precise + # naming of the prefixes isn't a correctness issue, just a niceness + # issue + qualname = f"{prefix}0" + if not hasattr(self.root, qualname): + self._qualname_counter[prefix] = 0 + return qualname + + i = self._qualname_counter[prefix] + while True: + qualname = f"{prefix}{i}" + i += 1 + if not hasattr(self.root, qualname): + break + self._qualname_counter[prefix] = i + + return qualname + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> "Argument": + """ + A method to specify the behavior of tracing when preparing values to + be used as arguments to nodes in the ``Graph``. + + By default, the behavior includes: + + #. Iterate through collection types (e.g. tuple, list, dict) and recursively + call ``create_args`` on the elements. + #. Given a Proxy object, return a reference to the underlying IR ``Node`` + #. Given a non-Proxy Tensor object, emit IR for various cases: + + * For a Parameter, emit a ``get_attr`` node referring to that Parameter + * For a non-Parameter Tensor, store the Tensor away in a special + attribute referring to that attribute. + + This method can be overridden to support more types. + + Args: + + a (Any): The value to be emitted as an ``Argument`` in the ``Graph``. + + + Returns: + + The value ``a`` converted into the appropriate ``Argument`` + """ + # The base tracer is used to construct Graphs when there is no associated + # module hierarchy, so it can never create parameter references. + # The default tracer adds the ability to refer to parameters when + # tracing modules. + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + raise NameError("parameter is not a member of this module") + elif isinstance(a, torch.Tensor): + for n_, p_ in self.root.named_buffers(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + elif isinstance(a, torch.nn.Module): + for n_, p_ in self.root.named_modules(): + if a is p_: + return self.create_node("get_attr", n_, (), {}) + # For NamedTuple instances that appear literally as args, we emit + # a node to construct the NamedTuple and use that Node as the argument. + if isinstance(a, tuple) and hasattr(a, "_fields"): + args = tuple(self.create_arg(elem) for elem in a) + return self.create_node("call_function", a.__class__, args, {}) + + # Tensors do not have a reliable string repr() from which they can be + # constructed (and we probably don't want to rely on that, either), so + # for any constant Tensor values we encounter, first search for if they + # are an attribute of some module in the module hierarchy. If so, emit + # a get_attr to retrieve that tensor. Otherwise, we'll store away the + # tensor value into a special attribute on the Module s.t. we can + # retrieve it with a get_attr. + if isinstance(a, _constant_attribute_types): + qualname: Optional[str] = self.tensor_attrs.get(a) + + # Tensor was not found in the Module hierarchy, stow it away in a + # special attribute and set the qualname to refer to that + if not qualname: + if isinstance(a, torch.Tensor): + base_name = "_tensor_constant" + elif isinstance(a, (FakeScriptObject, ScriptObject)): + base_name = "_torchbind_obj" + elif isinstance(a, pytree.TreeSpec): + base_name = "_tree_spec_constant" + else: + raise RuntimeError( + f"cannot create constant arg for {a} of type {type(a)}." + ) + qualname = self.get_fresh_qualname(base_name) + assert isinstance(qualname, str) + self.tensor_attrs[a] = qualname + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + if type(a) in _proxyable_classes: + # This is an instance of a proxyable class for which we did not + # witness its construction. Intern this as a constant attribute + + # TODO: binary search + qualname = self.get_fresh_qualname(f"_{a.__class__.__name__}_constant_") + assert isinstance(qualname, str) + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + + return super().create_arg(a) + + @compatibility(is_backward_compatible=True) + def is_leaf_module(self, m: torch.nn.Module, module_qualified_name: str) -> bool: + """ + A method to specify whether a given ``nn.Module`` is a "leaf" module. + + Leaf modules are the atomic units that appear in + the IR, referenced by ``call_module`` calls. By default, + Modules in the PyTorch standard library namespace (torch.nn) + are leaf modules. All other modules are traced through and + their constituent ops are recorded, unless specified otherwise + via this parameter. + + Args: + + m (Module): The module being queried about + module_qualified_name (str): The path to root of this module. For example, + if you have a module hierarchy where submodule ``foo`` contains + submodule ``bar``, which contains submodule ``baz``, that module will + appear with the qualified name ``foo.bar.baz`` here. + """ + return ( + m.__module__.startswith("torch.nn") + or m.__module__.startswith("torch.ao.nn") + ) and not isinstance(m, torch.nn.Sequential) + + @compatibility(is_backward_compatible=True) + def path_of_module(self, mod: torch.nn.Module) -> str: + """ + Helper method to find the qualified name of ``mod`` in the Module hierarchy + of ``root``. For example, if ``root`` has a submodule named ``foo``, which has + a submodule named ``bar``, passing ``bar`` into this function will return + the string "foo.bar". + + Args: + + mod (str): The ``Module`` to retrieve the qualified name for. + """ + # Prefer the O(1) algorithm + if self.submodule_paths: + path = self.submodule_paths.get(mod) + if path is None: + raise NameError("module is not installed as a submodule") + assert isinstance(path, str) + return path + # O(N^2) fallback in the case that we didn't store the submodule + # paths. + else: + for n, p in self.root.named_modules(): + if mod is p: + return n + raise NameError("module is not installed as a submodule") + + @compatibility(is_backward_compatible=True) + def call_module( + self, + m: torch.nn.Module, + forward: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + """ + Method that specifies the behavior of this ``Tracer`` when it encounters + a call to an ``nn.Module`` instance. + + By default, the behavior is to check if the called module is a leaf module + via ``is_leaf_module``. If it is, emit a ``call_module`` node referring to + ``m`` in the ``Graph``. Otherwise, call the ``Module`` normally, tracing through + the operations in its ``forward`` function. + + This method can be overridden to--for example--create nested traced + GraphModules, or any other behavior you would want while tracing across + ``Module`` boundaries. + + Args: + + m (Module): The module for which a call is being emitted + forward (Callable): The forward() method of the ``Module`` to be invoked + args (Tuple): args of the module callsite + kwargs (Dict): kwargs of the module callsite + + Return: + + The return value from the Module call. In the case that a ``call_module`` + node was emitted, this is a ``Proxy`` value. Otherwise, it is whatever + value was returned from the ``Module`` invocation. + """ + module_qualified_name = self.path_of_module(m) + with ScopeContextManager( + self.scope, Scope(module_qualified_name, type(m)) + ) as _scope: + # module_stack is an ordered dict so writing then deleting the + # entry is equivalent to push/pop on a list + num_calls = self.num_calls.get(module_qualified_name, 0) + module_key = ( + f"{_scope.module_path}@{num_calls}" + if num_calls > 0 + else _scope.module_path + ) + self.module_stack[module_key] = (module_qualified_name, _scope.module_type) + self.num_calls[module_qualified_name] = num_calls + 1 + if not self.is_leaf_module(m, module_qualified_name): + ret_val = forward(*args, **kwargs) + else: + ret_val = self.create_proxy( + "call_module", module_qualified_name, args, kwargs + ) + key, _ = self.module_stack.popitem(last=True) + assert key == module_key, f" Unexpected key {key}" + + return ret_val + + @compatibility(is_backward_compatible=False) + def getattr(self, attr: str, attr_val: Any, parameter_proxy_cache: dict[str, Any]): + """ + Method that specifies the behavior of this ``Tracer`` when we call getattr + on a call to an ``nn.Module`` instance. + + By default, the behavior is to return a proxy value for the attribute. It + also stores the proxy value in the ``parameter_proxy_cache``, so that future + calls will reuse the proxy rather than creating a new one. + + This method can be overridden to --for example-- not return proxies when + querying parameters. + + Args: + + attr (str): The name of the attribute being queried + attr_val (Any): The value of the attribute + parameter_proxy_cache (Dict[str, Any]): A cache of attr names to proxies + + Return: + + The return value from the getattr call. + """ + + def maybe_get_proxy_for_attr( + attr_val, collection_to_search, parameter_proxy_cache + ): + for n, p in collection_to_search: + if attr_val is p: + if n not in parameter_proxy_cache: + kwargs = {} + if ( + "proxy_factory_fn" + in inspect.signature(self.create_proxy).parameters + ): + kwargs["proxy_factory_fn"] = ( + None + if not self.param_shapes_constant + else lambda node: ParameterProxy( + self, node, n, attr_val + ) + ) + val_proxy = self.create_proxy("get_attr", n, (), {}, **kwargs) # type: ignore[arg-type] + parameter_proxy_cache[n] = val_proxy + return parameter_proxy_cache[n] + return None + + if isinstance(attr_val, torch.nn.Parameter): + maybe_parameter_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_parameters(), parameter_proxy_cache + ) + if maybe_parameter_proxy is not None: + return maybe_parameter_proxy + + if self.proxy_buffer_attributes and isinstance(attr_val, torch.Tensor): + maybe_buffer_proxy = maybe_get_proxy_for_attr( + attr_val, self.root.named_buffers(), parameter_proxy_cache + ) + if maybe_buffer_proxy is not None: + return maybe_buffer_proxy + + return attr_val + + # This method will be refactored + @compatibility(is_backward_compatible=False) + def create_args_for_root(self, root_fn, is_module, concrete_args=None): + """ + Create ``placeholder`` nodes corresponding to the signature of the ``root`` + Module. This method introspects root's signature and emits those + nodes accordingly, also supporting ``*args`` and ``**kwargs``. + """ + # In some cases, a function or method has been decorated with a wrapper + # defined via ``functools.wraps``. In this case, the outer code object + # will likely not contain the actual parameters we care about, so unwrap + # the function to get to the innermost callable. + fn_for_analysis = inspect.unwrap(root_fn) + co = fn_for_analysis.__code__ + total_args = co.co_argcount + co.co_kwonlyargcount + orig_args = list(co.co_varnames) + names_iter = iter(co.co_varnames) + args: list[Any] = [] + skip_arg_idx = 0 + if is_module: + if total_args == 0: + raise RuntimeError( + "``self`` argument cannot be part of *args expansion!" + ) + skip_arg_idx = 1 + next(names_iter) # skip self + args.append(self.root) + + sig = inspect.signature(fn_for_analysis) + + # This covers the very specific case where we are passing in flat + # concrete_args as a tuple, but our traced fn takes (*args, **kwargs). + # In this case, just take the concrete_args and pass them through. + name_idx = 0 + if ( + isinstance(concrete_args, tuple) + and len(concrete_args) > 0 + and (co.co_flags & HAS_VARSTUFF) + and total_args == 1 + ): + for concrete_arg in concrete_args: + out = self.create_proxy("placeholder", f"input_{name_idx}", (), {}) + if isinstance(concrete_arg, PHBase): + if concrete_arg != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=concrete_arg, to=out.node) + args.append(out) + name_idx += 1 + return root_fn, args + + arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)] + if isinstance(concrete_args, tuple): + if len(arg_names) != len(concrete_args): + raise RuntimeError( + f"Tracing expected {len(arg_names)} arguments but got {len(concrete_args)} concrete arguments" + ) + concrete_args = dict(zip(arg_names, concrete_args)) + + def proxy_placeholder(name): + return self._proxy_placeholder(name, concrete_args, sig, fn_for_analysis) + + args.extend(proxy_placeholder(names) for names in arg_names) + + if co.co_kwonlyargcount > 0 or co.co_flags & HAS_VARSTUFF: + # TODO: type annotations for *args and **kwargs + if co.co_flags & inspect.CO_VARARGS: + args.append(proxy_placeholder("*" + next(names_iter))) + if co.co_flags & inspect.CO_VARKEYWORDS: + args.append(proxy_placeholder("**" + next(names_iter))) + root_fn = _patch_function(root_fn, len(args)) + + flat_args, in_spec = pytree.tree_flatten(tuple(args)) + if not all(child.is_leaf() for child in in_spec.children_specs): + # In the case that we have pytree-flattened inputs in + # `concrete_args`, generate a flattening wrapper around the + # original root function and return that. + self.graph._codegen = _PyTreeCodeGen( + _PyTreeInfo(orig_args[:total_args], in_spec, None) + ) + + def flatten_fn(*args): + tree_args = pytree.tree_unflatten(list(args), in_spec) + tree_out = root_fn(*tree_args) + out_args, out_spec = pytree.tree_flatten(tree_out) + assert isinstance(self.graph._codegen, _PyTreeCodeGen) + self.graph._codegen.pytree_info = ( + self.graph._codegen.pytree_info._replace(out_spec=out_spec) + ) + return out_args + + return flatten_fn, flat_args + return root_fn, args + + @compatibility(is_backward_compatible=True) + def trace( + self, + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[dict[str, Any]] = None, + ) -> Graph: + """ + Trace ``root`` and return the corresponding FX ``Graph`` representation. ``root`` + can either be an ``nn.Module`` instance or a Python callable. + + Note that after this call, ``self.root`` may be different from the ``root`` passed + in here. For example, when a free function is passed to ``trace()``, we will + create an ``nn.Module`` instance to use as the root and add embedded constants + to. + + + Args: + + root (Union[Module, Callable]): Either a ``Module`` or a function to be + traced through. Backwards-compatibility for this parameter is + guaranteed. + concrete_args (Optional[Dict[str, any]]): Concrete arguments that should + not be treated as Proxies. This parameter is experimental and + its backwards-compatibility is *NOT* guaranteed. + + Returns: + + A ``Graph`` representing the semantics of the passed-in ``root``. + """ + global _is_fx_tracing_flag + old_is_fx_tracing_flag = _is_fx_tracing_flag + _is_fx_tracing_flag = True + try: + if isinstance(root, torch.nn.Module): + # do real recompilation for _LazyGraphModule before retracing since the trace + # method can not trace the _lazy_forward method. Got error: + # https://gist.github.com/shunting314/75549c2e82ae07ac1139c94a3583d259 + # without this. + from torch.fx._lazy_graph_module import _LazyGraphModule + + _LazyGraphModule.force_recompile(root) + + self.root = root + + assert hasattr(type(root), self.traced_func_name), ( + f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + ) + + fn = getattr(type(root), self.traced_func_name) + self.root_module_name = root._get_name() + self.submodule_paths = {mod: name for name, mod in root.named_modules()} + else: + self.root = torch.nn.Module() + fn = root + + tracer_cls: Optional[type[Tracer]] = getattr(self, "__class__", None) + self.graph = Graph(tracer_cls=tracer_cls) + if hasattr(fn, "__code__"): + code = fn.__code__ + self.graph._co_fields = { + "co_name": code.co_name, + "co_filename": code.co_filename, + "co_firstlineno": code.co_firstlineno, + } + + # When we encounter a Tensor value that's not a parameter, we look if it + # is some other attribute on the model. Construct a dict mapping Tensor + # values to the qualified name here for efficiency. This is used downstream + # in create_arg + self.tensor_attrs: dict[ + _ConstantAttributeType, + str, + ] = {} + + def collect_tensor_attrs(m: torch.nn.Module, prefix_atoms: list[str]): + for k, v in m.__dict__.items(): + if isinstance(v, _constant_attribute_types): + self.tensor_attrs[v] = ".".join(prefix_atoms + [k]) + for k, v in m.named_children(): + collect_tensor_attrs(v, prefix_atoms + [k]) + + collect_tensor_attrs(self.root, []) + + assert isinstance(fn, FunctionType) + + fn_globals = fn.__globals__ # run before it gets patched + fn, args = self.create_args_for_root( + fn, isinstance(root, torch.nn.Module), concrete_args + ) + + parameter_proxy_cache: dict[ + str, Proxy + ] = {} # Reduce number of get_attr calls + + # Method dispatch on parameters is not recorded unless it's directly used. + # Thus, we need to insert a proxy when __getattr__ requests a parameter. + @functools.wraps(_orig_module_getattr) + def module_getattr_wrapper(mod, attr): + attr_val = _orig_module_getattr(mod, attr) + return self.getattr(attr, attr_val, parameter_proxy_cache) + + @functools.wraps(_orig_module_call) + def module_call_wrapper(mod, *args, **kwargs): + def forward(*args, **kwargs): + return _orig_module_call(mod, *args, **kwargs) + + _autowrap_check( + patcher, # type: ignore[has-type] + getattr(getattr(mod, "forward", mod), "__globals__", {}), + self._autowrap_function_ids, + ) + return self.call_module(mod, forward, args, kwargs) + + with _new_patcher() as patcher: + # allow duplicate patches to support the case of nested calls + patcher.patch_method( + torch.nn.Module, + "__getattr__", + module_getattr_wrapper, + deduplicate=False, + ) + patcher.patch_method( + torch.nn.Module, + "__call__", + module_call_wrapper, + deduplicate=False, + ) + _patch_wrapped_functions(patcher) + _autowrap_check(patcher, fn_globals, self._autowrap_function_ids) + for module in self._autowrap_search: + _autowrap_check( + patcher, module.__dict__, self._autowrap_function_ids + ) + self.create_node( + "output", + "output", + (self.create_arg(fn(*args)),), + {}, + type_expr=fn.__annotations__.get("return", None), + ) + + self.submodule_paths = None + except RuntimeError as e: + if isinstance(e.args[0], str) and "data-dependent" in e.args[0]: + partial_fx_graph = self.graph.python_code( + root_module="self", + verbose=True, + ).src + e.partial_fx_graph = partial_fx_graph # type: ignore[attr-defined] + raise + + raise + finally: + _is_fx_tracing_flag = old_is_fx_tracing_flag + return self.graph + + def __deepcopy__(self, memo): + # _autowrap_search contains modules, which cannot be deepcopied. + new_tracer = Tracer.__new__(Tracer) + + for k, v in self.__dict__.items(): + if k in {"_autowrap_search"}: + new_obj = copy.copy(v) + else: + new_obj = copy.deepcopy(v, memo) + + new_tracer.__dict__[k] = new_obj + + return new_tracer + + def _proxy_placeholder(self, name, concrete_args, sig, fn_for_analysis): + if concrete_args is not None and name in concrete_args: + cnt = 0 + + def replace_ph(x): + nonlocal cnt + cnt += 1 + param = sig.parameters[name] + default: tuple[Any, ...] = ( + () if param.default is inspect.Parameter.empty else (param.default,) + ) + out = self.create_proxy( + "placeholder", f"{name}_{str(cnt)}", default, {} + ) + if isinstance(x, PHBase): + if x != PH: + # Transfer attrs in the case where you're using a placeholder other + # than the singleton PH (PH has no attributes to transfer). + # Proxies were created out of the placeholders. + # Transfer any metadata (put on the placeholders in the form of + # attributes set by the user) from the placeholder to the + # underlying nodes (the proxy is unwrapped by the user, but + # the metadata should hold). + _transfer_attrs(fr=x, to=out.node) + + return out + # Union[int, bool] == bool in Python <= 3.6 + if type(x) == bool or type(x) in base_types and type(x) != torch.Tensor: + torch._assert( + out == x, + f"{name} has been specialized to have value {x} but got another value", + ) + elif x is None: + args = ( + out, + f"{name} has been specialized to have value None but got another value", + ) + self.create_proxy("call_function", _assert_is_none, args, {}) + else: + warnings.warn( + f"Was not able to add assertion to guarantee correct input {name} to " + f"specialized function. It is up to the user to make sure that your inputs match the " + f"inputs you specialized the function with." + ) + + return x + + return pytree.tree_map(replace_ph, concrete_args[name]) + if name[0] == "*": + default: tuple[Any, ...] = () + else: + param = sig.parameters[name] + default = ( # type: ignore[assignment] + () if param.default is inspect.Parameter.empty else (param.default,) + ) + return self.create_proxy( + "placeholder", + name, + default, + {}, + type_expr=fn_for_analysis.__annotations__.get(name, None), + ) + + +# Dictionary of (id(globals dict), function name) => globals_dict to patch for +# the purposes of the wrap() API. +# We key by the globals dict id and function name to ensure we're wrapping a given +# function only once. +_wrapped_fns_to_patch: dict[tuple[int, str], dict] = {} + +# List of methods on classes to wrap (class type, function name) +# this currently only works for Tensor.* methods that aren't traced properly +_wrapped_methods_to_patch: list[tuple[type, str]] = [] + +if os.environ.get("FX_PATCH_GETITEM") == "1": + # This change is needed to trace models like PositionalEmbedding from BERT: + # https://github.com/pytorch/benchmark/blob/master/torchbenchmark/models/BERT_pytorch/bert_pytorch/model/embedding/position.py + # but causes issues in quantization documented here: + # https://github.com/pytorch/pytorch/issues/50710 + # once that is fixed we can make this the default behavior. + _wrapped_methods_to_patch.append((torch.Tensor, "__getitem__")) + + +def _find_proxy(*objects_to_search): + """ + Recursively search a data structure for a Proxy() and return it, + return None if not found. + """ + proxy = None + + def find_proxy(x): + nonlocal proxy + if isinstance(x, Proxy): + proxy = x + + map_aggregate(objects_to_search, find_proxy) + return proxy + + +def _create_wrapped_func(orig_fn): + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Given an closed-over ``orig_function`` to invoke, search the args and kwargs for + a Proxy object. If there is one, emit a ``call_function`` node to preserve the + call to this leaf function directly. Otherwise, just return the results of + this function call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return_proxy = proxy.tracer.create_proxy( + "call_function", orig_fn, args, kwargs + ) + return_proxy.node.meta["is_wrapped"] = True + return return_proxy + return orig_fn(*args, **kwargs) + + return wrapped + + +def _create_wrapped_method(cls, name): + orig_fn = getattr(cls, name) + + @functools.wraps(orig_fn) + def wrapped(*args, **kwargs): + """ + Search the args and kwargs for a Proxy object. If there is one, + emit a ``call_method`` node to preserve the call to this method + directly. Otherwise, just return the results of this function + call, as this function is not being traced. + """ + proxy = _find_proxy(args, kwargs) + if proxy is not None: + return proxy.tracer.create_proxy("call_method", name, args, kwargs) + return orig_fn(*args, **kwargs) + + return wrapped + + +class _PatchedFn(NamedTuple): + frame_dict: Any + fn_name: str + orig_fn: Any + new_fn: Any + + def revert(self): + raise NotImplementedError + + def patch(self): + raise NotImplementedError + + +class _PatchedFnSetItem(_PatchedFn): + def revert(self): + self.frame_dict[self.fn_name] = self.orig_fn + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + +class _PatchedFnDel(_PatchedFn): + def revert(self): + del self.frame_dict[self.fn_name] + + def patch(self): + self.frame_dict[self.fn_name] = self.new_fn + + +class _PatchedFnSetAttr(_PatchedFn): + def revert(self): + setattr(self.frame_dict, self.fn_name, self.orig_fn) + + def patch(self): + setattr(self.frame_dict, self.fn_name, self.new_fn) + + +class _Patcher: + def __init__(self) -> None: + super().__init__() + self.patches_made: list[_PatchedFn] = [] + self.visited: set[int] = set() + + def patch( + self, + frame_dict: dict[str, Any], + name: str, + new_fn: Callable, + deduplicate: bool = True, + ): + """ + Replace frame_dict[name] with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + if name not in frame_dict and hasattr(builtins, name): + self.patches_made.append(_PatchedFnDel(frame_dict, name, None, new_fn)) + self.patches_made[-1].patch() + elif getattr(frame_dict[name], "__fx_already_patched", False): + return # already patched, no need to do it again + else: + self.patches_made.append( + _PatchedFnSetItem(frame_dict, name, frame_dict[name], new_fn) + ) + self.patches_made[-1].patch() + + def patch_method( + self, cls: type, name: str, new_fn: Callable, deduplicate: bool = True + ): + """ + Replace object_or_dict.name with new_fn until we exit the context manager. + """ + new_fn.__fx_already_patched = deduplicate # type: ignore[attr-defined] + orig_fn = getattr(cls, name) + if getattr(orig_fn, "__fx_already_patched", False): + return # already patched, no need to do it again + self.patches_made.append(_PatchedFnSetAttr(cls, name, orig_fn, new_fn)) + self.patches_made[-1].patch() + + def visit_once(self, thing: Any): + """Return True on the first call to with thing, otherwise false""" + idx = id(thing) + if idx in self.visited: + return False + self.visited.add(idx) + return True + + def revert_all_patches(self): + """ + Remove all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.revert() + return self.patches_made + + def reapply_all_patches(self): + """ + Patch all the stored patcheds. It doesn't modify patches_made. + """ + for patch in self.patches_made: + patch.patch() + return self.patches_made + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """ + Undo all the changes made via self.patch() and self.patch_method() + """ + while self.patches_made: + # unpatch in reverse order to handle duplicates correctly + self.patches_made.pop().revert() + self.visited.clear() + + +CURRENT_PATCHER: Optional[_Patcher] = None + + +@contextlib.contextmanager +def _new_patcher(): + global CURRENT_PATCHER + prior_patcher = CURRENT_PATCHER + try: + CURRENT_PATCHER = _Patcher() + yield CURRENT_PATCHER + finally: + # Clear all the patches made by when using current patcher. + assert CURRENT_PATCHER is not None + CURRENT_PATCHER.revert_all_patches() + CURRENT_PATCHER = prior_patcher + + +@contextlib.contextmanager +def _maybe_revert_all_patches(): + current_patcher = CURRENT_PATCHER + patches_made = None + patches_removed = None + try: + if current_patcher is not None: + patches_removed = current_patcher.revert_all_patches() + yield + finally: + if current_patcher is not None: + patches_made = current_patcher.reapply_all_patches() + assert patches_made == patches_removed, ( + "CURRENT_PATCHER was changed during a revert_all_patches" + ) + + +def _patch_wrapped_functions(patcher: _Patcher): + """ + Go through ``_wrapped_fn_patch_table`` and, for each frame object, wrap + the listed global functions in the `_create_wrapped_func` wrapper. + """ + for (_, name), frame_dict in _wrapped_fns_to_patch.copy().items(): + if name not in frame_dict and hasattr(builtins, name): + orig_fn = getattr(builtins, name) + else: + orig_fn = frame_dict[name] + patcher.patch(frame_dict, name, _create_wrapped_func(orig_fn)) + + for cls, name in _wrapped_methods_to_patch: + patcher.patch_method(cls, name, _create_wrapped_method(cls, name)) + + +def _autowrap_check( + patcher: _Patcher, frame_dict: dict[str, Any], function_ids: set[int] +): + """ + Some methods, like `math.sqrt` are common enough we want to automatically wrap them as we see them. + This method searches a scope for them and patches them if found. + """ + if patcher.visit_once(frame_dict): + for name, value in frame_dict.items(): + if ( + not name.startswith("_") + and callable(value) + and id(value) in function_ids + ): + patcher.patch(frame_dict, name, _create_wrapped_func(value)) + + +@compatibility(is_backward_compatible=True) +def wrap(fn_or_name: Union[str, Callable]): + """ + This function can be called at module-level scope to register fn_or_name as a "leaf function". + A "leaf function" will be preserved as a CallFunction node in the FX trace instead of being + traced through:: + + # foo/bar/baz.py + def my_custom_function(x, y): + return x * x + y * y + + + torch.fx.wrap("my_custom_function") + + + def fn_to_be_traced(x, y): + # When symbolic tracing, the below call to my_custom_function will be inserted into + # the graph rather than tracing it. + return my_custom_function(x, y) + + This function can also equivalently be used as a decorator:: + + # foo/bar/baz.py + @torch.fx.wrap + def my_custom_function(x, y): + return x * x + y * y + + A wrapped function can be thought of a "leaf function", analogous to the concept of + "leaf modules", that is, they are functions that are left as calls in the FX trace + rather than traced through. + + Args: + + fn_or_name (Union[str, Callable]): The function or name of the global function to insert into the + graph when it's called + """ + if not callable(fn_or_name) and not isinstance(fn_or_name, str): + raise RuntimeError( + "Unsupported type for global function! Must be either a callable or " + "string name" + ) + + if callable(fn_or_name): + assert not isinstance(fn_or_name, str) # to make mypy happy + fn_name = fn_or_name.__name__ + else: + assert isinstance(fn_or_name, str), ( + "fn_or_name must be a global function or string name" + ) + fn_name = fn_or_name + + currentframe = inspect.currentframe() + assert currentframe is not None + f = currentframe.f_back + assert f is not None + if f.f_code.co_name != "": + raise NotImplementedError("wrap must be called at the top level of a module") + + # consider implementing Callable version of this via _autowrap_function_ids / _autowrap_search + # semantics would be slightly different, but would add support `from x import wrapped_function` + _wrapped_fns_to_patch[(id(f.f_globals), fn_name)] = f.f_globals + return fn_or_name + + +@compatibility(is_backward_compatible=True) +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + concrete_args: Optional[dict[str, Any]] = None, +) -> GraphModule: + """ + Symbolic tracing API + + Given an ``nn.Module`` or function instance ``root``, this function will return a ``GraphModule`` + constructed by recording operations seen while tracing through ``root``. + + ``concrete_args`` allows you to partially specialize your function, whether it's to remove control flow or data structures. + + For example:: + + def f(a, b): + if b == True: + return a + else: + return a * 2 + + FX can typically not trace through this due to the presence of control + flow. However, we can use `concrete_args` to specialize on the value of + `b` to trace through this:: + + f = fx.symbolic_trace(f, concrete_args={"b": False}) + assert f(3, False) == 6 + + Note that although you can still pass in different values of `b`, they will be ignored. + + We can also use `concrete_args` to eliminate data-structure handling from + our function. This will use pytrees to flatten your input. To avoid + overspecializing, pass in `fx.PH` for values that shouldn't be + specialized. For example:: + + def f(x): + out = 0 + for v in x.values(): + out += v + return out + + + f = fx.symbolic_trace( + f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} + ) + assert f({"a": 1, "b": 2, "c": 4}) == 7 + + + Args: + root (Union[torch.nn.Module, Callable]): Module or function to be traced and converted + into a Graph representation. + concrete_args (Optional[Dict[str, any]]): Inputs to be partially specialized + + Returns: + GraphModule: a Module created from the recorded operations from ``root``. + """ + tracer = Tracer() + graph = tracer.trace(root, concrete_args) + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + return _make_graph_module(tracer.root, graph, name) + + +@wrap +def _assert_is_none(value, msg): + assert value is None, msg diff --git a/phivenv/Lib/site-packages/torch/fx/_utils.py b/phivenv/Lib/site-packages/torch/fx/_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..eb36e081d1c25e26da25e20f722604be57db9de0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/_utils.py @@ -0,0 +1,67 @@ +# mypy: allow-untyped-defs +import sys +from typing import Optional + +import torch +from torch._logging import LazyString + + +def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs): + """ + Returns a LazyString that formats the graph code. + """ + + def format_name(): + if maybe_id is not None: + return f"{name} {maybe_id}" + else: + return name + + if "print_output" not in kwargs: + kwargs["print_output"] = False + + if "colored" in kwargs: + try: + if not sys.stdout.isatty(): + kwargs["colored"] = False + except AttributeError: + kwargs["colored"] = False + + return LazyString( + lambda: _format_graph_code( + f"===== {format_name()} =====\n", + gm.forward.__code__.co_filename, + gm.print_readable(**kwargs), + ) + ) + + +def _format_graph_code(name, filename, graph_str): + """ + Returns a string that formats the graph code. + """ + return f"TRACED GRAPH\n {name} {filename} {graph_str}\n" + + +def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[dict]: + """ + Returns the nn_module_stack of the first call_function node. + """ + for node in graph.nodes: + if node.op == "call_function" and "nn_module_stack" in node.meta: + return node.meta["nn_module_stack"] + return None + + +def get_node_context(node, num_nodes=2) -> str: + """ + Returns a string of the last num_nodes nodes in the graph. + """ + node_contexts = [] + cur = node + for _ in range(num_nodes): + node_contexts.append(cur.format_node()) + if cur.op == "root": + break + cur = cur.prev + return "\n".join(node_contexts[::-1]) diff --git a/phivenv/Lib/site-packages/torch/fx/annotate.py b/phivenv/Lib/site-packages/torch/fx/annotate.py new file mode 100644 index 0000000000000000000000000000000000000000..67d9b7feb2d81096dea84c3c05db5c7e6bcba202 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/annotate.py @@ -0,0 +1,36 @@ +# mypy: allow-untyped-defs +from torch.fx.proxy import Proxy + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +def annotate(val, type): + """ + Annotates a Proxy object with a given type. + + This function annotates a val with a given type if a type of the val is a torch.fx.Proxy object + Args: + val (object): An object to be annotated if its type is torch.fx.Proxy. + type (object): A type to be assigned to a given proxy object as val. + Returns: + The given val. + Raises: + RuntimeError: If a val already has a type in its node. + """ + if isinstance(val, Proxy): + if val.node.type: + raise RuntimeError( + f"Tried to annotate a value that already had a type on it!" + f" Existing type is {val.node.type} " + f"and new type is {type}. " + f"This could happen if you tried to annotate a function parameter " + f"value (in which case you should use the type slot " + f"on the function signature) or you called " + f"annotate on the same value twice" + ) + else: + val.node.type = type + return val + else: + return val diff --git a/phivenv/Lib/site-packages/torch/fx/config.py b/phivenv/Lib/site-packages/torch/fx/config.py new file mode 100644 index 0000000000000000000000000000000000000000..2539e748df4aa8016359bd1b068baa7653fcf686 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/config.py @@ -0,0 +1,6 @@ +# Whether to disable showing progress on compilation passes +# Need to add a new config otherwise wil get a circular import if dynamo config is imported here +disable_progress = True + +# If True this also shows the node names in each pass, for small models this is great but larger models it's quite noisy +verbose_progress = False diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__init__.py b/phivenv/Lib/site-packages/torch/fx/experimental/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8414485b57d2eeb305c5ee67062020585d1af365 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50d201eb3bae768cb28b37360f19a9840e1221eb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_backward_state.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_config.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_config.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df7f0d16674ebfe171db1c9201b4d105125ba2bb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_config.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7538cfe177cd3745498c83a660d29e34957efc99 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_constant_symnode.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd37b225924ce102d5ca627daceb4df46d269cbd Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/_dynamism.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..280802002e7d0dc1f02d6b7e8490d54966327ee3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/accelerator_partitioner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46a42b8930e07799979180994f7dfb0a7164d9c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/const_fold.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/debug.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/debug.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca3a6fb9001ba86a3dee73cb57c0f304376c8c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/debug.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57e4651f5db3120f9075223250e8dcf9422e684b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/graph_gradual_typechecker.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0aef1dc200977cebe738462030f382639b239393 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/merge_matmul.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43c39ef7d8404162e451be273017f8a1726aadac Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/meta_tracer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a731fec1afd259166d3ce51b11e9f61c435a2d4d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/normalize.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..578e453cd538dc671c0a4f9ad9e5f05369a81a62 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/optimization.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59dba285fc40d676c81eb0e7fd739d923e9ae055 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/partitioner_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..514629327c166e744fc98a55851b00809e19f73a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/proxy_tensor.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/recording.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/recording.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92ca6c5c264e57b58d8120060ddb37c3932cf38e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/recording.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32991ece7301e8179da030f2cafdd9df0aa1cb8c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/refinement_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fcba7f7ffad9bb6e0bce53fb223a3c98ff5be06a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/rewriter.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67f5a932aabf95eaf372349e6ed2b0edae7907f3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/schema_type_annotation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..012b74920186c2a91650d0d0d0d42d3da1306734 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/sym_node.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..811a2acdae8e02358731dcc36b883928ce169377 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/unify_refinements.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/validator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/validator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c3ee06258eb96f4de7bc94cb31cda61e1e784c2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/__pycache__/validator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/_backward_state.py b/phivenv/Lib/site-packages/torch/fx/experimental/_backward_state.py new file mode 100644 index 0000000000000000000000000000000000000000..cdc9705413e9c29714d6e165c4b2ab3b34796124 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/_backward_state.py @@ -0,0 +1,27 @@ +import torch.fx + + +class BackwardState: + """ + BackwardState is used to pass Python hooks from the forwards pass + into the backwards pass in Dynamo+Compiled Autograd. + + It is created by TorchDynamo and has special handling there. + Dynamo will pass an empty BackwardState to the forwards, then populate + members on it (via setattr) only after the forwards graph is finished. + Later on, in CompileAutograd we will inline and add the needed guards + on the BackwardState. + + BackwardState is identified and has special handling in AOTAutograd. + During AOTAutograd: + 1) BackwardState is an input to the forwards graph + 2) It must only be used in the backwards + 3) It will be empty in the forwards + 4) In the forwards we add a wrapper to save it + 5) In the backwards it becomes an input + 6) There can only be one per graph + + BackwardState requires CompiledAutograd. + """ + + proxy: torch.fx.Proxy diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/_config.py b/phivenv/Lib/site-packages/torch/fx/experimental/_config.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf80c41115f4a691d03c1c563e574039a031565 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/_config.py @@ -0,0 +1,106 @@ +import os +import sys +from typing import Optional + + +# [@compile_ignored: debug] Fails hard instead of graph breaking on guard on data dependent errors. +no_data_dependent_graph_break = ( + os.environ.get("TORCHDYNAMO_NO_DATA_DEPENDENT_GRAPH_BREAK", "0") == "1" +) +# [@compile_ignored: debug] Uses z3 for validating the guard optimizations transformations. +translation_validation = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION", "0") == "1" +) +# Timeout (in milliseconds) for z3 finding a solution. +# [@compile_ignored: debug] +translation_validation_timeout = int( + os.environ.get("TORCHDYNAMO_TRANSLATION_VALIDATION_TIMEOUT", "600000") +) +# Disables bisection for translation validation. +# +# Translation validation bisection is enabled by default, if translation validation +# is also enabled. This should help finding guard simplification issues. However, +# since validation uses Z3 for bisecting, it might take a lot of time. +# +# Set this configuration option so as to avoid bisecting. +# [@compile_ignored: debug] +translation_validation_no_bisect = ( + os.environ.get("TORCHDYNAMO_TRANSLATION_NO_BISECT", "0") == "1" +) +# Checks whether replaying ShapeEnv events on a freshly constructed one yields +# the a ShapeEnv with the same state. This should be used only in testing. +check_shape_env_recorded_events = False + +# TODO: Perhaps consider allowing unions for the configs below (so you can hit +# multiple reps at the same time) + +# Give extended debug information if the string representation of a guard +# matches this. For example, set this to "Ne(s0, 10)" and whenever we issue +# this guard, we will generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_guard_added = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED", None +) + +# Give extended debug information when a particular symbol is allocated. For +# example, set this to "u2" and whenever we create this symbol, we will +# generate full Python and C++ backtrace +# [@compile_ignored: debug] +extended_debug_create_symbol = os.environ.get( + "TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL", None +) + +# Give extended debug information (C++ backtrace) for all extended debug +# settings as well as errors. The C++ backtrace is slow and very spammy so we +# don't include it by default even when you're requesting extended debug. +# [@compile_ignored: debug] +extended_debug_cpp = os.environ.get("TORCHDYNAMO_EXTENDED_DEBUG_CPP", "") != "" + +# Give extended debug information (line of code) when a torch function +# is called during export. This is useful for showing progress and detecting +# where export might be stuck. Currently only works for strict=False. +# [@compile_ignored: debug] +extended_debug_current_loc = ( + os.environ.get("TORCHEXPORT_EXTENDED_DEBUG_CURRENT_LOC", "0") == "1" +) + +# [@compile_ignored: debug] Show a warning for every specialization +print_specializations = False + +# wraps (un)equalities with 'Not' class after recording the correct expression +# in the FX graph. This should incorrectly construct the divisible and replacement +# lists, and incorrectly issue guards. +inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY = False + +# [@compile_ignored: debug] Validate that ShapeEnv's version key is updated correctly +validate_shape_env_version_key = False + +# If we produce more than this many guards on a symbol, force the symbol to +# get specialized and bail out if this many guards mention this particular +# symbol. This may be slightly more aggressive than the true number of guards +# issued (as we test if we've hit the limit on-the-fly, whereas we may +# do further simplifications at final guard issuance time that make guards +# irrelevant.) +symbol_guard_limit_before_specialize: Optional[int] = None + +# This flag changes whether we should use the same symbolic variable to represent input sizes that are the same. +use_duck_shape = True + +# Controls the registration of torch.nonzero() on the meta device. +# When True, nonzero returns a tensor with shape (self.numel(), self.dim()) +# assuming all elements are none-zero. +# Default is False to prevent unintended registration. Set to True to enable. +meta_nonzero_assume_all_nonzero = False + +# Applies size-oblivious reasoning to backed symbols. This allocates a [0, inf] range for backed size symbols, +# and relies on size-oblivious semantics to avoid 0/1 specialization guards by marking them size-like. +# Currently an experimental option for export. +backed_size_oblivious = False + +# Skip dtype check in meta registrations. Only used for systems that does its own dtype checking. +skip_dtype_check_in_meta_registrations = False + +from torch.utils._config_module import install_config_module + + +install_config_module(sys.modules[__name__]) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/_constant_symnode.py b/phivenv/Lib/site-packages/torch/fx/experimental/_constant_symnode.py new file mode 100644 index 0000000000000000000000000000000000000000..828f96b3b46dc0b2956c87c68ae48094e19f84a7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/_constant_symnode.py @@ -0,0 +1,69 @@ +from typing import * # noqa: F403 + + +# Python version of c10/core/ConstantSymNodeImpl.cpp +# This needs to exist because the Python version of nested int is not compatible +# with the C++ version of constant symnode. +class ConstantIntNode: + def __init__(self, val: int): + self.val = val + + def is_constant(self) -> bool: + return True + + def maybe_as_int(self) -> int: + return self.val + + def is_int(self) -> bool: + return True + + def is_float(self) -> bool: + return False + + def is_bool(self) -> bool: + return False + + def is_nested_int(self) -> bool: + return False + + def clone(self) -> "ConstantIntNode": + return self + + def _str(self) -> str: + return str(self.val) + + def __str__(self) -> str: + return self._str() + + def __repr__(self) -> str: + return self._str() + + def _graph_repr(self) -> str: + return self._str() + + def mul(self, other: Any) -> Any: + return other.mul(self) + + def eq(self, other: Any) -> Any: + return other.eq(self) + + def ne(self, other: Any) -> Any: + return other.ne(self) + + def gt(self, other: Any) -> Any: + return other.lt(self) + + def lt(self, other: Any) -> Any: + return other.gt(self) + + def le(self, other: Any) -> Any: + return other.ge(self) + + def ge(self, other: Any) -> Any: + return other.le(self) + + def is_symbolic(self) -> bool: + return False + + def constant_int(self) -> int: + return self.val diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/_dynamism.py b/phivenv/Lib/site-packages/torch/fx/experimental/_dynamism.py new file mode 100644 index 0000000000000000000000000000000000000000..971eac52026b977d6842e5fd564662461665153d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/_dynamism.py @@ -0,0 +1,118 @@ +import re +from typing import Any, Callable, Union + +import torch +from torch.utils._pytree import tree_flatten_with_path, tree_map + + +KeyPath = tuple[Any, ...] +NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]] + +__all__ = [ + "normalize_source_name", + "module_to_nested_dict", + "track_dynamism_across_examples", + "clone_and_convert_to_meta", +] + + +def normalize_source_name(name: str) -> str: + # Match attribute access like .x and replace with ['x'] + return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name) + + +def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]: + """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys.""" + self_dict: dict[str, Any] = {} + + self_dict["_parameters"] = {} + self_dict["_modules"] = {} + + for attr_name in dir(module): + try: + if not attr_name.startswith("_") and not callable( + getattr(module, attr_name) + ): + attr_value = getattr(module, attr_name) + if ( + not isinstance(attr_value, torch.nn.Module) + and isinstance(attr_value, (int, float, torch.Tensor)) + and type(attr_value) is not bool + ): + self_dict[attr_name] = attr_value + except NotImplementedError: + # Skip attributes that raise NotImplementedError since they won't + # contain any dynamism anyways. + continue + + for name, param in module.named_parameters(recurse=False): + self_dict["_parameters"][name] = param + for name, buffer in module.named_buffers(recurse=False): + self_dict["_parameters"][name] = buffer + + for name, submodule in module.named_children(): + self_dict["_modules"][name] = module_to_nested_dict(submodule) + + return self_dict + + +def track_dynamism_across_examples( + example_inputs: list[Any], +) -> dict[Any, Any]: + """ + This function analyzes a list of example inputs to determine the dynamism of their shapes. + It tracks whether the dimensions of tensors or non-tensor values change across + different examples. The function returns a dictionary where each key represents + a path to a value in the input examples, and the corresponding value is a tuple + indicating which dimensions are dynamic (i.e., change across examples). This + helps in understanding how the structure of data varies across different instances. + """ + tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {} + + for ex in example_inputs: + if "self" in ex and isinstance(ex["self"], torch.nn.Module): + ex["self"] = module_to_nested_dict(ex["self"]) + leaves_with_paths, _ = tree_flatten_with_path(ex) + for key_path, value in leaves_with_paths: + if not isinstance(value, (int, float, torch.Tensor)): + continue + if isinstance(value, torch.Tensor): + shape: tuple[int | float, ...] = tuple(value.shape) + is_tensor = True + else: + shape = (value,) + is_tensor = False + if key_path not in tracking: + tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor) + else: + dim_sets, flag = tracking[key_path] + if flag != is_tensor: + pass + while len(dim_sets) < len(shape): + dim_sets.append(set()) + for i, dim in enumerate(shape): + tracking[key_path][0][i].add(dim) + + output: dict[Any, Any] = {} + for key_path, (dim_sets, _is_tensor) in tracking.items(): + final_dyn = tuple(len(s) > 1 for s in dim_sets) + key_str = "L" + "".join(f"{str(k)}" for k in key_path) + key = key_path[0].key # type: ignore[attr-defined] + if key not in output: + output[key] = {} + output[key][key_str] = final_dyn + return output + + +def clone_and_convert_to_meta(example_input: Any) -> Any: + """ + This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta. + For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively. + """ + + def transform_fn(value: Any) -> Any: + if isinstance(value, torch.Tensor): + return value.clone().to(device="meta") + return value + + return tree_map(transform_fn, example_input) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/accelerator_partitioner.py b/phivenv/Lib/site-packages/torch/fx/experimental/accelerator_partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcec276db0e98b8aaccaa8ef772474c6eac3d81 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/accelerator_partitioner.py @@ -0,0 +1,1080 @@ +# mypy: allow-untyped-defs +import operator +from collections import deque +from typing import NamedTuple + +import torch +from torch.fx.experimental.partitioner_utils import ( + Device, + get_extra_size_of, + get_latency_of_partitioned_graph, + get_partition_to_latency_mapping, + NodeLatency, + Partition, + PartitionerConfig, + PartitionMode, +) +from torch.fx.graph_module import GraphModule +from torch.fx.node import map_arg, Node +from torch.fx.passes.graph_manipulation import get_size_of_all_nodes +from torch.fx.passes.split_module import split_module + + +class DAGNode: + """DAGNode class maintains useful information for a partition (submodule), + and its input submodules and output submodules. + """ + + def __init__( + self, + submodule_node: Node, + input_nodes: list[Node], + output_nodes: list[Node], + logical_device_ids: list[int], + size_bytes: int, + ) -> None: + self.submodule_node: Node = submodule_node + self.input_nodes: list[Node] = input_nodes + self.output_nodes: list[Node] = output_nodes + self.logical_device_ids: list[int] = logical_device_ids + self.size_bytes = size_bytes + + def __str__(self) -> str: + return str(self.submodule_node) + + +class DAG: + """DAG class contains all the DAG nodes""" + + def __init__(self) -> None: + self.nodes: list[DAGNode] = [] + + def create_node( + self, + submodule_node: Node, + input_nodes: list[Node], + output_nodes: list[Node], + logical_devices: list[int], + size_bytes: int, + ) -> None: + node = DAGNode( + submodule_node, input_nodes, output_nodes, logical_devices, size_bytes + ) + self.nodes.append(node) + + +class PartitionResult(NamedTuple): + """NameTuple used for returning DAG and a new fx module""" + + dag: DAG + module_with_submodules: GraphModule + + +"""Followings are some helper functions for partition manipulation""" + + +def reset_partition_device(partitions): + for partition in partitions: + partition.logical_device_ids = [] + + +def combine_two_partitions( + partition_0: Partition, partition_1: Partition, partitions: list[Partition] +) -> None: + """Given a list of partitions and its two partitions, + combine these two partitions into a new one appending to the partitions + and remove the previous two partitions from the list of partitions + """ + partition = Partition(len(partitions)) + partition.nodes = partition_0.nodes.union(partition_1.nodes) + partition.recalculate_mem_size() + partitions.append(partition) + partitions.remove(partition_0) + partitions.remove(partition_1) + reorganize_partitions(partitions) + return + + +def set_parents_and_children(partitions: list[Partition]) -> None: + """Given a list of partitions, mark parents and children for each partition""" + # Go through all nodes in a partition. + # If a node's user is in other partition, + # then the other partition is this partition's children. + # This partition is the other partition's parent + for partition in partitions: + partition.children = set() + partition.parents = set() + for partition in partitions: + for node in partition.nodes: + # For each node in the current partition, find its users + users = node.users + for n in users: + # Find which the partition the user node belongs to. + # Note that if the node itself is also belongs to that partition, + # that partition is not the child of the current partition + for p in partitions: + if p != partition and n in p.nodes and node not in p.nodes: + partition.children.add(p) + p.parents.add(partition) + return + + +def reorganize_partitions(partitions: list[Partition]) -> None: + """Given a list of partitions, reorganize partition id, + its parents and its children for each partition + """ + # Rearrange partition ids + for i, partition in enumerate(partitions): + partition.partition_id = i + set_parents_and_children(partitions) + return + + +def get_bfs_level_partition(partitions: list[Partition]) -> None: + """Given a list of partitions, + mark the bfs level for each partition + """ + current_level: set[Partition] = set() + visited: set[Partition] = set() + for partition in partitions: + # If a partition has no parent, it should be in root level + if len(partition.parents) == 0: + current_level.add(partition) + next_level: set[Partition] = set() + level = 0 + # bfs + while current_level: + partition = current_level.pop() + partition.bfs_level = level + visited.add(partition) + children = partition.children + for child in children: + if child not in next_level: + next_level.add(child) + if not current_level: + current_level = next_level.copy() + next_level = set() + level += 1 + return + + +def get_node_to_partition_mapping(partitions: list[Partition]) -> dict[Node, int]: + """Given a list of partitions,return node to partition mapping""" + node_to_partition: dict[Node, int] = {} + for partition in partitions: + for node in partition.nodes: + node_to_partition[node] = partition.partition_id + return node_to_partition + + +def get_logical_id_to_device(devices: list[Device]) -> dict[int, Device]: + """Get a mapping from device logical ID to Device object.""" + logical_id_to_device: dict[int, Device] = {} + for d in devices: + logical_id_to_device[d.logical_id] = d + return logical_id_to_device + + +def get_device_partition_stats( + partitions: list[Partition], devices: list[Device] +) -> tuple[dict[Device, list[Partition]], dict[Device, int], list[Partition]]: + """Given a list of partitions and a list of devices, returns: + 1. A mapping from device to partitions on it; + 2. A mapping from device to its remaining memory size; + 3. A list of partitions that do not have a device. + """ + # logical id to device + logical_id_to_device = get_logical_id_to_device(devices) + # Track partitions on device + device_to_partitions: dict[Device, list[Partition]] = {} + # Track device's left mem size + device_to_left_mem_bytes: dict[Device, int] = {} + for d in devices: + device_to_partitions[d] = [] + device_to_left_mem_bytes[d] = d.available_mem_bytes + + # Deal with the partitions that already have a device + # and also collect all partitions without a device (no_device_partitions) + no_device_partitions = [] + for partition in partitions: + if partition.logical_device_ids != []: + for logical_id in partition.logical_device_ids: + device = logical_id_to_device[logical_id] + device_to_partitions[device].append(partition) + device_to_left_mem_bytes[device] -= partition.used_mem_bytes + else: + no_device_partitions.append(partition) + + return ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) + + +def get_device_to_partitions_mapping( + partitions: list[Partition], devices: list[Device] +): + """Given a list of partitions and a list of devices, + map each partition into a device. + """ + + def calculate_extra_mem_bytes_needed_for( + partition: Partition, partitions: list[Partition] + ): + all_nodes: set[Node] = set() + for p in partitions: + all_nodes = all_nodes.union(p.nodes) + if len(all_nodes) == 0: + return partition.used_mem_bytes + all_nodes = all_nodes.union(partition.nodes) + extra_size_needed = 0 + for node in partition.nodes: + extra_size_needed += get_extra_size_of(node, all_nodes) + return extra_size_needed + + def find_device_for(partition: Partition): + """Given a partition, find a logical device for the partition + The algorithm is to put the partition on the device + that has just enough mem left for that partition. + device_to_left_mem_bytes is a dictionary between device and its left mem size + sorted by its left mem size + """ + for d in device_to_left_mem_bytes: + extra_size_needed = calculate_extra_mem_bytes_needed_for( + partition, device_to_partitions[d] + ) + if extra_size_needed < device_to_left_mem_bytes[d]: + device_to_partitions[d].append(partition) + partition.logical_device_ids.append(d.logical_id) + device_to_left_mem_bytes[d] -= extra_size_needed + return True + return False + + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(partitions, devices) + + # Find devices for all the partitions without a device + found_device = True + for partition in no_device_partitions: + device_to_left_mem_bytes = dict( + sorted(device_to_left_mem_bytes.items(), key=operator.itemgetter(1)) + ) + found_device = find_device_for(partition) + if not found_device: + break + return found_device + + +def check_dependency(partition): + """Given a partition,check if there is a circular dependency on + this partition using bfs + """ + visited: set[Partition] = {partition} + queue: deque[Partition] = deque([partition]) + while queue: + p = queue.popleft() + for child in p.children: + if child == partition: + return True + else: + if child not in visited: + visited.add(child) + queue.append(child) + return False + + +class Partitioner: + """A fx module may not fit into one device. + Partitioner class helps partition one fx module into submodules (partitions), + so that the submodules can be executed crossing different accelerators. + The main function of this class is self.partition_graph. + It partitions the fx module based on the scheme specified in partition_config + A DAG structure is returned + along with a new fx module with submodule nodes. + """ + + def __init__(self) -> None: + self.partitions: list[Partition] = [] + self.node_to_partition: dict[Node, int] = {} + self.devices: list[Device] = [] + + def partition_graph( + self, + fx_module: GraphModule, + torch_module: torch.nn.Module, + partitioner_config: PartitionerConfig, + ) -> PartitionResult: + """Given the fx module, torch module and partitioner_config, + find the partitions, do the partitions, + and then return a DAG and a new fx module with submodule nodes (partitions) + """ + self.graph_module = fx_module + self.torch_module = torch_module + self.devices = partitioner_config.devices + if len(self.devices) == 0: + raise RuntimeError("No devices") + # Tag the size in bytes to all nodes in the graph_module. + get_size_of_all_nodes(self.graph_module) + # Check if there are op nodes in the fx module + nodes = self.graph_module.graph.nodes + if all(node.op in {"placeholder", "get_attr", "output"} for node in nodes): + raise RuntimeError("No Partition since no operations in the module") + # Calculate total size of the fx module + total_size_of_graph = 0 + for node in nodes: + if node.op == "output": + break + total_size_of_graph += node.size_bytes.total_size + # Find the device with the max mem size + device_with_max_mem = max(self.devices, key=lambda d: d.available_mem_bytes) + # AOT based partition + if partitioner_config.mode == PartitionMode.aot_based: + self.aot_based_partition( + partitioner_config.node_to_partition_mapping, + partitioner_config.partition_to_logical_device_mapping, + ) + # Single partition if the whole module can be fit into one device + elif total_size_of_graph <= device_with_max_mem.available_mem_bytes: + self.find_single_partition( + total_size_of_graph, logical_device_id=device_with_max_mem.logical_id + ) + elif total_size_of_graph > sum(d.available_mem_bytes for d in self.devices): + raise RuntimeError("Devices have no enough memory for the module") + else: + # Sparse nn based partition + if partitioner_config.mode == PartitionMode.sparse_nn: + available_mem_bytes = self.devices[0].available_mem_bytes + if not all( + device.available_mem_bytes == available_mem_bytes + for device in self.devices + ): + raise RuntimeError("All devices must have same memory size!") + # sparse_nn_partition only support same memory size + # TODO: add different size support for sparse_nn_partition + self.sparse_nn_partition(available_mem_bytes) + # Cost aware partition + elif partitioner_config.mode == PartitionMode.cost_aware: + self.cost_aware_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + # KL based partition + elif partitioner_config.mode == PartitionMode.kl_based: + self.kl_based_partition( + partitioner_config.transfer_rate_bytes_per_sec, + partitioner_config.node_to_latency_mapping, + ) + else: + self.size_based_partition() + + # Saturate host if possible. + if partitioner_config.saturate_host: + self.saturate_host() + + # Partition the graph module based on the partition assignment. + module_with_submodules = self.do_partition() + + # The DAG contains DAGNodes with info of each partition's input nodes, output nodes + # and how partitions are connected. + dag = self.dump_dag(module_with_submodules) + ret = PartitionResult(dag, module_with_submodules) + return ret + + def find_single_partition( + self, total_size_of_graph, logical_device_id: int = 0 + ) -> None: + """Fit the whole fx module into one device""" + partition_0 = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op == "output": + # Skip the output node, but there can + # be nodes after the output in certain cases. + continue + partition_0.nodes.add(node) + partition_0.used_mem_bytes = total_size_of_graph + partition_0.logical_device_ids = [logical_device_id] + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def size_based_partition(self) -> None: + """This method is to partition the fx module based on memory size. + It uses greedy approach. The result may not be the best. + The basic idea is: + Step 1: + Find a device which has enough memory to fit the current node, create a empty partition + with the size of that device. + Then keep adding the following nodes into the partition until the partition is full. + Step 2: + Repeat Step 1 until no device left + Step 3: + If some nodes are left, create a partition for each left node (single node partition). + and then try to map those partitions into logical devices with enough mem left. + """ + + def find_device_based_on_size(node) -> Device: + """Given a node, this function is to find a logical device + that could fit the node. + """ + mem_size_needed = get_extra_size_of(node, set()) + device = Device("", -1, -1) + for d in self.devices: + if ( + d not in occupied_devices + and d.available_mem_bytes >= mem_size_needed + ): + device = d + break + if device.available_mem_bytes < 0: + raise RuntimeError(str(node) + "is too large to fit any device") + occupied_devices.append(device) + return device + + # Track partition and its left mem size + partition_to_left_mem_bytes: dict[Partition, int] = {} + # Track all the devices that have been used + occupied_devices: list[Device] = [] + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if there are devices left + if len(self.partitions) <= len(self.devices): + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + # Check if the current partition is the very first partition + if partition.used_mem_bytes == 0: + # Find a device to fit the first node, return available mem size + device = find_device_based_on_size(node) + occupied_devices.append(device) + # Update partition and its left mem size + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) + # Update available mem for the current partition + partition.logical_device_ids.append(device.logical_id) + else: + # The current partition is not the first partition + # Check if the current node can fit into current partition + if ( + partition_to_left_mem_bytes[partition] + < total_size_of_input_nodes + ): + # Check if no device is left + if len(self.partitions) == len(self.devices): + # No device is left + # Create the first single node partition for the current node + self.create_single_node_partition(node) + continue + # Some devices are still left + # Create a new partition with a mem size that is enough for the current node + device = find_device_based_on_size(node) + partition = self.create_partition() + total_size_of_input_nodes = get_extra_size_of( + node, partition.nodes + ) + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) + partition.logical_device_ids.append(device.logical_id) + partition.add_node(node) + partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes + # Create single node partitions if no device is left + else: + self.create_single_node_partition(node) + reorganize_partitions(self.partitions) + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + # Mapping all partitions into device + found_partition_to_device_mapping = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_partition_to_device_mapping: + raise RuntimeError("Cannot Get a Valid Partition to Logical Device Mapping") + return + + def saturate_host(self) -> None: + """Saturate host by assigning replicates to unused devices with enough memory. + It uses a greedy approach to find a next available set of devices to place all split + partitions: For each used device, it searches for an idle device with minimal memory + size that can hold all the partition located on that device; If the search is successful + for all used devices, it then assigns the new devices' logical ID to the corresponding + partition. + """ + ( + device_to_partitions, + device_to_left_mem_bytes, + no_device_partitions, + ) = get_device_partition_stats(self.partitions, self.devices) + + assert len(no_device_partitions) == 0, ( + f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + ) + + # Devices that hold partitions + used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] + # Track replicates of the assigned devices + replicated_device_to_used_device: dict[Device, Device] = {} + + while len(used_devices) * 2 + len(replicated_device_to_used_device) <= len( + self.devices + ): + # Success flag for this round + success = True + # Devices that have not been assigned + idle_devices = [ + d + for d in self.devices + if d not in used_devices and d not in replicated_device_to_used_device + ] + # Temporary mapping from replicated device to original device + temp_replicate_mapping = {} + + # Find a new device to replicate all partitions on an used device + for used_device in used_devices: + # Idle devices that have enough memory + available_devices = [ + d + for d in idle_devices + if d.available_mem_bytes + >= used_device.available_mem_bytes + - device_to_left_mem_bytes[used_device] + ] + if len(available_devices) == 0: + success = False + break + new_device = min(available_devices, key=lambda d: d.available_mem_bytes) + idle_devices.remove(new_device) + temp_replicate_mapping[new_device] = used_device + + if not success: + break + replicated_device_to_used_device.update(temp_replicate_mapping) + + # Update logical device IDs assigned to the partitions + for ( + replicate_device, + original_device, + ) in replicated_device_to_used_device.items(): + logical_id = replicate_device.logical_id + for partition in device_to_partitions[original_device]: + partition.logical_device_ids.append(logical_id) + for p in self.partitions: + print(p.logical_device_ids) + + def do_partition(self) -> GraphModule: + """Return a new fx module with submodule nodes (partitions).""" + module_with_submodules = split_module( + self.graph_module, + self.torch_module, + lambda node: self.node_to_partition[node], + ) + return module_with_submodules + + def dump_dag(self, module_with_submodules: GraphModule) -> DAG: + """Return the dag structure and the new fx module with submodules.""" + dag = DAG() + for node in module_with_submodules.graph.nodes: + if node.op == "output": + break + if node.op in {"placeholder", "get_attr"}: + continue + if node.target == operator.__getitem__: + continue + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # When a node has two or more output nodes, + # it outputs its result to 'getitem' nodes. + # Those 'getitem' nodes are the output node for this node. + # Otherwise, the output node is this node itself. + if len(node.users) > 1: + output_nodes = list(node.users) + else: + output_nodes = [node] + partition_id = int(node.name.rsplit("_", 1)[-1]) + device_ids = self.partitions[partition_id].logical_device_ids + size_bytes = self.partitions[partition_id].used_mem_bytes + dag.create_node( + node, list(input_nodes), output_nodes, device_ids, size_bytes + ) + return dag + + def create_partition(self) -> Partition: + """Create a partition and append it to self.partitions.""" + partition_id = len(self.partitions) + partition = Partition(partition_id) + self.partitions.append(partition) + return partition + + def create_single_node_partition(self, node): + """Create a partition for a single node""" + partition = self.create_partition() + partition.add_node(node) + return + + def sparse_nn_partition(self, available_mem_bytes: int) -> None: + """This method partition a sparse nn module. + It is size based partition but different from size_based_partition, + it only works when all the devices have same memory size (available_mem_bytes). + In the future, devices with different mem sizes will be supported like size_based_partition. + It first traverse all the nodes and do the partitions based on the same memory size. + If the current partition has no enough memory left for a new op node + (call_module, call_method, call_function), a new partition is created. + When crossing the boundary between non-embedding nodes and embedding nodes, + a new partition is created regardlessly. + For example, if the current node is a non-embedding node but the next node is an + embedding node, a new partition is created for the next node. + After the partition, the partitions are combined as much as possible. + The rule is that a non-embedding partition only + combines with another non-embedding one. + So as the embedding partitions. + """ + + def combine_partitions_based_on_size( + partitions: list[Partition], available_mem_bytes: int + ) -> None: + """Combining small partitions together to keep as less partitions as possible. + Here is an example of the algorithm to do this: + Assume some partitions, we first sort them based on partition used memory size. + [(partition_4, 1), (partition_3, 1), (partition_2, 2), (partition_1, 7), (partition_0, 9)] + The available memory is 10. + step 1: self.find_partition_to_combine_based_on_size() + First, mark bfs level for each partition + Second, look the smallest partition, partition_4: 10 - 1 = 9 + It means any partition has a used memory equal or less than 9 could combine this partition + We go from the largest and selection partition_0. + Check the bfs level for two partitions, if the level difference is less than 2, + it can be combined. + step 2: repeat step 1 until no partitions can be combined + """ + find_combination = True + while find_combination: + # Sort partitions based on memory size + sorted_partitions = sorted(partitions, key=lambda p: p.used_mem_bytes) + # Mark bfs level + get_bfs_level_partition(self.partitions) + find_combination, partitions = find_partition_to_combine_based_on_size( + sorted_partitions, available_mem_bytes, partitions + ) + return + + def calculate_mem_bytes_needed(p1, p2): + """Given two partitions, calculate how many mem bytes + are needed if two partitions are combined + """ + nodes = p1.nodes.union(p2.nodes) + mem_bytes_needed = 0 + for node in nodes: + mem_bytes_needed += get_extra_size_of(node, nodes) + return mem_bytes_needed + + def find_partition_to_combine_based_on_size( + sorted_partitions: list[Partition], + available_mem_bytes: int, + partitions: list[Partition], + ) -> tuple[bool, list[Partition]]: + """step 1 in combine_partition_based_on_size()""" + find_combination = False + smallest_partition = sorted_partitions.pop(0) + for p in sorted_partitions[::-1]: + if abs(smallest_partition.bfs_level - p.bfs_level) <= 1: + # Calculate how many bytes needed if combined + mem_bytes_needed = calculate_mem_bytes_needed(p, smallest_partition) + if mem_bytes_needed <= available_mem_bytes: + combine_two_partitions(p, smallest_partition, self.partitions) + partitions.remove(smallest_partition) + partitions.remove(p) + partitions.append(self.partitions[-1]) + find_combination = True + break + return find_combination, partitions + + def reset_partition_in_sparse_nn(partition, new_partition=True): + """If crossing the boundary between non-embedding nodes and + embedding nodes, create a new partition + """ + if in_embedding_region: + embedding_partitions.append(partition) + else: + non_embedding_partitions.append(partition) + if new_partition: + partition = self.create_partition() + partition.left_mem_bytes = available_mem_bytes + return partition + return None + + def is_embedding_node(node: Node) -> bool: + """Check if a node is an embedding node""" + if node.op == "call_module": + submodule = self.graph_module + for atom in str(node.target).split("."): + if not hasattr(submodule, atom): + raise RuntimeError( + f"Module {submodule} has no attribute {atom}" + ) + submodule = getattr(submodule, atom) + if "Embedding" in str(submodule): + return True + return False + + # Track embedding partitions and non-embedding partitions separately + embedding_partitions: list[Partition] = [] + non_embedding_partitions: list[Partition] = [] + # A Flag to check the boundary + in_embedding_region: bool = False + partition = self.create_partition() + for node in self.graph_module.graph.nodes: + if node.op in {"call_module", "call_method", "call_function"}: + # Check if crossing the boundary between embedding nodes and non embedding nodes + if is_embedding_node(node) != in_embedding_region: + # Crossing the boundary + # Check if the current partition is an empty partition + if partition.used_mem_bytes != 0: + # The current partition isn't an empty partition. Create a new one. + partition = reset_partition_in_sparse_nn(partition) + in_embedding_region = not in_embedding_region + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if ( + total_size_of_input_nodes + partition.used_mem_bytes + > available_mem_bytes + ): + partition = reset_partition_in_sparse_nn(partition) + total_size_of_input_nodes = get_extra_size_of(node, partition.nodes) + if total_size_of_input_nodes > available_mem_bytes: + raise RuntimeError( + node.target + "is too large to fit into a device" + ) + partition.add_node(node) + reset_partition_in_sparse_nn(partition, new_partition=False) + # Set parents and children for partitions + set_parents_and_children(self.partitions) + # Combining non-embedding partitions + combine_partitions_based_on_size(non_embedding_partitions, available_mem_bytes) + # Combining embedding partitions + combine_partitions_based_on_size(embedding_partitions, available_mem_bytes) + total_size_of_non_embedding_partitions = 0 + for partition in non_embedding_partitions: + total_size_of_non_embedding_partitions += partition.used_mem_bytes + # Check if devices are enough for all partitions + if len(embedding_partitions) > len(self.devices): + msg = ( + "Need " + + str(len(embedding_partitions)) + + " devices, but only " + + str(len(self.devices)) + + " provided" + ) + raise RuntimeError(msg) + occupied_devices = [] + for i, partition in enumerate(embedding_partitions): + # Check if all non-embedding partitions can fit into embedding partition devices + if ( + total_size_of_non_embedding_partitions + partition.used_mem_bytes + > available_mem_bytes + ): + raise RuntimeError( + "partition_" + + str(partition.partition_id) + + "(embedding partition) and non embedding partitions can not fit into one device" + ) + else: + # Add logical device to the partition + partition.logical_device_ids = [self.devices[i].logical_id] + occupied_devices.append(self.devices[i].logical_id) + # Add logical devices to the non_embedding_partitions + for partition in non_embedding_partitions: + partition.logical_device_ids = occupied_devices + # Get the node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def cost_aware_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], + ) -> None: + """This method is to partition the fx module based on the cost. + The cost is the total latency of running the whole fx module. + In partitioner_utils.py, the cost model is built. + The cost aware partition algorithm is: + #1. At every beginning, each node is a partition. + Then we map all the partitions to the devices + and calculate the cost + #2. Then try to pre-combine any two of the partitions if the two + partitions can be combined. + (the bfs level is less than 2 or two partitions are connected and + can find partition to device mapping) + See if any partition pair could reduce the current cost. + Choose the pair that shows the minimum cost and then combine them + #3. Repeat #2 until the cost cannot be reduced. + """ + + def try_combining_partitions(p0_index, p1_index, partitions) -> float: + """Given two partitions and a list of partitions, combine these two partitions + and see what is the cost of the modified partition list + """ + p0 = partitions[p0_index] + p1 = partitions[p1_index] + """If two partitions' bfs level are less than 2 or two partitions are connected to each other, + then they can be combined + """ + if ( + (abs(p0.bfs_level - p1.bfs_level) <= 1) + or (p0 in p1.parents) + or p0 in (p1.children) + ): + combine_two_partitions(p0, p1, partitions) + # Check if a circular dependency exists after combining + if check_dependency(partitions[-1]): + return float("inf") + # Check if the modified partition list can be mapped to devices after combination + reset_partition_device(partitions) + found_deivce = get_device_to_partitions_mapping( + partitions, self.devices + ) + if not found_deivce: + return float("inf") + # Calculate the new cost + partition_to_latency_mapping = get_partition_to_latency_mapping( + partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + return cost + # If two partition can not be combined, the cost is inf + return float("inf") + + def search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) -> bool: + """Given transfer rate between partitions and each node's latency, + find two partitions to combine so the cost of the partitions can + be reduced. + The algorithm is : + 1. Go through all the partition pairs and see + if any pair of partitions can be combined. + 2. Calculate the cost after the combination. + 3. Select the minimum cost and combine its corresponding partition pair. + """ + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + if len(self.partitions) == 1: + return False + partition_pair: list[int] = [] + for i in range(len(self.partitions) - 1): + for j in range(i + 1, len(self.partitions)): + # Try to combine the partition pair + # and see the new cost after combination + new_cost = try_combining_partitions(i, j, self.partitions[:]) + if new_cost <= cost: + partition_pair = [i, j] + cost = new_cost + reorganize_partitions(self.partitions) + # If a partition pair is found, combine them + if len(partition_pair) != 0: + p0 = self.partitions[partition_pair[0]] + p1 = self.partitions[partition_pair[1]] + combine_two_partitions(p0, p1, self.partitions) + get_bfs_level_partition(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return len(partition_pair) != 0 + + for node in self.graph_module.graph.nodes: + if node.op not in {"placeholder", "get_attr", "output"}: + self.create_single_node_partition(node) + # Set up parent partitions and children partitions for each partition + set_parents_and_children(self.partitions) + # Get bfs level for each partition + get_bfs_level_partition(self.partitions) + find_combination = True + while find_combination: + # Search for a pair partition to generate the minimum new cost, + # then combine them + find_combination = search_combination( + transfer_rate_bytes_per_sec, node_to_latency_mapping + ) + # Make sure all partitions are set up correctly + reorganize_partitions(self.partitions) + # Set up node to partition mapping + self.node_to_partition = get_node_to_partition_mapping(self.partitions) + return + + def kl_based_partition( + self, + transfer_rate_bytes_per_sec: float, + node_to_latency_mapping: dict[Node, NodeLatency], + ) -> None: + """This function is a cost aware partition based + on Kernighan-Lin algorithm. + First, the graph is partitioned using size_based_partition. + Then, each node is swapped with any other node in a different + partition, and at the same time, the cost is estimated after + the swapping. + For example, we have nodes n0, n1, n2, n3 and n4. + Using size_based_partition, n0 and n1 are in Partition p0. + n2, n3 and n4 in Partition p1. The current cost is estimated. + We first tried using n0 to swap with n2 from the other partition. + Then we see that swapping n0 and n2 shows a lower cost + than the current cost and it is the minimum among other pairs like + (n0, None)(This means moving n0 to Partition without swapping other nodes), + (n0, n3) and (n0, n4). We swap n0 and n2 and set the new cost + as the current cost. + Then We repeat this process for all the other nodes until all swapping pairs + are tried. + """ + + def swap_nodes(n0, n1, p0, p1): + # Either n0 or n1 could be None + # That means we simply move the node + # to another partition + if n0 is not None: + p0.remove_node(n0) + p1.add_node(n0) + if n1 is not None: + p0.add_node(n1) + p1.remove_node(n1) + + def try_swap_nodes( + n0, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + cost = float("inf") + swap_nodes(n0, n1, p0, p1) + # Reorganize partitions after swapping + reorganize_partitions(self.partitions) + # Check if there is a circular dependency after swapping + if (not check_dependency(p0)) and (not check_dependency(p1)): + reset_partition_device(self.partitions) + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Check if all partitions can be mapped to logical devices after swapping + found_device = get_device_to_partitions_mapping( + self.partitions, self.devices + ) + if not found_device: + cost = float("inf") + else: + cost = get_latency_of_partitioned_graph( + self.partitions, + partition_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Swap back and reset all partitions back to original + swap_nodes(n1, n0, p0, p1) + reorganize_partitions(self.partitions) + reset_partition_device(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + return cost + + def swap_node_to_partition( + node, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ): + """This function helps to swap one node from partition p0 + with all the nodes in another partition p1 + """ + p1_nodes = list(p1.nodes) + [None] + min_cost = float("inf") + node_pair: list[Node] = [] + for n1 in p1_nodes: + # Ignore the node if it is not a op node + if n1 is not None and n1.op in {"placeholder", "get_attr"}: + continue + # Try swapping node in p0 with n1 in p1 + cost = try_swap_nodes( + node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec + ) + if cost < min_cost: + node_pair = [node, n1] + min_cost = cost + return cost, node_pair # type: ignore[possibly-undefined] + + # First use size_base_partition + self.size_based_partition() + partition_to_latency_mapping = get_partition_to_latency_mapping( + self.partitions, node_to_latency_mapping + ) + # Calculate the cost of the partitions + cost = get_latency_of_partitioned_graph( + self.partitions, partition_to_latency_mapping, transfer_rate_bytes_per_sec + ) + # Keep tracking the node pair that shows the better cost + node_pair: list[Node] = [] + # Keep tracking the partition pair of node pair + partition_pair: list[Partition] = [] + # Collect all the op nodes from the graph + op_nodes = [ + n + for n in self.graph_module.graph.nodes + if n.op not in {"placeholder", "get_attr", "output"} + ] + for node in op_nodes: + # Find which partition the current node belongs + p0_index = self.node_to_partition[node] + p0 = self.partitions[p0_index] + # Go through all the other partitions to swap + # with other nodes from those partitions + for p1_index, _ in enumerate(self.partitions): + if p0_index != p1_index: + p1 = self.partitions[p1_index] + new_cost, new_node_pair = swap_node_to_partition( + node, + p0, + p1, + node_to_latency_mapping, + transfer_rate_bytes_per_sec, + ) + # Update the cost + # Track the swapped node pair and their partitions + if new_cost < cost: + cost = new_cost + node_pair = new_node_pair + partition_pair = [p0, p1] + # Do the swapping after trying all the nodes from a partition + if len(node_pair) != 0: + swap_nodes( + node_pair[0], node_pair[1], partition_pair[0], partition_pair[1] + ) + reorganize_partitions(self.partitions) + get_device_to_partitions_mapping(self.partitions, self.devices) + reorganize_partitions(self.partitions) + # Mapping the device to the partition + get_device_to_partitions_mapping(self.partitions, self.devices) + return + + def aot_based_partition( + self, node_to_partition_mapping, partition_to_logical_device_mapping + ): + """This function helps to rebuild the partitions given the nodes and its + corresponding partition id + """ + partition_id_to_partition_mapping: dict[int, Partition] = {} + self.node_to_partition = node_to_partition_mapping + for node in self.node_to_partition: + partition_id = self.node_to_partition[node] + # If the requested partition has not been created, create the partition + if partition_id not in partition_id_to_partition_mapping: + partition = Partition(partition_id) + self.partitions.append(partition) + partition_id_to_partition_mapping[partition_id] = partition + partition.logical_device_ids = partition_to_logical_device_mapping[ + partition_id + ] + else: + partition = partition_id_to_partition_mapping[ + self.node_to_partition[node] + ] + # Add the current node into the partition + partition.add_node(node) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/const_fold.py b/phivenv/Lib/site-packages/torch/fx/experimental/const_fold.py new file mode 100644 index 0000000000000000000000000000000000000000..7f22ee4fbad1ea8f92c3d0888ca3a12f05257f61 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/const_fold.py @@ -0,0 +1,311 @@ +# mypy: allow-untyped-defs +import re +from typing import Callable, Optional, Union + +import torch.fx +from torch.fx.node import map_arg +from torch.fx.passes.split_module import split_module + + +__all__ = [ + "FoldedGraphModule", + "get_unique_attr_name_in_module", + "split_const_subgraphs", +] + + +class FoldedGraphModule(torch.fx.GraphModule): + """ + FoldedGraphModule is a GraphModule which also contains another + `const_subgraph_module` representing a subgraph which has all const attr + inputs and which can be run once before running the main standard + `graph`. The `const_output_names` are the ordered list names of attrs which + represent what each respective output from the const_subgraph should be set + on which attrs. + """ + + def __init__( + self, + root: torch.nn.Module, + graph: torch.fx.Graph, + const_subgraph: Optional[torch.fx.Graph] = None, + fx_const_folded_attrs_name: Optional[str] = None, + device_for_folded_attrs: str = "cuda", + ): + super().__init__(root, graph) + self.const_subgraph_module = ( + None + if const_subgraph is None + else torch.fx.GraphModule(root, const_subgraph) + ) + self.has_folding_been_run = False + self.fx_const_folded_attrs_name = fx_const_folded_attrs_name + self.device_for_folded_attrs = device_for_folded_attrs + + def __call__(self, *args, **kwargs): + if not self.has_folding_been_run: + self.run_folding() + return super().__call__(*args) + + def run_folding(self): + # If there's no const subgraph module or attr output names to use, return + # early as there is no const folding to perform. + if ( + self.const_subgraph_module is None + or self.fx_const_folded_attrs_name is None + ): + return + + assert not self.has_folding_been_run + self.has_folding_been_run = True + + # Actually run const folding subgraph. Note that single attr const fold + # subgraphs output a single Tensor while multiple outputs are returned as + # Tuple[Tensor,]. + folded_attrs = self.const_subgraph_module() + + def _create_param(i): + return torch.nn.Parameter( + i.detach().clone() + if not isinstance(i, int) + else torch.Tensor([i]).to(device=self.device_for_folded_attrs), + requires_grad=i.requires_grad if isinstance(i, torch.Tensor) else False, + ) + + params = ( + torch.nn.ParameterList([_create_param(i) for i in folded_attrs]) + if isinstance(folded_attrs, tuple) + else _create_param(folded_attrs) + ) + setattr(self, self.fx_const_folded_attrs_name, params) + + +def _inline_module(gm: torch.fx.GraphModule, inline_mod_name: str): + """ + Given `gm` and some graph module which is called with target name `inline_mod_name`, + this helper will inline all of the nodes from that called graph module into `gm`. + """ + # Fetch the inner graph module that we want to inline inside `gm`. + inline_mod = dict(gm.named_modules())[inline_mod_name] + assert isinstance(inline_mod, torch.fx.GraphModule) + call_mod_node_to_replace = None + for node in gm.graph.nodes: + if node.op == "call_module" and node.target == inline_mod_name: + call_mod_node_to_replace = node + break + assert call_mod_node_to_replace is not None + + # Now actually do the swap. Note that we have to keep track of new nodes that are + # copied into `gm` -- we do this via replacement_mapping. + call_mod_args = call_mod_node_to_replace.args + call_mod_kwargs = call_mod_node_to_replace.kwargs + + replacement_mapping: dict[torch.fx.Node, torch.fx.Node] = {} + ph_count = 0 + + def replacement_fn(node): + new_node = replacement_mapping[node] + new_node.meta = node.meta.copy() + return new_node + + for inline_node in inline_mod.graph.nodes: + if inline_node.op == "placeholder": + replacement_mapping[inline_node] = ( + call_mod_kwargs[inline_node.name] + if inline_node.name in call_mod_kwargs + else call_mod_args[ph_count] + ) + + ph_count += 1 + continue + + if inline_node.op == "output": + outputs = inline_node.args[0] + output_replacements = map_arg(outputs, replacement_fn) + call_mod_node_to_replace.replace_all_uses_with(output_replacements) + continue + + with gm.graph.inserting_before(call_mod_node_to_replace): + new_node = gm.graph.node_copy(inline_node, replacement_fn) + replacement_mapping[inline_node] = new_node + + gm.graph.eliminate_dead_code() + + +def get_unique_attr_name_in_module(mod_traced: torch.fx.GraphModule, name: str) -> str: + """ + Make sure the name is unique (in a module) and can represents an attr. + """ + # Delete all characters that are illegal in a Python identifier. + name = re.sub("[^0-9a-zA-Z_]+", "_", name) + if name[0].isdigit(): + name = f"_{name}" + # Now make sure it is in fact unique to the module by incrementing suffix value. + while hasattr(mod_traced, name): + match = re.match(r"(.*)_(\d+)$", name) + if match is None: + name = name + "_1" + else: + base, num = match.group(1, 2) + name = f"{base}_{int(num) + 1}" + + return name + + +def split_const_subgraphs( + module: Union[torch.nn.Module, torch.fx.GraphModule], + skip_folding_node_fn: Optional[Callable[[torch.fx.Node], bool]] = None, + device_for_folded_attrs: str = "cpu", +) -> FoldedGraphModule: + """ + Looks through `module` for any nodes that have all constant attribute inputs + and separates them out into their own constant subgraph, and returns a + FoldedGraphModule which runs that constant subgraph on the first run to set + attributes on the module prior to running the non-constant portion of the + graph. + """ + if not isinstance(module, torch.fx.GraphModule): + mod_traced = torch.fx.symbolic_trace(module) + else: + mod_traced = module + + # Build up a list of const_nodes, defined as nodes that are themselves + # get_attrs, or have all get_attr or other constant node inputs. + const_nodes: set[torch.fx.Node] = set() + found_const_folding = False + for node in mod_traced.graph.nodes: + # Skip over placeholders/outputs because they can't be const folded and + # we don't want to add tags to them. + if node.op in {"placeholder", "output"}: + continue + + # If the node itself is constant, or all of its inputs are constant, + # then tag it as constant. + if node.op != "get_attr" and not set(node.all_input_nodes).issubset( + const_nodes + ): + continue + + # If provided skip folding function says to skip, then skip. + if skip_folding_node_fn and skip_folding_node_fn(node): + continue + + # Skip folding side-effectful functions + if node.is_impure(): + continue + + # Must be a constant foldable node at this point. + const_nodes.add(node) + if node.op != "get_attr": + found_const_folding = True + + # If we did not find any const folding then return early without a const fold subgraph. + if not found_const_folding: + return FoldedGraphModule(mod_traced, mod_traced.graph) + + # Partition the module into two: submod_0 for constant folding subgraph, and + # submod_1 for the rest. + def mod_partition(node: torch.fx.Node): + return 0 if node in const_nodes else 1 + + split = split_module(mod_traced, module, mod_partition) + + const_mod_name, non_const_mod_name = "submod_0", "submod_1" + # Safely get submod_1 in case there are no non-const nodes + const_gm, non_const_gm = split.submod_0, getattr(split, non_const_mod_name, None) + + # The module that a call_module node refers to gets copied to submodules during split. + # The path to the module also gets inlined, i.e. mod.a.b -> mod_a_b. Here we need to + # attach inlined modules to `split` as it's the owning module now. + for node in non_const_gm.graph.nodes if non_const_gm else []: + if node.op == "call_module": + setattr(split, node.target, getattr(non_const_gm, node.target)) + for node in const_gm.graph.nodes: + if node.op == "call_module": + setattr(split, node.target, getattr(const_gm, node.target)) + + # split_module currently does not use get_attrs for attrs. Instead it passes + # them in as args from the parent module, which used get_attrs. Here we set + # them as get_attrs inside const_gm, allowing for running folding without + # somehow a priori knowing the attrs that should be passed as args. We can + # unconditionally do this for all placeholders because we know all + # placeholders to const_gm must be constants accessible via get_attr. + call_const_gm_args = None + for node in split.graph.nodes: + if node.op == "call_module": + if node.target == const_mod_name: + call_const_gm_args = node.args + break + assert call_const_gm_args is not None + + # Here we do the actual replacement of placeholders to get_attrs. Note that here we + # set the const_gm.graph into a new root_const_gm with split as the root module, + # because we are fetching attributes directly from the root module, instead of + # fetching them from const_gm. Example: The const_gm must have some format like: + # graph(): + # %inp : [num_users=1] = placeholder[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp, %inp), kwargs = {}) + # return add + # We replace that with the following, which does not have any placeholders: + # graph(): + # %inp_1 : [num_users=1] = get_attr[target=const_inp] + # %add : [num_users=1] = call_function[target=operator.add](args = (%inp_1, %inp_1), kwargs = {}) + # return add + root_const_gm = torch.fx.GraphModule(split, const_gm.graph) + + # The order of placeholders in the const_gm graph should match the order of + # args in the outer module, so we can simply use an index for the + # placeholder mapping + ph_idx = 0 + for node in root_const_gm.graph.nodes: + if node.op == "output": + multiple_outputs = isinstance(node.args[0], tuple) + continue + if node.op != "placeholder": + continue + assert ph_idx < len(call_const_gm_args) + in_node = call_const_gm_args[ph_idx] + ph_idx += 1 + assert in_node.op == "get_attr" + with root_const_gm.graph.inserting_before(node): + new_node = root_const_gm.graph.get_attr(in_node.target) + new_node.meta = node.meta.copy() + node.replace_all_uses_with(new_node) + root_const_gm.graph.erase_node(node) + assert "multiple_outputs" in locals() + + # Now find the call to const_gm inside split, and replace it with a getattr to the + # folded tensor(s) that result from constant folding. Note that we don't need to + # worry about whether this is one or more tensors because the original graph + # correctly uses getitem to extract individual tensors if there are multiple folded. + fx_const_folded_attrs_name = get_unique_attr_name_in_module( + mod_traced, "_FX_CONST_FOLDED_ATTRS" + ) + setattr( + split, + fx_const_folded_attrs_name, + torch.nn.ParameterList() if multiple_outputs else torch.nn.Parameter(), # type: ignore[possibly-undefined] + ) + for node in split.graph.nodes: + if node.op == "call_module" and node.target == const_mod_name: + with node.graph.inserting_before(node): + folded_attrs = node.graph.get_attr(fx_const_folded_attrs_name) + folded_attrs.meta = node.meta.copy() + node.replace_all_uses_with(folded_attrs) + break + + # Finally, inline the non-constant submod (if it exists) into the split submod. + # This is so that the original caller who may have passed in a graph module will + # get back out a graph module whose graph is traced to the same granularity. + if hasattr(split, non_const_mod_name): + _inline_module(split, non_const_mod_name) + + split.graph.eliminate_dead_code() + + return FoldedGraphModule( + split, + split.graph, + root_const_gm.graph, + fx_const_folded_attrs_name, + device_for_folded_attrs, + ) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/debug.py b/phivenv/Lib/site-packages/torch/fx/experimental/debug.py new file mode 100644 index 0000000000000000000000000000000000000000..0192b03d88149c412cc42a9ba54e28445c229c44 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/debug.py @@ -0,0 +1,33 @@ +from collections.abc import Sequence + +import torch.fx as fx + + +__all__ = ["set_trace"] + + +def set_trace(gm: fx.GraphModule) -> fx.GraphModule: + """ + Sets a breakpoint in `gm`'s generated python code. It drops into pdb when + `gm` gets run. + + Args: + gm: graph module to insert breakpoint. It is then recompiled for it to + take effect. + + Returns: + the `gm` with breakpoint inserted. + """ + + def insert_pdb(body: Sequence[str]) -> list[str]: + return ["import pdb; pdb.set_trace()\n", *body] + + with gm.graph.on_generate_code( + make_transformer=lambda cur_transform: ( + # new code transformer to register + lambda body: (insert_pdb(cur_transform(body) if cur_transform else body)) + ) + ): + gm.recompile() + + return gm diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/graph_gradual_typechecker.py b/phivenv/Lib/site-packages/torch/fx/experimental/graph_gradual_typechecker.py new file mode 100644 index 0000000000000000000000000000000000000000..4254b0d6574b1946e8d88cd7280059d9fd2f4817 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/graph_gradual_typechecker.py @@ -0,0 +1,1024 @@ +# mypy: allow-untyped-defs +import itertools +import operator +from functools import reduce +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +import sympy + +import torch +from torch.fx.experimental.refinement_types import Equality +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, is_consistent, is_more_precise, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_INFERENCE_RULES: dict[Target, Callable] = {} +_REFINEMENT_RULES: dict[Target, Callable] = {} +_RULES: dict[Target, Callable] = {} + +__all__ = [ + "GraphTypeChecker", + "Refine", + "adaptiveavgpool2d_check", + "adaptiveavgpool2d_inference_rule", + "add_inference_rule", + "all_eq", + "bn2d_inference_rule", + "broadcast_types", + "calculate_out_dimension", + "conv2d_inference_rule", + "conv_refinement_rule", + "conv_rule", + "element_wise_eq", + "expand_to_tensor_dim", + "first_two_eq", + "flatten_check", + "flatten_inference_rule", + "flatten_refinement_rule", + "get_attr_inference_rule", + "get_greatest_upper_bound", + "get_parameter", + "linear_check", + "linear_inference_rule", + "linear_refinement_rule", + "maxpool2d_check", + "maxpool2d_inference_rule", + "register_algebraic_expressions_inference_rule", + "register_inference_rule", + "register_refinement_rule", + "relu_inference_rule", + "reshape_inference_rule", + "transpose_inference_rule", +] + + +def expand_to_tensor_dim(t, n): + """ + Expand a type to the desired tensor dimension if possible + Raise an error otherwise. + - t is the given type + - n is a number of dimensions to expand to + """ + if t == Dyn: + dims = [Dyn] * n + return TensorType(tuple(dims)) + elif isinstance(t, TensorType): + if len(t.__args__) != n: + raise TypeError( + f"Cannot extend tensor. Tensor {t} has rank {len(t.__args__)}. It should have rank {n}" + ) + return t + else: + raise TypeError(f"Cannot match the type {t}") + + +def broadcast_types(t1, t2): + """ + Applies broadcasting to both given types such that they + become consistent with eachother and returns two new + resulting types + """ + + # if either type is Dyn, do nothing since the types are already consistent + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return t1, t2 + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + s1 = len(t1.__args__) + s2 = len(t2.__args__) + + new_t1 = list(t1.__args__) + new_t2 = list(t2.__args__) + + # We make the types the same length which is the first requirement + # for consistency + if s1 > s2: + for i in range(s1 - s2): + new_t2.insert(0, 1) + + elif s2 > s1: + for i in range(s2 - s1): + new_t1.insert(0, 1) + + # we replace occurrences of "1" with each tensor with + # the corresponding type from the other tensor + for i, (x, y) in enumerate(zip(new_t1, new_t2)): + if x == 1: + new_t1[i] = y + elif y == 1: + new_t2[i] = x + + # at this point our tensors should be consistent + # and we can apply the element-wise operation and find the right dimension + # for the output of the operation + (t1, t2) = TensorType(tuple(new_t1)), TensorType(tuple(new_t2)) + return (t1, t2) + else: + raise TypeError(f"Cannot broadcast types {t1} and {t2}") + + +def register_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _INFERENCE_RULES: + raise RuntimeError(f"Inference rule already registered for {call_target}!") + _INFERENCE_RULES[call_target] = fn + return fn + + return register + + +def register_refinement_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _REFINEMENT_RULES: + raise RuntimeError(f"Refinement rule already registered for {call_target}!") + _REFINEMENT_RULES[call_target] = fn + return fn + + return register + + +def register_algebraic_expressions_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _RULES: + raise RuntimeError(f"Rule already registered for {call_target}!") + _RULES[call_target] = fn + return fn + + return register + + +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def add_inference_rule(n: Node): + """ + Apply the addition inference rule. This includes: + - scalar addition + - broadcasting semantics + + Note that we always return the least precise type between + the operands (after applying broadcasting) to be the final type of the operation + + Note that we do not modify the operand types themselves after applying broadcasting + to them. We only use them to calculate the final type + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + t1 = n.args[0].type + t2 = n.args[1].type + + # handle scalar addition + if t1 == int and isinstance(t2, TensorType): + n.type = t2 + return n.type + + # handle scalar addition + elif t2 == int and isinstance(t1, TensorType): + n.type = t1 + return n.type + + # we bring the new types to the point where + # we can check for consistency + # any inconsistency would not have been caused + # by broadcasting at this point + (new_t1, new_t2) = broadcast_types(t1, t2) + + if new_t1 != t1 or new_t2 != t2: + n.meta["broadcast"] = True + n.meta[str(n.args[0])] = new_t1 + n.meta[str(n.args[1])] = new_t2 + + else: + n.meta["broadcast"] = False + + new_t1 = t1 if not n.meta["broadcast"] else new_t1 + new_t2 = t2 if not n.meta["broadcast"] else new_t2 + + # we check for consistency between the new types + if is_consistent(new_t1, new_t2): + # we return the less precise type because + # broadcasting may have happened + # for operands with shape [1,2,Dyn] and [1,2,1] + # we have to assign the node [1,2,Dyn] + if is_more_precise(new_t1, new_t2): + n.type = new_t2 + else: + n.type = new_t1 + return n.type + else: + raise TypeError( + f"Cannot add arguments {n.args[0]} ({n.args[0].type}) and {n.args[1]} ({n.args[1].type}) in node {n}." + f" Types should match " + ) + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, traced): + """ + The current getattr rule only handles the shape attribute + Can be extended to other attributes + The most representitive type we have is "Dyn" but the system + can be extended with more types, such as a type to represent shapes + """ + attr_name = n.args[1] + + if attr_name == "shape": + n.type = Dyn + else: + raise TypeError("Not yet implemented") + + # TODO. We leave it like this till we add a type to represent tensor sizes + return n.type + + +@register_inference_rule(torch.transpose) +def transpose_inference_rule(n: Node): + """ + We check that dimensions for the transpose operations + are within range of the tensor type of the node + """ + if n.target == torch.transpose: + assert isinstance(n.args[0], Node) + t = n.args[0].type + + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + dim1, dim2 = n.args[1], n.args[2] + + if t == Dyn: + n.type = Dyn + return n.type + + elif isinstance(t, TensorType): + if 0 <= dim1 < len(t.__args__) and 0 <= dim2 < len(t.__args__): + new_type = list(t.__args__) + new_type[dim1], new_type[dim2] = new_type[dim2], new_type[dim1] + final = TensorType(new_type) + n.type = get_greatest_upper_bound(n.type, final) + return n.type + else: + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) + else: + raise TypeError( + f"Cannot transpose {dim1} and {dim2} in type {t} for node {n}" + ) + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node): + """ + Without dynamism, the rule checks that the + product of the elements of the argument tensor + type is equal to the product of the elements + of the required shape. We gradualize this rule + by adding a case to handle fully dynamic input + as well as input where some of the tensor dimensions + are unknown. In this case we check for divisibility + """ + assert isinstance(n.args[0], Node) + t1 = n.args[0].type + + assert isinstance(n.args[1], list) + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) + + # if we do not know the original tensor dimension, + # we return the required dimension + if t1 == Dyn: + n.type = t2_type + return t2_type + + # if any of the dimensions are unknown, + # we check for divisibility + elif isinstance(t1, TensorType): + assert isinstance(t1, TensorType) + a = [e if e != Dyn else 1 for e in t1.__args__] + p1 = reduce(operator.mul, a) + p2 = reduce(operator.mul, t2) + if p1 % p2 == 0 or p2 % p1 == 0: + n.type = t2_type + return t2_type + else: + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + else: + raise TypeError(f"Cannot reshape in node {n} from {t1} to {t2_type}") + + +@register_inference_rule(BatchNorm2d) +def bn2d_inference_rule(n: Node, module_instance): + """ + Given a BatchNorm2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, x_3, x_4) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - t is consistent with t' + - x_2 is consistent with the module's num_features + - x_2' is consistent with the module's num_features + output type: the more precise type of t and t' + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + n.type = expand_to_tensor_dim(n.type, 4) + + # we check the conditions on the incoming argument + # and any existing annotation + # we also check for consistency between both annotations + if ( + is_consistent(arg_type.__args__[1], module_instance.num_features) + and is_consistent(n.type.__args__[1], module_instance.num_features) + and is_consistent(arg_type, n.type) + ): + # we choose the more precise type + # to be the node type + # so if an incoming argument has more type information + # we set this node's type to be the argument type + n.type = get_greatest_upper_bound(arg_type, n.type) + return n.type + else: + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) + + +def calculate_out_dimension(d_in, module_instance, index): + """ + For calculating h_in and w_out according to the conv2D documentation + """ + padding = ( + (module_instance.padding, module_instance.padding) + if isinstance(module_instance.padding, int) + else module_instance.padding + ) + kernel_size = ( + (module_instance.kernel_size, module_instance.kernel_size) + if isinstance(module_instance.kernel_size, int) + else module_instance.kernel_size + ) + stride = ( + (module_instance.stride, module_instance.stride) + if isinstance(module_instance.stride, int) + else module_instance.stride + ) + dilation = ( + (module_instance.dilation, module_instance.dilation) + if isinstance(module_instance.dilation, int) + else module_instance.dilation + ) + + DIMENSION_TYPES = (int, sympy.Symbol) + + if d_in == Dyn: + return Dyn + + elif isinstance(d_in, DIMENSION_TYPES): + n = d_in + 2 * padding[index] - dilation[index] * (kernel_size[index] - 1) - 1 + + return (n // stride[0]) + 1 + + else: + raise TypeError( + f"{d_in} in {module_instance} must be a number or Dyn. Received {type(d_in)}" + ) + + +def get_greatest_upper_bound(type1, type2): + """ + Get the most precise type that's consistent with the given types + """ + if type1 == Dyn: + return type2 + elif type2 == Dyn: + return type1 + elif isinstance(type1, TensorType) and isinstance(type2, TensorType): + if not is_consistent(type1, type2): + raise TypeError(f"Inconsistent types {type1}, {type2}") + gub = [ + t1 if is_more_precise(t1, t2) else t2 + for (t1, t2) in zip(type1.__args__, type2.__args__) + ] + return TensorType(tuple(gub)) + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance): + """ + Given a Conv2D instance and a node check the following conditions: + - the input type can be expanded to a size 4 tensor: t = (x_1, x_2, H, W) + - the current node type can be expanded to a size 4 tensor: t' = (x_1', x_2', x_3', x_4') + - x_2 is consistent with the module's in_channels + - let o = (x_1, out_channels, H_out, W_out) + then the output is the greatest upper bound of o and the existing node type t'. + """ + assert isinstance(n.args[0], Node) + n.args[0].type = expand_to_tensor_dim(n.args[0].type, 4) + arg_type = n.args[0].type + curr_node_type = expand_to_tensor_dim(n.type, 4) + + if is_consistent(arg_type.__args__[1], module_instance.in_channels): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType( + (arg_type.__args__[0], module_instance.out_channels, h_out, w_out) + ) + gub = get_greatest_upper_bound(new_type, curr_node_type) + n.type = gub + return n.type + else: + raise TypeError( + f"Cannot apply {module_instance} with input type {arg_type} and existing type {n.type} on {n}" + ) + + +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + n.type = get_greatest_upper_bound(n.args[0].type, n.type) + return n.type + + +def maxpool2d_check(typ, module_instance): + """ + Applies the maxpool2d shape information to the input + this affects the last two dimensions + """ + new_type_list = list(typ.__args__) + if len(new_type_list) == 4 or len(new_type_list) == 3: + w_in = new_type_list[-1] + h_in = new_type_list[-2] + + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + + new_type_list[-1] = w_out + new_type_list[-2] = h_out + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f"Wrong size {typ} for {module_instance}") + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool2d_inference_rule(n: Node, module_instance): + """ + Given a MaxPool2D instance and a node check the following conditions: + - Input size matches size 3 or 4 + - Current node type is consistent with the output type we will calculate + - Input size matches output size and the last two dimensions of the output + are w_out and h_out. The remaining dimensions are the same as the input + - Our final result is the greatest upper bound of the output we calculate + and the current node type. + """ + assert isinstance(n.args[0], Node) + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output = maxpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output, n.type) + return n.type + + +def linear_check(tensor_type, module_instance): + """ + Checks that an input tensor type satisfies the conditions for linear operation + and returns the output type based on in and out features given by module_instance + """ + if len(tensor_type.__args__) >= 2: + if is_consistent(module_instance.in_features, tensor_type.__args__[-1]): + new_type_args = list(tensor_type.__args__) + new_type_args[-1] = module_instance.out_features + return TensorType(tuple(new_type_args)) + else: + raise TypeError( + f"Inconsistent {module_instance.in_features} and {tensor_type.__args__[-1]} in {module_instance}" + ) + else: + raise TypeError(f"Type {tensor_type} must have rank 2 or more.") + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance): + """ + Applies the shape information to the input then gets the greatest upper bound + of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = linear_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(output_type, n.type) + return n.type + + +def adaptiveavgpool2d_check(tensor_type, module_instance): + output_size = module_instance.output_size + if isinstance(output_size, int): + output_size = [output_size, output_size] + elif isinstance(output_size, tuple): + output_size = list(output_size) + if output_size[0] is None: + output_size[0] = output_size[1] + if output_size[1] is None: + output_size[1] = output_size[0] + + new_type_list = list(tensor_type.__args__) + + if len(tensor_type.__args__) == 4 or len(tensor_type.__args__) == 3: + new_type_list[-1] = output_size[1] + new_type_list[-2] = output_size[0] + + return TensorType(tuple(new_type_list)) + + else: + raise TypeError(f"Tensor ranks must be 3 or 4. Got {tensor_type}") + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptiveavgpool2d_inference_rule(n: Node, module_instance): + """ + The input and output sizes should be the same except for the last + two dimensions taken from the input, which represent width and height + """ + assert isinstance(n.args[0], Node) + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + if isinstance(n.args[0].type, TensorType): + output_type = adaptiveavgpool2d_check(n.args[0].type, module_instance) + n.type = get_greatest_upper_bound(n.type, output_type) + return n.type + + +def flatten_check(tensor_type, start_dim, end_dim): + l = len(tensor_type.__args__) + + start_dim = l if start_dim == -1 else abs(start_dim) + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + if 0 <= start_dim <= (l - 1) and 0 <= end_dim <= l and start_dim < end_dim: + my_args = list(tensor_type.__args__) + lhs = my_args[0:start_dim] + rhs = my_args[end_dim:] + mid = my_args[start_dim:end_dim] + if Dyn in mid: + mid = [Dyn] + else: + mid = [reduce(operator.mul, my_args[start_dim:end_dim])] + new_type_list = lhs + mid + rhs + return TensorType(tuple(new_type_list)) + else: + raise TypeError( + f"Incompatible dimensions {start_dim}, {end_dim - 1} in type {tensor_type}" + ) + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node): + """ + Applies the flatten shape information to the input then gets the + greatest upper bound of the resulting type and the existing type + """ + assert isinstance(n.args[0], Node) + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if n.args[0].type == Dyn and isinstance(n.type, TensorType): + n.args[0].type = expand_to_tensor_dim(n.args[0].type, len(n.type.__args__)) + + if isinstance(n.args[0].type, TensorType): + output_type = flatten_check(n.args[0].type, start_dim, end_dim) + n.type = get_greatest_upper_bound(output_type, n.type) + + return n.type + + +class GraphTypeChecker: + def __init__(self, env, traced): + self.env = env + self.traced = traced + + def type_check(self): + """ + A gradual type checker for graphs + Effect: every node's field type will be + populated with a type after type-checking is done + """ + graph = self.traced.graph + + # type check every node with gradual type rules + # if any node does not type check return false + for n in graph.nodes: + self.type_check_node(n) + return True + + def type_check_node(self, n: Node): + """ + Type check a given fx node. + Current operations: + - Reshape + - Transpose + - Add + - Relu + - conv2d + - batchnorm2d + - flatten + - maxpool2d + - adaptiveavgpool2d + - linear + """ + if n.type is None: + n.type = Dyn + + if n.op == "placeholder": + return n.type + + elif n.op == "get_attr": + t = get_parameter(self.traced, n.target) # type: ignore[arg-type] + if isinstance(t.data, torch.Tensor): + n.type = TensorType(t.data.shape) + return n.type + + elif n.op == "call_function": + if n.target == getattr: + assert getattr in _INFERENCE_RULES + return _INFERENCE_RULES[n.target](n, self.traced) + + elif n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target](n) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)](n, module_instance) + else: + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") + + +@register_refinement_rule(Conv2d) +def conv_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(torch.nn.Linear) +def linear_refinement_rule(n: Node): + """ + The equality constraints are between the first dimension of + the input and output + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + res = [Equality(arg_type.__args__[0], n.type.__args__[0])] + return res + + +@register_refinement_rule(BatchNorm2d) +@register_refinement_rule(torch.nn.ReLU) +def all_eq(n: Node): + """ + For operations where the input shape is equal to the output shape + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[i], args2[i]) for i in range(len(args1))] + return res + + +@register_refinement_rule(torch.nn.AdaptiveAvgPool2d) +@register_refinement_rule(torch.nn.MaxPool2d) +def first_two_eq(n: Node): + """ + For operations where the first two dimensions of the input and output shape + are equal + """ + res = [] + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + args1 = arg_type.__args__ + args2 = n.type.__args__ + res = [Equality(args1[0], args2[0]), Equality(args1[1], args2[1])] + return res + + +@register_refinement_rule(torch.add) +@register_refinement_rule(operator.add) +def element_wise_eq(n: Node): + """ + For element-wise operations and handles broadcasting. + Note that after applying broadcasting to the arguments + we are able to determine if certain dimensions have not been broadcast + if they are symbolicallu equal. + + in this case, we can establish equality between those dimensions and the + corresponding output dimensions. + + Note that it takes two iterations for this result. One iteration to establish + equality between certain dimensions of the operands (requiring the whole solver + including unification) and another iteration to establish equality between the operands + and the resulting type, requiring another round of constraint generation and unificaiton. + """ + res = [] + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + arg_type1 = n.args[0].type + arg_type2 = n.args[1].type + if ( + isinstance(arg_type1, TensorType) + and isinstance(arg_type2, TensorType) + and isinstance(n.type, TensorType) + ): + args1, args2 = broadcast_types(arg_type1, arg_type2) + # by this point, we know that args1 and args2 are the same size. + a1 = args1.__args__ + a2 = args2.__args__ + a3 = n.type.__args__ + + # we would be here in the second iteration where we establish equality + # between operand type dimensions and the resulting type dimensions + r = [] + for x, y, z in zip(a1, a2, a3): + if x == y: + r.append(Equality(x, z)) + res = r + return res + + +@register_refinement_rule(torch.flatten) +def flatten_refinement_rule(n: Node): + """ + Generates equality constraints between the dimensions of the input and output + that will not be involved in the flatten operation + """ + assert isinstance(n.args[0], Node) + + eq_const = [] + + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + if isinstance(n.type, TensorType) and isinstance(n.args[0].type, TensorType): + l = len(n.type.__args__) + arg_type = n.args[0].type + start_dim = l if start_dim == -1 else start_dim + end_dim = l + end_dim + 1 if end_dim < 0 else end_dim + 1 + + for t1, t2 in zip(n.type.__args__[0:start_dim], arg_type.__args__[0:start_dim]): + eq_const.append(Equality(t1, t2)) + + for t1, t2 in zip(n.type.__args__[end_dim:], arg_type.__args__[end_dim:]): + eq_const.append(Equality(t1, t2)) + return eq_const + + +@register_algebraic_expressions_inference_rule(Conv2d) +def conv_rule(n: Node, module_instance): + """ + Represents the outout in terms of an algrbraic expression w.r.t + the input when possible + """ + assert isinstance(n.args[0], Node) + arg_type = n.args[0].type + if isinstance(arg_type, TensorType) and isinstance(n.type, TensorType): + w_in = arg_type.__args__[3] + h_in = arg_type.__args__[2] + h_out = calculate_out_dimension(h_in, module_instance, 0) + w_out = calculate_out_dimension(w_in, module_instance, 1) + new_type = TensorType((n.type.__args__[0], n.type.__args__[1], h_out, w_out)) + n.type = new_type + return new_type + + +class Refine: + """ + Symbolic shape inference. + Generates constraints over type variables. + Currently all constraints are equality constraints. + """ + + def __init__(self, traced): + self.constraints = [] + self.traced = traced + self.symbol_iter = itertools.count(start=0, step=1) + + def refine(self): + """ + Generates constraints for + every node in the graph based on + the operation. + """ + graph = self.traced.graph + for n in graph.nodes: + self.refine_node(n) + return True + + def symbolic_relations(self): + """ + Infers algebraic relations + """ + graph = self.traced.graph + for n in graph.nodes: + self.infer_symbolic_relations(n) + return True + + def replace_dyn_with_fresh_var(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if typ == Dyn: + new_symbol = Var(next(self.symbol_iter)) + return new_symbol + elif isinstance(typ, TensorType): + new_args = [self.replace_dyn_with_fresh_var(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.replace_dyn_with_fresh_var(t) for t in typ] + elif isinstance(typ, tuple): + return (self.replace_dyn_with_fresh_var(t) for t in typ) + else: + return typ + + def convert_to_sympy_symbols(self, typ): + """ + Replace all unknown types with fresh type variables. + """ + if isinstance(typ, Var): + return sympy.symbols(str(typ)) + elif isinstance(typ, TensorType): + new_args = [self.convert_to_sympy_symbols(a) for a in typ.__args__] + return TensorType(tuple(new_args)) + elif isinstance(typ, list): + return [self.convert_to_sympy_symbols(t) for t in typ] + elif isinstance(typ, tuple): + return (self.convert_to_sympy_symbols(t) for t in typ) + else: + return typ + + def refine_node(self, n: Node): + """ + Returns a list of equality constraints for + call_module and call_function nodes. + Models the relation between input and output dimensions + using constraints in case they are both tensors. + All operations used in resnet50 are defined. + """ + if n.type is None: + n.type = Dyn + + n.type = self.replace_dyn_with_fresh_var(n.type) + + if n.op == "call_function": + if n.target in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[n.target](n) + else: + pass + + if n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _REFINEMENT_RULES: + self.constraints += _REFINEMENT_RULES[type(module_instance)](n) + else: + pass + + if n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + def infer_symbolic_relations(self, n: Node): + n.type = self.convert_to_sympy_symbols(n.type) + if n.op == "call_function": + if n.target in _RULES: + return _RULES[n.target](n) + else: + pass + + if n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _RULES: + return _RULES[type(module_instance)](n, module_instance) + else: + pass + + if n.op == "output": + + def get_node_type(a): + return a.type + + n.type = torch.fx.node.map_arg(n.args[0], get_node_type) + return n.type + + else: + pass + + +def get_parameter(traced, target: str): + """ + Returns the parameter given by ``target`` if it exists, + otherwise throws an error. + + See the docstring for ``get_submodule`` for a more detailed + explanation of this method's functionality as well as how to + correctly specify ``target``. + + Args: + target: The fully-qualified string name of the Parameter + to look for. (See ``get_submodule`` for how to specify a + fully-qualified string.) + + Returns: + torch.nn.Parameter: The Parameter referenced by ``target`` + + Raises: + AttributeError: If the target string references an invalid + path or resolves to something that is not an + ``nn.Parameter`` + """ + module_path, _, param_name = target.rpartition(".") + + mod: torch.nn.Module = traced.get_submodule(module_path) + + if not hasattr(mod, param_name): + raise AttributeError(mod._get_name() + " has no attribute `" + param_name + "`") + + param: torch.nn.Parameter = getattr(mod, param_name) + + return param diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/merge_matmul.py b/phivenv/Lib/site-packages/torch/fx/experimental/merge_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..3d03327149b360b907df5f5cd5392bd4fbd1c590 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/merge_matmul.py @@ -0,0 +1,177 @@ +# mypy: allow-untyped-defs +import itertools +import operator + +import torch +from torch.fx._symbolic_trace import symbolic_trace +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph + + +def split_result_tensors( + result: torch.Tensor, inputs: list[torch.Tensor] +) -> tuple[torch.Tensor, ...]: + """ + A free function for use in the merge_matmul graph transformation below that + splits the output from a merged matmul into the individual results for each + input tensor. + + Arguments: + result: The merged matmul result tensor. + inputs: The list of inputs that were merged into one for the matmul. + + Returns: + List of matmul results for each input tensor. + """ + # When fx tracer is running, x.shape[0] will be torch.fx.Attribute but we + # need an int even when tracing + if isinstance(result, torch.fx.Proxy): + splits = [0] * len(inputs) + else: + splits = [x.shape[0] for x in inputs] + + return torch.split(result, splits) + + +def may_depend_on(a: Node, b: Node, search_depth: int = 6): + """ + Determine if one node depends on another in a torch.fx.Graph. + + Arguments: + a: The node that may have a dependency on b. + b: The node that a may have a dependency on. + search_depth: In the case of an indirect dependency, this function + searches upto this many nodes away in search of a + data dependency. If none is found, the function + makes the conservative assumption that there is a + dependency. + + Returns: + True if a may depend on b, False if it definitely does not. + """ + # Equivalence is defined as dependence. + if a == b: + return True + + # If a has no inputs, it cannot depend on b. + if len(a.all_input_nodes) == 0: + return False + + # If the search depth has been exhausted and no conclusion has been + # reached, assume that there is a data dependency. + if search_depth == 0: + return True + + # Recursively check all inputs of a. + for inp in a.all_input_nodes: + if may_depend_on(inp, b, search_depth - 1): + return True + + return False + + +def are_nodes_independent(nodes: list[Node]): + """ + Check if all of the given nodes are pairwise-data independent. + + Arguments: + nodes: The nodes to check for data dependencies. + + Returns: + True if any pair in nodes has a data dependency. + """ + # For each pair in nodes: + for i, j in itertools.combinations(nodes, 2): + if may_depend_on(i, j) or may_depend_on(j, i): + return False + + return True + + +def merge_matmul(in_mod: torch.nn.Module): + """ + A graph transformation that merges matrix multiplication operations that share the same right-hand + side operand into one large matrix multiplication. + ____ _________ _________ + ---- | | | | M| A * C | + M| A | T| B | * K| C | = |---------| + ---- , | | | | T| B * C | + K ---- --------- --------- + K R R + """ + gm = symbolic_trace(in_mod) + + rhs_users: dict[Node, list[Node]] = {} + lhs_users: dict[Node, list[Node]] = {} + + # Populate rhs_users and lhs_users - maps from LHS/RHS matrix multiply operands to + # the matmul of which they are the LHS/RHS. + for node in gm.graph.nodes: + if node.op != "call_function" or node.target is not torch.matmul: + continue + + lhs, rhs = node.args + + # TODO: Properly handle aliasing caused by get_attr. For now, + # use the attribute name as the operand if the node is a + # get_attr. + lhs = lhs.target if lhs.op == "get_attr" else lhs + rhs = rhs.target if rhs.op == "get_attr" else rhs + + lhs_users.setdefault(lhs, []).append(node) + rhs_users.setdefault(rhs, []).append(node) + + for rhs, mms in rhs_users.items(): + # There must be at least matmuls for a merge to make sense. + if len(mms) < 2: + continue + + # All matmuls must not depend on each other directly or indirectly + # in order for the merge to be possible. + if not are_nodes_independent(mms): + continue + + lhs_vals = [mm.args[0] for mm in mms] + + # Merge the matmul. + # Collect a list of LHS operands and the single RHS operand. + lhs = [gm.graph.get_attr(l) if isinstance(l, str) else l for l in lhs_vals] + rhs = gm.graph.get_attr(rhs) if isinstance(rhs, str) else rhs + + # Concatenate all the LHS operands. + merge_mm_cat = gm.graph.call_function(torch.cat, (lhs,), {}) + + # Multiply the concatenated LHS operands with the one RHS. This will produce + # the same results as all the individual matmuls involving rhs in the original graph, + # but they will all be concatenated together. + merge_mm = gm.graph.call_function( + torch.matmul, + ( + merge_mm_cat, + rhs, + ), + {}, + ) + + # Split the result of the merged matmul using the shapes of the LHS operands + # to ascertain how large each chunk should be. + merge_mm_split = gm.graph.call_function( + split_result_tensors, (merge_mm, lhs), {} + ) + merge_mm_res = [ + gm.graph.call_function(operator.getitem, (merge_mm_split, out), {}) + for out in range(len(lhs)) + ] + + # Replace all uses of the original, unmerged matmuls with the equivalent split chunk from the merged matmul. + for old, new in zip(mms, merge_mm_res): + old.replace_all_uses_with(new) + gm.graph.erase_node(old) + + # All of the new nodes created above were inserted at the end, so we need to sort + # the nodes topologically to make sure all definitions precede uses. + legalize_graph(gm) + + gm.recompile() + gm.graph.lint() + return gm diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/meta_tracer.py b/phivenv/Lib/site-packages/torch/fx/experimental/meta_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb90a686ffdc0a15007a8a708ca9dba9f592fbf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/meta_tracer.py @@ -0,0 +1,311 @@ +# mypy: allow-untyped-defs +import builtins +import functools +import warnings +from typing import Any, Callable, Optional, Union + +import torch +import torch.fx + + +def embedding_override(self, input): + return torch.empty(*input.shape, self.weight.shape[-1], device="meta") + + +def nn_layernorm_override(self, input): + return input + + +def torch_relu_override(x): + return x + + +def torch_nn_relu_override(self, x): + return x + + +def functional_relu_override(x, inplace=False): + assert not inplace, "dont support inplace functional.relu for metatensor analysis" + return x + + +def torch_where_override(condition, x, y): + # torch.where returns the broadcasted tensor of condition, x, and y, + # so hack it by using addition + return condition.to(device="meta") + x.to(device="meta") + y.to(device="meta") + + +def torch_abs_override(input, *, out=None): + assert out is None, "Dont support in-place abs for MetaTensor analysis" + return input + + +manual_meta_overrides: dict[Callable, Callable] = { + torch.nn.Embedding: embedding_override, + torch.nn.LayerNorm: nn_layernorm_override, + torch.relu: torch_relu_override, + torch.nn.functional.relu: functional_relu_override, + torch.nn.ReLU: torch_nn_relu_override, + torch.where: torch_where_override, + torch.abs: torch_abs_override, +} + + +def gen_constructor_wrapper(target): + @functools.wraps(target) + def wrapper(*args, **kwargs): + proxy = None + + def check_has_proxy(v): + if isinstance(v, torch.fx.Proxy): + nonlocal proxy + proxy = v + + torch.fx.node.map_aggregate(args, check_has_proxy) + torch.fx.node.map_aggregate(kwargs, check_has_proxy) + + if proxy is not None: + return proxy.tracer.create_proxy("call_function", target, args, kwargs) + else: + return target(*args, **kwargs) + + return wrapper, target + + +class MetaProxy(torch.fx.Proxy): + def install_tensor_meta(self, tensor_meta): + self._tensor_meta = tensor_meta + + def size(self, dim=None): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.size(*[dim] if dim else []) + return self.tracer.create_proxy( + "call_method", "size", (self, dim) if dim else (self,), {} + ) + + def dim(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.dim() + return self.tracer.create_proxy("call_method", "dim", (self,), {}) + + @property + def shape(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.shape + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "shape"), {} + ) + + @property + def dtype(self): + if hasattr(self, "_tensor_meta") and self._tensor_meta is not None: + return self._tensor_meta.dtype + return self.tracer.create_proxy( + "call_function", builtins.getattr, (self, "dtype"), {} + ) + + @property + def device(self): + # Hack so we can track when devices are used. During meta-tensor propagation, + # replace these values with a constant 'meta' + return MetaDeviceAttribute(self, "device") + + def __getattr__(self, k): + if k == "_tensor_meta": + return self.__getattribute__(k) + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return MetaAttribute(self, k) + + +class MetaAttribute(MetaProxy): + def __init__(self, root, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node = None + + @property + def node(self): # type: ignore[override] + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + + +class MetaDeviceAttribute(MetaAttribute): + pass + + +def proxys_to_metas(v): + if isinstance(v, MetaDeviceAttribute): + return "meta" + if isinstance(v, torch.fx.Proxy): + assert isinstance(v, MetaProxy), f"Expected MetaProxy but got {type(v)}" + assert hasattr(v, "_tensor_meta"), "MetaProxy does not have an associated meta" + return v._tensor_meta + return v + + +class MetaTracer(torch.fx.Tracer): + allow_insert_stateless_mods: bool = True + + _TORCH_METHODS_TO_PATCH = ["arange", "zeros", "ones", "full_like", "eye"] + + def create_proxy( + self, + kind, + target, + args, + kwargs, + name=None, + type_expr=None, + proxy_factory_fn=None, + ): + rv = super().create_proxy( + kind, target, args, kwargs, name, type_expr, proxy_factory_fn + ) + + if kind == "placeholder" and target in self.meta_args: + rv.install_tensor_meta(self.meta_args[target]) + return rv + + if target in self.orig_fns: + # NOTE: tensor constructors in PyTorch define the `device` argument as + # *kwargs-only*. That is why this works. If you add methods to + # _TORCH_METHODS_TO_PATCH that do not define `device` as kwarg-only, + # this will break and you will likely see issues where we cannot infer + # the size of the output. + if "device" in kwargs: + kwargs["device"] = "meta" + + try: + args_metas = torch.fx.node.map_aggregate(args, proxys_to_metas) + kwargs_metas = torch.fx.node.map_aggregate(kwargs, proxys_to_metas) + + if kind == "call_function": + meta_target = manual_meta_overrides.get(target, target) + meta_out = meta_target(*args_metas, **kwargs_metas) + elif kind == "call_method": + meta_target = getattr(args_metas[0], target) # type: ignore[index] + meta_out = meta_target(*args_metas[1:], **kwargs_metas) # type: ignore[index] + elif kind == "call_module": + assert hasattr(self, "orig_forward") + self._disable_module_getattr = True + try: + mod = self.root.get_submodule(target) + mod_type = type(mod) + if mod_type in manual_meta_overrides: + meta_out = manual_meta_overrides[mod_type]( + mod, *args_metas, **kwargs_metas + ) # type: ignore[misc, arg-type] + else: + meta_out = self.orig_forward(*args_metas, **kwargs_metas) + finally: + self._disable_module_getattr = False + elif kind == "get_attr": + self._disable_module_getattr = True + try: + attr_itr = self.root + atoms = target.split(".") + for atom in atoms: + attr_itr = getattr(attr_itr, atom) + assert isinstance(attr_itr, torch.Tensor) + meta_out = attr_itr.to(device="meta") + finally: + self._disable_module_getattr = False + else: + return rv + + # TODO + assert isinstance(rv, torch.fx.Proxy), "Dont support composite output yet" + rv.install_tensor_meta(meta_out) + except Exception as e: + warnings.warn(f"Could not compute metadata for {kind} target {target}: {e}") + + return rv + + def getattr(self, attr, attr_val, parameter_proxy_cache): + if getattr(self, "_disable_module_getattr", False): + return attr_val + else: + return super().getattr(attr, attr_val, parameter_proxy_cache) + + def call_module(self, m, forward, args, kwargs): + self.orig_forward = forward + return super().call_module(m, forward, args, kwargs) + + def _insert_module_as_submodule(self, mod: torch.nn.Module) -> str: + """ + Helper method which tries to insert a module that was not declared as submodule. + """ + idx = 0 + mod_name = mod.__class__.__name__.lower() + path = f"{mod_name}_{idx}" + while hasattr(self.root, path): + path = f"{mod_name}_{idx}" + idx += 1 + + self.root.add_module(path, mod) + return path + + def path_of_module(self, mod: torch.nn.Module) -> str: + try: + return super().path_of_module(mod) + except NameError: + if ( + self.allow_insert_stateless_mods + and len(list(mod.parameters())) == 0 + and len(list(mod.buffers())) == 0 + ): + path = self._insert_module_as_submodule(mod) + self.prev_module = path + return path + raise + + def proxy(self, node): + return MetaProxy(node, self) + + def trace(self, root, meta_args: dict[str, torch.Tensor], concrete_args=None): # type: ignore[override] + assert isinstance(meta_args, dict) + self.meta_args = meta_args + + self.patched_torch_methods = { + target: gen_constructor_wrapper(getattr(torch, target)) + for target in self._TORCH_METHODS_TO_PATCH + } + self.orig_fns = set() + + for name, (wrapper, orig) in self.patched_torch_methods.items(): + setattr(torch, name, wrapper) + self.orig_fns.add(orig) + + try: + graph = super().trace(root, concrete_args) + graph._tracer_extras = {"meta_args": meta_args} + return graph + finally: + for name, (_, orig) in self.patched_torch_methods.items(): + setattr(torch, name, orig) + + +def symbolic_trace( + root: Union[torch.nn.Module, Callable[..., Any]], + meta_args: Optional[dict[str, torch.Tensor]] = None, + concrete_args: Optional[dict[str, Any]] = None, +) -> torch.fx.GraphModule: + tracer = MetaTracer() + graph = tracer.trace(root, meta_args, concrete_args) # type: ignore[arg-type] + name = ( + root.__class__.__name__ if isinstance(root, torch.nn.Module) else root.__name__ + ) + gm = torch.fx.GraphModule(tracer.root, graph, name) + return gm diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b19ff219e8597c1c212e9773520f35a6a619756 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d4c2170b712373974607f4dd1f79451a762ae9a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4dd6f84b96ef7879b1da35db6148db0db3ee07c8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_generator.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad7b56d3bbafe6813e03ec752a2b1dd150ad2648 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/constraint_transformation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7dea56f41c9aa86080813e10a6a62be62dea2f7b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/operation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10a693cfd7dae01d555dcc12660f819277d9404c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/transform_to_z3.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4149581deb466aa3aae27ab88564fadd1928191 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/util.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78bced9972e17128d65d13eb014d4c729a2167f7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/__pycache__/z3_types.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py new file mode 100644 index 0000000000000000000000000000000000000000..02e6d6cf475ffce53ce6a63242255811ff082648 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint.py @@ -0,0 +1,643 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + + +class Constraint: + pass + + +class Conj(Constraint): + def __init__(self, conjuncts): + """ + :param conjuncts: Conjunction of constraints + """ + self.conjucts = conjuncts + + def __eq__(self, other): + if isinstance(other, Conj): + return self.conjucts == other.conjucts and self.conjucts == other.conjucts + else: + return False + + def __repr__(self): + return f"And({self.conjucts})" + + +class Disj(Constraint): + def __init__(self, disjuncts): + """ + :param disjuncts: Disjunction of constraints + """ + self.disjuncts = disjuncts + + def __eq__(self, other): + if isinstance(other, Disj): + return ( + self.disjuncts == other.disjuncts and self.disjuncts == other.disjuncts + ) + else: + return False + + def __repr__(self): + return f"Or({self.disjuncts})" + + +class Prod(Constraint): + def __init__(self, products): + """ + :param products: lists of dimensions to multiply + """ + self.products = products + + def __eq__(self, other): + if isinstance(other, Prod): + return self.products == other.products and self.products == other.products + else: + return False + + def __repr__(self): + return f"Product({self.products})" + + +class T(Constraint): + """ + True + """ + + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, T) + + def __repr__(self): + return "True" + + +class F(Constraint): + """ + False + """ + + def __init__(self) -> None: + pass + + def __eq__(self, other): + return isinstance(other, F) + + def __repr__(self): + return "False" + + +class BinaryConstraint(Constraint): + """ + Represents all binary operations + """ + + def __init__(self, lhs, rhs, op): + """ + :param lhs: lhs of the constraint + :param rhs: rhs of the constraint + :param op: string representing the operation + """ + self.lhs = lhs + self.rhs = rhs + self.op = op + + def __eq__(self, other): + if isinstance(other, BinaryConstraint): + return ( + self.lhs == other.lhs and self.rhs == other.rhs and self.op == other.op + ) + else: + return False + + def __repr__(self): + return f"({self.lhs} {self.op} {self.rhs})" + + +class BinConstraintT(BinaryConstraint): + """ + Binary constraints about tensors + """ + + def __init__(self, lhs, rhs, op): + assert (isinstance(lhs, (TVar, TensorType, int)) or lhs == Dyn) and ( + isinstance(rhs, (TVar, TensorType, int)) or rhs == Dyn + ) + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class BinConstraintD(BinaryConstraint): + """ + Binary constraints about dimensions + """ + + def __init__(self, lhs, rhs, op): + assert is_algebraic_expression(lhs) or is_dim(lhs) or is_bool_expr(lhs) + assert is_algebraic_expression(rhs) or is_dim(rhs) or is_bool_expr(rhs) + + super().__init__(lhs, rhs, op) + + def __eq__(self, other): + return super().__eq__(other) + + +class TGreatestUpperBound(Constraint): + """ + Greatest Upper bound for tensors with dynamic type + """ + + def __init__(self, res, rhs1, rhs2): + """ + :param res: tensor variable that stores the result of the outout + :param rhs1: tensor or tensor variable + :param rhs2: tensor or tensor variabke + """ + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f"{self.res} = {self.rhs1}\u2294*{self.rhs2}" + + def __eq__(self, other): + if isinstance(other, TGreatestUpperBound): + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) + else: + return False + + +class DGreatestUpperBound(Constraint): + """ + Greatest Upper bound for dimensions + """ + + def __init__(self, res, rhs1, rhs2): + """ + :param res: Dimension variable to store the result + :param rhs1: dimension variable 1 + :param rhs2: dimension variable 2 + """ + assert is_dim(res) + assert is_dim(rhs1) + assert is_dim(rhs2) + + self.res = res + self.rhs1 = rhs1 + self.rhs2 = rhs2 + + def __repr__(self): + return f"{self.res} = {self.rhs1}\u2294{self.rhs2}" + + def __eq__(self, other): + if isinstance(other, DGreatestUpperBound): + return ( + self.res == other.res + and self.rhs1 == other.rhs1 + and self.rhs2 == other.rhs2 + ) + else: + return False + + +class CanReshape(Constraint): + """ + can_reshape constraint + """ + + def __init__(self, src, target): + """ + :param src: tensor variable + :param target: tensor + """ + self.src = src + self.target = target + + def __repr__(self): + return f"can-reshape({self.src}, {self.target})" + + def __eq__(self, other): + if isinstance(other, CanReshape): + return self.src == other.src and self.target == other.target + else: + return False + + +class IndexSelect(Constraint): + def __init__(self, tensor_size, input_var, dim_replace, index, output): + """ + Args: + input_var: input to index_select + tensor_size: tensor size we are considering + dim_replace: the dimension of the output at "index" + index: location of the dimensions to replace in the input + output: variable to store the result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(dim_replace, DVar) or dim_replace == Dyn + assert isinstance(index, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.dim_replace = dim_replace + self.index = index + self.output = output + + def __repr__(self): + return ( + f" {self.output} = " + f"IndexSelect({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.dim_replace}, " + f"{self.index})" + ) + + def __eq__(self, other): + if isinstance(other, IndexSelect): + return ( + self.tensor_size == other.tensor_size + and self.dim_replace == other.dim_replace + and self.index == other.index + and self.output == other.output + and self.input_var == other.input_var + ) + else: + return False + + +class Transpose(Constraint): + def __init__(self, tensor_size, input_var, index1, index2, output): + """ + Args: + tensor_size: current tensor size + input_var: variable to hold input + index1: dimension 1 + index2: dimension 2 + output: output that stores result + """ + assert isinstance(input_var, TVar) + assert isinstance(output, TVar) + assert isinstance(index1, int) + assert isinstance(index2, int) + + self.input_var = input_var + self.tensor_size = tensor_size + self.index1 = index1 + self.index2 = index2 + self.output = output + + def __repr__(self): + return ( + f" {self.output} = " + f"Transpose({self.input_var}, " + f"tensor_size: {self.tensor_size}, " + f"{self.index1}, " + f"{self.index2})" + ) + + def __eq__(self, other): + if isinstance(other, Transpose): + return ( + self.tensor_size == other.tensor_size + and self.index1 == other.index1 + and self.index2 == other.index2 + and self.output == other.output + and self.input_var == other.input_var + ) + else: + return False + + +class GetItem(Constraint): + def __init__(self, tensor_size, index, res, input_var): + """ + Constraint for getting item given a tensor size + :param tensor_size: actual number + :param index: actual number representing the index + :param res: dimension variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, DVar) + + self.res = res + self.tensor_size = tensor_size + self.index = index + self.input_var = input_var + + def __repr__(self): + return f" {self.res} = GetItem({self.input_var}, tensor_size: {self.tensor_size}, {self.index})" + + def __eq__(self, other): + if isinstance(other, GetItem): + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index == other.index + and self.input_var == other.input_var + ) + else: + return False + + +class GetItemTensor(Constraint): + def __init__(self, tensor_size, index_tuple, res, input_var): + """ + Constraint for getting item given a tensor size + However, when the argument is a tuple, we will + expect a tensor + :param tensor_size: actual number representing the rank + :param index_tuple: tuple for indexing + :param res: tensor variable to carry the item we get + :param input_var: a tensor variable from which we will get item + """ + assert isinstance(res, TVar) + + self.res = res + self.tensor_size = tensor_size + self.index_tuple = index_tuple + self.input_var = input_var + + def __repr__(self): + return f" {self.res} = GetItemT({self.input_var}, tensor_size: {self.tensor_size}, {self.index_tuple})" + + def __eq__(self, other): + if isinstance(other, GetItemTensor): + return ( + self.res == other.res + and self.tensor_size == other.tensor_size + and self.index_tuple == other.index_tuple + and self.input_var == other.input_var + ) + else: + return False + + +class CalcConv(Constraint): + def __init__( + self, + conv_result, + input_var, + c_out, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): + """ + :param conv_result: the convolution result + :param input_var: input to convolution + :param c_out: output chanel type + :param kernel: kernel tuple + """ + self.conv_result = conv_result + self.input_var = input_var + self.c_out = c_out + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return ( + f"{self.conv_result} =" + f" calc-conv({self.input_var}," + f" {self.c_out}, {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) + + def __eq__(self, other): + if isinstance(other, CalcConv): + return ( + self.conv_result == other.conv_result + and self.input_var == other.input_var + and self.c_out == other.c_out + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation + and self.matching_constraint == other.matching_constraint + ) + else: + return False + + +class CalcMaxPool(Constraint): + def __init__( + self, + maxpool_result, + input_var, + kernel, + padding, + stride, + dilation, + matching_constraint_vars, + ): + """ + :param maxpool_result: the result of maxpool + :param input_var: input to convolution + :param kernel: kernel tuple + """ + self.maxpool_result = maxpool_result + self.input_var = input_var + self.kernel = kernel + self.padding = padding + self.stride = stride + self.dilation = dilation + self.matching_constraint = matching_constraint_vars + + def __repr__(self): + return ( + f"{self.maxpool_result} =" + f" calc-maxpool({self.input_var}," + f" {self.kernel}, " + f"{self.padding}, {self.stride}," + f" {self.dilation})" + ) + + def __eq__(self, other): + if isinstance(other, CalcMaxPool): + return ( + self.maxpool_result == other.maxpool_result + and self.input_var == other.input_var + and self.kernel == other.kernel + and self.padding == other.padding + and self.stride == other.stride + and self.dilation == other.dilation + and self.matching_constraint == other.matching_constraint + ) + else: + return False + + +class ApplyBroadcasting(Constraint): + def __init__(self, res1, res2, input1, input2): + """ + :param res1: resulting tensor 1 + :param res2: resulting tensor 2 + :param input1: tensor variable 1 + :param input2: tensor variable 2 + """ + self.res1 = res1 + self.res2 = res2 + self.input1 = input1 + self.input2 = input2 + + def __eq__(self, other): + if isinstance(other, ApplyBroadcasting): + return ( + self.res1 == other.res1 + and self.res2 == other.res2 + and self.input1 == other.input1 + and self.input2 == other.input2 + ) + else: + return False + + def __repr__(self): + return ( + f"{self.res1}, {self.res2} =" + f" apply-broadcasting({self.input1}," + f" {self.input2})" + ) + + +class CalcProduct(Constraint): + """ + Given correct dimensions, calculate the product for flatten accounting for Dyn + """ + + def __init__(self, start, end, flattened, dims_to_flatten): + """ + :param start: start index + :param end: end index + :param flattened: variable to store the product + :param dims_to_flatten: the type which we will flatten + """ + assert isinstance(dims_to_flatten, list) + assert isinstance(flattened, TVar) + assert isinstance(start, int) + assert isinstance(end, int) + + self.start = start + self.end = end + self.dims_to_flatten = dims_to_flatten + self.flattened = flattened + + def __eq__(self, other): + if isinstance(other, CalcProduct): + return ( + self.start == other.start + and self.end == other.end + and self.dims_to_flatten == other.dims_to_flatten + and self.flattened == other.flattened + ) + + else: + return False + + def __repr__(self): + return f"{self.flattened} = CalcProduct({self.start}, {self.end}, {self.dims_to_flatten})" + + +class TVar: + """ + Tensor variable with no tensor constructor + """ + + def __init__(self, tvar): + """ + :param tvar: tensor variable + """ + self.tvar = tvar + + def __repr__(self): + return f"TV({self.tvar})" + + def __eq__(self, other): + if isinstance(other, TVar): + return self.tvar == other.tvar + else: + return False + + +class DVar: + """ + Dimension variable + """ + + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f"DV({self.c})" + + def __eq__(self, other): + if isinstance(other, DVar): + return self.c == other.c + else: + return False + + +class BVar: + """ + Boolean variable + """ + + def __init__(self, c): + """ + :param c: character or number + """ + self.c = c + + def __repr__(self): + return f"BV({self.c})" + + def __eq__(self, other): + if isinstance(other, BVar): + return self.c == other.c + else: + return False + + +def is_algebraic_expression(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_add, op_sub, op_div, op_mul, op_mod] + else: + return isinstance(constraint, Prod) + + +def is_bool_expr(constraint): + if isinstance(constraint, BinConstraintD): + return constraint.op in [op_gt, op_lt, op_neq, op_eq] + else: + return isinstance(constraint, (BVar, Conj, Disj)) + + +def is_dim(d): + return isinstance(d, (DVar, int)) or d == Dyn diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..099db84593a53e0c84ea384bbe397d2259983ecc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_generator.py @@ -0,0 +1,1562 @@ +# mypy: allow-untyped-defs +import operator +import warnings +from collections.abc import Iterable +from typing import Callable, TypeVar +from typing_extensions import ParamSpec + +import torch +from torch.fx._symbolic_trace import _assert_is_none +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + BinConstraintT, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_matching, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_bvar, + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, + gen_tvar, +) +from torch.fx.node import Node, Target +from torch.fx.tensor_type import Dyn, TensorType +from torch.nn.modules.batchnorm import BatchNorm2d +from torch.nn.modules.conv import Conv2d + + +_T = TypeVar("_T") +_P = ParamSpec("_P") + +_INFERENCE_RULES: dict[Target, Callable] = {} + +MAX_TENSOR_RANK = 4 + +__all__ = [ + "ConstraintGenerator", + "adaptive_inference_rule", + "add_layer_norm_constraints", + "add_linear_constraints", + "arange_inference_rule", + "assert_inference_rule", + "batchnorm_inference_rule", + "bmm_inference_rule", + "broadcasting_inference_rule", + "conv2d_inference_rule", + "cumsum_inference_rule", + "embedding_inference_rule", + "embedding_inference_rule_functional", + "eq_inference_rule", + "equality_inference_rule", + "expand_inference_rule", + "flatten_inference_rule", + "full_inference_rule", + "gen_broadcasting_constraints", + "gen_embedding_rules", + "gen_layer_norm_constraints", + "generate_flatten_constraints", + "get_attr_inference_rule", + "getitem_inference_rule", + "gt_inference_rule", + "index_select_inference_rule", + "layer_norm_functional", + "layer_norm_inference_rule", + "linear_constraints", + "linear_inference_rule", + "lt_inference_rule", + "masked_fill_inference_rule", + "maxpool_inference_rule", + "neq_inference_rule", + "range_check", + "register_inference_rule", + "relu_inference_rule", + "reshape_inference_rule", + "size_inference_rule", + "tensor_inference_rule", + "torch_dim_inference_rule", + "torch_linear_inference_rule", + "transpose_inference_rule", + "type_inference_rule", + "view_inference_rule", +] + + +def register_inference_rule( + call_target: Target, +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: + def register(fn: Callable[_P, _T]) -> Callable[_P, _T]: + if call_target in _INFERENCE_RULES: + raise RuntimeError(f"Inference rule already registered for {call_target}!") + _INFERENCE_RULES[call_target] = fn + return fn + + return register + + +def generate_flatten_constraints(start_dim, end_dim, input, flattened, n, counter): + d, counter = gen_tensor_dims(n, counter) + c1 = BinConstraintT(input, TensorType(d), op_eq) + start_dim = n if start_dim == -1 else abs(start_dim) + end_dim = n + end_dim + 1 if end_dim < 0 else end_dim + 1 + c2 = CalcProduct(start_dim, end_dim, flattened, d) + nat_constraints = gen_nat_constraints(d) + return Conj([c1, c2, *nat_constraints]), counter + + +@register_inference_rule(getattr) +def get_attr_inference_rule(n: Node, symbols, constraints, counter): + """ + If the attribute is "device" then the tensor shape is preserved + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], str) + output, counter = gen_tvar(counter) + symbols[n] = output + + input = symbols[n.args[0]] + attr = n.args[1] + + if attr == "device": + return [BinConstraintT(input, output, op_eq)], counter + else: + raise NotImplementedError("Not yet implemented") + + +@register_inference_rule(torch.bmm) +def bmm_inference_rule(n: Node, symbols, constraints, counter): + """ + Constraints that match the input to a size 3 tensor + and switch the dimensions according to the rules + of batch multiplication + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + bmm_output, counter = gen_tvar(counter) + symbols[n] = bmm_output + + bmm_input1 = symbols[n.args[0]] + bmm_input2 = symbols[n.args[1]] + + dims_input1, counter = gen_tensor_dims(3, counter) + dims_input2, counter = gen_tensor_dims(3, counter) + + inputs_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_output, Dyn, op_eq), + ] + ) + + input1_dyn = Conj( + [ + BinConstraintT(bmm_input1, Dyn, op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input2[0], Dyn, dims_input2[2]]), op_eq + ), + ] + ) + + input2_dyn = Conj( + [ + BinConstraintT(bmm_input2, Dyn, op_eq), + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT( + bmm_output, TensorType([dims_input1[0], dims_input1[1], Dyn]), op_eq + ), + ] + ) + + consistency_constraints = [ + BinConstraintD(dims_input1[0], dims_input2[0], op_consistency) + ] + + batch_size, counter = gen_dvar(counter) + + inputs_are_tensors = Conj( + [ + BinConstraintT(bmm_input1, TensorType(dims_input1), op_eq), + BinConstraintT(bmm_input2, TensorType(dims_input2), op_eq), + BinConstraintT( + bmm_output, + TensorType([batch_size, dims_input1[1], dims_input2[2]]), + op_eq, + ), + *consistency_constraints, + DGreatestUpperBound(batch_size, dims_input1[0], dims_input2[0]), + ] + ) + + return [Disj([inputs_dyn, input1_dyn, input2_dyn, inputs_are_tensors])], counter + + +@register_inference_rule("index_select") +def index_select_inference_rule(n: Node, symbols, constraints, counter): + """ + We constrain the second argument to a vector or Dyn. + The output replaces the input with the shape of the vector + at the position given by the index (first argument) + """ + # print(n.args) + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], Node) + + index_select, counter = gen_tvar(counter) + symbols[n] = index_select + + dims, counter = gen_tensor_dims(1, counter) + + # equality constraint + is_size_1 = BinConstraintT(symbols[n.args[2]], TensorType(dims), op_eq) + is_dyn = BinConstraintT(symbols[n.args[2]], Dyn, op_eq) + + c2 = Conj( + [ + is_size_1, + Disj( + [ + IndexSelect( + i + 1, symbols[n.args[0]], dims[0], n.args[1], index_select + ) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + c3 = Conj( + [ + is_dyn, + Disj( + [ + IndexSelect(i + 1, symbols[n.args[0]], Dyn, n.args[1], index_select) + for i in range(MAX_TENSOR_RANK) + ] + ), + ] + ) + + return [Disj([c2, c3])], counter + + +@register_inference_rule("expand") +def expand_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the exact constraints as we do for tensor additions but we constraint + the rank of this expression to be equal to len(n.args[1:]) so that only + those cases get considered for the output + """ + assert isinstance(n.args[0], Node) + + # define the output for expand + expand, counter = gen_tvar(counter) + symbols[n] = expand + + # since we do not have two nodes here, we will construct an argument variable + e1 = symbols[n.args[0]] + e2, counter = gen_tvar(counter) + + e2_nat_constraints = [] + for arg in n.args[1:]: + assert isinstance(arg, (Node, int)) + if isinstance(arg, Node): + assert isinstance(symbols[arg], DVar) + e2_nat_constraints.append(BinConstraintD(0, symbols[arg], op_leq)) + + e2_constraint = BinConstraintT( + e2, + TensorType( + [arg if isinstance(arg, int) else symbols[arg] for arg in n.args[1:]] + ), + op_eq, + ) + + constraints, counter = gen_broadcasting_constraints( + e1, e2, symbols, counter, expand + ) + + # constraint the output size + dims, counter = gen_tensor_dims(len(n.args[1:]), counter) + nat_constraints = gen_nat_constraints(dims) + c = [ + BinConstraintT(expand, TensorType(dims), op_eq), + *nat_constraints, + e2_constraint, + *e2_nat_constraints, + ] + constraints += c + + return constraints, counter + + +@register_inference_rule(torch.nn.functional.gelu) +@register_inference_rule(torch.nn.functional.dropout) +@register_inference_rule(torch.nn.functional.softmax) +@register_inference_rule("detach") +@register_inference_rule("to") +@register_inference_rule("int") +@register_inference_rule("long") +@register_inference_rule("contiguous") +@register_inference_rule(torch.ones) +@register_inference_rule(torch.zeros) +def equality_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + output, counter = gen_tvar(counter) + symbols[n] = output + + if isinstance(n.args[0], Node): + input = symbols[n.args[0]] + if isinstance(input, TVar): + return [BinConstraintT(input, output, op_eq)], counter + + # then we have dimension variables + else: + for arg in n.args: + assert isinstance(symbols[arg], DVar) + my_size = [symbols[arg] for arg in n.args] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + + elif isinstance(n.args[0], tuple): + # then the tuple is the size + assert len(n.args[0]) <= 4 + my_size = [symbols[arg] for arg in n.args[0]] + return [BinConstraintT(output, TensorType(my_size), op_eq)], counter + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule("transpose") +def transpose_inference_rule(n: Node, symbols, constraints, counter): + """ + Can be considered as a sequence of two index selects, so we generate constraints accordingly + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], int) + assert isinstance(n.args[2], int) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + assert isinstance(from_arg, TVar) + + # input and output are dyn + is_dyn = Conj( + [BinConstraintT(from_arg, Dyn, op_eq), BinConstraintT(output, Dyn, op_eq)] + ) + + # or input is a tensor and we actually do the replacement + c3 = Disj( + [ + Transpose(i + 1, from_arg, n.args[1], n.args[2], output) + for i in range(MAX_TENSOR_RANK) + ] + ) + + return [Disj([is_dyn, c3])], counter + + +@register_inference_rule("type_as") +def type_inference_rule(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + output, counter = gen_tvar(counter) + symbols[n] = output + + from_arg = symbols[n.args[0]] + to_arg = symbols[n.args[1]] + + assert isinstance(from_arg, TVar) + assert isinstance(to_arg, TVar) + + return [ + BinConstraintT(from_arg, to_arg, op_consistency), + BinConstraintT(output, to_arg, op_eq), + ], counter + + +@register_inference_rule("masked_fill_") +def masked_fill_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to addition. For now we implement the constraints when + the argument is a boolean tensor. There is also a case for when + it is a condition. We will leave this out for now. + """ + + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], Node) + + # We will retrieve the type variables from the symbol table + # and confirm they are tensor variables + + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + if isinstance(e1, TVar) and isinstance(e2, TVar): + masked_fill_tensor, counter = gen_tvar(counter) + symbols[n] = masked_fill_tensor + return gen_broadcasting_constraints( + e1, e2, symbols, counter, masked_fill_tensor + ) + else: + raise NotImplementedError("Not yet implemented") + + +@register_inference_rule(torch.nn.functional.embedding) +def embedding_inference_rule_functional(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + embedding_dim_weights = symbols[n.args[1]] + + # will treat this as a static shape. So we will not use matching. + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT( + embedding_dim_weights, TensorType(weight_dims), op_eq + ) + embedding_dim = weight_dims[1] + constraints, counter = gen_embedding_rules(n, symbols, embedding_dim, counter) + return [equality_constraint] + constraints, counter + + +@register_inference_rule(torch.nn.modules.sparse.Embedding) +def embedding_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + The output shape differs from the input shape in the last dimension + """ + assert isinstance(n.args[0], Node) + return gen_embedding_rules(n, symbols, module_instance.embedding_dim, counter) + + +def gen_embedding_rules(n: Node, symbols, embedding_dim, counter): + embedding_output, counter = gen_tvar(counter) + symbols[n] = embedding_output + embedding_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(embedding_input, Dyn, op_eq) + output_dyn = BinConstraintT(embedding_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + + for i in range(1, MAX_TENSOR_RANK): + new_dims, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims) + + # we consider all tensor sizes and append embedding_dim to the end of the output dimension in all cases + c_tensor_i = Conj( + [ + BinConstraintT(embedding_input, TensorType(new_dims), op_eq), + BinConstraintT( + embedding_output, TensorType(new_dims + [embedding_dim]), op_eq + ), + ] + + nat_constraints + ) + c2.append(c_tensor_i) + + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.tensor) +def tensor_inference_rule(n: Node, symbols, constraints, counter): + """ + If the tensor is a scalar, we will skip it since we + do not support scalars yet. We will add support in the future + if it's needed. For our examples so far, scalars are not needed. + """ + return [], counter + + +@register_inference_rule("reshape") +@register_inference_rule("view") +def view_inference_rule(n: Node, symbols, constraints, counter): + """ + Similar to reshape but with an extra condition on the strides + """ + assert isinstance(n.args[0], Node) + + # generate the new variable + my_view, counter = gen_tvar(counter) + symbols[n] = my_view + + src_var = symbols[n.args[0]] + t2 = [ + symbols[elem] if isinstance(elem, Node) else elem for elem in n.args[1:] + ] # target shape + t2_type = [] + num_constraints = [] + + for t in t2: + if t == -1: + var, counter = gen_dvar(counter) + t2_type.append(var) + num_constraints.append(BinConstraintD(var, Dyn, op_neq)) + + else: + num_constraints.append(BinConstraintD(t, Dyn, op_neq)) + t2_type.append(t) + + t2_type = TensorType(t2_type) # type: ignore[assignment] + + c1 = BinConstraintT(my_view, t2_type, op_eq) + c2 = CanReshape(src_var, t2_type) + + # TODO: add the extra check mentioned here: + # https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch.Tensor.view + + return [c1, c2] + num_constraints, counter # type: ignore[operator] + + +@register_inference_rule("size") +def size_inference_rule(n: Node, symbols, constraints, counter): + """ + The constraint is just lhs = rhs. + Ex: size = input_ids.size() + """ + + if len(n.args) == 1: + # generate the new variable + size, counter = gen_tvar(counter) + symbols[n] = size + input = symbols[n.args[0]] + c = BinConstraintT(input, size, op_eq) + return [c], counter + + elif len(n.args) == 2: + # TODO: review this rule; should input = dyn; output = dyn be included here? + if isinstance(n.args[1], int): + # generate the new variable + size_index, counter = gen_dvar(counter) + symbols[n] = size_index + input = symbols[n.args[0]] + c2 = [ + GetItem(i + 1, n.args[1], size_index, input) + for i in range(MAX_TENSOR_RANK) + ] + c3 = BinConstraintD(0, size_index, op_leq) + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(size_index, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + else: + raise NotImplementedError + + else: + raise NotImplementedError + + +def range_check(i, n): + """ + Checks if an index i is within range of a size n list + Args: + i: index + n: list size + + Returns: Boolean + """ + if i >= 0: + return T() if i < n else F() + else: + return T() if i >= n else F() + + +@register_inference_rule(torch.cumsum) +def cumsum_inference_rule(n: Node, symbols, constraints, counter): + """ + Input and output shapes should be equal + We should verify that the index is valid + """ + assert isinstance(n.args[0], Node) + arg_1 = n.args[1] if len(n.args) > 1 else n.kwargs["dim"] + assert isinstance(arg_1, int) + + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims), op_eq), + BinConstraintT(output, TensorType(new_dims), op_eq), + ] + + [range_check(arg_1, i)] + + nat_constraints + ) + + c2.append(c_tensor_i) + dyn_or_tensor = Disj([c1, Disj(c2)]) + return [dyn_or_tensor], counter + + +@register_inference_rule(_assert_is_none) +def assert_inference_rule(n: Node, symbols, constraints, counter): + assert len(n.users) == 0 + return [], counter + + +@register_inference_rule(operator.getitem) +def getitem_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # dimension output case + if isinstance(n.args[1], int): + # create and store the new dimension variable + get_item_output, counter = gen_dvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + # if the input is dynamic, we accept any index and return + # a dynamic dimension as output + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintD(get_item_output, Dyn, op_eq) + c1 = Conj([input_dyn, output_dyn]) + + # if the input is a tensor, + # generate a getItem constraint which will be expanded based on the + # tensor dimension. + + c2 = [ + GetItem(i + 1, n.args[1], get_item_output, get_item_arg) + for i in range(MAX_TENSOR_RANK) + ] + + # since the output is a dimension, we make sure it's a natural number + # added as a conjunction to the disjunction of c2 + c3 = BinConstraintD(0, get_item_output, op_leq) + return [Disj([c1, Conj([Disj(c2), c3])])], counter + + # tensor output case + elif isinstance(n.args[1], tuple): + # create and store the new tensor variable + get_item_output, counter = gen_tvar(counter) + symbols[n] = get_item_output + + # retrieve arg variables + if n.args[0] in symbols: + get_item_arg = symbols[n.args[0]] + assert isinstance(get_item_arg, TVar) + + input_dyn = BinConstraintT(get_item_arg, Dyn, op_eq) + output_dyn = BinConstraintT(get_item_output, Dyn, op_eq) # type: ignore[assignment] + c1 = Conj([input_dyn, output_dyn]) + + c2 = [ + GetItemTensor(i + 1, n.args[1], get_item_output, get_item_arg) # type: ignore[misc] + for i in range(MAX_TENSOR_RANK) + ] + else: + # TODO: we should figure out why there is a key-error here. + return [], counter + + return [Disj([c1, *c2])], counter + + else: + raise RuntimeError("Method not yet implemented") + + +@register_inference_rule(operator.gt) +def gt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + gt_tensor, counter = gen_tvar(counter) + symbols[n] = gt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, gt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + gt_constraint = BinConstraintD(e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + elif isinstance(e1, TVar) and isinstance(e2, int): + # then we made the wrong assumption about the argument being a tensor + # so we should fix the assumption + warnings.warn( + f"Made the wrong assumption for node {n}. Correctness not guaranteed." + ) + + new_e1, counter = gen_dvar(counter) + symbols[n.args[0]] = new_e1 + symbols[n.args[0]] + + gt_constraint = BinConstraintD(new_e1, e2, op_gt) + + my_gt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_gt, gt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise NotImplementedError("Method not yet implemented") + + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(operator.eq) +def eq_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + eq_tensor, counter = gen_tvar(counter) + symbols[n] = eq_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, eq_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + eq_constraint = BinConstraintD(e1, e2, op_eq) + + my_eq, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_eq, eq_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError("Method not yet implemented") + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(operator.ne) +def neq_inference_rule(n: Node, symbols, constraints, counter): + """ + Translates to inconsistent in gradual types. + To prove inequality, we should prove that + tensors are either different sizes or + disagree on at least one dimension + + This is a WIP (works when the condition + is false. We are working on making this operation work + when the condition is true as well) + """ + assert isinstance(n.args[0], Node) + assert isinstance(n.args[1], tuple) + + # implementing for size 3 and 4 + if len(n.args[1]) == 3: + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + + lhs = symbols[n.args[0]] + + b, counter = gen_tensor_dims(4, counter) + input_is_size3 = BinConstraintT(lhs, TensorType([b[0], b[1], b[2]]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b[0], op_neq) + neq_2 = BinConstraintD(d2, b[1], op_neq) + neq_3 = BinConstraintD(d3, b[2], op_neq) + + # dimensions inconsistent + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b[0], Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b[1], Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b[2], Dyn, op_neq), neq_3] + ) + + dims_inconsistent = Disj( + [dims_inconsistent1, dims_inconsistent2, dims_inconsistent3] + ) + + # we are covering size 3 and 4 only for now + ne_constraint = Conj([input_is_size3, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + elif len(n.args[1]) == 4: + assert isinstance(n.args[1][0], (Node, int)) + assert isinstance(n.args[1][1], (Node, int)) + assert isinstance(n.args[1][2], (Node, int)) + assert isinstance(n.args[1][3], (Node, int)) + + lhs = symbols[n.args[0]] + + b1, counter = gen_dvar(counter) + b2, counter = gen_dvar(counter) + b3, counter = gen_dvar(counter) + b4, counter = gen_dvar(counter) + + input_is_size4 = BinConstraintT(lhs, TensorType([b1, b2, b3, b4]), op_eq) + + d1 = n.args[1][0] if isinstance(n.args[1][0], int) else symbols[n.args[1][0]] + d2 = n.args[1][1] if isinstance(n.args[1][1], int) else symbols[n.args[1][1]] + d3 = n.args[1][2] if isinstance(n.args[1][2], int) else symbols[n.args[1][2]] + d4 = n.args[1][3] if isinstance(n.args[1][3], int) else symbols[n.args[1][3]] + + # dimensions not equal + my_ne, counter = gen_bvar(counter) + neq_1 = BinConstraintD(d1, b1, op_neq) + neq_2 = BinConstraintD(d2, b2, op_neq) + neq_3 = BinConstraintD(d3, b3, op_neq) + neq_4 = BinConstraintD(d4, b4, op_neq) + + # dimensions to inconsistent + dims_inconsistent1 = Conj( + [BinConstraintD(d1, Dyn, op_neq), BinConstraintD(b1, Dyn, op_neq), neq_1] + ) + dims_inconsistent2 = Conj( + [BinConstraintD(d2, Dyn, op_neq), BinConstraintD(b2, Dyn, op_neq), neq_2] + ) + dims_inconsistent3 = Conj( + [BinConstraintD(d3, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_3] + ) + dims_inconsistent4 = Conj( + [BinConstraintD(d4, Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq), neq_4] + ) + + dims_inconsistent = Disj( + [ + dims_inconsistent1, + dims_inconsistent2, + dims_inconsistent3, + dims_inconsistent4, + ] + ) + + ne_constraint = Conj([input_is_size4, dims_inconsistent]) + + my_ne, counter = gen_bvar(counter) + + equality_constraint = BinConstraintD(my_ne, ne_constraint, op_eq) + + else: + raise NotImplementedError("Method not yet implemented") + + return [equality_constraint], counter + + +@register_inference_rule(operator.lt) +def lt_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], (Node, int)) + assert isinstance(n.args[1], (Node, int)) + + # We make sure this node will not be used again. We do not + # generate a constraint about that node. Only about the operands. + + e1 = symbols[n.args[0]] if isinstance(n.args[0], Node) else n.args[0] + e2 = symbols[n.args[1]] if isinstance(n.args[1], Node) else n.args[1] + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(e1, TVar) and isinstance(e2, TVar): + lt_tensor, counter = gen_tvar(counter) + symbols[n] = lt_tensor + return gen_broadcasting_constraints(e1, e2, symbols, counter, lt_tensor) + + elif isinstance(e1, DVar) and isinstance(e2, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + + else: + raise RuntimeError("Sort Mismatch") + + elif isinstance(n.args[0], Node) and not isinstance(n.args[1], Node): + if isinstance(e1, DVar): + # This is meant to be used for flow analysis only + lt_constraint = BinConstraintD(e1, e2, op_lt) + + my_lt, counter = gen_bvar(counter) + equality_constraint = BinConstraintD(my_lt, lt_constraint, op_eq) + return [equality_constraint], counter + else: + raise NotImplementedError("Method not yet implemented") + + else: + raise NotImplementedError("Method not yet implemented") + + +@register_inference_rule(torch.full) +def full_inference_rule(n: Node, symbols, constraints, counter): + full, counter = gen_tvar(counter) + symbols[n] = full + res = [] + + assert isinstance(n.args[0], Iterable) + for arg in n.args[0]: + dim = arg if isinstance(arg, int) else symbols[arg] + res.append(dim) + c = BinConstraintT(full, TensorType(list(res)), op_eq) # type: ignore[arg-type] + return [c], counter + + +# TODO normalize index +@register_inference_rule(torch.arange) +def arange_inference_rule(n: Node, symbols, constraints, counter): + start = 0 + step = 1 + + if len(n.args) == 1: + end = symbols[n.args[0]] + else: + raise NotImplementedError("Not yet implemented") + + # int((end - start) / step) + d1, counter = gen_dvar(counter) + size_constraint = BinConstraintD( + d1, BinConstraintD(BinConstraintD(end, start, op_sub), step, op_div), op_eq + ) + arange, counter = gen_tvar(counter) + symbols[n] = arange + + # either the a parameter is a number or it is Dyn + c1 = Disj( + [ + BinConstraintD(end, Dyn, op_eq), + BinConstraintD(start, Dyn, op_eq), + BinConstraintD(step, Dyn, op_eq), + ] + ) + c2 = BinConstraintD(d1, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + c11 = Conj( + [ + BinConstraintD(end, Dyn, op_neq), + BinConstraintD(start, Dyn, op_neq), + BinConstraintD(step, Dyn, op_neq), + ] + ) + c22 = BinConstraintD(d1, Dyn, op_neq) + both_numbers = Conj([c11, c22, size_constraint]) + + return [ + BinConstraintT(arange, TensorType([d1]), op_eq), + Disj([both_dyn, both_numbers]), + ], counter + + +def gen_broadcasting_constraints(e1, e2, symbols, counter, output_var): + # additional vars that don't correspond to expressions + e11, counter = gen_tvar(counter) + e22, counter = gen_tvar(counter) + + # generate constraints + c1 = TGreatestUpperBound(output_var, e11, e22) + c2 = ApplyBroadcasting(e11, e22, e1, e2) + c3 = BinConstraintT(e11, e22, op_consistency) + return [c1, c2, c3], counter + + +@register_inference_rule(operator.mul) +@register_inference_rule(torch.ne) +@register_inference_rule("ne") +@register_inference_rule(torch.add) +@register_inference_rule(operator.add) +def broadcasting_inference_rule(n: Node, symbols, constraints, counter): + op_code = None + if n.target == operator.add or n.target == torch.add: + op_code = op_add + elif n.target == operator.mul: + op_code = op_mul + + if isinstance(n.args[0], Node) and isinstance(n.args[1], Node): + if isinstance(symbols[n.args[0]], TVar) and isinstance( + symbols[n.args[1]], TVar + ): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + e2 = symbols[n.args[1]] + + return gen_broadcasting_constraints(e1, e2, symbols, counter, my_output) + else: + raise NotImplementedError("Method not yet implemented") + + elif isinstance(n.args[0], Node) and isinstance(n.args[1], (int, float)): + if isinstance(symbols[n.args[0]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + return [BinConstraintT(my_output, e1, op_eq)], counter + elif isinstance(symbols[n.args[0]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e1 = symbols[n.args[0]] + + # we will propagate the runtime value here since this is regular addition + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e1, n.args[1], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) + return [c], counter + + elif isinstance(n.args[1], Node) and isinstance(n.args[0], (int, float)): + if isinstance(symbols[n.args[1]], TVar): + my_output, counter = gen_tvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + return [BinConstraintT(my_output, e2, op_eq)], counter + elif isinstance(symbols[n.args[1]], DVar): + my_output, counter = gen_dvar(counter) + symbols[n] = my_output + e2 = symbols[n.args[1]] + + # we will propagate the runtime value here since this is regular addition + c = Conj( + [ + BinConstraintD( + my_output, BinConstraintD(e2, n.args[0], op_code), op_eq + ), + BinConstraintD(0, my_output, op_leq), + ] + ) + return [c], counter + + else: + raise NotImplementedError("Method not yet implemented") + + else: + # TODO generate add constraints for scalar addition + raise NotImplementedError("Addition not yet implemented") + + +@register_inference_rule(torch.flatten) +def flatten_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + flattened, counter = gen_tvar(counter) + symbols[n] = flattened + + input = symbols[n.args[0]] + + # set the default start and end dims + start_dim = 1 + end_dim = -1 + + if len(n.args) > 1: + assert isinstance(n.args[1], int) + start_dim = n.args[1] + + if len(n.args) > 2: + assert isinstance(n.args[2], int) + end_dim = n.args[2] + + c1 = BinConstraintT(input, Dyn, op_eq) + c2 = BinConstraintT(flattened, Dyn, op_eq) + both_dyn = Conj([c1, c2]) + + const = [] + for i in range(1, MAX_TENSOR_RANK + 1): + c, counter = generate_flatten_constraints( + start_dim, end_dim, input, flattened, i, counter + ) + const.append(c) + + return [Disj([both_dyn, *const])], counter + + +@register_inference_rule(torch.nn.functional.layer_norm) +def layer_norm_functional(n: Node, symbols, constraints, counter): + """ + We generate the constraint: input = output + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints(n, n.args[1], symbols, counter) + + +@register_inference_rule(torch.nn.LayerNorm) +def layer_norm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + Input should be consistent with the normalized_shape + """ + assert isinstance(n.args[0], Node) + return gen_layer_norm_constraints( + n, module_instance.normalized_shape, symbols, counter + ) + + +def gen_layer_norm_constraints(n: Node, normalized_shape, symbols, counter): + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintT(output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs, counter = gen_tensor_dims(i, counter) + nat_constraints = gen_nat_constraints(new_dims_rhs) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs), op_eq), + BinConstraintT(output, TensorType(new_dims_rhs), op_eq), + ] + + add_layer_norm_constraints(new_dims_rhs, list(normalized_shape)) + + nat_constraints + ) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + + +@register_inference_rule(torch.nn.Dropout) +@register_inference_rule(torch.nn.ReLU) +def relu_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output shapes should be equal. + """ + assert isinstance(n.args[0], Node) + output, counter = gen_tvar(counter) + symbols[n] = output + input = symbols[n.args[0]] + assert isinstance(input, TVar) + return [BinConstraintT(input, output, op_eq)], counter + + +@register_inference_rule(torch.nn.Linear) +def linear_inference_rule(n: Node, module_instance, symbols, constraints, counter): + """ + Input and output sizes should be the same except for the last dimension + If the input is Dyn, then so should the output + """ + assert isinstance(n.args[0], Node) + return linear_constraints( + n, module_instance.in_features, module_instance.out_features, symbols, counter + ) + + +@register_inference_rule("dim") +def torch_dim_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + my_dim, counter = gen_dvar(counter) + symbols[n] = my_dim + input = symbols[n.args[0]] + + input_dyn = BinConstraintT(input, Dyn, op_eq) + output_dyn = BinConstraintD(my_dim, Dyn, op_eq) + + c1 = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + + c_tensor_i = Conj( + [ + BinConstraintT(input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintD(my_dim, i, op_eq), + ] + ) + c1.append(c_tensor_i) + + return [Disj([Conj([input_dyn, output_dyn]), Disj(c1)])], counter + + +@register_inference_rule(torch._C._nn.linear) +def torch_linear_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + weight_dims, counter = gen_tensor_dims(2, counter) + equality_constraint = BinConstraintT( + symbols[n.args[1]], TensorType(weight_dims), op_eq + ) + constraints, counter = linear_constraints( + n, weight_dims[1], weight_dims[0], symbols, counter + ) + return [equality_constraint] + constraints, counter + + +def linear_constraints(n: Node, in_features, out_features, symbols, counter): + linear_output, counter = gen_tvar(counter) + symbols[n] = linear_output + linear_input = symbols[n.args[0]] + + input_dyn = BinConstraintT(linear_input, Dyn, op_eq) + output_dyn = BinConstraintT(linear_output, Dyn, op_eq) + + c1 = Conj([input_dyn, output_dyn]) + + c2 = [] + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj( + [ + BinConstraintT(linear_input, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(linear_output, TensorType(new_dims_rhs_2), op_eq), + ] + + add_linear_constraints( + new_dims_rhs_1, new_dims_rhs_2, in_features, out_features + ) + + nat_constraints + ) + c2.append(c_tensor_i) + return [Disj([c1, Disj(c2)])], counter + + +def add_layer_norm_constraints(input_dim, normalized_dim): + """ + The constraints say that the type has te form: [*, 1024, 1024] + while the normalized_dim have the form [1024, 1024] + Args: + input_dim: Input shape of layer norm + normalized_dim: normalized_dim parameter of the module instance + + """ + + # in this case we return false since there's a pattern mismatch + if len(normalized_dim) > len(input_dim): + return [F()] + + else: + constraints = [] + for i, n in zip(reversed(input_dim), reversed(normalized_dim)): + constraints.append(BinConstraintD(i, n, op_consistency)) + return constraints + + +def add_linear_constraints(dims1, dims2, in_features, out_features): + assert len(dims1) == len(dims2) + constraints = [] + for i in range(len(dims1)): + if i == len(dims1) - 1: + constraints.append(BinConstraintD(dims1[i], in_features, op_consistency)) + constraints.append(BinConstraintD(dims2[i], out_features, op_eq)) + else: + constraints.append(BinConstraintD(dims1[i], dims2[i], op_eq)) + + return constraints + + +@register_inference_rule(torch.reshape) +def reshape_inference_rule(n: Node, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + my_reshape, counter = gen_tvar(counter) + symbols[n] = my_reshape + + src_var = symbols[n.args[0]] + t2 = n.args[1] + t2_type = TensorType([Dyn if elem == -1 else elem for elem in t2]) # type: ignore[union-attr] + c1 = BinConstraintT(my_reshape, t2_type, op_eq) # type: ignore[union-attr] + c2 = CanReshape(src_var, t2_type) + + return [c1, c2], counter + + +@register_inference_rule(BatchNorm2d) +def batchnorm_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + # generate the new variable + batchnorm_output, counter = gen_tvar(counter) + symbols[n] = batchnorm_output + batchnorm_input = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + c1 = BinConstraintT(batchnorm_input, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT(batchnorm_input, batchnorm_output, op_eq) + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(torch.nn.AdaptiveAvgPool2d) +def adaptive_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + avg_pool, counter = gen_tvar(counter) + + symbols[n] = avg_pool + input_var = symbols[n.args[0]] + + # dim vars + d1, counter = gen_dvar(counter) + d2, counter = gen_dvar(counter) + d3, counter = gen_dvar(counter) + d4, counter = gen_dvar(counter) + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + c2 = BinConstraintT( + avg_pool, + TensorType( + [d1, d2, module_instance.output_size[0], module_instance.output_size[1]] + ), + op_eq, + ) + + return [c1, c2, *nat_constraints], counter + + +@register_inference_rule(Conv2d) +def conv2d_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + + my_conv, counter = gen_tvar(counter) + symbols[n] = my_conv + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + # c1 = Matching(input_var, TensorType([d1, d2, d3, d4])) + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + # c2 = DConsistency(module_instance.in_channels, d2) + c2 = BinConstraintD(module_instance.in_channels, d2, op_consistency) + + c3 = CalcConv( + my_conv, + input_var, + module_instance.out_channels, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, c3, *nat_constraints], counter + + +@register_inference_rule(torch.nn.MaxPool2d) +def maxpool_inference_rule(n: Node, module_instance, symbols, constraints, counter): + assert isinstance(n.args[0], Node) + maxpool, counter = gen_tvar(counter) + symbols[n] = maxpool + input_var = symbols[n.args[0]] + + # dim vars + [d1, d2, d3, d4], counter = gen_tensor_dims(MAX_TENSOR_RANK, counter) + + c1 = BinConstraintT(input_var, TensorType([d1, d2, d3, d4]), op_matching) + + c2 = CalcMaxPool( + maxpool, + input_var, + module_instance.kernel_size, + module_instance.padding, + module_instance.stride, + module_instance.dilation, + [d1, d2, d3, d4], + ) + + nat_constraints = gen_nat_constraints([d1, d2, d3, d4]) + + return [c1, c2, *nat_constraints], counter + + +class ConstraintGenerator: + def __init__(self, traced, graph=None): + self.traced = traced # traced or tracer.root + self.traced_params = dict(self.traced.named_parameters()) + self.constraints = [] + self.symbol_dict = {} + self.graph = traced.graph if hasattr(traced, "graph") else graph + + def generate_constraints(self, counter=0): + """ + Iterate through every node and generate constraints + Effect: self.constraints will be populated with the final constraints + """ + graph = self.graph + + all_constraints = [] + + for n in graph.nodes: + (constraints, counter) = self.generate_constraints_node(n, counter) + all_constraints += constraints + + return Conj(all_constraints), counter + + def generate_constraints_node(self, n: Node, counter): + """ + Generate constraints the given node: + Currently supported operations: + - Reshape + - Add + - conv2d + """ + + if n.op == "placeholder": + x, counter = gen_tvar(counter) + self.symbol_dict[n] = x + + my_type = n.type + + if n.type != Dyn and (not isinstance(n.type, TensorType)): + if n.type == torch.nn.parameter.Parameter: + # since we have a parameter, the shape must be static + assert "example_value" in n.meta + my_type = TensorType(n.meta["example_value"].size()) + else: + my_type = Dyn + + c1 = BinConstraintT(my_type, x, op_precision) + c2 = BinConstraintT(x, MAX_TENSOR_RANK, op_leq) + return [c1, c2], counter + + elif n.op == "call_function": + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "call_module": + module_instance = self.traced.get_submodule(n.target) + if type(module_instance) in _INFERENCE_RULES: + return _INFERENCE_RULES[type(module_instance)]( + n, module_instance, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for class {type(module_instance)}!" + ) + + elif n.op == "call_method": + if n.target in _INFERENCE_RULES: + return _INFERENCE_RULES[n.target]( + n, self.symbol_dict, self.constraints, counter + ) + else: + raise RuntimeError( + f"No inference rule registered for target {n.target}!" + ) + + elif n.op == "get_attr": + t = self.traced_params.get(n.target, None) + + if isinstance(t, torch.Tensor): + if len(t.shape) > 0: + res = list(t.shape) + attr_type = TensorType(res) + output, counter = gen_tvar(counter) + self.symbol_dict[n] = output + return [BinConstraintT(output, attr_type, op_eq)], counter + else: + # scalar? + return [], counter + else: + return [], counter + + elif n.op == "output": + return [], counter + + else: + raise NotImplementedError(f"Method {n.op} not yet implemented") diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py new file mode 100644 index 0000000000000000000000000000000000000000..b6567e58ed25ad3cba3d30640e8dbee7d2be992b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/constraint_transformation.py @@ -0,0 +1,1322 @@ +# mypy: ignore-errors +import copy +import itertools +from typing import Callable + +from torch.fx.experimental.migrate_gradual_types.constraint import ( + ApplyBroadcasting, + BinConstraintD, + CalcConv, + CalcMaxPool, + CalcProduct, + CanReshape, + Conj, + Constraint, + DGreatestUpperBound, + Disj, + DVar, + F, + GetItem, + GetItemTensor, + IndexSelect, + Prod, + T, + TGreatestUpperBound, + Transpose, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + BinConstraintT, + MAX_TENSOR_RANK, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_consistency, + op_div, + op_eq, + op_leq, + op_matching, + op_mod, + op_mul, + op_neq, + op_precision, + op_sub, +) +from torch.fx.experimental.migrate_gradual_types.util import ( + gen_dvar, + gen_nat_constraints, + gen_tensor_dims, +) +from torch.fx.tensor_type import Dyn, TensorType + + +_TRANSFORMATION_RULES: dict[Constraint, Callable] = {} + + +def register_transformation_rule(call_target): + def register(fn): + if call_target in _TRANSFORMATION_RULES: + raise RuntimeError( + f"Transformation rule already registered for {call_target}!" + ) + _TRANSFORMATION_RULES[call_target] = fn + return fn + + return register + + +def valid_index(index, dims): + """ + Given a list of dimensions, checks if an index is valid in the list + """ + try: + dims[index] + return T() + except IndexError: + return F() + + +@register_transformation_rule(Transpose) +def transform_transpose(constraint, counter): + """ + Similar to a sequence of two index-selects + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index1 = valid_index(constraint.index1, dims) + is_valid_index2 = valid_index(constraint.index2, dims) + new_dims = copy.deepcopy(dims) + nat_constraints = gen_nat_constraints(dims) + + if is_valid_index1 == T() and is_valid_index2 == T(): + new_dims[constraint.index1] = dims[constraint.index2] + new_dims[constraint.index2] = dims[constraint.index1] + + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index1, + is_valid_index2, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) + return transformed_constraint, counter + + +@register_transformation_rule(IndexSelect) +def transform_index_select(constraint, counter): + """ + The constraints consider the given tensor size, checks if the index is valid + and if so, generates a constraint for replacing the input dimension + with the required dimension + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + is_valid_index = valid_index(constraint.index, dims) + nat_constraints = gen_nat_constraints(dims) + + # if the index is valid then replace the input dimension with the new dimension + # otherwise the dimension will not be replaced and the clause will contain False + if is_valid_index == T(): + new_dims = copy.deepcopy(dims) + new_dims[constraint.index] = constraint.dim_replace + + transformed_constraint = Conj( + [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + BinConstraintT(constraint.output, TensorType(new_dims), op_eq), + ] + ) + + # print(constraints) + return transformed_constraint, counter + + +@register_transformation_rule(GetItem) +def transform_get_item(constraint, counter): + """ + generate an equality of the form: + t = [a1, ..., an] + then generate constraints that check if the given index is valid + given this particular tensor size. + If the index is valid, generate a constraint to get the item + Note that we already handled the Dyn input case in the previous + step. + Args: + constraint: GetItem which assumes we are getting an item from a tensor (not Dyn) + counter: variable tracking + Returns: simplified constraints for GetItem + + """ + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + is_valid_index = valid_index(constraint.index, dims) + + all_constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + *nat_constraints, + is_valid_index, + ] + + # if the index is valid, we generate a constraint for getting an item + # otherwise this clause will have been UNSAT due to the wrong index + if is_valid_index == T(): + all_constraints.append( + BinConstraintD(constraint.res, dims[constraint.index], op_eq) + ) + + return Conj(all_constraints), counter + + +def valid_index_tensor(index, dims): + """ + if the slice instances exceed the length of the dimensions + then this is a type error so we return False + """ + slice_count = 0 + for s in index: + if isinstance(s, slice): + slice_count += 1 + if slice_count > len(dims): + return F() + else: + return T() + + +@register_transformation_rule(GetItemTensor) +def transform_get_item_tensor(constraint, counter): + """ + When the index is a tuple, then the output will be a tensor + TODO: we have to check if this is the case for all HF models + + The cases we are covering here are a tuple with one of: + - slice with default argument + - None + + None appends 1 to the input tensor dimensions + so each occurrence of 'None' increases the rank by 1 + + slice with default arguments does not change the rank + """ + assert isinstance(constraint.index_tuple, tuple) + + # generate a result tensor of the expected size + dims, counter = gen_tensor_dims(constraint.tensor_size, counter) + nat_constraints = gen_nat_constraints(dims) + + # generate a place-holder list of the right rank + # where "slice" does not contribute to the rank and "None" does + none_c = constraint.index_tuple.count(None) + resulting_tensor_dims = (none_c + len(dims)) * [None] + + dim_index = 0 + for i in range(len(constraint.index_tuple)): + # append 1 to the right location of the resulting tensor + if constraint.index_tuple[i] is None: + resulting_tensor_dims[i] = 1 + + elif constraint.index_tuple[i] == slice(None, None, None): + pass + + else: + raise NotImplementedError("Method not yet implemented") + + # append the remaining dimensions to the right location + dim_index = 0 + for i in range(len(resulting_tensor_dims)): + if resulting_tensor_dims[i] is None: + resulting_tensor_dims[i] = dims[dim_index] + dim_index += 1 + + # check if the index is valid + is_valid_index = valid_index_tensor(constraint.index_tuple, dims) + + # check if the resulting tensor is within bounds + if len(resulting_tensor_dims) > 4: + return F(), counter + + else: + constraints = [ + BinConstraintT(constraint.input_var, TensorType(dims), op_eq), + BinConstraintT(constraint.res, TensorType(resulting_tensor_dims), op_eq), + *nat_constraints, + is_valid_index, + ] + return Conj(constraints), counter + + +@register_transformation_rule(BinConstraintT) +def generate_binconstraint_t(constraint, counter): + """ + Transform binary constraints for tensors + """ + + # precision constraints + if constraint.op == op_precision: + if constraint.lhs == Dyn: + return T(), counter + elif isinstance(constraint.lhs, TensorType): + is_fully_static = all(d != Dyn for d in constraint.lhs.__args__) + if is_fully_static: + return BinConstraintT(constraint.lhs, constraint.rhs, op_eq), counter + else: + new_dims = [] + + for _ in range(len(constraint.lhs.__args__)): + dim, counter = gen_dvar(counter) + new_dims.append(dim) + + new_dim_constraints = ( + [ + BinConstraintD(old_dim, new_dim, op_precision) + for new_dim, old_dim in zip(new_dims, constraint.lhs.__args__) + ] + + [BinConstraintT(constraint.rhs, TensorType(new_dims), op_eq)] + + [BinConstraintD(1, new_dim, op_leq) for new_dim in new_dims] + ) + return Conj(new_dim_constraints), counter + + # matching + elif constraint.op == op_matching: + assert isinstance(constraint.rhs, TensorType) + d1 = constraint.rhs.__args__[0] + d2 = constraint.rhs.__args__[1] + d3 = constraint.rhs.__args__[2] + d4 = constraint.rhs.__args__[3] + + conj = [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintD(d1, Dyn, op_eq), + BinConstraintD(d2, Dyn, op_eq), + BinConstraintD(d3, Dyn, op_eq), + BinConstraintD(d4, Dyn, op_eq), + ] + return ( + Disj( + [ + Conj(conj), + BinConstraintT(constraint.lhs, TensorType([d1, d2, d3, d4]), op_eq), + ] + ), + counter, + ) + + elif constraint.op == op_consistency: + c_dyn = Disj( + [ + BinConstraintT(constraint.lhs, Dyn, op_eq), + BinConstraintT(constraint.rhs, Dyn, op_eq), + ] + ) + ( + ( + c_tensor_1, + c_tensor_2, + c_tensor_3, + c_tensor_4, + ), + counter, + ) = gen_consistency_constraints(constraint, counter) + + return Disj([c_dyn, c_tensor_1, c_tensor_2, c_tensor_3, c_tensor_4]), counter + + elif constraint.op == op_leq: + assert isinstance(constraint.rhs, int) + disj = [BinConstraintT(constraint.lhs, Dyn, op_eq)] + for i in range(1, constraint.rhs + 1): + dims = [] + for _ in range(1, i + 1): + dim_var, counter = gen_dvar(counter) + dims.append(dim_var) + disj.append(BinConstraintT(constraint.lhs, TensorType(dims), op_eq)) + return Disj(disj), counter + else: + return constraint, counter + + +@register_transformation_rule(BinConstraintD) +def generate_binconstraint_d(constraint, counter): + """ + Transform binary constraints for dimensions + """ + if constraint.op == op_precision: + if isinstance(constraint.lhs, int): + return BinConstraintD(constraint.lhs, constraint.rhs, op_eq), counter + elif constraint.lhs == Dyn: + return T(), counter + + elif constraint.op == op_consistency: + return ( + Disj( + [ + BinConstraintD(constraint.lhs, constraint.rhs, op_eq), + BinConstraintD(constraint.rhs, Dyn, op_eq), + BinConstraintD(constraint.lhs, Dyn, op_eq), + ] + ), + counter, + ) + + else: + return constraint, counter + + +@register_transformation_rule(Conj) +def generate_conj(constraint, counter): + """ + Transform conjunctions + """ + new = [] + for c in constraint.conjucts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Conj(new), counter + + +@register_transformation_rule(Disj) +def generate_disj(constraint, counter): + """ + Transform disjunctions + """ + new = [] + for c in constraint.disjuncts: + new_c, counter = transform_constraint(c, counter) + new.append(new_c) + return Disj(new), counter + + +@register_transformation_rule(TGreatestUpperBound) +def generate_gub(constraint, counter): + """ + Transform greatest upper bound for tensors. Results in equality and Greatest Upper Bound + on dimensions + """ + c1 = Conj( + [ + Disj( + [ + BinConstraintT(constraint.rhs1, Dyn, op_eq), + BinConstraintT(constraint.rhs2, Dyn, op_eq), + ] + ), + BinConstraintT(constraint.res, Dyn, op_eq), + ] + ) + + [c2, c3, c4, c5], counter = gen_greatest_upper_bound(constraint, counter) + + return Disj([c1, c2, c3, c4, c5]), counter + + +@register_transformation_rule(DGreatestUpperBound) +def generate_d_gub(constraint, counter): + """ + Transform greatest upper bound for dimensions into equality constraints + """ + c1 = Conj( + [ + BinConstraintD(constraint.rhs1, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs2, op_eq), + ] + ) + c2 = Conj( + [ + BinConstraintD(constraint.rhs2, Dyn, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) + c3 = Conj( + [ + BinConstraintD(constraint.rhs2, constraint.rhs1, op_eq), + BinConstraintD(constraint.res, constraint.rhs1, op_eq), + ] + ) + return Disj([c1, c2, c3]), counter + + +@register_transformation_rule(CalcConv) +def generate_calc_conv(constraint, counter): + d, counter = gen_tensor_dims(4, counter) + conv_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the convolution result is a tensor of size 4 + c1 = BinConstraintT(constraint.conv_result, conv_result, op_eq) + + # the second dimension of the output is equal to the output channels + c2 = Conj( + [ + BinConstraintD(d[1], constraint.c_out, op_eq), + BinConstraintD(d[1], Dyn, op_neq), + ] + ) + + # the input corresponds to the output in the first dimension of the convolution + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcMaxPool) +def generate_calc_maxpool(constraint, counter): + """ + Transform maxpool constraints + """ + d, counter = gen_tensor_dims(4, counter) + maxpool_result = TensorType([d[0], d[1], d[2], d[3]]) + + # the maxpool result is a tensor of size 4 + c1 = BinConstraintT(constraint.maxpool_result, maxpool_result, op_eq) + + # the input corresponds to the output in the first and second dimension of maxpool + c2 = BinConstraintD(constraint.matching_constraint[1], d[1], op_eq) + c3 = BinConstraintD(constraint.matching_constraint[0], d[0], op_eq) + c4, c5 = calc_last_two_dims(constraint, d) + + leq_constraints = Conj( + [ + BinConstraintD(0, d[0], op_leq), + BinConstraintD(0, d[1], op_leq), + BinConstraintD(0, d[2], op_leq), + BinConstraintD(0, d[3], op_leq), + ] + ) + + return Conj([c1, c2, c3, c4, c5, leq_constraints]), counter + + +@register_transformation_rule(CalcProduct) +def generate_calc_product(constraint, counter): + """ + Transform flatten constraints + """ + start = constraint.start + end = constraint.end + dims = constraint.dims_to_flatten + flattened = constraint.flattened + n = len(constraint.dims_to_flatten) + + # this will be evaluated right here + boundary_check = 0 <= start and start < end and end <= n + + c_boundary = T() if boundary_check else F() + + lhs = dims[0:start] + rhs = dims[end:] + mid = dims[start:end] + + all_possibilities = generate_all_int_dyn_dim_possibilities(mid) + + all_constraints = [] + + for p in all_possibilities: + p = list(p) + # this tells us there is a dynamic variable + contains_dyn = not all(constraint.op == op_neq for constraint in p) + if contains_dyn: + mid_var = [Dyn] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ) + ] + + p + ) + ) + else: + new_var, counter = gen_dvar(counter) + mid_eq_prod = Conj( + [ + BinConstraintD(new_var, Prod(mid), op_eq), + BinConstraintD(new_var, Dyn, op_neq), + ] + ) + mid_var = [new_var] + total_constraints = lhs + mid_var + rhs + if len(total_constraints) > 4: + all_constraints.append(F()) + else: + all_constraints.append( + Conj( + [ + BinConstraintT( + flattened, TensorType(lhs + mid_var + rhs), op_eq + ), + mid_eq_prod, + ] + + p + ) + ) + + return Conj([Disj(all_constraints), c_boundary]), counter + + +@register_transformation_rule(CanReshape) +def generate_reshape(constraint, counter): + """ + Transform reshape constraints + """ + d, counter = gen_tensor_dims(4, counter) + + d1 = d[0] + d2 = d[1] + d3 = d[2] + d4 = d[3] + + target = constraint.target.__args__ + + is_fully_static = all(d != Dyn for d in target) + + # dynamic tensor + c1_dyn = BinConstraintT(constraint.src, Dyn, op_eq) + c2_tensor1 = BinConstraintT(constraint.src, TensorType([d1]), op_eq) + c2_tensor2 = BinConstraintT(constraint.src, TensorType([d1, d2]), op_eq) + c2_tensor3 = BinConstraintT(constraint.src, TensorType([d1, d2, d3]), op_eq) + c2_tensor4 = BinConstraintT(constraint.src, TensorType([d1, d2, d3, d4]), op_eq) + + d1_eq_dyn = BinConstraintD(d1, Dyn, op_eq) + d1_neq_dyn = BinConstraintD(d1, Dyn, op_neq) + + d2_eq_dyn = BinConstraintD(d2, Dyn, op_eq) + d2_neq_dyn = BinConstraintD(d2, Dyn, op_neq) + + d3_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d3_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + d4_eq_dyn = BinConstraintD(d3, Dyn, op_eq) + d4_neq_dyn = BinConstraintD(d3, Dyn, op_neq) + + nat_d1 = BinConstraintD(0, d1, op_leq) + nat_d2 = BinConstraintD(0, d2, op_leq) + nat_d3 = BinConstraintD(0, d3, op_leq) + nat_d4 = BinConstraintD(0, d4, op_leq) + + if is_fully_static: + # size 1 tensor + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, BinConstraintD(d1, Prod(target), op_eq)]))] + ) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # size 2 tensor + all_tensor_2 = Conj( + [c2_tensor2, gen_all_reshape_possibilities([d1, d2], target)] + ) + + # size 3 tensor + all_tensor_3 = Conj( + [c2_tensor3, gen_all_reshape_possibilities([d1, d2, d3], target)] + ) + + # size 4 tensor + all_tensor_4 = Conj( + [c2_tensor4, gen_all_reshape_possibilities([d1, d2, d3, d4], target)] + ) + + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) + + # then there must be exactly one occurrence of dyn + else: + new_target = [n for n in target if n != Dyn] + + # tensor 1 + c3_tensor1 = Disj( + [d1_eq_dyn, (Conj([d1_neq_dyn, is_dim_div_by_target(new_target, d1)]))] + ) + all_tensor_1 = Conj([c2_tensor1, c3_tensor1]) + + # tensor 2 + c21 = Disj([d1_eq_dyn, d2_eq_dyn]) + c22 = Conj( + [d1_neq_dyn, d2_neq_dyn, is_dim_div_by_target(new_target, Prod([d1, d2]))] + ) + all_tensor_2 = Conj([c2_tensor2, Disj([c21, c22])]) + + # tensor 3 + c31 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn]) + c32 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3])), + ] + ) + all_tensor_3 = Conj([c2_tensor3, Disj([c31, c32])]) + + # tensor 4 + c41 = Disj([d1_eq_dyn, d2_eq_dyn, d3_eq_dyn, d4_eq_dyn]) + c42 = Conj( + [ + d1_neq_dyn, + d2_neq_dyn, + d3_neq_dyn, + d4_neq_dyn, + is_dim_div_by_target(new_target, Prod([d1, d2, d3, d4])), + ] + ) + all_tensor_4 = Conj([c2_tensor4, Disj([c41, c42])]) + + return ( + Conj( + [ + Disj( + [c1_dyn, all_tensor_1, all_tensor_2, all_tensor_3, all_tensor_4] + ), + nat_d1, + nat_d2, + nat_d3, + nat_d4, + ] + ), + counter, + ) + + +@register_transformation_rule(ApplyBroadcasting) +def generate_broadcasting(constraint, counter): + """ + Transform broadcasting constraints + """ + e11, e12 = constraint.res1, constraint.res2 + e1, e2 = constraint.input1, constraint.input2 + + e1_dyn = BinConstraintT(e1, Dyn, op_eq) + e2_dyn = BinConstraintT(e2, Dyn, op_eq) + + # Introduce dimensions + e1_equal_e11 = BinConstraintT(e1, e11, op_eq) + e2_equal_e12 = BinConstraintT(e2, e12, op_eq) + + # dyn possibility + e1_dyn_constraint = Conj([e1_dyn, e1_equal_e11, e2_equal_e12]) + e2_dyn_constraint = Conj([e2_dyn, e1_equal_e11, e2_equal_e12]) + + # tensor possibility + # generate dimensions to create tensors of size 1 + final_tensor_1_constraint, _, _, nat_dims_1, counter = gen_broadcasting_constraints( + e1, e2, e11, e12, 1, counter + ) + + # generate dimensions to create tensors of size 2 + ( + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + nat_dims_2, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 2, counter) + + # generate dimensions to create tensors of size 3 + ( + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + nat_dims_3, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 3, counter) + + # generate dimensions to create tensors of size 4 + ( + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2, + nat_dims_4, + counter, + ) = gen_broadcasting_constraints(e1, e2, e11, e12, 4, counter) + + final_result = Disj( + [ + e1_dyn_constraint, + e2_dyn_constraint, + final_tensor_1_constraint, + final_tensor_2_constraint_no_padding, + final_tensor_2_constraint_padding_arg1, + final_tensor_2_constraint_padding_arg2, + final_tensor_3_constraint_no_padding, + final_tensor_3_constraint_padding_arg1, + final_tensor_3_constraint_padding_arg2, + final_tensor_4_constraint_no_padding, + final_tensor_4_constraint_padding_arg1, + final_tensor_4_constraint_padding_arg2, + ] + ) + + return ( + Conj([final_result, *nat_dims_1, *nat_dims_2, *nat_dims_3, *nat_dims_4]), + counter, + ) + + +def transform_constraint(constraint: Constraint, counter: int): + """ + Transforms a constraint into a simpler constraint. + Ex: precision and consistency are transformed to equality + Args: + constraint: constraint to be transformed + counter: for variable tracking + + Returns: Constraint + + """ + if type(constraint) in _TRANSFORMATION_RULES: + return _TRANSFORMATION_RULES[type(constraint)](constraint, counter) + + else: + return constraint, counter + + +def calc_last_two_dims(constraint, d: list[DVar]): + """ + Generates constraints for the last two dimensions of a convolution or a maxpool output + Args: + constraint: CalcConv or CalcMaxPool + d: The list of output dimensions + + Returns: Constraints for calculating the last two dimensions of the output + + """ + + assert isinstance(constraint, (CalcConv, CalcMaxPool)) + + b3 = constraint.matching_constraint[2] + b4 = constraint.matching_constraint[3] + + b3_dyn = Conj([BinConstraintD(d[2], Dyn, op_eq), BinConstraintD(b3, Dyn, op_eq)]) + b4_dyn = Conj([BinConstraintD(d[3], Dyn, op_eq), BinConstraintD(b4, Dyn, op_eq)]) + + d3_not_dyn = Conj( + [BinConstraintD(d[2], Dyn, op_neq), BinConstraintD(b3, Dyn, op_neq)] + ) + d4_not_dyn = Conj( + [BinConstraintD(d[3], Dyn, op_neq), BinConstraintD(b4, Dyn, op_neq)] + ) + + # transform parameters into tuples incase they are not already + padding = ( + (constraint.padding, constraint.padding) + if isinstance(constraint.padding, int) + else constraint.padding + ) + kernel = ( + (constraint.kernel, constraint.kernel) + if isinstance(constraint.kernel, int) + else constraint.kernel + ) + stride = ( + (constraint.stride, constraint.stride) + if isinstance(constraint.stride, int) + else constraint.stride + ) + dilation = ( + (constraint.dilation, constraint.dilation) + if isinstance(constraint.dilation, int) + else constraint.dilation + ) + + f1 = BinConstraintD(b3, BinConstraintD(2, padding[0], op_mul), op_add) + f2 = BinConstraintD(dilation[0], BinConstraintD(kernel[0], 1, op_sub), op_mul) + f3 = BinConstraintD( + BinConstraintD(BinConstraintD(f1, f2, op_sub), 1, op_sub), stride[0], op_div + ) + f4 = BinConstraintD(f3, 1, op_add) + + c4 = Disj([b3_dyn, Conj([d3_not_dyn, BinConstraintD(d[2], f4, op_eq)])]) + + f11 = BinConstraintD(b4, BinConstraintD(2, padding[1], op_mul), op_add) + f22 = BinConstraintD(dilation[1], BinConstraintD(kernel[1], 1, op_sub), op_mul) + f33 = BinConstraintD( + BinConstraintD(BinConstraintD(f11, f22, op_sub), 1, op_sub), stride[1], op_div + ) + f44 = BinConstraintD(f33, 1, op_add) + + c5 = Disj([b4_dyn, Conj([d4_not_dyn, BinConstraintD(d[3], f44, op_eq)])]) + + return c4, c5 + + +def generate_all_int_dyn_dim_possibilities(my_list: list[DVar]): + """ + Generate all possibilities of being equal or not equal to dyn for my_list + Args: + my_list: List of tensor dimensions + + Returns: A list of a list of constraints. Each list of constraints corresponds to + one possibility about the values of the dimension variables + """ + # generate all possibilities of being equal or not equal to dyn for my_list + eq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_eq) for i in range(len(my_list)) + ] + neq_possibilities = [ + BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list)) + ] + + d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)] + all_possibilities = list(itertools.product(*d_possibilities)) + return all_possibilities + + +def is_target_div_by_dim(target: list[int], dim: list[DVar]): + """ + Generate constraints to check if the target dimensions are divisible by the input dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(Prod(target), dim, op_mod), 0, op_eq) + + +def is_dim_div_by_target(target: list[int], dim: list[DVar]): + """ + Generate constraints to check if the input dimensions is divisible by the target dimensions + Args: + target: Target dimensions + dim: Input dimensions + + Returns: Constraints to check divisibility + + """ + return BinConstraintD(BinConstraintD(dim, Prod(target), op_mod), 0, op_eq) + + +def gen_all_reshape_possibilities(list_of_dims, target): + """ + Consider all possibilities what the input dimensions could be (number or dynamic) + Then generate the appropriate constraints using multiplication or mod depending on the possibility + The possibilities we consider here are the cross product of being equal to dyn or not equal to dyn + for the input. Target is fixed because at most one dimension could be dyn. + We have different cases for this. + + Args: + list_of_dims: The input list of dimensions + target: The tensor we want to reshape to + + Returns: A disjunction of transformed reshape constraints + + """ + all_possibilities = generate_all_int_dyn_dim_possibilities(list_of_dims) + + all_constraints = [] + + for p in all_possibilities: + to_multiply = [] + + p = list(p) + + for constraint in p: + assert isinstance(constraint, BinConstraintD) + if constraint.op == op_neq: + to_multiply.append(constraint.lhs) + + if not to_multiply: + all_constraints.append(Conj(p)) + + elif len(to_multiply) < len(list_of_dims): + all_constraints.append( + Conj(p + [is_target_div_by_dim(target, Prod(to_multiply))]) + ) + else: + all_constraints.append( + Conj(p + [BinConstraintD(Prod(list_of_dims), Prod(target), op_eq)]) + ) + + return Disj(all_constraints) + + +def broadcast_dim(tensor_input1, tensor_input2, res1, res2, index, padding=False): + """ + Apply broadcasting to the 'index' dimension of tensor_input1. + Args: + tensor_input1: should represent [d1, ..., d_index, ...] where d_index = 1 + tensor_input2: represents the second input + res1: broadcasted result 1 + res2: broadcasted result 2 + index: the index to broadcast + padding: If padding was used, then tensor_input1[index] does not exist + + Returns: + + """ + if tensor_input1[index] is None: + assert padding + + if not padding: + # then the inputs are the same length so they all have dimensions at "index" + return Conj( + [ + BinConstraintD(tensor_input1[index], 1, op_eq), + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) + + else: + # we don't set the input dimension to 1, since it doesn't exist. + return Conj( + [ + BinConstraintD(res1[index], res2[index], op_eq), + BinConstraintD(res2[index], tensor_input2[index], op_eq), + ] + ) + + +def apply_padding( + e1_var: TVar, + e11: BinConstraintT, + e2: BinConstraintT, + e12: BinConstraintT, + d2: list[DVar], + d11: list[DVar], + d12: list[DVar], + counter: int, +): + """ + We are considering the possibility where one input has less dimensions than + another input, so we apply padding to the broadcasted results + + Args: + e1_var: Variable representing the first input where padding will be + e11: constraint of the form e11 = Tensortype[d1, ..., dn] + e2: constraint of the form e2 = Tensortype[d1, ..., dn] + e12: constraint of the form e11 = Tensortype[d1, ..., dn] + d2: Tensor variables for the second input + d11: Tensor variables for the broadcasted first input + d12: Tensor variables for the broadcasted second input + counter: variable tracking + + Returns: A new constraint whose goal is to apply padding to the broadcasted result + + """ + + res = [] + + # pad the shorter input with None so we can pass it to the broadcasting helper function + for i in range(1, len(d2)): + d1, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(d1 + d2 + d11 + d12) + + e1 = BinConstraintT(e1_var, TensorType(d1), op_eq) + + simulate_padding = [None] * (len(d2) - i) + + assert len(simulate_padding + d1) == len(d2) + + # for every padding size, we also consider broadcasting + broadcast_padding = [ + broadcast_dim(simulate_padding, d2, d11, d12, j, True) + for j in range(len(d2) - i) + ] + + # we consider the possibilities for broadcasting for every dimension. Since we already + # padded d1, we do not consider it while broadcasting + all_broadcasting_possibilities = ( + generate_all_broadcasting_possibilities_no_padding( + d1, d2[(len(d2) - i) :], d11[(len(d2) - i) :], d12[(len(d2) - i) :] + ) + ) + # combine all constraints into a conjunction + c = Conj( + [ + e1, + e11, + e2, + e12, + *broadcast_padding, + all_broadcasting_possibilities, + *nat_constraints, + ] + ) + res.append(c) + + return Disj(res), counter + + +def no_broadcast_dim_with_index( + d1: list[DVar], d2: list[DVar], d3: list[DVar], d4: list[DVar], i: int +): + """ + Args: + d1: input 1 + d2: input 2 + d3: simulated broadcasting for input 1 + d4: simulated broadcasting for input 2 + i: the rank of the resulting tensor addition + + Returns: Constraints for when no broadcasting occurs + """ + return Conj( + [ + Disj( + [ + Conj( + [ + BinConstraintD(d1[i], 1, op_eq), + BinConstraintD(d2[i], 1, op_eq), + ] + ), + Conj( + [ + BinConstraintD(d1[i], 1, op_neq), + BinConstraintD(d2[i], 1, op_neq), + ] + ), + ] + ), + BinConstraintD(d1[i], d3[i], op_eq), + BinConstraintD(d2[i], d4[i], op_eq), + ] + ) + + +def gen_lists_of_dims(num_tensors: int, dim_size: int, counter: int): + """ + Generate lists of DVar to represent tensor dimensions + Args: + num_tensors: the required number of tensors + dim_size: the number of dimensions for each tensor + counter: variable tracking + + Returns: A list of a list of tensor dimensions + + """ + res = [] + + for _ in range(num_tensors): + dims, counter = gen_tensor_dims(dim_size, counter) + res.append(dims) + + return res, counter + + +def create_equality_constraints_for_broadcasting( + e1: TVar, + e2: TVar, + e11: TVar, + e12: TVar, + d1: list[DVar], + d2: list[DVar], + d11: list[DVar], + d12: list[DVar], +): + """ + Create equality constraints for when no broadcasting occurs + Args: + e1: Input 1 + e2: Input 2 + e11: Broadcasted input 1 + e12: Broadcasted input 2 + d1: Variables that store dimensions for e1 + d2: Variables that store dimensions for e2 + d11: Variables that store dimensions for e11 + d12: Variables that store dimensions for e22 + + Returns: Four equality constraints + + """ + + e1_tensor = BinConstraintT(e1, TensorType(d1), op_eq) + e11_tensor = BinConstraintT(e11, TensorType(d11), op_eq) + e2_tensor = BinConstraintT(e2, TensorType(d2), op_eq) + e12_tensor = BinConstraintT(e12, TensorType(d12), op_eq) + return [e1_tensor, e11_tensor, e2_tensor, e12_tensor] + + +def gen_consistency_constraints(constraint: Constraint, counter: int): + """ + Args: + constraint: Consistency constraint on tensors + counter: for variable tracking + + Returns: Equality and consistency constraints on dimensions + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + new_dims_rhs_1, counter = gen_tensor_dims(i, counter) + new_dims_rhs_2, counter = gen_tensor_dims(i, counter) + + nat_constraints = gen_nat_constraints(new_dims_rhs_1 + new_dims_rhs_2) + + c_tensor_i = Conj( + [ + BinConstraintT(constraint.lhs, TensorType(new_dims_rhs_1), op_eq), + BinConstraintT(constraint.rhs, TensorType(new_dims_rhs_2), op_eq), + ] + + [ + BinConstraintD(d1, d2, op_consistency) + for d1, d2 in zip(new_dims_rhs_1, new_dims_rhs_2) + ] + + nat_constraints + ) + + all_constraints.append(c_tensor_i) + + return all_constraints, counter + + +def gen_greatest_upper_bound(constraint: TGreatestUpperBound, counter: int): + """ + Args: + constraint: Greatest upper bound on tensors + counter: variable tracking + + Returns: A set of equality constraints and DGreatestUpperBound constraints + + """ + + all_constraints = [] + + for i in range(1, MAX_TENSOR_RANK + 1): + c = [] + dims1, counter = gen_tensor_dims(i, counter) + c1tensor = TensorType(dims1) + + dims2, counter = gen_tensor_dims(i, counter) + c2tensor = TensorType(dims2) + + dims3, counter = gen_tensor_dims(i, counter) + c3tensor = TensorType(dims3) + + c += [ + BinConstraintT(constraint.rhs1, c1tensor, op_eq), + BinConstraintT(constraint.rhs2, c2tensor, op_eq), + BinConstraintT(constraint.res, c3tensor, op_eq), + ] + gen_nat_constraints(dims1 + dims2 + dims3) + + assert ( + len(c3tensor.__args__) == len(c1tensor.__args__) == len(c2tensor.__args__) + ) + for i in range(len(c3tensor.__args__)): + c.append( + DGreatestUpperBound( + c3tensor.__args__[i], c1tensor.__args__[i], c2tensor.__args__[i] + ) + ) + + all_constraints.append(Conj(c)) + return all_constraints, counter + + +def generate_all_broadcasting_possibilities_no_padding( + d1: list[DVar], d2: list[DVar], d11: list[DVar], d12: list[DVar] +): + """ + Generate broadcasting constraints assuming no padding. Broadcasting can happen at any dimension. + We look at all combinations for all dimensions in d1 and d2 + Args: + d1: input1 dimensions + d2: input2 dimensions + d11: broadcasted input1 dimensions + d12: broadcasted input2 dimensions + + Returns: broadcasting constraints relating the input dimensions to the broadcasted dimensions + + """ + + size = len(d1) + + res2 = [] + + for i in range(size): + t1 = broadcast_dim(d1, d2, d11, d12, i) + t2 = broadcast_dim(d2, d1, d12, d11, i) + t3 = no_broadcast_dim_with_index(d1, d2, d11, d12, i) + + res2.append(Disj([t1, t2, t3])) + + return Conj(res2) + + +def gen_broadcasting_constraints( + e1: TVar, e2: TVar, e11: TVar, e12: TVar, i: int, counter: int +): + """ + Simulates broadcasting on e1 and e2 and returns the results + respectively in e11 and e12. Because of gradual types, + e1 and e2 may not be equal. Similarly, e11 and e12 may not + be equal. e11 and e12 should be guaranteed to be consistent + as they represent the shapes of the tensors to be added after + broadcasting. + Args: + e1: TVar representing the type of input 1 + e2: TVar representing the type of input 2 + e11: TVar representing the representing broadcasted input 1 + e12: TVar representing the representing broadcasted input 2 + i: The rank of the resulting type of addition + counter: for variable tracking + + Returns: Simplified broadcasting constraints + + """ + dims, counter = gen_lists_of_dims(4, i, counter) + [d1, d2, d3, d4] = dims + nat_dims_i = gen_nat_constraints(list(itertools.chain.from_iterable(dims))) + + initialize_tensors_constraints = create_equality_constraints_for_broadcasting( + e1, e2, e11, e12, d1, d2, d3, d4 + ) + + [e1_tensor, e11_tensor, e2_tensor, e12_tensor] = initialize_tensors_constraints + + # without padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_no_padding = Conj( + [ + *initialize_tensors_constraints, + generate_all_broadcasting_possibilities_no_padding(d1, d2, d3, d4), + ] + ) + + # with padding, broadcast all possibilities for tensors of size i + final_tensor_constraint_padding_arg1, counter = apply_padding( + e1, e11_tensor, e2_tensor, e12_tensor, d2, d3, d4, counter + ) + + final_tensor_constraint_padding_arg2, counter = apply_padding( + e2, e12_tensor, e1_tensor, e11_tensor, d1, d4, d3, counter + ) + + return ( + final_tensor_constraint_no_padding, + final_tensor_constraint_padding_arg1, + final_tensor_constraint_padding_arg2, + nat_dims_i, + counter, + ) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca1705cdcf49282a59e9f0366249747e9079a61 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/operation.py @@ -0,0 +1,14 @@ +op_add = "+" +op_sub = "-" +op_mul = "*" +op_div = "/" +op_eq = "=" +op_neq = "!=" +op_imp = "=>" +op_matching = "\u22b3" # (contains) +op_consistency = "~" +op_precision = "\u2291" # (square image of or equal to) +op_leq = "\u2264" # less-than or equal to +op_lt = "<" +op_gt = ">" +op_mod = "%" diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py new file mode 100644 index 0000000000000000000000000000000000000000..0bb1b3ec52d6cc5c78b45c84107ab3fd1ef562a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/transform_to_z3.py @@ -0,0 +1,446 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BinConstraintT, + BVar, + Conj, + Disj, + DVar, + F, + is_algebraic_expression, + is_bool_expr, + is_dim, + Prod, + T, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.constraint_generator import ( + ConstraintGenerator, +) +from torch.fx.experimental.migrate_gradual_types.constraint_transformation import ( + transform_constraint, +) +from torch.fx.experimental.migrate_gradual_types.operation import ( + op_add, + op_div, + op_eq, + op_gt, + op_leq, + op_lt, + op_mod, + op_mul, + op_neq, + op_sub, +) +from torch.fx.tensor_type import Dyn, TensorType + + +try: + import z3 # type: ignore[import] + + from torch.fx.experimental.migrate_gradual_types.z3_types import ( + D, + tensor_type, + z3_dyn, + ) + + HAS_Z3 = True + + def transform_to_z3(constraint, counter, dimension_dict): + if isinstance(constraint, Conj): + conjuncts = [] + for c in constraint.conjucts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + conjuncts.append(new_c) + return z3.And(conjuncts), counter + + elif isinstance(constraint, Disj): + disjuncts = [] + for c in constraint.disjuncts: + new_c, counter = transform_to_z3(c, counter, dimension_dict) + disjuncts.append(new_c) + return z3.Or(disjuncts), counter + + elif isinstance(constraint, T): + return True, counter + + elif isinstance(constraint, F): + return False, counter + + elif isinstance(constraint, BinConstraintT): + if constraint.op == op_eq: + lhs, counter = transform_var(constraint.lhs, counter, dimension_dict) + rhs, counter = transform_var(constraint.rhs, counter, dimension_dict) + return (lhs == rhs), counter + + else: + raise NotImplementedError("Method not yet implemented") + + elif isinstance(constraint, BinConstraintD): + if constraint.op == op_eq: + if isinstance(constraint.lhs, BVar) and is_bool_expr(constraint.rhs): + transformed_rhs, counter = transform_to_z3( + constraint.rhs, counter, dimension_dict + ) + transformed_lhs = z3.Bool(constraint.lhs.c) + return transformed_lhs == transformed_rhs, counter + + elif is_dim(constraint.lhs) and is_dim(constraint.rhs): + # with dimension transformations we consider the encoding + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) + return lhs == rhs, counter + + else: + # then we have an algebraic expression which means that we disregard the + # first element of the encoding + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs == rhs, counter + + # The assumption here is that the LHS and RHS must be dimensions + elif constraint.op == op_neq: + assert is_dim(constraint.lhs) + assert is_dim(constraint.rhs) + lhs, counter = transform_dimension( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_dimension( + constraint.rhs, counter, dimension_dict + ) + if constraint.rhs == Dyn or constraint.lhs == Dyn: + if constraint.rhs == Dyn: + return lhs.arg(0) == 1, counter + elif constraint.lhs == Dyn: + return rhs.arg(0) == 1, counter + + # if one of the instances is a number + elif isinstance(constraint.lhs, int) or isinstance(constraint.rhs, int): + if isinstance(constraint.lhs, int): + return ( + z3.Or( + [ + rhs.arg(0) == 0, + z3.And([rhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) + + elif isinstance(constraint.rhs, int): + return ( + z3.Or( + [ + lhs.arg(0) == 0, + z3.And([lhs.arg(0) == 1, lhs.arg(1) != rhs.arg(1)]), + ] + ), + counter, + ) + + else: + return ( + z3.Or( + [ + z3.And([lhs.arg(0) == 0, rhs.arg(0) != 0]), + z3.And([lhs.arg(0) != 0, rhs.arg(0) == 0]), + z3.And( + [ + lhs.arg(0) != 0, + rhs.arg(0) != 0, + lhs.arg(1) != rhs.arg(1), + ] + ), + ] + ), + counter, + ) + + elif constraint.op == op_leq: + # if the dimensions are not dyn, this will come into effect + # there would have been another constraint specifying if a given dimension + # is dyn or not + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs <= rhs, counter + + elif constraint.op == op_gt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs > rhs, counter + + elif constraint.op == op_lt: + assert is_dim(constraint.lhs) and is_dim(constraint.rhs) + lhs, counter = transform_algebraic_expression( + constraint.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + constraint.rhs, counter, dimension_dict + ) + return lhs < rhs, counter + + else: + raise NotImplementedError("operation not yet implemented") + + else: + raise NotImplementedError("Operation not yet implemented") + + def transform_var(tensor, counter, dimension_dict): + """ + Transforms tensor variables to a format understood by z3 + Args: + tensor: Tensor variable or a tensor type potentially with variable dimensions + Returns: Transformed variable to a z3 format + + """ + if isinstance(tensor, TensorType): + res = [] + for t in tensor.__args__: + transformed, counter = transform_dimension(t, counter, dimension_dict) + res.append(transformed) + + assert len(res) <= 4 + if len(tensor.__args__) == 1: + return tensor_type.tensor1(res[0]), counter + elif len(tensor.__args__) == 2: + return tensor_type.tensor2(res[0], res[1]), counter + elif len(tensor.__args__) == 3: + return tensor_type.tensor3(res[0], res[1], res[2]), counter + elif len(tensor.__args__) == 4: + return tensor_type.tensor4(res[0], res[1], res[2], res[3]), counter + + elif tensor == Dyn: + return z3_dyn, counter + + elif isinstance(tensor, TVar): + return z3.Const(tensor.tvar, tensor_type), counter + + def transform_dimension(dimension, counter, dimension_dict): + """ + Takes a dimension variable or a number and transforms it to a tuple + according to our scheme + Args: + dimension: The dimension to be transformed + counter: variable tracking + + Returns: tuple and the current counter + + """ + if dimension == Dyn: + counter += 1 + return D(0, z3.Int(counter)), counter + elif isinstance(dimension, int): + return D(1, dimension), counter + elif isinstance(dimension, DVar): + if dimension.c in dimension_dict: + return ( + D(z3.Int(dimension_dict[dimension.c]), z3.Int(dimension.c)), + counter, + ) + else: + counter += 1 + dimension_dict[dimension.c] = counter + return D(z3.Int(counter), z3.Int(dimension.c)), counter + + def transform_algebraic_expression(expr, counter, dimension_dict): + """ + Transforms an algebraic expression to z3 format + Args: + expr: An expression is either a dimension variable or an algebraic-expression + + + Returns: the transformed expression + + """ + assert is_algebraic_expression(expr) or is_dim(expr) + + if is_dim(expr): + transformed, counter = transform_dimension(expr, counter, dimension_dict) + return transformed.arg(1), counter + + elif isinstance(expr, Prod): + dims = [] + for dim in expr.products: + assert is_dim(dim) + d, counter = transform_dimension(dim, counter, dimension_dict) + dims.append(d.arg(1)) + return z3.Product(dims), counter + + elif is_algebraic_expression(expr): + lhs, counter = transform_algebraic_expression( + expr.lhs, counter, dimension_dict + ) + rhs, counter = transform_algebraic_expression( + expr.rhs, counter, dimension_dict + ) + + if expr.op == op_sub: + c = lhs - rhs + + elif expr.op == op_add: + c = lhs + rhs + + elif expr.op == op_div: + c = lhs / rhs + + elif expr.op == op_mul: + c = lhs * rhs + + elif expr.op == op_mod: + c = lhs % rhs + + else: + raise NotImplementedError("operation not yet implemented") + + return c, counter + + else: + raise RuntimeError + + def transform_all_constraints(traced, counter=0): + """ + Given a trace, generates constraints and transforms them to z3 format + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(traced) + new_constraints, counter = generator.generate_constraints(counter) + + # print(new_constraints.conjucts[0]) + # print(*new_constraints.conjucts, sep='\n') + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + # print(new_constraints) + # print(new_constraints.conjucts) + # new_constraints.conjucts = new_constraints.conjucts[:-1] + # print(*new_constraints.conjucts, sep='\n') + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + # print(transformed) + return transformed + + def iterate_till_fixed_point(constraints, counter): + """ + Transform constraints till reaching a fixed point + """ + old_c = None + while old_c != constraints: + old_c = constraints + constraints, counter = transform_constraint(constraints, counter) + return constraints, counter + + def transform_all_constraints_trace_time(tracer_root, graph, node, counter=0): + """ + Takes a node and a graph and generates two sets of constraints. + One set constraints the node's constraints and another set + constraints the negation of the node's constraints + Args: + tracer_root: the root for getting the module instances + graph: the graph so far in the tracing process + node: node that represents a conditional + counter: variable tracking + + Returns: Two sets of constraints. One with a conjunction with the + the conditional constraint and the other with a conjunction with + its negation. + + """ + dimension_dict = {} # type: ignore[var-annotated] + + generator = ConstraintGenerator(tracer_root, graph) + new_constraints, counter = generator.generate_constraints(counter) + + condition_constraint = new_constraints.conjucts[-1] + + # we know the constraint is a conjunction where the last constraint is about the conditional + # so remove the last constraint + new_constraints.conjucts = new_constraints.conjucts[:-1] + + # transform precision, matching, consistency till obtaining a fixed point + new_constraints, counter = iterate_till_fixed_point(new_constraints, counter) + + # since the function returns a list of one element, we get the first element + # we are only interested in the RHS in this case because the LHS just stores + # the result + + # we make sure the constraint is of the form: + # c = b where b is a boolean expression + # and we consider b (constraint.rhs) for transformation + assert isinstance(condition_constraint.lhs, BVar) + assert is_bool_expr(condition_constraint.rhs) + condition_constraint_rhs = condition_constraint.rhs + + # transform the condition constraint + condition_constraint_rhs, counter = iterate_till_fixed_point( + condition_constraint_rhs, counter + ) + + transformed, counter = transform_to_z3(new_constraints, counter, dimension_dict) + + transformed_condition_constraint, counter = transform_to_z3( + condition_constraint_rhs, counter, dimension_dict + ) + + negation_transformed_condition_constraint = z3.Not( + transformed_condition_constraint + ) + + return z3.And([transformed, transformed_condition_constraint]), z3.And( + [transformed, negation_transformed_condition_constraint] + ) + + def evaluate_conditional_with_constraints( + tracer_root, graph, node, counter=0, user_constraints=None + ): + """ + Given an IR and a node representing a conditional, evaluate the conditional + and its negation + Args: + tracer_root: Tracer root for module instances + node: The node to be evaluated + + Returns: the results of evaluating the condition and the negation with + the rest of the constraints + + """ + + ( + transformed_positive, + transformed_negative, + ) = transform_all_constraints_trace_time(tracer_root, graph, node, counter) + + s = z3.Solver() + s.add(transformed_positive) + if user_constraints is not None: + s.add(user_constraints) + condition = s.check() + + s = z3.Solver() + s.add(transformed_negative) + if user_constraints is not None: + s.add(user_constraints) + negation = s.check() + return condition, negation + +except ImportError: + HAS_Z3 = False diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/util.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ee98dc8406652773f73f8442e55b9bd0a15aaa87 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/util.py @@ -0,0 +1,59 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.migrate_gradual_types.constraint import ( + BinConstraintD, + BVar, + DVar, + TVar, +) +from torch.fx.experimental.migrate_gradual_types.operation import op_leq + + +def gen_tvar(curr): + """ + Generate a tensor variable + :param curr: The current counter + :return: a tensor variable and the updated counter + """ + curr += 1 + return TVar(curr), curr + + +def gen_dvar(curr): + """ + Generate a dimension variable + :param curr: the current counter + :return: a dimension variable and an updated counter + """ + curr += 1 + return DVar(curr), curr + + +def gen_bvar(curr): + """ + Generate a boolean variable + :param curr: the current counter + :return: a boolean variable and an updated counter + """ + curr += 1 + return BVar(curr), curr + + +def gen_tensor_dims(n, curr): + """ + Generate a list of tensor dimensions + :param n: the number of dimensions + :param curr: the current counter + :return: a list of dimension variables and an updated counter + """ + dims = [] + for _ in range(n): + dvar, curr = gen_dvar(curr) + dims.append(dvar) + return dims, curr + + +def gen_nat_constraints(list_of_dims): + """ + Generate natural number constraints for dimensions + """ + return [BinConstraintD(0, d, op_leq) for d in list_of_dims] diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py new file mode 100644 index 0000000000000000000000000000000000000000..1cd70e4146bc401e0face5f94512d712ba3c3953 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/migrate_gradual_types/z3_types.py @@ -0,0 +1,30 @@ +try: + import z3 # type: ignore[import] + + HAS_Z3 = True + # dynamic type + dyn = z3.DeclareSort("Dyn") + dyn_type = z3.Const("dyn", dyn) + + # dimension + dim = z3.Datatype("dim") + dim.declare("dim", ("0", z3.IntSort()), ("1", z3.IntSort())) + dim = dim.create() + + # tensors + tensor_type = z3.Datatype("TensorType") + tensor_type.declare("Dyn", ("dyn", dyn)) + tensor_type.declare("tensor1", ("0", dim)) + tensor_type.declare("tensor2", ("0", dim), ("1", dim)) + tensor_type.declare("tensor3", ("0", dim), ("1", dim), ("2", dim)) + tensor_type.declare("tensor4", ("0", dim), ("1", dim), ("2", dim), ("3", dim)) + tensor_type = tensor_type.create() + + # create dimension + D = dim.dim + + z3_dyn = tensor_type.Dyn(dyn_type) + + +except ImportError: + HAS_Z3 = False diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/normalize.py b/phivenv/Lib/site-packages/torch/fx/experimental/normalize.py new file mode 100644 index 0000000000000000000000000000000000000000..5b05833704514c927c26bfaf698d78fbd59b6ba4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/normalize.py @@ -0,0 +1,163 @@ +# mypy: allow-untyped-defs +import operator +from typing import Any, Callable, Optional + +import torch +import torch.fx +import torch.fx as fx +from torch.fx import Proxy, Transformer +from torch.fx.node import Argument, map_aggregate, Node, Target +from torch.fx.operator_schemas import ( + create_type_hint, + normalize_function, + normalize_module, +) + +from .schema_type_annotation import AnnotateTypesWithSchema + + +class NormalizeArgs(Transformer): + """ + Normalize arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and rewritten to exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. Also populates default + values. Does not support positional-only parameters or varargs + parameters (*args, **kwargs). + + If the nodes have 'type' metadata, it will use it to disambiguate + overloads. Otherwise, it will throw an error. + + Example usage: + m = torchvision.models.resnet18() + traced = torch.fx.symbolic_trace(m) + traced = NormalizeArgs(traced).transform() + """ + + def __init__( + self, module: torch.fx.GraphModule, normalize_to_only_use_kwargs: bool = True + ): + super().__init__(module) + self.node_map: dict[Proxy, Node] = {} + self.normalize_to_only_use_kwargs = normalize_to_only_use_kwargs + + def run_node(self, n: Node) -> Any: + args, kwargs = self.fetch_args_kwargs_from_env(n) + + def get_type(arg): + if isinstance(arg, fx.Node): + return n.meta["type"] if "type" in n.meta else None + return type(arg) + + arg_types = map_aggregate(n.args, get_type) + assert isinstance(arg_types, tuple) + arg_types = tuple([create_type_hint(i) for i in arg_types]) + kwarg_types = {k: get_type(v) for k, v in kwargs.items()} + if n.op == "call_function": + out = self.call_function(n.target, args, kwargs, arg_types, kwarg_types) + else: + out = super().run_node(n) + if n.op != "output": + self.node_map[out] = n + out.node.meta = n.meta + out.node.type = n.type + return out + + def call_function( + self, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + arg_types: Optional[tuple[Any, ...]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + ): + assert callable(target) + new_args_and_kwargs = normalize_function( + target, + args, # type: ignore[arg-type] + kwargs, + arg_types, # type: ignore[arg-type] + kwarg_types, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return self.tracer.create_proxy( + "call_function", target, new_args, new_kwargs + ) + else: + return super().call_function(target, args, kwargs) + + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + assert isinstance(target, str) + new_args_and_kwargs = normalize_module( + self.module, + target, + args, # type: ignore[arg-type] + kwargs, + self.normalize_to_only_use_kwargs, + ) + if new_args_and_kwargs: + new_args, new_kwargs = new_args_and_kwargs + return super().call_module(target, new_args, new_kwargs) + else: + return super().call_module(target, args, kwargs) + + +class NormalizeOperators(AnnotateTypesWithSchema): + """ + Normalize callsites that are different ways of "spelling" the same + invocation into a single, canonical call. Currently supports: + + 1. Normalize operators (e.g. operator.add) to the `torch` ops they + ultimately invoke (e.g. torch.add) when it is possible to statically + reason that + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = NormalizeOperators(traced).transform() + """ + + binary_magic_method_remap: dict[ + Callable[[Any, Any], Any], Callable[[Any, Any], Any] + ] = { + torch.add: operator.add, + torch.mul: operator.mul, + torch.sub: operator.sub, + torch.div: operator.truediv, + torch.floor_divide: operator.floordiv, + torch.remainder: operator.mod, + torch.eq: operator.eq, + torch.ne: operator.ne, + torch.lt: operator.lt, + torch.le: operator.le, + torch.gt: operator.gt, + torch.ge: operator.ge, + } + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + # Normalize operators according to the magic methods implemented on tensors here: + # https://github.com/pytorch/pytorch/blob/28c5d90b679c6b38bf4183ec99f16d933c2f1bcd/tools/autograd/templates/python_variable_methods.cpp#L1137 # noqa: B950 + + assert callable(target) + + if target in self.binary_magic_method_remap: + if len(args) != 2: + return super().call_function(target, args, kwargs) + lhs, rhs = args + + return super().call_function( + target=self.binary_magic_method_remap[target], + args=(lhs, rhs), + kwargs={}, + ) + + return super().call_function(target, args, kwargs) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/optimization.py b/phivenv/Lib/site-packages/torch/fx/experimental/optimization.py new file mode 100644 index 0000000000000000000000000000000000000000..f90d63df2b2d4cb1fd317d373df782da12d50e78 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/optimization.py @@ -0,0 +1,486 @@ +# mypy: allow-untyped-defs +import copy +import logging +import operator +import time +from collections import defaultdict +from collections.abc import Iterable +from enum import Enum +from typing import Any, cast, Optional + +import torch +import torch.fx as fx +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.mkldnn as th_mkldnn +from torch.fx.node import Argument, Target +from torch.fx.passes.shape_prop import ShapeProp +from torch.nn.utils.fusion import fuse_conv_bn_eval, fuse_linear_bn_eval + + +__all__ = [ + "matches_module_pattern", + "replace_node_module", + "fuse", + "remove_dropout", + "extract_subgraph", + "modules_to_mkldnn", + "reset_modules", + "MklSubgraph", + "gen_mkl_autotuner", + "use_mkl_length", + "UnionFind", + "optimize_for_inference", +] + + +def _parent_name(target: str) -> tuple[str, str]: + """ + Splits a qualname into parent path and last atom. + For example, `foo.bar.baz` -> (`foo.bar`, `baz`) + """ + *parent, name = target.rsplit(".", 1) + return parent[0] if parent else "", name + + +# Works for length 2 patterns with 2 modules +def matches_module_pattern( + pattern: Iterable[type], node: fx.Node, modules: dict[str, Any] +): + if len(node.args) == 0: + return False + nodes: tuple[Any, fx.Node] = (node.args[0], node) + for expected_type, current_node in zip(pattern, nodes): + if not isinstance(current_node, fx.Node): + return False + if current_node.op != "call_module": + return False + if not isinstance(current_node.target, str): + return False + if current_node.target not in modules: + return False + if type(modules[current_node.target]) is not expected_type: + return False + return True + + +def replace_node_module( + node: fx.Node, modules: dict[str, Any], new_module: torch.nn.Module +): + assert isinstance(node.target, str) + parent_name, name = _parent_name(node.target) + modules[node.target] = new_module + setattr(modules[parent_name], name, new_module) + + +def fuse(model: torch.nn.Module, inplace=False, no_trace=False) -> torch.nn.Module: + """ + Fuses convolution/BN and linear/BN layers for inference purposes. + Will deepcopy your model by default, but can modify the model inplace as well. + """ + patterns = [ + (nn.Conv1d, nn.BatchNorm1d), + (nn.Conv2d, nn.BatchNorm2d), + (nn.Conv3d, nn.BatchNorm3d), + (nn.Linear, nn.BatchNorm1d), + ] + if not inplace: + model = copy.deepcopy(model) + if not no_trace or not isinstance(model, torch.fx.GraphModule): + fx_model = fx.symbolic_trace(model) + else: + fx_model = model + modules = dict(fx_model.named_modules()) + new_graph = copy.deepcopy(fx_model.graph) + + for pattern in patterns: + for node in new_graph.nodes: + if matches_module_pattern(pattern, node, modules): + if len(node.args[0].users) > 1: + # Output of conv/linear is used by other nodes + continue + first_layer = modules[node.args[0].target] + bn = modules[node.target] + if not bn.track_running_stats: + continue + if pattern[0] in [nn.Conv1d, nn.Conv2d, nn.Conv3d]: + fused_layer = fuse_conv_bn_eval(first_layer, bn) + else: # nn.Linear + fused_layer = fuse_linear_bn_eval(first_layer, bn) + replace_node_module(node.args[0], modules, fused_layer) + node.replace_all_uses_with(node.args[0]) + new_graph.erase_node(node) + return fx.GraphModule(fx_model, new_graph) + + +def remove_dropout(model: nn.Module) -> nn.Module: + """ + Removes all dropout layers from the module. + """ + fx_model = fx.symbolic_trace(model) + + class DropoutRemover(torch.fx.Transformer): + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + if isinstance(self.submodules[target], nn.Dropout): + assert len(args) == 1 + return args[0] + else: + return super().call_module(target, args, kwargs) + + return DropoutRemover(fx_model).transform() + + +def extract_subgraph( + orig_module: nn.Module, + nodes: list[fx.Node], + inputs: list[fx.Node], + outputs: list[fx.Node], +): + """ + Given lists of nodes from an existing graph that represent a subgraph, returns a submodule that executes that subgraph. + """ + new_graph = fx.Graph() + env: dict[fx.Node, fx.Node] = {} + for input in inputs: + new_node = new_graph.placeholder(input.name) + env[input] = new_node + for node in nodes: + new_node = new_graph.node_copy(node, lambda x: env[x]) + env[node] = new_node + new_graph.output([env[output] for output in outputs]) + new_graph.lint() + return fx.GraphModule(orig_module, new_graph) + + +mkldnn_supported = [ + nn.Conv2d, + nn.Linear, + nn.BatchNorm2d, + nn.ReLU, + nn.MaxPool2d, + nn.AvgPool2d, + nn.AdaptiveAvgPool2d, + torch.relu, + torch.transpose, + torch.sigmoid, + F.relu, + F.avg_pool2d, + F.adaptive_avg_pool2d, +] +# These are operators that may not be convertible into MKLDNN ops (e.g. the +# args are scalar values). Thus, we only include them in the subgraph if their +# arguments are already in MKLDNN. +# TODO: Determine whether this can be removed after type inference. +mkldnn_supported_unknown = [operator.add, operator.mul] +mkldnn_map = { + nn.Conv2d: th_mkldnn.MkldnnConv2d, + nn.Linear: th_mkldnn.MkldnnLinear, + nn.BatchNorm2d: lambda a, _: th_mkldnn.MkldnnBatchNorm(a), +} + + +def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]): + """ + For each node, if it's a module that can be preconverted into MKLDNN, + then we do so and create a mapping to allow us to convert from the MKLDNN + version of the module to the original. + """ + old_modules: dict[nn.Module, nn.Module] = {} + for node in nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if type(cur_module) in mkldnn_map: + new_module = mkldnn_map[type(cur_module)](cur_module, torch.float) + assert isinstance(new_module, nn.Module) + old_modules[new_module] = copy.deepcopy(cur_module) + replace_node_module(node, modules, new_module) + return old_modules + + +def reset_modules( + nodes: list[fx.Node], + modules: dict[str, nn.Module], + old_modules: dict[nn.Module, nn.Module], +): + """ + Maps each module that's been changed with `modules_to_mkldnn` back to its + original. + """ + for node in nodes: + if node.op == "call_module": + assert isinstance(node.target, str) + cur_module = modules[node.target] + if cur_module in old_modules: + replace_node_module(node, modules, old_modules[cur_module]) + + +class MklSubgraph: + def __init__(self, fx_graph: fx.Graph): + self.fx_graph = fx_graph + self.nodes: list[fx.Node] = [] + self.start_nodes: list[fx.Node] = [] + self.end_nodes: list[fx.Node] = [] + + +def gen_mkl_autotuner(example_inputs, iters=10, warmup=1): + """ + This generates a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by running it with the example_inputs. + + Example usage: + heuristic = gen_mkl_autotuner(example_inputs, iters=10) + fast_model = optimization.optimize_for_inference(model, heuristic) + """ + fx_model = None + old_modules = None + + def use_mkl_heuristic(graph: MklSubgraph) -> bool: + nonlocal fx_model, old_modules + input_nodes = graph.start_nodes + if fx_model is None: + fx_model = graph.fx_graph.owning_module + old_modules = graph.fx_graph.old_modules # type: ignore[attr-defined] + ShapeProp(fx_model).propagate(example_inputs) + sample_inputs = [torch.randn(node.shape) for node in input_nodes] # type: ignore[attr-defined] + output_args = cast(list[fx.Node], [node.args[0] for node in graph.end_nodes]) + submodule = extract_subgraph(fx_model, graph.nodes, input_nodes, output_args) + + def benchmark(f): + for _ in range(warmup): + f() + begin = time.time() + for _ in range(iters): + f() + return time.time() - begin + + mkl_time = benchmark( + lambda: [ + i.to_dense() for i in submodule(*[i.to_mkldnn() for i in sample_inputs]) + ] + ) + + reset_modules( + submodule.graph.nodes, dict(submodule.named_modules()), old_modules + ) + no_mkl_time = benchmark(lambda: submodule(*sample_inputs)) + return mkl_time < no_mkl_time + + return use_mkl_heuristic + + +def use_mkl_length(graph: MklSubgraph) -> bool: + """ + This is a heuristic that can be passed into `optimize_for_inference` that + determines whether a subgraph should be run in MKL by checking if there + are more than 2 nodes in it + """ + return len(graph.nodes) > 2 + + +class UnionFind: + def __init__(self, n): + self.parent: list[Optional[int]] = [None] * n + self.size: list[int] = [0] * n + + def make_set(self, v: int): + self.parent[v] = v + self.size[v] = 1 + + def find(self, v: int) -> int: + par = self.parent[v] + if v == par: + return v + assert par is not None + self.parent[v] = self.find(par) + return cast(int, self.parent[v]) + + def join(self, a: int, b: int): + a, b = self.find(a), self.find(b) + if a == b: + return a + if self.size[a] < self.size[b]: + a, b = b, a + self.parent[b] = a + self.size[a] += self.size[b] + + +def optimize_for_inference( + model: torch.nn.Module, + pass_config: Optional[dict[str, Any]] = None, + tracer: type[fx.Tracer] = fx.Tracer, +) -> torch.nn.Module: + """ + Performs a set of optimization passes to optimize a model for the + purposes of inference. Specifically, the passes that are run are: + 1. Conv/BN fusion + 2. Dropout removal + 3. MKL layout optimizations + + The third optimization takes a function `use_mkl_heuristic` that's used + to determine whether a subgraph should be explicitly run in MKL layout. + + Note: As FX does not currently handle aliasing, this pass currently + assumes nothing aliases. If that isn't true, use at your own risk. + """ + default_pass_config = { + "conv_bn_fuse": True, + "remove_dropout": True, + "mkldnn_layout_optimize": {"heuristic": use_mkl_length}, + } + if pass_config is None: + pass_config = {} + default_pass_config.update(pass_config) + + if default_pass_config["conv_bn_fuse"]: + model = fuse(model) + if default_pass_config["remove_dropout"]: + model = remove_dropout(model) + if default_pass_config["mkldnn_layout_optimize"] is False: + return model + if not isinstance(default_pass_config["mkldnn_layout_optimize"], dict): + raise RuntimeError("mkldnn_layout_optimize config is not a dict") + if "heuristic" not in default_pass_config["mkldnn_layout_optimize"]: + raise RuntimeError("Heuristic not found in mkldnn_layout_optimize config") + use_mkl_heuristic = default_pass_config["mkldnn_layout_optimize"]["heuristic"] + + cur_tracer = tracer() + fx_graph = cur_tracer.trace(copy.deepcopy(model)) + fx.GraphModule(cur_tracer.root, fx_graph) + modules: dict[str, nn.Module] = dict(model.named_modules()) + + class MklSupport(Enum): + NO = 1 + YES = 2 + UNKNOWN = 3 + + # Inserts to_mkldnn and to_dense around every node we want to be a MKLDNN node. + # If the op is in `mkldnn_supported` then we always treat it as a MKLDNN node. + # However, if it's in `mkldnn_supported_unknown`, then we only treat it as + # a MKLDNN node if its inputs are MKLDNN nodes. + for node in list(fx_graph.nodes): + supports_mkldnn = MklSupport.NO + if node.op == "call_module": + cur_module = modules[node.target] + if type(cur_module) in mkldnn_supported: + supports_mkldnn = MklSupport.YES + sample_parameter = next(cur_module.parameters(), None) + if sample_parameter is not None: + assert sample_parameter.dtype == torch.float, ( + "this pass is only for torch.float modules" + ) + assert sample_parameter.device == torch.device("cpu"), ( + "this pass is only for CPU modules" + ) + elif node.op == "call_function": + if node.target in mkldnn_supported: + supports_mkldnn = MklSupport.YES + elif node.target in mkldnn_supported_unknown: + supports_mkldnn = MklSupport.UNKNOWN + + if supports_mkldnn != MklSupport.NO: + if supports_mkldnn == MklSupport.UNKNOWN: + if not any(arg.target == "to_dense" for arg in node.args): + continue + with fx_graph.inserting_before(node): + mkldnn_args = fx.map_arg( + node.args, lambda n: fx_graph.call_method("to_mkldnn", (n,)) + ) + + node.args = cast(tuple[fx.node.Argument], mkldnn_args) + + with fx_graph.inserting_after(node): + dense_x = fx_graph.create_node("call_method", "to_dense", (node,)) + node.replace_all_uses_with(dense_x) + dense_x.args = (node,) + + # Does pre-conversion of all modules into MKLDNN (when possible) + old_modules = modules_to_mkldnn(list(fx_graph.nodes), modules) + fx_graph.old_modules = old_modules # type: ignore[attr-defined] + + # optimizes all a -> to_dense -> to_mkldnn -> b patterns into a -> b + for node in fx_graph.nodes: + if node.op == "call_method" and node.target == "to_dense": + prv_node = node.args[0] + users = list(node.users) + for user in users: + if user.op == "call_method" and user.target == "to_mkldnn": + user.replace_all_uses_with(prv_node) + fx_graph.erase_node(user) + if len(node.users) == 0: + fx_graph.erase_node(node) + + num_nodes = len(fx_graph.nodes) + uf = UnionFind(num_nodes) + + def get_color(n): + if hasattr(n, "color"): # Current node is part of a MKL subgraph + return uf.find(n.color) + if hasattr(n, "start_color"): # Current node is input to MKL subgraph + return uf.find(n.start_color) + return None + + # This code is to find each MKLDNN subgraph. Each MKLDNN subgraph consists + # of input nodes (which are only `to_mkldnn` calls), output nodes + # (`to_dense` calls), and intermediate nodes, which are run entirely on + # MKLDNN layout tensors. + # + # Specifically, this code does a flood fill on a directed acyclic graph + # (DAG), starting from each possible "start node" (i.e: `to_mkldnn` nodes). + # If every node only had one input, this would be sufficient. However, in + # the case that a node has multiple inputs coming from different start + # nodes (i.e. colors), we need to join these 2 colors into 1. That's done + # using a Disjoint Set Union. + for cur_idx, node in enumerate(fx_graph.nodes): + if node.op == "call_method" and node.target == "to_mkldnn": + node.start_color = cur_idx + uf.make_set(cur_idx) + elif node.op == "call_method" and node.target == "to_dense": + assert get_color(node.args[0]) is not None + node.end_color = get_color(node.args[0]) + else: + cur_colors = [ + get_color(i) + for i in node.all_input_nodes + if isinstance(i, fx.Node) + if get_color(i) is not None + ] + + if len(cur_colors) == 0: + continue + assert not any(i is None for i in cur_colors) + cur_colors = sorted(cur_colors) + node.color = cur_colors[0] + for other_color in cur_colors[1:]: + uf.join(cur_colors[0], other_color) + + mkldnn_graphs: dict[int, MklSubgraph] = defaultdict(lambda: MklSubgraph(fx_graph)) + for node in fx_graph.nodes: + if hasattr(node, "color"): + mkldnn_graphs[uf.find(node.color)].nodes.append(node) + if hasattr(node, "start_color"): + mkldnn_graphs[uf.find(node.start_color)].start_nodes.append(node) + if hasattr(node, "end_color"): + mkldnn_graphs[uf.find(node.end_color)].end_nodes.append(node) + + # Now that we have all the subgraphs, we need to decide which MKLDNN + # subgraphs we actually want to keep in MKLDNN. + for graph in mkldnn_graphs.values(): + if not use_mkl_heuristic(graph): + for node in graph.start_nodes + graph.end_nodes: + prv = node.args[0] + node.replace_all_uses_with(prv) # type: ignore[arg-type] + fx_graph.erase_node(node) + reset_modules(graph.nodes, modules, old_modules) + + mkldnn_conversions = 0 + for node in fx_graph.nodes: + if node.target == "to_mkldnn" or node.target == "to_dense": + mkldnn_conversions += 1 + + logging.getLogger(__name__).info("mkldnn conversions: %s", mkldnn_conversions) + fx_graph.lint() + result = fx.GraphModule(model, fx_graph) + return result diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/partitioner_utils.py b/phivenv/Lib/site-packages/torch/fx/experimental/partitioner_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aef455fa8e24d45fc51144c058ae084efab98d8d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/partitioner_utils.py @@ -0,0 +1,317 @@ +# mypy: allow-untyped-defs +from enum import Enum +from typing import NamedTuple + +from torch.fx.node import map_arg, Node + + +class Partition: + """Partition class contains all the information about an individual partition. + It also provides necessary methods for manipulation the partition. + """ + + def __init__(self, partition_id: int) -> None: + self.nodes: set[Node] = set() + self.partition_id = partition_id + self.parents: set[Partition] = set() + self.children: set[Partition] = set() + self.bfs_level: int = -1 + self.used_mem_bytes: int = 0 + self.logical_device_ids: list[int] = [] + + def __str__(self): + return str(self.partition_id) + + def recalculate_mem_size(self): + self.used_mem_bytes = 0 + for node in self.nodes: + self.used_mem_bytes += get_extra_size_of(node, self.nodes) + + def add_node(self, node): + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Add current node's input nodes if they are placeholder or constants + for n in input_nodes: + if n.op in {"placeholder", "get_attr"}: + self.nodes.add(n) + self.nodes.add(node) + self.recalculate_mem_size() + + def remove_node(self, node): + # Remove a node only if the node is in the partition + if node in self.nodes: + self.nodes.remove(node) + # Collect the node's input nodes + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Check if an input node is a placeholder or get_attr, + # and this input node is not used by some other nodes in this partition, + # the remove this input node + for input_node in input_nodes: + if all( + n not in self.nodes for n in input_node.users + ) and input_node.op in {"placeholder", "get_attr"}: + self.nodes.remove(input_node) + self.recalculate_mem_size() + + +class Device(NamedTuple): + name: str + available_mem_bytes: int + logical_id: int + + +class NodeLatency(NamedTuple): + # Latency due to the memory bandwidth + mem_latency_sec: float + # Latency due to the computation + computer_latency_sec: float + + +class PartitionLatency(NamedTuple): + # Sum of all nodes' memory latency on the critical path + mem_latency_sec: float + # Sum of all nodes' compute latency on the critical path + computer_latency_sec: float + # Latency of the critical path + overall_latency_sec: float + + +class PartitionMode(Enum): + size_based = 0 + sparse_nn = 1 + cost_aware = 2 + kl_based = 3 + aot_based = 4 + + +class PartitionerConfig(NamedTuple): + devices: list[Device] + mode: PartitionMode = PartitionMode.size_based + transfer_rate_bytes_per_sec: float = 0.0 + node_to_latency_mapping: dict[Node, NodeLatency] = {} + node_to_partition_mapping: dict[Node, int] = {} + partition_to_logical_device_mapping: dict[int, list[int]] = {} + # Saturate host by replicating partitions to the remaining idle devices. + saturate_host: bool = False + + +def get_extra_size_of(node: Node, nodes: set[Node]) -> int: + """Given a node and a set of nodes, + this function return the extra size that needed + if this node is included in this set. + """ + # Find all its input nodes + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # Calculate total size of related nodes + total_size_of_input_nodes = 0 + for n in input_nodes: + # Make sure this node hasn't been in this set yet + if n not in nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.output_size + else: + raise RuntimeError("node has no size_bytes attr") + # Don't forget the op node itself + size_bytes = getattr(node, "size_bytes", None) + if size_bytes: + total_size_of_input_nodes += size_bytes.total_size + else: + raise RuntimeError("node has no size_bytes attr") + return total_size_of_input_nodes + + +def get_latency_of_one_partition( + partition: Partition, node_to_latency_mapping: dict[Node, NodeLatency] +) -> PartitionLatency: + """Given a partition and its nodes' latency, return a PartitionLatency for this partition""" + + def get_top_nodes(partition: Partition) -> list[Node]: + """Given a partition, return a list of nodes on the top bfs level""" + top_nodes: list[Node] = [] + for node in partition.nodes: + # Skip placeholder and get_attr nodes + if node.op in {"placeholder", "get_attr"}: + continue + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + # If a node has no input nodes in this partition, + # or its input nodes in this partition are placeholders and get_attrs + # this node is on the top bfs level in this partition + if not any( + n in partition.nodes and n.op not in {"placeholder", "get_attr"} + for n in input_nodes + ): + top_nodes.append(node) + return top_nodes + + def dfs_helper(node: Node, partition_latency) -> PartitionLatency: + """Given a top node of a partition, this function returns + the latency of the critical path in the partition + """ + node_latency = node_to_latency_mapping[node] + # Calculate the current overall latency of the partition + overall_latency_sec = partition_latency.overall_latency_sec + max( + node_latency.computer_latency_sec, node_latency.mem_latency_sec + ) + # Update the mem latency of this path + mem_latency_sec = ( + partition_latency.mem_latency_sec + node_latency.mem_latency_sec + ) + # Update the compute latency of this path + computer_latency_sec = ( + partition_latency.computer_latency_sec + node_latency.computer_latency_sec + ) + # Get all users of this node that are in this partition + users = set(node.users).intersection(partition.nodes) + if users: + max_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + for n in users: + # Get new partition latency recursively + new_partition_latency = dfs_helper( + n, + PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ), + ) + if ( + new_partition_latency.overall_latency_sec + > max_latency.overall_latency_sec + ): + max_latency = new_partition_latency + return max_latency + # If there is no user, the node is at bottom of the partition + return PartitionLatency( + mem_latency_sec, computer_latency_sec, overall_latency_sec + ) + + # Main part starts + # Get all top level nodes of this partition + top_nodes = get_top_nodes(partition) + critical_path_latency = PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ) + # Go through all top nodes and find the largest latency (critical pass latency) + for node in top_nodes: + partition_latency = dfs_helper( + node, + PartitionLatency( + mem_latency_sec=0.0, computer_latency_sec=0.0, overall_latency_sec=0.0 + ), + ) + if ( + partition_latency.overall_latency_sec + > critical_path_latency.overall_latency_sec + ): + critical_path_latency = partition_latency + return critical_path_latency + + +def get_partition_to_latency_mapping( + partitions: list[Partition], node_to_latency_mapping: dict[Node, NodeLatency] +) -> dict[Partition, PartitionLatency]: + """Given all the partitions and node_to_latency_mapping dictionary, + return a mapping dictionary of each partition to its overall latency + """ + partition_to_latency_mapping: dict[Partition, PartitionLatency] = {} + # Go through each partition and get its latency + for partition in partitions: + partition_latency = get_latency_of_one_partition( + partition, node_to_latency_mapping + ) + partition_to_latency_mapping[partition] = partition_latency + return partition_to_latency_mapping + + +def get_comm_latency_between( + parent_partition: Partition, + child_partition: Partition, + transfer_rate_bytes_per_sec: float, +): + """Given two partitions (parent and child), + calculate the communication latency between the two. + """ + # If two partitions are on the same device, the comm latency is 0. + if ( + parent_partition.logical_device_ids != [] + and child_partition.logical_device_ids != [] + and parent_partition.logical_device_ids == child_partition.logical_device_ids + ): + return 0.0 + # Keep tracking the communication size between parent and child + comm_size = 0 + # Keep tracking all the counted node + visited_nodes = set() + # Go through all nodes in the child partition + # If a node has input nodes from the parent partition, + # the output size of those input nodes will be counted + # and added to comm_size + for node in child_partition.nodes: + input_nodes: dict[Node, None] = {} + map_arg(node.args, input_nodes.setdefault) + map_arg(node.kwargs, input_nodes.setdefault) + for n in input_nodes: + if n in parent_partition.nodes and n not in visited_nodes: + size_bytes = getattr(n, "size_bytes", None) + if size_bytes is not None: + comm_size += size_bytes.output_size + visited_nodes.add(n) + return comm_size / transfer_rate_bytes_per_sec + + +def get_latency_of_partitioned_graph( + partitions: list[Partition], + partition_to_latency_mapping: dict[Partition, PartitionLatency], + transfer_rate_bytes_per_sec: float, +): + """Given all partitions in a graph, find the critical path among all partitions + and return its latency as the latency of the whole graph + """ + + def dfs_helper(partition: Partition, latency_so_far_sec: float) -> float: + """This function helps to recursively get the latency of a path of partitions""" + # Update latency by adding current partition's latency + latency_so_far_sec += partition_to_latency_mapping[ + partition + ].overall_latency_sec + + if partition.children: + max_latency_sec = 0.0 + for child in partition.children: + # Calculate latency between + comm_latency_sec = get_comm_latency_between( + partition, child, transfer_rate_bytes_per_sec + ) + new_latency_sec = dfs_helper( + child, latency_so_far_sec + comm_latency_sec + ) + if new_latency_sec > max_latency_sec: + max_latency_sec = new_latency_sec + return max_latency_sec + return latency_so_far_sec + + def get_top_partitions(partitions: list[Partition]) -> list[Partition]: + """This function is to return all the partitions without parents + as the starting points of all the paths + """ + # If a partition has no parents, then it is a top partition + top_partitions = [ + partition for partition in partitions if len(partition.parents) == 0 + ] + return top_partitions + + top_partitions = get_top_partitions(partitions) + critical_path_latency_sec = 0.0 + for partition in top_partitions: + latency_sec = dfs_helper(partition, 0.0) + if latency_sec > critical_path_latency_sec: + critical_path_latency_sec = latency_sec + return critical_path_latency_sec diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/proxy_tensor.py b/phivenv/Lib/site-packages/torch/fx/experimental/proxy_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb7710f2fe50234803573fc0e9d61dd4a7ef7cc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/proxy_tensor.py @@ -0,0 +1,2434 @@ +# mypy: allow-untyped-decorators +# Copyright (c) Facebook, Inc. and its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from __future__ import annotations + +import functools +import inspect +import logging +import operator +import traceback +import typing +import typing_extensions +import weakref +from collections import defaultdict, OrderedDict +from collections.abc import Generator, Mapping, Sequence +from contextlib import _GeneratorContextManager, contextmanager, ExitStack, nullcontext +from dataclasses import dataclass +from typing import ( + Any, + Callable, + Optional, + overload, + Protocol, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import Concatenate, ParamSpec, Self, TypeVarTuple, Unpack +from weakref import WeakKeyDictionary + +import torch +import torch._ops +import torch.fx as fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree +from torch import SymBool, SymInt, Tensor +from torch._dispatch.python import enable_python_dispatcher +from torch._library.fake_class_registry import FakeScriptObject +from torch._logging import trace_structured +from torch._subclasses.fake_impls import fast_detach +from torch._subclasses.fake_tensor import ( + FakeTensor, + FakeTensorMode, + is_fake, + unset_fake_temporarily, +) +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx import GraphModule, Proxy, Tracer +from torch.fx.graph_module import _assign_attr +from torch.fx.node import ( + _side_effectful_need_to_be_preserved_pre_dispatch, + Argument, + Target, +) +from torch.fx.passes.shape_prop import _extract_tensor_metadata +from torch.nn import Module +from torch.overrides import TorchFunctionMode +from torch.utils._python_dispatch import ( + _disable_infra_mode, + _push_mode, + _unset_infra_mode, + TorchDispatchMode, +) +from torch.utils._stats import count +from torch.utils._thunk import Thunk +from torch.utils._traceback import CapturedTraceback +from torch.utils.weak import _WeakHashRef, WeakIdKeyDictionary, WeakTensorKeyDictionary + +from ._backward_state import BackwardState +from .sym_node import SymNode + + +if TYPE_CHECKING: + import types + from collections.abc import MutableMapping + + import sympy + + from torch._ops import OpOverload + from torch.fx._symbolic_trace import PHBase + from torch.types import IntLikeType + +__all__ = [ + "PythonKeyTracer", + "dispatch_trace", + "make_fx", + "DecompositionInterpreter", + "py_sym_types", + "get_innermost_proxy_mode", + "get_proxy_mode", + "handle_sym_dispatch", + "maybe_enable_thunkify", + "maybe_disable_thunkify", +] + +_ProxyTracer = Union["PythonKeyTracer", "_GraphAppendingTracerEx"] + +_AnyScriptObject = (torch.ScriptObject, FakeScriptObject) +_AnyScriptObjectType = Union[torch.ScriptObject, FakeScriptObject] + +aten = torch.ops.aten +prim = torch.ops.prim + +log = logging.getLogger(__name__) +not_implemented_log = torch._logging.getArtifactLogger(__name__, "not_implemented") + +CURRENT_DECOMPOSITION_TABLE: Mapping[OpOverload, Callable] = {} + +CONSTANT_NUMEL_LIMIT = 1 + +T = TypeVar("T") +U = TypeVar("U") +_P = ParamSpec("_P") +R = TypeVar("R") +_Ts = TypeVarTuple("_Ts") + +null_ctx_type = type(nullcontext) +# We currently convert all SymInt to proxies before we use them. +# This could plausibly be handled at the Dynamo level. +pytree.register_pytree_node( + torch.Size, + lambda xs: (list(xs), None), + lambda xs, _: tuple(xs), + flatten_with_keys_fn=lambda xs: ( + [(pytree.SequenceKey(i), x) for i, x in enumerate(xs)], + None, + ), + serialized_type_name="torch.Size", +) +# Ideally unflattening should not lose info, but we unflatten +# torch.Size to tuple (see above). This is necessary because the +# torch.Size constructor only accepts ints whereas our infra often +# transforms them to non-ints, e.g. symint proxies. Anyway, losing +# such info can cause pytree mapping or spec matching to fail, so +# work around this problem using the following dict as needed. +_pytree_subclasses_that_lose_info = {torch.Size: tuple} + + +def fake_signature(fn: Callable[_P, R], nargs: int) -> Callable[_P, R]: + """FX gets confused by varargs, de-confuse it""" + argnames = ",".join(f"arg{i}" for i in range(nargs)) + return eval(f"lambda {argnames}: fn({argnames})", {"fn": fn}) + + +@contextmanager +def decompose( + decomposition_table: Optional[Mapping[OpOverload, Callable]], +) -> Generator[Mapping[OpOverload, Callable], None, None]: + global CURRENT_DECOMPOSITION_TABLE + old_decomposition_table = CURRENT_DECOMPOSITION_TABLE + CURRENT_DECOMPOSITION_TABLE = decomposition_table or {} + try: + yield CURRENT_DECOMPOSITION_TABLE + finally: + CURRENT_DECOMPOSITION_TABLE = old_decomposition_table + + +# ensure we cannot collide with other properties +proxy_slot = object() + + +class _NoDefault: + pass + + +no_default = _NoDefault() + +from torch.types import py_sym_types, PySymType + + +class _HasMeta(Protocol): + meta: dict[str, PySymType] + + +def is_sym_node(node: _HasMeta) -> bool: + assert hasattr(node, "meta"), "All nodes traced with proxy_tensor should have meta" + return "val" in node.meta and isinstance(node.meta["val"], py_sym_types) + + +@overload +def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ... + + +@overload +def set_proxy_slot( + obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy +) -> None: ... + + +@overload +def set_proxy_slot( + obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType +) -> None: ... + + +def set_proxy_slot( + obj: Union[PySymType, _AnyScriptObjectType, Tensor], + tracer: _ProxyTracer, + proxy: object, +) -> None: + log.debug("set_proxy_slot %s (%s) %s", obj, id(obj), proxy) + if isinstance(obj, Tensor): + # We DO want to clobber proxies whenever we run an inplace operation + # on a tensor, and it affects the metadata on the proxy. + assert isinstance(proxy, _ProxyTensor) + tracer.tensor_tracker[obj] = proxy + elif isinstance(obj, (_AnyScriptObject)): + # We DO want to clobber proxies, with a similar rationale as for tensors. + assert isinstance(proxy, Proxy) + tracer.script_object_tracker[obj] = proxy + else: + # NB: Never clobber pre-existing proxy. Although the proxies + # are in principle equivalent, when we do graph partitioning + # we need there not to be spurious dependencies on tangent inputs. + # This works because primals get their SymInts set first, and + # THEN later we allocate tangent inputs. Make sure if a SymInt + # is derivable from a primal that we use that. + assert isinstance(obj, py_sym_types), type(obj) + if obj not in tracer.symnode_tracker: + tracer.symnode_tracker[obj] = typing.cast(_PySymProxyType, proxy) + + # WAR: python test/dynamo/test_subclasses.py + # TestNestedTensor.test_basic_autograd + # + # AOTAutograd doesn't pass the "outer sizes" as an actual argument + # to make_fx, but it is made use of internally in AOTAutograd's + # call to tensor unflatten. Because the outer sizes isn't passed + # as an argument, it is therefore untracked. However, it turns + # out you luck out, because *Dynamo* will manually add the outer + # sizes as an argument so you can fix up the proxy'ness. + # + # This is probably fixed in + # https://github.com/pytorch/pytorch/pull/125941/ + import sympy + + if isinstance(obj.node.expr, sympy.Symbol): + tracer.sympy_expr_tracker[obj.node.expr] = proxy + + +def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool: + assert isinstance(obj, (Tensor, SymNode)), type(obj) + return bool(get_proxy_slot(obj, tracer, False, lambda _: True)) + + +_PySymProxyType = Thunk[Proxy] + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, +) -> _ProxyTensor: ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, +) -> Union[_ProxyTensor, U]: ... + + +@overload +def get_proxy_slot( + obj: Tensor, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_ProxyTensor], R], +) -> Union[R, U]: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, +) -> Proxy: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, +) -> Union[Proxy, U]: ... + + +@overload +def get_proxy_slot( + obj: _AnyScriptObjectType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[Proxy], R], +) -> Union[R, U]: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, +) -> _PySymProxyType: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: T, +) -> Union[T, _PySymProxyType]: ... + + +@overload +def get_proxy_slot( + obj: PySymType, + tracer: _ProxyTracer, + default: U, + transform: Callable[[_PySymProxyType], R], +) -> Union[R, U]: ... + + +# the default argument is what to return if the slot is not set. +# the transform argument is handy if you need to extract a subfield from +# the successfully looked up result (but NOT the default.) +def get_proxy_slot( + obj: Union[Tensor, _AnyScriptObjectType, PySymType], + tracer: _ProxyTracer, + default: object = no_default, + transform: Callable = lambda x: x, +) -> object: + tracker: Any + if isinstance(obj, Tensor): + tracker = tracer.tensor_tracker + elif isinstance(obj, _AnyScriptObject): + tracker = tracer.script_object_tracker + else: + assert isinstance(obj, py_sym_types), type(obj) + tracker = tracer.symnode_tracker + + if obj not in tracker: + # Last ditch + if isinstance(obj, py_sym_types) and obj.node.expr in tracer.sympy_expr_tracker: + value = tracer.sympy_expr_tracker[obj.node.expr] + else: + if isinstance(default, _NoDefault): + raise RuntimeError( + f"{obj} ({id(obj)})is not tracked with proxy for {tracer}" + ) + return default + else: + value = tracker[obj] + res = transform(value) + return res + + +def snapshot_fake(val: Tensor, include_real: bool = False) -> Optional[Tensor]: + # val.detach() will also eventually call fast_detach(), + # but this saves us a full trip into __torch_dispatch__ + # (snapshot_fake is called a lot) + if isinstance(val, FakeTensor): + return fast_detach(val.fake_mode, val, include_real) + else: + return val.detach() + + +_ExtractValType = Optional[ + Union[ + PySymType, + _AnyScriptObjectType, + BackwardState, + list["_ExtractValType"], + tuple["_ExtractValType", ...], + dict[str, "_ExtractValType"], + Tensor, + int, + float, + bool, + ] +] + + +def extract_val(val: _ExtractValType, include_real: bool = False) -> _ExtractValType: + if is_fake(val): + return snapshot_fake(val, include_real=include_real) + elif isinstance(val, py_sym_types): + return val + elif isinstance(val, _AnyScriptObject): + return val + elif isinstance(val, BackwardState): + return val + elif isinstance(val, (list, tuple)): + return val.__class__([extract_val(x) for x in val]) + elif isinstance(val, dict): + return {k: extract_val(v) for k, v in val.items()} + elif isinstance(val, Tensor): + if not val.is_sparse: + # NB: Kinda hacky, but we should try to get val as the metadata + # everywhere + # TODO: This doesn't properly track storages. A more robust + # approach would be to maintain a per-trace FakeTensorMode and + # from_real_tensor to create fake values (don't forget to + # snapshot_fake) + from torch._guards import detect_fake_mode + + fake_tensor_mode = detect_fake_mode(val) + if not fake_tensor_mode: + fake_tensor_mode = FakeTensorMode(allow_fallback_kernels=True) + with fake_tensor_mode: + return torch.empty_strided( + val.shape, val.stride(), device=val.device, dtype=val.dtype + ) + else: + return None + elif isinstance(val, (int, float, bool)): + return val + elif val is None: + return None + + typing_extensions.assert_never(val) + + +@contextmanager +def _enable_thunkify( + tracer: _ProxyTracer, *, enable: bool = True +) -> Generator[None, None, None]: + """ + Enable thunkification inside the context manager. Thunkification prevents + SymNode computation from directly being traced into an FX graph; instead, + the compute is only added to the graph if it is actually used. This helps + us track SymNode compute when it is computed (since we need /something/ + to put in the tracker) even if it is unlikely to be used. + """ + old = tracer.enable_thunkify + tracer.enable_thunkify = enable + try: + yield + finally: + tracer.enable_thunkify = old + + +@contextmanager +def maybe_disable_thunkify() -> Generator[None, None, None]: + """Within a context, disable thunkification. See :func:`maybe_enable_thunkify` + for more details. This is helpful if you have a wrapper function which + you want to enable thunkification on, but in some segment on the inside (say, + the original user function), you want to disable thunkification as you know + it is not needed there. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer, enable=False): + yield + else: + yield + + +@contextmanager +def maybe_enable_thunkify() -> Generator[None, None, None]: + """Within this context manager, if you are doing make_fx tracing, we will thunkify + all SymNode compute and avoid tracing it into the graph unless it is actually needed. + You should prefer to avoid using this as much as possible, as lazy evaluation of + SymNode tracing can lead to long chains of thunks which will stack overflow + if you evaluate them. However, this is currently sometimes necessary as there + are buggy parts of PT2 which will fail with "s0 is not tracked with proxy" error + due to insufficient tracing of SymNode computation. + """ + proxy_mode = get_proxy_mode() + if proxy_mode is not None: + with _enable_thunkify(proxy_mode.tracer): + yield + else: + yield + + +# Note [invariants for node meta 'val'] +# What invariants do we have for the 'val' set on the FX node? It has accurate +# metadata... but only for metadata that exists "below" all other subsystems +# (most notably autograd, but also vmap, functorch transforms, etc). This means +# you can get the dtype, shape, stride, storage, but you CANNOT get requires_grad, +# grad_fn, _base (_base actually may be set due to recursive call to +# ADInplaceOrView, but you shouldn't rely on it.) +def set_meta(proxy: Proxy, val: _ExtractValType) -> Proxy: + proxy.node.meta["val"] = extract_val( + val, include_real=(proxy.node.op == "placeholder") + ) + + with _enable_thunkify(proxy.tracer): # type: ignore[arg-type] + # Best effort tensor_meta setting; prefer using val! + if is_fake(val): + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + elif isinstance(val, Tensor) and not val.is_sparse: + proxy.node.meta["tensor_meta"] = _extract_tensor_metadata(val) + return proxy + + +def thunkify( + tracer: _ProxyTracer, f: Callable[_P, R], *args: _P.args, **kwargs: _P.kwargs +) -> Thunk[R]: + """ + Delays computation of f until it's called again + Also caches the result + """ + if tracer.enable_thunkify: + return Thunk(functools.partial(f, *args, **kwargs)) + else: + r = f(*args, **kwargs) + return Thunk(lambda: r) + + +def track_tensor( + tensor: Tensor, proxy: Proxy, *, constant: Optional[Tensor], tracer: _ProxyTracer +) -> None: + def try_set_proxy_slot( + outer_s: IntLikeType, + proxy_callable: Callable[Concatenate[PySymType, _P], Proxy], + *args: _P.args, + **kwargs: _P.kwargs, + ) -> None: + assert callable(proxy_callable) + if isinstance(outer_s, SymInt): + with _enable_thunkify(tracer): + set_proxy_slot( + outer_s, + tracer, + thunkify(tracer, proxy_callable, outer_s, *args, **kwargs), + ) + + # The basic idea is that we need to associate each tensor/SymInt + # with a Proxy. How do we setup this association? We just store + # the proxy on the proxy slot of the object, keyed on the tracer + # (so that if we have multiple tracers at the same time, they + # don't clobber each other.) + for i, s in enumerate(tensor.shape): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_size.int, (proxy, i), {} + ), + x, + ), + i, + ) + + if not is_sparse_any(tensor): + for i, s in enumerate(tensor.stride()): + try_set_proxy_slot( + s, + lambda x, i: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_stride.int, (proxy, i), {} + ), + x, + ), + i, + ) + + try_set_proxy_slot( + tensor.numel(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", torch.ops.aten.sym_numel.default, (proxy,), {} + ), + x, + ), + ) + if not is_sparse_any(tensor): + try_set_proxy_slot( + tensor.storage_offset(), + lambda x: set_meta( + tracer.create_proxy( + "call_function", + torch.ops.aten.sym_storage_offset.default, + (proxy,), + {}, + ), + x, + ), + ) + set_proxy_slot(tensor, tracer, _ProxyTensor(proxy, constant)) + + +_NestedProxys = Union[ + Proxy, Sequence["_NestedProxys"], Mapping[object, "_NestedProxys"] +] +_NestedTensors = Union[ + Tensor, Sequence["_NestedTensors"], Mapping[object, "_NestedTensors"] +] + + +def track_tensor_tree( + inner_res: T, + proxy_res: _NestedProxys, + *, + constant: Optional[_NestedTensors], + tracer: _ProxyTracer, +) -> T: + # NB: We call set_unbacked_bindings only on the *topmost* call to + # track_tensor_tree, not recursive calls. This is because there must + # be only ONE unbacked_binding proxy call, and it should be the one + # where all of the unbacked SymInts actually first come into existence. + # If you call this again on the inner proxies for the tuple projections, + # you will have multiple unbacked_bindings for the same symbol, but + # they're not going to show up anywhere. + # + # I was briefly deceived into setting unbacked bindings recursively when + # working on https://github.com/pytorch/pytorch/pull/133585 because I + # observed that some extra unbacked bindings were needed to handle some + # higher order operator code. But actually it looks like this was + # just an unrelated bug that needed to be fixed separately. + _set_unbacked_bindings(inner_res, proxy_res) + + def wrap_with_proxy( + e: object, proxy: _NestedProxys, constant: Optional[_NestedTensors] + ) -> None: + if isinstance(e, Tensor): + assert isinstance(proxy, Proxy) + assert constant is None or isinstance(constant, Tensor) + track_tensor(e, proxy, tracer=tracer, constant=constant) + set_meta(proxy, e) + elif isinstance(e, py_sym_types): + assert isinstance(proxy, Proxy) + # NB: eagerly set meta here, so that the numbering is in order + set_meta(proxy, e) + set_proxy_slot(e, tracer, thunkify(tracer, lambda: proxy)) + elif isinstance(e, _AnyScriptObject): + assert isinstance(proxy, Proxy) + set_proxy_slot(e, tracer, proxy) + set_meta(proxy, e) + elif isinstance(e, (tuple, list)): + # example use case: allreduce_ returns ([tensor], work) + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + def get_constant( + c: Optional[_NestedTensors], idx: int + ) -> Optional[_NestedTensors]: + if c is None: + return None + else: + assert isinstance(c, (list, tuple)) + return c[idx] + + for idx, ee in enumerate(e): + # Use an indexer here - if proxy is a List then it will unwrap + # it. If it's a Proxy then it will proxy the getelem. + wrap_with_proxy(ee, proxy[idx], get_constant(constant, idx)) # type: ignore[index] + + elif isinstance(e, dict): + # example use case: triton_kernel_wrapper takes arguments as kwargs + + # In theory we could support const-prop when proxy-tensor-tracing + # operators that returns dicts of tensors, but we have no use case + # for it today (since the only op we currently trace that can + # return a dict is triton_kernel_wrapper_functional/mutation, + # which does not participate in const-prop) + assert constant is None + + if isinstance(proxy, fx.Proxy): + set_meta(proxy, e) + + for key, val in e.items(): + wrap_with_proxy(val, proxy[key], None) # type: ignore[index] + + elif isinstance(e, BackwardState): + assert isinstance(proxy, Proxy) + set_meta(proxy, e) + e.proxy = proxy + else: + # intentionally pass on primitives + pass + + wrap_with_proxy(inner_res, proxy_res, constant) + + return inner_res + + +@dataclass +class _ProxyTensor: + proxy: Proxy + constant: Optional[Tensor] + + +def fetch_sym_proxy( + tracer: _ProxyTracer, +) -> Callable[[PySymType], Union[bool, int, float, Proxy]]: + def inner(e: PySymType) -> Union[int, bool, float, Proxy]: + n = e.node + if n.constant is not None: + return n.constant + if e.node.expr.is_number: + if isinstance(e, SymBool): + return bool(e.node.expr) + elif isinstance(e, SymInt): + return int(e.node.expr) + return float(e.node.expr) + else: + assert isinstance(e, py_sym_types) + # NB: we REQUIRE all symints to be tracked + return get_proxy_slot(e, tracer).force() + + return inner + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: Tensor +) -> Union[_ProxyTensor, Tensor]: ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: _AnyScriptObjectType +) -> Union[Proxy, _AnyScriptObjectType]: ... + + +@overload +def fetch_object_proxy( + tracer: _ProxyTracer, t: PySymType +) -> Union[_PySymProxyType, PySymType]: ... + + +def fetch_object_proxy( + tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] +) -> object: + return get_proxy_slot(t, tracer, t) + + +HANDLED_TYPES = (Tensor, torch.nn.Parameter, FakeTensor) + + +def _maybe_record_pointwise_barrier( + func: object, proxy_mode: ProxyTorchDispatchMode +) -> None: + """ + Records pointwise operators in user program (non decomposed) that were output in fp16/bf16 + """ + if proxy_mode.decomp_layers or not proxy_mode.emulate_precision_casts: + return + + if ( + not isinstance(func, torch._ops.OpOverload) + or torch.Tag.pointwise not in func.tags + ): + return + + last_node = next(iter(reversed(proxy_mode.tracer.graph.nodes))) + t = last_node.meta.get("val") + if not isinstance(t, torch.Tensor) or t.dtype not in ( + torch.bfloat16, + torch.float16, + ): + return + + last_node.meta["low_precision_pointwise_barrier"] = True + + +def proxy_call( + proxy_mode: ProxyTorchDispatchMode, + func: OpOverload, + pre_dispatch: bool, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + unrecognized_types: list[type] = [] + flat_args_kwargs, spec = pytree.tree_flatten((args, kwargs)) + + def can_handle_tensor(x: Tensor) -> bool: + r = type(x) in HANDLED_TYPES or has_proxy_slot(x, proxy_mode.tracer) + if proxy_mode._allow_fake_constant: + r = r or type(x) in (torch._subclasses.FakeTensor,) + if not r: + unrecognized_types.append(type(x)) + return r + + # If there are any tensor subclasses, we need to handle those tensor subclasses first + # TODO: we could use types to test this + if not all(can_handle_tensor(x) for x in flat_args_kwargs if isinstance(x, Tensor)): + not_implemented_log.debug( + "ProxyTensorMode tensors without proxy had unrecognized subclasses: %s", + unrecognized_types, + ) + return NotImplemented + + r = maybe_handle_decomp(proxy_mode, func, args, kwargs) + if r is not NotImplemented: + _maybe_record_pointwise_barrier(func, proxy_mode) + return r + + # For pre-autograd tracing, we do not want to run CompositeImplicit decomps. + if not pre_dispatch and func not in [ + torch.ops.aten.size.default, + torch.ops.aten.stride.default, + torch.ops.aten.storage_offset.default, + ]: + with proxy_mode: + r = func.decompose(*args, **kwargs) + if r is not NotImplemented: + return r + + if func is torch.ops.aten.is_nonzero.default: + with proxy_mode: + torch._check( + args[0].numel() == 1, # type: ignore[attr-defined] + lambda: "Boolean value of Tensor with more than one value is ambiguous", + ) + return (args[0] != 0).item() # type: ignore[attr-defined] + + tracer = proxy_mode.tracer + f_flat_args_kwargs = [ + ( + fetch_object_proxy(tracer, x) + if isinstance(x, (Tensor, _AnyScriptObject)) + else x + ) + for x in flat_args_kwargs + ] + + # If there are SymInts, we also should not consider this constant. + # However, fake tensor handling of SymInts is sufficiently broken that + # I couldn't write a test for this case + all_constant = ( + not any( + t.constant is None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + # TODO: maybe constant SymInts should also be allowed? Not sure if + # this can happen + and not any(isinstance(x, py_sym_types) for x in flat_args_kwargs) + ) + + if torch.Tag.data_dependent_output in func.tags: + # Check if all of the Tensor inputs are constants + if all_constant: + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + with unset_fake_temporarily(): + return func(*const_args, **const_kwargs) + # If any of the Tensor inputs are "real" (not FakeTensor), we may + # incorrectly burn in constants by allowing this access. Raise + # an error in this case + if proxy_mode._error_on_data_dependent_ops and pytree.tree_all_only( + Tensor, lambda t: not is_fake(t), (args, kwargs) + ): + raise RuntimeError( + f"It appears that you're trying to get value out of a tracing tensor with {func} - erroring out! " + "It's likely that this is caused by data-dependent control flow or similar. " + "It may be possible to trace this with dynamic shapes; try setting tracing_mode='symbolic' " + "in your make_fx call." + ) + + proxy_flat_args_kwargs = [ + e.proxy if isinstance(e, _ProxyTensor) else e for e in f_flat_args_kwargs + ] + proxy_flat_args_kwargs = [ + (fetch_sym_proxy(proxy_mode.tracer)(e) if isinstance(e, py_sym_types) else e) + for e in proxy_flat_args_kwargs + ] + proxy_args, proxy_kwargs = pytree.tree_unflatten(proxy_flat_args_kwargs, spec) + + # When we trace through a torch.tensor invocation, you never actually + # see a torch.ops.aten.tensor call. Instead, the way this function is + # implemented internally is that we allocate a plain tensor (this is + # *guaranteed* to be a plain tensor, we disable all modes when doing + # so), and then call at::lift_fresh on it (to give modes a chance to do + # their stuff). Furthermore, the tensor argument to lift_fresh is guaranteed + # to be freshly allocated, so we want lift_fresh to be a no-op (directly + # returning the input argument). + # + # Here is the basic problem: when we trace this sequence of executions + # into an FX graph, what happens to this call sequence? Traditionally, + # tensor constants get interned as buffers on the FX GraphModule. But + # this is dangerous. Consider: + # + # x = torch.tensor(1) + # x.add_(2) + # + # Naively, this traces into: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh(t) + # x.add_(2) + # + # If lift_fresh returns t directly, the subsequent add_ call will + # modify the tensor constant. Really, the problem is we've violated + # the invariant the argument to lift is fresh. So what we should + # preserve the invariant by replacing lift_fresh with lift_fresh_copy: + # + # t = self._tensor_constant0 # initialized to torch.tensor(1) + # x = torch.ops.aten.lift_fresh_copy(t) + # x.add_(2) + # + # This is what the overload modification does. + if func is torch.ops.aten.lift_fresh.default: + func = torch.ops.aten.lift_fresh_copy.default + + proxy_out = proxy_mode.tracer.create_proxy( + "call_function", + func, + proxy_args, + proxy_kwargs, + name=proxy_mode.tracer.graph._target_to_str(func.overloadpacket.__name__), + ) + + with _enable_thunkify(proxy_mode.tracer): + out = func(*args, **kwargs) + + # In some circumstances, we will be tracing in a situation where a tensor + # is *statically* known to be a constant (currently, this only happens if + # you run torch.tensor; deterministic factory functions like torch.arange + # don't get this treatment). When the tensor in question is small, it's + # helpful to due constant propagation in case we call item() (in which + # case we can return the constant value that is known, rather than give + # an error.) The logic here tests if constant propagation is possible + # (because all of the inputs are constant). If so, we disable fake tensor + # mode (if it is on) and do true compute on the constant. + # + # It's worth highlighting that we're making a policy decision here. + # There is a potential that the tensor is actually quite large, and we + # don't actually want to run the compute. The tensor being quite large + # is one of the reasons why factory functions don't get this treatment + # (since they can be quite large; if a parameter is initialized to a + # constant value it will be!) Similarly, there is also a potential + # to run an operator that blows up the size of a small tensor; we don't + # protect against this case, but we could force, e.g., only single + # element constant computation by testing the numel of the result before + # propagating const-ness. Similarly, we don't require the constant to + # live on CPU, but we could. + any_constant = any( + t.constant is not None + for t in f_flat_args_kwargs + if isinstance(t, _ProxyTensor) + ) + + constant = None + + def tensor_numel_in_limit(t: Tensor) -> bool: + return t.numel() <= CONSTANT_NUMEL_LIMIT + + # If this is a lift, the input tensor is guaranteed to be a + # constant, so we keep a copy of the original argument along so + # we can query it if we're asked to item() it at some later point + if ( + func is torch.ops.aten.lift_fresh_copy.default + and out.numel() <= CONSTANT_NUMEL_LIMIT + ): + with unset_fake_temporarily(): + assert isinstance(args[0], (Proxy, Tensor)), type(args[0]) + constant = args[0].clone() + elif ( + torch.Tag.nondeterministic_seeded not in func.tags + and all_constant + and any_constant + and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out) + ): + # NB: do NOT include factories as constants + with unset_fake_temporarily(): + const_flat_args_kwargs = [ + t.constant if isinstance(t, _ProxyTensor) else t + for t in f_flat_args_kwargs + ] + const_args, const_kwargs = pytree.tree_unflatten( + const_flat_args_kwargs, spec + ) + constant = func(*const_args, **const_kwargs) + else: + constant = None + + track_tensor_tree(out, proxy_out, constant=constant, tracer=tracer) + _maybe_record_pointwise_barrier(func, proxy_mode) + return out + + +class _SymNodeDict: + """ + Wrapper around a dictionary that will hash SymInts with their nodes + """ + + def __init__(self) -> None: + self.sym_node_dict: dict[PySymType, _PySymProxyType] = {} + + def __setitem__(self, key: PySymType, value: _PySymProxyType) -> None: + self.sym_node_dict[key.node] = value + + def __getitem__(self, key: PySymType) -> _PySymProxyType: + return self.sym_node_dict[key.node] + + def __contains__(self, key: PySymType) -> bool: + return key.node in self.sym_node_dict + + def get( + self, key: PySymType, default: Optional[_PySymProxyType] = None + ) -> _PySymProxyType: + # dict.get()'s annotation doesn't accept `None` when the value type + # isn't Optional. + return self.sym_node_dict.get(key.node, default) # type: ignore[arg-type, return-value] + + def __iter__(self) -> Any: + raise NotImplementedError + + def __len__(self) -> int: + return len(self.sym_node_dict) + + +class PythonKeyTracer(Tracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: _SymNodeDict + sympy_expr_tracker: dict[sympy.Symbol, object] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + torch_fn_counts: dict[OpOverload, int] + enable_thunkify: bool = False + stack_trace: bool = False + + def __init__(self) -> None: + super().__init__(autowrap_modules=()) # type: ignore[arg-type] + self.tensor_tracker = WeakTensorKeyDictionary() + self.symnode_tracker = _SymNodeDict() + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + self.sympy_expr_tracker = dict() + + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + self.enable_thunkify = False + + # In general, we don't want to make modules leaves. In principle, users of + # this tracer might want to override this in order to turn a couple specific + # modules into leaves in the traced graph. + def call_module( + self, + m: Module, + forward: Callable[..., Any], + args: tuple[Any, ...], + kwargs: dict[str, Any], + ) -> Any: + return forward(*args, **kwargs) + + # We don't want to turn getattr calls into proxies. So we just return the actual value. + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] + ) -> object: + return attr_val + + def create_arg(self, a: object) -> fx.node.Node: + if isinstance(a, torch.nn.Parameter): + for n, p in self.root.named_parameters(): + if a is p: + return self.create_node("get_attr", n, (), {}) + + qualname = self.get_fresh_qualname("_param_constant") + setattr(self.root, qualname, a) + + return self.create_node("get_attr", qualname, (), {}) + elif isinstance(a, py_sym_types): + assert a.node.constant is not None + return a.node.constant + return super().create_arg(a) # type: ignore[return-value] + + @overload + def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ... + + @overload + def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ... + + @overload + def unwrap_proxy( + self, e: _AnyScriptObjectType + ) -> Union[Proxy, _AnyScriptObjectType]: ... + + def unwrap_proxy(self, e: T) -> object: + if isinstance(e, Tensor): + return get_proxy_slot(e, self, e, lambda x: x.proxy) + elif isinstance(e, py_sym_types): + return get_proxy_slot(e, self, e, lambda e: e.force()) + elif isinstance(e, _AnyScriptObject): + return get_proxy_slot(e, self, e) + else: + return e + + def create_node( + self, + kind: str, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> torch.fx.Node: + node = super().create_node(kind, target, args, kwargs, name, type_expr) # type: ignore[arg-type] + + # stack_trace + if ( + self.stack_trace + and "stack_trace" not in node.meta + and node.op not in ["placeholder", "output"] + ): + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + # we retain frames from forward() calls, or ops + # located in torch/__init__.py (e.g. sym_int, sym_constrain_range, vmap) + stack_trace = [ + frame + for frame in user_frame_summary + if ( + frame.name == "forward" + or frame.filename.endswith("torch/__init__.py") + ) + ] + # filter out forward() frames from fx/_symbolic_trace.py, export/_trace.py + # this is hardcoded, but leads to a much cleaner stack trace + stack_trace = [ + frame + for frame in stack_trace + if not frame.filename.endswith( + ("fx/_symbolic_trace.py", "export/_trace.py") + ) + ] + if ( + stack_trace + ): # empty list for strict mode, dynamo should handle stack_trace + stack_trace = traceback.StackSummary.from_list(stack_trace) + node.meta["stack_trace"] = "".join(stack_trace.format()).strip() + + if kind == "get_attr": + assert isinstance(target, str) + attr = getattr(self.root, target) + if isinstance(attr, torch.Tensor): + with disable_proxy_modes_tracing(): + node.meta["val"] = extract_val(attr) + + def map_fn(v: Any) -> Optional[_ExtractValType]: + if not isinstance(v, torch.fx.Node) or "val" not in v.meta: + return None + val = v.meta["val"] + # other subclasses like FunctionalTensor error on `extract_val` + # "Attempting to use FunctionalTensor on its own." just store FakeTensors for now + if isinstance(val, torch.Tensor) and not isinstance(val, FakeTensor): + return None + return extract_val(v.meta["val"]) + + if _should_save_eager_input_vals(target, (args, kwargs)): + # NOTE "eager_input_vals" + # We save the original (args, kwargs) FakeTensor values for nodes + # that have exact stride requirements. This is useful downstream. + # We use this information inside Inductor to ensure that inputs to + # stride-sensitive operators have the correct strides. + arg_inp, kwarg_inp = torch.fx.node.map_aggregate((args, kwargs), map_fn) # type: ignore[misc, arg-type] + node.meta["eager_input_vals"] = (arg_inp, kwarg_inp) + + return node + + +def _should_save_eager_input_vals( + target: Any, + args_kwargs: Optional[tuple[tuple[Argument, ...], dict[str, Argument]]] = None, +) -> bool: + from torch._higher_order_ops.invoke_subgraph import InvokeSubgraphHOP + + if not callable(target): + return False + if isinstance( + target, + ( + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperFunctional, + torch._higher_order_ops.triton_kernel_wrap.TritonKernelWrapperMutation, + InvokeSubgraphHOP, + ), + ): + return True + if args_kwargs is not None and ( + target is torch.ops.higher_order.auto_functionalized + or target is torch.ops.higher_order.auto_functionalized_v2 + ): + args = args_kwargs[0] + assert isinstance( + args[0], (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ) + return _should_save_eager_input_vals(args[0], None) + if target is torch.ops.higher_order.with_effects: + # TODO: inductor lowering for with_effects needs to be updated to propagate + # the arg_kwarg_vals + return False + if isinstance(target, torch._ops.HigherOrderOperator): + if pytree.tree_any(_should_save_eager_input_vals, args_kwargs): + raise RuntimeError( + f"NYI: The HOP {target} has an input that is an OpOverload that " + f"needs exact strides. We probably need special logic to " + f"propagate the FakeTensor vals. Please file an issue." + ) + if isinstance(target, torch._ops.OpOverload): + from torch._library.utils import get_layout_constraint_tag + + return get_layout_constraint_tag(target) == torch._C.Tag.needs_exact_strides + return False + + +def _make_temp_remove_mode_context_manager( + mode_ty: type[TorchFunctionMode], +) -> Callable[[], _GeneratorContextManager[Optional[TorchFunctionMode]]]: + @contextmanager + def context_manager_fn() -> Generator[Optional[TorchFunctionMode], None, None]: + from torch.overrides import _len_torch_function_stack, _pop_mode, _push_mode + + temp_elements = [] + removed_mode = None + + while _len_torch_function_stack() > 0: + mode = _pop_mode() + if isinstance(mode, mode_ty): + removed_mode = mode + break + else: + temp_elements.append(mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + try: + yield removed_mode + + finally: + if removed_mode is not None: + count = len(temp_elements) + while count > 0: + mode = _pop_mode() + count -= 1 + + temp_elements.append(removed_mode) + + for mode in reversed(temp_elements): + _push_mode(mode) + + return context_manager_fn + + +@torch._disable_dynamo +def dispatch_trace( + root: Union[Module, Callable], + tracer: Tracer, + concrete_args: Optional[tuple[Any, ...]] = None, +) -> GraphModule: + graph = tracer.trace(root, concrete_args) # type: ignore[arg-type] + + # NB: be careful not to DCE .item() calls + def impure_pred(n: fx.Node) -> bool: + from .symbolic_shapes import is_accessor_node + + # Always defer to the built-in notion of impure + if n.is_impure(): + return True + + # Accessors always OK to DCE + if is_accessor_node(n): + return False + + # If the operator in question takes SymInt args to SymInt output, + # we assume it's pure and OK to DCE + if ( + isinstance(n.meta.get("val"), py_sym_types) + and + # NB: constant args ok + all( + isinstance(a.meta.get("val"), py_sym_types) + for a in n.args + if isinstance(a, fx.Node) + ) + ): + return False + + # No idea, just assume it's not OK + return True + + graph.eliminate_dead_code(impure_pred) + from torch._inductor.fx_passes.dedupe_symint_uses import dedupe_symints + + dedupe_symints(graph) + name = root.__class__.__name__ if isinstance(root, Module) else root.__name__ + return fx._lazy_graph_module._make_graph_module(tracer.root, graph, name) + + +def wrap_key( + f: Callable[[Unpack[_Ts]], R], + tensors: tuple[Unpack[_Ts]], + tracer: _ProxyTracer, + pre_dispatch: bool, +) -> Callable[_P, R]: + flat_tensors, _tensors_spec = pytree.tree_flatten(tensors) + + @functools.wraps(f) + def wrapped(*proxies: _P.args, **_unused: _P.kwargs) -> R: + flat_proxies, _proxies_spec = pytree.tree_flatten(proxies) + assert len(flat_proxies) == len(flat_tensors) + with disable_proxy_modes_tracing() as m: + assert isinstance(m, ProxyTorchDispatchMode) + track_tensor_tree(flat_tensors, flat_proxies, constant=None, tracer=tracer) + + def get_tensor_proxy_slot(t: Tensor) -> Union[Tensor, Proxy]: + return get_proxy_slot(t, tracer, t, lambda x: x.proxy) # type: ignore[attr-defined] + + out = f(*tensors) # type:ignore[call-arg] + out = pytree.tree_map_only(Tensor, get_tensor_proxy_slot, out) + out = pytree.tree_map_only( + _AnyScriptObject, lambda t: get_proxy_slot(t, tracer, t, lambda x: x), out + ) + + def get_sym_proxy_slot(t: PySymType) -> Proxy: + return get_proxy_slot(t, tracer).force() + + out = pytree.tree_map_only(py_sym_types, get_sym_proxy_slot, out) + return out + + return wrapped + + +# TODO: Make downstream users of this work with OperatorBase +ORIGINAL_ATEN: Optional[object] = None + + +@contextmanager +def set_original_aten_op(func: OpOverload) -> Generator[None, None, None]: + global ORIGINAL_ATEN + if ORIGINAL_ATEN is None and fx_traceback.has_preserved_node_meta(): + ORIGINAL_ATEN = func + fx_traceback.current_meta["original_aten"] = func + try: + yield + finally: + ORIGINAL_ATEN = None + fx_traceback.current_meta["original_aten"] = None + else: + yield + + +class TorchFunctionMetadataMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + + def __torch_function__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + self.tracer.torch_fn_metadata = func + self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1 + return func(*args, **kwargs) + + +_temp_remove_metadata_torch_function_mode = _make_temp_remove_mode_context_manager( + TorchFunctionMetadataMode +) + + +# This mode is **only** used for pre_dispatch tracing. +# In particular, we need to make sure that autograd/autocast API's +# that do not desugar into dispatcher operators stay in the graph. +class PreDispatchTorchFunctionMode(TorchFunctionMode): + def __init__(self, tracer: _ProxyTracer) -> None: + self.tracer = tracer + # The input to torch.amp.autocast_mode._exit_autocast graph node should be the + # enter_autocast node. So we have to save the enter autocast node here, and assign it + # to the exit_autocast call_function node. + self.enter_autocast_nodes: list[torch.fx.Node] = [] + + def __torch_function__( + self, + func: Union[OpOverload, Callable], + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + kwargs = kwargs or {} + if func in _side_effectful_need_to_be_preserved_pre_dispatch: + # It's for passing the export verifier which needs to verify the meta['val'] + # TODO(tmanlaibaatar): we should systematically couple it with expoert verifier, + # instead of hardcoding it here. + # T203648563 + if func == torch.amp.autocast_mode._exit_autocast: + enter_node = self.enter_autocast_nodes.pop() + args = (enter_node,) + node = self.tracer.create_node("call_function", func, args, {}) # type: ignore[arg-type] + if func == torch.amp.autocast_mode._enter_autocast: + self.enter_autocast_nodes.append(node) + if func in [ + torch._C._set_grad_enabled, + torch.amp.autocast_mode._enter_autocast, + torch.amp.autocast_mode._exit_autocast, + ]: + node.meta["val"] = None + return node + # Don't actually run the function! We just want to trace the calls + # into a graph. We don't actualy want to change global autograd state. + return func(*args, **kwargs) + + +_temp_remove_pre_dispatch_torch_function_mode = _make_temp_remove_mode_context_manager( + PreDispatchTorchFunctionMode +) + + +class ProxyTorchDispatchMode(TorchDispatchMode): + # Ensure this is read-only; this exists only for legacy reasons + @property + def enable_tracing(self) -> bool: + return True + + def __init__( + self, + tracer: _ProxyTracer, + tracing_mode: str, + pre_dispatch: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + ) -> None: + dk = torch._C.DispatchKey.PreDispatch if pre_dispatch else None + super().__init__(dk) + self.tracer = tracer + self.tracing_mode = tracing_mode + self.pre_dispatch = pre_dispatch + self._allow_fake_constant = _allow_fake_constant + self._error_on_data_dependent_ops = _error_on_data_dependent_ops + # Indicates to our torch_dispatch dispatching infra that + # this is an "infra" mode with lower dispatching precedence. + self._mode_key = torch._C._TorchDispatchModeKey.PROXY + # Every time we enter a mode, we maintain a stack telling us what the previous + # ProxyTorchDispatchMode state was (if there was any). + # This lets us properly reset the state on exit. + self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = [] + self.decomp_layers: int = 0 + from torch._inductor import config + + self.emulate_precision_casts: bool = config.emulate_precision_casts + + @count + def __torch_dispatch__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...] = (), + kwargs: Optional[dict[str, object]] = None, + ) -> object: + with set_original_aten_op(func): + kwargs = kwargs or {} + + if func in (prim.device.default,): + return func(*args, **kwargs) + + return proxy_call(self, func, self.pre_dispatch, args, kwargs) + + def __enter__(self) -> Self: + # Stash and store the previous proxy mode (there may or may not be one) + maybe_prev_proxy_mode = _unset_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + self.enter_stack.append(maybe_prev_proxy_mode) + return super().__enter__() + + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[types.TracebackType], + ) -> Optional[bool]: + b = super().__exit__(exc_type, exc_value, traceback) + + # Re-enable the previous proxy mode, if there was one. + mb_previous_proxy_mode = self.enter_stack.pop() + if mb_previous_proxy_mode is not None: + _push_mode(mb_previous_proxy_mode) + + return b + + @classmethod + def is_infra_mode(cls) -> bool: + return True + + def _compute_proxy( + self, func: OpOverload, args: tuple[object, ...], out: PySymType + ) -> Proxy: + # Handle torch.sym_sum + n_args: tuple[object, ...] + if len(args) == 1 and isinstance(args[0], (list, tuple)): + n_args = ( + tuple( + ( + get_proxy_slot(a, self.tracer).force().node + if isinstance(a, py_sym_types) + else a + ) + for a in args[0] + ), + ) + else: + n_args = tuple( + ( + get_proxy_slot(a, self.tracer).force().node + if isinstance(a, py_sym_types) + else a + ) + for a in args + ) + + # func doesn't have a __torch_function__ that Proxy can interpose, so + # we gotta do it manually + n_out = self.tracer.create_node("call_function", func, n_args, {}) # type: ignore[arg-type] + p_out = fx.Proxy(n_out, self.tracer) + set_meta(p_out, out) + return p_out + + def __sym_dispatch__( + self, + func: OpOverload, + types: tuple[torch._C._TensorMeta, ...], + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + # Peephole optimize multiply by one + # NB: be careful not to trigger guards here! + if func == operator.mul: + if isinstance(args[1], int) and args[1] == 1: + return args[0] + elif isinstance(args[0], int) and args[0] == 1: + return args[1] + + # For speed, we assume there are no nested data structures + # (otherwise we could use tree_map) + # We also assume there are no keyword arguments. + assert not kwargs + out = func(*args, **kwargs) + + # If func returned a constant, we don't need to trace; we have + # determined that the result is constant (no matter if the inputs + # were symbolic) and it is no longer necessary to trace the + # computation. This could occur if func triggered some guards. + if isinstance(out, py_sym_types): + p_out_thunk = thunkify( + self.tracer, self._compute_proxy, func=func, args=args, out=out + ) + set_proxy_slot(out, self.tracer, p_out_thunk) + + return out + + +class _GraphAppendingTracerEx(fx.proxy.GraphAppendingTracer): + script_object_tracker: MutableMapping[_AnyScriptObjectType, Proxy] + symnode_tracker: MutableMapping[PySymType, _PySymProxyType] + tensor_tracker: MutableMapping[Tensor, _ProxyTensor] + sympy_expr_tracker: dict[sympy.Symbol, object] + torch_fn_metadata: Optional[OpOverload] + torch_fn_counts: dict[OpOverload, int] + enable_thunkify: bool = False + + def __init__(self, graph: fx.graph.Graph) -> None: + super().__init__(graph) + self.symnode_tracker = weakref.WeakKeyDictionary() + self.tensor_tracker = WeakTensorKeyDictionary() + self.sympy_expr_tracker = {} + self.script_object_tracker = WeakIdKeyDictionary( + dict=None, ref_type=_WeakHashRef + ) + # Stores the torch function that was called during tracing + self.torch_fn_metadata = None + # Stores the counts for every torch function called. This is to help + # distinguish between different calls to the same torch function. + self.torch_fn_counts = {} + + +# TODO: I'm not sure what the point of this class is; you can just +# make_fx through a regular Interpreter +class DecompositionInterpreter(fx.Interpreter): + def __init__( + self, + module: fx.GraphModule, + new_graph: fx.Graph, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + **kwargs: object, + ) -> None: + super().__init__(module, **kwargs) # type: ignore[arg-type] + self.new_graph = new_graph + self.tracer = _GraphAppendingTracerEx(self.new_graph) + # Blegh + self.decomposition_table = decomposition_table or {} + self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") + + def placeholder( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + # TODO handle case where the first character of target is '*' + return out + + def get_attr( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] + proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) + track_tensor_tree(out, proxy, constant=None, tracer=self.tracer) + return out + + # call_function, call_method, call_module get traced automatically by the outer mode. + + def output( + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> object: + out = super().output(target, args, kwargs) # type: ignore[arg-type] + + def get_proxy_node(x: _ProxyTensor) -> fx.node.Node: + return x.proxy.node + + def unwrap(e: Tensor) -> Union[Tensor, fx.Node]: + return get_proxy_slot(e, self.tracer, e, get_proxy_node) + + self.new_graph.output(pytree.tree_map(unwrap, out)) + return out + + def run(self, *args: object, **kwargs: object) -> object: + # Should enter the mode at least once for being able to restore it later + # See: https://github.com/pytorch/pytorch/pull/82549#discussion_r934782025 + with decompose(self.decomposition_table), self.mode: + return super().run(*args, **kwargs) # type: ignore[arg-type] + + +def wrapper_and_args_for_make_fx( + func: Callable[..., R], args: tuple[object, ...], kwargs: dict[str, object] +) -> tuple[Callable[[list[object]], R], list[object]]: + # make_fx doesn't support kwargs, so we need to do this flattening + # and then unflatten the args before calling func + flat_args, spec = pytree.tree_flatten((args, kwargs)) + + def wrapped(flat_args: list[object]) -> R: + fn_args, fn_kwargs = pytree.tree_unflatten(flat_args, spec) + return func(*fn_args, **fn_kwargs) + + return wrapped, flat_args + + +@contextmanager +def disable_autocast_cache() -> Generator[None, None, None]: + old_value = torch.is_autocast_cache_enabled() + torch.set_autocast_cache_enabled(False) + try: + yield + finally: + torch.set_autocast_cache_enabled(old_value) + + +class _ModuleNotInstalledAsSubmoduleError(NameError): + pass + + +# Base class for inline _ModuleStackTracer.__init__.AttrProxy +class _AttrProxy: + def reset_proxy_mapping(self, base: Module, path: str) -> None: + pass + + +class _ModuleStackTracer(PythonKeyTracer): + r"""Customized version of PythonKeyTracer that retains module stack + information in node.meta["nn_module_stack"]. + + FX symbolic trace actually does this already, but it relies on `self.root` + being the actual module being traced. Since make_fx traces a lambda of our + creation, things don't work properly. + + So for this version we hold onto a reference to the original module + (scope_root) and use that to match the path. Also when we see, + A + / \ + B C + \ / + D + we want to record the path as A.B.D by recording only one path. + See Note [Preserving the nn module stack metadata during export non-strict mode] # noqa: W605 + """ + + def __init__(self, scope_root: GraphModule) -> None: + super().__init__() + self.stack_trace = True + self.scope_root = scope_root + self.enable_attr_proxy = False + self.submodule_paths = {} + for name, m in self.scope_root.named_modules(remove_duplicate=False): + if m in self.submodule_paths: + log.info( + "Shared module found between %s and %s, AttrProxy is enabled.", + self.submodule_paths[m], + name, + ) + self.enable_attr_proxy = True + else: + self.submodule_paths[m] = name + + self.proxy_paths: WeakKeyDictionary[_AttrProxy, str] = WeakKeyDictionary() + self.attr_proxy_map: WeakKeyDictionary[Module, _AttrProxy] = WeakKeyDictionary() + self.proxy_modules: WeakKeyDictionary[_AttrProxy, Module] = WeakKeyDictionary() + self.counter = 0 + + self.module_id_cache = defaultdict(list) + for name, mod in self.scope_root.named_modules(remove_duplicate=False): + self.module_id_cache[id(mod)].append(name) + + # Build a wrapper around _AttrProxy to provide the tracer. We can't + # store it on _AttrProxy itself beceause we mimic the underlying class + # (including its attributes). + tracer = self + + class AttrProxy(_AttrProxy): + def __init__(self, base: Union[Module, _AttrProxy], path: str) -> None: + if isinstance(base, _AttrProxy): + base = base.get_base() # type: ignore[attr-defined] + + assert isinstance(base, Module) + # Class is modified to be a subclass of torch.nn.Module + # Warning: We blow away our own attributes here to mimic the base class + # - so don't expect `self.x` to do anything useful. + self.__class__ = type( + base.__class__.__name__, + (self.__class__, base.__class__), + {}, + ) + self.__dict__ = base.__dict__ + self.__class__.__module__ = base.__class__.__module__ + self.__class__.__qualname__ = base.__class__.__qualname__ + + # This overwrites any existing paths if `base` is an AttrProxy + tracer.proxy_paths[self] = path + tracer.proxy_modules[self] = base + + def __getattr__(self, name: str) -> AttrProxy: + assert isinstance(self, Module) + # Calling into torch.nn.Module.__getattr__ with super(), + # That __getattr__ is patched to be module_getattr_wrapper in _symbolic_trace.py. + # which then calls into _ModuleStackTracer.getattr + attr_val = super().__getattr__(name) # type: ignore[misc] + if not isinstance(attr_val, Module): + return attr_val + + return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name) + + def get_base(self) -> Module: + return tracer.proxy_modules[self] + + def __getitem__(self, idx: Union[int, slice]) -> AttrProxy: + if isinstance(idx, slice): + if isinstance(self, torch.nn.Sequential): + # Copied from nn/modules/container.py + res = torch.nn.Sequential( + OrderedDict(list(self._modules.items())[idx]) + ) + return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") + elif isinstance(self, torch.nn.ModuleList): + # Copied from nn/modules/container.py + res = torch.nn.ModuleList(list(self._modules.values())[idx]) + return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}") + + return super().__getitem__(idx) # type: ignore[misc] + + @property + def _modules(self) -> dict[str, AttrProxy]: + assert "_modules" in self.__dict__ + submodules = self.__dict__["_modules"] + assert isinstance(submodules, dict) + return { + key: ( + AttrProxy(value, tracer.proxy_paths[self] + "." + str(key)) # type: ignore[misc] + if value is not None + else value + ) + for key, value in submodules.items() + } + + self.proxy_type = AttrProxy + + def path_of_module(self, mod: Module) -> str: + """ + Use tracked access path during tracing instead of the default BFS behavior. + Still use all the possible module paths to verify the result. + """ + if mod is self.scope_root: + return "" + + if isinstance(mod, _AttrProxy): + return self.proxy_paths[mod] + + try: + return Tracer.path_of_module(self, mod) + except NameError as e: + raise _ModuleNotInstalledAsSubmoduleError from e + + def getattr( + self, attr: str, attr_val: object, parameter_proxy_cache: dict[str, Proxy] + ) -> object: + if ( + not isinstance(attr_val, Module) + or isinstance(attr_val, fx.GraphModule) + or not self.enable_attr_proxy + ): + return super().getattr(attr, attr_val, parameter_proxy_cache) + if isinstance(attr_val, _AttrProxy): + return attr_val + + # See NOTE [caching AttrProxy]. + if attr_val not in self.attr_proxy_map: + self.attr_proxy_map[attr_val] = self.proxy_type(attr_val, attr) + else: + self.attr_proxy_map[attr_val].reset_proxy_mapping(attr_val, attr) + return self.attr_proxy_map[attr_val] + + def trace( # type: ignore[override] + self, root: Union[Module, Callable], concrete_args: Optional[dict[str, object]] + ) -> fx.Graph: + res = super().trace(root, concrete_args) + + # Since we are making _AttrProxy mimic the original + # submodule, when someone registers a module directly + # to the tracer while tracing, the proxy object gets registered + # first. So we need to replace the proxy modules with the real ones + # This can happen during HOO tracing + proxy_module_names_to_be_replaced: list[tuple[str, _AttrProxy]] = [] + for name, module in self.root.named_modules(): + if module in self.proxy_modules: + proxy_module_names_to_be_replaced.append((name, module)) + + def _delete_proxy_attr(obj: Module, target: str) -> bool: + # Copied from fx/graph_module.py + # Customized it for proxy type + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + assert isinstance(obj, Module) + mod = obj + + # Get the parent module + for item in path: + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, (_AttrProxy, Module)): + return False + + if not hasattr(mod, target_submod): + return False + + # At least the leaf module should be proxy type. + if not isinstance(getattr(mod, target_submod), _AttrProxy): + return False + + delattr(mod, target_submod) + return True + + for proxy_module_name, proxy_module in proxy_module_names_to_be_replaced: + _delete_proxy_attr(self.root, proxy_module_name) + actual_module = self.proxy_modules[proxy_module] + _assign_attr(actual_module, self.root, proxy_module_name) + + return res + + def call_module( + self, + m: Module, + forward: Callable, + args: tuple[object, ...], + kwargs: dict[str, object], + ) -> None: + """PythonKeyTracer overrides call_module to avoid the scope handling, + but we actually want it. + """ + from torch._dynamo import OptimizedModule + + # FIXME (tmanlaibaatar) + # When we call torch.compile inside HOO, we will end up + # invoking a module that is not registered on the root. For + # now, we just inline them. But once we start supporting + # mark_strict in export, we do need to properly handle this. + # Right now, it doesn't matter because current non-strict + # use cases don't need to work with HOO. + if isinstance(m, (OptimizedModule, GraphModule)): + return forward(*args, **kwargs) + + try: + return Tracer.call_module(self, m, forward, args, kwargs) + except _ModuleNotInstalledAsSubmoduleError: + log.debug( + "Unable to find the path of the module %s. " + "This might be because the module was not properly registered " + "as a submodule, which is not good practice. We will trace " + "through the module without recording stack information.", + str(m), + ) + return forward(*args, **kwargs) + + def is_leaf_module(self, m: Module, module_qualified_name: str) -> bool: + return False + + def create_node(self, *args: object, **kwargs: object) -> fx.node.Node: + """ + Create node and add on metadata. + Add nn_module_stack here instead of TracerBase, + since calls to make_fx() might not want to record module stack metadata. + Add torch_fn by looking at torch_fn_metadata and torch_fn_counts. + Add stack_trace by filtering out forward() stack frames. + """ + node = super().create_node(*args, **kwargs) # type: ignore[arg-type] + + # nn_module_stack + if node.op not in ["placeholder", "output"]: + if "nn_module_stack" not in node.meta: + node.meta["nn_module_stack"] = self.module_stack + # convert nn_module_stack from Dict[key, (FQN, class)] -> Dict[str, Tuple[str, str]] + for key, (fqn, mod_cls) in node.meta["nn_module_stack"].items(): + if isinstance(mod_cls, type): + node.meta["nn_module_stack"][key] = ( + fqn, + mod_cls.__module__ + "." + mod_cls.__qualname__, + ) + + # torch_fn + if ( + node.op == "call_function" + and self.torch_fn_metadata is not None + and "torch_fn" not in node.meta + ): + node.meta["torch_fn"] = ( + f"{self.torch_fn_metadata.__name__}_{self.torch_fn_counts[self.torch_fn_metadata]}", + f"{self.torch_fn_metadata.__class__.__name__}.{self.torch_fn_metadata.__name__}", + ) + + return node + + +class _MakefxTracer: + def __init__( + self, + decomposition_table: Optional[Mapping[OpOverload, Callable]], + tracing_mode: str, + _allow_non_fake_inputs: bool, + pre_dispatch: bool, + record_module_stack: bool, + _allow_fake_constant: bool, + _error_on_data_dependent_ops: bool, + stack_trace: bool = False, + ) -> None: + # Configurations that are used to initialize the context managers and their states. + # Should not modify them during tracing. + self.decomposition_table: dict[OpOverload, Callable] = dict( + decomposition_table or {} + ) + self.decomposition_table.setdefault( + torch.ops.aten.sym_numel.default, torch._decomp.decompositions.sym_numel + ) + self.tracing_mode: str = tracing_mode + self._allow_non_fake_inputs: bool = _allow_non_fake_inputs + self.pre_dispatch: bool = pre_dispatch + self.record_module_stack: bool = record_module_stack + self._allow_fake_constant: bool = _allow_fake_constant + self._error_on_data_dependent_ops: bool = _error_on_data_dependent_ops + + # All context managers and their states should be initialized before tracing based on the inputs + # and configurations. After tracing, their states should be cleaned except for shape_env. + # Remember to specify how to initialize it from user inputs and from parent tracer whenever + # adding new modes in _MakefxTracer. + self.fake_tensor_mode: Optional[FakeTensorMode] = None + self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() + self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = ( + nullcontext() + ) + self.fx_tracer: Optional[PythonKeyTracer] = None + self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() + self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = ( + nullcontext() + ) + self.stack_trace = stack_trace + + def _checkpoint_modes(self) -> list[Any]: + return [ + self.fake_tensor_mode, + self.proxy_mode, + self.proxy_function_mode, + self.fx_tracer, + self.python_dispatcher_mode, + self.torch_fn_metadata_mode, + ] + + def _restore_modes( + self, + prev_fake_tensor_mode: Optional[FakeTensorMode], + prev_proxy_mode: Union[nullcontext, ProxyTorchDispatchMode], + prev_proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode], + prev_fx_tracer: Optional[PythonKeyTracer], + prev_python_dispatcher_mode: Union[nullcontext, Any], + prev_torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode], + ) -> None: + self.fake_tensor_mode = prev_fake_tensor_mode + self.proxy_mode = prev_proxy_mode + self.proxy_function_mode = prev_proxy_function_mode + self.fx_tracer = prev_fx_tracer + self.python_dispatcher_mode = prev_python_dispatcher_mode + self.torch_fn_metadata_mode = prev_torch_fn_metadata_mode + + @contextmanager + def _init_modes_from_inputs( + self, f: Callable, args: tuple[object, ...] + ) -> Generator[None, None, None]: + prev_modes = self._checkpoint_modes() + try: + # Avoid importing sympy at a module level + from .symbolic_shapes import ShapeEnv + + if hasattr(f, "_orig_mod") and self.record_module_stack: + scope_root = f._orig_mod + # _ModuleStackTracer always try to preserve stack trace + self.fx_tracer = _ModuleStackTracer(scope_root) + else: + self.fx_tracer = PythonKeyTracer() + self.fx_tracer.stack_trace = self.stack_trace + + if self.tracing_mode == "fake": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=True, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=ShapeEnv(), + static_shapes=True, + ) + self.fake_tensor_mode = fake_tensor_mode + elif self.tracing_mode == "symbolic": + import torch._dynamo + + fake_tensor_mode = torch._dynamo.utils.detect_fake_mode(args) + if fake_tensor_mode is None: + shape_env = ShapeEnv() + import torch._functorch.config as _config + + with _config.patch(fake_tensor_allow_unsafe_data_ptr_access=False): + fake_tensor_mode = FakeTensorMode( + allow_fallback_kernels=False, + allow_non_fake_inputs=self._allow_non_fake_inputs, + shape_env=shape_env, + ) + assert fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) + self.fake_tensor_mode = fake_tensor_mode + else: + if not self.tracing_mode == "real": + raise AssertionError( + f"Unexpected tracing type: {self.tracing_mode}" + ) + + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _construct_modes_with_fx_tracer(self, fx_tracer: _ProxyTracer) -> None: + self.proxy_mode = ProxyTorchDispatchMode( + fx_tracer, + self.tracing_mode, + pre_dispatch=self.pre_dispatch, + _allow_fake_constant=self._allow_fake_constant, + _error_on_data_dependent_ops=self._error_on_data_dependent_ops, + ) + + if self.pre_dispatch: + self.proxy_function_mode = PreDispatchTorchFunctionMode(fx_tracer) + + # pre-autograd tracing uses per-dispatch-key modes, + # which requires the python dispatcher + if self.tracing_mode == "symbolic" or self.pre_dispatch: + self.python_dispatcher_mode = enable_python_dispatcher() + + self.torch_fn_metadata_mode = TorchFunctionMetadataMode(fx_tracer) + + @contextmanager + def _init_modes_from_parent( + self, parent_tracer: _MakefxTracer + ) -> Generator[None, None, None]: + # By default, subtracer creates new modes based on parent tracer's config. + # However, there are cases where we want to share the same modes with parent tracer + # For example, fake_tensor_mode, we want the example value's fake_mode of parent graph and subgraphs to be the same. + prev_modes = self._checkpoint_modes() + try: + self.fake_tensor_mode = parent_tracer.fake_tensor_mode + + def _create_sub_fx_tracer(parent_tracer: _ProxyTracer) -> PythonKeyTracer: + if type(parent_tracer) == PythonKeyTracer: + return PythonKeyTracer() + elif type(parent_tracer) == _ModuleStackTracer: + return _ModuleStackTracer(parent_tracer.scope_root) + else: + raise RuntimeError( + f"Unexpected tracer type: {type(parent_tracer)}." + ) + + assert parent_tracer.fx_tracer is not None + self.fx_tracer = _create_sub_fx_tracer(parent_tracer.fx_tracer) + self._construct_modes_with_fx_tracer(self.fx_tracer) + yield + finally: + self._restore_modes(*prev_modes) + + def _trace_inner(self, f: Callable, *args: object) -> GraphModule: + # TODO: We need to explicitly import torch._dynamo before calling dispatch_trace, + # because dispatch_trace will introduce the lazy import of torch._dynamo, + # and some contexts set before calling dispatch_trace will cause problems with the import of torch._dynamo, + # such as some torch API(torch.ones and so on) in populate_builtin_to_tensor_fn_map() will be affected + # by the context set before dispatch_trace. + import torch._dynamo + + phs = pytree.tree_map(lambda _: torch.fx._symbolic_trace.PH, args) + + def _wrap_fake(args: T) -> T: + arg_count = 0 + + def inner_wrap_fake(x: object) -> object: + nonlocal arg_count + # TODO: it would be nice to line these up with the names + # FX will choose for the placeholders, but we don't + # actually know what the names will be at this point yet + # NB: the Source here is actually meaningless + from torch._dynamo.source import ConstantSource + + assert self.fake_tensor_mode is not None + source = ConstantSource(f"input{arg_count}") + if isinstance(x, Tensor): + arg_count += 1 + return self.fake_tensor_mode.from_tensor(x, source=source) + # NB: don't match on bools + elif type(x) is int and self.tracing_mode == "symbolic": + assert self.fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) + return self.fake_tensor_mode.shape_env.create_symintnode( + self.fake_tensor_mode.shape_env.create_symbol( + x, source, positive=None + ), + hint=x, + source=source, + ) + elif isinstance(x, torch.ScriptObject): + return torch._library.fake_class_registry.maybe_to_fake_obj( + self.fake_tensor_mode, x + ) + + assert not isinstance(x, FakeScriptObject), ( + f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + ) + return x + + wrap_fn_map = { + "real": lambda x: x, + "fake": inner_wrap_fake, + "symbolic": inner_wrap_fake, + } + return pytree.tree_map(wrap_fn_map[self.tracing_mode], args) + + def _wrap_func(f: Callable[_P, R], phs: Sequence[PHBase]) -> Callable[_P, R]: + if ( + not hasattr(inspect.unwrap(f), "__code__") + or inspect.unwrap(f).__code__.co_flags & inspect.CO_VARARGS + ): + # FX doesn't support varargs, so we gotta fake up a wrapper + # TODO: Would be nice to fix this at the source... + return fake_signature(f, len(phs)) + return f + + args = _wrap_fake(args) + func = _wrap_func(f, phs) + # We disable the autocast cache as the autocast cache causes type conversions on parameters to + # check a cache, which introduces untracked tensors into the graph + # + # We also disable tracing by any other tensor proxy-based tracers except the current. The + # purpose of `make_fx` is to produce graphmodules as a side effect; its internal execution is + # thus irrelevant to any external functional trace. + proxy_mode: ProxyTorchDispatchMode = typing.cast( + ProxyTorchDispatchMode, self.proxy_mode + ) + with ExitStack() as stack: + stack.enter_context(decompose(self.decomposition_table)) + if self.fake_tensor_mode: + stack.enter_context(self.fake_tensor_mode) + stack.enter_context(self.python_dispatcher_mode) + stack.enter_context(self.proxy_function_mode) + stack.enter_context(self.torch_fn_metadata_mode) + stack.enter_context(proxy_mode) + stack.enter_context(disable_autocast_cache()) + stack.enter_context(_set_make_fx_tracer(self)) + + assert self.fx_tracer is not None + try: + t = dispatch_trace( + wrap_key(func, args, self.fx_tracer, self.pre_dispatch), + tracer=self.fx_tracer, + concrete_args=tuple(phs), + ) + except Exception: + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "make_fx_fail_partial", + "encoding": "string", + }, + payload_fn=lambda: self.fx_tracer.graph.python_code( # type: ignore[union-attr] + root_module="self", + verbose=True, + include_stride=True, + include_device=True, + ).src, + ) + raise + + # TODO: kind of a bad way to do it, should maybe figure out a better way + if self.tracing_mode == "symbolic": + assert self.fake_tensor_mode is not None + t.shape_env = self.fake_tensor_mode.shape_env # type: ignore[assignment] + return t + + def trace(self, f: Callable, *args: object) -> fx.GraphModule: + with self._init_modes_from_inputs(f, args): + return self._trace_inner(f, *args) + + def trace_subgraph(self, f: Callable, *args: object) -> GraphModule: + # Create a new tracer based on parent's config + sub_tracer = _MakefxTracer( + self.decomposition_table, + "real", + self._allow_non_fake_inputs, + self.pre_dispatch, + self.record_module_stack, + self._allow_fake_constant, + self._error_on_data_dependent_ops, + ) + with sub_tracer._init_modes_from_parent(self): + return sub_tracer._trace_inner(f, *args) + + +_CURRENT_MAKE_FX_TRACER: Optional[_MakefxTracer] = None + + +@contextmanager +def _set_make_fx_tracer(tracer: _MakefxTracer) -> Generator[None, None, None]: + global _CURRENT_MAKE_FX_TRACER + prev_tracer = _CURRENT_MAKE_FX_TRACER + try: + _CURRENT_MAKE_FX_TRACER = tracer + yield + finally: + _CURRENT_MAKE_FX_TRACER = prev_tracer + + +def make_fx( + f: Callable, + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, + tracing_mode: str = "real", + _allow_non_fake_inputs: bool = False, + *, + pre_dispatch: bool = False, + record_module_stack: bool = False, + _allow_fake_constant: bool = False, + _error_on_data_dependent_ops: bool = True, + stack_trace: bool = False, +) -> Callable[..., GraphModule]: + """ + Given a function f, return a new function which when executed with valid + arguments to f, returns an FX GraphModule representing the set of operations that + were executed during the course of execution. + + If stack_trace is True, the stack_trace will be preserved on node.meta["stack_trace"] + """ + + assert tracing_mode in ["real", "fake", "symbolic"] + + from torch._inductor import config + + make_fx_tracer = _MakefxTracer( + decomposition_table, + tracing_mode, + _allow_non_fake_inputs, + pre_dispatch, + record_module_stack, + _allow_fake_constant, + _error_on_data_dependent_ops, + stack_trace=stack_trace or config.trace.enabled, + ) + + @functools.wraps(f) + def wrapped(*args: object) -> GraphModule: + return make_fx_tracer.trace(f, *args) + + return wrapped + + +def get_torch_dispatch_modes() -> list[TorchDispatchMode]: + return torch.utils._python_dispatch._get_current_dispatch_mode_stack() + + +# TODO: this is a legacy name, there is only ever one proxy mode as it's an +# infra mode +def get_innermost_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + return get_proxy_mode() + + +def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: + """ + Current the currently active proxy tracing mode, or None if + we are not currently tracing. This includes pre-dispatch proxy + tracing. + """ + pre_dispatch_mode = torch._ops._get_dispatch_mode_pre_dispatch( + torch._C._TorchDispatchModeKey.PROXY + ) + mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) + assert pre_dispatch_mode is None or mode is None, ( + f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + ) + return pre_dispatch_mode or mode + + +def handle_sym_dispatch( + func: Callable[_P, R], + args: _P.args, # type: ignore[valid-type] # not allowed to use _P.args here + kwargs: _P.kwargs, # type: ignore[valid-type] # not allowed to use _P.kwargs here +) -> R: + """ + Call into the currently active proxy tracing mode to do a + SymInt/SymFloat/SymBool dispatch trace on a function that operates on + these arguments. + """ + mode = get_proxy_mode() + assert mode + # Have to do it manually, because we're not doing the normal torch + # dispatch machinery which disables it for us + with disable_proxy_modes_tracing(): + # TODO: properly compute types + types: list[type] = [] + return mode.__sym_dispatch__(func, types, args, kwargs) # type: ignore[arg-type, return-value] + + +@contextmanager +def disable_proxy_modes_tracing() -> Generator[ProxyTorchDispatchMode, None, None]: + return _disable_infra_mode(torch._C._TorchDispatchModeKey.PROXY) + + +def maybe_handle_decomp( + proxy_mode: ProxyTorchDispatchMode, + op: OpOverload, + args: tuple[object, ...], + kwargs: dict[str, object], +) -> object: + from torch._inductor.compiler_bisector import CompilerBisector + + if op in CURRENT_DECOMPOSITION_TABLE: + if CompilerBisector.disable_subsystem( + "aot_eager_decomp_partition", "decomposition", lambda: repr(op) + ): + return NotImplemented + + with proxy_mode: + proxy_mode.decomp_layers += 1 + out = CURRENT_DECOMPOSITION_TABLE[op](*args, **kwargs) + proxy_mode.decomp_layers -= 1 + return out + + return NotImplemented + + +def get_isolated_graphmodule( + func: Callable, + args: tuple[object, ...], + kwargs: dict[str, object], + tracing_mode: str = "real", + decomposition_table: Optional[Mapping[OpOverload, Callable]] = None, +) -> GraphModule: + """A helper function used to get the GraphModule for the given func. + + It's expected to be used in the ProxyTensor tracing context. + It detaches the args and kwargs from the current tracer so that the trace of + the current graph module can be created without any side-effects. + """ + wrapped, all_args = wrapper_and_args_for_make_fx(func, args, kwargs) + + with disable_proxy_modes_tracing(): + gm = make_fx( + wrapped, decomposition_table=decomposition_table, tracing_mode=tracing_mode + )(all_args) + return gm + + +def _set_unbacked_bindings(out: object, out_proxy: _NestedProxys) -> None: + """A helper function for setting up unbacked_bindings on the destination FX graph.""" + from .symbolic_shapes import compute_unbacked_bindings + + # Can't use detect_fake_mode here, + # + # python test/distributed/_tensor/test_dtensor_compile.py -k + # test_tp_compile_fullgraph_is_seq_parallel_False + # + # will fail. Very strange, it probably isn't right for them to be using + # two fake modes there... + fake_mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FAKE) + if fake_mode and fake_mode.shape_env: + if symbol_to_path := compute_unbacked_bindings(fake_mode.shape_env, out): + assert isinstance(out_proxy, Proxy), out_proxy + out_proxy.node.meta["unbacked_bindings"] = symbol_to_path diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/recording.py b/phivenv/Lib/site-packages/torch/fx/experimental/recording.py new file mode 100644 index 0000000000000000000000000000000000000000..aa9ba4393186fd88603fae8085ee67c75b023142 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/recording.py @@ -0,0 +1,529 @@ +# mypy: allow-untyped-defs +import functools +import inspect +import itertools +import logging +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import torch +import torch.utils._pytree as pytree + + +log = logging.getLogger(__name__) +trace_shape_events_log = torch._logging.getArtifactLogger( + __name__, "trace_shape_events" +) + + +__all__ = [ + "ShapeEnvEvent", + "record_shapeenv_event", + "replay_shape_env_events", + "FakeTensorMeta", + "shape_env_check_state_equal", + "NotEqualError", +] + +# [Note: Recording ShapeEnv Events] +# ================================= +# +# What is a ShapeEnv event? +# ------------------------- +# We consider a ShapeEnv event every function call (ShapeEnv method or +# independent function) that modifies the state of the ShapeEnv instance. +# Such calls are recorded alongside their positional and keyword arguments, +# so that it may be replayed over a different ShapeEnv instance. +# +# See [Note: ShapeEnv State Equality] for what is considered the state +# of a ShapeEnv instance. +# +# What is it for? +# --------------- +# ShapeEnv events recording is used for reconstructing the ShapeEnv in an +# arbitrary state in time. +# +# Being able to arbitrarily replay events like so is useful, mainly for +# translation validation bisection. i.e. if a ValidationException has been +# raised, find the earliest point in time where the translation validation +# fails. +# +# Besides that, it also allows us to inspect the given instance and, +# for example, check the guards that would actually be issued at that point. +# +# What kind of arguments can be stored in an event? +# ------------------------------------------------- +# There's no specific rule for what cannot be used as an argument. +# That said, pay special attention to the following cases: +# +# 1. Tensor inputs: there are some tests that check whether the inputs +# were garbage collected after execution. These will fail if there's +# an event that is holding a reference to those inputs. +# +# 2. ShapeEnv arguments: if there is an argument of ShapeEnv type, that +# will be automatically replaced by the new given ShapeEnv instance. +# +# 3. SymTypes arguments: they also hold references to ShapeEnv. So, +# whenever we see them, we create a new instance, replacing the +# ShapeEnv reference. +# +# 4. FX nodes: specifically, FX nodes from the FX graph for symbolic +# shapes. That argument must be replaced when replaying the event at +# ShapeEnvEvent.run, since it has to reference a node from the given +# instance, and not from the recorded instance. + + +# Event class for reconstructing ShapeEnv at arbitrary time. +# +# Represents a method call that mutates ShapeEnv in a way that affects the +# issued guards, when ShapeEnv.produce_guards is called. +@dataclass +class ShapeEnvEvent: + # ShapeEnv method. + f: Callable + + # Arguments and keyword arguments called with. + args: Optional[list[Any]] = None + kwargs: Optional[dict[str, Any]] = None + + # List of tracked_fakes at the time the method was called. + tracked_fakes: Optional[list[Any]] = None + + # Name of the captured event. + # Used for special handling of particular methods. + name: Optional[str] = None + + # Replay itself, but using shape_env as self. + def run(self, shape_env=None) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + is_symbolic, + ShapeEnv, + SymTypes, + ) + + # Special handling for the constructor event. + if self.f is ShapeEnv: + assert shape_env is None and self.args is None and self.kwargs is not None + return ShapeEnv(**self.kwargs) + + assert shape_env is not None + args = list(self.args or []) + kwargs = dict(self.kwargs or {}) + + # Replace any argument of type ShapeEnv by the given one. + args, kwargs = pytree.tree_map_only( + ShapeEnv, lambda _: shape_env, (args, kwargs) + ) + + # Replace any argument of type SymTypes by a new instance, + # replacing its ShapeEnv reference. + args, kwargs = pytree.tree_map_only( + lambda x: isinstance(x, SymTypes) and is_symbolic(x), + lambda a: type(a)(a.node.with_shape_env(shape_env)), + (args, kwargs), + ) + + # Converts FX nodes using the mapping argument. + def maybe_convert_node(x: Any) -> Any: + if not isinstance(x, torch.fx.Node): + # Don't do anything to x if it's not an FX node. + return x + + # If, at some point, we created an FX node, it means that translation validation is on. + # It also means we are building an FX graph for symbolic shapes at shape_env.graph, and + # we are tracking node names at shape_env.name_to_node. + assert hasattr(shape_env, "name_to_node") + name_to_node = shape_env.name_to_node # type: ignore[attr-defined] + assert x.name in name_to_node + return name_to_node[x.name] + + # Replaces the value of an specific argument by the result of fn. + def replacearg(index: int, key: str, fn: Callable): + if index < len(args): + args[index] = fn(args[index]) + if key in kwargs: + kwargs[key] = fn(kwargs[key]) + + if self.is_create_fx_call_function(): + # ShapeEnv.create_fx_call_function: + # "args" parameter is a tuple of FX nodes from the FX graph of the old ShapeEnv. + # They must be replaced, since a "call_function" FX node with this tuple as argument + # will be added to the FX graph of the new shape_env. + replacearg( + index=2, + key="args", + fn=lambda args: tuple(maybe_convert_node(a) for a in args), + ) + if self.is_evaluate_expr() or self.is_defer_runtime_assert(): + # ShapeEnv.evaluate_expr and ShapeEnv.guard_or_defer_runtime_assert: + # "fx_node" parameter is an (optional) FX node that represents the evaluate expression. + # They must be replaced, since it will be part of a "call_function" FX node for + # torch._assert, which will be added to the FX graph of the new shape_env. + replacearg(index=3, key="fx_node", fn=maybe_convert_node) + + # Actually call the method with the converted arguments. + return self.f(*args, **kwargs) + + def __str__(self) -> str: + name = self.name if self.name is not None else self.f.__name__ + return f"event: {name} ({self.args}, {self.kwargs})" + + def is_create_fx_call_function(self) -> bool: + return self.name == "_create_fx_call_function" + + def is_evaluate_expr(self) -> bool: + return self.name == "evaluate_expr" + + def is_defer_runtime_assert(self) -> bool: + return self.name == "guard_or_defer_runtime_assert" + + +NEST = 0 + + +# Extracts a ShapeEnv instance inside args and kwargs. +# Specifically, it looks for: +# 1. ShapeEnv arguments +# 2. SymInt, SymFloat, or SymBool arguments +# If we find more than one object of any of the above types, we +# also check that the ShapeEnv instance is the same for all of them. +def _extract_shape_env_and_assert_equal(args, kwargs): + from torch.fx.experimental.symbolic_shapes import is_symbolic, ShapeEnv, SymTypes + + def assert_equal(old: Optional[ShapeEnv], new: ShapeEnv) -> ShapeEnv: + if old is not None: + assert old is new, "call with different ShapeEnv" + return new + + shape_env = None + for val in itertools.chain(args, kwargs.values()): + if isinstance(val, ShapeEnv): + shape_env = assert_equal(shape_env, val) + if isinstance(val, SymTypes) and is_symbolic(val): + shape_env = assert_equal(shape_env, val.node.shape_env) + + return shape_env + + +# Decorator for recording the given function as a replayable event. +# +# This decorator should be used at every function that mutates the state of +# ShapeEnv in some way that affects the resulting issued guards (i.e. when +# ShapeEnv.produce_guards is called). +# +# save_tracked_fakes: saves a snapshot of the TrackedFake list. +# This is used when calling ShapeEnv.produce_guards at arbitrary points in time. +# +# name: the name of the function being recorded. Normally (and by default) this +# is taken from the decorated function but can be set if you need to override +# it. +# +# When to save the list of TrackedFake? +# ===================================== +# We should save the list of TrackedFake whenever the translation validation +# bisection may actually stop and call the produce_guards method at the moment +# right after the recorded function was played. In other words, since the +# bisection bisects through torch._assert calls, we should save in all methods +# that adds a torch._assert call to the symbolic shapes FX graph. +# +# At the moment, there are 2 methods that save the list: +# - ShapeEnv.evaluate_expr +# - ShapeEnv.guard_or_defer_runtime_assert +def record_shapeenv_event( + *, save_tracked_fakes: bool = False, name: Optional[str] = None +) -> Callable: + def decorator(fn: Callable) -> Callable: + assert callable(fn) + args = inspect.getfullargspec(fn).args + assert args and args[0] == "self", ( + "record_shapeenv_event should only wrap methods on ShapeEnv; refactor your " + "code so that it calls into a method on ShapeEnv" + ) + nonlocal name + if name is None: + name = fn.__name__ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + assert isinstance(args[0], ShapeEnv) + + global NEST + + trace_shape_events_log.debug( + "%scall %s(*%r, **%r)", " " * NEST, name, args[1:], kwargs + ) + NEST += 1 + + def retlog(r): + trace_shape_events_log.debug("%s-> %s", " " * (NEST - 1), r) + return r + + shape_env = args[0] + + try: + if not shape_env.should_record_events or shape_env.is_recording: # type: ignore[has-type] + # If ShapeEnv is already recording an event, call the wrapped + # function directly. + # + # NB: here, we skip the check of whether all ShapeEnv instances + # are equal, in favor of a faster dispatch. + return retlog(fn(*args, **kwargs)) + + # Retrieve an instance of ShapeEnv. + # Assumption: the collection of args and kwargs may not reference + # different ShapeEnv instances. + self = _extract_shape_env_and_assert_equal(args, kwargs) + + # If we are calling this function without any ShapeEnv instance + # alive in its arguments, we don't record and call the original. + if self is None: + return retlog(fn(*args, **kwargs)) + + # Otherwise, start recording and call the function. + with self._recording(): + # Take a snapshot of the current tracked_fakes. + tracked_fakes = ( + self._snapshot_tracked_fakes() if save_tracked_fakes else None + ) + # Record the event for 'fn'. + event = ShapeEnvEvent( + fn, + list(args), + kwargs, + tracked_fakes, + name=name, + ) + # Play the event on this ShapeEnv. + # NB: It's important to put the event first, because running + # the event can trigger internal events that must be ordered + # after this event. However, if an exception happens, we do + # NOT want to have the event in the list, so pop it off from + # the record if an error happened + self.events.append(event) + try: + return retlog(event.run(self)) + except Exception: + self.events.pop() + raise + + except Exception: + if not shape_env.should_record_events or shape_env.is_recording: + # If ShapeEnv is disabled or already recording an event, re-raise the exception without logging. + raise + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) + raise + + finally: + NEST -= 1 + + return wrapper + + return decorator + + +# Replays the ShapeEnvEvents list. +# It assumes the first event is the constructor call. +# +# fn: transforms an old FX node into one corresponding to the newly created ShapeEnv. +def replay_shape_env_events(events): + from torch.fx.experimental.symbolic_shapes import ShapeEnv + + constructor_event = events[0] + assert constructor_event.f == ShapeEnv + + # Constructs the new ShapeEnv. + shape_env = constructor_event.run() + + for event in events[1:]: + try: + # Actually replays each event. + # We need to call create_mapping_fn every time, since the node list might + # change after each event is replayed. + event.run(shape_env) + except Exception: + log.error("failed when running event: %s", event) + raise + + return shape_env + + +# FakeTensor metadata. +# This is to be used in place of FakeTensor placeholders when calling +# ShapeEnv.produce_guards. +@dataclass +class FakeTensorMeta: + tensor_size: tuple[Union[int, torch.SymInt], ...] + tensor_stride: tuple[Union[int, torch.SymInt], ...] + tensor_storage_offset: Union[int, torch.SymInt] + is_nested: bool + + def size(self) -> tuple[Union[int, torch.SymInt], ...]: + return self.tensor_size + + def stride(self) -> tuple[Union[int, torch.SymInt], ...]: + return self.tensor_stride + + def storage_offset(self) -> Union[int, torch.SymInt]: + return self.tensor_storage_offset + + def dim(self) -> int: + return len(self.tensor_size) + + @staticmethod + def from_fake(fake) -> "FakeTensorMeta": + return FakeTensorMeta( + fake.size(), fake.stride(), fake.storage_offset(), fake.is_nested + ) + + +# [Note: ShapeEnv State Equality] +# =============================== +# +# What is considered ShapeEnv state? +# ---------------------------------- +# We consider to be the state of a ShapeEnv instance everything that +# is not in the inline tuple inside remove_nonstate_variables function. +# That is: the fields within ShapeEnv that modify the flow of execution +# of the program. +# +# So, for example: the replacements field might influence on how an +# expression is simplified. That, in turn, may result in a guard being +# statically known (i.e. not added). +# +# On the other hand, var_to_stack serves only changes what is printed +# in the screen, i.e. used only for debugging purposes. Therefore, we +# should not consider it when comparing states. +# +# What to do on NotEqualError? +# ---------------------------- +# Here are a few possible causes for getting a NotEqualError raised: +# +# 1. New field that does not belong in the ShapeEnv state. +# For example: log field of type ShapeEnvLoggerAdapter. Different +# ShapeEnv instances will always have different ShapeEnvLoggerAdapter +# instances, i.e. equality comparison would fail. +# Solution: add it to the inlined tuple inside remove_nonstate_variables +# function inside check_equal method. +# +# 2. New field that is not directly comparable across instances. +# For example: guards field of type List[ShapeGuard]. More specifically, +# the ShapeGuard type holds an expression and a stack information +# for debugging purposes. When replaying the even on a new ShapeEnv +# instance, the stack would be different, which would trigger this error. +# Solution: add a special case to the map_value function inside +# check_equal function. +# +# 3. Mutation of ShapeEnv on some not recorded function. +# If a mutation of the state of ShapeEnv happens inside a function +# that is not recorded (or that no caller in the stack is recorded), +# then, the replayed ShapeEnv won't catch that. +# Solution: decorate the function with record_shape_env_event. + + +# Checks whether the state of two ShapeEnv are equal w.r.t. the guards +# returned by ShapeEnv.produce_guards. +def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value): + # Collect and remove variables that don't necessarily represent the state + # of a ShapeEnv. Note: we copy the dictionary so that we don't modify the + # instance itself. + env1_vars = vars(env1).copy() + env2_vars = vars(env2).copy() + + for v in non_state_variable_names: + if v in env1_vars: + env1_vars.pop(v) + if v in env2_vars: + env2_vars.pop(v) + + # Function for transforming the mismatched values into string. + # Needed, since dict and set entries order might not be the same every time. + def value_to_str(value: Any) -> str: + if isinstance(value, dict): + return ( + "{" + + ", ".join(f"{k}: {value[k]}" for k in sorted(value.keys(), key=str)) + + "}" + ) + if isinstance(value, set): + return "{" + ", ".join(f"{v}" for v in sorted(value)) + "}" + return str(value) + + # Compares env1_vars with env2_vars. + # Here, we allow the value of each field to be mapped, so that we appropriately + # compare the two values. + def compare_vars( + map_value: Callable[[str, Any], Any], + ) -> list[tuple[str, str, str]]: + env1_set, env2_set = set(env1_vars), set(env2_vars) + + # First, compare the set of keys in each vars dictionary. + if env1_set != env2_set: + raise NotEqualError( + "field set mismatch:", + [ + ( + "found unique fields:", + str(sorted(env1_set - env2_set)), + str(sorted(env2_set - env1_set)), + ), + ], + ) + + # Then, sort the keys, and compare the mapped values of each key. + sorted_keys = list(env1_set) + sorted_keys.sort() + + mapped_dict = [ + (k, map_value(k, env1_vars[k]), map_value(k, env2_vars[k])) + for k in sorted_keys + ] + + # Return a list of tuples representing the fields that did not match + # alongside their respective mapped values. + return [ + (f"{k}: values don't match.", value_to_str(val1), value_to_str(val2)) + for k, val1, val2 in mapped_dict + if val1 != val2 + ] + + # Accumulate the mismatching fields. + errors = compare_vars(map_value) + + if len(errors) > 0: + raise NotEqualError("field values don't match:", errors) + + +class NotEqualError(Exception): + def __init__( + self, + msg: str, + mismatched: list[tuple[str, str, str]], + ) -> None: + details = "\n".join( + [ + "\n".join( + [ + f"==> {inner_msg}", + f" > Left: {str1}", + f" > Right: {str2}", + ] + ) + for inner_msg, str1, str2 in mismatched + ] + ) + + super().__init__( + f"""\ +ShapeEnv not equal: {msg} + +{details} +""" + ) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/refinement_types.py b/phivenv/Lib/site-packages/torch/fx/experimental/refinement_types.py new file mode 100644 index 0000000000000000000000000000000000000000..c78a1783084553a4a0649247938d26561889c038 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/refinement_types.py @@ -0,0 +1,16 @@ +class Equality: + def __init__(self, lhs: object, rhs: object): + self.lhs = lhs + self.rhs = rhs + + def __str__(self) -> str: + return f"{self.lhs} = {self.rhs}" + + def __repr__(self) -> str: + return f"{self.lhs} = {self.rhs}" + + def __eq__(self, other: object) -> bool: + if isinstance(other, Equality): + return self.lhs == other.lhs and self.rhs == other.rhs + else: + return False diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/rewriter.py b/phivenv/Lib/site-packages/torch/fx/experimental/rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..bee598694c0dfc392dc3ab5a4ec9db556402aff5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/rewriter.py @@ -0,0 +1,143 @@ +# mypy: allow-untyped-decorators +# mypy: allow-untyped-defs +import ast +import copy +import functools +import inspect +import textwrap +from types import FunctionType +from typing import Any, Callable, cast, Optional, Union + +import torch +from torch._sources import normalize_source_lines +from torch.fx._symbolic_trace import Tracer +from torch.fx.graph import Graph + + +class AST_Rewriter(ast.NodeTransformer): + """ + Take a FunctionType object representing a `forward` method, then + perform an AST rewrite to swap out nodes that are not symbolically + traceable with a callsite to the FX alternative. + + To support swapping out an AST node, define a new `visit` method on + that node. For more details, see: + https://docs.python.org/3/library/ast.html#ast.NodeTransformer + """ + + # This function checks for new keys added in the globals dict. TorchDynamo + # can insert new keys in the global dict and upset the check. Therefore, put + # a disable here. This function is an optimization pass and not really + # suitable for dynamo tracing anyways. + @torch._dynamo.disable + def rewrite(self, fn: FunctionType): + # Normalize the source lines + sourcelines, _ = inspect.getsourcelines(fn) + sourcelines = normalize_source_lines(sourcelines) + source = "".join(sourcelines) + normalized_str = textwrap.dedent(source) + + # Rewrite the original AST + source_ast = ast.parse(normalized_str) + dest_ast = ast.fix_missing_locations(self.visit(source_ast)) + + # Pull out the compiled function from the newly-created Module + code = compile(dest_ast, "", "exec") + globals_dict = copy.copy(fn.__globals__) + keys_before = set(globals_dict.keys()) + exec(code, globals_dict) + new_keys = list(set(globals_dict.keys()) - keys_before) + assert len(new_keys) == 1 + fn_compiled = globals_dict[new_keys[0]] + + # return the compiled function with the original globals + def change_func_globals(f, globals): + """Based on https://stackoverflow.com/a/13503277/2988730 (@unutbu)""" + # __globals__ is a private member of the function class + # so we have to copy the function, f, all of its member, except f.__globals__ + g = FunctionType( + f.__code__, + globals, + name=f.__name__, + argdefs=f.__defaults__, + closure=f.__closure__, + ) + g = functools.update_wrapper(g, f) + g.__kwdefaults__ = copy.copy(f.__kwdefaults__) # type:ignore[attr-defined] + return g + + # Return the correct FunctionType object + return change_func_globals(fn_compiled, globals=fn.__globals__) + + def visit_Assert(self, node): + """ + Swap out the Assert node (Python's `assert`) with a callsite to the + symbolically-traceable torch._assert function + """ + # Create the Call node + n = ast.parse("torch._assert()", mode="eval") + assert isinstance(n, ast.Expression) + call_node = n.body + assert isinstance(call_node, ast.Call) + msg = node.msg if node.msg else ast.Constant(value="", kind=None) + call_node.args = [node.test, msg] + + # Ensure that the new node conforms to the Python AST grammar + expr_wrapper = ast.Expr(value=call_node) + + # Return the new Call node to signify that we want to use it as + # a replacement for the original _assert node + return ast.copy_location(expr_wrapper, node) + + def visit_AnnAssign(self, node): + """ + Swap out Python's AnnAssign with an Assign node where the annotation function is called. + Example: + Original: + y: Tensor_Type(1,2,3, Dyn) = f2(x) + Output: + y = annotate(f2(x),Tensor_Type((1,2,3,Dyn))) + """ + return ast.Assign( + targets=[node.target], + value=ast.Call( + func=ast.Name(id="annotate", ctx=ast.Load()), + args=[node.value, node.annotation], + keywords=[], + ), + ) + + +class RewritingTracer(Tracer): + def trace( + self, + root: Union[torch.nn.Module, Callable], + concrete_args: Optional[dict[str, Any]] = None, + ) -> Graph: + return super().trace(_rewrite(root), concrete_args) + + +def _rewrite(fn: Union[torch.nn.Module, Callable]) -> Union[torch.nn.Module, Callable]: + if isinstance(fn, torch.nn.Module): + # Rewrite this module's `forward` as well as the `forward`s of + # all of this module's recursive descendents. Return the new, + # rewritten module hierarchy. + def rewrite_module(m: torch.nn.Module): + class RewrittenModule(torch.nn.Module): + def __init__(self, orig): + super().__init__() + for k, v in orig.__dict__.items(): + if isinstance(v, torch.nn.Module): + self.__dict__[k] = copy.copy(rewrite_module(v)) + else: + self.__dict__[k] = copy.copy(v) + + RewrittenModule.forward = AST_Rewriter().rewrite( + cast(FunctionType, m.forward) + ) + return RewrittenModule(m) + + return rewrite_module(fn) + else: + # Rewrite this single free function + return AST_Rewriter().rewrite(cast(FunctionType, fn)) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/schema_type_annotation.py b/phivenv/Lib/site-packages/torch/fx/experimental/schema_type_annotation.py new file mode 100644 index 0000000000000000000000000000000000000000..ad31316bae10a5ffbada88041d0fa1d7a73fe705 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/schema_type_annotation.py @@ -0,0 +1,145 @@ +# mypy: allow-untyped-defs +import inspect +from typing import Any, Optional + +import torch +import torch.fx +from torch._jit_internal import boolean_dispatched +from torch.fx import Transformer +from torch.fx.node import Argument, Target +from torch.fx.operator_schemas import _torchscript_type_to_python_type + + +class AnnotateTypesWithSchema(Transformer): + """ + Use Python function signatures to annotate types for `Nodes` within an FX graph. + This pulls out Python function signatures for: + + 1. Standard `torch.nn` Module calls + 2. `torch.nn.functional` calls + 3. Attribute fetches via `get_attr` + + Example usage: + + m = torchvision.models.resnet18() + + traced = torch.fx.symbolic_trace(m) + + traced = AnnotateTypesWithSchema(traced).transform() + + """ + + def __init__( + self, + module: torch.nn.Module, + annotate_functionals: bool = True, + annotate_modules: bool = True, + annotate_get_attrs: bool = True, + ): + super().__init__(module) + self.annotate_functionals = annotate_functionals + self.annotate_modules = annotate_modules + self.annotate_get_attrs = annotate_get_attrs + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + python_ret_type = None + if self.annotate_functionals and target.__module__ == "torch.nn.functional": + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + # TODO: can we emit the union of these? What are the implications on TorchScript + # compilation? + if ( + inspect.signature(if_true).return_annotation + != inspect.signature(if_false).return_annotation + ): + return super().call_function(target, args, kwargs) + target_for_analysis = if_true + + python_ret_type = self._extract_python_return_type(target_for_analysis) + + return_proxy = super().call_function(target, args, kwargs) + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) + return return_proxy + + def call_module( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ): + python_ret_type = None + assert isinstance(target, str) + submod = self.fetch_attr(target) + if self.annotate_modules and hasattr(submod.__class__, "__name__"): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + python_ret_type = self._extract_python_return_type(submod.forward) + return_proxy = super().call_module(target, args, kwargs) + return_proxy.node.type = ( + return_proxy.node.type if return_proxy.node.type else python_ret_type + ) + return return_proxy + + def get_attr( + self, + target: torch.fx.node.Target, + args: tuple[Argument, ...], + kwargs: dict[str, Any], + ): + attr_proxy = super().get_attr(target, args, kwargs) + + if self.annotate_get_attrs: + module_itr = self.module + assert isinstance(target, str) + atoms = target.split(".") + for i, atom in enumerate(atoms): + if not hasattr(module_itr, atom): + raise RuntimeError( + f"Node referenced nonextent target {'.'.join(atoms[:i])}!" + ) + module_itr = getattr(module_itr, atom) + + maybe_inferred_ts_type = torch._C._jit_try_infer_type(module_itr) + if maybe_inferred_ts_type.success(): + python_type = _torchscript_type_to_python_type( + maybe_inferred_ts_type.type() + ) + attr_proxy.node.type = ( + python_type if not attr_proxy.node.type else attr_proxy.node.type + ) + + return attr_proxy + + def _extract_python_return_type(self, target: Target) -> Optional[Any]: + """ + Given a Python call target, try to extract the Python return annotation + if it is available, otherwise return None + + Args: + + target (Callable): Python callable to get return annotation for + + Returns: + + Optional[Any]: Return annotation from the `target`, or None if it was + not available. + """ + assert callable(target) + try: + sig = inspect.signature(target) + except (ValueError, TypeError): + return None + + return ( + sig.return_annotation + if sig.return_annotation is not inspect.Signature.empty + else None + ) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/sym_node.py b/phivenv/Lib/site-packages/torch/fx/experimental/sym_node.py new file mode 100644 index 0000000000000000000000000000000000000000..106ecaad18827933c63207c4eeee42a518e11f75 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/sym_node.py @@ -0,0 +1,1847 @@ +# mypy: allow-untyped-defs + +from __future__ import annotations + + +""" +This file does three things: +- Contains the definition of SymNode +- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time +- Does not depend on sympy at import time + +As this file is imported from within torch/__init__.py we do not want it to depend on SymPy +to avoid having to load SymPy at import time, as doing so is *very* slow. +""" + + +import builtins +import functools +import inspect +import itertools +import logging +import math +import operator +import sys +from functools import lru_cache, update_wrapper +from typing import Optional, TYPE_CHECKING, Union + +import torch +import torch._logging.structured as structured + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import ( # noqa: F401 + sym_float, + sym_ite, + sym_max, + sym_min, + sym_not, + SymBool, + SymFloat, + SymInt, +) +from torch._logging import dtrace_structured + + +if TYPE_CHECKING: + from torch.fx.experimental.symbolic_shapes import ShapeEnv + +log = logging.getLogger(__name__) +sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node") + + +__all__ = ["SymNode", "method_to_operator", "magic_methods"] + + +from torch.types import py_sym_types as SymTypes + + +def _to_symtype(t): + if t is bool: + return SymBool + if t is int: + return SymInt + if t is float: + return SymFloat + return t + + +# TODO: An incomplete list +# 1. Set variables to be equal when we do equality +# 2. Specialize on 0/1 when we do subtraction +class SymNode: + """ + This is a type erased SymInt/SymFloat which we use to do actual operations. + End users don't touch this. Magic methods are NOT defined on this object. + """ + + # Note [optimized_summation]: indicates that SymNode is an Add expression of the form + # a + b + c + d... etc where all terms are unique symbols. This allows us to do some optimizations + # for common patterns see _optimized_add. + + # The unfortunate reason we have this here is because sympy sets __slots__ = () for add expression, + # so we cannot add the attribute directly to the sympy expression. Furthermore, we cannot use it as + # a weak dictionary key either! So instead, we attach the attribute here to the SymNode. + _optimized_summation: bool = False + + def __init__( + self, + expr, + shape_env, + pytype, + hint: Optional[Union[int, float, bool]], + constant=None, + fx_node=None, + optimized_summation=False, + ): + self._expr = expr + self.shape_env = shape_env + self.pytype = pytype + self._optimized_summation = optimized_summation + + # What's the difference between hint and constant? + # + # - A constant is known to be invariant across invocations of the model; + # it will always be this value. We only really know this when we + # encounter an honest-to-goodness literal (when wrapping it into + # a SymNode, we set constant.) Most of the time, constant is None + # + # - A hint is a *particular* value from the particular run we are + # tracing, but it may vary the next time around. It's useful to + # keep this around, as if we need a concrete value from a SymNode, + # we will return the hint and guard on the expression that produced + # it giving the same hint next time around. The hint is not + # guaranteed to be set either: if you have an unbacked SymNode, + # there won't be any hint; it was the result of some tensor-dependent + # computation, but we don't know what it actually is because we + # haven't actually run the tensor computation. + # + # If _hint is None, we will query maybe_evaluate_static(compute_hint=True) + # in hopes that we've learned enough about the unbacked symints to + # discharge the hint; otherwise, you're likely to just error out. + # + # (A previous version of this system had some optimizations to only + # recompute when it was possible we had learned enough about the + # unbacked symint that a hint was now possible, but as we added more + # potential refinements to unbacked symints this got harder to keep + # in sync, so we've deleted it for now.) + + def compute_hint(): + from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols + + # This occasionally gets exercised by, e.g., + # convert_shape_to_symint. It's just a nicety so you don't HAVE + # to have a correct hint on hand when making a SymNode. + # Don't attempt to compute for unbacked, this can be quite + # expensive. + if has_free_unbacked_symbols(self.expr): + return None + hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True) + if hint is not None: + hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint + return hint + + if hint is not None: + assert type(hint) is pytype or type(hint) is _to_symtype(pytype), ( + "Cannot create SymNode of type " + f"{pytype} with incompatible hint of type {type(hint)}" + ) + if self.shape_env and self.shape_env._translation_validation_enabled: + # This is technically not TV, but this assert is expensive so + # let's only do it when we're already doing expensive things + computed_hint = compute_hint() + assert hint == computed_hint, ( + f"{hint} != {computed_hint} (for {self.expr})" + ) + else: + hint = compute_hint() + self._hint = hint + self.constant: Optional[Union[int, float, bool]] = constant + + # Record the FX node of the current node if we are doing translation + # validation. They will be used for building the input assertions for + # the translation validation problem. + tx_validation_en = ( + self.shape_env and self.shape_env._translation_validation_enabled + ) + self.fx_node = tx_validation_en and fx_node + + def with_shape_env(self, shape_env: ShapeEnv) -> SymNode: + return SymNode( + self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node + ) + + def _value_eq(self, other: SymNode) -> bool: + # Purposely don't include the shape_env in the eq. + return ( + self._expr == other._expr + and self.pytype == other.pytype + and self._hint == other._hint + and self.constant == other.constant + and self.fx_node == other.fx_node + ) + + def _value_hash(self) -> int: + # Purposely don't include the shape_env in the hash. + return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node)) + + @property + def expr(self): + return self.shape_env.replace(self._expr) + + @property + def hint(self): + return self._hint + + def has_hint(self): + return self._hint is not None + + def require_hint(self, fallback=None): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if self._hint is None: + if fallback is not None: + # Say we have some expr like 2*u0 + s0 + # The hint will be None, since the expr contains at least 1 unbacked. + # We will: + # - replace every backed free symbol with its corresponding hint + # - replace every unbacked free symbol with the fallback + # - regenerate the expression with those symbol replacements + # Note: this is not really complete either, since right now + # this logic does not take into account any value ranges + # for the unbacked symints, we may need to beef it up at some point. + unbacked_symbols = free_unbacked_symbols(self.expr) + replacements = { + s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s] + for s in self.expr.free_symbols + } + return self.expr.xreplace(replacements) + # NB: we expect this to raise + return self.shape_env.size_hint(self.expr) + return self._hint + + def maybe_as_int(self): + if self.expr.is_number: + return int(self.expr) + else: + return None + + # NB: This does conversions, not sure if this is good or not + def maybe_as_float(self): + import sympy + + if isinstance(self.expr, sympy.Float): + return float(self.expr) + else: + return None + + def maybe_as_bool(self): + import sympy + + if self.expr is sympy.true: + return True + elif self.expr is sympy.false: + return False + else: + return None + + def is_int(self): + return self.pytype is int + + def is_float(self): + return self.pytype is float + + def is_bool(self): + return self.pytype is bool + + def is_nested_int(self): + # Unbacked SymInts cannot be nested int today + return ( + self._hint is not None + and isinstance(self._hint, SymInt) + and self._hint.node.is_nested_int() + ) + + def wrap_int(self, num): + assert type(num) is int + import sympy + + return SymNode( + sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num + ) + + def wrap_float(self, num): + assert type(num) is float + import sympy + + return SymNode( + sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num + ) + + def wrap_bool(self, num): + assert type(num) is bool + import sympy + + return SymNode( + sympy.true if num else sympy.false, + self.shape_env, + bool, + num, + constant=num, + fx_node=num, + ) + + def clone(self): + return self + + def str(self): + return f"{self.expr}" + + def __str__(self): + return self.str() + + def __repr__(self): + rep = [ + f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}", + ] + if self._hint is not None: + rep.append(f"hint={self._hint}") + if self.constant is not None: + rep.append(f"constant={self.constant}") + if self.fx_node is not None: + rep.append(f"fx_node={self.fx_node}") + return ", ".join(rep) + ")" + + def _graph_repr(self) -> builtins.str: + # Representation used by GraphModule to create a pythonic version of a graph + return self.str() + + # These methods call the metaprogrammed methods, they're hand written + # here so we get good stack traces + def abs(self) -> SymNode: + return self._abs() # type: ignore[attr-defined] + + def pos(self) -> SymNode: + return self._pos() # type: ignore[attr-defined] + + def round(self, ndigits=None) -> SymNode: + return self._round(ndigits) # type: ignore[attr-defined] + + def trunc(self) -> SymNode: + return self._trunc() # type: ignore[attr-defined] + + def add(self, other) -> SymNode: + return self._add(other) # type: ignore[attr-defined] + + def sub(self, other) -> SymNode: + return self._sub(other) # type: ignore[attr-defined] + + def mul(self, other) -> SymNode: + return self._mul(other) # type: ignore[attr-defined] + + def mod(self, other) -> SymNode: + return self._mod(other) # type: ignore[attr-defined] + + def float_pow(self, other) -> SymNode: + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> SymNode: + return self._pow_by_natural(other) # type: ignore[attr-defined] + + def and_(self, other) -> SymNode: + return self._and_(other) # type: ignore[attr-defined] + + def or_(self, other) -> SymNode: + return self._or_(other) # type: ignore[attr-defined] + + def float_truediv(self, other) -> SymNode: + return self._float_truediv(other) # type: ignore[attr-defined] + + def int_truediv(self, other) -> SymNode: + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> SymNode: + return self._int_floordiv(other) # type: ignore[attr-defined] + + def lshift(self, other) -> SymNode: + return self._lshift(other) # type: ignore[attr-defined] + + def rshift(self, other) -> SymNode: + return self._rshift(other) # type: ignore[attr-defined] + + def sym_not(self) -> SymNode: # noqa: F811 + return self._sym_not() # type: ignore[attr-defined] + + def eq(self, other) -> SymNode: + return self._eq(other) # type: ignore[attr-defined] + + def ne(self, other) -> SymNode: + return self._ne(other) # type: ignore[attr-defined] + + def gt(self, other) -> SymNode: + return self._gt(other) # type: ignore[attr-defined] + + def lt(self, other) -> SymNode: + return self._lt(other) # type: ignore[attr-defined] + + def le(self, other) -> SymNode: + return self._le(other) # type: ignore[attr-defined] + + def ge(self, other) -> SymNode: + return self._ge(other) # type: ignore[attr-defined] + + def floor(self) -> SymNode: + return self._floor() # type: ignore[attr-defined] + + def is_integer(self) -> SymNode: + return self._is_integer() # type: ignore[attr-defined] + + def sym_float(self) -> SymNode: # noqa: F811 + return self._sym_float() # type: ignore[attr-defined] + + def sym_int(self) -> SymNode: + return self._sym_int() # type: ignore[attr-defined] + + def ceil(self) -> SymNode: + return self._ceil() # type: ignore[attr-defined] + + def neg(self) -> SymNode: + return self._neg() # type: ignore[attr-defined] + + def sym_min(self, other) -> SymNode: # noqa: F811 + return self._sym_min(other) # type: ignore[attr-defined] + + def sym_max(self, other) -> SymNode: # noqa: F811 + return self._sym_max(other) # type: ignore[attr-defined] + + def sym_ite(self, then_val, else_val) -> SymNode: + return self._sym_ite(then_val, else_val) # type: ignore[attr-defined] + + def is_contiguous(self, sizes, strides) -> SymNode: + return self._is_contiguous(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode: + return self._is_channels_last_contiguous_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode: + return self._is_channels_last_contiguous_3d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_2d(self, sizes, strides) -> SymNode: + return self._is_channels_last_strides_2d(sizes, strides) # type: ignore[attr-defined] + + def is_channels_last_strides_3d(self, sizes, strides) -> SymNode: + return self._is_channels_last_strides_3d(sizes, strides) # type: ignore[attr-defined] + + def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode: + return self._is_non_overlapping_and_dense_indicator(sizes, strides) # type: ignore[attr-defined] + + # Make C++ happy + def sym_or(self, other): + return self.or_(other) + + def sym_and(self, other): + return self.and_(other) + + # Integer bitwise ops + def bitwise_and(self, other): + return self._bitwise_and(other) # type: ignore[attr-defined] + + def bitwise_or(self, other): + return self._bitwise_or(other) # type: ignore[attr-defined] + + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> SymNode: + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + + def is_non_overlapping_and_dense(self, sizes, strides): + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq( + to_node(self, 1) + ) # type: ignore[attr-defined] + + def int_(self): + return self.guard_int("", 0) # NB: uses Python backtrace + + # This one is currently done by hand, but if we add other variadic + # functions consider factoring it out to be metaprogrammed too. Note that + # some load bearing logic is directly in torch.sym_sum + + def sym_sum(self, args) -> SymNode: + import sympy + + # Inner impl + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + torch.sym_sum, + (tuple(wrap_node(a) for a in args),), + {}, + ), + ) + exprs = [a.expr for a in args] + out = sympy.Add(*exprs) + + size_hints = [] + out_hint = None + for a in args: + if a.hint is None: + break + size_hints.append(a.hint) + else: + out_hint = sum(size_hints) + + fx_node, _ = self.shape_env._create_fx_call_function( + torch.sym_sum, (tuple(a.fx_node for a in args),) + ) + + # NB: Only for integers! + return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node) + + def evaluate(self, size_oblivious=False): + return self.shape_env.evaluate_sym_node(self, size_oblivious) + + # You can manually trigger a guard with this function + def guard_int(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return int(r) + except Exception: + log.warning("Failed to convert to int: %s", r) + raise + + def guard_float(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return float(r) + except Exception: + log.warning("Failed to convert to float: %s", r) + raise + + def guard_bool(self, file, line): + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate() + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def expect_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols + + if ( + self.has_hint() + and not free_unbacked_symbols(self.expr) + and not self.shape_env.prefer_deferred_runtime_asserts_over_guards + ): + # OK to generate guards + return self.guard_bool(file, line) + # Generate a deferred runtime assert (this might actually end up doing + # a regular guard if we can!) + # TODO: file/line here is very important, because the assert has been + # deferred so you can't backtrace easily + return self.shape_env.guard_or_defer_runtime_assert( + self.expr, f"{file}:{line}", fx_node=self.fx_node + ) + + def expect_size(self, file, line): + from torch.fx.experimental.symbolic_shapes import _advise_is_size + + b = self.ge(self.wrap_int(0)) + # Generate a deferred runtime assert + r = b.expect_true(file, line) + # Refine compile time range, but only if it's unbacked. + # If you refine range for hinted variables, you can end up making + # improper deductions since compile time reasoning may be + # incompatible with runtime reasoning. + if r and not self.has_hint(): + _advise_is_size(SymInt(self)) + return r + + def statically_known_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import statically_known_true + + assert self.is_bool() + return statically_known_true(SymBool(self)) + + def guard_size_oblivious(self, file, line): + """ + Like guard_bool, but if we encounter unbacked symbols, if those symbols + are size-like, we will treat them as >= 2 for the purposes of the analysis. + + This CHANGES the runtime semantics, but all size-oblivious sites have been + audited to ensure that the runtime semantics don't change in a material way. + Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping + an unbacked one size, or a tensor reporting as non-contiguous even if it's + contiguous if it would have been reported contiguous due to being empty. + """ + # TODO: use the file/line for some useful diagnostic on why a + # guard occurred + r = self.evaluate(size_oblivious=True) + try: + return bool(r) + except Exception: + log.warning("Failed to convert to bool: %s", r) + raise + + def guard_or_false(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_false + + assert self.is_bool() + return guard_or_false(SymBool(self)) + + def guard_or_true(self, file, line): + from torch.fx.experimental.symbolic_shapes import guard_or_true + + assert self.is_bool() + return guard_or_true(SymBool(self)) + + def bool_(self): + return self.guard_bool("", 0) + + def is_symbolic(self): + return True + + def nested_int(self): + return None + + def is_constant(self): + return False + + +# TODO: this probably needs the sizes-strides eval functions +METHOD_TO_OPERATOR = { + "pos": operator.pos, + "abs": operator.abs, + "add": operator.add, + "and": operator.and_, + "bitwise_and": operator.and_, + "ceil": math.ceil, + "eq": operator.eq, + "floor": math.floor, + "trunc": math.trunc, + "int_floordiv": operator.floordiv, + "ge": operator.ge, + "gt": operator.gt, + "is_integer": lambda x: x.is_integer(), + "le": operator.le, + "lshift": operator.lshift, + "lt": operator.lt, + "mod": operator.mod, + "mul": operator.mul, + "ne": operator.ne, + "neg": operator.neg, + "or": operator.or_, + "bitwise_or": operator.or_, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, + "round": builtins.round, + "rshift": operator.rshift, + "sub": operator.sub, + "sym_float": sym_float, + "sym_ite": sym_ite, + "sym_max": sym_max, + "sym_min": sym_min, + "sym_not": sym_not, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, +} + +unary_magic_methods = { + "abs", + "sym_float", + "sym_int", + "ceil", + "floor", + "neg", + "sym_not", + "pos", + "trunc", +} + + +# Adding math ops: sqrt, cos, sin, ... +def _get_sym_node_fn(name): + def fn(self): + return getattr(self, f"_sym_{name}")() + + return fn + + +math_op_names = ( + "sqrt", + "cos", + "cosh", + "sin", + "sinh", + "tan", + "tanh", + "asin", + "acos", + "atan", + "log2", +) +for name in math_op_names: + sym_name = f"sym_{name}" + priv_sym_name = f"_{sym_name}" + setattr(SymNode, sym_name, _get_sym_node_fn(name)) + METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name) + unary_magic_methods.add(sym_name) + __all__.append(sym_name) + + +# Unary methods that are not magic methods +unary_nonmagic_methods = { + "is_integer", +} + +unary_methods = unary_magic_methods | unary_nonmagic_methods + +# Most methods are only registered on SymInt and SymFloat +# Some methods are only be registered on SymBool +only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"} +# Methods that implicitly convert SymBool into SymInt +bool_becomes_int_magic_methods = {"add", "sub", "mul"} +# Methods that are also on SymBool, in addition to on SymInt and SymFloat +also_bool_magic_methods = {"eq"} +bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods + +# Methods that are only for float +only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"} + + +magic_methods_on_operator_with_trailing_underscore = {"and", "or"} +# remap necessary because an op name can have a bitwise and boolean implementation +bitwise_ops = { + "bitwise_and": "and", + "bitwise_or": "or", +} + + +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} + +for name in math_op_names: + sym_name = f"sym_{name}" + always_float_magic_methods.add(sym_name) + + +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} +always_bool_magic_methods = { + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + "and", + "or", + "sym_not", + "is_non_overlapping_and_dense", + "is_integer", +} + +# Methods that have a `__foo__` as well as `__rfoo__` + + +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv + + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) + + +def _sympy_floordiv(a, b): + from torch.utils._sympy.functions import FloorDiv + + return FloorDiv(a, b) + + +def _sympy_mod(a, b): + from torch.utils._sympy.functions import Mod, PythonMod + + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) + + +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural + + return PowByNatural(a, b) + + +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) + + +def _sympy_and(a, b): + import sympy + + return sympy.And(a, b) + + +def _sympy_or(a, b): + import sympy + + return sympy.Or(a, b) + + +def _sympy_lshift(a, b): + from torch.utils._sympy.functions import LShift + + return LShift(a, b) + + +def _sympy_rshift(a, b): + from torch.utils._sympy.functions import RShift + + return RShift(a, b) + + +def _binary_search_insert_arg(ordered_args, new_arg): + """ + If new_arg is found in ordered_args None is returned, else the new + ordered_args with new_arg inserted + """ + if len(ordered_args) == 0: + return [new_arg] + + from sympy.core.basic import _args_sortkey as sort_key, Basic + + # Fast path when new_arg > ordered_args[-1]. + if sort_key(ordered_args[-1]) < sort_key(new_arg): + return ordered_args + [new_arg] + + # Fast path when new_arg < ordered_args[0]. + if sort_key(ordered_args[0]) > sort_key(new_arg): + return [new_arg] + ordered_args + + low, high = 0, len(ordered_args) - 1 + + while low <= high: + mid = (low + high) // 2 + compare_result = Basic.compare(ordered_args[mid], new_arg) + if compare_result == 0: + return None + elif compare_result < 0: + low = mid + 1 + else: + high = mid - 1 + + ordered_args.insert(low, new_arg) + return ordered_args + + +def _optimized_add( + lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False +): + """ + Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea + is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols, + and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following. + 1. Avoid running other optimizations when the Add is constructed. + 2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n) + (comparing terms is expensive and shows in the profiles). + The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols, + (2) the result sympy expression. + """ + import sympy + from sympy.core.basic import _args_sortkey as sortkey + + def make_optimized(ordered_args): + assert ordered_args is not None + result = sympy.Add(*ordered_args, evaluate=False) + return (True, result) + + from torch.utils._sympy.functions import _is_symbols_binary_summation + + lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs) + rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs) + + if lhs_is_optimized_summation and rhs_is_optimized_summation: + # (a0+a1..) + (a2+a3..) => (a0+a1+a2+a3) + if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]): + return make_optimized(lhs._args + rhs._args) + # (a2+a3..) + (a0+a1..) => (a0+a1+a2+a3) + if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]): + return make_optimized(rhs._args + lhs._args) + + # (a1+a3) + (a0+a2) => (a0+a1+a2+a3) + if len(lhs._args) <= 2 and len(rhs._args) <= 2: + new_args = list(lhs._args) + for a in rhs._args: + new_args = _binary_search_insert_arg(new_args, a) + if new_args is None: + break + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + # (a0+a2) + a1 => (a0+a1+a2) + if lhs_is_optimized_summation and rhs.is_symbol: + new_args = _binary_search_insert_arg(list(lhs._args), rhs) + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + # a1 + (a0+a2)=> (a0+a1+a2) + if rhs_is_optimized_summation and lhs.is_symbol: + new_args = _binary_search_insert_arg(list(rhs._args), lhs) + # None means an element already exists. + if new_args is not None: + return make_optimized(new_args) + + result = sympy.Add(lhs, rhs) + return (_is_symbols_binary_summation(result), result) + + +def _bitwise_and(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_and + + return BitwiseFn_bitwise_and(a, b) + + +def _bitwise_or(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_or + + return BitwiseFn_bitwise_or(a, b) + + +reflectable_magic_methods = { + "add": _optimized_add, + "sub": operator.sub, + "mul": operator.mul, + "mod": _sympy_mod, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, + "and": _sympy_and, + "bitwise_and": _bitwise_and, + "or": _sympy_or, + "bitwise_or": _bitwise_or, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, + "lshift": _sympy_lshift, + "rshift": _sympy_rshift, +} + + +def _floor_ceil_helper(a, fn): + import sympy + + if isinstance(a, sympy.Mul): + aa = a.args + if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer: + coef = sympy.Integer(aa[0]) + if aa[0] == coef: # structural equality test + return coef * aa[1] + if ( + isinstance(a, sympy.Float) + and a == sympy.Integer(a) + or isinstance(a, sympy.Integer) + ): + return sympy.Integer(a) + return fn(a) + + +def _sympy_floor(a): + from torch.utils._sympy.functions import FloorToInt + + return FloorToInt(a) + + +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) +def _sympy_trunc(a): + from torch.utils._sympy.functions import TruncToInt + + return TruncToInt(a) + + +def _sympy_ceil(a): + from torch.utils._sympy.functions import CeilToInt + + return CeilToInt(a) + + +def _sympy_eq(a, b): + import sympy + + return sympy.Eq(a, b) + + +def _sympy_ne(a, b): + import sympy + + return sympy.Ne(a, b) + + +def _sympy_gt(a, b): + import sympy + + return sympy.Gt(a, b) + + +def _sympy_lt(a, b): + import sympy + + return sympy.Lt(a, b) + + +def _sympy_le(a, b): + import sympy + + return sympy.Le(a, b) + + +def _sympy_ge(a, b): + import sympy + + return sympy.Ge(a, b) + + +def _sympy_min(a, b): + from torch.utils._sympy.functions import Min + + return Min(a, b) + + +def _sympy_max(a, b): + from torch.utils._sympy.functions import Max + + return Max(a, b) + + +def _sympy_ite(a, t, f): + import sympy + + return sympy.Piecewise((t, a), (f, True)) + + +current_module = sys.modules[__name__] + + +def _get_sym_math_fn(name): + def fn(a): + import torch.utils._sympy.functions + + return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a) + + return fn + + +for name in math_op_names: + priv_sympy_name = f"_sympy_{name}" + fn = _get_sym_math_fn(name) + fn.__qualname__ = fn.__name__ = priv_sympy_name + setattr(current_module, priv_sympy_name, fn) + +del fn, name, priv_sympy_name # type: ignore[possibly-undefined] + + +def _sympy_abs(a): + import sympy + + return sympy.Abs(a) + + +def _sympy_round(number, ndigits=None): + from torch.utils._sympy.functions import RoundDecimal, RoundToInt + + if ndigits is None: + return RoundToInt(number) + else: + return RoundDecimal(number, ndigits) + + +def _sympy_sym_float(a): + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) + + +def _sympy_is_integer(a): + import sympy + + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) + + +magic_methods = { + **reflectable_magic_methods, + "sym_not": operator.invert, + "pos": operator.pos, + "eq": _sympy_eq, + "ne": _sympy_ne, + "gt": _sympy_gt, + "lt": _sympy_lt, + "le": _sympy_le, + "ge": _sympy_ge, + "floor": _sympy_floor, + "trunc": _sympy_trunc, + "sym_float": _sympy_sym_float, + "ceil": _sympy_ceil, + "neg": operator.neg, + "sym_min": _sympy_min, + "sym_max": _sympy_max, + "sym_ite": _sympy_ite, + "abs": _sympy_abs, + "round": _sympy_round, + "is_integer": _sympy_is_integer, +} + + +for name in math_op_names: + sym_name = f"sym_{name}" + magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}") + +del name, sym_name, math_op_names, current_module # type: ignore[possibly-undefined] + + +def sympy_is_contiguous(sizes, strides): + dim = len(sizes) + return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1))) + + +def sympy_is_contiguous_generic(sizes, strides, dim_order): + import sympy + + dim = len(sizes) + + if len(dim_order) != dim: + return sympy.false + + is_contiguous = sympy.true + z = sympy.S.One + # Contiguous if the strides make sense (or the dim is size 1) + for d in dim_order: + is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z) + z *= sizes[d] + # OR if any size is zero + for d in range(dim): + is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero) + return is_contiguous + + +# NB: There is a TODO in C++ to allow omitting the batch dim. If that +# happens you will need to refactor this + + +def sympy_is_channels_last_contiguous_2d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_contiguous_3d(sizes, strides): + return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def sympy_is_channels_last_strides_generic(sizes, strides, dim_order): + import sympy + + from torch.utils._sympy.functions import Max + + dim = len(sizes) + + if dim != len(dim_order): + return sympy.false + + m = sympy.S.Zero + r = sympy.true + + # special case for trivial C dimension. default to NCHW + r &= sympy.Ne(strides[1], 0) + + for d in dim_order: + r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m) + # Fallback to NCHW as default layout for ambiguous cases + # This is the flaw of implicit memory_format from strides. + # N111 tensor with identical strides for size 1 dimension; + # Two cases could lead us here: + # a. N111 contiguous Tensor ([N,1,1,1]@[1,1,1,1]) + # b. N11W contiguous Tensor sliced on the W-dimension. + # ([N,1,1,1]@[W,W,W,W]) + if d == 0: + r &= sympy.Ne(m, strides[1]) + # This is necessary to: + # 1. distinguish the memory_format of N1H1; + # [H, 1, 1, 1] channels_last stride + # [H, H, 1, 1] contiguous stride + # 2. permutation of 1C1W: + # [1, C, 1, H]@[HC, H, H, 1] transpose(1, 3) + # [1, H, 1, C]@[HC, 1, H, H] shouldn't be identified as + # channels_last + m = strides[d] * Max(sizes[d], 1) + + return r + + +def sympy_is_channels_last_strides_2d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0]) + + +def sympy_is_channels_last_strides_3d(sizes, strides): + return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0]) + + +def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides): + from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator + + return IsNonOverlappingAndDenseIndicator(*sizes, *strides) + + +sizes_strides_methods = { + # TODO: These could also be done with indicators, maybe it is better + # for reasoning to do it that way + "is_contiguous": sympy_is_contiguous, + "is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d, + "is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d, + "is_channels_last_strides_2d": sympy_is_channels_last_strides_2d, + "is_channels_last_strides_3d": sympy_is_channels_last_strides_3d, + "is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator, +} + + +def to_node(self, num): + if isinstance(num, SymTypes): + return num.node + elif type(num) is bool: + return self.wrap_bool(num) + elif type(num) is int: + return self.wrap_int(num) + elif type(num) is float: + return self.wrap_float(num) + else: + # NotImplemented is important so that Python tries the + # other magic method + return NotImplemented + + +def wrap_node(x): + # TODO: let C++ also take advantage of this + if isinstance(x, SymNode) and x.constant is not None: + return x.constant + if x.is_int(): + return SymInt(x) + elif x.is_float(): + return SymFloat(x) + elif x.is_bool(): + return SymBool(x) + else: + raise AssertionError(f"unrecognized return type {x}") + + +def method_to_operator(method): + return METHOD_TO_OPERATOR[method] + + +def _make_node_magic(method, func): + func = lru_cache(256)(func) + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"{method}_" + else: + method_attr = method + + def uninteresting_files() -> set[str]: + import torch + + mods = [ + torch._dynamo.eval_frame, + torch._dynamo.utils, + torch.fx.experimental.sym_node, + torch, + ] + import torch._dynamo.guards + + return ( + {inspect.getfile(m) for m in mods} + | torch._dynamo.guards.uninteresting_files() + | {""} + ) + + def capture_provenance(fn): + @functools.wraps(fn) + def wrapper(self, other=None): + if other is None: + result = fn(self) + else: + result = fn(self, other) + if torch._logging._internal.GET_DTRACE_STRUCTURED: + if other is not None: + arguments = [self, other] + else: + arguments = [self] + + def get_id(sym_node) -> Optional[int]: + # We don't want to return an ID if the input is a constant + import sympy + + if sym_node.constant is not None: + return None + elif id(sym_node) == id(result): + return None + elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)): + return None + elif sym_node.expr in (sympy.true, sympy.false): + return None + return id(sym_node) + + dtrace_structured( + "expression_created", + metadata_fn=lambda: { + "method": method, + "result": str(result), + "result_id": id(result), + "arguments": [str(a) for a in arguments], + "argument_ids": [ + get_id(i) for i in arguments if get_id(i) is not None + ], + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + }, + ) + + return result + + return wrapper + + @capture_provenance + def binary_magic_impl(self, other): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = method_to_operator(method) + + out_hint = None + if self.hint is not None and other.hint is not None: + out_hint = op(self.hint, other.hint) + + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) + ) + assert isinstance(other, SymNode) + optimized_summation = False + try: + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + elif method == "add": + # see Note [optimized_summation] + (optimized_summation, out) = func( + self.expr, + other.expr, + self._optimized_summation, + other._optimized_summation, + ) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) + raise + sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out) + pytype: type + # This is not strictly correct. In Python, a**b may return complex when + # a < 0 and b is a float: (-1)**2.1. Same for sympy.sqrt(-3.14). This + # returns a float while both arguments are ints: 2**(-1). Also, max and + # min do not type promote. To avoid having data-dependent control flow + # here, we just set the type to float if one of the args is a float. In + # case of a type mismatch, we assume that it will be detected during + # evaluation. + if method in always_float_magic_methods: + pytype = float + elif method in always_bool_magic_methods: + pytype = bool + elif self.pytype is float or other.pytype is float: + pytype = float + else: + pytype = self.pytype + + if ( + pytype is not None + and out_hint is not None + and not isinstance(out_hint, SymTypes) + ): + out_hint = pytype(out_hint) + + # Create a FX node that corresponds to the operation being applied to + # this node. + fx_node, _ = self.shape_env._create_fx_call_function( + op, (self.fx_node, other.fx_node) + ) + + result = SymNode( + out, + self.shape_env, + pytype, + out_hint, # type: ignore[arg-type] + fx_node=fx_node, + optimized_summation=optimized_summation, # see Note [optimized_summation] + ) + return result + + @capture_provenance + def unary_magic_impl(self): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = method_to_operator(method) + if get_proxy_mode(): + return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {})) + # TODO: consider constant prop here + expr = self.expr + if method == "floor" or method == "ceiling": + expr = self.shape_env._simplify_floor_div(expr) + + try: + out = func(expr) + except Exception: + log.warning("failed to eval %s(%s)", method, expr) + raise + sym_node_log.debug("%s %s -> %s", func, expr, out) + out_hint = None + if self.hint is not None: + out_hint = op(self.hint) + pytype: type + if method in always_int_magic_methods: + pytype = int + elif method in always_bool_magic_methods: + pytype = bool + elif method in always_float_magic_methods: + pytype = float + else: + pytype = self.pytype + + fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + if method in unary_methods: + setattr(SymNode, f"_{method_attr}", unary_magic_impl) + elif method == "sym_ite": + + def sym_ite_impl(pred_node, then_node, else_node): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + out_hint = then_node.hint if pred_node.hint else else_node.hint + if get_proxy_mode(): + return to_node( + pred_node, + handle_sym_dispatch( + sym_ite, + ( + wrap_node(pred_node), + wrap_node(then_node), + wrap_node(else_node), + ), + {}, + ), + ) + + try: + out = func(pred_node.expr, then_node.expr, else_node.expr) + except Exception: + log.warning( + "failed to eval %s(%s, %s, %s)", + method, + pred_node.expr, + then_node.expr, + else_node.expr, + ) + raise + + fx_node, _ = pred_node.shape_env._create_fx_call_function( + sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node) + ) + return SymNode( + out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node + ) + + setattr(SymNode, f"_{method_attr}", sym_ite_impl) + elif method == "round": + + def round_impl(self, ndigits=None): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = builtins.round + if get_proxy_mode(): + return to_node( + self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {}) + ) + + expr = self.expr + try: + out = func(expr, ndigits) + except Exception: + log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) + raise + + if ndigits is None: + pytype = int + else: + pytype = self.pytype + + out_hint = None + if self.hint is not None: + out_hint = op(self.hint, ndigits) + + # Internally, None is used as sentinel to indicate that a something is not a node on an FX graph. At the + # same time, there is no way to wrap a plain None into an FX node. Thus, there is no way to pass None here + # without triggering some asserts that check whether we are mixing FX nodes with untracked arguments. The + # hack down below works, because all round function down the line all take ndigits=None as default in their + # signature. + # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL + args = [self.fx_node] + if ndigits is not None: + args.append(ndigits) + fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args)) + return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node) + + setattr(SymNode, f"_{method_attr}", round_impl) + else: + setattr(SymNode, f"_{method_attr}", binary_magic_impl) + + +def _make_node_sizes_strides(method, func): + # NB: don't LRU cache, lots of arguments + + def sizes_strides_impl(self, sizes, strides): + from torch.fx.experimental.proxy_tensor import ( + get_proxy_mode, + handle_sym_dispatch, + ) + + op = getattr(sys.modules[__name__], method) + if get_proxy_mode(): + return to_node( + self, + handle_sym_dispatch( + op, + ([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]), + {}, + ), + ) + size_exprs = [s.expr for s in sizes] + stride_exprs = [s.expr for s in strides] + try: + out = func(size_exprs, stride_exprs) + except Exception: + log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs) + raise + # bool is never expandable + + size_hints = [] + out_hint = None + for s in sizes: + if s.hint is None: + break + size_hints.append(s.hint) + else: + stride_hints = [] + for s in strides: + if s.hint is None: + break + stride_hints.append(s.hint) + else: + out_hint = op(size_hints, stride_hints) + + # NB: This is the indicator function, not the actual bool! + pytype: type + if method.endswith("_indicator"): + pytype = int + else: + pytype = bool + return SymNode(out, self.shape_env, pytype, out_hint) + + setattr(SymNode, f"_{method}", sizes_strides_impl) + + # TODO: This is technically hotpath, but in the ideal end state + # guards on this will resolve at a higher level so you never + # spend time in this code + def sizes_strides_user(sizes, strides): + import sympy + + from torch.fx.experimental.symbolic_shapes import ( + eval_is_non_overlapping_and_dense, + ) + + for a in itertools.chain(sizes, strides): + if isinstance(a, SymInt): + return wrap_node( + getattr(a.node, method)( + [to_node(a.node, b) for b in sizes], + [to_node(a.node, b) for b in strides], + ) + ) + if method == "is_non_overlapping_and_dense_indicator": + return eval_is_non_overlapping_and_dense(sizes, strides) + else: + # TODO: this is an awful implementation + return bool( + func( + [sympy.sympify(a) for a in sizes], + [sympy.sympify(a) for a in strides], + ) + ) + + # Skip for is_non_overlapping_and_dense_indicator + if not hasattr(sys.modules[__name__], method): + setattr(sys.modules[__name__], method, sizes_strides_user) + + +for method, func in magic_methods.items(): + _make_node_magic(method, func) + +for method, func in sizes_strides_methods.items(): + _make_node_sizes_strides(method, func) + + +def _make_user_magic(method, user_type): + # User magic takes care of wrapping the other operand into a node, + # so that our internal logic can assume everything is nodes + + if method in magic_methods_on_operator_with_trailing_underscore: + method_attr = f"sym_{method}" + else: + method_attr = method + + def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]): + if isinstance(x, (int, float, bool)): + return x + if isinstance(x, SymBool): + return x.node.guard_bool("", 0) + raise AssertionError("expect to be called with constant SymBools") + + def is_constant(x): + if isinstance(x, (int, float, bool)): + return True + if isinstance(x, (SymInt, SymFloat, SymBool)): + return x.node.is_constant() + return False + + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + + if method in bool_becomes_int_magic_methods: + + def promote(x): + """Implements True+True=2, which works in python but not sympy""" + if isinstance(x, SymBool): + return SymInt(x.node.wrap_int(int(x))) + return x + + else: + + def promote(x): + return x + + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + + # Before and after performing the operation, check if any operands are constant. + # If so, extract out the constant values first. If `self` itself is a + # constant, then "redispatch" by calling back into the operator. Sometimes + # this means that operations involving SymBool return plain bools. + # Alternatively, we could also rewrap into constant Symbool (i.e. by + # implementing wrap_bool in ConstantSymNodeImpl), but we're not doing that + # today for no particular reason. + def unary_magic_impl(self): + self = promote(self) + if is_constant(self): + return (method_to_operator(method))(get_constant(self)) + return wrap_node(getattr(self.node, method_attr)()) + + def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + sym_node_log.debug("MAGIC %s %s %s", method, self, other) + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(self.node, method_attr)(other_node)) + return get_constant(ret) if is_constant(ret) else ret + + def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented + self = promote(self) + other = promote(other) + self, other = promote2(self, other) + if is_constant(self): + return (method_to_operator(method))(get_constant(self), other) + if is_constant(other): + other = get_constant(other) + other_node = to_node(self.node, other) + if other_node is NotImplemented: + return NotImplemented + ret = wrap_node(getattr(other_node, method_attr)(self.node)) + return get_constant(ret) if is_constant(ret) else ret + + if method in unary_magic_methods: + setattr(user_type, f"__{method}__", unary_magic_impl) + elif method in unary_nonmagic_methods: + orig = getattr(user_type, method) + setattr(user_type, method, update_wrapper(unary_magic_impl, orig)) + elif method == "sym_ite": + + def sym_ite_magic_impl(pred, then_val, else_val): + pred_node = pred.node + then_node = to_node(pred_node, then_val) + else_node = to_node(pred_node, else_val) + if then_node is NotImplemented or else_node is NotImplemented: + return NotImplemented + assert ( + isinstance(then_node, SymNode) + and isinstance(else_node, SymNode) + and then_node.pytype == else_node.pytype + ) + ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node)) + return get_constant(ret) if ret.node.is_constant() else ret + + setattr(user_type, f"__{method}__", sym_ite_magic_impl) + elif method == "round": + + def round_magic_impl(self, ndigits=None): + if is_constant(self): + return builtins.round(get_constant(self), ndigits) + + return wrap_node(getattr(self.node, method)(ndigits)) + + setattr(user_type, f"__{method}__", round_magic_impl) + else: + method_name = method + if method in bitwise_ops: + method_name = bitwise_ops[method] + setattr(user_type, f"__{method_name}__", binary_magic_impl) + if method in reflectable_magic_methods: + setattr(user_type, f"__r{method_name}__", rbinary_magic_impl) + + +for method, func in magic_methods.items(): # type: ignore[assignment] + if method in only_bool_magic_methods: + _make_user_magic(method, SymBool) + continue + if method in only_float_magic_methods: + _make_user_magic(method, SymFloat) + continue + if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: + _make_user_magic(method, SymBool) + _make_user_magic(method, SymInt) + if method not in bitwise_ops: + _make_user_magic(method, SymFloat) + +del method +del func diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/symbolic_shapes.py b/phivenv/Lib/site-packages/torch/fx/experimental/symbolic_shapes.py new file mode 100644 index 0000000000000000000000000000000000000000..8027ba96131a234ef67389b49ab088f83085af60 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/symbolic_shapes.py @@ -0,0 +1,8055 @@ +from __future__ import annotations + +import sympy +from sympy import S + +from torch._prims_common import BoolLike, FloatLike, IntLike + + +""" +``torch.fx.experimental.symbolic_shapes`` provides interfaces for interacting with +our symbolic shapes reasoning system that is used heavily in torch.compile. Although +this is not generally considered public API, when writing framework code in PyTorch +as well as extensions to PyTorch (e.g., in custom operator implementations), you may +need to make use of these APIs to setup dynamic shapes support appropriately. +""" + +import abc +import atexit +import collections +import dis +import functools +import hashlib +import inspect +import itertools +import logging +import math +import operator +import os +import re +import sys +import threading +import traceback +from collections import Counter, defaultdict +from collections.abc import Generator, Iterator, Mapping, Sequence +from contextlib import _GeneratorContextManager, contextmanager +from dataclasses import asdict, dataclass, field +from enum import Enum +from typing import ( + Any, + Callable, + cast, + Generic, + NamedTuple, + NoReturn, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) +from typing_extensions import deprecated, ParamSpec, TypeAlias, TypeGuard + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +import torch.utils._pytree as pytree + +# NB: The sym_* functions are used via getattr() and must be imported here. +from torch import SymBool, SymFloat, SymInt +from torch._guards import ShapeGuard, SLoc, Source, TracingContext +from torch._logging import dtrace_structured, LazyString, structured, trace_structured +from torch._subclasses.meta_utils import is_sparse_any +from torch._utils_internal import signpost_event +from torch.fx.experimental import _config as config +from torch.fx.experimental.recording import ( + FakeTensorMeta, + record_shapeenv_event, + replay_shape_env_events, + shape_env_check_state_equal, + ShapeEnvEvent, +) +from torch.fx.experimental.sym_node import SymNode, SymTypes +from torch.types import py_sym_types +from torch.utils._ordered_set import OrderedSet +from torch.utils._python_dispatch import is_traceable_wrapper_subclass +from torch.utils._sympy.functions import ( + Application, + CeilToInt, + CleanDiv, + FloorDiv, + FloorToInt, + IntTrueDiv, + IsNonOverlappingAndDenseIndicator, + Max, + Min, + Mod, + PythonMod, + TruncToInt, +) +from torch.utils._sympy.numbers import int_oo +from torch.utils._sympy.printers import CppPrinter, PythonPrinter +from torch.utils._sympy.singleton_int import SingletonInt +from torch.utils._sympy.solve import try_solve +from torch.utils._sympy.symbol import make_symbol, symbol_is_type, SymT +from torch.utils._sympy.value_ranges import ( + bound_sympy, + SymPyValueRangeAnalysis, + ValueRangeError, + ValueRanges, +) +from torch.utils._traceback import CapturedTraceback, format_frame + + +if TYPE_CHECKING: + import types + + from torch import Tensor + from torch._dynamo.source import TensorPropertySource + from torch._subclasses.fake_tensor import FakeTensor + from torch.types import BoolLikeType, FloatLikeType, IntLikeType + + +InputList = list +DimList = list + +log = logging.getLogger(__name__) + + +class GuardOnDataDependentSymNode(RuntimeError): + cond: sympy.Basic + + def __init__(self, cond: sympy.Basic, *args: Any) -> None: + super().__init__(*args) + self.cond = cond + + +class PendingUnbackedSymbolNotFound(RuntimeError): + pass + + +aten = torch._ops.ops.aten # type: ignore[has-type] + +__all__ = [ + "guard_or_false", + "guard_or_true", + "has_symbolic_sizes_strides", + "create_contiguous", + "ShapeEnv", + "is_concrete_int", + "is_concrete_float", + "is_concrete_bool", + "has_static_value", + "guard_int", + "guard_float", + "guard_scalar", + "canonicalize_bool_expr", + "hint_int", + "SYMPY_INTERP", + "free_symbols", + "is_symbol_binding_fx_node", + "is_nested_int", + "SHAPEENV_EVENT_KEY", + "CURRENT_NODE_KEY", + "has_free_symbols", + "has_free_unbacked_symbols", + "sym_and", + "sym_eq", + "sym_or", + "SymbolicContext", + "StatelessSymbolicContext", + "StatefulSymbolicContext", + "SubclassSymbolicContext", + "SymIntSymbolicContext", + "TrackedFake", + "statically_known_true", + "statically_known_false", + "guard_size_oblivious", + "check_consistent", + "compute_unbacked_bindings", + "ConvertIntKey", + "rebind_unbacked", + "resolve_unbacked_bindings", + "is_accessor_node", + "ValueRangesSLoc", + "SymIntEqByExpr", + "Specialization", +] + +# FX node metadata keys for symbolic shape FX graph. +SHAPEENV_EVENT_KEY = "shapeenv_event" +CURRENT_NODE_KEY = "current_node" + + +def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: + log.debug( + "lru_cache_stats %s: %s", + wrapped_f.__name__, # type: ignore[attr-defined] + wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined] + ) + + +# Note about Sympy Expr/SympyBoolean/Basic typing: the Sympy hierarchy is +# +# Basic +# Expr +# SympyBoolean +# Relational +# +# Notably, Expr and SympyBoolean are not related. So use Basic when the +# expression could denote int, float OR bool, and otherwise use the more +# specific Expr for int/float and SympyBoolean for bool. +# +# In obscure Meta only situations, sympy.logic.boolalg doesn't exist at runtime. +# So make sure only type checker evaluates this alias. +# Xref: https://www.internalfb.com/diff/D53324783 +SympyBoolean: TypeAlias = "sympy.logic.boolalg.Boolean" + + +_T = TypeVar("_T") +_SympyT = TypeVar("_SympyT", sympy.Expr, SympyBoolean, sympy.Basic) + + +class SymIntEqByExpr: + """ + This is a wrapper around SymInt which has alternative semantics for + equality. Specifically, instead of erroring or guarding, we + instead will hash/compare equality based on the underlying sympy + expression; e.g., s0 and s1 will always compare as False. + + NB: This does NOT do fancy analysis that maybe_evaluate_static does; + we can only reason through equalities that occur because to expressions + canonicalize to the same expression via regular simplification. + """ + + val: Union[torch.SymInt, int] + + def __init__(self, val: Union[torch.SymInt, int]) -> None: + self.val = val + + def __repr__(self) -> str: + return repr(self.val) + + def _extract(self) -> sympy.Expr: + if isinstance(self.val, torch.SymInt): + return self.val.node.expr + else: + return sympy.Integer(self.val) + + def __eq__(self, other: object) -> bool: + assert isinstance(other, SymIntEqByExpr) + + # int equality fastpath + if type(self.val) is int and type(other.val) is int: + return self.val == other.val + + return self._extract() == other._extract() + + def __hash__(self) -> int: + return hash(self._extract()) + + +def _nested_int_aware_sort( + tup: tuple[IntLikeType, int], +) -> tuple[int, IntLikeType, int]: + return ( + # Order nested ints by their coefficients. + # 1 here to order nested ints after non-nested-ints. + (1, tup[0].node.nested_int_coeff(), tup[1]) + if is_nested_int(tup[0]) + else (0, *tup) + ) + + +# Wrapper on lru_cache that reports statistics at process end +def lru_cache( + maxsize: Optional[int], +) -> Callable[[Callable[..., _T]], functools._lru_cache_wrapper[_T]]: + def inner(f: Callable[..., _T]) -> functools._lru_cache_wrapper[_T]: + wrapped_f = functools.lru_cache(maxsize)(f) + old_cache_clear = wrapped_f.cache_clear + prev_hits = 0 + prev_misses = 0 + + # TODO: There's a ref-cycle here (wrapped_f -> cumulative_cache_info + # -> wrapped_f) but cannot be solved with weakref as wrapped_f is not + # weakref'able on some versions of Python + + def cumulative_cache_info() -> functools._CacheInfo: + cur = wrapped_f.cache_info() + return functools._CacheInfo( + prev_hits + cur.hits, + prev_misses + cur.misses, + cur.maxsize, + cur.currsize, + ) + + def new_cache_clear() -> None: + nonlocal prev_hits, prev_misses + cur = wrapped_f.cache_info() + prev_hits += cur.hits + prev_misses += cur.misses + old_cache_clear() + + wrapped_f.cache_clear = new_cache_clear # type: ignore[attr-defined, method-assign] + wrapped_f.cumulative_cache_info = cumulative_cache_info # type: ignore[attr-defined, method-assign] + if log.isEnabledFor(logging.DEBUG): + atexit.register(log_lru_cache_stats, wrapped_f) # type: ignore[arg-type] + return wrapped_f + + return inner + + +# These are modules that contain generic code for interacting with ShapeEnv +# which are unlikely to identify a particular interesting guard statement +@lru_cache(None) +def uninteresting_files() -> set[str]: + import torch._compile + import torch._dynamo.eval_frame + import torch._inductor.sizevars + import torch._library.custom_ops + import torch._library.fake_impl + import torch._logging + import torch._subclasses.fake_tensor + import torch._subclasses.meta_utils + + mods = [ + sys.modules[__name__], + torch.fx.experimental.recording, + torch.fx.experimental.sym_node, + torch.fx.interpreter, + torch, + torch._compile, + torch._dynamo.eval_frame, + torch._inductor.sizevars, + torch._library.custom_ops, + torch._library.fake_impl, + torch._subclasses.meta_utils, + torch._subclasses.fake_tensor, + torch._logging._internal, + torch._logging.structured, + ] + import torch._dynamo.guards + + return ( + {inspect.getfile(m) for m in mods} + | torch._dynamo.guards.uninteresting_files() + | {""} + ) + + +class ConstraintViolationError(RuntimeError): + pass + + +def has_symbolic_sizes_strides(elem: torch.Tensor) -> bool: + return elem._has_symbolic_sizes_strides + + +Int: TypeAlias = Union[torch.SymInt, int] + + +def create_contiguous(shape: Sequence[Int]) -> list[Int]: + strides: list[Int] = [1] + for dim in reversed(shape[:-1]): + strides.append(dim * strides[-1]) # type: ignore[operator] + return list(reversed(strides)) + + +def hint_int(a: Union[torch.SymInt, int], fallback: Optional[int] = None) -> int: + """ + Retrieve the hint for an int (based on the underlying real values as observed + at runtime). If no hint is available (e.g., because data dependent shapes), + if fallback is not None, use that instead (otherwise raise an error). + """ + if isinstance(a, torch.SymInt): + return a.node.require_hint(fallback) + assert type(a) is int, a + return a + + +Scalar: TypeAlias = Union[torch.SymInt, torch.SymFloat, torch.SymBool, int, float, bool] + + +def has_hint(a: Scalar) -> bool: + if isinstance(a, SymTypes): + return a.node.has_hint() + return True + + +def is_concrete_int(a: IntLikeType) -> bool: + """ + Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or int): Object to test if it int + """ + assert isinstance(a, (SymInt, int)) + + if isinstance(a, int): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Integer): + return True + + return False + + +def is_concrete_float(a: FloatLikeType) -> bool: + r"""Utility to check if underlying object + in SymInt is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymInt or float): Object to test if it float + """ + assert isinstance(a, (SymFloat, float)) + + if isinstance(a, float): + return True + + if isinstance(a.node.expr, sympy.core.numbers.Float): + return True + + return False + + +def is_concrete_bool(a: BoolLikeType) -> bool: + """ + Utility to check if underlying object + in SymBool is concrete value. Also returns + true if integer is passed in. + + Args: + a (SymBool or bool): Object to test if it bool + """ + assert isinstance(a, (SymBool, bool)) + + if isinstance(a, bool): + return True + + if isinstance( + a.node.expr, (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse) + ): + return True + + return False + + +def has_static_value(a: Union[SymBool, SymFloat, SymInt, bool, float, int]) -> bool: + """ + User-code friendly utility to check if a value is static or dynamic. + Returns true if given a constant, or a symbolic expression with a fixed value. + + Args: + a (Union[SymBool, SymFloat, SymInt, bool, float, int]): Object to test + """ + assert isinstance(a, BoolLike + FloatLike + IntLike) + if ( + isinstance(a, BoolLike) + and is_concrete_bool(a) # type: ignore[arg-type] + or isinstance(a, FloatLike) + and is_concrete_float(a) # type: ignore[arg-type] + or isinstance(a, IntLike) + and is_concrete_int(a) # type: ignore[arg-type] + ): + return True + + assert isinstance(a, py_sym_types) + return a.node.shape_env.bound_sympy(a.node.expr).is_singleton() # type: ignore[union-attr] + + +def guard_size_oblivious(expr: Union[torch.SymBool, bool]) -> bool: + """ + Perform a guard on a symbolic boolean expression in a size oblivious way. + This is typically used when a non-oblivious test would result in a guard + on a data dependent value of which we don't know the value of at compile time. + When a guard is tested this way, we may diverge in behavior from how regular + PyTorch semantics would treat it. For more information, see + https://github.com/pytorch/pytorch/pull/118579 + """ + if isinstance(expr, torch.SymBool): + return expr.node.guard_size_oblivious("", 0) + else: + assert isinstance(expr, bool), expr + return expr + + +def check_consistent(new: _T, old: _T) -> None: + """ + Test that two "meta" values (typically either Tensor or SymInt) have + the same values, e.g., after retracing. If we don't understand the + quantities in question, we'll just skip the consistency check. + """ + # TODO: do boolean equality test too, see + # https://github.com/pytorch/pytorch/issues/124110 + scalar_types = (torch.SymInt, torch.SymFloat, int, float) + + if isinstance(new, torch.Tensor): + assert isinstance(old, torch.Tensor) + torch._check( + old.dim() == new.dim(), lambda: f"{old.shape} != {new.shape} (old != new)" + ) + # Do this manually so that each individual test is irrefutable + # (TODO: should be a helper for this, maybe sym_eq? That + # gives us a compound expression and I'm not sure it + # simplifies right now) + for i, j in zip(old.shape, new.shape): + torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") + # NB: bool is subclass of int + elif isinstance(new, scalar_types) and not isinstance(new, bool): + assert isinstance(old, scalar_types) and not isinstance(old, bool), ( + f"{old} != {new}" + ) + torch._check(old == new, lambda: f"{old} != {new} (old != new)") + + +def resolve_unbacked_bindings( + shape_env: Optional[ShapeEnv], + bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]], +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: + """ + When we do fake tensor prop, we oftentimes will allocate new unbacked symints. + We then run proxy tensor mode, which populates node.meta["unbacked_bindings"] + with these new symints. To ensure consistency we use PropagateUnbackedSymInts + to rename unbacked bindings to their old ones. But all of the node metas are + still using the old bindings from before the renaming. This function helps to + post facto apply any renamings discovered in the PropogateUnbackedSymInts pass. + """ + if bindings is None: + return None + assert shape_env is not None + return {shape_env.unbacked_renamings.get(k, k): v for k, v in bindings.items()} + + +Result: TypeAlias = Union[torch.Tensor, tuple[torch.Tensor, ...]] + + +def rebind_unbacked( + shape_env: Optional[ShapeEnv], n: torch.fx.Node, result: Result +) -> None: + """ + Suppose we are retracing a pre-existing FX graph that previously had + fake tensor propagation (and therefore unbacked SymInts). When we retrace, + we re-propagate fake tensors, which results in new unbacked SymInts. + When this happens, we need to tell the shape environment about the equivalence + of the old and new unbacked SymInts. Pass us the old torch.fx.Node (which + has the old binding information) and the new result (which we can extract the + new unbacked SymInts out from). + """ + + # Inputs never need rebinding + if n.op == "placeholder": + return + + if bindings := resolve_unbacked_bindings( + shape_env, n.meta.get("unbacked_bindings") + ): + assert shape_env is not None + for raw_u0, path in bindings.items(): + u1 = pytree.key_get(result, path) + # Sometimes, things were previously unbacked bindings become constants. + # There are two situations this can happen. + # + # First, you might have a runtime assert that causes the + # constant-ification. In this case, the /binding/ itself will + # still be an unbacked symbol (because we will only force it + # to be a constant later in fake tensor propagation). In this + # case, u1 is a SymInt and we still do all our work as normal. + # + # But second, it might be that fake tensor propagation DIRECTLY + # converted the unbacked SymInt into a constant. This happens + # more rarely, but we have identified two situations it can + # validly occur: + # + # - If you have a tensor_version operator, these are initially + # allocated as unbacked SymInts, but after AOTAutograd they + # get forced specialized to specific values. In this case, + # there is no reason to do runtime asserts on them, this is + # just a hack to properly keep track of them to start. + # + # - If you have an item() call on a constant tensor, the result + # of the item() call is constant and we do not need runtime + # asserts on this symbol. In + # https://github.com/pytorch/pytorch/issues/140625 we have a + # case where in the initial trace of the program we are unable + # to determine that torch.tensor is constant, but then + # subsequent passes cause torch.tensor to become a constant and + # then the unbacked symbol goes poof. + # + # In all of these cases, it is no longer necessary to generate + # deferred runtime asserts, since other subsystems (e.g., the + # constant-ification pass) ensure that the quantity is now truly + # static and cannot change at runtime. So it's OK to discard + # in these situations. + # + # There is one more hazard (re + # https://github.com/pytorch/pytorch/issues/141248), the problem + # is that you can end up with "dangling" unbacked symbols that + # exist in the ShapeEnv but are never bound anywhere. You might + # like an invariant that unbacked symbols never get lost. But + # we do not have this invariant, so do not try to enforce it. + if isinstance(u1, int): + log.info( + "rebind_unbacked: discard %s %s %s -> %s", + n.target, + raw_u0, + path, + u1, + ) + continue + + # We only care about rebinding unbacked things + if u1.node.hint is not None: + continue + + raw_u1 = u1.node.expr + # Simplify SymBool binding + if ( + isinstance(raw_u1, sympy.Piecewise) + and len(raw_u1.args) == 2 + and ( + raw_u1_args0 := cast( + tuple[sympy.Basic, sympy.Basic], raw_u1.args[0] + ) + ) + and raw_u1_args0[0] == 1 + and isinstance(eq := raw_u1_args0[1], sympy.Eq) + and isinstance(new_raw_u1 := eq.lhs, sympy.Symbol) + and shape_env.var_to_range[new_raw_u1].issubset(ValueRanges(0, 1)) + and eq.rhs == 1 + and cast(tuple[sympy.Basic, sympy.Basic], raw_u1.args[1]) == (0, True) + ): + # This is what the pattern match above is testing + repacked = _sympy_cast_symbool_to_symint_guardless( + sympy.Eq(new_raw_u1, 1) + ) + assert repacked == raw_u1, f"{repacked} != {raw_u1}" + # Cancel the to_int(to_bool(x)). This is sound because x in + # [0, 1] + raw_u1 = new_raw_u1 + + if not isinstance(raw_u1, sympy.Symbol): + assert not raw_u1.free_symbols, ( + f"should have been constant, but got {raw_u1}" + ) + continue + + # The old and new could be the same if you improperly hit the memo + # while retracing. Make sure you updated FakeTensorMode.epoch + assert raw_u0 != raw_u1, f"{raw_u0} possible memo disaster" + # Reuse the OLD symbol name + shape_env._rename_unbacked_to(raw_u1, raw_u0) + + +# NB: You could try to expand this to cover more cases by simply +# detecting whenever you have an int output, but this is a bit +# dangerous in case someone adds a function that returns an int but is +# mutating. So manually whitelist for now. +def is_accessor_node(node: torch.fx.Node) -> bool: + """ + Helper function to determine if a node is trying to access + a symbolic integer such as size, stride, offset or item. Currently + primarily only used in a DCE pass to figure out purity. + """ + + # Dynamo only exercised condition + if ( + node.op == "call_method" + and isinstance(node.args[0], torch.fx.Node) + and isinstance(node.args[0].meta.get("example_value"), torch.Tensor) + and node.target in ["size", "stride", "storage_offset", "item"] + ): + return True + + if node.op == "call_function" and node.target in [ + torch.ops.aten.sym_size, + torch.ops.aten.sym_size.default, + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_stride.default, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_storage_offset, + torch.ops.aten.sym_storage_offset.default, + torch.ops.aten.sym_numel.default, + ]: + return True + + return False + + +def canonicalize_bool_expr(expr: _T) -> _T: + """ + Canonicalize a boolean expression by transforming it into a lt / le + inequality and moving all the non-constant terms to the rhs. + We canonicalize And / Ors / Not via cnf and then canonicalize their subexpr + recursively + nb. sympy.Rel.canonical is not good enough https://github.com/sympy/sympy/issues/25924 + + Args: + expr (sympy.Expr): Expression to canonicalize + """ + # Canonicalise an inequality by transforming it into a lt / le + # inequality and moving all the non-constant terms to the rhs + # We canonicalise And / Ors / Not via cnf + # nb. Relational.canonical in sympy is broken + # https://github.com/sympy/sympy/issues/25924 + + if not isinstance( + expr, (sympy.Rel, sympy.And, sympy.Or, sympy.Not, sympy.Eq, sympy.Ne) + ): + return expr + + if isinstance(expr, (sympy.And, sympy.Or, sympy.Not)): + expr = sympy.logic.boolalg.to_cnf(expr) + return _canonicalize_bool_expr_impl(expr) # type: ignore[arg-type, return-value] + + +def _sympy_from_args( + cls: type[Union[sympy.Add, sympy.Mul]], + args: list[sympy.Expr], + sort: bool = True, + is_commutative: Optional[bool] = None, +) -> sympy.Expr: + """ + Create a sympy expression from a list of arguments, optimizing for performance. + + This function creates a sympy Add or Mul expression from a list of arguments + while avoiding expensive operations like flattening. It handles sorting the + arguments appropriately based on the expression type. + + Args: + cls: The sympy class to create (Add or Mul) + args: List of sympy expressions to combine + sort: Whether to sort the arguments (default: True) + is_commutative: Whether the operation is commutative (default: None) + + Returns: + A sympy expression of type cls combining all arguments + + Raises: + ValueError: If cls is not sympy.Add or sympy.Mul + """ + + if not args: + return cls.identity # type: ignore[union-attr] + + # These args are already in canonical form, so we avoid calling + # Add(*args) to avoid expensive Add.flatten operation + if sort: + if cls is sympy.Add: + sort_fn = sympy.core.add._addsort + elif cls is sympy.Mul: + sort_fn = sympy.core.mul._mulsort + else: + raise ValueError(f"Unknown cls: {cls}") + + # we don't support non commutative with sort + assert is_commutative is True + if args[0].is_Number: + rest = args[1:] + sort_fn(rest) + return cls._from_args([args[0]] + rest, is_commutative=is_commutative) # type: ignore[attr-defined] + else: + args = args.copy() + sort_fn(args) + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] + else: + # if the args are already sorted, we create directly + return cls._from_args(args, is_commutative=is_commutative) # type: ignore[attr-defined] + + +def _canonicalize_bool_expr_impl(expr: SympyBoolean) -> SympyBoolean: + """ + After canonicalization, we are guaranteed to have eliminated Ge/Gt relations + (rewriting them to Le/Lt, respectively). + """ + if isinstance(expr, (sympy.And, sympy.Or)): + return type(expr)(*map(canonicalize_bool_expr, expr.args)) + + opposite = {sympy.Gt: sympy.Lt, sympy.Ge: sympy.Le} + t: Union[type[Any]] + if isinstance(expr, tuple(opposite.keys())): + rhs = expr.lhs - expr.rhs # type: ignore[attr-defined] + t = opposite[type(expr)] # type: ignore[index] + else: + assert isinstance(expr, (sympy.Lt, sympy.Le, sympy.Eq, sympy.Ne)) + rhs = expr.rhs - expr.lhs + t = type(expr) + + def is_neg(t: sympy.Expr) -> bool: + return (t.is_Number and t.is_negative) or ( + isinstance(t, sympy.Mul) and t.args[0].is_Number and t.args[0].is_negative + ) + + lhs = S.Zero + rhs = _reduce_to_lowest_terms(rhs) + if isinstance(rhs, sympy.Add): + pos = [] + neg = [] + for term in rhs.args: + if is_neg(term): + neg.append(-term) + else: + pos.append(term) + # these are already sorted + rhs = _sympy_from_args(sympy.Add, pos, sort=False, is_commutative=True) + # the terms were changed, so needs a sorting + lhs = _sympy_from_args(sympy.Add, neg, sort=True, is_commutative=True) + elif is_neg(rhs): + # lhs == 0 + lhs, rhs = -rhs, S.Zero + # We don't have to evaluate here because lhs, rhs came from a Boolean + # and it was already simplified + return t(lhs, rhs, evaluate=False) + + +def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr: + """ + Eliminates any integer factor from a given expression. + E.g., 6x + 4y reduces to 3x + 2y. + + Useful when an expression is == or != to 0. + """ + + def integer_coefficient(x: sympy.Expr) -> int: + if x.is_Integer: + return abs(int(x)) + elif x.is_Mul: + # If one of the args of a Mul is an Integer, it is the + # first arg. eg: args(2*x*3*y) == (6, x, y) + return abs(int(x.args[0])) if x.args[0].is_Integer else 1 # type: ignore[call-overload] + else: + return 1 + + def div_by_factor(x: sympy.Expr, factor: int) -> sympy.Expr: + if x.is_Integer: + return x / factor + elif x.is_Mul: + if x.args[0] != factor: + args = [x.args[0] / sympy.Integer(factor), *x.args[1:]] + else: + # Mul._from_args require a canonical list of args + # so we remove the first arg (x.args[0] / factor) if it was 1 + args = list(x.args[1:]) + return _sympy_from_args(sympy.Mul, args, is_commutative=x.is_commutative) + else: + raise AssertionError(f"illegal arg to div_by_factor: {x}") + + if expr.is_Add: + atoms = cast(Sequence[sympy.Expr], expr.args) + factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) + if factor == 1: + return expr + atoms = [div_by_factor(x, factor) for x in atoms] + return _sympy_from_args( + sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative + ) + elif expr.is_Integer: + return S.One + elif expr.is_Mul: + return div_by_factor(expr, integer_coefficient(expr)) + return expr + + +def is_nested_int(s: IntLikeType) -> TypeGuard[SymInt]: + return isinstance(s, torch.SymInt) and s.node.is_nested_int() + + +IterateExprsAtom: TypeAlias = Union[ + SymInt, SymFloat, SymBool, int, float, bool, sympy.Basic, torch.Tensor +] +IterateExprs: TypeAlias = Union[IterateExprsAtom, Sequence[IterateExprsAtom]] + + +def _iterate_exprs(val: IterateExprs) -> Iterator[sympy.Basic]: + """ + Recursively iterate through a value and yield all sympy expressions contained within it. + + This function traverses various data structures (tensors, lists, tuples, etc.) and extracts + any symbolic expressions they contain. It's used for operations like finding free symbols + in complex nested structures. + + Args: + val: The value to extract sympy expressions from. Can be a symbolic type (SymInt, SymFloat, SymBool), + a sympy expression, a primitive type (int, float, bool), a container (tuple, list), + a sparse tensor, a regular tensor, None, or a torch.Generator. + + Yields: + sympy.Basic: Each sympy expression found in the value. + + Raises: + AssertionError: If the value is of an unsupported type. + """ + if isinstance(val, SymTypes): + # This allow applies to the jagged layout NestedTensor case as + # nested ints are not symbolic + if is_symbolic(val): + yield val.node.expr + elif isinstance(val, sympy.Basic): + yield val + elif isinstance(val, (int, float, bool)): + pass + elif isinstance(val, (tuple, list)): + for s in val: + yield from _iterate_exprs(s) + elif is_sparse_any(val): + yield from _iterate_exprs(val.size()) + elif isinstance(val, torch.Tensor): + yield from _iterate_exprs(val.size()) + yield from _iterate_exprs(val.stride()) + yield from _iterate_exprs(val.storage_offset()) + elif val is None: + pass + # see Note: [Generator arguments in AOTDispatcher] + elif isinstance(val, torch.Generator): + pass + else: + raise AssertionError(f"cannot extract sympy expressions from {val} {type(val)}") + + +def free_symbols(val: IterateExprs) -> OrderedSet[sympy.Symbol]: + """ + Recursively collect all free symbols from a value. + + This function traverses various data structures (tensors, lists, tuples, etc.) and extracts + all sympy symbols contained within them. It's useful for finding all symbolic variables + that a complex nested structure depends on. + + Args: + val: The value to extract symbols from. Can be a symbolic type (SymInt, SymFloat, SymBool), + a container (tuple, list), a tensor, or None. + + Returns: + OrderedSet[sympy.Symbol]: An ordered set of all free symbols found in the value. + """ + if val is None: + return OrderedSet() + + itr = _iterate_exprs(val) + + # we need at least 1 to call union, so we hand code the identity + try: + first_expr = next(itr) + except StopIteration: + return OrderedSet() + + # TODO: Apparently, returning an OrderedSet here breaks + # python test/distributed/tensor/test_dtensor_compile.py TestDTensorCompile.test_dtensor_dynamic + return first_expr.free_symbols.union(*(e.free_symbols for e in itr)) # type: ignore[return-value] + + +def has_free_symbols(val: IterateExprs) -> bool: + """Faster version of bool(free_symbols(val))""" + return not all((e.is_number or e.is_Boolean) for e in _iterate_exprs(val)) + + +def has_free_unbacked_symbols(x: IterateExprs) -> bool: + """Faster version of bool(free_unbacked_symbols(val))""" + from sympy.core.traversal import iterargs + + for s in _iterate_exprs(x): + for arg in iterargs(s): + if arg.is_Symbol and symbol_is_type( + arg, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT) + ): + return True + return False + + +def free_unbacked_symbols(x: IterateExprs) -> OrderedSet[sympy.Symbol]: + """Like free_symbols, but filtered to only report unbacked symbols""" + + # NB: keep synced with is_unbacked_symint + return OrderedSet( + s + for s in free_symbols(x) + if symbol_is_type(s, (SymT.UNBACKED_INT, SymT.UNBACKED_FLOAT)) + ) + + +# WARNING: Don't use this on Dynamo produced graphs, they don't have meta +# setup! +def is_symbol_binding_fx_node(node: torch.fx.Node) -> Optional[sympy.Symbol]: + """ + Check if a given FX node is a symbol binding node. + + A symbol binding node is one that has a SymInt value in its meta that contains + a sympy Symbol expression, and is either a placeholder node or contains unbacked symbols. + + Args: + node (torch.fx.Node): The FX node to check + + Returns: + Optional[sympy.Symbol]: The sympy Symbol if the node is a symbol binding node, None otherwise + """ + if ( + "val" in node.meta + and isinstance(node.meta["val"], torch.SymInt) + and isinstance(node.meta["val"].node.expr, sympy.Symbol) + and ( + node.op == "placeholder" + or free_unbacked_symbols(node.meta["val"].node.expr) + ) + ): + return node.meta["val"].node.expr + return None + + +def find_symbol_binding_fx_nodes( + graph: torch.fx.Graph, +) -> dict[sympy.Symbol, torch.fx.Node]: + """ + Find all nodes in an FX graph that bind sympy Symbols. + + This function scans through all nodes in the given FX graph and identifies + nodes that bind sympy Symbols (typically placeholder nodes with SymInt values). + When multiple nodes bind the same symbol, only the first occurrence is kept. + + Args: + graph: The FX graph to search for symbol binding nodes + + Returns: + A dictionary mapping from sympy Symbols to their binding FX nodes + """ + r = {} + # NB: Prefer first occurrence of symbol + for node in graph.nodes: + if (s := is_symbol_binding_fx_node(node)) is not None and s not in r: + r[s] = node + return r + + +@dataclass(frozen=True) +class Specialization: + """ + This class is used in multi-graph compilation contexts where we generate + multiple specialized graphs and dispatch to the appropriate one at runtime. + This allows us to optimize the trade-off between performance and generality + by creating specialized versions for common patterns (e.g., x.shape[0] % 16 == 0) + while maintaining a general fallback. + """ + + source: TensorPropertySource + check_fn: Callable + + +# Analogous to ConvertIntSource +@dataclass(frozen=True) +class ConvertIntKey: + def __str__(self) -> str: + return ".cast_symbool_to_symint_guardless()" + + def get(self, b: bool) -> IntLikeType: + """Get the int value from bool""" + return cast_symbool_to_symint_guardless(b) + + +@dataclass(frozen=True) +class CallMethodKey: + name: str + + def __str__(self) -> str: + return f".{self.name}()" + + def get(self, o: Any) -> Any: + """Call the method on object""" + return getattr(o, self.name)() + + +@dataclass(frozen=True) +class InnerTensorKey: + inner_name: str + + def __str__(self) -> str: + return f".{self.inner_name}" + + def get(self, o: Any) -> Any: + """Get the inner tensor attribute""" + return getattr(o, self.inner_name) + + +@dataclass(frozen=True) +class DivideByKey: + divisor: IntLikeType + + def __str__(self) -> str: + return f".__floordiv__({self.divisor})" + + def get(self, o: int) -> int: + """Divide object by divisor""" + return o // self.divisor + + +def _free_unbacked_symbols_with_path( + a: object, + path: pytree.KeyPath, + real: Optional[object] = None, + shape_env: Optional[ShapeEnv] = None, + pending: Optional[set[sympy.Symbol]] = None, + simplify: bool = False, +) -> dict[sympy.Symbol, pytree.KeyPath]: + """ + Recursively traverses a structure to find unbacked symbols and their access paths. + + This function walks through tensors, lists, tuples, and symbolic values to locate + unbacked symbols that are in the pending set, and returns a mapping from those + symbols to their access paths in the structure. + + Args: + a: The object to traverse (tensor, list, tuple, SymInt, etc.) + path: The current path in the object tree + real: Optional real tensor corresponding to the fake tensor being traversed + shape_env: Optional ShapeEnv to register unbacked values with + pending: Set of unbacked symbols to look for (will be modified in-place) + simplify: Whether to use simplified expressions + + Returns: + A dictionary mapping unbacked symbols to their access paths + """ + go = functools.partial( + _free_unbacked_symbols_with_path, + shape_env=shape_env, + pending=pending, + simplify=simplify, + ) + + def expr(s: Union[SymInt, SymFloat, SymBool]) -> sympy.Expr: + if simplify: + return s.node.expr + # (When called from compute_unbacked_bindings) + # NB: Intentionally access _expr, not expr, do not want + # simplification! + return s.node._expr + + if pending is None: + pending = set() + r = {} + if isinstance(a, (tuple, list)): + # NB: real is apparently not always a tuple/list here + # python test/inductor/test_torchinductor.py CpuTests.test_index_propagation_nested_indirect_indexing_cpu + for i in range(len(a)): + r.update( + go( + a[i], + path + (pytree.SequenceKey(i),), + real=real[i] if real is not None else None, # type: ignore[index] + ) + ) + elif is_traceable_wrapper_subclass(a): + # TODO: Determine if this is correct + attrs, _ = a.__tensor_flatten__() + for attr in attrs: + sub = getattr(a, attr) + r.update(go(sub, path + (InnerTensorKey(attr),))) + elif isinstance(a, torch.Tensor): + from torch._subclasses.fake_tensor import FakeTensor + + assert isinstance(a, FakeTensor) + r.update( + go( + a.size(), + path + (CallMethodKey("size"),), + real=a.real_tensor.size() if a.real_tensor is not None else None, + ) + ) + if a.layout not in [ + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc, + ]: + r.update( + go( + a.stride(), + path + (CallMethodKey("stride"),), + real=a.real_tensor.stride() if a.real_tensor is not None else None, + ) + ) + r.update( + go( + a.storage_offset(), + path + (CallMethodKey("storage_offset"),), + real=( + a.real_tensor.storage_offset() + if a.real_tensor is not None + else None + ), + ) + ) + + elif ( + isinstance(a, (torch.SymInt, torch.SymFloat)) + and isinstance(s := expr(a), sympy.Symbol) + and s in pending + ): + r[s] = path + if shape_env and real is not None: + assert isinstance(real, (int, float)) + shape_env.set_unbacked_var_to_val(s, real) + pending.remove(s) + # When an unbacked SymInt is perfectly divisible by an integer + # constant, we replace it with the integer constant to improve + # reasoning capabilities. However, in synthetic examples, it is + # then possible that the factor never is explicitly allocated. + # Fortunately, we can compute it by division. + elif ( + isinstance(a, torch.SymInt) + and isinstance(s := expr(a), sympy.Mul) + and len(s.args) == 2 + and isinstance(lhs := s.args[0], (sympy.Integer, sympy.Symbol)) + and isinstance(rhs := s.args[1], sympy.Symbol) + # support exactly one unbacked for now + and ((rhs in pending) ^ (lhs in pending)) + # support constant coefficient or backed symbolic coefficient + and ( + isinstance(coeff := lhs if lhs not in pending else rhs, sympy.Integer) + or shape_env + and coeff in shape_env.var_to_val + ) + ): + + def _symint_wrap(s: sympy.Symbol) -> SymInt: + return shape_env.create_symintnode( # type: ignore[union-attr] + s, + hint=int(shape_env.var_to_val[s]), # type: ignore[union-attr] + source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr] + ) + + unbacked = lhs if lhs in pending else rhs + divisor: IntLikeType = ( + int(coeff) + if shape_env and isinstance(coeff, sympy.Integer) + else _symint_wrap(coeff) + ) + # TODO: DivideByKey needs to test divisibility at runtime! + r[unbacked] = path + (DivideByKey(divisor),) + if real is not None: + assert isinstance(real, int) + val = ( + real // int(coeff) + if isinstance(coeff, sympy.Integer) + else CleanDiv(real, coeff) + ) + if shape_env: + shape_env.set_unbacked_var_to_val(unbacked, val) + pending.remove(unbacked) + # The annoyance here arises from the fact that SymBool is + # allocated by allocating a SymInt and then testing if it's equal + # to one. So you have a complicated binding site logic for this. + elif ( + isinstance(a, torch.SymBool) + and isinstance(s := expr(a), sympy.Eq) + # This must match create_unbacked_symbool EXACTLY + and isinstance(s.lhs, sympy.Symbol) + and s.rhs == 1 + and s.lhs in pending + ): + r[s.lhs] = path + (ConvertIntKey(),) + if real is not None: + assert type(real) is bool + if shape_env: + shape_env.set_unbacked_var_to_val(s, int(real)) + pending.remove(s.lhs) + + return r + + +def compute_unbacked_bindings( + shape_env: Optional[ShapeEnv], + example_value: object, + old_example_value: Optional[object] = None, + peek: bool = False, +) -> Optional[dict[sympy.Symbol, pytree.KeyPath]]: + """ + After having run fake tensor propagation and producing example_value + result, traverse example_value looking for freshly bound unbacked + symbols and record their paths for later. It is an error if + we have allocated an unbacked SymInt but it cannot be found in + example_value. (NB: this means if you have a multi-output + function, you must call this on the tuple of tensor output, you + cannot wait!) + + The peek parameter lets you check out what the bindings are without + changing the affected list. This is primarily useful for ensuring + unbacked_var_to_val is promptly populated when propagate_real_tensors is on. + """ + if shape_env is None: + return None + + fs = shape_env.pending_fresh_unbacked_symbols + pending = set(fs) + if not pending: + return None + + if not peek: + log.info("compute_unbacked_bindings %s", fs) + fs.clear() + + symbol_to_path = _free_unbacked_symbols_with_path( + example_value, (), shape_env=shape_env, pending=pending, simplify=False + ) + if not peek and pending: + extra = ( + repr((example_value.stride(), example_value.storage_offset())) + if isinstance(example_value, torch.Tensor) + else "" + ) + raise PendingUnbackedSymbolNotFound( + f"Pending unbacked symbols {pending} not in returned outputs {example_value} {extra}.\n" + "Did you accidentally call new_dynamic_size() or item() more times " + "than you needed to in your fake implementation?\n" + "For more help, see https://docs.google.com/document/d/1RWrH-3wLEpzR9kCS6gGBNen_-Fs-8PVbWWFE5AcgeWE/edit" + ) + + # Why do we have to do some rebinding here? If the original FX node + # wasn't a binding site because you had a memo hit, but post + # translation you aren't a memo hit anymore, there's now a new binding + # site... but we know (because it's the same FX node) that the value + # is actually the same, they're just not obviously equal anymore. + # + # The logic here is written carefully, because unlike the + # bind_unbacked case, we are not guaranteed to have a symbol for + # old_sym. If we have a symbol, do regular rename unbacked to; but if + # we don't, we need to specially eliminate the fresh unbacked symbol + # (NB: we are /trusting/ that the memoization is correct, and that we + # don't need to generate a new runtime assert. This is load bearing, + # as repropagation can happen after we've frozen runtime asserts.) + if old_example_value is not None: + for keypath in symbol_to_path.values(): + old_sym = pytree.key_get(old_example_value, keypath) + new_sym = pytree.key_get(example_value, keypath) + if isinstance(new_sym, SymTypes) and isinstance( + new_s := new_sym.node.expr, sympy.Symbol + ): + if ( + isinstance(old_sym, SymTypes) + and (old_s := old_sym.node.expr) != new_s + ): + if isinstance(old_s, sympy.Symbol): + shape_env._rename_unbacked_to(new_s, old_s) + else: + shape_env._eliminate_unbacked(new_s, old_s) + elif not isinstance(old_sym, SymTypes): + shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) + + return symbol_to_path + + +# Note [guard_or_] +# The following two functions are common utilities used while defining unbacked semantics +# of various framework code. Those would be used in situations you prefer to guard and know +# the result of the expression over not guarding, but in case you hit a data dependent error +# you are ok with just returning true or false. +# +# When to use this? +# (1) If you can use a higher level combinator prefer using those instead, they are definitely safe (modulo short-circuiting). +# +# (2) It can be used if the program would behave equivalently if _guard_or returned true or false. +# Many inductor optimizations fall in this bracket for example. +# +# (3) Finally, it's even be OK if the program wouldn't behave equivalently, so long as the +# change is semantics preserving. It can be semantics preserving if the program errors in more +# cases than it did previously (but otherwise behaves identically), or if it changes some quantity +# in a way that doesn't matter (e.g., strides often fall in this bucket.) +# +# (4) Specialize for the general case and add a runtime assertion that would fail during +# runtime if the conditions for the general case are not satisfied. Examples for this are; +# assuming expand/reshape inputs are not -1. or assuming the non-broadcasting path. +# +def _guard_or(a: BoolLikeType, default: bool) -> bool: + """ + Try to guard a, if data dependent error encountered just return default. + """ + if not isinstance(a, SymBool): + assert isinstance(a, bool) + return a + + # if backed_size_oblivious is True we treat backed as unbacked here. + if torch.fx.experimental._config.backed_size_oblivious: + result = _static_eval_sym_bool(a) + return result if result is not None else default + + shape_env = getattr(a.node, "shape_env", None) + + # xla symnode path. + if shape_env is None: + return guard_bool(a) + + sym_node = a.node + r = sym_node.shape_env.evaluate_sym_node( + sym_node, size_oblivious=False, fallback_value=default + ) + return bool(r) + + +def guard_or_false(a: BoolLikeType) -> bool: + """ + Try to guard a, if data dependent error encountered just return false. + """ + return _guard_or(a, False) + + +def guard_or_true(a: BoolLikeType) -> bool: + """ + Try to guard a, if data dependent error encountered just return true. + """ + return _guard_or(a, True) + + +def _static_eval_sym_bool(x: SymBool) -> Optional[bool]: + assert isinstance(x, SymBool) + expr = x.node.expr + + try: + # Shape env access is inside the try on purpose. xla symnode does not + # have it on its attributes. + shape_env = x.node.shape_env + simplified = shape_env._maybe_evaluate_static(expr) + if simplified is not None: + return bool(simplified) + else: + return None + except Exception: + log.debug("Could not simplify %s", expr) + return None + + +def statically_known_false(x: BoolLikeType) -> bool: + """ + Returns True if x can be simplified to a constant and is False. + If x cannot be evaluated from static, we return False + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to False at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + """ + if not isinstance(x, SymBool): + assert isinstance(x, bool) + return not x + + result = _static_eval_sym_bool(x) + if result is None: + return False + + return not result + + +def statically_known_true(x: BoolLikeType) -> bool: + """ + Returns True if x can be simplified to a constant and is true. + + .. note:: + This function doesn't introduce new guards, so the expression may end + up evaluating to true at runtime even if this function returns False. + + Args: + x (bool, SymBool): The expression to try statically evaluating + """ + if not isinstance(x, SymBool): + assert isinstance(x, bool) + return x + + result = _static_eval_sym_bool(x) + if result is None: + return False + + return result + + +def sym_and(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: + """ + and, but for symbolic expressions, without bool casting. + """ + if len(others) == 0: + return x + for y in others: + x = operator.and_(x, y) + return x + + +def sym_eq(x: _T, y: _T) -> BoolLikeType: + """ + Like ==, but when run on list/tuple, it will recursively test equality + and use sym_and to join the results together, without guarding. + """ + if isinstance(x, (tuple, list)) and isinstance(y, (list, tuple)): + if len(x) != len(y): + return False + return functools.reduce(operator.and_, map(sym_eq, x, y), True) + elif isinstance(x, (int, torch.SymInt)) and isinstance(y, (int, torch.SymInt)): + return x == y + else: + raise AssertionError(f"unexpected sym_eq between {type(x)} {type(y)}") + + +def sym_or(x: BoolLikeType, *others: BoolLikeType) -> BoolLikeType: + """ + or, but for symbolic expressions, without bool casting. + """ + if len(others) == 0: + return x + for y in others: + x = operator.or_(x, y) + return x + + +def guard_scalar( + a: Union[SymBool, SymInt, SymFloat, int, bool, float], +) -> Union[bool, int, float]: + """ + Guard a scalar value, which can be a symbolic or concrete boolean, integer, or float. + + This function dispatches to the appropriate guard function based on the type of the input. + + Args: + a: A symbolic or concrete scalar value (bool, int, or float) + + Returns: + The concrete value after guarding + + Raises: + AssertionError: If the input is not a recognized scalar type + """ + if isinstance(a, (SymBool, bool)): + return guard_bool(a) + elif isinstance(a, (SymInt, int)): + return guard_int(a) + elif isinstance(a, (SymFloat, float)): + return guard_float(a) + else: + raise AssertionError(f"unrecognized scalar {a}") + + +def _advise_is_size(a: SymInt) -> None: + """ + Don't use this directly; use torch._check_is_size instead. + + This is a softer version of _constrain_range_for_size (with min=0, + max=Inf). Instead of forcibly constraining a variable (and erroring if we + failed to constrain it), it will simply advise us that a size is + constrained in some way. We will always defer a runtime assert for this + constraint if we cannot prove it at compile-time, but we we only + *sometimes* learn useful extra information at compile-time with this + information. This is in contrast to constrain_range_for_size, where if + you don't call that on a fresh unbacked symint, chances are we will choke. + + TODO: Make Dynamo handle this appropriately if this is seen in Dynamo-ed + code. Right now this is only really used in code with AOTAutograd trace + through, so it is not a big problem that this isn't supported, but in + principle all of this code should be Dynamo'able too. + + TODO: I didn't support min/max because I didn't have a use case where this + actually helped. In principle we can support it, it just makes the + implementation below more complicated. + """ + + # This must always succeed, because the sole allowed caller _check_is_size + # was responsible for expect_true'ing this + # This assert triggers expensive sym compute, do not do it until its cheap. + # assert a >= 0 + + # NB: it's important not to constrain range for size for *hinted* SymInts, + # because it is not only unsound, it will immediately trip our asserts + # that hints have to be consistent with static analysis! If you somehow + # have an unbounded SymInt that later constrains to 1, this will be + # inconsistent with the range + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + ): + _constrain_range_for_size(a) + + +def _advise_is_bounded(a: SymInt, upper_bound: IntLikeType) -> None: + if ( + isinstance(a, SymInt) + and isinstance(a.node, SymNode) + and isinstance(a.node.expr, sympy.Symbol) + and a.node.shape_env.is_unbacked_symint(a.node.expr) + and isinstance(upper_bound, int) # TODO: relax + ): + a.node.shape_env._constrain_is_bounded(a.node.expr, upper_bound) + + +def _constrain_range_for_size( + a: SymInt, min: Optional[int] = None, max: Optional[int] = None +) -> None: + """ + This function is NOT INTENDED to be used by itself. + """ + + if isinstance(a, (SymFloat, SymBool)): + raise ValueError("Constraining SymFloat/SymBool is nyi") + + assert isinstance(a, SymInt), "can only constrain range for SymInt" + assert isinstance(a.node.expr, sympy.Symbol), f"constraining non-Symbols NYI: {a}" + + a.node.shape_env._constrain_range_for_size(a.node.expr, min, max) + + +# inclusive both ways +def constrain_range( + a: SymInt, *, min: Optional[int], max: Optional[int] = None +) -> None: + """ + Applies a constraint that the passed in SymInt must lie between min-max + inclusive-inclusive, WITHOUT introducing a guard on the SymInt (meaning + that it can be used on unbacked SymInts). If min/max are None, we assume + that the dimension is unbounded in that direction. Repeated application + of constrain_range intersects the ranges. This is a fairly low level API + that doesn't have a lot of safety guarantees (TODO: provide higher level + APIs). + + Currently, we use this API in the following circumstance: when we allocate + an unbacked SymInt, denoting an integer quantity which is data dependent, + we ordinarily do not know anything about what values it may take. This + means that any sort of guard on it will immediately fail. However, in + many cases, we know something about the unbacked SymInt: for example, we + know that nonzero(x).size(0) must be >= 0. We use constrain_range to + narrow the possible range, declaring that negative symbols are impossible. + This permits to definitely answer True to queries like 'nnz >= 0', even if + we don't know what the actual (hinted) value of 'nnz' is. In fact, we + actually use constrain_range to unsoundly discharge common guards: for an + unbacked SymInt produced by nonzero, we will also assume that it is not + equal to 0/1 (even though these are perfectly possible values at runtime), + because we generally expect graphs that are valid for N=2 to also be valid + for N=1. + """ + if min is None: + min = -int_oo + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + if isinstance(a, int): + if not (min <= a <= max): + raise ValueError(f"Invalid value {a} for range [{min}:{max}]") + return + + a.node.shape_env._constrain_range(a.node.expr, min, max) + + +def constrain_unify(a: torch.SymInt, b: torch.SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + return + else: + shape_env = b.node.shape_env + else: + shape_env = a.node.shape_env + + shape_env._constrain_unify(a, b) + + +# Assume that a boolean is true for the purposes of subsequent symbolic +# reasoning. This will keep track of corresponding runtime checks to verify +# that the result is upheld: either as a regular guard, or as a special set +# of asserts which are triggered when an unbacked SymInt is allocated. +# +# DO NOT use this function for these cases: +# +# - This is inappropriate for "branching" conditions (where both +# true and false result in valid programs). We will always assume +# the condition evaluates true, and so it will never be possible +# to trace the false condition when you use it. For true branching +# on unbacked SymInts, you must use torch.cond; if you incorrectly +# use expect_true in this case, you will make the false branch +# unreachable (as we will simply assume that only the true branch +# is ever exercised). +# +# - This is inappropriate for situations where you know some other system +# invariant guarantees that this property holds, since you don't +# really need to insert a runtime check in that case. Use something +# like constrain_range in that case. +# +# This API has a hitch. To avoid having to reimplement error reporting +# capabilities, this function CAN return False. The invariant is that +# the surrounding code must raise an error when this function returns +# False. This is quite low level, so we recommend using other functions +# like check() which enforce this in a more intuitive way. +# +# By the way, this name is a nod to the __builtin_expect macro, +# which is used similarly (but unlike __builtin_expect, you MUST fail +# in the unlikely branch.) (I think expect is a good name; in recent +# versions of C++, this is replaced with [[likely]], which is weaker +# and not accurate for this function!) +def expect_true(a: BoolLikeType, skip: int = 0) -> bool: + if isinstance(a, SymBool): + # TODO: check perf implications of this + frame = inspect.currentframe() + for _ in range(skip + 1): # always run this loop at least once + if frame is None: + break + frame = frame.f_back + return a.node.expect_true( + frame.f_code.co_filename if frame else "", frame.f_lineno if frame else 0 + ) + assert type(a) is bool, a + return a + + +def guard_bool(a: BoolLikeType) -> bool: + if isinstance(a, SymBool): + return a.node.guard_bool("", 0) # NB: uses Python backtrace + assert type(a) is bool, a + return a + + +def guard_int(a: IntLikeType) -> int: + if isinstance(a, SymInt): + return a.node.guard_int("", 0) # NB: uses Python backtrace + assert type(a) is int, a + return a + + +def guard_float(a: FloatLikeType) -> float: + if isinstance(a, SymFloat): + return a.node.guard_float("", 0) # NB: uses Python backtrace + assert isinstance(a, float), a + return a + + +# Given a GraphModule, return all the FakeTensors for all the placeholders +def fx_placeholder_vals(gm: torch.fx.GraphModule) -> list[object]: + return [n.meta["val"] for n in gm.graph.nodes if n.op == "placeholder"] + + +def fx_placeholder_targets(gm: torch.fx.GraphModule) -> list[str]: + return [n.target for n in gm.graph.nodes if n.op == "placeholder"] + + +# Given a GraphModule and arguments to run it with, evaluate that the guards +# for its associated ShapeEnv are satisfied by the passed arguments. This +# WILL check for duck sizing. +def eval_guards( + gm: torch.fx.GraphModule, *args: Tensor, ignore_static: bool = True +) -> bool: + return gm.shape_env.evaluate_guards_for_args( # type: ignore[operator, union-attr] + fx_placeholder_vals(gm), args, ignore_static=ignore_static + ) + + +def bind_symbols(gm: torch.fx.GraphModule, *args: Tensor) -> dict[sympy.Symbol, int]: + return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) # type: ignore[operator, union-attr] + + +class DimDynamic(Enum): + """ + Controls how to perform symbol allocation for a dimension. It is always + sound to default this to DYNAMIC, but the policies DUCK and STATIC can + result in better trace-time and compile-time performance, as they reduce + the number of allocated symbols and generally make your graph more static. + + NB: If we notice you've applied a constraint to the dimension, we will + force it to DYNAMIC for simplicity. + + DimDynamic is controlled by a variety of higher level UX features. + Currently: + + - In eager mode, the default policy is DUCK. + - The default is changed to STATIC with assume_static_by_default. + - An individual dim is marked DYNAMIC if you mark_dynamic_dim. + - In export mode, the default policy is STATIC. + - An individual dim is marked DYNAMIC if you specify it in + dynamic_shapes passed to export. + """ + + # Treat the dimension symbolically + DYNAMIC = 0 + # Treat the dimension symbolically, but if its hint matches another + # dynamic dimension, unify the two symbols ("duck sizing") + DUCK = 1 + # Treat the dimension statically based on its hint + STATIC = 2 + # Treat the dimension as a size-like unbacked + SIZE_LIKE_UNBACKED = 3 + # Infer the strides from stride. If size is static, strides will be static as well. + INFER_STRIDE = 4 + # Like SIZE_LIKE_UNBACKED, but there's a hint + OBLIVIOUS_SIZE = 5 + + +# NB: These constraints affect both clients and backends: given some +# constraint C, the client must pass inputs that satisfy the constraint, +# while a backend must not introduce guards BEYOND this constraint. +# For clarity, we document the implications on both sides for both the client +# and the backend. +# +# NB: These constraints are on a *single* dimension. In principle, we could +# also have multi-dimension constraints, but our guess is that this is not +# actually useful and so we are not supporting it right now. +# +# NB: Strict constraints are typically only suitable for export, as in eager +# a backend like inductor may validly introduce extra, discretionary guards +# to improve performance of code. A StrictMinMaxConstraint would be brittle +# under future optimizations performed by inductor; we don't guarantee +# eager code with StrictMinMaxConstraint will keep working in the future! + + +@dataclass(frozen=True) +class Constraint: + warn_only: bool + + +@dataclass(frozen=True) +class StrictMinMaxConstraint(Constraint): + """ + For clients: the size at this dimension must be within 'vr' (which + specifies a lower and upper bound, inclusive-inclusive) AND it + must be non-negative and should not be 0 or 1 (but see NB below). + + For backends: there must not be any guards on this dimension which + are not implied by the given lower and upper bound. Regardless of + the lower bound, the backend can assume the size is non-negative + and that it is not 0 or 1. + + An unbounded StrictMinMaxConstraint can be thought of as a strict version + of "RelaxedUnspecConstraint". + + NB: Export will often unsoundly assume that a graph works for 0/1, even + though at trace time we assumed size is not 0 or 1. The idea is that + if we produce a graph that works for a range of values, it will be OK + for N=0/1 too. + """ + + vr: ValueRanges + + def render(self, source: Source) -> str: + """Format the constrain equation""" + # TODO: better printing for -oo and oo + return f"{self.vr.lower} <= {source.name()} <= {self.vr.upper}" + + +@dataclass(frozen=True) +class RelaxedUnspecConstraint(Constraint): + """ + For clients: no explicit constraint; constraint is whatever is implicitly + inferred by guards from tracing. + + For backends: there must exist at least TWO possible values for the + size at this dimension which satisfy the guards for this dimension. + + In other words, this constraint helps us distinguish between "we don't + care if this dimension specializes or not" versus "this dimension must be + unspecialized." However, this constraint doesn't say very much about what + specialization is permitted; for example, if we guard on a size being + even, this would still be acceptable under an unspec constraint. This + makes RelaxedUnspecConstraint useful for eager mode, where your backend compiler + may add constraints to otherwise dynamic dimensions; we can't assert that + there are NO guards as this is brittle because compilers should be able to + add extra constraints. If you want to assert that there are no guards, + use StrictMinMaxConstraint with an unbounded ValueRanges. + """ + + def render(self, source: Source) -> str: + return f"RelaxedUnspecConstraint({source.name()})" + + +# NB: None here indicates the client constraint is whatever is implicitly +# inferred by guards from tracing, and that a backend can add whatever guards +# it wants (including fully specializing the value). +DimConstraint = Union[StrictMinMaxConstraint, RelaxedUnspecConstraint, None] + + +@dataclass(frozen=True) +class EqualityConstraint(Constraint): + """ + Represent and decide various kinds of equality constraints between input sources. + + A "source pair" is a pair of input sources for dynamic dimensions that + are specified equal. We represent `source_pairs` in a union-find forest + so that we can efficiently check whether two such sources are transitively equal. + + A "derived equality" relates an input source to an expression over a root. + The root can be another input source, corresponding to some dynamic dimension, + or a phantom symbol that does not directly represent any dynamic dimension. We + represent `derived_equalities` involving input sources in a transitively-closed map + so that we can efficiently check whether an input source is transitively equal to + a given expression over another input source. + (NOTE: In contrast, it is easy to decide whether an input source is transitively equal + to a given expression over a phantom symbol; such expressions are already in canonical + form and so the problem reduces to symbolic expression equality.) + """ + + source_pairs: list[tuple[Source, Source]] + derived_equalities: list[ + tuple[Source, Union[Source, sympy.Symbol], Callable[[sympy.Expr], sympy.Expr]] + ] + phantom_symbols: list[sympy.Symbol] + relaxed_sources: set[Source] + + _parents: dict[Source, Source] = field(init=False) + _defs: dict[Source, sympy.Expr] = field(init=False) + + def __post_init__(self) -> None: + """ + Pre-processing to answer queries `is_equal` and `is_derived` below. + + Example: Suppose we are given: + source_pairs [a = b, b = c] + derived_equalities [d = c + 1, e = d - 1] + We first construct a union find with source_pairs: + _parents = {a: a, b: a, c: a} + Then we compute canonical symbolic expressions, recursively applying derived_equalities + until we bottom out: + _defs = {d: c + 1, e: (c + 1) - 1 aka c} + """ + + # self._parents is a map from input sources to input sources where, conceptually, + # these are directed edges in a union-find forest + _parents: dict[Source, Source] = {} + object.__setattr__(self, "_parents", _parents) + # self._defs is a map from input sources to "canonical" symbolic expressions, + # i.e., unary expressions with symbols that corresponds to regular Dims (i.e., + # not derived Dims) + _defs: dict[Source, sympy.Expr] = {} + object.__setattr__(self, "_defs", _defs) + + for source1, source2 in self.source_pairs: + # preprocess into a union-find forest + self._union(self._find(source1), self._find(source2)) + for source, root, fn in self.derived_equalities: + # preprocess into a transitively-closed map + # NOTE(avik): we reuse the union-find forest for canonicalizing input sources + if isinstance(root, sympy.Symbol): + self._defs[self._find(source)] = fn(root) + else: + self._defs[self._find(source)] = fn(self._rewrite(root)) + + def _find(self, source: Source) -> Source: + # chase edges to find the root of this equivalence class + if source in self._parents: + return self._find(self._parents[source]) + else: + return source + + def _union(self, root1: Source, root2: Source) -> None: + # merge two equivalence classes by adding an edge from one root to the other + if root1 != root2: + self._parents[root1] = root2 + + def _rewrite(self, src: Source) -> sympy.Expr: + # always represent the given source by the root of its equivalence class + src = self._find(src) + if src in self._defs: + # simply look up the definition if it exists + # NOTE(avik): This works because definitions are always transitively-closed; + # otherwise we would have to do recursive rewriting. + return self._defs[src] + else: + # otherwise, create a symbol representing the source + return sympy.Symbol(src.name()) + + def is_equal(self, source1: Source, source2: Source) -> bool: + return ( + # check whether source1 and source2 have the same root + # or are relaxed + (src1 := self._find(source1)) in self.relaxed_sources + or (src2 := self._find(source2)) in self.relaxed_sources + or src1 == src2 + # check whether source1 is derived equal to source2 + or self.is_derived(source1, source2, lambda x: x) + ) + + def is_derived( + self, src: Source, symbol_src: Source, fn: Callable[[sympy.Expr], sympy.Expr] + ) -> bool: + # check whether both src and symbol_src have the same definition + return self._rewrite(src) == fn(self._rewrite(symbol_src)) + + +def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]: + assert isinstance(symbolic_context, SymbolicContext), ( + "Invalid symbolic_context object" + ) + assert type(symbolic_context) is not SymbolicContext, ( + "Illegal usage of symbolic_context ABC" + ) + return True + + +def _is_supported_equivalence(expr: sympy.Expr) -> bool: + # Currently supported Dim ops are linear expressions with integer coefficients. + # So check that expr only contains +, *, ints, and a single occurrence of a symbol. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(expr, (sympy.Add, sympy.Mul)): + if len(expr.args) > 2: + return False + lhs, rhs = expr.args + return (_is_supported_equivalence(lhs) and isinstance(rhs, sympy.Integer)) or ( + isinstance(lhs, sympy.Integer) and _is_supported_equivalence(rhs) + ) + return isinstance(expr, sympy.Symbol) + + +def _has_uninterpretable_sympy_function(expr: sympy.Basic) -> bool: + """ + Add functions that our sympy interpreter can't reify into FX nodes + """ + return expr.has( + torch.utils._sympy.functions.ToFloat, + torch.utils._sympy.functions.TruncToInt, + torch.utils._sympy.functions.CeilToInt, + ) + + +@dataclass(frozen=True) +class SymbolicContext: + """ + Data structure specifying how we should create symbols in + ``create_symbolic_sizes_strides_storage_offset``; e.g., should + they be static or dynamic. + + This is an abstract base class because we are probably going to add + another version of this that says "use exactly these SymInts, don't + allocate fresh symbols." + """ + + +@dataclass(frozen=True) +class SymIntSymbolicContext(SymbolicContext): + """ + Data structure specifying any constraints on a SymInt input + """ + + constraint: DimConstraint + + +_P1 = ParamSpec("_P1") +_T1 = TypeVar("_T1") + + +@dataclass(frozen=True) +class StatelessSymbolicContext(Generic[_P1, _T1], SymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by ``DimDynamic`` and ``DimConstraint``. + This will cause fresh symbols to be allocated + """ + + dynamic_sizes: DimList[DimDynamic] + dynamic_strides: DimList[DimDynamic] = None # type: ignore[assignment] + constraint_sizes: DimList[DimConstraint] = None # type: ignore[assignment] + constraint_strides: DimList[DimConstraint] = None # type: ignore[assignment] + specialize_on: Optional[list[list[Callable[_P1, _T1]]]] = None + # If the tensor is a view, this should be populated for the base. It contains + # information on how to allocate symbols when recursively fakeifying the base + # during view fake-ification. + view_base_context: Optional[SymbolicContext] = None + # TODO: add storage offset and stride symbolic_context + + def __post_init__(self) -> None: + if self.specialize_on is None: + object.__setattr__( + self, + "specialize_on", + [[]] * len(self.dynamic_sizes), + ) + if self.dynamic_strides is None: + object.__setattr__( + self, + "dynamic_strides", + [DimDynamic.INFER_STRIDE] * len(self.dynamic_sizes), + ) + if self.constraint_sizes is None: + object.__setattr__( + self, "constraint_sizes", [None] * len(self.dynamic_sizes) + ) + if self.constraint_strides is None: + object.__setattr__( + self, "constraint_strides", [None] * len(self.dynamic_sizes) + ) + assert all( + stride in (DimDynamic.INFER_STRIDE, DimDynamic.DYNAMIC, DimDynamic.DUCK) + for stride in self.dynamic_strides + ) + + +# note [Tensor Fakification and Symbol Caching] +# +# As of the time of this note, dynamo creates a fresh fake tensor mode for backends. +# The reason we do this is because there are certain classes of operations, namely, +# metadata mutations, that change tensor size, stride, etc. This means that the fake tensor +# state at the end of a dynamo trace is different than the fake tensor state at the beginning +# of a trace. Backends like aot_autograd need a fresh fake tensor to correctly track metadata mutation, +# view relationships, etc. +# +# As we create a new fake mode, we also lose the memoization that comes with it. Rather than +# transfer the memoization cache, we instead transfer the shape env. However, with this +# comes nuance - as dynamo is selective in how it makes symbolic shapes. Due to strategies in +# automatic dynamic and constraints, the policy for which dims are dynamic is nuanced and varies across +# recompilations. +# +# In order to preserve the symbolic decisions made during dynamo tensor fakification, we pass +# a StatefulSymbolicContext at creation time. This object is tracked, per tensor, on the TracingContext. +# The lifecycle of this object should match the lifecycle of the original dynamo tracked tensor, and it is +# safe to reuse this object as many times as necessary to create a fake tensor. Fake tensors +# created with new fake modes should produce the same exact symbols as the original, providing the same shape_env +# is used. +# TODO(voz): Shape env validation +@dataclass(frozen=True) +class StatefulSymbolicContext(StatelessSymbolicContext): + """ + Create symbols in ``create_symbolic_sizes_strides_storage_offset`` via + a symbolic_context determination as given by a cache of Source:Symbol. A cache hit + will reuse a stored symbol, and a cache miss will write to this cache. + + This behaves like StatelessSymbolicContext, except the cache supersedes the + other values - dynamic_sizes and constraint_sizes will not be read if we cache + hit. + + It is the cache owner's responsibility to maintain the lifecycle of the cache + with respect to different shape_envs, clearing, etc. + """ + + tensor_source: Source = None # type: ignore[assignment] + # Why is this keyed on int first? + # That integer is actually the id of the shape_env. This cache short-circuits symbol + # creation, and we must store it per shape env. Now, while tracing invariants are a single + # shape env per tracing context, and every new frame gets a new shape_env. So where would we have + # multiple shape envs? The answer lies in recording. When we are replaying, replay_shape_env_events + # is invoked, and creates a new shape_env. Replaying events against this new shape_env will + # cause it to fail with unknown symbols, as the symbols cached here will skip creation, and never + # get recorded in var_to_val, etc. + # TODO(voz): consider a weakref to the shape_env here + shape_env_to_source_to_symbol_cache: dict[int, dict[str, sympy.Expr]] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + super().__post_init__() + # The None default is annoying, but required because of dataclass limitations + assert self.tensor_source is not None + if not self.shape_env_to_source_to_symbol_cache: + object.__setattr__(self, "shape_env_to_source_to_symbol_cache", {}) + + +@dataclass(frozen=True) +class SubclassSymbolicContext(StatefulSymbolicContext): + """ + The correct symbolic context for a given inner tensor of a traceable tensor subclass + may differ from that of the outer symbolic context. This structure allows for this + flexibility, with inner symbolic contexts mapped via attr -> symbolic context. + """ + + inner_contexts: dict[str, SymbolicContext] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + super().__post_init__() + if self.inner_contexts is None: + self.inner_contexts = {} + + +@dataclass +class TrackedFake: + """ + Tracks the sources of all fake tensors we wrap in Dynamo. + Used by shape guard computation. + """ + + fake: Union[FakeTensor, SymInt] + source: Source + symbolic_context: Optional[SymbolicContext] + + def __hash__(self) -> int: + return hash((self.fake, self.source.name())) + + def __eq__(self, other: object) -> bool: + if isinstance(other, TrackedFake): + return self.fake is other.fake and self.source.name() == other.source.name() + return False + + +def is_symbolic( + val: Union[int, SymInt, float, SymFloat, bool, SymBool], +) -> TypeGuard[Union[SymInt, SymFloat, SymBool]]: + if isinstance(val, (int, float, bool)): + return False + return val.node.is_symbolic() + + +IndicatorTypes = (IsNonOverlappingAndDenseIndicator,) + + +def _expandsums(args: list[sympy.Expr]) -> tuple[sympy.Expr, bool]: + """ + Expand products of sums into sums of products. + + This function takes a list of sympy expressions and separates them into + additive expressions (those with is_Add=True) and other expressions. + It then computes the distributive product, expanding (a+b)*(c+d) into a*c + a*d + b*c + b*d. + + Args: + args: A list of sympy expressions to expand + + Returns: + A tuple containing: + - The expanded expression as a sympy.Expr + - A boolean indicating whether expansion occurred (True if multiple additive + expressions were present or if there was at least one additive and one other expression) + """ + adds, other = [], [] + for arg in args: + if arg.is_Add: + adds.append(arg) + else: + other.append(arg) + + result = [sympy.Mul(*other)] + for add in adds: + result = [a * b for a, b in itertools.product(result, add.args)] + + result = sympy.Add(*result) + return result, len(adds) > 1 or (len(adds) > 0 and len(other) > 0) + + +def _fast_expand(expr: _SympyT) -> _SympyT: + """ + A faster implementation of sympy's expand function for common cases. + + This function expands expressions like (a+b)^n or (a+b)*(c+d) into sums of products, + but avoids the expensive checks and features of sympy's full expand implementation. + It only recreates objects when necessary to avoid expensive operations. + + Args: + expr: A sympy expression to expand + + Returns: + The expanded expression + """ + + # The expand algorithm in sympy is slow due to all the features is supports + # For eg: e^(-x)*(x-1)/(x+1) is expanded to (x-1)/(e^x + e^x*x) if x is + # positive and (e^(-x)*x-e^(-x))/(x+1) if x is negative. We do not implement + # such features here to avoid expensive checks. We also make sure that we + # only re-create the objects if any of the args changed to avoid expensive + # checks when re-creating objects. + new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] + if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): + return _fast_expand(expr.func(*new_args)) + + if expr.is_Pow: + base: sympy.Expr + exp: sympy.Expr + base, exp = expr.args # type: ignore[assignment] + if exp.is_Integer and base.is_Add: + if exp > 1: + return sympy.expand_multinomial(expr, deep=False) + elif exp < 0: + return S.One / sympy.expand_multinomial(S.One / expr, deep=False) + elif expr.is_Mul: + num: list[sympy.Expr] = [] + den: list[sympy.Expr] = [] + for arg in expr.args: + if arg.is_Pow and arg.args[1] == -1: + den.append(S.One / arg) # type: ignore[operator, arg-type] + else: + num.append(arg) # type: ignore[arg-type] + + num, num_changed = _expandsums(num) + den, den_changed = _expandsums(den) + if num_changed or den_changed: + return num / den + + return expr + + +@lru_cache(256) +def safe_expand(r: _SympyT) -> _SympyT: + """ + Expand the given symbolic expression by recursively rewriting product of + sums into sum of products (with the product being either a multiplication or + exponentiation). + + NOTE: using this on an intermediate expression may prevent simplification + down the line, e.g., if we eagerly expand `(a + b)^2` into `a^2 + 2ab + b^2`, + we won't be able to simplify `(a^2 + 2ab + b^2) / (a + b)` as easily. + """ + if hasattr(r, "expand"): + try: + return _fast_expand(r) + except RecursionError: + log.warning("RecursionError in _fast_expand(%s)", r) + return r + else: + return r + + +class _SymbolInfo(NamedTuple): + k: sympy.Symbol + vr: Optional[ValueRanges] + val: Optional[sympy.Integer] + is_size_like: bool + + +@lru_cache(None) +def _maybe_evaluate_static_worker( + expr: _SympyT, + # NB: this is a tuple to ensure it can be LRU cached + symbol_info: tuple[_SymbolInfo, ...], + unbacked_only: bool, + size_oblivious: bool, +) -> Optional[_SympyT]: + """ + This variant of ShapeEnv._maybe_evaluate_static has no dependence on + ShapeEnv and thus can be cached indefinitely. It does the "heavy" lifting + for static evaluation, including nontrivial reliance on Sympy simplification + that occurs when we reallocate the symbols + """ + + # Simplify making use of value range lower bound + new_shape_env = {} + new_range_env = {} + for idx, sinfo in enumerate(symbol_info): + k, vr, val, is_size_like = sinfo + if isinstance(val, SingletonInt): + # Skip var_ranges logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + assert vr is not None + if size_oblivious and is_size_like: + lower = max(2, vr.lower) + # Clamping size-oblivious to some quantity below sys.maxsize + # helps us determine that f(u0) != sys.maxsize, which is a + # test that is looking for sys.maxsize as a sentinel, but you + # don't really want to worry about it for unbacked SymInts. + # This is similar to the flavor where size oblivious omits + # 0/1, it changes semantics but in a benign way. + upper = min(2**48, vr.upper) + # Excluding the very upper bound can be helpful + if upper > lower: + upper = upper - 1 + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= upper: + vr = ValueRanges(lower, upper) + else: + lower = vr.lower + # Don't do anything if we don't have a nontrivial lower bound + # Also don't do anything if we asked only to simplify unbacked + # SymInt + if lower is -int_oo or (unbacked_only and val is not None) or not vr.is_int: + new_range_env[k] = vr + continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # + # Positive means >= 1 + # Positive - 1 means >= 0 + # Positive + lower - 1 means >= lower + # The new symbol 's' is "too low", so when we substitute it in + # we have to increase it by offset (and conversely, the new + # variables have to have their value range bounds adjusted as + # well) + s = sympy.Symbol(f"evaluate_static_shape_{idx}", positive=True, integer=True) + + # Note: + # Offset might be a fraction(e.g. aten.split.Tensor), but shapes are always integers. + # Sympy might give unexepected results when comparing an integer with a non-integer + # Therefore, we cast offset to int here. + # For example: + # shape_0 = sympy.Symbol("shape_0", positive=True, integer=True) + # expr = sympy.Eq(shape_0 - 1/3, 4) + # expr.xreplace({}) # False + offset = int(lower - 1) + new_shape_env[k] = s + offset + new_range_env[s] = SymPyValueRangeAnalysis.add(vr, -offset) + + # TODO: remove this try catch (esp for unbacked_only) + try: + new_expr = expr.xreplace(new_shape_env) + except RecursionError: + log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) + return None + + # We need to canonicalize, as after expand we may have something like `a + b = a` and + # sympy will not simplify the a. The two appeareances of the a will then make value ranges + # analysis give lose bounds + new_expr = canonicalize_bool_expr(safe_expand(new_expr)) + if new_expr.is_number: + return new_expr + + # Check if the range can solve it statically + out = bound_sympy(new_expr, new_range_env) + if out.is_singleton(): + return out.lower + + return new_expr if unbacked_only else None + + +def error() -> NoReturn: + raise AssertionError("shouldn't be hit") + + +# TODO: Deduplicate this with torch/_prims_common/__init__.py +def eval_is_non_overlapping_and_dense( + sizes: Sequence[int], strides: Sequence[int] +) -> int: + return int(guard_bool(_eval_is_non_overlapping_and_dense(sizes, strides))) + + +def _eval_is_non_overlapping_and_dense( + sizes: Sequence[int], strides: Sequence[int] +) -> bool: + """ + Evaluates whether a tensor with the given sizes and strides is non-overlapping and dense. + + A tensor is non-overlapping if there's no memory location that belongs to more than one element. + A tensor is dense if all elements are stored in memory without gaps. + + Args: + sizes: Sequence of dimension sizes for the tensor + strides: Sequence of strides for the tensor + + Returns: + True if the tensor is non-overlapping and dense, False otherwise + """ + dim = len(sizes) + + # Short-circuits for tensors of rank one, which are + # non-overlapping and "dense" if their stride is one + # or it is a 0/1 element tensor + if dim == 1: + return strides[0] == 1 or sizes[0] < 2 + + # Checks that there exists a permutation of the strides s.t. the tensor would be contiguous + # Sorts (length, stride) pairs by stride + lengths_and_strides = sorted(zip(sizes, strides), key=operator.itemgetter(1)) + + # Unlike the C++ code, we don't move the 0/1 size dimensions to the + # end. So we have to keep going for this code. + expected_stride = 1 + for length, stride in lengths_and_strides: + if length == 1: + continue + + if stride != expected_stride: + return False + + expected_stride *= length + + return True + + +def _sympy_cast_symbool_to_symint_guardless(x: SympyBoolean) -> sympy.Expr: + return sympy.Piecewise((1, x), (0, True)) + + +def cast_symbool_to_symint_guardless( + symbool: Union[bool, torch.SymBool], +) -> Union[int, torch.SymInt]: + """ + Converts a SymBool or bool to a SymInt or int without introducing guards. + + This function maps True to 1 and False to 0, preserving the symbolic nature + of the input when it's a SymBool. Unlike regular casting which might introduce + guards, this function performs the conversion without adding any guards. + + Args: + symbool: A boolean value, either a concrete bool or symbolic SymBool + + Returns: + The corresponding integer value (1 for True, 0 for False) as either + a concrete int or symbolic SymInt + """ + if isinstance(symbool, bool): + return 1 if symbool else 0 + int_sym = _sympy_cast_symbool_to_symint_guardless(symbool.node.expr) + return symbool.node.shape_env.create_symintnode( + int_sym, hint=int(symbool.node.require_hint()) if has_hint(symbool) else None + ) + + +SYMPY_INTERP = { + "IsNonOverlappingAndDenseIndicator": eval_is_non_overlapping_and_dense, + "cast_symbool_to_symint_guardless": cast_symbool_to_symint_guardless, + "math": math, + "torch": torch, +} + + +def _lru_cache( + fn: Callable[..., _T], maxsize: Optional[int] = None +) -> functools._lru_cache_wrapper[_T]: + """ + Wrapper around lru_cache that clears when new info about shapes has been + updated. + + Use lru_cache if the output is always the same, regardless of the + constraints we know now (i.e. evaluate_expr) + + Use _lru_cache otherwise. + + Also note that this depends on _update_version_counter being called on the + shape environment whenever the constraints are updated, otherwise the cache + will not be cleared. + """ + fn_cache = lru_cache(maxsize)(fn) + prior_version = 0 + + if config.validate_shape_env_version_key: + prior_key = None + + @functools.wraps(fn) + def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: + nonlocal prior_version, prior_key + if prior_key is None: + prior_key = self._get_key() + + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + prior_key = self._get_key() + else: + assert prior_key == self._get_key(), ( + "ShapeEnv cache key changed without version being updated!" + ) + + return fn_cache(self, *args, **kwargs) + + else: + + @functools.wraps(fn) + def wrapper(self: ShapeEnv, *args: Any, **kwargs: Any) -> _T: # type: ignore[misc] + nonlocal prior_version + if prior_version != self._version_counter: + fn_cache.cache_clear() + prior_version = self._version_counter + + return fn_cache(self, *args, **kwargs) + + wrapper.cache_clear = fn_cache.cache_clear # type: ignore[attr-defined] + wrapper.cache_info = fn_cache.cache_info # type: ignore[attr-defined] + return wrapper # type: ignore[return-value] + + +@dataclass(frozen=True) +class RuntimeAssert: + """ + This is pretty similar to ShapeGuard but it also comes with a message, + and is exclusively used for things that MUST be true (unlike guards, + which can evaluate False, in which case you just choose not to use + a particular specialization) + """ + + expr: SympyBoolean + msg: str = field(repr=False) + stack: CapturedTraceback = field(repr=False) + + +# Used for printing SymExprs in compile_fx +class SymExprPrinter(PythonPrinter): + def _print_Float(self, expr: sympy.Float) -> str: + return str(float(expr)) + + +class _ShapeGuardPrinter(abc.ABC): + """ + Abstract base class for printers that convert symbolic expressions to string representations. + + This class provides common functionality for printing symbolic expressions with + special handling for symbols that represent tensor shapes, strides, etc. + Subclasses implement specific formatting for different output languages. + + Args: + symbol_to_source: Mapping from sympy symbols to their source objects + source_ref: Function to convert a source to its string representation + var_to_sources: Mapping from sympy symbols to their source objects (for error reporting) + """ + + def __init__( + self, + symbol_to_source: Mapping[sympy.Symbol, list[Source]], + source_ref: Callable[[Source], str], + var_to_sources: Mapping[sympy.Symbol, list[Source]], + ) -> None: + self.symbol_to_source = symbol_to_source + self.source_ref = source_ref + self.var_to_sources = var_to_sources + super().__init__() + + def _print_Float(self, expr: sympy.Float) -> str: + """Convert a sympy Float to a Python float string representation.""" + return str(float(expr)) + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + """ + Convert a sympy Symbol to its source representation. + + This method looks up the symbol in symbol_to_source mapping and returns + the string representation of its first source. + + Args: + expr: The sympy Symbol to convert + + Returns: + String representation of the symbol's source + + Raises: + AssertionError: If the symbol is not found in symbol_to_source + """ + assert isinstance(expr, sympy.Symbol), str(type(expr)) + + def repr_symbol_to_source() -> str: + return repr( + { + symbol: [s.name() for s in sources] + for symbol, sources in self.symbol_to_source.items() + } + ) + + assert self.symbol_to_source.get(expr), ( + f"{expr} (could be from {[s.name() for s in self.var_to_sources[expr]]}) " + f"not in {repr_symbol_to_source()}. If this assert is failing, it could be " + "due to the issue described in https://github.com/pytorch/pytorch/pull/90665" + ) + return self.print_source(self.symbol_to_source[expr][0]) + + @abc.abstractmethod + def print_source(self, source: Source) -> str: + """ + Convert a source object to its string representation. + + Args: + source: The source object to convert + + Returns: + String representation of the source + """ + ... + + @abc.abstractmethod + def doprint(self, expr: sympy.Expr) -> str: + """ + Convert a sympy expression to its string representation. + + Args: + expr: The sympy expression to convert + + Returns: + String representation of the expression + """ + ... + + +class ShapeGuardPythonPrinter(_ShapeGuardPrinter, PythonPrinter): + """ + Python printer for shape guards that extends the base ShapeGuardPrinter. + + This class provides functionality to print symbolic expressions as Python code, + with caching to improve performance when printing the same expressions multiple times. + It handles printing of sources and expressions according to Python syntax. + + Args: + *args: Arguments passed to the parent classes. + """ + + def __init__(self, *args: Any) -> None: + super().__init__(*args) + self._print_cache: dict[sympy.Expr, str] = {} + + def print_source(self, source: Source) -> str: + """ + Convert a source object to its string representation using the source_ref function. + + Args: + source: The source object to convert + + Returns: + String representation of the source + """ + return self.source_ref(source) + + def doprint(self, expr: sympy.Expr) -> str: + """ + Convert a sympy expression to its Python string representation with caching. + + This method first checks if the expression is already in the cache. + If found, it returns the cached result; otherwise, it delegates to + PythonPrinter's doprint method and caches the result. + + Args: + expr: The sympy expression to convert + + Returns: + String representation of the expression in Python syntax + """ + val = self._print_cache.get(expr, None) + if val is not None: + return val + else: + res = PythonPrinter.doprint(self, expr) + self._print_cache[expr] = res + return res + + +@deprecated( + "`torch.fx.experimental.symbolic_shapes.ShapeGuardPrinter` is deprecated, " + "please use `torch.fx.experimental.symbolic_shapes.ShapeGuardPythonPrinter` instead.", + category=FutureWarning, +) +class ShapeGuardPrinter(ShapeGuardPythonPrinter): + pass + + +class _ShapeGuardCppPrinter(_ShapeGuardPrinter, CppPrinter): + def __init__(self, *args: Any) -> None: + self.all_symbols: set[str] = set() + self.source_to_symbol: dict[Source, sympy.Symbol] = {} + super().__init__(*args) + + def print_source(self, source: Source) -> str: + if source in self.source_to_symbol: + return self.source_to_symbol[source].name + + source_name = source.name() + mangled_name = re.sub("[^0-9a-zA-Z_]+", "_", source_name) + old_mangled_name = mangled_name + count = 0 + while mangled_name in self.all_symbols: + mangled_name = f"{old_mangled_name}_{count}" + count += 1 + self.source_to_symbol[source] = sympy.Symbol(mangled_name) + self.all_symbols.add(mangled_name) + return mangled_name + + def doprint(self, expr: sympy.Expr) -> str: + return CppPrinter.doprint(self, expr) + + +# A dataclass for storing shape guards +@dataclass(frozen=True) +class _ShapeGuardsHelper: + exprs: list[str] + + +# A dataclass for storing C++ expressions and helper variables +@dataclass(frozen=True) +class _CppShapeGuardsHelper(_ShapeGuardsHelper): + source_to_symbol: dict[Source, sympy.Symbol] + + +class LoggingShapeGuardPrinter(ShapeGuardPythonPrinter): + def __init__(self, var_to_sources: Mapping[sympy.Symbol, list[Source]]): + super().__init__(var_to_sources, lambda n: n.name(), var_to_sources) + + +class DynamicDimConstraintPrinter(PythonPrinter): + """ + Printer for dynamic dim constraints. + - Instead of symbol s_k it prints its source t.size()[i] + - Instead of Eq(_, _), Mod(_, _), etc. it prints _ == _, _ % _, etc. + + We use this to suggest code for specifying dynamic dim constraints. + """ + + def __init__( + self, + symbol_to_source: dict[sympy.Symbol, list[Source]], + source_name_to_debug_name: Mapping[str, str], + ): + super().__init__() + self.symbol_to_source = symbol_to_source + self.source_name_to_debug_name = source_name_to_debug_name + + def _print_Symbol(self, expr: sympy.Symbol) -> str: + assert isinstance(expr, sympy.Symbol), str(type(expr)) + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) + return self.symbol_to_source[expr][0].name() + + +class DimConstraints: + """ + Custom solver for a system of constraints on symbolic dimensions. + Solutions are "static" values or simplified "dynamic" constraints. + """ + + def __init__( + self, + symbol_to_source: dict[sympy.Symbol, list[Source]], + var_to_val: Mapping[sympy.Symbol, sympy.Integer], + marked_dynamic: set[sympy.Symbol], + source_name_to_debug_name: Mapping[str, str], + ) -> None: + # We try to solve systems of inequalities with 1 free variable. + self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = ( + defaultdict(set) + ) + # Among them, we prioritize solving for a free variable that has equalities. + # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() + # and removing a symbol from the former => removing it from the latter. + self._symbols_with_equalities: set[sympy.Symbol] = set() + # A solution of a free variable with equalities becomes a substitution. + # We use these substitutions to simplify other constraints. + # NOTE: removing a symbol from _symbols_with_equalities => adding it to _substitutions. + self._substitutions: dict[sympy.Symbol, sympy.Integer] = {} + + # In general, constraints may have // and % operations. + # Of course, // can be expressed in terms of / and %. + # Our inequality solver can handle / but not %. So we need to transform them away. + # We do so by using the values of variables as hints to evaluate %. + # For soundness we record additional congruence guards and solve them separately. + self._var_to_val: Mapping[sympy.Symbol, sympy.Integer] = var_to_val + self._congruences: defaultdict[sympy.Symbol, set[sympy.Expr]] = defaultdict(set) + + # We do not try to (directly) solve inequalities with > 1 free variables. + # NOTE: free variables in these inequalities cannot also be in _substitutions. + self._multivariate_inequalities: set[SympyBoolean] = set() + + # We park external equalities between free variables here. + self._symbolic_equivalences: list[tuple[Source, sympy.Expr]] = [] + + # Solutions come in two forms: + # - (static) specializations + # - (dynamic) inequalities / congruences + self._static_results: set[str] = set() + self._dynamic_results: set[str] = set() + + # printer for solutions + self._dcp = DynamicDimConstraintPrinter( + symbol_to_source, source_name_to_debug_name + ) + + # inconsistencies found on substituting with concrete values / static solutions + self._inconsistencies: list[str] = [] + + # symbols that are marked dynamic + self._marked_dynamic = marked_dynamic + + # track supported sympy functions and subtract from list of all sympy functions + self._supported_sympy_functions: set[sympy.Function] = { + Application, + Mod, + PythonMod, + FloorDiv, + } + self._enumerate_sympy_functions() + + def rewrite_with_congruences(self, s: sympy.Symbol, expr: _SympyT) -> _SympyT: + """ + Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k. + This leaves rational operators (in particular of the form b / d) that our inequality solver can handle. + We solve the added congruences separately (using our congruence solver, see below). + """ + + def mod_handler(*args: sympy.Expr) -> sympy.Expr: + # Suppose that we have an expression of the form b % d with free variable s. + # Using the value of s as a "hint," we can evaluate b % d to a value k. + # Then we can rewrite b % d to k while adding the guard b % d == k. + + # NOTE(avik): This abstraction is provably sound but, in general, incomplete. It is complete IFF + # the original expression always evaluates to a constant value (i.e., it does not vary with s). + # In other words, + # - solutions of s with the rewritten expression are guaranteed to also be solutions of s with + # the original expression; + # - while it may be possible to find solutions of s with the original expression that are not + # solutions with the rewritten expression, in that case the original expression cannot evaluate + # to the same value for all solutions of s. + # + # Should we be worried about this incompleteness? No, because of the following reasons: + # 1. It unblocks dramatic simplification that would not be otherwise possible with current tech + # (i.e., "don't let perfect be the enemy of the good"). + # 2. We already have a tradition of using hints to add guards in the compiler for making progress. + # 3. We have not yet seen a counterexample arise in practice! In particular, any congruence guards + # we generate (or simplify to) seem to be of the form b % d == k where k is a constant. + # + # Here's a theoretical counterexample: 3*s % (s + 1) == s - 2, that is satisfied by all s >= 2. + # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we + # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! + base, divisor = args + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( + self._var_to_val + ) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + return mod_reduced + + def floor_div_handler(*args: sympy.Expr) -> sympy.Expr: + # Suppose that we have an expression of the form b // d with free variable s. + # Using the value of s, we can evaluate b % d to a value k. + # Then we can rewrite b // d to (b - k) / d, while adding the guard b % d == k. + + # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d + # and eliminating b % d as above. + base, divisor = args + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) + mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( + self._var_to_val + ) + congruence = (base - mod_reduced) % divisor + if congruence != 0: + self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha + return (base - mod_reduced) / divisor + + if expr.has(Mod): + expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) + if expr.has(FloorDiv): + expr = expr.replace(FloorDiv, floor_div_handler) + return expr + + def _enumerate_sympy_functions(self) -> None: + module = torch.utils._sympy.functions + all_functions = set() + for attr in dir(module): + if isinstance(func := getattr(module, attr), sympy.FunctionClass): + all_functions.add(func) + self._unsupported_sympy_functions = all_functions.difference( + self._supported_sympy_functions + ) + + def _has_unsupported_sympy_function(self, expr: sympy.Basic) -> bool: + """ + Tracks list of sympy.Functions the export solver doesn't know how to handle. + """ + return expr.has(*self._unsupported_sympy_functions) + + def add(self, expr: SympyBoolean) -> bool: + """Add an expression to the set of constraints. + + Return whether the expression is a trivial constraint (i.e., an obvious tautology). + """ + if expr == sympy.true: + return True + orig_expr = expr + orig_reduced = orig_expr.xreplace(self._var_to_val) + # TODO(avik): https://github.com/pytorch/pytorch/issues/101093 + # It is possible that `expr` will fail the consistency check because of + # precision errors. Specifically, on substituting its free symbols with + # their concrete values, we might end up comparing floats. Until we have + # a fix for this issue, we delay raising such failures. See solve(). + if orig_reduced == sympy.false: + self._inconsistencies.append(f"{orig_expr} is inconsistent!") + if isinstance( + expr, (sympy.Ne, sympy.Or, sympy.And) + ) or self._has_unsupported_sympy_function(expr): + # we're not going to do anything useful with these, so drop them + return False + free_symbols = expr.free_symbols + assert free_symbols, f"Did not expect constraint with no free variables: {expr}" + if len(free_symbols) > 1: + # multivariate: record and move on + self._multivariate_inequalities.add(expr) + else: + # univariate: can solve these immediately + s = next(iter(free_symbols)) + # eliminate // and % (see documentation of `rewrite_with_congruences` above) + old_n_congruences = len(self._congruences[s]) + expr = self.rewrite_with_congruences(s, expr) + new_n_congruences = len(self._congruences[s]) + if expr == sympy.true: + return old_n_congruences == new_n_congruences + reduced = expr.xreplace(self._var_to_val) + if reduced == sympy.false: + self._inconsistencies.append( + f"{expr}, obtained by rewriting {orig_expr} with congruences, " + "is inconsistent!" + ) + if isinstance(expr, sympy.Eq): + # special status for symbols that have equalities (see `solve` below) + self._symbols_with_equalities.add(s) + self._univariate_inequalities[s].add(expr) + return False + + def add_equality(self, source: Source, expr: sympy.Expr) -> None: + """Add an equality constraint""" + if expr.is_number: + # specialization, right here + self._static_results.add(f"{source.name()} == {expr}") + else: + # these will resolve to either specializations or dynamic equality constraints + self._symbolic_equivalences.append((source, expr)) + + def _reduce_congruences(self) -> dict[sympy.Symbol, set[sympy.Expr]]: + reduced_congruences: dict[sympy.Symbol, set[sympy.Expr]] = {} + for s, congruences in self._congruences.items(): + remainder_modulus_pairs = [] + congruences_to_check = set() + for congruence in congruences: + base, divisor = congruence.args + # We are given a congruence of the form base % divisor == 0 with a free variable s. So: + # - we transform this into an equation of the form base = divisor * tmp; + # - we solve this equation for s to get a linear solution with free variable tmp. + tmp = sympy.Symbol("reduce_congruences_tmp", integer=True) + symbol, solution = sympy.solve_linear(base - divisor * tmp, symbols=[s]) + # See https://docs.sympy.org/latest/modules/solvers/solvers.html#sympy.solvers.solvers.solve_linear + # for how to interpret the results. + if s == symbol: + # This means the solution is of the form s = modulus*tmp + remainder. + modulus, remainder = sympy.polys.polytools.div(solution, tmp) + if isinstance(modulus, sympy.Integer) and isinstance( + remainder, sympy.Integer + ): + # Make sure 0 <= remainder <= modulus. + remainder = remainder % modulus + remainder_modulus_pairs.append((remainder, modulus)) + continue + # This means that we did not get a unique solution to the equation. + # No problem, we will check it. + congruences_to_check.add(congruence) + # Finally we solve for a congruence s such that s = r_i mod m_i for each (r_i, m_i). + # The solution will be a congruence of the form s = r mod m. + # NOTE(avik): Since the given m_i may not be pairwise coprime, we can't just use CRT. + if remainder_modulus_pairs: + remainder, modulus = sympy.ntheory.modular.solve_congruence( + *remainder_modulus_pairs + ) + reduced_congruences[s] = {(s - remainder) % modulus} + substitution = { + s: modulus * sympy.Symbol("tmp", integer=True) + remainder + } + reduced_congruences[s].update( + congruence + for congruence in congruences_to_check + if not sympy.checksol(congruence, substitution) + ) + else: + reduced_congruences[s] = congruences_to_check + + return reduced_congruences + + def _raise_inconsistencies(self) -> None: + if self._inconsistencies: + msg = "\n".join(self._inconsistencies) + self._inconsistencies.clear() + raise ValueError(f"The following inconsistencies were found:\n{msg}") + + def solve(self) -> None: + """Solve the system of constraint equations to find simplified constraints""" + self._raise_inconsistencies() + # as long as there are symbols with equalities, solve for them + # NOTE(avik): this is guaranteed to terminate (#iterations <= #symbols) + while self._symbols_with_equalities: + s = self._symbols_with_equalities.pop() + exprs = self._univariate_inequalities.pop(s) + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + if isinstance(solution, sympy.And): + solution = next( + (arg for arg in solution.args if isinstance(arg, sympy.Eq)), + solution, + ) + assert isinstance(solution, sympy.Eq), ( + f"Expected an equality constraint for {s}, got {solution}" + ) + symbol, val = solution.args + assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" + # because this is univariate, the solution is a specialization + self._static_results.add( + f"{self._dcp.symbol_to_source[s][0].name()} == {val}" + ) + # add this as a substitution to simplify other constraints + self._substitutions[s] = val # type: ignore[assignment] + + # simplify multivariate inequalities: some of them will now become univariate! + multivariate_inequalities = self._multivariate_inequalities + self._multivariate_inequalities = set() + for expr in multivariate_inequalities: + self.add(expr.xreplace({s: self._substitutions[s]})) + self._raise_inconsistencies() + + # solve linear congruences + # NOTE(avik): We do not need to solve them for symbols that have already been specialized. + reduced_congruences = self._reduce_congruences() + for s, congruences in reduced_congruences.items(): + for congruence in congruences: + # any congruence that cannot be checked becomes a dynamic constraint as well + if s not in self._substitutions or not sympy.checksol( + congruence, {s: self._substitutions[s]} + ): + if self._is_supported_congruence(congruence): + base, divisor = congruence.args + tmp_name = "_" + str( + self._dcp.source_name_to_debug_name.get( + self._dcp.symbol_to_source[s][0].name(), + self._dcp.symbol_to_source[s][0].name(), + ) + ) + tmp = sympy.Symbol(tmp_name, integer=True) + from torch._dynamo.source import ConstantSource + + self._dcp.symbol_to_source[tmp] = [ConstantSource(tmp_name)] + r = try_solve(sympy.Eq(base, divisor * tmp), s) + assert r is not None + self._dynamic_results.add(self._dcp.doprint(sympy.Eq(s, r[1]))) + + # remaining symbols have only pure inequalities (no equalities) + for s, exprs in self._univariate_inequalities.items(): + try: + solution = sympy.solvers.inequalities.reduce_inequalities(exprs, s) + # because this is univariate, the solution is a dynamic (range) constraint + if isinstance(solution, sympy.Or): + solution = next( + iter( + arg + for arg in solution.args + if arg.xreplace(self._var_to_val) + ) + ) + if isinstance(solution, sympy.And): + for arg in solution.args: + self._dynamic_results.add(self._dcp.doprint(arg)) + else: + self._dynamic_results.add(self._dcp.doprint(solution)) + except (NotImplementedError, AssertionError) as e: + log.warning("Failed to reduce inequalities: %s", e) + for expr2 in exprs: + self._dynamic_results.add(self._dcp.doprint(expr2)) + + # simplify symbolic equivalences: some of them will now become specializations! + symbolic_equivalences = self._symbolic_equivalences + self._symbolic_equivalences = [] + for source, expr3 in symbolic_equivalences: + self.add_equality(source, expr3.xreplace(self._substitutions)) + + # remaining symbolic equivalences become dynamic equality constraints + for source, expr3 in self._symbolic_equivalences: + self._dynamic_results.add(f"{source.name()} == {self._dcp.doprint(expr3)}") + + @classmethod + def _is_supported_congruence(cls, congruence: sympy.Expr) -> bool: + base, divisor = congruence.args + # Congruences that can be currently expressed with supported Dim ops are + # of the form (x + a) % b == 0, where x is a Dim and a and b are constants. + # This allows us to derive x as b*y - a for some Dim y. + # (See also documentation of dynamic_shapes._DerivedDim.) + if isinstance(base, sympy.Add): + lhs, rhs = base.args + cond = ( + isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Integer) + ) or (isinstance(lhs, sympy.Integer) and isinstance(rhs, sympy.Symbol)) + else: + cond = isinstance(base, sympy.Symbol) + cond = cond and isinstance(divisor, sympy.Integer) + return cond + + def forced_specializations(self) -> dict[str, sympy.Expr]: + """Returns a dictionary of the names of symbols to their specialized value""" + + def debug_name(src: Source) -> str: + name = src.name() + if self._dcp.source_name_to_debug_name: + return f"{self._dcp.source_name_to_debug_name[name]} = {name}" + else: + return name + + return { + debug_name(self._dcp.symbol_to_source[s][0]): val + for s, val in self._substitutions.items() + if s in self._marked_dynamic + } + + def _is_derived_dim( + self, dim: object + ) -> TypeGuard[torch.export.dynamic_shapes._DerivedDim]: + return isinstance(dim, torch.export.dynamic_shapes._DerivedDim) + + def _is_dim(self, dim: object) -> TypeGuard[torch.export.dynamic_shapes.Dim]: + return isinstance(dim, torch.export.dynamic_shapes.Dim) and not isinstance( + dim, torch.export.dynamic_shapes._DerivedDim + ) + + def _process_derived_dim_roots( + self, + results: dict[str, dict[str, Any]], + name_to_dim: dict[str, Any], + ) -> None: + """ + Here we resolve 2 concerns with derived dims suggested fixes: 1) newly introduced roots, + and 2) root swapping. + + 1) Newly introduced roots appear with modulo guards, e.g. Mod(dx, 2) = 0 suggests + dx is a derived dim equal to 2 * _dx, introducing a new root _dx. Currently the final + suggested fixes handle this correctly, but we can get intermediate results that look like + {"dy": {"eq": "dx + 1"}, "dx": {"eq": "2 * _dx + 1, "min": 3, "max": 15}} + and this routine prettifies this by unifying to a single root, and making each suggestion + either a derived dim or min/max range, not both. + + 2) With suggested fixes for derived dims, roots can be swapped, + e.g. dx, dx - 1 -> dy + 1, dy. Here we don't want to print out the attached name, + since this leads to messages like "dx - 1 = Dim("dx - 1", ...)". + Instead we evaluate the new root value, and remove results for its derivations. + + First we find all the original roots (specified in dynamic_shapes), that are found in the + values of results (i.e. used for computing suggesting fix values). These original roots + (suppose `dx`) are either specialized, unchanged, refined, or swapped + (expressed as a derived dim). If any of the first 3 cases happen, we suggest `dx`'s value + in results, and remove suggestions for derivations of `dx`, assuming the derived relation + is valid. If swapped, we find the new root, and use the fix to evaluate `dx`'s new value, + and then do the same with `dx`'s derivations. + + Assuming the originally specified derived relations are correct is valid, because: + 1) if the relations are plain wrong (e.g. input shape = (6, 4) with spec (dx, dx - 1)) + produce_guards() will catch this and crash before hand. + 2) if the relations are numerically correct but do not match the emitted guard, + for example: + + def forward(self, x, y): + return x.reshape([-1]) + y # guard: s0 * 2 = s1 + inputs = (torch.randn(6, 2), torch.randn(12)) + dx = Dim("dx", min=2, max=32) + dynamic_shapes={"x": (dx, 2), "y": (dx + 6, )} # this matches values but not op + + then this leads to 2 linear equations, and a) produce_guards() is able to solve for + the unique solution of dx = 6 and specialize, and b) the export constraint solver will + raise an issue due to range constraints (a unique solution means not all values in a + range satisfy a guard) and also force specializations. + """ + from torch.export.dynamic_shapes import Dim + + def _check_same_range(c: Mapping[str, int], dim: object) -> bool: + # returns True if c & dim are both min/max ranges with same values + return ( + self._is_dim(dim) + and ("min" in c or "max" in c) + and ( + (dim.min < 2 and c.get("min", 2) == 2) or dim.min == c.get("min", 2) # type: ignore[attr-defined] + ) # let pass if analysis min = 2 and specified min = 0/1 + and dim.max == c.get("max", int_oo) # type: ignore[attr-defined] + ) + + # 1) newly introduced roots + # this part we handle adding newly introduced roots + # these arise from guards like "x.shape[0] % 3 == 0" + # leading to suggested fixes like "dx = 3*_dx" + # extract _dx, and find appropriate min/max values + # + # before, we have something like: + # {"dx": {"eq": 3*_dx+1, "min": 4, "max": 10}, "dy": dx+1, "dz": dx+2} + # we want instead: + # {"_dx": {"min": 1, "max": 4}, "dx": 3*_dx+1, "dy": 3*_dx+2, "dz": 3*_dx+3} + introduced_roots: dict[str, str] = {} # map new root -> old root + for k, c in list(results.items()): + if "eq" in c and isinstance(c["eq"], sympy.Expr): # derived dim + root = next(iter(c["eq"].free_symbols)) + if str(root) not in name_to_dim: + introduced_roots[str(root)] = k + # calculate necessary min & max + modulus, remainder = sympy.polys.polytools.div(c["eq"], root) + c_min = c.get("min", 2) + min_ = math.ceil((c_min - remainder) / modulus) + c_max = c.get("max", int_oo) + max_ = math.floor((c_max - remainder) / modulus) + # create result & dim + results[str(root)] = {"min": min_, "max": max_} + name_to_dim[str(root)] = Dim(str(root), min=min_, max=max_) + # remove old root min/max bounds + c.pop("min", None) + c.pop("max", None) + + # alter derivations that depend on old root, to unify to new root + # e.g. dx=3*_dx+1, dy=dx+1 -> dy=3*_dx+2 + for old_root in introduced_roots.values(): + for k, c in list(results.items()): + if ( + "eq" in c + and isinstance(c["eq"], sympy.Expr) + and str(symbol := next(iter(c["eq"].free_symbols))) == old_root + ): # derived dim with root = old_root + new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 + new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 + c["eq"] = new_expr + + # 2) root swapping + # collect all the original roots that are used for calculating values of suggested fixes + # this consists of: + # 1) {"dx": {"min": ..., "max": ...}} -> dx: refined root dim + # 2) {"dy": "dx + 1"} -> dx: root for suggested fix + modified_roots: set[str] = set() + for k, c in results.items(): + if k not in name_to_dim: # _dynamo.export() may handle source directly + continue + if self._is_dim(name_to_dim[k]) and ("min" in c or "max" in c): # case 1) + modified_roots.add(k) + elif "eq" in c and isinstance(c["eq"], sympy.Expr): # case 2) + root = next(iter(c["eq"].free_symbols)) + assert root is not None + modified_roots.add(str(root)) + + # exclude newly introduced roots, we've already processed these + modified_roots = modified_roots.difference(introduced_roots) + + # evaluate the new value for each root + # this is now either 1) unchanged, 2) refined with a new range, + # or 3) specialized to a concrete value + modified_root_values: dict[str, dict[str, Any]] = {} + for mroot in modified_roots: + swapped_root = True + if mroot in results: + c = results[mroot] + if ("min" in c or "max" in c) or isinstance( # range + c["eq"], int + ): # specialized + # here, the original root is a root Dim or concrete value in results. + # if it is a derived dim, it is swapped, and we handle that below. + if not _check_same_range( + c, name_to_dim[mroot] + ): # ignore if unchanged + modified_root_values[mroot] = c + swapped_root = False + + if swapped_root: + # if the original root has been swapped in results, that means the new root + # is a range (if it had specialized, the original root would have too). + # find this new root, and solve for the original root's range. + for k, c in results.items(): + if k not in name_to_dim: + continue + dim = name_to_dim[k] + if ( + dim.__class__.__name__ == "_DerivedDim" + and dim.root.__name__ == mroot + ): + # only look for min/max root, otherwise root would have specialized + if "min" in c or "max" in c: + expr = sympy.sympify(k) + s = next(iter(expr.free_symbols)) + result = { + "min": try_solve(sympy.Eq(expr, c["min"]), s)[1], # type: ignore[arg-type, index] + "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] + } + if not _check_same_range( + result, + name_to_dim[mroot], # type: ignore[index, arg-type] + ): # ignore if unchanged + modified_root_values[mroot] = result # type: ignore[index] + break + + # filter out results where the key is a derived dim (e.g. {"dx - 1" : 4}) + # we only want to suggest fixes for the root, to avoid derived names. + # also, remove anything in modified_roots, since we either add new modified values after this, + # or have decided they are unchanged. + for k in list(results.keys()): + if k not in name_to_dim: + continue + if self._is_derived_dim(name_to_dim[k]) or k in modified_roots: + del results[k] + + # update results with modified root values + # now results has the following properties: + # - only contains original roots as keys + # - each root is now either specialized, refined, or derived from another original root + results.update(modified_root_values) + + def prettify_results( + self, + original_signature: inspect.Signature, + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any]], + constraint_violation_error: object, + forced_specializations: dict[str, str], + ) -> str: + """Format a message for constraint violation erros""" + from torch.export.dynamic_shapes import _get_dim_name_mapping + + if not self._dcp.source_name_to_debug_name: + # nothing to do + return "" + + def transform(s: str, inverse: bool = False) -> str: + for k, v in self._dcp.source_name_to_debug_name.items(): + s = s.replace(k, v) if not inverse else s.replace(v, k) + return s + + results: defaultdict[str, dict[str, Any]] = defaultdict(dict) + if dynamic_shapes is None: + dynamic_shapes = {} + + def flip(op: str) -> str: + if op == "<=": + return ">=" + if op == ">=": + return "<=" + if op == "<": + return ">" + if op == ">": + return "<" + assert op == "==" + return op + + def relation_with_digit(expr: str, op: str, digit: int) -> None: + if op == "<=": + results[expr]["max"] = digit + elif op == "<": + results[expr]["max"] = digit - 1 + elif op == ">=": + results[expr]["min"] = digit + elif op == ">": + results[expr]["min"] = digit + 1 + else: + assert op == "==" + results[expr]["eq"] = digit + + # retrieve dynamic shapes + name_to_dim = _get_dim_name_mapping(dynamic_shapes) + + for s in self._static_results.union(self._dynamic_results): + t = transform(s) + if t == s: + continue + left, op, right = re.split(r"( == | <= | >= | < | > )", t) + op = op.strip() + if op == "==" and left == right: + continue + if right.isdigit(): + relation_with_digit(left, op, int(right)) + elif left.isdigit(): + relation_with_digit(right, flip(op), int(left)) + else: + assert op == "==", t + try: + results[left]["eq"] = sympy.sympify(right) + except TypeError: # rhs source is not linked to Dim name + pass + + # order forced specializations based on name + forced_specializations = { + k: forced_specializations[k] + for k in sorted( + forced_specializations.keys(), + key=lambda x: x.split(" = ")[1], + ) + } + + buf = "" + if forced_specializations: + debug_names = set() + for k in forced_specializations: + dim = name_to_dim[k.split(" = ")[0]] + if self._is_derived_dim(dim): + debug_names.add(dim.root.__name__) # type: ignore[attr-defined] + else: + debug_names.add(dim.__name__) + + buf += ( + f"Specializations unexpectedly required ({', '.join(sorted(debug_names))})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + ) + for s, val in forced_specializations.items(): + buf += f" - solving the guards generated for {s} resulted in a specialized value of {val}.\n" + + self._process_derived_dim_roots(results, name_to_dim) + + dims = [] + others = [] + + # order results by source name + results2 = { + k: results[k] + for k in sorted( + results.keys(), + key=lambda x: transform(x, inverse=True), + ) + } + for k, c in results2.items(): + if "eq" in c: + other = c["eq"] + if isinstance(other, int): + others.append(f"{k} = {other}") + elif _is_supported_equivalence(other): + others.append(f"{k} = {other}") + else: + min_ = c.get("min", None) + if min_ == 2: + min_ = None + max_ = c.get("max", None) + if min_ is not None and max_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_}, max={max_})") + elif min_ is not None: + dims.append(f"{k} = Dim('{k}', min={min_})") + elif max_ is not None: + dims.append(f"{k} = Dim('{k}', max={max_})") + else: + dims.append(f"{k} = Dim('{k}')") + + # results2 will get filtered out if no new suggestions, + # this can happen if guards are too complex. + # in that case don't suggest fix + if dims or others: + buf += "\nSuggested fixes:\n " + buf += "\n ".join(dims + others) + + return buf + + +TLS = threading.local() + + +@dataclass(frozen=True) +class ShapeEnvSettings: + """ + Encapsulates all shape env settings that could potentially affect + FakeTensor dispatch. Used when creating dispatch cache keys. + """ + + allow_scalar_outputs: bool + allow_dynamic_output_shape_ops: bool + assume_static_by_default: bool + specialize_zero_one: bool + duck_shape: bool + prefer_deferred_runtime_asserts_over_guards: bool + allow_complex_guards_as_runtime_asserts: bool + trace_asserts: bool + + +@dataclass +class ValueRangesSLoc: + """ + Locations of the guards that triggered lower and upper bound. + """ + + lower: SLoc + upper: SLoc + + +@contextmanager +def _suppress_guards(shape_env: ShapeEnv) -> Iterator[None]: + shape_env._suppress_guards_enter() + try: + yield + finally: + shape_env._suppress_guards_exit() + + +@dataclass +class _FrameLocalResult: + loc: Optional[str] = None + locals: dict[str, Any] = field(default_factory=dict) + symbols: dict[str, str] = field(default_factory=dict) + + +class ShapeEnv: + # This is a wrapper over the actual __init__ function. + # + # Where to add a new constructor parameter to ShapeEnv? + # ===================================================== + # This __init__ function should be used only for parameters related to event recording. + # These are parameters that we don't wish to pass down the road to new ShapeEnv instances + # created from replaying events. + # + # If you wish to add a parameter to the constructor of ShapeEnv, unrelated to event + # recording, do so in the _init function. + def __init__( + self, + *, + should_record_events: Optional[bool] = None, + tracked_fakes: Optional[list[Any]] = None, + **kwargs: Any, + ) -> None: + self._init(**kwargs) + + # Disable event recording when replaying. + kwargs["should_record_events"] = False + + from torch.fx.experimental.validator import translation_validation_enabled + + self._translation_validation_enabled = translation_validation_enabled() + + # If not specified, enable event recording if both: + # - Translation validation is on + # - Translation validation bisection is not disabled + self.should_record_events = ( + should_record_events + if should_record_events is not None + else ( + self._translation_validation_enabled + and not config.translation_validation_no_bisect + ) + ) + + # Enable event recording check if both: + # - It should record events + # - The recording check is enabled + self.check_recorded_events = ( + self.should_record_events and config.check_shape_env_recorded_events + ) + + # This will make sure we only record the top-level function call. + self.is_recording = False + # Keep track of the list of tracked fakes. + self.tracked_fakes = tracked_fakes + # List of events for reconstructing ShapeEnv at arbitrary points in time. + self.events: list[ShapeEnvEvent] = ( + [ShapeEnvEvent(ShapeEnv, kwargs=kwargs)] + if self.should_record_events + else [] + ) + + # FakeTensor per-ShapeEnv operation cache. This is used for caching + # operations that contain symbolic shapes which have guards on the + # ShapeEnv (so are ShapeEnv-dependent). + # + # NOTE: It's important that SymNodes in this cache have their ShapeEnv + # stripped otherwise you end up with cycles which can only be cleaned + # with the GC. + self.fake_tensor_cache: dict[ + torch._subclasses.fake_tensor._DispatchCacheKey, + torch._subclasses.fake_tensor._DispatchCacheEntry, + ] = {} + + # Pro-tip: if you add new field to ShapeEnv, this affects some accept + # tests. Accept their output with: + # + # EXPECTTEST_ACCEPT=1 python test/dynamo/test_dynamic_shapes.py -k test_shape_env_equal + # + def _init( + self, + *, + allow_scalar_outputs: bool = True, + allow_dynamic_output_shape_ops: bool = True, + # NB: These are legacy configuration that help us make good choices + # when the constraint/dynamic dims are not explicitly passed to us. + # Ideally we will fix all call sites to be explicit and not have + # implicit choices, but this apparently was pretty involved. + assume_static_by_default: bool = False, + # Note - On 0/1 specialization + # + # The following options affect decisions we make about eager + # specialization. Disabling them will increase trace time (as we do + # more symbolic reasoning) and can also harm the quality of generated + # code (because inductor may not be able to specialize for bounds + # being equal--although if we later respecialize because of a guard, + # your code may be just as good as it was before.) + # + # When True, eagerly specialize input sizes which have 0/1. + specialize_zero_one: bool = True, + # When True, assume input sizes which have the same size are + # symbolically equal. + duck_shape: Optional[bool] = None, + # For debugging + co_fields: Optional[dict[str, str]] = None, + # When True, whenever safe, we will generate a deferred runtime assert + # instead of a guard whenever we know that an expression must be True, + # otherwise it would be an error, even for backed SymInts (where we + # could ostensibly unconditionally generate guards). This is useful + # for export, where preventing "error checking" sizes from showing up + # in guards is helpful, since these guards in some sense are overly + # pedantic. See also https://github.com/pytorch/pytorch/issues/121749 + prefer_deferred_runtime_asserts_over_guards: bool = False, + # When True, does not emit or raise constraint violation errors on + # implicit guards generated by ops, and defers to runtime assertions + # in the graph instead. For export. + allow_complex_guards_as_runtime_asserts: bool = False, + # XXX Add any new settings that could affect FakeTensor evaluation + # to: torch._subclasses.fake_tensor._ShapeEnvSettings + trace_asserts: bool = False, + ) -> None: + if duck_shape is None: + duck_shape = config.use_duck_shape + + self.settings = ShapeEnvSettings( + # Not directly used by ShapeEnv; indirectly used by FakeTensor + allow_scalar_outputs=allow_scalar_outputs, + allow_dynamic_output_shape_ops=allow_dynamic_output_shape_ops, + # End + assume_static_by_default=assume_static_by_default, + specialize_zero_one=specialize_zero_one, + duck_shape=duck_shape, + prefer_deferred_runtime_asserts_over_guards=prefer_deferred_runtime_asserts_over_guards, + allow_complex_guards_as_runtime_asserts=allow_complex_guards_as_runtime_asserts, + trace_asserts=trace_asserts, + ) + + self.guards: list[ShapeGuard] = [] + self.axioms: dict[sympy.Expr, sympy.Expr] = {} + + # A set of ids that have already been allocated. This is used + # for when we allocate symbol ids using the hash of the source + # names to ensure we don't have collisions via linear probing + self.unique_ids: set[int] = set() + # Maps symbolic ints to their original concrete values + # Currently populated from tensors + self.var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Like var_to_val, but only set when propagate_real_tensors is on. + # Used as last resort to avoid GuardOnDataDependent error + self.unbacked_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Like above, but used exclusively for OBLIVIOUS_SIZE. These + # potentially could be put together but I am not sure, writing out + # the logic individually before abstracting. + self.oblivious_var_to_val: dict[sympy.Symbol, sympy.Integer] = {} + # Maps symbolic ints to their min/max range. These ranges + # are conservative: the int MUST fall in the range, but the + # range may contain ints which may not actually appear in + # practice + self.var_to_range: dict[sympy.Symbol, ValueRanges] = {} + self.var_to_range_sloc: dict[sympy.Symbol, ValueRangesSLoc] = {} + self.source_name_to_debug_name: dict[str, str] = {} + self.var_to_sources: dict[sympy.Symbol, list[Source]] = {} + self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {} + # Maps a source to the *original* symbol that was assigned to it + self.source_to_var: dict[str, sympy.Symbol] = {} + # Maps from sympy ints to expressions representing them + # Populated from equality guards (i.e. a.shape[0] == b.shape[0]) + self.replacements: dict[sympy.Symbol, sympy.Expr] = {} + # The sloc of the guard that triggered this replacement to be added + self.replacements_slocs: dict[sympy.Symbol, SLoc] = {} + self.unbacked_renamings: dict[sympy.Symbol, sympy.Symbol] = {} + # Set holds a % b expressions that evaluate to 0. + self.divisible: set[sympy.Expr] = set() + # Set that holds "size-like" symbols. When we perform + # "size-oblivious" tests, these can be assumed to be >= 2. + self.size_like: set[sympy.Symbol] = set() + # Duck-shaping says that if two input tensors have the same size, + # they get assigned the same symbolic variable + self.val_to_var: dict[int, sympy.Symbol] = {} + self.unbacked_symfloat_counter = itertools.count() + self.unbacked_symint_counter = itertools.count() + # Similar to guards, but these MUST evaluate to true and can + # only be evaluated at runtime midway through (i.e., they always + # involve unbacked symints) + # + # For efficiency reasons, we index in the following way. Suppose you have + # a runtime assert i0 + i1 <= s1. We pick the most recently allocated + # symbol in the source expression and add the assert to the list for + # that symbol e.g., {i1: [i0 + i1 <= s1]}. + # + # We access the runtime asserts in two situations: + # + # - When we are guarding on an expression, we will attempt to + # statically evaluate it, in case the unbacked SymInts can + # simplify away. If we have a runtime assert, we may be able + # to discharge the guard entirely. We only need to attempt + # runtime asserts that mention freevars of the expression in + # question. + # + # - When we are performing codegen (in Inductor for eager, or + # when finalizing the export FX graph), we need to know what + # extra runtime asserts to insert. Whenever an unbacked + # SymInt comes into scope, all runtime asserts involving it + # become eligible for insertion (so long as all of their other + # free unbacked symbols are also in scope). We technically + # can handle any choice of key by kicking inexpressible asserts + # to the next unbacked symbol to wait on, but if we choose the + # latest key, an assert will only show up at the moment when + # we can actually codegen it. + self.deferred_runtime_asserts: dict[ + Optional[sympy.Symbol], list[RuntimeAssert] + ] = {} + # This exists so we can efficiently invalidate the cache (it's used as + # part of the cache key); otherwise we'd have to iterate through + # deferred_runtime_asserts to compute its length + self.num_deferred_runtime_asserts = 0 + self.log = log + self.log.info("create_env") + self.frozen = False + self.runtime_asserts_frozen = False + self.dim_constraints: Optional[DimConstraints] = None + self.counter: Counter[str] = collections.Counter() + # Mapping from sympy.Symbol to the number of guards which mention this + # symbol + self.symbol_guard_counter: Counter[sympy.Symbol] = collections.Counter() + # A selection of important fields on co_field; solely used for + # signpost_event + self.co_fields = co_fields if co_fields else {} + + # Whenever we allocate a fresh unbacked Symbol, we add it to this + # pending list. Unbacked symbol allocation can occur at unpredictable + # points during meta tensor propagation, but at some point, we + # have to know what the binding site for an unbacked symbol is, and + # this is computed when we actually place the node in the graph. The + # important thing is that we always actually handle every unaccounted + # for unbacked symbol, so this list helps us keep track of them and + # then make sure they are all accounted for. + # + # We could potentially give rise to errors earlier by lexically + # scoping when we do propagation, and only allowing unbacked symbols + # to be allocated at this point in time. However this is inconvenient + # to do in Dynamo, because fake tensor propagation is far from when we + # analyze binding sites (set_example_value), so we do it in a more + # mutatey way. + # + # NB: fresh unbacked symbols NEVER get substitutions applied to them, + # they are binding sites! + self.pending_fresh_unbacked_symbols: list[sympy.Symbol] = [] + + # Version counter used to invalidate cached values + self._prev_cache_key = self._get_key() + self._version_counter = 0 + + # Each time divisible is changed this should be set to True, this is set in _update_version_counter. + self._resimplify_floor_div_axioms = True + + # Cache for FX nodes. + # Maps an already built node a tuple of: + # 1. node's target + # 2. list of arguments + # This drastically reduces the size of the FX graph, avoiding + # duplicated nodes. + self.fx_node_cache: dict[tuple[Callable, tuple[Any, ...]], torch.fx.Node] = {} + self.source_to_symbol: dict[str, sympy.Symbol] = {} + + # Suppose you want to replace an unbacked symbol with another + # unbacked symbol. This is error prone because you can cause + # references to unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. + # + # To control for this, we track the order unbacked symbols + # were allocated, and only allow substitutions if they respect + # the dependency from this order; an unbacked symbol can only + # be substituted with unbacked symbols that come before it in the + # order. + # + # This also imposes an ordering on the unbacked symbol binding + # sites themselves: you are not allowed to reorder unbacked symbol + # bindings. At the moment, this is not tracked, but we potentially + # could track this at the IR level using a higher order operator + # with something like effect token tracking. + self.unbacked_alloc_order: dict[sympy.Symbol, int] = {} + + self.user_specialization_stacks: dict[Source, traceback.StackSummary] = {} + self.framework_specialization_stacks: dict[Source, traceback.StackSummary] = {} + + self.trace_asserts = trace_asserts + + self.specializations: OrderedSet[Specialization] = OrderedSet() + + from torch.fx.experimental.validator import translation_validation_enabled + + self._translation_validation_enabled = translation_validation_enabled() + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import TranslationValidator + + self.validator = TranslationValidator() + self.graph = torch.fx.Graph() + # Create an output graph and start inserting before that. + # This is needed when 'deepcopy'-ing this object. + self.graph.inserting_before(self.graph.output(None)) + + # Mapping of each node name to the node itself. + # + # This is useful for matching an FX node from a recorded ShapeEnv.graph + # to the FX node of the ShapeEnv we are running the event on. + # + # Whenever you add a node to self.graph, you must add a mapping to this + # variable. Otherwise, the built FX graph on the replayed ShapeEnv will + # not be valid. + self.name_to_node: dict[str, torch.fx.Node] = {} + + @property + def allow_scalar_outputs(self) -> bool: + return self.settings.allow_scalar_outputs + + @property + def allow_dynamic_output_shape_ops(self) -> bool: + return self.settings.allow_dynamic_output_shape_ops + + @property + def assume_static_by_default(self) -> bool: + return self.settings.assume_static_by_default + + @property + def specialize_zero_one(self) -> bool: + return self.settings.specialize_zero_one + + @property + def duck_shape(self) -> bool: + return self.settings.duck_shape + + @property + def prefer_deferred_runtime_asserts_over_guards(self) -> bool: + return self.settings.prefer_deferred_runtime_asserts_over_guards + + @property + def allow_complex_guards_as_runtime_asserts(self) -> bool: + return self.settings.allow_complex_guards_as_runtime_asserts + + @contextmanager + def patch_source_specialization( + self, source: Source, check_fn: Callable[[sympy.Symbol], sympy.Expr] + ) -> Iterator[None]: + """ + Temporarily add symbol-level axioms to the ShapeEnv. This is useful when you want to "fork" + and have parallel universes of ShapeEnvs. For example, we use this when doing multi-graph + compile so we can support various graphs with varying levels of specializations. + + This context manager allows for temporarily adding constraints to the shape environment + based on a specialization function applied to a symbol associated with a source. + + Args: + source: The source of the symbol to specialize + check_fn: A function that takes a sympy Symbol and returns a sympy expression + representing a constraint/specialization to be applied + """ + name = source.name() + sym = self.source_to_var[name] + expr = check_fn(SymInt(SymNode(sym, self, int, None))).node._expr + new_axioms = dict(self.get_implications(self.simplify(expr))) + added_replacements = {} + + for axiom in new_axioms: + if ( + isinstance(axiom, sympy.Eq) + and isinstance(axiom.lhs, sympy.Symbol) + and isinstance(axiom.rhs, sympy.Integer) + and axiom.lhs not in self.replacements + ): + self.replacements[axiom.lhs] = axiom.rhs + added_replacements[axiom.lhs] = axiom.rhs + self.axioms.update(new_axioms) + + # We need to freeze the ShapeEnv becuase any additional modification of + # the ShapeEnv will cause unsoundness for subsequent specialization calls. + self.frozen = True + try: + yield + finally: + for k in new_axioms: + self.axioms.pop(k, None) + for k in added_replacements: + self.replacements.pop(k, None) + self.frozen = False + + def check_equal(self, other: ShapeEnv) -> None: + """Compare another ShapeEnv for equivalence""" + # ShapeEnv fields that are not relevant for the outcome of + # ShapeEnv.produce_guards call: + # - Debugging variables + # - Translation validation related variables + # - Events recording related variables + non_state_variable_names = ( + "counter", + "log", + "var_to_stack", + "fx_node_cache", + "graph", + "validator", + "check_recorded_events", + "should_record_events", + "is_recording", + "tracked_fakes", + "events", + "source_name_to_debug_name", + "_prev_cache_key", + "_version_counter", + "dim_constraints", + # source locations are OK to diverge + "var_to_range_sloc", + "replacements_slocs", + "_resimplify_floor_div_axioms", + "_expr_sym_node_id", + "user_specialization_stacks", + "framework_specialization_stacks", + ) + + # Mapping of the value of each to-be-compared field into the values that + # should actually be compared. + # + # You should modify this if, for example, the field that holds state and + # debugging information. e.g. ShapeGuard holds the actual guard (sympy.Expr) + # and the stack when it was added to the set of guards. In order to compare + # it, we throw away the stack information. + def map_value(key: str, value: Any) -> Any: + if key in ("unbacked_symfloat_counter", "unbacked_symint_counter"): + from copy import copy + + # For itertools.count(), we compare the next integer returned + # by the count iterators. Not that we need to copy the iterator + # first. Otherwise we are mutating the object. + return next(copy(value)) + elif key == "guards": + # Transform the list of ShapeGuard into a list of expressions. + return [g.expr for g in value] + elif key == "deferred_runtime_asserts": + # Transform the list of RuntimeAsserts into a list of expressions. + return {s: [ra.expr for ra in ras] for s, ras in value.items()} + elif key == "name_to_node": + # Compare just the set of keys is the same. + return set(value.keys()) + elif key in ( + "symbol_guard_counter", + "pending_fresh_unbacked_symbols", + "fake_tensor_cache", + ): + # Skip this for comparisons + return None + return value + + shape_env_check_state_equal(self, other, non_state_variable_names, map_value) + + def _snapshot_tracked_fakes(self) -> Optional[list[Any]]: + if self.tracked_fakes is None: + return None + + from torch._dynamo.variables.builder import TrackedFake + + def maybe_transform_fake(fake: TrackedFake) -> TrackedFake: + inner_fake = ( + fake.fake + if isinstance(fake.fake, (torch.SymInt, torch.SymFloat)) + else FakeTensorMeta.from_fake(fake.fake) + ) + # Even though TrackedFake accepts either a Union[SymInt, FakeTensor], here we give it a + # FakeTensorMeta for two reasons: + # 1. this is all the information we need when recording ShapeEnvEvents. + # 2. it works even if each TrackedFake changes its metadata. + return TrackedFake(inner_fake, fake.source, fake.symbolic_context) # type: ignore[arg-type] + + return [maybe_transform_fake(fake) for fake in self.tracked_fakes] + + def _last_event_index(self) -> int: + return len(self.events) - 1 + + @contextmanager + def _recording(self) -> Iterator[None]: + self.is_recording = True + try: + yield + finally: + self.is_recording = False + + @record_shapeenv_event() + def _eliminate_unbacked(self, orig_s: sympy.Symbol, new_s: sympy.Expr) -> None: + self._set_replacement(orig_s, new_s, "eliminate_unbacked") + + @record_shapeenv_event() + def set_unbacked_var_to_val(self, k: sympy.Symbol, v: int) -> None: + """Used only when propagate_real_tensors; registers a value for an + unbacked symbol, which can be used last resort to resolve hints.""" + log.info("set_unbacked_var_to_val %s = %s", k, v) + self.unbacked_var_to_val[k] = sympy.sympify(v) + + # Unlike set_replacement, this records a shapeenv event + @record_shapeenv_event() + def _rename_unbacked_to(self, orig_s: sympy.Symbol, new_s: sympy.Symbol) -> None: + assert isinstance(orig_s, sympy.Symbol), orig_s + assert isinstance(new_s, sympy.Symbol), new_s + assert free_unbacked_symbols(new_s), new_s + assert free_unbacked_symbols(orig_s), orig_s + dest = self.replacements.get(orig_s) + if dest is not None: + assert not free_unbacked_symbols(dest), f"{orig_s} -> {dest}" + self._set_replacement(orig_s, new_s, "rename_unbacked_to") + self.unbacked_renamings[orig_s] = new_s + if dest is not None: + self._set_replacement(new_s, dest, "rename_unbacked_to_dest") + + @record_shapeenv_event() + def _constrain_is_bounded(self, a: sympy.Symbol, upper_bound: int) -> None: + # TODO: Do something nontrivial when upper_bound is expression + pass + + @record_shapeenv_event() + def _constrain_range_for_size( + self, a: sympy.Symbol, min: Optional[int] = None, max: Optional[int] = None + ) -> None: + if min is None: + min = 0 + if max is None: + max = int_oo + + if max < min: + raise ValueError( + "Maximum value to constrain_as_size can't be less than the specified min value, " + "received min={min} and max={max}" + ) + + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + self.size_like.add(a) + + @record_shapeenv_event() + def _constrain_range(self, a: sympy.Expr, min: int, max: int) -> None: + if isinstance(a, sympy.Integer): + if not (min <= int(a) <= max): + raise ValueRangeError(f"Invalid value {int(a)} for range [{min}:{max}]") + return + + # TODO: Shouldn't we install a guard if the symbol is backed? Or is the + # semantics that this is an "unchecked" assert (but it this actually + # something useful? Might be better to restrict only for unbacked + # SymInt). + if isinstance(a, sympy.Symbol): + self.constrain_symbol_range( + a, + compiler_min=min, + compiler_max=max, + ) + + @record_shapeenv_event() + def _constrain_unify(self, a: SymInt, b: SymInt) -> None: + """ + Given two SymInts, constrain them so that they must be equal. NB: + this will not work with SymInts that represent nontrivial expressions + (yet!) + """ + # TODO: this does not install a deferred runtime assert yet + + # TODO: Maybe dedupe this with _maybe_guard_rel? + # Update Feb 2024: this is extra important to do, this doesn't handle + # unbacked replacements properly nor does it generate deferred runtime + # asserts + if not isinstance(a, SymInt): + if not isinstance(b, SymInt): + assert a == b + else: + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) + assert b.node.shape_env is self + self.replacements[b.node.expr] = sympy.Integer(a) + else: + # TODO: Actually, we can support this as long as one of them is a symbol. + # NB: We can't actually do "unification" as our operators are not + # injective + assert isinstance(a.node.expr, sympy.Symbol), "constraining non-Symbols NYI" + assert a.node.shape_env is self + if not isinstance(b, SymInt): + self.replacements[a.node.expr] = sympy.Integer(b) + else: + assert a.node.shape_env is b.node.shape_env + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) + new_var = self._find(a.node.expr) + self.replacements[b.node.expr] = new_var + + def _ignore_fresh_unbacked_symbols_tls(self) -> bool: + return getattr(TLS, "ignore_fresh_unbacked_symbols", False) + + @record_shapeenv_event() + def _ignore_fresh_unbacked_symbols_set(self, b: bool) -> bool: + prev = self._ignore_fresh_unbacked_symbols_tls() + TLS.ignore_fresh_unbacked_symbols = b + return prev + + @contextmanager + def ignore_fresh_unbacked_symbols(self) -> Iterator[None]: + """ + Indicates that the newly allocated unbacked SymInts are being + discarded + """ + prev = self._ignore_fresh_unbacked_symbols_set(True) + try: + yield + finally: + self._ignore_fresh_unbacked_symbols_set(prev) + + @record_shapeenv_event() + def freeze(self) -> None: + """Freeze this ShapeEnv to stop accumulating guards + + A frozen ShapeEnv will ignore any further guards generated on it and + only emit a warning which may lead to accuracy problems. + """ + self.frozen = True + + @record_shapeenv_event() + def freeze_runtime_asserts(self) -> None: + """Freeze this ShapeEnv to stop adding deferred runtime asserts. + + We will error if you try to install a new runtime assert when it is + frozen. This would indicate a lowering violation, or perhaps something + we know statically is already True but we are checking it again in a way + that is not clearly dischargeable. + """ + # self.prefer_deferred_runtime_asserts_over_guards = False + self.runtime_asserts_frozen = True + + def _create_symbol_for_source(self, source: Source) -> Optional[sympy.Symbol]: + if not self._translation_validation_enabled: + return None + srcname = source.name() + if source not in self.source_to_symbol: + self.source_to_symbol[srcname] = sympy.Symbol(srcname, integer=True) + return self.source_to_symbol[srcname] + + def _add_z3var(self, symbol: sympy.Symbol, type: type) -> None: + if self._translation_validation_enabled: + self.validator.add_var(symbol, type) + + def _add_target_expr(self, expr: SympyBoolean) -> None: + if self._translation_validation_enabled: + self.validator.add_target_expr(expr) + + def _add_assertion(self, expr: SympyBoolean) -> None: + if self._translation_validation_enabled: + self.validator.add_assertion(expr) + + def _check_translation_validate(self) -> None: + if self._translation_validation_enabled: + self.validator.validate() + + @record_shapeenv_event() + def _create_fx_call_function( + self, + op: Callable, + args: tuple, + ) -> tuple[Optional[torch.fx.Node], bool]: + # Cache this tuple in order to avoid duplicated nodes. + node_key = (op, args) + # Flags whether the returned node was cached or not. + fresh = False + + if self._translation_validation_enabled and node_key not in self.fx_node_cache: + # Presence of None in the arguments implies that we should ignore this operation. + if any(a is None for a in args): + # We check if we are not mixing SymNode that should not be ignored + # (fx_node is not None) with those that should (fx_node is None). + assert all(not isinstance(a, torch.fx.Node) for a in args) + return None, fresh + + fresh = True + + # If translation validation is enabled, all arguments must have its + # own FX node. + assert all(a is not None for a in args), ( + f"missing arg in FX graph ({op.__name__}): {args}" + ) + node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) + self.name_to_node[node.name] = node + + return self.fx_node_cache.get(node_key, None), fresh + + def _create_fx_placeholder_and_z3var( + self, + symbol: sympy.Symbol, + type: type, + ) -> Optional[torch.fx.Node]: + if not self._translation_validation_enabled: + return None + + node_key = (self.graph.placeholder, (symbol,)) + + # Check if we haven't added this symbol already. + # If so, skip the placeholder creation, as it + # generates invalid Python code. + if node_key not in self.fx_node_cache: + # Add a Z3 variable according to 'type'. + self._add_z3var(symbol, type) + # Create the FX placeholder out of a mangled name. + mangled_name = re.sub( + r"[^a-zA-Z0-9]", "_", re.sub(r"[()]", "", symbol.name) + ) + node = self.fx_node_cache[node_key] = self.graph.placeholder(mangled_name) + self.name_to_node[node.name] = node + # Attach the 'symbol' to the placeholder so that we can retrieve + # the Z3 variable later. + node.meta["symbol"] = symbol + + return self.fx_node_cache[node_key] + + def _remove_fx_node(self, node: Optional[torch.fx.Node]) -> None: + if self._translation_validation_enabled and node is not None: + self.name_to_node.pop(node.name) + self.graph.erase_node(node) + + def _add_fx_node_metadata(self, node: torch.fx.Node) -> None: + from torch._dynamo.utils import get_current_node + + if self.should_record_events: + node.meta[SHAPEENV_EVENT_KEY] = self._last_event_index() + node.meta[CURRENT_NODE_KEY] = get_current_node() + + @staticmethod + def _suppress_guards_tls() -> bool: + return getattr(TLS, "suppress_guards", False) + + @record_shapeenv_event() + def _suppress_guards_enter(self) -> None: + if not hasattr(TLS, "suppress_guards_stack"): + TLS.suppress_guards_stack = [] + old = self._suppress_guards_tls() + TLS.suppress_guards_stack.append(old) + TLS.suppress_guards = True + + @record_shapeenv_event() + def _suppress_guards_exit(self) -> None: + old = ( + TLS.suppress_guards_stack.pop() + if len(TLS.suppress_guards_stack) > 0 + else False + ) + TLS.suppress_guards = old + + def suppress_guards(self) -> _GeneratorContextManager[None]: + """Context manager to ignore all guards generated inside""" + return _suppress_guards(self) + + def _get_key(self) -> tuple[int, int, int, int]: + """ + Defines the current "state" of the guards we've accumulated in this ShapeEnv. + Determines when we need to invalidate our cache + """ + return ( + len(self.replacements), + len(self.divisible), + self.num_deferred_runtime_asserts, + len(self.unbacked_var_to_val), + ) + + def _update_version_counter(self) -> None: + # if the change to shape env effects self.divisible set + # _resimplify_floor_div_axioms. + # This is used to trigger a resimplication of FloorDiv to CleanDivs + # in implication inside the function resimplify_floor_div. + if len(self.divisible) != self._prev_cache_key[1]: + self._resimplify_floor_div_axioms = True + + # The shape environment is queried orders of magnitude more often than + # it is changed, so we summarise the cache key into a linearly + # increasing version counter which is cheaper to check in _lru_cache + + # Only update version counter if the state actually changed + cur_key = self._get_key() + + if self._prev_cache_key != cur_key: + self._prev_cache_key = cur_key + self._version_counter += 1 + + def _produce_dyn_sizes( + self, + ex_size: Sequence[IntLikeType], + source: Source, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + return self._produce_dyn_sizes_from_int_tuple( + tuple(ex_size), source, symbolic_context + ) + + def _produce_dyn_sizes_from_int_tuple( + self, + tensor_size: Sequence[IntLikeType], + source: Source, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + assert all(not is_symbolic(val) for val in tensor_size), ( + f"Expect size to be a plain tuple of ints but got {tensor_size}" + ) + from torch._dynamo.source import TensorProperty, TensorPropertySource + + _assert_symbol_context(symbolic_context) + dynamic_dims = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + constraint_dims = symbolic_context.constraint_sizes # type: ignore[attr-defined] + size = [] + for i, val in enumerate(tensor_size): + sym = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.SIZE, i), + dynamic_dims[i], + constraint_dims[i], + do_not_specialize_zero_one=config.backed_size_oblivious, + symbolic_context=symbolic_context, + ) + if ( + isinstance(symbolic_context, StatelessSymbolicContext) + and symbolic_context.specialize_on + ): + for specialization in symbolic_context.specialize_on[i]: + self.specializations.add( + Specialization( + TensorPropertySource(source, TensorProperty.SIZE, i), + specialization, + ) + ) + if ( + config.backed_size_oblivious + and isinstance(sym, sympy.Symbol) # could be static + and symbol_is_type(sym, SymT.SIZE) + ): + self.size_like.add(sym) + size.append(sym) + return size + + def create_symbolic_sizes_strides_storage_offset( + self, + ex: torch.Tensor, + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: + """ + Returns a list of symbolic sizes and strides for the given tensor. + We try our best to express stride in terms of the sizes, so as to not + introduce new symbolic variables. + """ + + ex_size = tuple( + self._maybe_specialize_sym_int_with_hint(sz) for sz in ex.size() + ) + ex_stride = tuple( + self._maybe_specialize_sym_int_with_hint(sd) for sd in ex.stride() + ) + ex_storage_offset = self._maybe_specialize_sym_int_with_hint( + ex.storage_offset() + ) + + return self._create_symbolic_sizes_strides_storage_offset( + ex_size, + ex_stride, + ex_storage_offset, + [_is_dim_dynamic(ex, i) for i in range(ex.dim())], + source, + symbolic_context=symbolic_context, + ) + + # Dynamo may want to wrap FakeTensors with SymInt sizes up e.g. make_fx(opt_f(), tracing_mode="symbolic"). + # We create symbols in shape_env using the backed hints behind SymInt. + + # Case 1: when SymInt is backed, dynamo can proceed with FakeTensors that have concrete shape. + # produce_guards will trigger specializations on the outer stuff + + # Case 2: when the SymInt is unbacked, we will throw an data dependent error in require_hint(). + # + # It's probably good for now but it's important to note that this approach has implications for + # the original shape_env when checking guards in different order. + + # Example: + # --------- + # Consider a function "opt_f" as shown below: + + # @torch.compile() + # def opt_f(x: bool, y: Tensor): + # if x == True: + # return y + torch.randn([4]) + # else: + # return y + # Depending on the sequence of calls, we might install two different sets of guards: + + # 1. opt_f(False, y): + # - "x == False" (always works for any size y) + + # 2. opt_f(True, y): + # - Triggers recompilation and results in guards like: + # - "x == True and y.size(0) == 4" + # - (or "y.size(0) == 4 and x == True") + + # The order of checking the guards matters. In this specific example: + # If True branch guard check precedes False branch and for True branch, y.size(0) check precedes x == True, + # we may have an unnessary shape speciliazation for y. + def _maybe_specialize_sym_int_with_hint( + self, maybe_sym: IntLikeType + ) -> IntLikeType: + assert isinstance(maybe_sym, (int, torch.SymInt)) + if is_symbolic(maybe_sym): + assert maybe_sym.node.shape_env is not self, ( + "expect the symbol is created from an shape env other than current one." + ) + return maybe_sym.node.require_hint() + return maybe_sym + + @record_shapeenv_event() + def _create_symbolic_sizes_strides_storage_offset( + self, + # NB: SymInt is allowed here due to nested int, normally you don't + # actually pass true symbolic sizes to this function + ex_size: Sequence[IntLikeType], + ex_stride: Sequence[IntLikeType], + ex_storage_offset: IntLikeType, + is_dim_dynamic: Sequence[bool], + source: Source, + *, + symbolic_context: Optional[SymbolicContext] = None, + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: + dim = len(ex_size) + + # Reimplement the legacy behavior + if symbolic_context is None: + constraint_sizes: list[DimConstraint] = [None] * dim + constraint_strides: list[DimConstraint] = [None] * dim + dynamic_dims = [] + dynamic_strides = [] + for i in range(dim): + # NB: This is encapsulation breaking! Legacy behavior was + # bad. + if is_dim_dynamic[i]: + r = DimDynamic.DYNAMIC + elif self.assume_static_by_default: + r = DimDynamic.STATIC + else: + r = DimDynamic.DUCK + dynamic_dims.append(r) + dynamic_strides.append(r) + dynamic_dims = [DimDynamic.DUCK] * dim + dynamic_strides = [DimDynamic.INFER_STRIDE] * dim + # symbolic_context is None - set one + symbolic_context = StatelessSymbolicContext( + dynamic_sizes=dynamic_dims, + dynamic_strides=dynamic_strides, + constraint_sizes=constraint_sizes, + constraint_strides=constraint_strides, + ) + # We got a StatelessSymbolicContext + _assert_symbol_context(symbolic_context) + constraint_sizes = symbolic_context.constraint_sizes # type: ignore[attr-defined] + constraint_strides = symbolic_context.constraint_strides # type: ignore[attr-defined] + dynamic_sizes = symbolic_context.dynamic_sizes # type: ignore[attr-defined] + dynamic_strides = symbolic_context.dynamic_strides # type: ignore[attr-defined] + + # TODO: make this configurable from outside symbolic_context; we made a symbolic_context + # decision here where if all sizes are static, we are going to + # specialize all of the inner strides/offset too. We don't have to + # do this, and arguably we should ALWAYS allow for dynamic offset, + # this is cheap. + # TODO: This should be DYNAMIC, using DUCK for BC + dynamic_offset = ( + DimDynamic.STATIC + if all(r == DimDynamic.STATIC for r in dynamic_sizes) + else DimDynamic.DUCK + ) + are_sizes_static = all(r == DimDynamic.STATIC for r in dynamic_sizes) + + assert len(dynamic_sizes) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(dynamic_strides) == dim, f"{len(dynamic_sizes)} != {dim}" + assert len(constraint_sizes) == dim + assert len(constraint_strides) == dim + + from torch._dynamo.source import TensorProperty, TensorPropertySource + + size: list[sympy.Expr] = self._produce_dyn_sizes_from_int_tuple( + ex_size, source, symbolic_context + ) + stride = self._compute_symbolic_stride( + source, + size, + ex_size, + ex_stride, + dynamic_strides, + constraint_strides, + are_sizes_static, + symbolic_context, + ) + + sym_sizes = [ + self.create_symintnode( + sym, + hint=hint, + source=TensorPropertySource(source, TensorProperty.SIZE, i), + ) + for i, (sym, hint) in enumerate(zip(size, ex_size)) + ] + sym_stride = [] + for i, stride_expr in enumerate(stride): + # NB: Don't duck size the stride; instead use the expression + # we computed + assert stride_expr is not None + sym_stride.append( + self.create_symintnode( + stride_expr, + hint=ex_stride[i], + source=TensorPropertySource(source, TensorProperty.STRIDE, i), + ) + ) + sym_storage_offset = self.create_symintnode( + self.create_symbol( + ex_storage_offset, + TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + dynamic_dim=dynamic_offset, + constraint_dim=None, + symbolic_context=symbolic_context, + ), + hint=ex_storage_offset, + source=TensorPropertySource(source, TensorProperty.STORAGE_OFFSET), + ) + return tuple(sym_sizes), tuple(sym_stride), sym_storage_offset + + def _compute_symbolic_stride( + self, + source: Source, + size: Sequence[sympy.Expr], + ex_size: Sequence[IntLikeType], + ex_stride: Sequence[IntLikeType], + dynamic_strides: Sequence[DimDynamic], + constraint_strides: Sequence[ + Optional[Union[StrictMinMaxConstraint, RelaxedUnspecConstraint]] + ], + are_sizes_static: bool, + symbolic_context: SymbolicContext, + ) -> list[sympy.Expr]: + from torch._dynamo.source import TensorProperty, TensorPropertySource + + stride: list[Optional[sympy.Expr]] = [None] * len(size) + candidates: dict[IntLikeType, sympy.Expr] = {} + + # iterate over unbound strides in val ascending order with + # index descending as a tie breaker since for cases like + # [(1, 1), (1, 0)], we want to fill in the right most + # stride first. + val_list = [(val, -i) for i, val in enumerate(ex_stride)] + val_list.sort(key=_nested_int_aware_sort) + + for val, neg_i in val_list: + i = -neg_i + contiguous_stride = ( + i != len(ex_stride) - 1 + and ex_stride[i] == ex_size[i + 1] * ex_stride[i + 1] + ) + if val in (0, 1) and not contiguous_stride: + out_stride = sympy.Integer(val) + else: + dynamic_stride = dynamic_strides[i] + if dynamic_stride == DimDynamic.INFER_STRIDE and val in candidates: + # Set stride to a candidate only for DimDynamic.INFER_STRIDE + out_stride = candidates[val] + else: + # Set INFER_STRIDE to STATIC or DUCK depending on sizes + dyn_stride = dynamic_stride + if dynamic_stride == DimDynamic.INFER_STRIDE: + dyn_stride = ( + DimDynamic.STATIC if are_sizes_static else DimDynamic.DUCK + ) + out_stride = self.create_symbol( + val, + TensorPropertySource(source, TensorProperty.STRIDE, i), + dynamic_dim=dyn_stride, + constraint_dim=constraint_strides[i], + symbolic_context=symbolic_context, + ) + stride[i] = out_stride + candidates[ex_size[i] * val] = size[i] * out_stride + + assert all(x is not None for x in stride) + return stride + + @record_shapeenv_event() + def create_symintnode( + self, + sym: sympy.Expr, + *, + hint: Optional[int], + source: Optional[Source] = None, + ) -> IntLikeType: + """Create a SymInt value from a symbolic expression + + If you know what the current hint value of the SymInt to be created + is, pass it into hint. Otherwise, pass None and we will make our best + guess + + """ + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + out: IntLikeType + if isinstance(sym, sympy.Integer): + if hint is not None: + assert int(sym) == hint + out = int(sym) + else: + # How can this occur? When we mark_unbacked, we end up with a real + # tensor that has hints for all sizes, but we MUST NOT create a + # SymNode with a hint, because we're hiding the hint from our eyes + # with the unbacked Symbol. And in fact, the hint compute may be + # inconsistent with size oblivious tests. + if free_unbacked_symbols(sym): + hint = None + out = SymInt(SymNode(sym, self, int, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_symfloatnode( + self, + sym: sympy.Expr, + *, + hint: Optional[int], + source: Optional[Source] = None, + ) -> FloatLikeType: + """Create a SymFloat value from a symbolic expression""" + if self._translation_validation_enabled and source is not None: + # Create a new symbol for this source. + symbol = self._create_symbol_for_source(source) + assert symbol is not None + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + # Add an equality assertion for the newly created symbol and 'sym'. + self._add_assertion(sympy.Eq(symbol, sym)) + else: + fx_node = None + + out: FloatLikeType + if isinstance(sym, sympy.Float): + if hint is not None: + assert float(sym) == hint + out = float(sym) + else: + # You could give this the same treatment as SymInt above if + # you supported mark_unbacked on a float, but it's a kind of + # strange thing to do though because floats don't get 0/1 + # specialization anyway + if free_unbacked_symbols(sym): + assert hint is None, sym + out = SymFloat(SymNode(sym, self, float, hint, fx_node=fx_node)) + return out + + @record_shapeenv_event() + def create_unspecified_symint_and_symbol( + self, value: int, source: Source, dynamic_dim: DimDynamic + ) -> IntLikeType: + """Create a SymInt wrapping a new unspecified symbol""" + return self.create_symintnode( + self.create_unspecified_symbol( + value, + source=source, + dynamic_dim=dynamic_dim, + ), + hint=value, + source=source, + ) + + def create_symboolnode(self, sym: sympy.Expr) -> SymBool: + """Create a SymBool object from a sympy boolean expression""" + # This function is only being used in serialization, so we do not track it + # for validation. + return SymBool(SymNode(sym, self, bool, None)) + + def _log_create_unbacked_symbol( + self, + prefix: str, + symbol: sympy.Symbol, + vr: ValueRanges, + source: Optional[Source] = None, + sym_node: Optional[SymNode] = None, + ) -> None: + is_debug = config.extended_debug_create_symbol is not None and str( + symbol + ) in config.extended_debug_create_symbol.split(",") + sloc: Union[str, SLoc] + if source is None: + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + else: + sloc, maybe_extra_debug = source.name(), "" + log.info( + "%s %s [%s, %s] %s%s", + prefix, + symbol, + vr.lower, + vr.upper, + sloc, + maybe_extra_debug, + stack_info=is_debug, + ) + trace_structured( + "create_unbacked_symbol", + metadata_fn=lambda: { + "symbol": str(symbol), + "node_id": id(sym_node), + "vr": f"[{vr.lower}, {vr.upper}]", + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(), + }, + ) + + @record_shapeenv_event() + def create_unbacked_symfloat(self) -> SymFloat: + """Create a symbolic float without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_FLOAT, next(self.unbacked_symfloat_counter) + ) + self.counter["create_unbacked_symbol"] += 1 + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float + sloc = self._get_sloc() + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, float) + + sym_node = SymNode(symbol, self, float, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symfloat", symbol, vr, sym_node=sym_node + ) + + return SymFloat(sym_node) + + @record_shapeenv_event() + def create_unbacked_symint(self, source: Optional[Source] = None) -> SymInt: + """Create a symbolic integer without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True + ) + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int + sloc = self._get_sloc() + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, int) + + sym_node = SymNode(symbol, self, int, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symint", symbol, vr, source, sym_node=sym_node + ) + + return SymInt(sym_node) + + def is_unbacked_symint(self, symbol: sympy.Symbol) -> bool: + """Check if a sympy symbol matches the naming convention for unbacked symbols""" + return symbol_is_type(symbol, SymT.UNBACKED_INT) + + @record_shapeenv_event() + def create_unbacked_symbool(self) -> SymBool: + """Create a symbolic boolean without a hint value""" + symbol: sympy.Symbol = make_symbol( + SymT.UNBACKED_INT, next(self.unbacked_symint_counter), integer=True + ) + if not self._ignore_fresh_unbacked_symbols_tls(): + self.pending_fresh_unbacked_symbols.append(symbol) + self.counter["create_unbacked_symbol"] += 1 + self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) + vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int + sloc = self._get_sloc("default value range for unbacked SymBool") + self.var_to_range_sloc[symbol] = ValueRangesSLoc(sloc, sloc) + + # Create a new FX placeholder and Z3 variable for 'symbol'. + fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) + + sym_node = SymNode(sympy.Eq(symbol, 1), self, bool, None, fx_node=fx_node) + self._log_create_unbacked_symbol( + "create_unbacked_symbool", symbol, vr, sym_node=sym_node + ) + + return SymBool(sym_node) + + @record_shapeenv_event() + def create_unspecified_symbol( + self, + val: Union[int, SymInt, float, SymFloat], + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + symbolic_context: Optional[StatelessSymbolicContext] = None, + ) -> sympy.Expr: + """ + Create a symbol with an unspecified value + + Compared to standard symbols we do not assume the value is positive, + nor do we specialze on zero or one values. + """ + # 'positive' is None for unspecified symbols, since we can't + # assume that it will be neither positive nor negative. + + # We don't want to specialize zero one val for unspecified symbol + # so that we can always get a new symbol despite val. + return self.create_symbol( + val, + source, + dynamic_dim, + constraint_dim, + positive=None, + do_not_specialize_zero_one=True, + symbolic_context=symbolic_context, + ) + + @record_shapeenv_event() + def create_symbol( + self, + val: int, + source: Source, + dynamic_dim: DimDynamic = DimDynamic.DUCK, + constraint_dim: DimConstraint = None, # NB: includes None + positive: Optional[bool] = True, + do_not_specialize_zero_one: bool = False, + symbolic_context: Optional[StatelessSymbolicContext] = None, + ) -> sympy.Expr: + """Create a new symbol which is tracked by this ShapeEnv""" + # check if constraint_dim is actually static integer + if ( + isinstance(constraint_dim, StrictMinMaxConstraint) + and constraint_dim.vr.lower == constraint_dim.vr.upper + ): + dynamic_dim = DimDynamic.STATIC + if constraint_dim.vr.lower != val: + raise ConstraintViolationError( + f"Static shape constraint of {constraint_dim.vr.lower} does not match input size of {val}, " + f"for {source.name()}" + ) + if symbolic_context: + from torch._dynamo.source import TensorPropertySource + + assert isinstance(source, TensorPropertySource) + # TODO: storage_offset handling? + assert source.idx is not None + symbolic_context.dynamic_sizes[source.idx] = dynamic_dim + symbolic_context.constraint_sizes[source.idx] = None + constraint_dim = None + + # see note [Tensor Fakification and Symbol Caching] + source_name = source.name() + if ( + isinstance(symbolic_context, StatefulSymbolicContext) + and id(self) not in symbolic_context.shape_env_to_source_to_symbol_cache + ): + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] = {} + + if ( + isinstance(symbolic_context, StatefulSymbolicContext) + and source_name + and ( + source_name + in symbolic_context.shape_env_to_source_to_symbol_cache[id(self)] + ) + ): + return symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] + + if dynamic_dim in (DimDynamic.SIZE_LIKE_UNBACKED, DimDynamic.OBLIVIOUS_SIZE): + out = self.create_unbacked_symint(source).node.expr + self._constrain_range_for_size(out) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = out + if dynamic_dim is DimDynamic.OBLIVIOUS_SIZE: + self.oblivious_var_to_val[out] = val + return out + + if do_not_specialize_zero_one: + specialize_zero_one = False + else: + specialize_zero_one = self.specialize_zero_one + + assert isinstance(source, Source), f"{type(source)} {source}" + assert not (positive and val < 0), f"positive set for negative value: {val}" + # It's always sound to allocate a symbol as DYNAMIC. If the user + # constrained the symbol, force the symbolic_context to DYNAMIC, because our + # constraint code will do weird stuff if, e.g., it's duck shaped + if constraint_dim is not None: + dynamic_dim = DimDynamic.DYNAMIC + + if dynamic_dim is DimDynamic.STATIC: + out = sympy.Integer(val) + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = out + return out + + elif dynamic_dim is DimDynamic.DUCK: + # duck_shape can be used to globally turn off duck shaping, even + # if it was requested + duck = self.duck_shape + elif dynamic_dim is DimDynamic.DYNAMIC: + duck = False + else: + raise AssertionError(f"unhandled dynamic_dim {dynamic_dim}") + + sloc = self._get_sloc() + + if val in (0, 1) and specialize_zero_one: + if val == 0: + return sympy.S.Zero + else: + return sympy.S.One + elif not duck or val not in self.val_to_var: + # If we're not duck shaping, we always create a new symbol + # Even if we're duck shaping, if we haven't seen this particular + # value before, we also create a new symbol + symbol_id = self._generate_unique_id(source.name()) + if type(val) is int or is_nested_int(val): + sympy_expr = make_symbol( + SymT.SIZE, symbol_id, positive=positive, integer=True + ) + else: + sympy_expr = make_symbol( + SymT.FLOAT, symbol_id, positive=positive, real=True + ) + self.source_to_var[source_name] = sympy_expr + # We always associate vars to vals + if isinstance(val, int): + self.var_to_val[sympy_expr] = sympy.Integer(val) + elif isinstance(val, float): + self.var_to_val[sympy_expr] = sympy.Float(val) + else: + # Only used for jagged layout nested tensors + self.var_to_val[sympy_expr] = SingletonInt( + val.node.nested_int(), coeff=val.node.nested_int_coeff() + ) + + # Do the appending later, because we always want to populate this + self.var_to_sources[sympy_expr] = [] + # Create a Z3 variable for the new symbol. + self._add_z3var(sympy_expr, int) + + if duck: + # Make sure to reuse this symbol for subsequent duck shaping + self.val_to_var[val] = sympy_expr + + if isinstance(val, int): + if positive: + # Add assertions for the newly created symbols + self._add_assertion(sympy_expr > 1) + + # Apply default range, which assumes not zero-one + self.var_to_range[sympy_expr] = self._default_value_range( + do_not_specialize_zero_one + ) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc( + self._get_sloc( + "user code shown is first use of this value--the guard itself is not " + "due user code but due to 0/1 specialization in the framework; to " + "avoid specialization try torch._dynamo.mark_unbacked(tensor, dim)" + if self.specialize_zero_one + else None + ), + sloc, + ) + else: + self.var_to_range[sympy_expr] = ( + self._default_unspecified_value_range() + ) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) + + # Small performance optimization: if we have a min-max constraint, + # we can proactively narrow to that range + if isinstance(constraint_dim, StrictMinMaxConstraint): + assert not duck + self._update_var_to_range( + sympy_expr, constraint_dim.vr, is_constraint=True + ) + + vr = self.var_to_range[sympy_expr] + assert vr.is_int + + if val not in vr: + raise ConstraintViolationError( + f"{val} not in range [{vr.lower}, {vr.upper}]" + ) + + range_str = f"[{vr.lower}, {vr.upper}]" + elif isinstance(val, float): + self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) + self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) + range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float + else: + # Skip var_range logic for SingletonInt + # Only used for jagged layout nested tensors + range_str = "" + + r = sympy_expr + + is_debug = config.extended_debug_create_symbol is not None and str( + sympy_expr + ) in config.extended_debug_create_symbol.split(",") + maybe_more_info = "" + if not is_debug and os.getenv("TORCHDYNAMO_EXTENDED_ADVICE", "1") not in ( + "0", + "", + ): + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{sympy_expr}" ' + "or to suppress this message run with " + 'TORCHDYNAMO_EXTENDED_ADVICE="0"' + ) + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + self.log.info( + "create_symbol %s = %s for %s %s %s%s%s", + sympy_expr, + val, + source.name(), + range_str, + sloc, + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + trace_structured( + "create_symbol", + metadata_fn=lambda: { + "symbol": str(sympy_expr), + "val": repr(val), + "vr": range_str, + "source": source.name(), + "user_stack": structured.from_traceback( + TracingContext.extract_stack() + ), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + + self.counter["create_symbol"] += 1 + else: + # This implements duck-shaping: input sizes that match are assigned + # the same symint + r = self.val_to_var[val] + self.source_to_var[source_name] = r + self.log.debug("create_symbol %s duck sized %s", r, source.name()) + + if isinstance(r, sympy.Symbol): + r_sources = self.var_to_sources[r] + r_sources.append(source) + if not source.is_ephemeral() and r_sources[0].is_ephemeral(): + # prefer non-ephemeral source first since it may be guarded on later + r_sources[0], r_sources[-1] = r_sources[-1], r_sources[0] + + # This ensures we get zeros in symbol_guard_counts, which makes + # some queries simpler (since we will accumulate mass on 0 this + # way) + self.symbol_guard_counter[r] = 0 + + if isinstance(symbolic_context, StatefulSymbolicContext) and source_name: + symbolic_context.shape_env_to_source_to_symbol_cache[id(self)][ + source_name + ] = r + return r + + def add_var_to_val(self, expr: sympy.Symbol, val: int) -> None: + """Adds a new symbol to the symbolic environment.""" + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) + assert expr not in self.var_to_val, f"{expr} already exists" + self.var_to_val[expr] = sympy.Integer(val) + + def _debug_name(self, source: Source) -> str: + src_name = source.name() + return self.source_name_to_debug_name.get(src_name, src_name) + + def _render_range_for_constraint_violation( + self, source: Source, c: Union[StrictMinMaxConstraint, RelaxedUnspecConstraint] + ) -> str: + if isinstance(c, StrictMinMaxConstraint): + lower, upper = c.vr.lower, c.vr.upper + default = self._default_value_range() + if lower <= default.lower: + lower = None + if upper >= default.upper: + upper = None + c_render = ( + f"{self._debug_name(source)} = {source.name()} in the specified range" + ) + if lower is not None and upper is not None: + c_render += f" {lower} <= {self._debug_name(source)} <= {upper}" + elif lower is None and upper is not None: + c_render += f" {self._debug_name(source)} <= {upper}" + elif lower is not None and upper is None: + c_render += f" {lower} <= {self._debug_name(source)}" + return c_render + return c.render(source) + + def produce_guards(self, *args: Any, **kwargs: Any) -> list[str]: + """ + Like produce_guards_verbose, but only returns the non-verbose python guard expressions + (no verbose guards produced.) + """ + return self.produce_guards_verbose(*args, **kwargs, langs=("python",))[0].exprs + + def produce_guards_verbose( + self, + placeholders: Sequence[FakeTensor], + sources: Sequence[Source], + source_ref: Callable[[Source], str] = lambda n: n.name(), + *, + guards: Optional[list[ShapeGuard]] = None, + input_contexts: Optional[DimList[SymbolicContext]] = None, + # Encodes user-specified input shape equations of the form s = s' and s = fn(s'). + # (See docs on EqualityConstraint for details of the encoding.) + equalities_inputs: Optional[EqualityConstraint] = None, + _simplified: bool = False, + # Indicates if we should produce guards for known static values. + ignore_static: bool = True, + langs: tuple[str, ...] = ("python", "verbose_python"), + ) -> list[_ShapeGuardsHelper]: + """ + Generates a list of guards strings which, when evaluated in a context that + defines tensors for all the sources, returns True or False depending + on if the guards in the list evaluated to True or not. Primarily used by Dynamo, + but this is also helpful for manual testing of guards (see + evaluate_guards_for_args) + + For convenience in testing, a source is allowed to be a str, + in which case we will assume it is a LocalSource + + simplified lets you omit duck sizing, equality and 0/1 guards. + This is useful for testing when you don't care about the boilerplate + guards, and it may be helpful for user output too (be careful though; + some equality guards are nontrivial! It would be nice to get simplified + output to print them too). It's private because it's not + intended for normal use + + Returns guards in python and python with verbose comments (verbose) by + default. + """ + self.log.info("produce_guards") + + # Check if we get to the same ShapeEnv state by replaying the recorded events. + # This will create a new ShapeEnv instance, and call all recorded function + # calls on this new instance. Finally, it will check whether this new instance + # has equal state. + # + # It's important that we do it in the begining of this function, since it modifies + # self.dim_constraints through its execution. Changes that happen in this method + # aren't interesting, since this is the function call we wish to reproduce at the + # end. If we wish to simply reproduce ShapeEnv instances even after this call, + # this method should also be recorded. + if self.check_recorded_events: + shape_env = replay_shape_env_events(self.events) + self.check_equal(shape_env) + + assert len(placeholders) == len(sources), ( + f"len({placeholders}) != len({sources})" + ) + Tensorlike = (torch.Tensor, FakeTensorMeta) + + def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: + return StatelessSymbolicContext( + # Ignored; only the constraints part is relevant below. + dynamic_sizes=[DimDynamic.DYNAMIC] * t.dim(), + dynamic_strides=[DimDynamic.INFER_STRIDE] * t.dim(), + constraint_sizes=[None] * t.dim(), + constraint_strides=[None] * t.dim(), + ) + + # Expand optional inputs, or verify invariants are upheld + if input_contexts is None: + input_contexts = [ + _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None + for t in placeholders + ] + else: + assert len(input_contexts) == len(placeholders) + for i, (t, context) in enumerate(zip(placeholders, input_contexts)): + if isinstance(t, Tensorlike): + if context is None: + input_contexts[i] = _create_no_constraints_context(t) + else: + assert isinstance(t, (SymInt, int, SymFloat, float)) + assert not isinstance(context, list) + + # It took a lot of sweat to figure out the algorithm here. Let's + # explain how it works. + # + # The ShapeEnv lifecycle looks something like this: + # + # - For each input, you either generate a fresh Sympy symbol (s0) to + # represent its value (a binding site), or you reuse some + # preexisting symbol or expression, skipping the symbol allocation + # (e.g., duck sizing to a preexisting symbol, or expressing a + # stride as a multiplication of a separate stride and size.) + # Naively, you might expect to bind a fresh Sympy symbol for + # every input, but this is fairly wasteful as most of these + # symbols immediately simplify away, and if you don't eagerly + # specialize, e.g., 0/1 symbols, you end up with very complicated + # expressions that are not optimizable in practice. + # + # - You perform some compute on these symbols, occasionally + # introducing guards on boolean expressions on these symbols. + # In particular, whenever we guard on equality (_maybe_guard_rel), + # we can simplify shapes; e.g., when s0 == s1 * 2, we can now + # replace all occurrences of s0 with s1 * 2. Sometimes, a + # boolean expression evaluation doesn't introduce a guard, as + # the guard is already entailed by the simplifications we have + # applied. + # + # - In the end, you have a bunch of replacements (saying how to + # simplify shapes) and a bunch of guards (all the equality guards + # are trivial, because they're covered by the replacements). + # + # From the ShapeEnv, we must generate a Python expression that, when + # evaluated on a set of inputs, tells us whether or not these boolean + # expressions would have evaluated in the same way. However, + # we cannot easily compute this, as we elide recording boolean + # expressions when we think they are vacuously true. Thus, we seek + # an approximation: we must generate an expression, if true, would have + # produced an "equivalent" ShapeEnv, which would answer guard + # expressions in the same way. + # + # Our notion of equivalence is a bit subtle. For example, consider + # the ShapeEnv created from an input of size (5, 4) versus (4, 4) + # (no other guards.) Duck sizing would generate (s0, s1) in the first + # case but (s0, s0) in the second. We do NOT assume that size + # variables are disjoint; so in fact a graph that assumes the input + # could be (s0, s1) subsumes (s0, s0) (setting s0 == s1), but not + # vice versa. However, consider an analogous case (1,) versus (2,). + # Duck sizing generates (1,) and (s0,); the (s0,) graph does NOT + # subsume the (1,) graph because we assume that any size variables + # is NOT 0/1 (and make simplifications according to this; e.g., if + # we queried s0 == 0, we would immediately return False without + # returning a guard.) + # + # So, it is perhaps easier to flip things on their head: the guard + # expressions we generate here say what simplifications are valid, + # and what are not. Below, we explain each of the guard expressions + # we generate + + # TODO: Make this more efficient by binding all the size/stride/offsets + # to locals before performing tests on them. + + from torch._dynamo.source import TensorProperty, TensorPropertySource + + # Actual codegen must be delayed as we don't necessarily know what + # the symbol mapping is + input_guards = [] + + symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( + list + ) + symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = ( + collections.defaultdict(set) + ) + constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] + + printers: list[_ShapeGuardPrinter] = [] + py_printer = ShapeGuardPythonPrinter( + symbol_to_source, source_ref, self.var_to_sources + ) + for lang in langs: + if lang in ["python", "verbose_python"]: + printers.append(py_printer) + elif lang == "cpp": + printers.append( + _ShapeGuardCppPrinter( + symbol_to_source, source_ref, self.var_to_sources + ) + ) + else: + raise NotImplementedError(f"Unknown lang: {lang}") + + def record_constraint_violation( + warn_only: bool, + debug_name: str, + msg: str, + hint: Optional[Callable[[], str]] = None, + ) -> None: + constraint_violations.append( + (warn_only, debug_name, lambda: f"{msg}{hint()}" if hint else msg) + ) + + def is_dim(src: object) -> TypeGuard[TensorPropertySource]: + return ( + isinstance(src, TensorPropertySource) + and src.prop is TensorProperty.SIZE + ) + + if equalities_inputs: + source_index = {} + for i, src in enumerate(sources): + source_index[src.name()] = i + + def get_expression(tensor_dim_src: Source) -> sympy.Expr: + fake = placeholders[source_index[tensor_dim_src.base.name()]] # type: ignore[attr-defined] + assert tensor_dim_src.idx is not None # type: ignore[attr-defined] + symint = fake.shape[tensor_dim_src.idx] # type: ignore[attr-defined] + if isinstance(symint, torch.SymInt): + return symint.node.expr + else: + assert type(symint) is int, f"Expected int, got {type(symint)}" + return sympy.Integer(symint) + + for src1, src2 in equalities_inputs.source_pairs: + expr1, expr2 = get_expression(src1), get_expression(src2) # type: ignore[] + # Check whether given input shape values satisfy a specified equation s = s'. + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2)) + if not concrete_val: + raise ConstraintViolationError( + f"{src1.name()} = {expr1 if isinstance(expr1, int) else expr1.xreplace(self.var_to_val)}" + " is not equal to " + f"{src2.name()} = {expr2 if isinstance(expr2, int) else expr2.xreplace(self.var_to_val)}" + ) + + for srcEq, root, fn in equalities_inputs.derived_equalities: + expr1 = get_expression(srcEq) + # recall that root is either a phantom symbol or an input source + expr2, debug_name = ( + (root, self.var_to_sources[root][0].name()) + if isinstance(root, sympy.Symbol) + else (get_expression(root), self._debug_name(root)) + ) + expr2_ = fn(expr2) + # Check whether given input shape values satisfy a specified equation s = fn(s'). + # - Raise when the equation was violated by the given input shape values. + # - Otherwise issue a guard to constrain them. + concrete_val = self.evaluate_expr(sympy.Eq(expr1, expr2_)) + if not concrete_val: + raise ConstraintViolationError( + f"Expected input {srcEq.name()} to be equal to " + f"{fn(sympy.Symbol(debug_name))}, " + f"where {debug_name} = {expr2.xreplace(self.var_to_val)}, " + f"but got {expr1.xreplace(self.var_to_val)}" + ) + + for phantom_symbol in equalities_inputs.phantom_symbols: + # we created additional phantom symbols that are not input shape dimensions + symbol_to_source[phantom_symbol].extend( + self.var_to_sources[phantom_symbol] + ) + + # How do we know what the value of s0 is? Fresh variables can only be + # bound by inputs, so there MUST be some other input which binds the + # variable. If there is no such input, this is an error in our + # system. We record where all symbols come from, to help you diagnose + # why those symbols didn't occur. + # + # In fact, generally speaking it is only possible for the "outermost" + # user of a ShapeEnv to evaluate the guards, because some inputs may + # not be available to inner levels. For example, Dynamo can guard on + # tensors that never actually become graph arguments (they are + # pruned). In this case, only Dynamo knows about these arguments. + def track_symint( + source: Source, val: IntLikeType, constraint: DimConstraint = None + ) -> None: + log.debug("track_symint %s %s %s", LazyString(source.name), val, constraint) + assert not isinstance(val, SymInt) or is_symbolic(val) + + if isinstance(val, SymInt) and val.node.maybe_as_int() is not None: + val = val.node.maybe_as_int() + + if isinstance(val, SymInt): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + if constraint is not None and not isinstance( + constraint, RelaxedUnspecConstraint + ): + symbol_to_constraints[s].add(constraint) + else: + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + # try inferring the ranges of the expr s + sym_vrs = { + x: self.var_to_range.get(x, None) for x in s.free_symbols + } + if any(vr is None for vr in sym_vrs.values()): + # some of the free symbols in s don't have ranges + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + if s.is_number: + i = int(s) + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if i not in (0, 1): + constraint_violated = True + if constraint_violated: + assert constraint is not None + + def hint(s: sympy.Expr) -> str: + sexpr = py_printer.doprint(s) + return f"{sexpr}." + + var_with_range = self._render_range_for_constraint_violation( + source, constraint + ) + msg = ( + f"Not all values of {var_with_range} are valid because " + f"{self._debug_name(source)} was inferred to be equal to " + ) + record_constraint_violation( + constraint.warn_only, + self._debug_name(source), + msg, + hint=functools.partial(hint, s), + ) + + input_guards.append((source, s)) + else: + s = sympy.Integer(val) + input_guards.append((source, s)) + constraint_violated = False + if isinstance(constraint, StrictMinMaxConstraint): + if not ( + s == constraint.vr.lower == constraint.vr.upper + ): # allow static constraints + constraint_violated = True + elif isinstance(constraint, RelaxedUnspecConstraint): + # Don't complain about 0/1 specialization, we + # expect to have to compile in this case anyway + if val not in (0, 1): + constraint_violated = True + if constraint_violated: + assert constraint is not None + var_with_range = self._render_range_for_constraint_violation( + source, constraint + ) + user_stack = self.user_specialization_stacks.get(source, None) + framework_stack = self.framework_specialization_stacks.get( + source, None + ) + msg = ( + f"You marked {self._debug_name(source)} as dynamic but your code " + f"specialized it to be a constant ({val}). If you're using mark_dynamic, " + f"either remove it or use maybe_mark_dynamic. If you're using Dim.DYNAMIC, " + f"replace it with either Dim.STATIC or Dim.AUTO." + + ( + "\n\nFramework stack:\n" + "".join(framework_stack.format()) + if framework_stack + else "" + ) + + ( + "\n\nUser stack:\n" + "".join(user_stack.format()) + if user_stack + else "" + ) + ) + record_constraint_violation( + constraint.warn_only, self._debug_name(source), msg + ) + + def track_symfloat(source: Source, val: FloatLikeType) -> None: + log.debug("track_symfloat %s %s", LazyString(source.name), val) + assert not isinstance(val, SymFloat) or is_symbolic(val) + + if isinstance(val, SymFloat) and val.node.maybe_as_float() is not None: + val = val.node.maybe_as_float() + + if isinstance(val, SymFloat): + s = val.node.expr + if isinstance(s, sympy.Symbol): + symbol_to_source[s].append(source) + input_guards.append((source, s)) + else: + s = sympy.Float(val) + input_guards.append((source, s)) + + for t, source, context in zip(placeholders, sources, input_contexts): + if isinstance(source, str): + from torch._dynamo.source import LocalSource + + source = LocalSource(source) + assert isinstance(source, Source) + if t is None: + continue + if isinstance(t, (SymInt, int)): + constraint = ( + None if context is None else getattr(context, "constraint", None) + ) + track_symint(source, t, constraint) + continue + elif isinstance(t, (SymFloat, float)): + track_symfloat(source, t) + continue + assert isinstance(t, Tensorlike) + if is_traceable_wrapper_subclass(t): + from torch._dynamo.source import AttrSource + + assert isinstance(context, SubclassSymbolicContext) + + # For subclasses, we need to track symints on BOTH the outer + # and inner tensors. + # TODO: type this better + sources_tensors_constraints: list[tuple[Source, Any, Any, Any]] = [ + (source, t, context.constraint_sizes, context.constraint_strides) + ] + attrs, _ = t.__tensor_flatten__() + for attr in attrs: + inner_t = getattr(t, attr) + inner_context = context.inner_contexts[attr] + sources_tensors_constraints.append( + ( + AttrSource(source, attr), + inner_t, + inner_context.constraint_sizes, # type: ignore[attr-defined] + inner_context.constraint_strides, # type: ignore[attr-defined] + ) + ) + else: + sources_tensors_constraints = [ + (source, t, context.constraint_sizes, context.constraint_strides) # type: ignore[attr-defined] + ] + + for ( + src, + curr_t, + constraint_size, + constraint_stride, + ) in sources_tensors_constraints: + if is_sparse_any(curr_t): + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource( + src, TensorProperty.SIZE, i + ) + track_symint(property_source, ss, constraint_size[i]) + else: + for i, ss in enumerate(curr_t.size()): + property_source = TensorPropertySource( + src, TensorProperty.SIZE, i + ) + track_symint(property_source, ss, constraint_size[i]) + for i, ss in enumerate(curr_t.stride()): + property_source = TensorPropertySource( + src, TensorProperty.STRIDE, i + ) + track_symint(property_source, ss, constraint_stride[i]) + track_symint( + TensorPropertySource(src, TensorProperty.STORAGE_OFFSET), + curr_t.storage_offset(), + ) + + # 1. Every input must equal the final simplified symbolic expression + # stored on the placeholder. Given a placeholder (s0*2, s1), + # if we have an input (2, 3), we must show s0*2 == 2 and s1 == 3. + # This does a lot of work: it covers duck sizing and equality guards. + all_exprs: list[list[str]] = [[] for _ in langs] + self.dim_constraints = DimConstraints( + symbol_to_source, + self.var_to_val, + set(symbol_to_constraints.keys()), + self.source_name_to_debug_name, + ) + + if not _simplified: + for source, expr in input_guards: + srcname = source.name() + if self._translation_validation_enabled: + # Ignore sources that were not turned into SymInts. + if srcname in self.source_to_symbol: + self._add_target_expr( + sympy.Eq(self.source_to_symbol[srcname], expr) + ) + + # Small optimization + if ( + isinstance(expr, sympy.Symbol) + and symbol_to_source.get(expr) + and source == symbol_to_source[expr][0] + ): + continue + + # This logic excludes static values found on tensors from guarding, because + # dynamo's check_tensor_fn does that (see guards.cpp). + # However, for non tensor sources, we still need to guard here. + if ignore_static and isinstance(source, TensorPropertySource): + if expr.is_number: + self.log.debug( + "Skipping guard %s", f"{source_ref(source)} == {expr}" + ) + continue + + if is_dim(source): + self.dim_constraints.add_equality(source, expr) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + res = f"{printer.print_source(source)} == {printer.doprint(expr)}" + + if lang == "verbose_python": + if (s0 := self.source_to_var.get(srcname)) is not None: + if source != self.var_to_sources[s0][0]: + res = ( + f"{res} # duck sizing added this equality because these " + f"variables had the same size {self.var_to_val[s0]} " + "(to avoid this specialization, set torch.fx.experimental._config.use_duck_shape = False)" + ) + elif (sloc := self.replacements_slocs.get(s0)) is not None: + res = f"{res} # {sloc}" + else: + res = f"{res} # (unknown var {s0}, please file a bug)" + else: + res = f"{res} # (unknown source {srcname}, please file a bug)" + exprs.append(res) + + if ( + isinstance(source, TensorPropertySource) + and source.prop is TensorProperty.SIZE + and equalities_inputs + and len(expr.free_symbols) == 1 + ): + symbol = next(iter(expr.free_symbols)) + if ( + isinstance(expr, sympy.Symbol) + and expr in symbol_to_constraints + and not equalities_inputs.is_equal( + source, symbol_to_source[expr][0] + ) + ): + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} and " + f"{self._debug_name(symbol_to_source[expr][0])} = {symbol_to_source[expr][0].name()} " + "must always be equal." + ) + record_constraint_violation( + equalities_inputs.warn_only, self._debug_name(source), msg + ) + + if ( + not isinstance(expr, sympy.Symbol) + and symbol in symbol_to_constraints + and not equalities_inputs.is_derived( + source, + symbol_to_source[symbol][0], + lambda x: expr.xreplace({symbol: x}), + ) + ): + src = symbol_to_source[symbol][0] + msg = ( + f"The values of {self._debug_name(source)} = {source.name()} must always be related to " + f"the values of {self._debug_name(src)} = {src.name()} by " + f"{self._debug_name(source)} = {expr.xreplace({symbol: sympy.sympify(self._debug_name(src))})}." + ) + record_constraint_violation( + equalities_inputs.warn_only, self._debug_name(source), msg + ) + + # NB: Not necessary to report constraint violations here: + # constraints are guaranteed to be on symbols (we've already + # caught constants and non-atomic expressions), so we only + # have relational constraints, but we don't support those + # at the moment + + # 2. Every guard must evaluate to True (but remember many guards + # like s0 == s1*2 because trivial due to simplification) + issued = set() + + def issue_guard(guard: ShapeGuard) -> None: + expr = self.simplify(guard.expr) + + # Avoid re-issueing the same guard. + if expr in issued: + return + + issued.add(expr) + + try: + is_trivial = False + if any( + is_dim(source) + for s in expr.free_symbols + for source in symbol_to_source[s] + ): + assert self.dim_constraints is not None + is_trivial = self.dim_constraints.add(expr) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + guard_expr = printer.doprint(expr) + if lang == "verbose_python": + guard_expr = f"{guard_expr} # {guard.sloc}" + exprs.append(guard_expr) + + self._add_target_expr(expr) + # A non-relational constraint on a single sizevar can violate + # a constraint + if not is_trivial and len(expr.free_symbols) == 1: + symbol = next(iter(expr.free_symbols)) + source = symbol_to_source[symbol][0] + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + var_with_range = ( + self._render_range_for_constraint_violation(source, c) + ) + msg = ( + f"Not all values of {var_with_range} " + f"satisfy the generated guard {py_printer.doprint(expr)}." + ) + record_constraint_violation( + c.warn_only, self._debug_name(source), msg + ) + elif isinstance(c, RelaxedUnspecConstraint): + # This is fine, we allow guards here as long as it + # didn't constrain it to one value (we don't + # actually know this; this depends on our + # ValueRanges reasoning capability) + pass + else: + raise AssertionError(f"unrecognized constraint {c}") + except Exception: + self.log.warning("Failing guard allocated at %s", guard.sloc) + raise + + # First, issue all guards. + # This removes all the checks that follow from bounds + # We could simply emit those and also the bounds 2 <= size when necessary + for guard in guards if guards is not None else self.guards: + if ( + self._maybe_evaluate_static( + guard.expr, axioms=(), size_oblivious=guard.size_oblivious + ) + is not None + ): + continue + issue_guard(guard) + + # Because there are guards that export's constraint solver can suggest good fixes for, that we may have + # deferred as runtime asserts, and that produce_guards() alone won't do anything with (e.g. divisiblity guards), + # we want to send runtime asserts to export's constraint solver too. These will still stay in the graph as asserts, + # but export's constraint solver can decide whether to do anything with them (i.e. raise an error and provide + # suggested fixes, or decide it's out of scope and leave as a runtime assert in the graph). + for ra in self.deferred_runtime_asserts.get(None, []): + if self._maybe_evaluate_static(ra.expr, axioms=()) is not None: + continue + expr = self.simplify(ra.expr) + self.dim_constraints.add(expr) + + # 3. Every symbol must be within its value range (this handles 0/1 + # specialization too). + for symbol, sources in symbol_to_source.items(): + r = self.var_to_range.get(symbol) + if r is None: + continue + vr_sloc = self.var_to_range_sloc[symbol] + + assert sources + bounds = [] + rf = source_ref(sources[0]) + verbose_expr = "" + if r.lower not in (-sympy.oo, -int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Ge(symbol, r.lower)) + # Only print lower bound in simplified mode if it is not the + # default + if not _simplified or r.lower != self._default_value_range().lower: + bounds.append(sympy.Le(r.lower, symbol, evaluate=False)) + verbose_expr = f"{r.lower} <= {rf} # {vr_sloc.lower}" + if r.upper not in (sympy.oo, int_oo): + if any(is_dim(source) for source in sources): + self.dim_constraints.add(sympy.Le(symbol, r.upper)) + # nontrivial upper bound is always interesting + bounds.append(sympy.Le(symbol, r.upper, evaluate=False)) + if verbose_expr: + verbose_expr = f"{r.lower} <= {rf} <= {r.upper} # {vr_sloc.lower} and {vr_sloc.upper}" + else: + verbose_expr = f"{rf} <= {r.upper} # {vr_sloc.upper}" + if bounds: + bound = sympy.And(*bounds, evaluate=False) + + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "verbose_python": + exprs.append(verbose_expr) + else: + exprs.append(printer.doprint(bound)) + # NB: verbose_exprs are done above + + # Check constraints + constraints = symbol_to_constraints[symbol] + for c in constraints: + if isinstance(c, StrictMinMaxConstraint): + # TODO: With int_oo, I think this condition is a noop + # now + if not (c.vr & self._default_value_range()).issubset(r): + source = sources[0] + + expr = sympy.And( + sympy.Le(r.lower, symbol), sympy.Le(symbol, r.upper) + ) + guard_expr = py_printer.doprint(expr) + var_with_range = ( + self._render_range_for_constraint_violation(source, c) + ) + msg = f"Not all values of {var_with_range} satisfy the generated guard {guard_expr}" + record_constraint_violation( + c.warn_only, + self._debug_name(source), + msg, + ) + # We NaN specialize, which means similar to 0/1 specialization we + # should assume that the float is NOT nan. This is load bearing + # if you have something like an equality guard, nan will play + # merry hell with the reasoning. + if symbol_is_type(symbol, SymT.FLOAT): + res = f"not math.isnan({py_printer.print_source(sources[0])})" + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "verbose_python": + exprs.append( + f"{res} # implicit guard for float input due to NaN specialization in the framework" + ) + elif lang == "python": + exprs.append(res) + elif lang == "cpp": + exprs.append(f"~std::isnan({printer.print_source(sources[0])})") + else: + raise NotImplementedError(f"Unimplemented for lang: {lang}") + + if constraint_violations: + warn_msgs: list[str] = [] + error_msgs: list[str] = [] + debug_names = set() + for warn_only, debug_name, msg_cb in constraint_violations: + if warn_only: + str_msg = f" {len(warn_msgs) + 1}. {msg_cb()}" + warn_msgs.append(str_msg) + else: + str_msg = f" - {msg_cb()}" + error_msgs.append(str_msg) + debug_names.add(debug_name) + if len(error_msgs) > 0: + debug_names_str = ", ".join(sorted(debug_names)) + err = "\n".join(error_msgs) + raise ConstraintViolationError( + f"Constraints violated ({debug_names_str})! " + 'For more information, run with TORCH_LOGS="+dynamic".\n' + f"{err}" + ) + elif len(warn_msgs) > 0: + log.debug("%s Warning only constraints violated", len(warn_msgs)) + + signpost_event( + "dynamic", + "produce_guards", + { + **self.co_fields, + **self.counter, + "num_guards": len(all_exprs[0]), + "free_symbols": sum(1 for v in symbol_to_source.values() if v), + # The keys are meaningless from an aggregate perspective, so + # don't include them. Biggest first. + "symbol_guard_counts": sorted( + self.symbol_guard_counter.values(), reverse=True + ), + }, + ) + + if self._translation_validation_enabled: + from torch.fx.experimental.validator import PopulateValidator + + # Add all deferred runtime assertions; these are not technically + # handled by produce_guards but we need to put them in the target + # set + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + self._add_target_expr(ra.expr) + + # Add value range bound guards for all symbols with no trivial bounds. + # Reason: '_maybe_evaluate_static' may eliminate guards based on the + # refined value ranges. + for sym, vr in self.var_to_range.items(): + if vr.lower not in (-sympy.oo, -int_oo): + self._add_target_expr(sympy.Le(vr.lower, sym)) + if vr.upper not in (sympy.oo, int_oo): + self._add_target_expr(sympy.Le(sym, vr.upper)) + + # Before validating, populate the input of the validator with the + # built FX graph. + with fx_traceback.preserve_node_meta(): + PopulateValidator(self.graph, self.validator).run() + + # Only run translation validation when we are not passing custom guards + if guards is None: + self._check_translation_validate() + + helpers: list[_ShapeGuardsHelper] = [] + for exprs, printer, lang in zip(all_exprs, printers, langs): + if lang == "cpp": + assert isinstance(printer, _ShapeGuardCppPrinter) + helpers.append(_CppShapeGuardsHelper(exprs, printer.source_to_symbol)) + else: + helpers.append(_ShapeGuardsHelper(exprs)) + return helpers + + def produce_guards_expression( + self, + placeholders: Sequence[Union[SymInt, FakeTensor]], + *, + guards: Optional[list[ShapeGuard]] = None, + ignore_static: bool = True, + ) -> Optional[str]: + """ + Expected to be used with evaluate_guards_expression(). Produces the guards + for the given placeholders and returns a string expression to be evaluated + by evaluate_guards_expression given concrete values for the placeholders. + """ + from torch._dynamo.source import LocalSource + + arg_names = [f"t{i}" for i in range(len(placeholders))] + produced_guards = self.produce_guards( + placeholders, + [LocalSource(a) for a in arg_names], + guards=guards, + ignore_static=ignore_static, + ) + if produced_guards: + return " and ".join(produced_guards) + return None + + def evaluate_symexpr(self, code: str) -> Union[int, float, bool]: + """ + To be used by compile_fx to evaluate symexprs + """ + args = {str(e): val for e, val in self.var_to_val.items()} + return eval(code, SYMPY_INTERP, args) + + def deserialize_symexpr(self, code: str) -> Union[SymInt, SymFloat, SymBool]: + """ + To be used by compile_fx to deserialize symexprs + """ + args = { + str(e): SymInt(SymNode(e, self, int, int(val), fx_node=None)) + for e, val in self.var_to_val.items() + } + return eval(code, SYMPY_INTERP, args) + + def evaluate_guards_expression(self, code: str, args: Sequence[object]) -> bool: + """ + Expected to be used with produce_guards_expression(). Evaluates an expression + generated by produce_guards_expression for the given concrete args. + """ + arg_names = [f"t{i}" for i in range(len(args))] + return eval(code, SYMPY_INTERP, {"L": dict(zip(arg_names, args))}) + + def evaluate_guards_for_args( + self, + placeholders: Sequence[FakeTensor], + args: Sequence[Tensor], + *, + ignore_static: bool = True, + ) -> bool: + """Generate guards for a graph's placeholder values and evaluate the guards with args""" + code = self.produce_guards_expression(placeholders, ignore_static=ignore_static) + if code: + return self.evaluate_guards_expression(code, args) + return True + + def get_pruned_guards(self, symints: Sequence[torch.SymInt]) -> list[ShapeGuard]: + """ + Get a list of guards, but pruned so it only provides guards that + reference symints from the passed in input + """ + symints = { + s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) + } + guards = [ + g for g in self.guards if all(s in symints for s in g.expr.free_symbols) + ] + return guards + + def bind_symbols( + self, placeholders: Sequence[FakeTensor], args: Sequence[Tensor] + ) -> dict[sympy.Symbol, int]: + """ + Given a paired list of placeholders (fake tensors with + symbolic sizes) and concrete arguments (regular tensors + with real sizes), returns a dictionary mapping each + symbol to its real value. So for example, if you + have a placeholder with size (s0, s1), binding + (2, 4) to it will give you {s0: 2, s1: 4}. This is + not guaranteed to bind ALL symbols in the ShapeEnv; + we can't bind a symbol if it doesn't occur in any placeholder, + and symbols that already have replacements won't get bindings. + + This is a little duplicative with evaluate_guards but + it's different enough that it seemed cleanest to make + another copy. This assumes the guards are already checked, + though if it's cheap we'll check for shenanigans + """ + bindings: dict[sympy.Symbol, int] = {} + + def bind_symint(arg: object, val: object) -> None: + if isinstance(val, SymInt): + assert isinstance(arg, int) + s = val.node.expr + + if isinstance(s, sympy.Symbol): + if s in bindings: + assert bindings[s] == arg, f"{bindings[s]} != {arg}" + else: + bindings[s] = arg + elif isinstance(-s, sympy.Symbol): + if -s in bindings: + assert bindings[-s] == -arg, f"{bindings[-s]} != {-arg}" + else: + bindings[-s] = -arg + + for t, arg in zip(placeholders, args): + if t is None: + continue + if isinstance(t, SymInt): + bind_symint(arg, t) + continue + assert isinstance(t, torch.Tensor) + for i, s in enumerate(t.size()): + bind_symint(arg.size(i), s) + for i, s in enumerate(t.stride()): + bind_symint(arg.stride(i), s) + bind_symint(arg.storage_offset(), t.storage_offset()) + + return bindings + + def get_nontrivial_guards(self) -> list[SympyBoolean]: + """Returns a list of guard expressions that aren't statically known (i.e. not trivial)""" + return [ + self.simplify(guard.expr) + for guard in self.guards + if self._maybe_evaluate_static( + guard.expr, axioms=(), size_oblivious=guard.size_oblivious + ) + is None + ] + + def format_guards(self, verbose: bool = False) -> str: + """Format this shape env's guard expressions with optional traceback info if verbose""" + + return "\n".join( + f" - {guard.expr}{' ' + str(guard.sloc) if verbose else ''}" + for guard in self.guards + ) + + def bound_sympy( + self, expr: sympy.Expr, size_oblivious: bool = False + ) -> ValueRanges: + """Given a sympy expression, computes a ValueRanges bound for what values it can be""" + # TODO: maybe it's guaranteed x in is var_to_range? + var_to_range = {x: self.var_to_range.get(x, None) for x in expr.free_symbols} + if size_oblivious: + # Clamp values of size-like variables + # NB: discarding the old upper bound in intentional, per + # https://github.com/pytorch/pytorch/pull/123675 + for x in self.size_like & var_to_range.keys(): + if var_to_range[x] is not None: + # NB: do NOT set upper to 2 ** 48, we're using this solely + # to determine if we can do size-like replacement, the + # upper bound is irrelevant here + var_to_range[x] = ValueRanges(2, int_oo) + return bound_sympy(expr, var_to_range) # type: ignore[arg-type] + + @_lru_cache + def get_axioms( + self, + symbols: Optional[tuple[sympy.Symbol]] = None, + compute_hint: bool = False, + ) -> tuple[SympyBoolean, ...]: + """ + Given the symbols in an expression, it returns all the runtime asserts that have those symbols + concatenated with all the guards. + If symbols is None, it returns all the runtime asserts (and all the guards) + """ + if symbols is None: + runtime_asserts = ( + r.expr for rs in self.deferred_runtime_asserts.values() for r in rs + ) + else: + runtime_asserts = ( + r.expr + for s in symbols + if s not in self.var_to_val + for r in self.deferred_runtime_asserts.get(s, ()) + ) + guards: Iterator[SympyBoolean] = (g.expr for g in self.guards) + axioms: Iterator[SympyBoolean] = itertools.chain(guards, runtime_asserts) + if compute_hint: + axioms = ( + canonicalize_bool_expr(a.xreplace(self.var_to_val)) for a in axioms + ) + return tuple(dict.fromkeys(axioms).keys()) + + @lru_cache(None) + def get_implications( + self, e: SympyBoolean + ) -> tuple[tuple[SympyBoolean, sympy.logic.boolalg.BooleanAtom], ...]: + """Given a expression, it returns a list of predicates that follow from it""" + equiv: dict[SympyBoolean, sympy.logic.boolalg.BooleanAtom] = {} + + def add_expr(expr: SympyBoolean) -> None: + expr = canonicalize_bool_expr(expr) + if isinstance(expr, (sympy.Eq, sympy.Ne)): + # No need to canonicalize + # TODO We could further canonicalize Eq ordering the lhs and rhs somehow + # With this, we could remove the need for the commutativity part + opposite = sympy.Eq if isinstance(expr, sympy.Ne) else sympy.Ne + # Commutativity of == and != + equiv[type(expr)(expr.lhs, expr.rhs, evaluate=False)] = sympy.true + equiv[type(expr)(expr.rhs, expr.lhs, evaluate=False)] = sympy.true + equiv[opposite(expr.lhs, expr.rhs, evaluate=False)] = sympy.false + equiv[opposite(expr.rhs, expr.lhs, evaluate=False)] = sympy.false + else: + # Expr and negation + equiv[expr] = sympy.true + # we do not pass evaluate=False like others on purpose here! + # we want not(a=b and not ~(a Optional[sympy.Basic]: + """ + Tries to evaluate expr without introducing guards + + If unbacked_only == True, then we only do substitutions on + unbacked SymInts (leaving regular hinted integers alone). This could + result in an expression that still contains backed SymInts, which you + could then potentially guard on. + + Use compute_hint == True if you are trying to compute a non-binding + hint for the particular hint values of backed and unbacked SymInts, + e.g., if s0 happens to be 3 this run, compute_hint will subsitute s0 with 3. + """ + + # axioms with compute hint NYE + assert not compute_hint or not axioms + expr = self.simplify(expr, size_oblivious) + + if compute_hint: + expr = expr.xreplace(self.var_to_val).xreplace(self.unbacked_var_to_val) + + expr = canonicalize_bool_expr(expr) + + def resimplify_floor_div(axioms: dict[sympy.Expr, sympy.Expr]) -> None: + if not self._resimplify_floor_div_axioms: + return + self._resimplify_floor_div_axioms = False + new_items = {} + for k, v in axioms.items(): + # A FloorDiv in implications could have became CleanDiv at this point, due to new facts + # to the shapeEnv. This handles such issue but its not ideal. This is the only expression + # simplification that depends on the global state of shape env. + # TODO try to get rid of CleanDiv since it breaks the invariant thats simplifications of sympy + # expressions only depend on the expression itself. + if k.has(FloorDiv): + new_items.update({self.simplify(k): v}) + axioms.update(new_items) + + # Pattern matching + if axioms is None: + resimplify_floor_div(self.axioms) + subst = self.axioms + else: + subst = {} + for e in axioms: + if e.free_symbols.issubset(expr.free_symbols): + subst.update(dict(self.get_implications(self.simplify(e)))) + + resimplify_floor_div(subst) + + expr = expr.xreplace(subst) + # TODO: compute hint might have gotten broken here + + fs = expr.free_symbols + + if not fs and (expr.is_number or expr.is_Boolean): + return expr + + if var_to_range is None: + var_ranges = self.var_to_range + else: + var_ranges = dict(var_to_range) + + symbol_info = tuple( + _SymbolInfo( + s, + var_ranges.get(s), + self.var_to_val.get(s), + s in self.size_like, + ) + for s in sorted(fs, key=str) # TODO: speed up sort? + ) + + r = _maybe_evaluate_static_worker( + expr, symbol_info, unbacked_only, size_oblivious + ) + return r + + @_lru_cache + def replace(self, expr: _SympyT) -> _SympyT: + """ + Apply symbol replacements to any symbols in the given expression. + """ + replacements = {} + for s in expr.free_symbols: + r = self._find(s) + + # Micro-optimization: only do replacements if r and s are different + # Otherwise, xreplace is not a no-op and will trigger expensive + # assumption queries if expr has a relational node. + if not r.is_Symbol or r != s: + replacements[s] = r + if replacements: + return safe_expand(expr.xreplace(replacements)) + else: + return expr + + @_lru_cache + def _update_divisible(self) -> None: + new_divisible = set() + for k in self.divisible: + res = self.replace(k) + if not res.is_number: + new_divisible.add(k) + + self.divisible = new_divisible + self._update_version_counter() + + @_lru_cache + def simplify(self, expr: _SympyT, size_oblivious: bool = False) -> _SympyT: + """Use known constraints and replacements to simplify the given expr""" + expr = safe_expand(expr) + expr = self.replace(expr) + + if size_oblivious and (expr.has(Max) or expr.has(Min)): # type: ignore[has-type] + min_max_replacements = {} + for atom in (*expr.atoms(Max), *expr.atoms(Min)): # type: ignore[has-type] + if len(atom.args) > 2: + continue + a, b = atom.args + if b == 1 or b == 0: + a, b = b, a + if a == 1 or a == 0: + vr = self.bound_sympy(b, size_oblivious=True) + if vr.lower >= a: + min_max_replacements[atom] = b if atom.func is Max else a + elif vr.upper <= a: + min_max_replacements[atom] = a if atom.func is Max else b + if min_max_replacements: + expr = expr.xreplace(min_max_replacements) + + if expr.has(TruncToInt): + trunc_replacements = {} + for atom in expr.atoms(TruncToInt): + if isinstance(atom.args[0], IntTrueDiv): + base, divisor = atom.args[0].args + if base % divisor == 0: + trunc_replacements[atom] = base // divisor + if trunc_replacements: + expr = expr.xreplace(trunc_replacements) + + # TODO it would seem that this pass is not necessary given the + # below replacement of // with /, but for nested FloorDivs + # the non-recursive replacement doesn't work, and + # recursive makes it hard to look up divisibility, + # because existing divisibility info has FloorDiv in it, not / + # for now just do a separate pass to catch common nested case + if expr.has(FloorDiv): + self._update_divisible() + div_replacements = {} + for atom in expr.atoms(FloorDiv): + base, divisor = atom.args + if isinstance(divisor, FloorDiv): + base1, divisor1 = divisor.args + if ( + self.replace(Mod(base, divisor)) in self.divisible + and base == base1 + and self.replace(Mod(base1, divisor1)) in self.divisible + ): + div_replacements[atom] = divisor1 + if div_replacements: + expr = expr.xreplace(div_replacements) + expr = safe_expand(expr) + if expr.has(FloorDiv): + div_replacements = {} + pows = expr.atoms(sympy.Pow) + rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) + for fd in expr.atoms(FloorDiv): + base, divisor = fd.args + if self.replace(Mod(base, divisor)) in self.divisible: + div_replacements[fd] = CleanDiv(base, divisor) + if div_replacements: + new_expr = expr.xreplace(div_replacements) + new_expr = safe_expand(new_expr) + new_pows = new_expr.atoms(sympy.Pow) + new_rationals = new_expr.atoms(sympy.Rational).difference( + new_expr.atoms(sympy.Integer) + ) + # divisions simplified away + if new_pows.issubset(pows) and new_rationals.issubset(rationals): + expr = new_expr + return expr + + # TODO: overload for allow_none literal + @lru_cache(256) + def size_hint( + self, expr: sympy.Basic, *, allow_none: bool = False + ) -> Optional[sympy.Basic]: + """ + Gets a size hint for a given expression from the underlying shapes we had. + Does not introduce a guard, so only use this when you can guarantee that + your code is still valid for arbitrary shapes (such as optimization decisions) + """ + result_expr = safe_expand(expr).xreplace(self.var_to_val) + if not result_expr.is_number: + from torch.utils._sympy.singleton_int import SingletonInt + + if isinstance(result_expr, SingletonInt): + return None + r = self._maybe_evaluate_static(result_expr, compute_hint=True) + if r is not None: + return r + if allow_none: + return None + + if self.oblivious_var_to_val: + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + correct_hint = result_expr.xreplace(self.oblivious_var_to_val) + counterfactual_hint = result_expr.xreplace( + {k: max(v, 2) for k, v in self.oblivious_var_to_val.items()} + ) + if ( + not correct_hint.free_symbols + and not counterfactual_hint.free_symbols + ): + if correct_hint == counterfactual_hint: + log.info("oblivious_size hit %s -> %s", expr, correct_hint) + return correct_hint + else: + log.info( + "oblivious_size counterfactual failed %s -> %s != %s", + expr, + correct_hint, + counterfactual_hint, + ) + else: + log.info( + "oblivious_size miss %s -> %s (counterfactual: %s)", + expr, + correct_hint, + counterfactual_hint, + ) + + if self.unbacked_var_to_val: + unsound_expr = result_expr.xreplace(self.unbacked_var_to_val) + if not unsound_expr.free_symbols: + log.warning( + "propagate_real_tensors size_hint(%s) -> %s", expr, unsound_expr + ) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(expr), + "result": repr(unsound_expr), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + self.guard_or_defer_runtime_assert( + sympy.Eq(result_expr, unsound_expr), + f"propagate_real_tensors: {result_expr} == {unsound_expr}", + ) + return unsound_expr + + raise self._make_data_dependent_error(result_expr, expr) + return result_expr + + # NB: keep in sync with size_hint + @lru_cache(256) + def has_hint(self, expr: sympy.Expr) -> bool: + result_expr = safe_expand(expr).xreplace(self.var_to_val) + return ( + result_expr.is_number + or self._maybe_evaluate_static(result_expr) is not None + ) + + def _make_data_dependent_error( + self, + expr: sympy.Basic, + unhinted_expr: sympy.Basic, + *, + size_oblivious_result: Optional[sympy.Basic] = None, + expr_sym_node_id: Optional[int] = None, + ) -> GuardOnDataDependentSymNode: + # TODO: in a Dynamo context, having user code, and having the + # name of the local, will be much better + size_like_symbols = [] + for s in expr.free_symbols: + stacktrace = "".join(self.var_to_stack[s].format()) + self.log.debug( + "Data dependent variable '%s' allocated at:\n%s", s, stacktrace + ) + if s in self.size_like: + size_like_symbols.append(s) + size_oblivious_result_msg = "" + if size_oblivious_result is not None: + size_oblivious_result_msg = ( + f"ATTENTION: guard_size_oblivious would fix the error, evaluating expression to {size_oblivious_result}.\n" + "Maybe you need to add guard_size_oblivious to framework code, see doc below for more guidance.\n\n" + ) + sloc, maybe_extra_debug = self._get_stack_summary(True) + if expr.is_integer: # type: ignore[attr-defined] + desc = ( + "Could not extract specialized integer from data-dependent expression" + ) + else: + desc = "Could not guard on data-dependent expression" + msg = ( + f"{desc} {expr} (unhinted: {unhinted_expr}). " + f"(Size-like symbols: {', '.join(map(str, size_like_symbols)) or 'none'})\n\n" + f"{size_oblivious_result_msg}" + f"Caused by: {sloc}\n" + 'For more information, run with TORCH_LOGS="dynamic"\n' + "For extended logs when we create symbols, also add " + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n' + "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" + "For more debugging help, see " + "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" + + maybe_extra_debug + # TODO: Help text about how to use our runtime tests to fix this + # problem + ) + + dtrace_structured( + "guard_on_data_dependent_error", + metadata_fn=lambda: { + "expr": repr(expr), + "unhinted_expr": repr(unhinted_expr), + "expr_id": self._expr_sym_node_id, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + return GuardOnDataDependentSymNode(expr, msg) + + def _update_var_to_range( + self, + symbol: sympy.Symbol, + vr: ValueRanges, + vr_sloc: Optional[ValueRangesSLoc] = None, + *, + is_constraint: bool = False, + ) -> None: + lower, upper = vr.lower, vr.upper + + # If we have a size-like unbacked SymInt, refuse to refine the range to be + # less than two. This is because when we intersect this range + # with [2, inf] for size oblivious tests, the range would be + # unsatisfiable. In other words, once you have a size-like + # unbacked SymInt, we can never learn that it is exactly zero or one, + # because we would now give inconsistent results for all size + # oblivous tests! + if upper < 2 and symbol in self.size_like: + vr = ValueRanges(lower, 2) + + # Updates the range and the guards corresponding to each bound of the symbol. + if symbol not in self.var_to_range: + self.log.debug("_update_var_to_range %s = %s (new)", symbol, vr) + self.var_to_range[symbol] = vr + if vr_sloc is None: + sloc = self._get_sloc() + vr_sloc = ValueRangesSLoc(sloc, sloc) + self.var_to_range_sloc[symbol] = vr_sloc + else: + old = self.var_to_range[symbol] + new = old & vr + if new != old: + if vr_sloc is None: + sloc = self._get_sloc() + vr_sloc = ValueRangesSLoc(sloc, sloc) + if new.lower != old.lower: + self.var_to_range_sloc[symbol].lower = vr_sloc.lower + if new.upper != old.upper: + self.var_to_range_sloc[symbol].upper = vr_sloc.upper + self.var_to_range[symbol] = new + self.log.debug("_update_var_to_range %s = %s (update)", symbol, new) + + if (v := self.var_to_val.get(symbol)) is not None: + r = self.var_to_range[symbol] + if v not in r: + # For constraint failure, delay this for later + # TODO: Rework all of this, the constraint logic is very + # duplicative with regular reasoning + if not is_constraint: + assert v in r, f"{v} not in {r}" + + def _set_replacement(self, a: sympy.Symbol, tgt: sympy.Expr, msg: str) -> None: + """ + Adds or updates a replacement for a symbol. + Use this instead of `self.replacements[a] = tgt`. + """ + if tgt == self.replacements.get(a, None): + return + + if a in tgt.free_symbols: + return + + # Precondition: a == tgt + assert isinstance(a, sympy.Symbol) + + if ( + self.allow_complex_guards_as_runtime_asserts + and not _is_supported_equivalence(tgt) + ): + return # continuing leads to placeholder shapes having complex expressions that we can't resolve + + # Handles nested tensor symbolic variables which don't have + # var_to_range bounds + tgt_bound = None + if a in self.var_to_range: + src_bound = self.var_to_range[a] + + # First, refine the value range of a based on the computed value range + # of tgt. This is always OK to do, even if we decide not to do the + # substitution in the end. This might be a no-op, if a already has + # a tighter bound + tgt_bound = self.bound_sympy(tgt) + self._update_var_to_range(a, tgt_bound) + + # Next, check if we can update the range of free symbols in tgt + # based on the range in a. But only do it if: + # - the source bound non-trivially improves over what we get out of + # the existing bounds. + # - the replacement is univariate and we can invert the tgt expression + if not tgt_bound.issubset(src_bound) and len(tgt.free_symbols) == 1: + b = next(iter(tgt.free_symbols)) + # Try to invert the equality + r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) + if r is not None: + self.log.debug( + "set_replacement: solve for %s in %s == %s gives %s", + b, + a, + tgt, + r, + ) + # The solution here can be non-integral, for example, if + # we have s0 = 2*s1, then s1 = s0/2. What we would like + # to do is calculated the bounds in arbitrary precision, + # and then requantize the bound to integers when we are + # done. + rat_b_bound = self.bound_sympy(r[1]) + b_bound = ValueRanges( + CeilToInt(rat_b_bound.lower), FloorToInt(rat_b_bound.upper) + ) + self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) + tgt_bound = self.bound_sympy(tgt) + assert tgt_bound.issubset(src_bound), ( + f"{tgt_bound=} not a subset of {src_bound=}" + ) + + # TODO: Should we propagate size-like-ness? + # + # Pros: if u0 is size-like, intuitively u0 == u1 should cause u1 + # to become size-like. + # + # Cons: if u0 is size-like, what about u0 - 1 == u1? You CAN'T + # propagate in this case, because what if u0 == 0, then u1 is negative + # and clearly isn't a size. So, at minimum, any f(x) whose value + # range isn't [0, inf] given x in [0, inf] cannot propagate + # size-like-ness. But there are many situations where you could + # imagine u1 is going to be size-like and actually you just didn't + # have a refined enough value range on u0. Since even innocuous + # looking arithmetic operations can destroy size-like-ness, it's + # best to not propagate it at all and force the user to annotate it + # as necessary. + # + # Compromise: we preserve size-like-ness only for exact equality + # and nothing else. + if a in self.size_like and isinstance(tgt, sympy.Symbol): + self.size_like.add(tgt) + elif isinstance(tgt, sympy.Symbol) and tgt in self.size_like: + self.size_like.add(a) + + # Now, decide if we will do the substitution. + # + # - If the source has a non-trivial range, only substitute if + # we preserve this range. Note that we may have propagated + # the src_range to free variables in tgt when tgt is univariate + # and we could find an inverse, which helps us achieve this. + # This ensures we never "forget" about user defined ranges, + # even if they end up being defined on composite formulas + # like s0 + s1. + # + # - If the variable is unbacked, only substitute if the substitution + # would preserve the bounds also under size-like-ness conditions. + + if not tgt_bound.issubset(src_bound): + self.log.debug( + "skipped set_replacement %s = %s (%s) [%s not subset of %s]", + a, + tgt, + msg, + tgt_bound, + src_bound, + ) + return + elif a in self.size_like: + tgt_bound_so = self.bound_sympy(tgt, size_oblivious=True) + src_bound_so = self.bound_sympy(a, size_oblivious=True) + if not tgt_bound_so.issubset(src_bound_so): + self.log.debug( + "skipped set_replacement %s = %s (%s) " + "[%s not subset of %s (size-oblivious conditions)]", + a, + tgt, + msg, + tgt_bound_so, + src_bound_so, + ) + return + + if isinstance(tgt, (sympy.Integer, sympy.Float)): + # specializing to a constant, which is likely unexpected (unless + # you specified dynamic=True) + + user_tb = TracingContext.extract_stack() + trace_structured( + "symbolic_shape_specialization", + metadata_fn=lambda: { + "symbol": repr(a), + "sources": [s.name() for s in self.var_to_sources.get(a, [])], + "value": repr(tgt), + "reason": msg, + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + "user_stack": ( + structured.from_traceback(user_tb) if user_tb else None + ), + }, + ) + + for source in self.var_to_sources.get(a, []): + if user_tb: + self.user_specialization_stacks[source] = user_tb + self.framework_specialization_stacks[source] = ( + CapturedTraceback.extract(cpp=True) + ) + + if config.print_specializations: + self.log.warning( + "Specializing %s to %s", self.var_to_sources[a][0].name(), tgt + ) + self.log.debug("SPECIALIZATION", stack_info=True) + log.info("set_replacement %s = %s (%s) %s", a, tgt, msg, tgt_bound) + self.replacements[a] = tgt + # NB: the replacement may get refined, but the user will find the + # FIRST one most useful (TODO: Maybe we could consider tracking all of + # them) + if a not in self.replacements_slocs: + self.replacements_slocs[a] = self._get_sloc() + self._update_version_counter() + + # When specializing 'a == tgt', the equality should be also conveyed to + # Z3, in case an expression uses 'a'. + self._add_target_expr(sympy.Eq(a, tgt, evaluate=False)) + + def _add_divisible(self, expr: sympy.Expr) -> None: + self.divisible.add(expr) + self._update_version_counter() + + @_lru_cache + @record_shapeenv_event() + def _find(self, a: sympy.Symbol) -> sympy.Expr: + """ + Implements a DSU-like algorithm to find the variable that represents a + Also handles transitive non-identity replacements. + + a: b + c + c: d + """ + if a not in self.replacements: + return a + res = self.replacements[a] + cur_replace = {s: self._find(s) for s in res.free_symbols} + replaced, changed = self.replacements[a]._xreplace(cur_replace) + if changed: + self._set_replacement(a, replaced, "find") + return self.replacements[a] + + @lru_cache(256) + def _maybe_guard_rel(self, expr: sympy.Expr) -> None: + """ + The relational guard is guarded to be true. Use this information to + simplify shapes (i.e. a == b or a % 5 == 0) + """ + if isinstance(expr, sympy.And): + for arg in expr.args: + self._maybe_guard_rel(arg) + return + elif not isinstance(expr, sympy.Rel): + log.warning( + "_maybe_guard_rel() was called on non-relation expression %s", expr + ) + return + + # A good example of what goes wrong if you don't do this is + # python test/functorch/test_aotdispatch.py -k + # test_aot_autograd_symbolic_module_exhaustive_nn_LazyConv3d_cpu_float32 + if isinstance(expr, sympy.Ne): + return + + free = list(expr.free_symbols) + + assert len(free) > 0, ( + f"The expression should not be static by this point: {expr}" + ) + # In case of really gnarly expression, we don't blow up + if len(free) > 5: + return + + # Prioritize unbacked symints for solving by ordering them last. + # Prefer to simplify out lexicographically higher symbols (i.e. simplify out s4 over s3). + # (NB: this unfortunately isn't strictly equivalent to simplifying out newer symbols) + # Prefer to simplify out symbols with ephemeral sources. + def _smart_symbol_sort(x: sympy.Symbol) -> tuple[int, int, str]: + has_only_ephemeral_sources = x in self.var_to_sources and all( + s.is_ephemeral() for s in self.var_to_sources[x] + ) + # NB: size_hint is int, not sympy.Expr, do not use int_oo here + hint_size = self.size_hint(x, allow_none=True) + if hint_size is None: + size = sys.maxsize + elif symbol_is_type(x, SymT.SIZE): + assert isinstance(hint_size, sympy.Expr) + size = int(hint_size) + else: + size = sys.maxsize + name = x.name + # 1 puts ephemeral sourced symbols first when sorting in reverse + return (1 if has_only_ephemeral_sources else 0, size, name) + + free = sorted(free, key=_smart_symbol_sort, reverse=True) # type: ignore[attr-defined] + lhs = expr.lhs + rhs = expr.rhs + + self._refine_ranges(expr) + + # The rest of this stuff is for equality only + if not isinstance(expr, sympy.Eq): + return + + if not expr.has(Mod): + try: + floor_div_atoms = lhs.atoms(FloorDiv).union(rhs.atoms(FloorDiv)) + if len(floor_div_atoms) > 0 and any( + a.divisor != 1 for a in floor_div_atoms + ): + raise NotImplementedError + + # Never replace unbacked symbols with other unbacked symbols. + # This is error prone because you can cause references to + # unbacked symbols to time travel backwards. E.g., + # + # u1 = x.item() + # ... use of u1 ... + # u2 = y.item() + # u3 = z.item() + # torch._check(u1 == u2 + u3) + # + # If you replace u1 with u2 + u3, then the use of u1 now + # references u2 and u3 prior to them actually being bound at + # runtime. It's pretty inconvenient to setup control + # dependencies for substitutions, so ban it entirely. + def trivial_solve(lhs: sympy.Expr, rhs: sympy.Expr) -> bool: + if isinstance(lhs, sympy.Symbol): + if free_unbacked_symbols(lhs) and not free_unbacked_symbols( + rhs + ): + return True + if symbol_is_type(lhs, SymT.FLOAT): + return True + # TODO: Maybe trivial solutions for int should also be + # done? + return False + + # short-circuit when no solving is needed + if trivial_solve(lhs, rhs): + self._set_replacement(lhs, self._find(rhs), "trivial_lhs") + elif trivial_solve(rhs, lhs): + self._set_replacement(rhs, self._find(lhs), "trivial_rhs") + else: + r = try_solve(expr, free[0], floordiv_inequality=False) + if r is not None and all( + t.is_integer for t in sympy.preorder_traversal(r[1]) + ): + new_var = self._find(r[1]) + ok = len(free_unbacked_symbols(new_var)) == 0 + if ok: + self._set_replacement(free[0], new_var, "solve") + except NotImplementedError: + pass + if expr.has(Mod): + mod_expr = next(iter(expr.atoms(Mod))) + try: + r = try_solve(expr, mod_expr, floordiv_inequality=False) + if r is not None and r[1] == 0: + self._add_divisible(mod_expr) + # This is a little bit of extra logic to make things like + # torch.empty(i0, q).view(c, -1, q) work out + p, q = mod_expr.args + if ( + isinstance(q, sympy.Number) + and isinstance(p, sympy.Mul) + and len(p.args) == 2 + ): + c, i0 = p.args + # Given Mod(c * i0, q) == 0 + if ( + isinstance(c, sympy.Number) + and isinstance(i0, sympy.Symbol) + and self.is_unbacked_symint(i0) + ): + # We have Mod(i0, q / c) == 0, which means we can + # rewrite i0 as (q / gcd(q, c)) * i1 + d = q / sympy.gcd(q, c) # TODO: CleanDiv? + i1 = self.create_unbacked_symint().node.expr + # Propagate the value ranges. It doesn't really + # matter if we use truediv or floordiv, because we + # have established divisibility. + self._update_var_to_range( + i1, + SymPyValueRangeAnalysis.floordiv( + self.var_to_range[i0], ValueRanges.wrap(d) + ), + ) + # Propagate hints (real tensor tracing) + if i0 in self.unbacked_var_to_val: + self.set_unbacked_var_to_val( + i1, self.unbacked_var_to_val[i0] // d + ) + # Propagate size-like-ness + if i0 in self.size_like: + self.size_like.add(i1) + self._set_replacement(i0, d * i1, "divisibility") + + except NotImplementedError: + pass + return + + # See: Note - On 0/1 specialization + def _default_value_range( + self, do_not_specialize_zero_one: bool = False + ) -> ValueRanges: + lower = 0 if (do_not_specialize_zero_one or not self.specialize_zero_one) else 2 + return ValueRanges(lower, int_oo) + + def _default_unspecified_value_range(self) -> ValueRanges: + return ValueRanges.unknown_int() + + @_lru_cache + def _simplify_floor_div(self, expr: sympy.Expr) -> sympy.Expr: + floor_divs = tuple(expr.atoms(FloorDiv)) + # we expect floor_divs to be exact, + # and thus add the guards for the exact floordivs, + # even if tracing doesn't require them otherwise + for fd in reversed(floor_divs): + base, divisor = fd.args + mod_expr = Mod(base, divisor) + eq_expr = sympy.Eq(mod_expr, 0) + # add necessary mod guards + self.evaluate_expr(eq_expr) + return self.simplify(expr) + + # We're about to add a guard/runtime assert, check if the ShapeEnv is frozen + # and if so issue a warning + def _check_frozen(self, expr: sympy.Basic, concrete_val: sympy.Basic) -> None: + if self.frozen: + self.counter["ignored_backward_guard"] += 1 + signpost_event( + "dynamic", + "evaluate_expr_frozen", + { + **self.co_fields, + "ignored_guard": f"{expr} == {concrete_val}", + # no version = original state (this signpost is expected) + # version 2 = dynamic backwards is eagerly compiled + "version": 2, + }, + ) + log.info( + "Ignored guard %s == %s, this could result in accuracy problems", + expr, + concrete_val, + # only print stack trace when debug mode is on (e.g. TORCH_LOGS="dynamic") + stack_info=True if log.getEffectiveLevel() < logging.WARNING else False, + ) + + def _get_user_frame(self) -> Optional[types.FrameType]: + frame = inspect.currentframe() + while frame is not None: + if frame.f_code.co_filename not in uninteresting_files(): + return frame + frame = frame.f_back + return frame + + def _get_stack_summary( + self, is_debug: bool = False, framework_loc: Optional[str] = None + ) -> tuple[SLoc, str]: + floc: Optional[Union[str, traceback.FrameSummary]] = framework_loc + if floc is None: + frame = self._get_user_frame() + try: + if frame is not None: + floc = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + finally: + del frame + + # NB: this stack is truncated, but it's fine because the main + # stack_info will give you the rest of the info you need + maybe_user_loc = None + user_tb = TracingContext.extract_stack() + if user_tb: + idx = len(user_tb) - 1 + while idx > 0 and user_tb[idx].filename in uninteresting_files(): + idx -= 1 + maybe_user_loc = format_frame(user_tb[idx], line=True) + + maybe_extra_debug = "" + if is_debug and user_tb: + maybe_extra_debug = ( + "\nUser Stack (most recent call last):\n" + + " (snipped, see stack below for prefix)\n" + + "".join(traceback.format_list(user_tb)) + ) + if is_debug and config.extended_debug_cpp: + cpp_stack = CapturedTraceback.extract(cpp=True) + maybe_extra_debug += "\nC++ stack trace:\n" + "".join(cpp_stack.format()) + elif is_debug: + maybe_extra_debug += ( + "\nFor C++ stack trace, run with TORCHDYNAMO_EXTENDED_DEBUG_CPP=1" + ) + + return SLoc(floc, maybe_user_loc), maybe_extra_debug + + # Pass in framework_loc to override the framework location info + def _get_sloc(self, framework_loc: Optional[str] = None) -> SLoc: + sloc, _ = self._get_stack_summary(framework_loc=framework_loc) + return sloc + + def _generate_unique_id(self, source_name: str) -> int: + attempt = int(hashlib.sha256(source_name.encode()).hexdigest(), 16) % 100 + while attempt in self.unique_ids: + attempt += 1 + self.unique_ids.add(attempt) + return attempt + + def _find_frame_locals(self) -> _FrameLocalResult: + """ + Given the current user code frame, finds the relevant lines of code, + values of symbolic locals, and free symbols involved. + """ + frame_locals: dict[str, Any] = {} + frame_symbols: dict[str, str] = {} + + if ( + frame := _find_user_code_frame() + ) is None or frame.f_code.co_filename == "": + return _FrameLocalResult() + + # find bytecode instructions relevant to the frame + instructions = list(dis.Bytecode(frame.f_code)) + co_lines, offset = inspect.getsourcelines(frame.f_code) + start, end, cur = None, None, None + for i, instr in enumerate(instructions): + if instr.starts_line is not None: + cur = instr.starts_line + if cur != frame.f_lineno: + continue + if start is None: + start = end = i + else: + end = i + + if start is None or end is None: # no instructions found + return _FrameLocalResult() + + # track involved locals and free symbols + def go(x: Any) -> Optional[str]: + if isinstance(x, torch.Tensor): + for y in x.size(): + go(y) + for y in x.stride(): + go(y) + go(x.storage_offset()) + return ( + f"Tensor(shape: {x.size()}, " + f"stride: {x.stride()}, " + f"storage_offset: {x.storage_offset()})" + ) + elif isinstance(x, (SymBool, SymInt, SymFloat)): + for s in x.node.expr.free_symbols: + if str(s) in frame_symbols: # type: ignore[operator] + continue + if s in self.var_to_sources: + frame_symbols[str(s)] = self.var_to_sources[s][0].name() # type: ignore[assignment] + return str(x) + return None + + # go through instructions, seeing linenos & involved locals + last_lineno = frame.f_lineno + for instr in instructions[start : end + 1]: + if (lineno := instr.starts_line) is not None: + last_lineno = max(last_lineno, lineno) + if isinstance(instr.argval, str) and instr.argval in frame.f_locals: + flat_locals = pytree.tree_flatten(frame.f_locals[instr.argval])[0] + frame_locals[instr.argval] = [ + go(flat_local) for flat_local in flat_locals + ] + + # store LOC + locs = co_lines[frame.f_lineno - offset : last_lineno + 1 - offset] + if not locs: + return _FrameLocalResult() + + indent = len(locs[0]) - len(locs[0].lstrip()) + frame_loc = "".join([loc[indent:] for loc in locs]).strip() # type: ignore[assignment] + return _FrameLocalResult( + loc=frame_loc, locals=frame_locals, symbols=frame_symbols + ) + + def _log_guard(self, prefix: str, g: SympyBoolean, forcing_spec: bool) -> None: + dtrace_structured( + "guard_added", + metadata_fn=lambda: { + "expr": str(g), + "prefix": prefix, + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + "symbol_to_sources": { + str(v): k + for k, v in self.source_to_var.items() + if v in g.free_symbols + }, + "frame_locals": asdict(self._find_frame_locals()), + }, + ) + trace_structured( + "guard_added_fast", + metadata_fn=lambda: { + "expr": str(g), + "user_stack": structured.from_traceback(TracingContext.extract_stack()), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + if self.log.isEnabledFor(logging.INFO): + str_g = str(g) + is_debug = ( + config.extended_debug_guard_added is not None + and str_g == config.extended_debug_guard_added + ) + sloc, maybe_extra_debug = self._get_stack_summary(is_debug) + maybe_more_info = "" + if not is_debug: + maybe_more_info = ( + ", for more info run with " + f'TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="{str_g}"' + ) + self.log.info( + "%s %s [guard added] %s%s%s", + prefix if not forcing_spec else f"{prefix} (forcing_spec)", + str_g, + sloc, + maybe_more_info, + maybe_extra_debug, + stack_info=is_debug, + ) + + # A local variable to evaluate_expr stored in the class to avoid + # using it for the lru_cache that is on top of it since it does + # not effect the results. When needed its read directly. + _expr_sym_node_id: Optional[int] = None + + def evaluate_sym_node( + self, + sym_node: SymNode, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + ) -> sympy.Basic: + """ + Given a a SymNode, evaluates sym_node.expr, adding guards if necessary. + """ + + self._expr_sym_node_id = id(sym_node) + return self.evaluate_expr( + sym_node.expr, + sym_node.hint, + sym_node.fx_node, + size_oblivious, + fallback_value=fallback_value, + ) + + def _is_python_assert(self) -> bool: + # Check if this boolean is used in an assertion, bytecode pattern for + # assertions is pretty stable for Python 3.7--3.13, ported with minimal + # changes from torch/fx/proxy.py + # Bytecode pattern for `assert` statements: + # TO_BOOL / COMPARE_OP # Only for Python >= 3.13 + # POP_JUMP_IF_TRUE + # LOAD_ASSERTION_ERROR + # RAISE_VARARGS + frame = self._get_user_frame() + assert frame is not None + + insts = list(dis.get_instructions(frame.f_code)) + if sys.version_info >= (3, 11): + # For Python >= 3.11, instructions can be 2-4 bytes long. + from bisect import bisect_left + + cur = bisect_left(insts, frame.f_lasti, key=lambda x: x.offset) + else: + # For Python <= 3.10, instructions are always 2 bytes. + cur = frame.f_lasti // 2 + + if sys.version_info >= (3, 13): + if insts[cur].opname in ("TO_BOOL", "COMPARE_OP"): + # Peek 1 instruction further. + cur += 1 + inst = insts[cur] + + if inst.opname == "POP_JUMP_IF_TRUE" and inst.arg is not None: + first = insts[cur + 1] + + starts_with_assert = ( + first.opname == "LOAD_GLOBAL" + and first.argval == "AssertionError" + or first.opname == "LOAD_ASSERTION_ERROR" + ) + if starts_with_assert and insts[cur + 2].opname == "RAISE_VARARGS": + return True + return False + + def _log_real_tensor_propagation( + self, orig_expr: sympy.Basic, unsound_result: sympy.Basic + ) -> None: + log.warning( + "propagate_real_tensors evaluate_expr(%s) -> %s", + orig_expr, + unsound_result, + ) + trace_structured( + "propagate_real_tensors", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "stack": structured.from_traceback( + CapturedTraceback.extract(skip=1).summary() + ), + }, + ) + dtrace_structured( + "propagate_real_tensors_provenance", + metadata_fn=lambda: { + "expr": repr(orig_expr), + "result": repr(unsound_result), + "expr_node_id": self._expr_sym_node_id, + "user_stack": structured.get_user_stack(3), + "stack": structured.get_framework_stack(3), + "symbol_to_sources": { + str(v): k + for k, v in self.source_to_var.items() + if v in orig_expr.free_symbols + }, + "frame_locals": asdict(self._find_frame_locals()), + }, + ) + + def evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]] = None, + fx_node: Optional[torch.fx.Node] = None, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + *, + forcing_spec: bool = False, + ) -> sympy.Basic: + """ + Given an expression, evaluates it, adding guards if necessary + When fallback_value is not None the function return fallback_value instead of failing with data dependent error. + """ + + # Add extra state that evaluate_expr() depends on. + suppress_guards_tls = ShapeEnv._suppress_guards_tls() + return self._inner_evaluate_expr( + orig_expr, + hint, + fx_node, + size_oblivious, + forcing_spec, + suppress_guards_tls, + fallback_value, + ) + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True, name="evaluate_expr") + def _inner_evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[int, bool, float]], + fx_node: Optional[torch.fx.Node], + size_oblivious: bool, + forcing_spec: bool, + _suppress_guards_tls: bool, + fallback_value: Optional[bool] = None, + ) -> sympy.Basic: + try: + return self._evaluate_expr( + orig_expr, + hint, + fx_node, + size_oblivious, + fallback_value, + forcing_spec=forcing_spec, + ) + except Exception as e: + if isinstance(e, GuardOnDataDependentSymNode): + pass + else: + self.log.warning( + "failed during evaluate_expr(%s, hint=%s, size_oblivious=%s, forcing_spec=%s", + orig_expr, + hint, + size_oblivious, + forcing_spec, + ) + raise + + def _log_suppressed_dde(self, a: SymBool, assumed_value: bool) -> None: + sloc, extra = self._get_stack_summary(True) + log.info( + "could not evaluate %s due to data dependency, it was assumed to be %s with no runtime assertions %s %s", + a, + assumed_value, + sloc, + extra, + ) + + def _evaluate_expr( + self, + orig_expr: sympy.Basic, + hint: Optional[Union[bool, int, float]] = None, + fx_node: Optional[torch.fx.Node] = None, + size_oblivious: bool = False, + fallback_value: Optional[bool] = None, + *, + forcing_spec: bool = False, + ) -> sympy.Basic: + # TODO: split conjunctions and evaluate them separately + + if isinstance( + orig_expr, + (sympy.logic.boolalg.BooleanTrue, sympy.logic.boolalg.BooleanFalse), + ): + return orig_expr + + # Don't track this one. (Because this cache is inside this function the + # cache only lasts for the invocation of this function call) + @functools.cache + def compute_concrete_val() -> sympy.Basic: + if hint is None: + # This is only ever called for expressions WITHOUT unbacked + # symbols + r = self.size_hint(orig_expr) + assert r is not None + return r + else: + return sympy.sympify(hint) + + concrete_val: Optional[sympy.Basic] + + # Check if: + # 1. 'translation_validation' is set + # 2. the corresponding 'fx_node' is not 'None' + # 3. the guard should not be suppressed + # 4. the guard doesn't contain backed symfloat symbols + # since z3 can't handle floats + # 5. fallback_value is none. + # If all of the above check, we create an FX node representing the + # actual expression to be guarded. + node = None + fresh = False + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + and not size_oblivious + and not any(symbol_is_type(s, SymT.FLOAT) for s in orig_expr.free_symbols) + and fallback_value is None + ): + # TODO: does this even worked with unbacked :think: + concrete_val = compute_concrete_val() + if concrete_val is sympy.true: + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + elif concrete_val is sympy.false: + neg, _ = self._create_fx_call_function(operator.not_, (fx_node,)) + node, fresh = self._create_fx_call_function(torch._assert, (neg,)) + else: + eql, _ = self._create_fx_call_function( + operator.eq, (fx_node, concrete_val) + ) + node, fresh = self._create_fx_call_function(torch._assert, (eql,)) + + assert node is not None + # If this is a fresh node, we have to remember the event index that + # corresponds to this assertion node. + # Reason: so that, given an assertion node, we can replay the ShapeEnv + # events until the point where this assertion node was freshly created. + if fresh: + self._add_fx_node_metadata(node) + + # After creating the FX node corresponding to orig_expr, we must make sure that + # no error will be raised until the end of this function. + # + # Reason: the translation validation may become invalid otherwise. + # + # If an error is raised before the end of this function, we remove the FX node + # inserted, and re-raise the error. + guard = None + + try: + if orig_expr.is_number: + self.log.debug("eval %s [trivial]", orig_expr) + if hint is not None: + if isinstance(hint, bool): + assert orig_expr == hint, f"{orig_expr} != {hint}" + else: + assert sympy.Eq(orig_expr, hint), f"{orig_expr} != {hint}" + return orig_expr + + expr = orig_expr + + static_expr = self._maybe_evaluate_static( + expr, size_oblivious=size_oblivious + ) + if static_expr is not None: + self.log.debug( + "eval %s == %s [statically known]", + ( + f"size_oblivious({orig_expr})" + if size_oblivious + else size_oblivious + ), + static_expr, + ) + if ( + not size_oblivious + and config.backed_size_oblivious + and hint is not None + ): + # TODO: maybe reconcile this with use of counterfactual hints + # in unbacked case + assert static_expr == hint, f"{static_expr} != {hint}" + return static_expr + + transmute_into_runtime_assert = False + + concrete_val = None + if not (expr.free_symbols <= self.var_to_val.keys()): + # TODO: dedupe this with _maybe_evaluate_static + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None + if not (new_expr.free_symbols <= self.var_to_val.keys()): + ok = False + + # fallback_value is set when guard_or_true or guard_or_false are used. + if not ok and fallback_value is not None: + self._log_suppressed_dde(orig_expr, fallback_value) + return fallback_value + + # oblivious_var_to_val will be defined iff we have sizes with DimDynamic.OBLIVIOUS_SIZE type. + # See https://github.com/pytorch/pytorch/issues/137100#issuecomment-2495778113 + if ( + self.oblivious_var_to_val + and not ( + correct_hint := orig_expr.xreplace( + self.oblivious_var_to_val + ) + ).free_symbols + and not ( + counterfactual_hint := orig_expr.xreplace( + { + k: max(2, v) + for k, v in self.oblivious_var_to_val.items() + } + ) + ).free_symbols + and correct_hint == counterfactual_hint + ): + # TODO: better logging + log.info( + "oblivious_size %s -> %s (passed counterfactual)", + orig_expr, + correct_hint, + ) + concrete_val = correct_hint + # NB: do NOT transmute into runtime assert + ok = True + + # unbacked_var_to_val is not None iff propagate_real_tensors is on. + # if propagate_real_tensors is on, we check the example values to generate (unsound_result) + # and if they pass we add a runtime assertions and continue. + if ( + not ok + and self.unbacked_var_to_val + and not ( + unsound_result := orig_expr.xreplace( + self.unbacked_var_to_val + ).xreplace(self.var_to_val) + ).free_symbols + ): + self._log_real_tensor_propagation(orig_expr, unsound_result) + transmute_into_runtime_assert = True + concrete_val = unsound_result + ok = True + + # Check if this is coming from a python assert statement, if so, convert it to a runtime assertion + # instead of failing. + if not ok and self.trace_asserts and self._is_python_assert(): + concrete_val = sympy.true + transmute_into_runtime_assert = True + ok = True + + if not ok: + size_oblivious_result = None + # compute size_oblivious_result to suggest it as a fix for the user if it works. + if not size_oblivious: + size_oblivious_result = self._maybe_evaluate_static( + expr, size_oblivious=True + ) + raise self._make_data_dependent_error( + expr.xreplace(self.var_to_val), + expr, + size_oblivious_result=size_oblivious_result, + expr_sym_node_id=self._expr_sym_node_id, + ) + else: + expr = new_expr + + if concrete_val is None: + concrete_val = compute_concrete_val() + self._check_frozen(expr, concrete_val) + + if ( + config.inject_EVALUATE_EXPR_flip_equality_TESTING_ONLY + and isinstance(hint, bool) + and isinstance(expr, (sympy.Eq, sympy.Ne)) + ): + expr = sympy.Not(expr) + + # Turn this into a boolean expression, no longer need to consult + # concrete_val + if concrete_val is sympy.true: + g = cast(SympyBoolean, expr) + elif concrete_val is sympy.false: + g = sympy.Not(expr) + else: + g = sympy.Eq(expr, concrete_val) # type: ignore[arg-type] + + if transmute_into_runtime_assert: + self.guard_or_defer_runtime_assert( + g, f"propagate_real_tensors: {orig_expr} == {concrete_val}" + ) + return concrete_val + + if not self._suppress_guards_tls(): + self._log_guard("eval", g, forcing_spec=forcing_spec) + + # TODO: If we successfully eliminate a symbol via equality, it + # is not actually necessary to save a guard for the equality, + # as we will implicitly generate a guard when we match that + # input against the symbol. Probably the easiest way to + # implement this is to have maybe_guard_rel return a bool + # saying if it "subsumed" the guard (and therefore the guard + # is no longer necessary) + self._maybe_guard_rel(g) + + if not self.allow_complex_guards_as_runtime_asserts: + # at this point, we've evaluated the concrete expr value, and have + # flipped/negated the guard if necessary. Now we know what to guard + # or defer to runtime assert on. + guard = ShapeGuard( + g, self._get_sloc(), size_oblivious=size_oblivious + ) + self.guards.append(guard) + self.axioms.update(dict(self.get_implications(self.simplify(g)))) + else: + # it's fine to defer simple guards here without checking, + # the _maybe_guard_rel() call above will set replacements if possible, + # and so the result here will be statically known + self.guard_or_defer_runtime_assert(g, f"evaluate_expr: {orig_expr}") + else: + self._log_guard("eval [guard suppressed]", g, forcing_spec=forcing_spec) + + except Exception: + if fresh: + self._remove_fx_node(node) + raise + + if not self._suppress_guards_tls(): + if guard is not None: # we might have deferred this to runtime assert + for s in g.free_symbols: + self.symbol_guard_counter[s] += 1 + # Forcing_spec to avoid infinite recursion + if ( + not forcing_spec + and config.symbol_guard_limit_before_specialize is not None + and self.symbol_guard_counter[s] + > config.symbol_guard_limit_before_specialize + ): + # Force specialization + self.log.info( + "symbol_guard_limit_before_specialize=%s exceeded on %s", + config.symbol_guard_limit_before_specialize, + s, + ) + self.evaluate_expr(s, forcing_spec=True) + + return concrete_val + + def cleanup(self) -> None: + """ + Break reference cycles. + + This destroys the stacks. If you really want to keep them, we + just need some way to break references on code objects. + """ + for s in self.var_to_stack.values(): + s.cleanup() + for ras in self.deferred_runtime_asserts.values(): + for ra in ras: + ra.stack.cleanup() + + @lru_cache(256) + @record_shapeenv_event(save_tracked_fakes=True) + def guard_or_defer_runtime_assert( + self, orig_expr: SympyBoolean, msg: str, fx_node: Optional[torch.fx.Node] = None + ) -> bool: + """ + Adds a guard that orig_expr is True if we can or fall back to adding an assert + that is checked at runtime. + + Args: + orig_expr (sympy.Expr): Boolean expression to assert is true + msg (str): Message to display on assertion failure + fx_node (Optional, torch.fx.Node): node in ``self.graph`` corresponding + to the expression, if applicable + """ + expr = orig_expr + + # TODO: split conjunctions and evaluate them separately + + static_expr = self._maybe_evaluate_static(expr) + if static_expr is not None: + self.log.debug( + "runtime_assert %s == %s [statically known]", orig_expr, static_expr + ) + # TODO: assert bool(static_expr) + return bool(static_expr) + + # Attempt to eliminate the unbacked SymInt + new_expr = self._maybe_evaluate_static(expr, unbacked_only=True) + assert new_expr is not None + if ( + not self.prefer_deferred_runtime_asserts_over_guards + and new_expr.free_symbols <= self.var_to_val.keys() + ): + # Do a normal guard + return self.evaluate_expr(new_expr, fx_node=fx_node) + # NB: Don't use new_expr as expr; it could contain gunk like shape0 + # which we don't want to guard on + + if ( + self._translation_validation_enabled + and fx_node is not None + and not self._suppress_guards_tls() + ): + node, fresh = self._create_fx_call_function(torch._assert, (fx_node,)) + assert node is not None + if fresh: + self._add_fx_node_metadata(node) + + if not self._suppress_guards_tls(): + self._log_guard("runtime_assert", orig_expr, forcing_spec=False) + # If you're here because of this assert, read Note [Backwards runtime asserts] + # in torch/_inductor/graph.py + if self.runtime_asserts_frozen: + log.debug("runtime_asserts_frozen but then got %s", expr) + self._check_frozen(expr, sympy.true) + # eliminate symbols on equality tests / refine ranges + self._maybe_guard_rel(expr) + + # canonicalise to remove equations that are trivially equal + orig_expr = expr + expr = canonicalize_bool_expr(expr) + stack = CapturedTraceback.extract(skip=1) + ra = RuntimeAssert(expr, msg, stack) + # TODO: Do this in a way that is less janky than int(s.name[1:]) + cands = sorted( + (s for s in expr.free_symbols if symbol_is_type(s, SymT.UNBACKED_INT)), + key=lambda s: int(s.name[1:]), + ) + # Is None when prefer_deferred_runtime_asserts_over_guards=True + # and the guard in question has no unbacked SymInts in front + ix = cands[-1] if cands else None + self.deferred_runtime_asserts.setdefault(ix, []).append(ra) + self.axioms.update(dict(self.get_implications(self.simplify(expr)))) + self.num_deferred_runtime_asserts += 1 + self._update_version_counter() + else: + self._log_guard( + "runtime_assert [guard suppressed]", orig_expr, forcing_spec=False + ) + + return True + + # Refines the ranges of the variables present in 'guard'. + # + # This function tries to refine the range of the variables inside + # 'guard' by reasoning about it. Specifically, when 'guard' is a + # 'sympy.Relational' operation. + # + # It does mainly 3 things: + # 1. Tries to isolate a variable in the left-hand side + # 2. Compute the value range of the right-hand side + # 3. Update the value range of the variable, if better + def _refine_ranges(self, expr: SympyBoolean) -> None: + expr = self.simplify(expr) + + for symbol in expr.free_symbols: + assert isinstance(symbol, sympy.Symbol) + + if isinstance(self.var_to_val.get(symbol, None), SingletonInt): + # Skip var_to_range logic for SingletonInt which is only used + # for jagged layout NestedTensors today + continue + + r = try_solve(expr, symbol) + + if r is None or not (symbol.is_integer and r[1].is_integer): + # Range refinement only supports integer symbols for now. + # There are lots of SymPy bugs when it comes to comparing + # reals and integers, so we skip that for now. + continue + + r_expr, rhs = r + vr = self.var_to_range[symbol] + lower, upper = vr.lower, vr.upper + + rhs_vr = bound_sympy(rhs, self.var_to_range) + + # Let's suppose that we have a preexisting range for x [0, 100]. + # Now, we issue a guard x > y, where the range for y is [50, 150]. + # Then, lower = 0, rhs_vr.lower = 50 and therefore refinement can happen, + # refining x to [51, 100], since x must be greater than y, but the lowest + # y could be is 50. + # + # sympy.Eq may update both lower and upper bounds. + # sympy.G{t,e} may update the lower bound, only. + # sympy.L{t,e} may update the upper bound, only. + if lower < rhs_vr.lower and isinstance( + r_expr, (sympy.Eq, sympy.Ge, sympy.Gt) + ): + # Strictly greater relations allow us to refine a bit more, since + # x < y implies that the lower bound for x is: y + 1. + lower = rhs_vr.lower + int(isinstance(r_expr, sympy.Gt)) + if upper > rhs_vr.upper and isinstance( + r_expr, (sympy.Eq, sympy.Le, sympy.Lt) + ): + upper = rhs_vr.upper - int(isinstance(r_expr, sympy.Lt)) + + # Do nothing if the new value range is no better than what we already have. + if vr == ValueRanges(lower, upper): + continue + + # Updates the range and the guards corresponding to each bound of the symbol. + self._update_var_to_range(symbol, ValueRanges(lower, upper)) + # If the range is refined to singleton, set replacement + if self.var_to_range[symbol].is_singleton(): + self._set_replacement( + symbol, + self.var_to_range[symbol].lower, + "range_refined_to_singleton", + ) + + # Clears the cache, since this update can change the result. + self._maybe_evaluate_static.cache_clear() + + @lru_cache(maxsize=None) + @record_shapeenv_event() + def constrain_symbol_range( + self, s: sympy.Symbol, compiler_min: int, compiler_max: int + ) -> None: + upd_vr = ValueRanges(compiler_min, compiler_max) + old_vr = self.var_to_range.get(s, ValueRanges.unknown()) + self._update_var_to_range(s, upd_vr) + if (new_vr := self.var_to_range[s]) != old_vr: + log.info( + "constrain_symbol_range %s [%s, %s]", s, new_vr.lower, new_vr.upper + ) + + +def _is_int(expr: object) -> bool: + return isinstance(expr, SymInt) and expr.node.expr.is_number + + +# WARNING: This is legacy, DO NOT USE +def _is_dim_dynamic(t: torch.Tensor, d: int) -> bool: + return hasattr(t, "_dynamo_dynamic_indices") and d in t._dynamo_dynamic_indices + + +class PropagateUnbackedSymInts(torch.fx.Interpreter): + def run_node(self, n: torch.fx.Node) -> Result: + """ + Run an FX node, propagating unbacked Symbol bindings to the new fake tensor + """ + from torch._guards import detect_fake_mode + + result = super().run_node(n) + rebind_unbacked(detect_fake_mode().shape_env, n, result) + return result + + +def _find_user_code_frame() -> Optional[types.FrameType]: + frame = inspect.currentframe() + while frame is not None: + if not frame.f_code.co_filename.startswith( + os.path.dirname(inspect.getfile(torch)) + os.path.sep + ): + break + frame = frame.f_back + return frame + + +def _blame_user_code(e: Exception, frame: types.FrameType) -> None: + frame_summary = traceback.FrameSummary( + frame.f_code.co_filename, + frame.f_lineno, + frame.f_code.co_name, + ) + msg = e.args[0] + msg += "\n\nThe following call raised this error:\n" + "".join( + traceback.StackSummary.from_list([frame_summary]).format() + ) + e.args = (msg,) + + +class _PythonMsgPrinter(PythonPrinter): + """ + Util printer that replaces sympy symbols with their source-level names + and renders sympy relational operators (e.g., Eq, Ne, Ge, Le) inline + (i.e., as ==, !=, >, <). + """ + + def __init__(self, src_map: dict[str, list[str]]) -> None: + super().__init__() + self.src_map = src_map + + def _print_Symbol(self, sym: sympy.Symbol) -> str: + return self.src_map[sym.name][0] + + +def _is_non_negative_check(cond: sympy.Basic) -> Optional[str]: + """ + Check if a condition (SymPy expression) is checking for non-negative values (>= 0). + Returns the variable name if it's a non-negative check (>= 0), None otherwise. + """ + if isinstance(cond, sympy.Rel): + if cond.rel_op == ">=" and cond.rhs == 0: + return str(cond.lhs) + return None + + +def _suggest_torch_checks( + e: GuardOnDataDependentSymNode, src_map: defaultdict[str, list[str]] +) -> None: + """ + Enhances a GuardOnDataDependentSymNode error with suggested fixes using torch._check. + + This function analyzes the condition that caused the data-dependent error and generates + user-friendly suggestions for fixing it by adding appropriate torch._check calls. + It handles special cases like non-negative checks with specific recommendations. + + Args: + e: The GuardOnDataDependentSymNode error to enhance with suggestions + src_map: A mapping from symbol names to their corresponding source-level variable names + + Returns: + None. Modifies the error message in-place by updating e.args[0]. + """ + # extract the unresolved condition on unbacked symints in the error + cond = e.cond + diff = ", ".join(s.name for s in cond.free_symbols if s.name not in src_map) + if diff: + log.warning("Unable to find user code corresponding to {%s}", diff) + return + printer = _PythonMsgPrinter(src_map) + msg = e.args[0] + msg += "\nTo fix the error, insert one of the following checks before this call:" + + not_cond_str = printer.doprint(sympy.Not(cond)) + var_name = _is_non_negative_check(cond) + + # suggested fixes to resolve `cond` are to tell the compiler to assume + # either `cond` or its negation (the user will need to select which) + suggested_fixes = [] + + if var_name: + suggested_fixes = [ + f"You can add either: torch._check_is_size({var_name}) or torch._check({var_name}>=0)" + f" Note: torch._check_is_size({var_name}) could prevent data dependent errors that" + + " happen in a guard_size_oblivious(..) context by opting into guard_size_oblivious reasoning." + + " See documentation on guard_size_oblivious for more details:" + + " https://pytorch.org/docs/stable/generated/torch.fx.experimental.symbolic_shapes.guard_size_oblivious.html", + f"torch._check({not_cond_str})", + ] + else: + suggested_fixes = [ + f"torch._check({printer.doprint(cond)})", + f"torch._check({not_cond_str})", + ] + + for i, fix in enumerate(suggested_fixes): + msg += f"\n {i + 1}. {fix}" + src_mapped = ", ".join( + f"`{s}` with {' or '.join(src_map[s])}" + for s in sorted(s.name for s in cond.free_symbols) + ) + msg += f"\n\n(These suggested fixes were derived by replacing {src_mapped} in {cond} and its negation.)" + e.args = (msg,) + + +def _suggest_fixes_for_data_dependent_error_non_strict( + e: GuardOnDataDependentSymNode, +) -> None: + """ + Given a raised data-dependent error, add the following to the error message: + 1. the closest user code location that raised the error; + 2. suggested fixes for the error in terms of live variables at that location. + """ + + # walk the stack up from the data-dependent error until a non-torch frame is found + frame = _find_user_code_frame() + if frame is not None: + # add frame info to error message + _blame_user_code(e, frame) + + # map symbol names reachable via frame locals to their source-level names + src_map = defaultdict(list) + for var, val in frame.f_locals.items(): + try: + tree_leaves_with_path = pytree.tree_leaves_with_path(val) + except ValueError: + log.warning( + "pytree.tree_leaves_with_path failed for value of type {%s} in local variable {%s}", + type(val), + var, + ) + continue + # figure out how to access any symbol inside `val` through `var` + for path, leaf in tree_leaves_with_path: + name = var + pytree.keystr(path) + if isinstance(leaf, torch.SymInt): + src_map[str(leaf.node.expr)].append(name) + elif isinstance(leaf, torch.Tensor): + for i, dim in enumerate(leaf.shape): + if isinstance(dim, torch.SymInt): + src_map[str(dim.node.expr)].append(f"{name}.shape[{i}]") + + # add suggested torch.check()s based on `src_map` to the error message + # replacing unbacked symints in the unresolved condition in the error + if isinstance(e.cond, sympy.logic.boolalg.Boolean): + _suggest_torch_checks(e, src_map) + + +@contextmanager +def _remove_effect_token_unbacked_bindings( + node: torch.fx.Node, +) -> Generator[None, None, None]: + """ + Temporarily modifies unbacked_bindings in a node's metadata by removing the first element + of each path, which corresponds to an effect token. + + This is used when processing nodes that have effect tokens as the first element in their + unbacked_bindings paths. The context manager ensures that the original bindings are + restored after the operation is complete. + + Args: + node: The FX node whose unbacked_bindings will be temporarily modified + + Yields: + None + """ + old_bindings = node.meta.get("unbacked_bindings", {}) + + # Remove the extra layer for effect token + new_bindings = {k: path[1:] if path else path for k, path in old_bindings.items()} + + node.meta["unbacked_bindings"] = new_bindings + + try: + yield + finally: + node.meta["unbacked_bindings"] = old_bindings + + +# This helper function is used in passes that insert runtime assertions in the graph. +# When accessing expressions representing input placeholders, we do not apply replacements +# since those inputs should be seen by assertions that use them to be inserted. The only replacement +# that we apply is unbacked renaming. +def _get_placeholder_expr(sym_node: SymNode) -> sympy.Expr: + shape_env = sym_node.shape_env + result = sym_node._expr + if result in shape_env.unbacked_renamings: + return shape_env.unbacked_renamings[result] + return result diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__init__.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..24ba562de17ddd0d683237f4d7e5447b2e668e74 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__init__.py @@ -0,0 +1,4 @@ +# mypy: disable-error-code=attr-defined +from .core import reify, unify # noqa: F403 +from .more import unifiable # noqa: F403 +from .variable import isvar, Var, var, variables, vars # noqa: F403 diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45bb7ce2865ae90e69666e58f91a49ef5cc80bc0 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6a366dcc44ccc9924f94badd006ce8eaa7200fc Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71c0ef7273c0051fb121b0663074af6b50ba907a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/dispatch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf1986286099a60d20dda9fe2c28dff023f37011 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/match.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c6a8e3057202677c5ff0463e1744792cd24f5d77 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/more.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45348239d5c2fa3d355516d0715eea14fa95fb7b Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/unification_tools.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8422cec4faa91e570151a59cff4cf574bc552d3 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a55d59d1e4a66aa0c0a2ca5e749148a1923847a2 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/__pycache__/variable.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/core.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/core.py new file mode 100644 index 0000000000000000000000000000000000000000..86dcf68a38d94a110e52c9682230546f6c152bc3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/core.py @@ -0,0 +1,141 @@ +# mypy: allow-untyped-defs +from collections.abc import Iterator # type: ignore[import] +from functools import partial + +from .dispatch import dispatch +from .unification_tools import assoc # type: ignore[import] +from .utils import transitive_get as walk +from .variable import isvar + + +__all__ = ["reify", "unify"] + +############### +# Reification # +############### + + +@dispatch(Iterator, dict) +def _reify(t, s): + return map(partial(reify, s=s), t) + # return (reify(arg, s) for arg in t) + + +_reify + + +@dispatch(tuple, dict) # type: ignore[no-redef] +def _reify(t, s): + return tuple(reify(iter(t), s)) + + +_reify + + +@dispatch(list, dict) # type: ignore[no-redef] +def _reify(t, s): + return list(reify(iter(t), s)) + + +_reify + + +@dispatch(dict, dict) # type: ignore[no-redef] +def _reify(d, s): + return {k: reify(v, s) for k, v in d.items()} + + +_reify + + +@dispatch(object, dict) # type: ignore[no-redef] +def _reify(o, s): + return o # catch all, just return the object + + +def reify(e, s): + """Replace variables of expression with substitution + >>> # xdoctest: +SKIP + >>> x, y = var(), var() + >>> e = (1, x, (3, y)) + >>> s = {x: 2, y: 4} + >>> reify(e, s) + (1, 2, (3, 4)) + >>> e = {1: x, 3: (y, 5)} + >>> reify(e, s) + {1: 2, 3: (4, 5)} + """ + if isvar(e): + return reify(s[e], s) if e in s else e + return _reify(e, s) + + +############### +# Unification # +############### + +seq = tuple, list, Iterator + + +@dispatch(seq, seq, dict) +def _unify(u, v, s): + if len(u) != len(v): + return False + for uu, vv in zip(u, v): # avoiding recursion + s = unify(uu, vv, s) + if s is False: + return False + return s + + +# +# @dispatch((set, frozenset), (set, frozenset), dict) +# def _unify(u, v, s): +# i = u & v +# u = u - i +# v = v - i +# return _unify(sorted(u), sorted(v), s) +# +# +# @dispatch(dict, dict, dict) +# def _unify(u, v, s): +# if len(u) != len(v): +# return False +# for key, uval in iteritems(u): +# if key not in v: +# return False +# s = unify(uval, v[key], s) +# if s is False: +# return False +# return s +# +# +# @dispatch(object, object, dict) +# def _unify(u, v, s): +# return False # catch all + + +@dispatch(object, object, dict) +def unify(u, v, s): # no check at the moment + """Find substitution so that u == v while satisfying s + >>> x = var("x") + >>> unify((1, x), (1, 2), {}) + {~x: 2} + """ + u = walk(u, s) + v = walk(v, s) + if u == v: + return s + if isvar(u): + return assoc(s, u, v) + if isvar(v): + return assoc(s, v, u) + return _unify(u, v, s) + + +unify + + +@dispatch(object, object) # type: ignore[no-redef] +def unify(u, v): + return unify(u, v, {}) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/dispatch.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..0cb435e7e352e314911287d2e99ee2cd038b18f0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/dispatch.py @@ -0,0 +1,8 @@ +from functools import partial + +from .multipledispatch import dispatch # type: ignore[import] + + +namespace = {} # type: ignore[var-annotated] + +dispatch = partial(dispatch, namespace=namespace) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/match.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/match.py new file mode 100644 index 0000000000000000000000000000000000000000..89c554df327792e3b4cabf3b50e295c7f4a98560 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/match.py @@ -0,0 +1,129 @@ +# mypy: allow-untyped-defs +from .core import reify, unify # type: ignore[attr-defined] +from .unification_tools import first, groupby # type: ignore[import] +from .utils import _toposort, freeze +from .variable import isvar + + +class Dispatcher: + def __init__(self, name): + self.name = name + self.funcs = {} + self.ordering = [] + + def add(self, signature, func): + self.funcs[freeze(signature)] = func + self.ordering = ordering(self.funcs) + + def __call__(self, *args, **kwargs): + func, _ = self.resolve(args) + return func(*args, **kwargs) + + def resolve(self, args): + n = len(args) + for signature in self.ordering: + if len(signature) != n: + continue + s = unify(freeze(args), signature) + if s is not False: + result = self.funcs[signature] + return result, s + raise NotImplementedError( + "No match found. \nKnown matches: " + + str(self.ordering) + + "\nInput: " + + str(args) + ) + + def register(self, *signature): + def _(func): + self.add(signature, func) + return self + + return _ + + +class VarDispatcher(Dispatcher): + """A dispatcher that calls functions with variable names + >>> # xdoctest: +SKIP + >>> d = VarDispatcher("d") + >>> x = var("x") + >>> @d.register("inc", x) + ... def f(x): + ... return x + 1 + >>> @d.register("double", x) + ... def f(x): + ... return x * 2 + >>> d("inc", 10) + 11 + >>> d("double", 10) + 20 + """ + + def __call__(self, *args, **kwargs): + func, s = self.resolve(args) + d = {k.token: v for k, v in s.items()} + return func(**d) + + +global_namespace = {} # type: ignore[var-annotated] + + +def match(*signature, **kwargs): + namespace = kwargs.get("namespace", global_namespace) + dispatcher = kwargs.get("Dispatcher", Dispatcher) + + def _(func): + name = func.__name__ + + if name not in namespace: + namespace[name] = dispatcher(name) + d = namespace[name] + + d.add(signature, func) + + return d + + return _ + + +def supercedes(a, b): + """``a`` is a more specific match than ``b``""" + if isvar(b) and not isvar(a): + return True + s = unify(a, b) + if s is False: + return False + s = {k: v for k, v in s.items() if not isvar(k) or not isvar(v)} + if reify(a, s) == a: + return True + if reify(b, s) == b: + return False + + +# Taken from multipledispatch +def edge(a, b, tie_breaker=hash): + """A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + if supercedes(a, b): + if supercedes(b, a): + return tie_breaker(a) > tie_breaker(b) + else: + return True + return False + + +# Taken from multipledispatch +def ordering(signatures): + """A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(first, edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[attr-defined, assignment] + return _toposort(edges) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/more.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/more.py new file mode 100644 index 0000000000000000000000000000000000000000..29c6068ad33192ee698769bd573972225dfc7961 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/more.py @@ -0,0 +1,122 @@ +# mypy: allow-untyped-defs +from .core import reify, unify # type: ignore[attr-defined] +from .dispatch import dispatch + + +def unifiable(cls): + """Register standard unify and reify operations on class + This uses the type and __dict__ or __slots__ attributes to define the + nature of the term + See Also: + >>> # xdoctest: +SKIP + >>> class A(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + >>> unifiable(A) + + >>> x = var("x") + >>> a = A(1, 2) + >>> b = A(1, x) + >>> unify(a, b, {}) + {~x: 2} + """ + _unify.add((cls, cls, dict), unify_object) + _reify.add((cls, dict), reify_object) + + return cls + + +######### +# Reify # +######### + + +def reify_object(o, s): + """Reify a Python object with a substitution + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + ... def __str__(self): + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") + >>> f = Foo(1, x) + >>> print(f) + Foo(1, ~x) + >>> print(reify_object(f, {x: 2})) + Foo(1, 2) + """ + if hasattr(o, "__slots__"): + return _reify_object_slots(o, s) + else: + return _reify_object_dict(o, s) + + +def _reify_object_dict(o, s): + obj = object.__new__(type(o)) + d = reify(o.__dict__, s) + if d == o.__dict__: + return o + obj.__dict__.update(d) + return obj + + +def _reify_object_slots(o, s): + attrs = [getattr(o, attr) for attr in o.__slots__] + new_attrs = reify(attrs, s) + if attrs == new_attrs: + return o + else: + newobj = object.__new__(type(o)) + for slot, attr in zip(o.__slots__, new_attrs): + setattr(newobj, slot, attr) + return newobj + + +@dispatch(slice, dict) +def _reify(o, s): + """Reify a Python ``slice`` object""" + return slice(*reify((o.start, o.stop, o.step), s)) + + +######### +# Unify # +######### + + +def unify_object(u, v, s): + """Unify two Python objects + Unifies their type and ``__dict__`` attributes + >>> # xdoctest: +SKIP + >>> class Foo(object): + ... def __init__(self, a, b): + ... self.a = a + ... self.b = b + ... + ... def __str__(self): + ... return "Foo(%s, %s)" % (str(self.a), str(self.b)) + >>> x = var("x") + >>> f = Foo(1, x) + >>> g = Foo(1, 2) + >>> unify_object(f, g, {}) + {~x: 2} + """ + if type(u) != type(v): + return False + if hasattr(u, "__slots__"): + return unify( + [getattr(u, slot) for slot in u.__slots__], + [getattr(v, slot) for slot in v.__slots__], + s, + ) + else: + return unify(u.__dict__, v.__dict__, s) + + +@dispatch(slice, slice, dict) +def _unify(u, v, s): + """Unify a Python ``slice`` object""" + return unify((u.start, u.stop, u.step), (v.start, v.stop, v.step), s) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3b25cc77d1fcd6b9f74eb182c77eaffb953676 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__init__.py @@ -0,0 +1,7 @@ +from .core import dispatch +from .dispatcher import ( + Dispatcher, + halt_ordering, + MDNotImplementedError, + restart_ordering, +) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ae41b762326b2cd78c1d4c2d9513ac7519136ce Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5932c7c6e67ce7ab7f43a9c58c411dd6c9342a7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/conflict.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda6970c86bb595ff355e05cfe34de6de8dcd2b1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/core.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e24511802ab15bfa7b8f40b623b6cd59ae7dc372 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/dispatcher.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..28b974895173d4a4a70193609585d74b21e3da30 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98ba11900db55bff94e9e4cc750da0c8f0c406cf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/__pycache__/variadic.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py new file mode 100644 index 0000000000000000000000000000000000000000..36306aac1452e212cd985d87b1fc669797e670bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/conflict.py @@ -0,0 +1,139 @@ +# mypy: allow-untyped-defs +import operator + +from .utils import _toposort, groupby +from .variadic import isvariadic + + +__all__ = [ + "AmbiguityWarning", + "supercedes", + "consistent", + "ambiguous", + "ambiguities", + "super_signature", + "edge", + "ordering", +] + + +class AmbiguityWarning(Warning): + pass + + +def supercedes(a, b): + """A is consistent and strictly more specific than B""" + if len(a) < len(b): + # only case is if a is empty and b is variadic + return not a and len(b) == 1 and isvariadic(b[-1]) + elif len(a) == len(b): + return all(map(issubclass, a, b)) + else: + # len(a) > len(b) + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not (isvariadic(cur_a) or isvariadic(cur_b)): + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + assert p1 == len(a) - 1 + return p2 == len(b) - 1 and issubclass(cur_a, cur_b) + elif isvariadic(cur_b): + assert p2 == len(b) - 1 + if not issubclass(cur_a, cur_b): + return False + p1 += 1 + return p2 == len(b) - 1 and p1 == len(a) + + +def consistent(a, b): + """It is possible for an argument list to satisfy both A and B""" + + # Need to check for empty args + if not a: + return not b or isvariadic(b[0]) + if not b: + return not a or isvariadic(a[0]) + + # Non-empty args check for mutual subclasses + if len(a) == len(b): + return all(issubclass(aa, bb) or issubclass(bb, aa) for aa, bb in zip(a, b)) + else: + p1 = 0 + p2 = 0 + while p1 < len(a) and p2 < len(b): + cur_a = a[p1] + cur_b = b[p2] + if not issubclass(cur_b, cur_a) and not issubclass(cur_a, cur_b): + return False + if not (isvariadic(cur_a) or isvariadic(cur_b)): + p1 += 1 + p2 += 1 + elif isvariadic(cur_a): + p2 += 1 + elif isvariadic(cur_b): + p1 += 1 + # We only need to check for variadic ends + # Variadic types are guaranteed to be the last element + return ( + isvariadic(cur_a) # type: ignore[possibly-undefined] + and p2 == len(b) + or isvariadic(cur_b) # type: ignore[possibly-undefined] + and p1 == len(a) + ) + + +def ambiguous(a, b): + """A is consistent with B but neither is strictly more specific""" + return consistent(a, b) and not (supercedes(a, b) or supercedes(b, a)) + + +def ambiguities(signatures): + """All signature pairs such that A is ambiguous with B""" + signatures = list(map(tuple, signatures)) + return { + (a, b) + for a in signatures + for b in signatures + if hash(a) < hash(b) + and ambiguous(a, b) + and not any(supercedes(c, a) and supercedes(c, b) for c in signatures) + } + + +def super_signature(signatures): + """A signature that would break ambiguities""" + n = len(signatures[0]) + assert all(len(s) == n for s in signatures) + + return [max((type.mro(sig[i]) for sig in signatures), key=len)[0] for i in range(n)] + + +def edge(a, b, tie_breaker=hash): + """A should be checked before B + Tie broken by tie_breaker, defaults to ``hash`` + """ + # A either supercedes B and B does not supercede A or if B does then call + # tie_breaker + return supercedes(a, b) and ( + not supercedes(b, a) or tie_breaker(a) > tie_breaker(b) + ) + + +def ordering(signatures): + """A sane ordering of signatures to check, first to last + Topological sort of edges as given by ``edge`` and ``supercedes`` + """ + signatures = list(map(tuple, signatures)) + edges = [(a, b) for a in signatures for b in signatures if edge(a, b)] + edges = groupby(operator.itemgetter(0), edges) + for s in signatures: + if s not in edges: + edges[s] = [] + edges = {k: [b for a, b in v] for k, v in edges.items()} # type: ignore[assignment, attr-defined] + return _toposort(edges) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/core.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/core.py new file mode 100644 index 0000000000000000000000000000000000000000..75a364ac4a554af7aab6f64b796493ef69ee6efe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/core.py @@ -0,0 +1,84 @@ +# mypy: allow-untyped-defs +import inspect + +from .dispatcher import Dispatcher, MethodDispatcher + + +global_namespace = {} # type: ignore[var-annotated] + +__all__ = ["dispatch", "ismethod"] + + +def dispatch(*types, **kwargs): + """Dispatch function on the types of the inputs + Supports dispatch on all non-keyword arguments. + Collects implementations based on the function name. Ignores namespaces. + If ambiguous type signatures occur a warning is raised when the function is + defined suggesting the additional method to break the ambiguity. + + Example: + >>> # xdoctest: +SKIP + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> # xdoctest: +SKIP + >>> f(3) + 4 + >>> f(3.0) + 2.0 + >>> # Specify an isolated namespace with the namespace keyword argument + >>> my_namespace = {} + >>> @dispatch(int, namespace=my_namespace) + ... def foo(x): + ... return x + 1 + >>> # Dispatch on instance methods within classes + >>> class MyClass(object): + ... @dispatch(list) + ... def __init__(self, data): + ... self.data = data + ... + ... @dispatch(int) + ... def __init__(self, datum): + ... self.data = [datum] + >>> MyClass([1, 2, 3]).data + [1, 2, 3] + >>> MyClass(3).data + [3] + """ + namespace = kwargs.get("namespace", global_namespace) + + types = tuple(types) + + def _df(func): + name = func.__name__ + + if ismethod(func): + dispatcher = inspect.currentframe().f_back.f_locals.get( # type: ignore[union-attr] + name, # type: ignore[union-attr] + MethodDispatcher(name), + ) + else: + if name not in namespace: + namespace[name] = Dispatcher(name) + dispatcher = namespace[name] + + dispatcher.add(types, func) + return dispatcher + + return _df + + +def ismethod(func): + """Is func a method? + Note that this has to work as the method is defined but before the class is + defined. At this stage methods look like functions. + """ + if hasattr(inspect, "signature"): + signature = inspect.signature(func) + return signature.parameters.get("self", None) is not None + else: + spec = inspect.getfullargspec(func) # type: ignore[union-attr, assignment] + return spec and spec.args and spec.args[0] == "self" diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py new file mode 100644 index 0000000000000000000000000000000000000000..392bdebeabc43987c32682467d6166e7967cd946 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/dispatcher.py @@ -0,0 +1,453 @@ +# mypy: allow-untyped-defs +import inspect +import itertools as itl +from typing_extensions import deprecated +from warnings import warn + +from .conflict import ambiguities, AmbiguityWarning, ordering, super_signature +from .utils import expand_tuples +from .variadic import isvariadic, Variadic + + +__all__ = [ + "MDNotImplementedError", + "ambiguity_warn", + "halt_ordering", + "restart_ordering", + "variadic_signature_matches_iter", + "variadic_signature_matches", + "Dispatcher", + "source", + "MethodDispatcher", + "str_signature", + "warning_text", +] + + +class MDNotImplementedError(NotImplementedError): + """A NotImplementedError for multiple dispatch""" + + +def ambiguity_warn(dispatcher, ambiguities): + """Raise warning when ambiguity is detected + Parameters + ---------- + dispatcher : Dispatcher + The dispatcher on which the ambiguity was detected + ambiguities : set + Set of type signature pairs that are ambiguous within this dispatcher + See Also: + Dispatcher.add + warning_text + """ + warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning) + + +@deprecated( + "`halt_ordering` is deprecated, you can safely remove this call.", + category=FutureWarning, +) +def halt_ordering(): + """Deprecated interface to temporarily disable ordering.""" + + +@deprecated( + "`restart_ordering` is deprecated, if you would like to eagerly order the dispatchers, " + "you should call the `reorder()` method on each dispatcher.", + category=FutureWarning, +) +def restart_ordering(on_ambiguity=ambiguity_warn): + """Deprecated interface to temporarily resume ordering.""" + + +def variadic_signature_matches_iter(types, full_signature): + """Check if a set of input types matches a variadic signature. + Notes + ----- + The algorithm is as follows: + Initialize the current signature to the first in the sequence + For each type in `types`: + If the current signature is variadic + If the type matches the signature + yield True + Else + Try to get the next signature + If no signatures are left we can't possibly have a match + so yield False + Else + yield True if the type matches the current signature + Get the next signature + """ + sigiter = iter(full_signature) + sig = next(sigiter) + for typ in types: + matches = issubclass(typ, sig) + yield matches + if not isvariadic(sig): + # we're not matching a variadic argument, so move to the next + # element in the signature + sig = next(sigiter) + else: + try: + sig = next(sigiter) + except StopIteration: + assert isvariadic(sig) + yield True + else: + # We have signature items left over, so all of our arguments + # haven't matched + yield False + + +def variadic_signature_matches(types, full_signature): + # No arguments always matches a variadic signature + assert full_signature + return all(variadic_signature_matches_iter(types, full_signature)) + + +class Dispatcher: + """Dispatch methods based on type signature + Use ``dispatch`` to add implementations + Examples + -------- + >>> # xdoctest: +SKIP("bad import name") + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def f(x): + ... return x + 1 + >>> @dispatch(float) + ... def f(x): + ... return x - 1 + >>> f(3) + 4 + >>> f(3.0) + 2.0 + """ + + __slots__ = "__name__", "name", "funcs", "_ordering", "_cache", "doc" + + def __init__(self, name, doc=None): + self.name = self.__name__ = name + self.funcs = {} + self.doc = doc + + self._cache = {} + + def register(self, *types, **kwargs): + """register dispatcher with new implementation + >>> # xdoctest: +SKIP + >>> f = Dispatcher("f") + >>> @f.register(int) + ... def inc(x): + ... return x + 1 + >>> @f.register(float) + ... def dec(x): + ... return x - 1 + >>> @f.register(list) + ... @f.register(tuple) + ... def reverse(x): + ... return x[::-1] + >>> f(1) + 2 + >>> f(1.0) + 0.0 + >>> f([1, 2, 3]) + [3, 2, 1] + """ + + def _df(func): + self.add(types, func, **kwargs) # type: ignore[call-arg] + return func + + return _df + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return sig.parameters.values() + + @classmethod + def get_func_annotations(cls, func): + """get annotations of function positional parameters""" + params = cls.get_func_params(func) + if params: + Parameter = inspect.Parameter + + params = ( + param + for param in params + if param.kind + in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + ) + + annotations = tuple(param.annotation for param in params) + + if all(ann is not Parameter.empty for ann in annotations): + return annotations + + def add(self, signature, func): + """Add new types/method pair to dispatcher + >>> # xdoctest: +SKIP + >>> D = Dispatcher("add") + >>> D.add((int, int), lambda x, y: x + y) + >>> D.add((float, float), lambda x, y: x + y) + >>> D(1, 2) + 3 + >>> D(1, 2.0) + Traceback (most recent call last): + ... + NotImplementedError: Could not find signature for add: + >>> # When ``add`` detects a warning it calls the ``on_ambiguity`` callback + >>> # with a dispatcher/itself, and a set of ambiguous type signature pairs + >>> # as inputs. See ``ambiguity_warn`` for an example. + """ + # Handle annotations + if not signature: + annotations = self.get_func_annotations(func) + if annotations: + signature = annotations + + # Handle union types + if any(isinstance(typ, tuple) for typ in signature): + for typs in expand_tuples(signature): + self.add(typs, func) + return + + new_signature = [] + + for index, typ in enumerate(signature, start=1): + if not isinstance(typ, (type, list)): + str_sig = ", ".join( + c.__name__ if isinstance(c, type) else str(c) for c in signature + ) + raise TypeError( + f"Tried to dispatch on non-type: {typ}\n" + f"In signature: <{str_sig}>\n" + f"In function: {self.name}" + ) + + # handle variadic signatures + if isinstance(typ, list): + if index != len(signature): + raise TypeError("Variadic signature must be the last element") + + if len(typ) != 1: + raise TypeError( + "Variadic signature must contain exactly one element. " + "To use a variadic union type place the desired types " + "inside of a tuple, e.g., [(int, str)]" + ) + new_signature.append(Variadic[typ[0]]) + else: + new_signature.append(typ) + + self.funcs[tuple(new_signature)] = func + self._cache.clear() + + try: + del self._ordering + except AttributeError: + pass + + @property + def ordering(self): + try: + return self._ordering + except AttributeError: + return self.reorder() + + def reorder(self, on_ambiguity=ambiguity_warn): + self._ordering = od = ordering(self.funcs) + amb = ambiguities(self.funcs) + if amb: + on_ambiguity(self, amb) + return od + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + try: + func = self._cache[types] + except KeyError as e: + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) from e + self._cache[types] = func + try: + return func(*args, **kwargs) + + except MDNotImplementedError as e: + funcs = self.dispatch_iter(*types) + next(funcs) # burn first + for func in funcs: + try: + return func(*args, **kwargs) + except MDNotImplementedError: + pass + + raise NotImplementedError( + "Matching functions for " + f"{self.name}: <{str_signature(types)}> found, but none completed successfully", + ) from e + + def __str__(self): + return f"" + + __repr__ = __str__ + + def dispatch(self, *types): + """Determine appropriate implementation for this type signature + This method is internal. Users should call this object as a function. + Implementation resolution occurs within the ``__call__`` method. + >>> # xdoctest: +SKIP + >>> from multipledispatch import dispatch + >>> @dispatch(int) + ... def inc(x): + ... return x + 1 + >>> implementation = inc.dispatch(int) + >>> implementation(3) + 4 + >>> print(inc.dispatch(float)) + None + See Also: + ``multipledispatch.conflict`` - module to determine resolution order + """ + + if types in self.funcs: + return self.funcs[types] + + try: + return next(self.dispatch_iter(*types)) + except StopIteration: + return None + + def dispatch_iter(self, *types): + n = len(types) + for signature in self.ordering: + if len(signature) == n and all(map(issubclass, types, signature)): + result = self.funcs[signature] + yield result + elif len(signature) and isvariadic(signature[-1]): + if variadic_signature_matches(types, signature): + result = self.funcs[signature] + yield result + + @deprecated( + "`resolve()` is deprecated, use `dispatch(*types)`", category=FutureWarning + ) + def resolve(self, types): + """Determine appropriate implementation for this type signature + .. deprecated:: 0.4.4 + Use ``dispatch(*types)`` instead + """ + return self.dispatch(*types) + + def __getstate__(self): + return {"name": self.name, "funcs": self.funcs} + + def __setstate__(self, d): + self.name = d["name"] + self.funcs = d["funcs"] + self._ordering = ordering(self.funcs) + self._cache = {} + + @property + def __doc__(self): # type: ignore[override] + docs = [f"Multiply dispatched method: {self.name}"] + + if self.doc: + docs.append(self.doc) + + other = [] + for sig in self.ordering[::-1]: + func = self.funcs[sig] + if func.__doc__: + s = f"Inputs: <{str_signature(sig)}>\n" + s += "-" * len(s) + "\n" + s += func.__doc__.strip() + docs.append(s) + else: + other.append(str_signature(sig)) + + if other: + docs.append("Other signatures:\n " + "\n ".join(other)) + + return "\n\n".join(docs) + + def _help(self, *args): + return self.dispatch(*map(type, args)).__doc__ + + def help(self, *args, **kwargs): + """Print docstring for the function corresponding to inputs""" + print(self._help(*args)) + + def _source(self, *args): + func = self.dispatch(*map(type, args)) + if not func: + raise TypeError("No function found") + return source(func) + + def source(self, *args, **kwargs): + """Print source code for the function corresponding to inputs""" + print(self._source(*args)) + + +def source(func): + s = f"File: {inspect.getsourcefile(func)}\n\n" + s = s + inspect.getsource(func) + return s + + +class MethodDispatcher(Dispatcher): + """Dispatch methods based on type signature + See Also: + Dispatcher + """ + + __slots__ = ("obj", "cls") + + @classmethod + def get_func_params(cls, func): + if hasattr(inspect, "signature"): + sig = inspect.signature(func) + return itl.islice(sig.parameters.values(), 1, None) + + def __get__(self, instance, owner): + self.obj = instance + self.cls = owner + return self + + def __call__(self, *args, **kwargs): + types = tuple([type(arg) for arg in args]) + func = self.dispatch(*types) + if not func: + raise NotImplementedError( + f"Could not find signature for {self.name}: <{str_signature(types)}>" + ) + return func(self.obj, *args, **kwargs) + + +def str_signature(sig): + """String representation of type signature + >>> str_signature((int, float)) + 'int, float' + """ + return ", ".join(cls.__name__ for cls in sig) + + +def warning_text(name, amb): + """The text for ambiguity warnings""" + text = f"\nAmbiguities exist in dispatched function {name}\n\n" + text += "The following signatures may result in ambiguous behavior:\n" + for pair in amb: + text += "\t" + ", ".join("[" + str_signature(s) + "]" for s in pair) + "\n" + text += "\n\nConsider making the following additions:\n\n" + text += "\n\n".join( + [ + "@dispatch(" + str_signature(super_signature(s)) + f")\ndef {name}(...)" + for s in amb + ] + ) + return text diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f407ff45e29391f80e73b5226dc7cbe32e50160b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/utils.py @@ -0,0 +1,127 @@ +# mypy: allow-untyped-defs +from collections import OrderedDict + + +__all__ = ["raises", "expand_tuples", "reverse_dict", "groupby", "typename"] + + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +def expand_tuples(L): + """ + >>> expand_tuples([1, (2, 3)]) + [(1, 2), (1, 3)] + >>> expand_tuples([1, 2]) + [(1, 2)] + """ + if not L: + return [()] + elif not isinstance(L[0], tuple): + rest = expand_tuples(L[1:]) + return [(L[0],) + t for t in rest] + else: + rest = expand_tuples(L[1:]) + return [(item,) + t for t in rest for item in L[0]] + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> _toposort({1: (2, 3), 2: (3,)}) + [1, 2, 3] + >>> # Closely follows the wikipedia page [2] + >>> # [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + >>> # Communications of the ACM + >>> # [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = OrderedDict((k, set(val)) for k, val in incoming_edges.items()) + S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges) + L = [] + + while S: + n, _ = S.popitem() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S[m] = None + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = OrderedDict() # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +# Taken from toolz +# Avoids licensing issues because this version was authored by Matthew Rocklin +def groupby(func, seq): + """Group a collection by a key function + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + See Also: + ``countby`` + """ + + d = OrderedDict() # type: ignore[var-annotated] + for item in seq: + key = func(item) + if key not in d: + d[key] = [] + d[key].append(item) + return d + + +def typename(type): + """Get the name of `type`. + Parameters + ---------- + type : Union[Type, Tuple[Type]] + Returns + ------- + str + The name of `type` or a tuple of the names of the types in `type`. + Examples + -------- + >>> typename(int) + 'int' + >>> typename((int, float)) + '(int, float)' + """ + try: + return type.__name__ + except AttributeError: + if len(type) == 1: + return typename(*type) + return f"({', '.join(map(typename, type))})" diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py new file mode 100644 index 0000000000000000000000000000000000000000..fd332c44de52d6e4dc73bee2539088aa5655afb3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/multipledispatch/variadic.py @@ -0,0 +1,96 @@ +# mypy: allow-untyped-defs +from .utils import typename + + +__all__ = ["VariadicSignatureType", "isvariadic", "VariadicSignatureMeta", "Variadic"] + + +class VariadicSignatureType(type): + # checking if subclass is a subclass of self + def __subclasscheck__(cls, subclass): + other_type = subclass.variadic_type if isvariadic(subclass) else (subclass,) + return subclass is cls or all( + issubclass(other, cls.variadic_type) # type: ignore[attr-defined] + for other in other_type + ) + + def __eq__(cls, other): + """ + Return True if other has the same variadic type + Parameters + ---------- + other : object (type) + The object (type) to check + Returns + ------- + bool + Whether or not `other` is equal to `self` + """ + return isvariadic(other) and set(cls.variadic_type) == set(other.variadic_type) # type: ignore[attr-defined] + + def __hash__(cls): + return hash((type(cls), frozenset(cls.variadic_type))) # type: ignore[attr-defined] + + +def isvariadic(obj): + """Check whether the type `obj` is variadic. + Parameters + ---------- + obj : type + The type to check + Returns + ------- + bool + Whether or not `obj` is variadic + Examples + -------- + >>> # xdoctest: +SKIP + >>> isvariadic(int) + False + >>> isvariadic(Variadic[int]) + True + """ + return isinstance(obj, VariadicSignatureType) + + +class VariadicSignatureMeta(type): + """A metaclass that overrides ``__getitem__`` on the class. This is used to + generate a new type for Variadic signatures. See the Variadic class for + examples of how this behaves. + """ + + def __getitem__(cls, variadic_type): + if not (isinstance(variadic_type, (type, tuple)) or type(variadic_type)): + raise ValueError( + "Variadic types must be type or tuple of types" + " (Variadic[int] or Variadic[(int, float)]" + ) + + if not isinstance(variadic_type, tuple): + variadic_type = (variadic_type,) + return VariadicSignatureType( + f"Variadic[{typename(variadic_type)}]", + (), + dict(variadic_type=variadic_type, __slots__=()), + ) + + +class Variadic(metaclass=VariadicSignatureMeta): + """A class whose getitem method can be used to generate a new type + representing a specific variadic signature. + Examples + -------- + >>> # xdoctest: +SKIP + >>> Variadic[int] # any number of int arguments + + >>> Variadic[(int, str)] # any number of one of int or str arguments + + >>> issubclass(int, Variadic[int]) + True + >>> issubclass(int, Variadic[(int, str)]) + True + >>> issubclass(str, Variadic[(int, str)]) + True + >>> issubclass(float, Variadic[(int, str)]) + False + """ diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/unification_tools.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/unification_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..41e0fb4cc9e5d428f3b2961450ca163a7003cb4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/unification_tools.py @@ -0,0 +1,419 @@ +# mypy: allow-untyped-defs +import collections +import operator +from collections.abc import Mapping +from functools import reduce + + +__all__ = [ + "merge", + "merge_with", + "valmap", + "keymap", + "itemmap", + "valfilter", + "keyfilter", + "itemfilter", + "assoc", + "dissoc", + "assoc_in", + "update_in", + "get_in", +] + + +def _get_factory(f, kwargs): + factory = kwargs.pop("factory", dict) + if kwargs: + raise TypeError( + f"{f.__name__}() got an unexpected keyword argument '{kwargs.popitem()[0]}'" + ) + return factory + + +def merge(*dicts, **kwargs): + """Merge a collection of dictionaries + + >>> merge({1: "one"}, {2: "two"}) + {1: 'one', 2: 'two'} + + Later dictionaries have precedence + + >>> merge({1: 2, 3: 4}, {3: 3, 4: 4}) + {1: 2, 3: 3, 4: 4} + + See Also: + merge_with + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge, kwargs) + + rv = factory() + for d in dicts: + rv.update(d) + return rv + + +def merge_with(func, *dicts, **kwargs): + """Merge dictionaries and apply function to combined values + + A key may occur in more than one dict, and all values mapped from the key + will be passed to the function as a list, such as func([val1, val2, ...]). + + >>> merge_with(sum, {1: 1, 2: 2}, {1: 10, 2: 20}) + {1: 11, 2: 22} + + >>> merge_with(first, {1: 1, 2: 2}, {2: 20, 3: 30}) # doctest: +SKIP + {1: 1, 2: 2, 3: 30} + + See Also: + merge + """ + if len(dicts) == 1 and not isinstance(dicts[0], Mapping): + dicts = dicts[0] + factory = _get_factory(merge_with, kwargs) + + result = factory() + for d in dicts: + for k, v in d.items(): + if k not in result: + result[k] = [v] + else: + result[k].append(v) + return valmap(func, result, factory) + + +def valmap(func, d, factory=dict): + """Apply function to values of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> valmap(sum, bills) # doctest: +SKIP + {'Alice': 65, 'Bob': 45} + + See Also: + keymap + itemmap + """ + rv = factory() + rv.update(zip(d.keys(), map(func, d.values()))) + return rv + + +def keymap(func, d, factory=dict): + """Apply function to keys of dictionary + + >>> bills = {"Alice": [20, 15, 30], "Bob": [10, 35]} + >>> keymap(str.lower, bills) # doctest: +SKIP + {'alice': [20, 15, 30], 'bob': [10, 35]} + + See Also: + valmap + itemmap + """ + rv = factory() + rv.update(zip(map(func, d.keys()), d.values())) + return rv + + +def itemmap(func, d, factory=dict): + """Apply function to items of dictionary + + >>> accountids = {"Alice": 10, "Bob": 20} + >>> itemmap(reversed, accountids) # doctest: +SKIP + {10: "Alice", 20: "Bob"} + + See Also: + keymap + valmap + """ + rv = factory() + rv.update(map(func, d.items())) + return rv + + +def valfilter(predicate, d, factory=dict): + """Filter items in dictionary by value + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> valfilter(iseven, d) + {1: 2, 3: 4} + + See Also: + keyfilter + itemfilter + valmap + """ + rv = factory() + for k, v in d.items(): + if predicate(v): + rv[k] = v + return rv + + +def keyfilter(predicate, d, factory=dict): + """Filter items in dictionary by key + + >>> iseven = lambda x: x % 2 == 0 + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> keyfilter(iseven, d) + {2: 3, 4: 5} + + See Also: + valfilter + itemfilter + keymap + """ + rv = factory() + for k, v in d.items(): + if predicate(k): + rv[k] = v + return rv + + +def itemfilter(predicate, d, factory=dict): + """Filter items in dictionary by item + + >>> def isvalid(item): + ... k, v = item + ... return k % 2 == 0 and v < 4 + + >>> d = {1: 2, 2: 3, 3: 4, 4: 5} + >>> itemfilter(isvalid, d) + {2: 3} + + See Also: + keyfilter + valfilter + itemmap + """ + rv = factory() + for item in d.items(): + if predicate(item): + k, v = item + rv[k] = v + return rv + + +def assoc(d, key, value, factory=dict): + """Return a new dict with new key value pair + + New dict has d[key] set to value. Does not modify the initial dictionary. + + >>> assoc({"x": 1}, "x", 2) + {'x': 2} + >>> assoc({"x": 1}, "y", 3) # doctest: +SKIP + {'x': 1, 'y': 3} + """ + d2 = factory() + d2.update(d) + d2[key] = value + return d2 + + +def dissoc(d, *keys, **kwargs): + """Return a new dict with the given key(s) removed. + + New dict has d[key] deleted for each supplied key. + Does not modify the initial dictionary. + + >>> dissoc({"x": 1, "y": 2}, "y") + {'x': 1} + >>> dissoc({"x": 1, "y": 2}, "y", "x") + {} + >>> dissoc({"x": 1}, "y") # Ignores missing keys + {'x': 1} + """ + factory = _get_factory(dissoc, kwargs) + d2 = factory() + + if len(keys) < len(d) * 0.6: + d2.update(d) + for key in keys: + if key in d2: + del d2[key] + else: + remaining = set(d) + remaining.difference_update(keys) + for k in remaining: + d2[k] = d[k] + return d2 + + +def assoc_in(d, keys, value, factory=dict): + """Return a new dict with new, potentially nested, key value pair + + >>> purchase = { + ... "name": "Alice", + ... "order": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> assoc_in(purchase, ["order", "costs"], [0.25, 1.00]) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'order': {'costs': [0.25, 1.00], 'items': ['Apple', 'Orange']}} + """ + return update_in(d, keys, lambda x: value, value, factory) + + +def update_in(d, keys, func, default=None, factory=dict): + """Update value in a (potentially) nested dictionary + + inputs: + d - dictionary on which to operate + keys - list or tuple giving the location of the value to be changed in d + func - function to operate on that value + + If keys == [k0,..,kX] and d[k0]..[kX] == v, update_in returns a copy of the + original dictionary with v replaced by func(v), but does not mutate the + original dictionary. + + If k0 is not a key in d, update_in creates nested dictionaries to the depth + specified by the keys, with the innermost value set to func(default). + + >>> inc = lambda x: x + 1 + >>> update_in({"a": 0}, ["a"], inc) + {'a': 1} + + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> update_in(transaction, ["purchase", "costs"], sum) # doctest: +SKIP + {'credit card': '5555-1234-1234-1234', + 'name': 'Alice', + 'purchase': {'costs': 1.75, 'items': ['Apple', 'Orange']}} + + >>> # updating a value when k0 is not in d + >>> update_in({}, [1, 2, 3], str, default="bar") + {1: {2: {3: 'bar'}}} + >>> update_in({1: "foo"}, [2, 3, 4], inc, 0) + {1: 'foo', 2: {3: {4: 1}}} + """ + ks = iter(keys) + k = next(ks) + + rv = inner = factory() + rv.update(d) + + for key in ks: + if k in d: + d = d[k] + dtemp = factory() + dtemp.update(d) + else: + d = dtemp = factory() + + inner[k] = inner = dtemp + k = key + + if k in d: + inner[k] = func(d[k]) + else: + inner[k] = func(default) + return rv + + +def get_in(keys, coll, default=None, no_default=False): + """Returns coll[i0][i1]...[iX] where [i0, i1, ..., iX]==keys. + + If coll[i0][i1]...[iX] cannot be found, returns ``default``, unless + ``no_default`` is specified, then it raises KeyError or IndexError. + + ``get_in`` is a generalization of ``operator.getitem`` for nested data + structures such as dictionaries and lists. + + >>> transaction = { + ... "name": "Alice", + ... "purchase": {"items": ["Apple", "Orange"], "costs": [0.50, 1.25]}, + ... "credit card": "5555-1234-1234-1234", + ... } + >>> get_in(["purchase", "items", 0], transaction) + 'Apple' + >>> get_in(["name"], transaction) + 'Alice' + >>> get_in(["purchase", "total"], transaction) + >>> get_in(["purchase", "items", "apple"], transaction) + >>> get_in(["purchase", "items", 10], transaction) + >>> get_in(["purchase", "total"], transaction, 0) + 0 + >>> get_in(["y"], {}, no_default=True) + Traceback (most recent call last): + ... + KeyError: 'y' + + See Also: + itertoolz.get + operator.getitem + """ + try: + return reduce(operator.getitem, keys, coll) + except (KeyError, IndexError, TypeError): + if no_default: + raise + return default + + +def getter(index): + if isinstance(index, list): + if len(index) == 1: + index = index[0] + return lambda x: (x[index],) + elif index: + return operator.itemgetter(*index) + else: + return lambda x: () + else: + return operator.itemgetter(index) + + +def groupby(key, seq): + """Group a collection by a key function + + >>> names = ["Alice", "Bob", "Charlie", "Dan", "Edith", "Frank"] + >>> groupby(len, names) # doctest: +SKIP + {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']} + + >>> iseven = lambda x: x % 2 == 0 + >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP + {False: [1, 3, 5, 7], True: [2, 4, 6, 8]} + + Non-callable keys imply grouping on a member. + + >>> groupby( + ... "gender", + ... [ + ... {"name": "Alice", "gender": "F"}, + ... {"name": "Bob", "gender": "M"}, + ... {"name": "Charlie", "gender": "M"}, + ... ], + ... ) # doctest:+SKIP + {'F': [{'gender': 'F', 'name': 'Alice'}], + 'M': [{'gender': 'M', 'name': 'Bob'}, + {'gender': 'M', 'name': 'Charlie'}]} + + Not to be confused with ``itertools.groupby`` + + See Also: + countby + """ + if not callable(key): + key = getter(key) + d = collections.defaultdict(lambda: [].append) # type: ignore[var-annotated] + for item in seq: + d[key(item)](item) + rv = {} + for k, v in d.items(): + rv[k] = v.__self__ # type: ignore[var-annotated, attr-defined] + return rv + + +def first(seq): + """The first element in a sequence + + >>> first("ABC") + 'A' + """ + return next(iter(seq)) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/utils.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4ec4ab6daf4835f0bbbf7ce6486e0460daae2e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/utils.py @@ -0,0 +1,108 @@ +# mypy: allow-untyped-defs +__all__ = ["hashable", "transitive_get", "raises", "reverse_dict", "xfail", "freeze"] + + +def hashable(x): + try: + hash(x) + return True + except TypeError: + return False + + +def transitive_get(key, d): + """Transitive dict.get + >>> d = {1: 2, 2: 3, 3: 4} + >>> d.get(1) + 2 + >>> transitive_get(1, d) + 4 + """ + while hashable(key) and key in d: + key = d[key] + return key + + +def raises(err, lamda): + try: + lamda() + return False + except err: + return True + + +# Taken from theano/theano/gof/sched.py +# Avoids licensing issues because this was written by Matthew Rocklin +def _toposort(edges): + """Topological sort algorithm by Kahn [1] - O(nodes + vertices) + inputs: + edges - a dict of the form {a: {b, c}} where b and c depend on a + outputs: + L - an ordered list of nodes that satisfy the dependencies of edges + >>> # xdoctest: +SKIP + >>> _toposort({1: (2, 3), 2: (3,)}) + [1, 2, 3] + Closely follows the wikipedia page [2] + [1] Kahn, Arthur B. (1962), "Topological sorting of large networks", + Communications of the ACM + [2] http://en.wikipedia.org/wiki/Toposort#Algorithms + """ + incoming_edges = reverse_dict(edges) + incoming_edges = {k: set(val) for k, val in incoming_edges.items()} + S = {v for v in edges if v not in incoming_edges} + L = [] + + while S: + n = S.pop() + L.append(n) + for m in edges.get(n, ()): + assert n in incoming_edges[m] + incoming_edges[m].remove(n) + if not incoming_edges[m]: + S.add(m) + if any(incoming_edges.get(v, None) for v in edges): + raise ValueError("Input has cycles") + return L + + +def reverse_dict(d): + """Reverses direction of dependence dict + >>> d = {"a": (1, 2), "b": (2, 3), "c": ()} + >>> reverse_dict(d) # doctest: +SKIP + {1: ('a',), 2: ('a', 'b'), 3: ('b',)} + :note: dict order are not deterministic. As we iterate on the + input dict, it make the output of this function depend on the + dict order. So this function output order should be considered + as undeterministic. + """ + result = {} # type: ignore[var-annotated] + for key in d: + for val in d[key]: + result[val] = result.get(val, ()) + (key,) + return result + + +def xfail(func): + try: + func() + raise Exception("XFailed test passed") # pragma:nocover # noqa: TRY002 + except Exception: + pass + + +def freeze(d): + """Freeze container to hashable form + >>> freeze(1) + 1 + >>> freeze([1, 2]) + (1, 2) + >>> freeze({1: 2}) # doctest: +SKIP + frozenset([(1, 2)]) + """ + if isinstance(d, dict): + return frozenset(map(freeze, d.items())) + if isinstance(d, set): + return frozenset(map(freeze, d)) + if isinstance(d, (tuple, list)): + return tuple(map(freeze, d)) + return d diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unification/variable.py b/phivenv/Lib/site-packages/torch/fx/experimental/unification/variable.py new file mode 100644 index 0000000000000000000000000000000000000000..8fc15a7b926b40909c7da66bba05995f208af415 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unification/variable.py @@ -0,0 +1,90 @@ +# mypy: allow-untyped-defs +from contextlib import contextmanager + +from .dispatch import dispatch +from .utils import hashable + + +_global_logic_variables = set() # type: ignore[var-annotated] +_glv = _global_logic_variables + + +class Var: + """Logic Variable""" + + _id = 1 + + def __new__(cls, *token): + if len(token) == 0: + token = f"_{Var._id}" # type: ignore[assignment] + Var._id += 1 + elif len(token) == 1: + token = token[0] + + obj = object.__new__(cls) + obj.token = token # type: ignore[attr-defined] + return obj + + def __str__(self): + return "~" + str(self.token) # type: ignore[attr-defined] + + __repr__ = __str__ + + def __eq__(self, other): + return type(self) == type(other) and self.token == other.token # type: ignore[attr-defined] + + def __hash__(self): + return hash((type(self), self.token)) # type: ignore[attr-defined] + + +def var(): + return lambda *args: Var(*args) + + +def vars(): + return lambda n: [var() for i in range(n)] + + +@dispatch(Var) +def isvar(v): + return True + + +isvar + + +@dispatch(object) # type: ignore[no-redef] +def isvar(o): + return not not _glv and hashable(o) and o in _glv + + +@contextmanager +def variables(*variables): + """ + Context manager for logic variables + + Example: + >>> # xdoctest: +SKIP("undefined vars") + >>> from __future__ import with_statement + >>> with variables(1): + ... print(isvar(1)) + True + >>> print(isvar(1)) + False + >>> # Normal approach + >>> from unification import unify + >>> x = var("x") + >>> unify(x, 1) + {~x: 1} + >>> # Context Manager approach + >>> with variables("x"): + ... print(unify("x", 1)) + {'x': 1} + """ + old_global_logic_variables = _global_logic_variables.copy() + _global_logic_variables.update(set(variables)) + try: + yield + finally: + _global_logic_variables.clear() + _global_logic_variables.update(old_global_logic_variables) diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/unify_refinements.py b/phivenv/Lib/site-packages/torch/fx/experimental/unify_refinements.py new file mode 100644 index 0000000000000000000000000000000000000000..88e433a0a3c5050689318ca13b6585108825c1de --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/unify_refinements.py @@ -0,0 +1,124 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.graph_gradual_typechecker import Refine +from torch.fx.experimental.unification import unify, Var # type: ignore[attr-defined] +from torch.fx.tensor_type import TensorType + + +def infer_symbolic_types_single_pass(traced): + """ + Calls our symbolic inferencer once. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + +def infer_symbolic_types(traced): + """ + Calls our symbolic inferencer twice. + This is useful when one pass is not enough + to infer all the information such as the case + for braodcasting. + """ + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r = Refine(traced) + r.refine() + mgu = unify_eq(r.constraints) + substitute_all_types(traced.graph, mgu) + + r.symbolic_relations() + + +def convert_eq(list_of_eq): + """ + Convert equality constraints in the right format + to be used by unification library. + """ + lhs = [] + rhs = [] + for eq in list_of_eq: + lhs.append(eq.lhs) + rhs.append(eq.rhs) + return tuple(lhs), tuple(rhs) + + +def unify_eq(list_of_eq): + """ + Apply unification to a set of + equality constraints + """ + lhs, rhs = convert_eq(list_of_eq) + return unify(lhs, rhs) + + +def substitute_solution_one_type(mapping, t): + """ + Apply the most general unifier to a type + """ + if isinstance(t, Var): + if t in mapping.keys(): + return mapping[t] + else: + return t + + elif isinstance(t, TensorType): + new_type = [] + for typ in t.__args__: + if typ in mapping.keys(): + new_type.append(mapping[typ]) + else: + new_type.append(typ) + return TensorType(tuple(new_type)) + + elif isinstance(t, list): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return new_type + + elif isinstance(t, tuple): + new_type = [] + for typ in t: + new_type.append(substitute_solution_one_type(mapping, typ)) + return tuple(new_type) + + else: + return t + + +def substitute_all_types(graph, mapping): + """ + Apply the most general unifier to all types in a graph + till reaching a fixed point. If the input and output graph + are the same, we converge. + """ + flag = True + while flag: + flag = False + for k in mapping: + old_mapping_val = mapping[k] + if mapping[k] in mapping.keys(): + new_key = mapping[k] + mapping[k] = mapping[new_key] + if old_mapping_val != mapping[k]: + flag = True + + for n in graph.nodes: + n.type = substitute_solution_one_type(mapping, n.type) + + +def check_for_type_equality(g1, g2): + """ + A check equality to be used in fixed points. + We do not use graph equality but instead type + equality. + """ + for n, m in zip(g1.nodes, g2.nodes): + if n.type != m.type: + return False + return True diff --git a/phivenv/Lib/site-packages/torch/fx/experimental/validator.py b/phivenv/Lib/site-packages/torch/fx/experimental/validator.py new file mode 100644 index 0000000000000000000000000000000000000000..cd24e8cc2abd7448fb1733ea448b7719de3a9144 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/experimental/validator.py @@ -0,0 +1,869 @@ +# mypy: allow-untyped-defs +import builtins +import functools +import logging +import math +import operator +from dataclasses import dataclass +from typing import Any, Callable, Optional, Union + +import sympy + +import torch +import torch.fx +import torch.fx.traceback as fx_traceback +from torch._dynamo.exc import TorchDynamoException +from torch._dynamo.utils import dynamo_timed +from torch.fx.node import Argument, Target +from torch.utils._sympy.interp import sympy_interp + + +log = logging.getLogger(__name__) + +try: + import z3 # type: ignore[import] + + # Translation Validation for Dynamo guards + # ======================================== + # + # Checks whether optimizations applied to the collected guards are + # valid. In other words, whether the guard function we actually run + # does not have false positives (unsound). + # + # In order to do so, we build the guards using 2 different information + # attached to each 'SymNode': + # 1. SymPy expressions + # 2. FX nodes + # + # SymPy expressions have implicit optimizations baked within itself, + # which may have a few bugs. On the other hand, we build the FX graph + # manually, with no optimizations enabled. This gives us access to + # the "ground truth". + # + # We then convert into Z3 expressions both the SymPy expressions + # (see [Note: SympyToZ3]) that reach 'ShapeEnv.produce_guards' function + # and the FX nodes (see [Note: PopulateValidator]) that go through + # 'ShapeEnv.evaluate_expr' function. Finally, we run the validation. + # (see [Note: TranslationValidator]) + # Better Z3 to string implementation (for a small fraction of Z3). + # + # Here are the things we clean before showing the Z3 expression: + # - Rename a few ops (e.g. "Distinct" ==> "!=") + # + # - Ignore ToInt and ToReal operations: + # usually they don't really matter + # + # - Transform (ToInt (/ ...)) into (idiv ...): + # this is the pattern for floor division + # + # - Collect a chain of the same operations into one + def z3str(e: z3.ExprRef) -> str: + assert z3.is_expr(e), f"unsupported expression type: {e}" + + def get_args_str(e: z3.ExprRef) -> list[str]: + return [z3str(e.arg(i)) for i in range(e.num_args())] + + # First, we simplify the given expression. + # This is done using rewriting rules, so shouldn't take long. + e = z3.simplify(e) + + # Only support function applications. + # Even Z3 "variables" are, in fact, function applications. + if not z3.is_app(e): + raise ValueError(f"can't print Z3 expression: {e}") + + if z3.is_int_value(e) or z3.is_rational_value(e): + return e.as_string() # type: ignore[attr-defined] + + decl = e.decl() + kind = decl.kind() + op = str(decl) + args = get_args_str(e) + + if kind == z3.Z3_OP_POWER: + op = "pow" + + elif kind in (z3.Z3_OP_ADD, z3.Z3_OP_MUL): + # Collect the arguments of chains of ADD and MUL. + # This is safe, since they are associative. + + def collect_str_args(e): + if not (z3.is_app(e) and e.decl().kind() == kind): + return [z3str(e)] + else: + return [ + x + for i in range(e.num_args()) + for x in collect_str_args(e.arg(i)) + ] + + args = collect_str_args(e) + + elif kind == z3.Z3_OP_NOT: + # Revert some conversions that z3.simplify applies: + # - a != b ==> (Not (== a b)) ==> (!= a b) + # - a < b ==> (Not (<= b a)) ==> (> b a) + # - a > b ==> (Not (<= a b)) ==> (> a b) + + assert e.num_args() == 1 + arg = e.arg(0) + + assert z3.is_app(arg) + argkind = arg.decl().kind() + + logic_inverse = { + z3.Z3_OP_EQ: "!=", + z3.Z3_OP_LE: ">", + z3.Z3_OP_GE: "<", + } + + if argkind in logic_inverse: + op = logic_inverse[argkind] + args = get_args_str(arg) + + elif kind in (z3.Z3_OP_TO_INT, z3.Z3_OP_TO_REAL): + assert e.num_args() == 1 + argstr = z3str(e.arg(0)) + + # Check if it's the floor division pattern. + if argstr.startswith("(/"): + return "(idiv" + argstr[2:] + + # Otherwise, just ignore it. + return argstr + + elif kind == z3.Z3_OP_UNINTERPRETED: + assert e.num_args() == 0 + return str(decl) + + string = op + " " + " ".join(args) + return f"({string.rstrip()})" + + # We need to convert to/from BitVec in order to use z3 bitwise ops. + # We assume that integers are 64 bit. + # If all args are boolean, then use the boolean bitwise op implementation instead, if provided. + def _bitwise_op(bitwise_func, bool_func): + @functools.wraps(bitwise_func) + def wrapper(self, *args): + if bool_func is not None and all( + isinstance(arg, z3.BoolRef) for arg in args + ): + return bool_func(*args) + + wrapped_args = tuple(z3.Int2BV(a, 64) for a in args) + return z3.BV2Int(bitwise_func(*wrapped_args)) + + return wrapper + + # Implementation of Python semantics as Z3 expressions. + # + # Z3 Real-Int theory has operators with semantics that differ that of + # Python. Therefore, in order to get it right, we need to implement + # the (Python) semantics we are relying on in Z3. + @dataclass + class _Z3Ops: + # Validator used for adding assertions as needed. + # e.g. div(a, b) requires b != 0. + validator: "TranslationValidator" + + # The 2 functions below are used for conditionally casting between + # integer and reals. + # + # Returns a real expression from 'x'. + @staticmethod + def to_real(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_real() else z3.ToReal(x) + + # Returns an integer expression from 'x'. + @staticmethod + def to_int(x: z3.ArithRef) -> z3.ArithRef: + return x if x.is_int() else z3.ToInt(x) + + def sym_sum(self, args: z3.ArithRef) -> z3.ArithRef: + return sum(args) + + # Implements Python division semantics. + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + self.validator.add_assertion(denominator != 0) # type: ignore[arg-type] + return _Z3Ops.to_real(numerator) / _Z3Ops.to_real(denominator) + + def floor(self, number: z3.ArithRef) -> z3.ArithRef: + # Z3 ToInt function rounds a real number towards negative infinity. + return _Z3Ops.to_int(number) + + # Python semantics for 'FloorDiv' states that before applying the floor + # function, the operands are converted to their common type. + def floordiv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + cast_result_to_real = numerator.is_real() or denominator.is_real() + result = _Z3Ops.to_int(self.div(numerator, denominator)) + # Since the 'result' is already an integer, we just have to check + # whether we should cast it to real. + return _Z3Ops.to_real(result) if cast_result_to_real else result + + def ceil(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value] + + def trunc(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value] + + def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a > b, a, b) # type: ignore[return-value] + + def min(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef: + return z3.If(a < b, a, b) # type: ignore[return-value] + + # Python semantics for 'Mod' is defined as: p % q = p - floordiv(p, q) * q + # It should work with both integer and reals. + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return p - self.floordiv(p, q) * q + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + # Z3 can't handle complex numbers very well. + self.validator.add_assertion(z3.Or(base != 0, exp > 0)) # type: ignore[arg-type] + return base**exp + + def sqrt(self, number: z3.ArithRef) -> z3.ArithRef: + # Square-root: + # 1. Only work with reals + number = _Z3Ops.to_real(number) + # 2. The number should be positive or zero. + # Otherwise, Z3 returns 'unknown'. + self.validator.add_assertion(number >= 0) + return number**0.5 + + def abs(self, number: z3.ArithRef) -> z3.ArithRef: + return z3.Abs(number) + + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: + # Pythons builtin 'round' implements the 'round half to even' strategy + # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even + # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to + # floating point numbers, which is different from real numbers that we are dealing with here. + # Instead, we implement 'round half to even' in terms of 'round half up' (floor(x + 0.5)) and + # 'round half down' (ceil(x - 0.5)). + # Assuming 'round half up' is the default case, we need to correct ..., -3.5, -1.5, 0.5, 2.5, 4.5, ... + # to round down, i.e. use the 'round half down' strategy + return z3.If( + self.mod(number, z3.IntVal(2)) == 0.5, + self.ceil(number - 0.5), + self.floor(number + 0.5), + ) + + bitwise_and = _bitwise_op(operator.and_, z3.And) + bitwise_or = _bitwise_op(operator.or_, z3.Or) + lshift = _bitwise_op(operator.lshift, None) + rshift = _bitwise_op(operator.rshift, None) + + # Lifts a callable to be used in Z3. + # + # This function replaces the given 'op' by a function that: + # + # 1. Lifts the arguments into Z3 (i.e. make them inhabitants of Z3) + # + # 2. Calls an operation that corresponds to 'op', but works with Z3 + # inhabitants (left as is if it works as is) + def z3op(op: Callable, validator: "TranslationValidator") -> Callable: + # Operations that have booleans as their argument. + # This is needed because the argument of some FX nodes were + # literal integers, instead of booleans. So, whenever this flag + # is set, we also convert ints to booleans. + boolean_ops = {operator.not_} + as_bool = op in boolean_ops + + # Lifts the function into 'z3.ExprRef' domain. + def lift(func): + def wrap(a) -> z3.ExprRef: + if isinstance(a, (z3.ArithRef, z3.BoolRef)): + return a + # Convert it into a Z3 value, if it is some of the supported + # types below. + if isinstance(a, bool) or (as_bool and isinstance(a, int)): + return z3.BoolVal(bool(a)) + if isinstance(a, (int, sympy.Integer)): + return z3.IntVal(int(a)) + if isinstance(a, (float, sympy.Float)): + return z3.RealVal(float(a)) + raise ValueError(f"can't lift type: {type(a)}") + + @functools.wraps(func) + def wrapper(*args): + # Lifts the arguments into a list of Z3 inhabitants. + if len(args) == 1 and isinstance(args[0], (list, tuple)): + wrapped_args = (tuple(wrap(a) for a in args[0]),) + else: + wrapped_args = tuple(wrap(a) for a in args) + # Run the function on the Z3 expressions. + return func(*wrapped_args) + + return wrapper + + ops = _Z3Ops(validator) + replacement_map = { + # Operator module. + operator.not_: lift(z3.Not), + operator.and_: lift(ops.bitwise_and), + operator.or_: lift(ops.bitwise_or), + operator.lshift: lift(ops.lshift), + operator.rshift: lift(ops.rshift), + operator.floordiv: lift(ops.floordiv), + operator.truediv: lift(ops.div), + operator.mod: lift(ops.mod), + operator.abs: lift(ops.abs), + builtins.round: lift(ops.round_to_int), + # Math module. + math.ceil: lift(ops.ceil), + math.floor: lift(ops.floor), + math.trunc: lift(ops.trunc), + # Torch module. + torch.sym_float: lift(ops.to_real), + torch.sym_max: lift(ops.max), + torch.sym_min: lift(ops.min), + torch.sym_sum: lift(ops.sym_sum), + torch.sym_ite: lift(lambda b, t, f: t if b else f), + torch._sym_sqrt: lift(ops.sqrt), # type: ignore[attr-defined] + # Not lifted because we only use this function as a + # marker for adding the expression as validator input. + torch._assert: torch._assert, + } + return replacement_map[op] if op in replacement_map else lift(op) + + # Processes an FX graph, populating the given validator. + # + # [Note: PopulateValidator] + # This class walks through each node in the FX graph, translating + # them into the Z3 world. + # + # Then, whenever it finds an 'torch._assert' call_function operation, + # it adds the Z3 expression corresponding to the argument as validator + # input. + class PopulateValidator(torch.fx.Interpreter): + def __init__(self, graph: torch.fx.Graph, validator: "TranslationValidator"): + # Reference to the translation validator. + self.validator = validator + + # Build the graph module and call `Interpreter` constructor. + module = torch.fx.GraphModule(root={}, graph=graph) + super().__init__(module, garbage_collect_values=True) + + def placeholder( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + symbol = fx_traceback.get_current_meta()["symbol"] + return self.validator.z3var(symbol) + + def call_function( + self, target: Target, args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + if target != torch._assert: + # Lift and runs the node target function + return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] + # Adds the Z3 expression corresponding to the first argument + # as a validator input. + assert len(args) == 1, ( + f"expected 1 argument on assertion. Got: {len(args)} " + ) + self.validator.add_source_expr(args[0]) # type: ignore[arg-type] + + # Translates SymPy expressions into Z3 expressions. + # + # [Note: SympyToZ3] + # At the time of the translation, all free variables present in the + # SymPy expression being translated must be already mapped to a Z3 + # integer variable. + class SympyToZ3: + OPERATOR_HANDLES = {"add", "mul", "eq", "ne", "lt", "gt", "le", "ge"} + + def __init__( + self, + validator: "TranslationValidator", + ) -> None: + self._validator = validator + self._ops = _Z3Ops(self._validator) + + def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision + if dtype is torch.int64: + return z3.IntVal(int(value)) + if dtype is torch.double: + return z3.RealVal(float(value)) + if dtype is torch.bool: + return z3.BoolVal(bool(value)) + raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def truediv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + + def floordiv( + self, numerator: z3.ArithRef, denominator: z3.ArithRef + ) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def div(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.floordiv(numerator, denominator) + + def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: + return self._ops.mod(p, q) + + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) + + def __getattr__(self, name: str) -> Any: + REPLACEMENT = { + "and_": z3.And, + "or_": z3.Or, + "not_": z3.Not, + "bitwise_and": self._ops.bitwise_and, + "bitwise_or": self._ops.bitwise_or, + "lshift": self._ops.lshift, + "rshift": self._ops.rshift, + "floor": self._ops.floor, + "ceil": self._ops.ceil, + "minimum": self._ops.min, + "maximum": self._ops.max, + } + + if name in REPLACEMENT: + return REPLACEMENT[name] + if name in self.OPERATOR_HANDLES: + return getattr(operator, name) + raise AttributeError(f"unhandled operator: {name}") + + def run(self, expr: sympy.Basic) -> z3.ExprRef: + return sympy_interp(self, self._validator.symbols, expr) # type: ignore[arg-type] + + # Dynamo guards translation validator. + # + # [Note: TranslationValidator] + # Verifies whether the guards issued by 'ShapeEnv.produce_guards' are sound. + # That is: whether those (target) guards only yield TRUE whenever the original, + # unoptimized, (source) guards yield TRUE. + # + # More concretely, given 'source' and 'target' guard expressions, we wish to + # check whether the following expression holds: + # + # Not(And(source)) AND And(target) + # + # i.e. whether there is an assignment of the free variables where the opposite + # happens: target is TRUE, but source is FALSE. + class TranslationValidator: + def __init__(self) -> None: + log.debug("new instance") + + # Mapping of SymPy symbols to Z3 variables. + self.symbols: dict[sympy.Symbol, z3.ExprRef] = {} + + # Set of source Z3 expressions. + # They represent the generated guards without any kind of + # simplification or transformation. + self._source_exprs: set[z3.BoolRef] = set() + + # Set of target Z3 expressions. + # They represent the actual checked guards at runtime. They might + # be simplified or transformed versions of the source guards. + self._target_exprs: set[z3.BoolRef] = set() + + # Set of Z3 expressions representing assertions over both the + # source and target expressions. + self._assertions: set[z3.BoolRef] = set() + + # Retrieves the corresponding Z3 variable. + def z3var(self, symbol: sympy.Symbol) -> z3.ExprRef: + assert symbol in self.symbols, f"Z3 variable not found for: {symbol}" + return self.symbols[symbol] + + # Create a variable in Z3 of 'type' for 'symbol', if it doesn't already exists. + def add_var(self, symbol: sympy.Symbol, type: type) -> z3.ExprRef: + if symbol in self.symbols: + return self.symbols[symbol] + + log.debug("new variable: %s (%s)", symbol.name, type.__name__) + + if type is int: + var = z3.Int(symbol.name) + + # If 'symbol' is positive (SymPy assumption), we have to + # convey it to Z3 as well. + if symbol.is_positive: # type: ignore[attr-defined] + self._target_exprs.add(var > 0) + elif type is float: + var = z3.Real(symbol.name) + elif type is bool: + var = z3.Bool(symbol.name) + else: + raise RuntimeError(f"unsupported type for Z3 variable: {type}") + + self.symbols[symbol] = var + return var + + # Checks whether all symbols were already added. + def _check_freesymbols(self, e: sympy.Basic) -> None: + for s in e.free_symbols: + assert isinstance(s, sympy.Symbol) + # Call 'z3var' just to check whether there's already a + # Z3 variable corresponding to 's'. + self.z3var(s) + + def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: + z3expr = SympyToZ3(self).run(e) + assert isinstance(z3expr, z3.BoolRef), ( + f"expected boolean expression. Got: {z3expr}" + ) + return z3expr + + def add_source_expr(self, e: z3.BoolRef) -> None: + if e not in self._source_exprs: + log.debug("add source guard: %s", z3str(e)) + self._source_exprs.add(e) + + def add_target_expr(self, e: "sympy.logic.boolalg.Boolean") -> None: + self._check_freesymbols(e) + z3expr = self.to_z3_boolean_expr(e) + if e not in self._target_exprs: + log.debug("add target guard: %s", z3str(z3expr)) + self._target_exprs.add(z3expr) + + def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: + if isinstance(e, sympy.Basic): + self._check_freesymbols(e) + ref = self.to_z3_boolean_expr(e) + else: + ref = e + assert isinstance(ref, z3.BoolRef) + if ref not in self._assertions: + log.debug("add assertion: %s", z3str(ref)) + self._assertions.add(ref) + + def validate(self) -> None: + with dynamo_timed("TranslationValidator.validate"): + return self._validate() + + def _validate(self) -> None: + if len(self._source_exprs) == 0 or len(self._target_exprs) == 0: + # If there are no source/target expressions, there's nothing we really + # wish to prove. So, we just return. + return None + + # Here, we use "QF_NRA" logic for the solver: + # "Quantifier-free Non-linear Real Arithmetic". + # + # Most of the guards expressions have: + # 1. arithmetic between integer and reals + # 2. no quantifiers + # 3. potentially non-linear. + # + # Although there's also "QF_NIRA" (mixed integer-real arithmetic), + # "QF_NRA" seems to work better on 'dynamo/test_dynamic_shapes.py'. + solver = z3.SolverFor("QF_NRA") + # Set a timeout for finding a solution. + solver.set(timeout=translation_validation_timeout()) + + # Add all the assertions to the solver. + for assertion in self._assertions: + solver.add(assertion) + + # "Is there any case where it's TRUE for the target expressions, + # but FALSE for the source expressions?" + solver.add(z3.Not(z3.And(*self._source_exprs))) + solver.add(*self._target_exprs) + + log.debug("translation validation: start") + r = solver.check() + if r == z3.sat: + # Target expressions are unsound. + # Log the found model and the source expressions that failed. + model = solver.model() + raise ValidationException( + model, + self._assertions, + self._target_exprs, + failed_source_exprs=[ + inp for inp in self._source_exprs if not model.evaluate(inp) + ], + ) + else: + if r == z3.unknown: + # Could not find a solution. It didn't fail, but it also + # didn't succeed. Canceling the validation execution (keyboard + # interrupt) also gets to this branch. + log.warning( + "translation validation: could not validate: got z3.unknown" + ) + else: + # Target expressions are sound. + assert r == z3.unsat + log.debug("translation validation: success") + +except ImportError: + _HAS_Z3 = False + + __all__ = [ + "translation_validation_enabled", + "translation_validation_timeout", + "ValidationException", + "BisectValidationException", + ] + +else: + _HAS_Z3 = True + + __all__ = [ + "z3str", + "z3op", + "PopulateValidator", + "SympyToZ3", + "TranslationValidator", + "translation_validation_enabled", + "translation_validation_timeout", + "ValidationException", + "BisectValidationException", + ] + +from torch.fx.experimental import _config as config + + +def translation_validation_enabled() -> bool: + # Checks everytime this function is called, in case the Dynamo + # option is set, but Z3 is not installed. + _assert_z3_installed_if_tv_set() + return _HAS_Z3 and config.translation_validation + + +def translation_validation_timeout() -> int: + return config.translation_validation_timeout + + +def _assert_z3_installed_if_tv_set(): + assert _HAS_Z3 or not config.translation_validation, ( + "translation validation requires Z3 package. Please, either install " + "z3-solver or disable translation validation." + ) + + +class ValidationException(TorchDynamoException): + def __init__(self, model, assertions, target_exprs, failed_source_exprs): + assert _HAS_Z3 + + def symbolstr(sym) -> str: + return f"{sym}: {model[sym]}" + + def joinlines(xs) -> str: + return "\n".join(f" ==> {x}" for x in xs) + + model_str = joinlines(sorted(map(symbolstr, model))) + assertions_str = joinlines(sorted(map(z3str, assertions))) + target_exprs_str = joinlines(sorted(map(z3str, target_exprs))) + failed_source_exprs_str = joinlines(sorted(map(z3str, failed_source_exprs))) + + self.msg = "translation validation failed." + self.details = f"""\ +Model: +{model_str} + +Assertions: +{assertions_str} + +Target Expressions: +{target_exprs_str} + +Failed Source Expressions: +{failed_source_exprs_str}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +class BisectValidationException(TorchDynamoException): + def __init__(self, validation_exc, expr, failed_action, traced_node): + self.msg = f"translation validation failed when {failed_action}: {expr}" + self.details = f"""\ +Failure occurred while running node: + {traced_node.format_node()} + +{validation_exc.details}""" + + def __str__(self): + return f"{self.msg}\n\n{self.details}" + + +# Checks when this module is loaded. +_assert_z3_installed_if_tv_set() + + +# Translation validation bisection. +# +# Bisect into the torch._assert nodes recorded in the shape_env FX graph, and raise +# the earliest ValidationException. +# +# As guards are added by ShapeEnv.evaluate_expr calls, some simplification errors +# might be silently happening. This function tries to nail down exactly at which +# point things went wrong from a validation perspective. +def bisect(shape_env): + from torch.fx.experimental.recording import ( + FakeTensorMeta, + replay_shape_env_events, + ShapeEnvEvent, + ) + from torch.fx.experimental.symbolic_shapes import ( + CURRENT_NODE_KEY, + ShapeEnv, + SHAPEENV_EVENT_KEY, + ) + + events = shape_env.events + + # Retrieves the ShapeEnvEvent associated with node. + def get_node_event(node: torch.fx.Node) -> ShapeEnvEvent: + assert SHAPEENV_EVENT_KEY in node.meta + return events[node.meta[SHAPEENV_EVENT_KEY]] + + # Creates a new instance of fake, but updating every symbolic value's ShapeEnv + # reference to the one given as argument. + # + # This is needed so as not to simplify a symbolic expression using a ShapeEnv + # "from the future", where it may have a different set of replacements. + def new_with_shape_env(shape_env: ShapeEnv, fake) -> Any: + if isinstance(fake, int): + return fake + if isinstance(fake, torch.SymInt): + return torch.SymInt(fake.node.with_shape_env(shape_env)) + if isinstance(fake, torch.SymFloat): + return torch.SymFloat(fake.node.with_shape_env(shape_env)) + assert isinstance(fake, FakeTensorMeta) + return FakeTensorMeta( + tuple(new_with_shape_env(shape_env, s) for s in fake.size()), + tuple(new_with_shape_env(shape_env, s) for s in fake.stride()), + new_with_shape_env(shape_env, fake.storage_offset()), + fake.is_nested, + ) + + # Checks whether the given shape_env fails when produce_guards is called. + def check_shapeenv_fails( + shape_env: ShapeEnv, tracked_fakes: Optional[list[Any]] + ) -> Optional[ValidationException]: + assert tracked_fakes is not None + try: + # This produce_guards call is a best-effort replication, since we + # don't populate EqualityConstraint list. Reason: we would also have + # to save OutputGraph.tracked_fakes_id_to_source. + shape_env.produce_guards( + [new_with_shape_env(shape_env, a.fake) for a in tracked_fakes], + [a.source for a in tracked_fakes], + input_contexts=[a.symbolic_context for a in tracked_fakes], + ) + return None + except ValidationException as e: + return e + + # Checks whether the ShapeEnv reconstructed by replaying the events until + # node is created fails when produce_guards is called. + def check_node_fails(node: torch.fx.Node) -> Optional[ValidationException]: + number = node.meta[SHAPEENV_EVENT_KEY] + # Reconstruct shape_env until the event at event_number. + shape_env = replay_shape_env_events(events[: number + 1]) + shape_env.graph.lint() + return check_shapeenv_fails(shape_env, events[number].tracked_fakes) + + last_exception = check_shapeenv_fails( + shape_env, shape_env._snapshot_tracked_fakes() + ) + + if not last_exception: + # We don't actually fail due to a produce_guards call. + # Stop and don't bisect. + log.info("translation validation succeeded: no errors found.") + return + + if not shape_env.should_record_events or config.translation_validation_no_bisect: + # Bisection is off. + # Return the last ValidationException we got. + raise last_exception + + # Cache the raised exception (if any) at each bisection point. + exception = {} + + # Bisection happens on the assertion nodes of the recorded FX graph for + # dynamic shapes. + assert_nodes = [ + node for node in shape_env.graph.nodes if node.target == torch._assert + ] + + # Preparing the indices for binary search. + # The overall invariants are + # - for all i < left, assert_node[i] doesn't fail + # - for all i >= right, assert_node[i] fails + # - `right in exception` always holds + # - `left <= right` always holds + left, mid, right = 0, 0, len(assert_nodes) - 1 + exception[right] = check_node_fails(assert_nodes[right]) + + while left < right: + mid = (left + right) // 2 + + node = assert_nodes[mid] + log.debug("bisecting at %s: %s", mid, get_node_event(node)) + + # Check whether the new shape_env raises a ValidationException or not. + exception[mid] = check_node_fails(node) + + if exception[mid]: + right = mid + else: + left = mid + 1 + + assert left in exception and isinstance(exception[left], ValidationException) + + node = assert_nodes[left] + event = get_node_event(node) + + if event.is_evaluate_expr(): + failed_action = "evaluating" + else: + assert event.is_defer_runtime_assert(), f"unexpected event type: {event}" + failed_action = "adding runtime assert" + + args = event.args + assert args is not None + assert len(args) >= 2, ( + f"bisecting expects {event.name} to have at least 2 positional arguments. " + f"Got: {len(args)}" + ) + assert isinstance(args[1], sympy.Basic), ( + f"bisecting expects {event.name} to have a SymPy expression as its second argument. " + f"Got: {type(args[1])}" + ) + + raise BisectValidationException( + exception[left], + expr=args[1], + failed_action=failed_action, + traced_node=node.meta[CURRENT_NODE_KEY], + ) diff --git a/phivenv/Lib/site-packages/torch/fx/graph.py b/phivenv/Lib/site-packages/torch/fx/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..85fec494b64bbf3b27e3e69df2c33e5c0c7e4f92 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/graph.py @@ -0,0 +1,2003 @@ +# mypy: allow-untyped-defs +import builtins +import contextlib +import copy +import enum +import functools +import inspect +import keyword +import math +import os +import re +import typing +import warnings +from collections import defaultdict +from collections.abc import Iterable, Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from typing import Any, Callable, Literal, NamedTuple, Optional, TYPE_CHECKING + +import torch +import torch.utils._pytree as pytree +from torch._C import _fx_map_arg as map_arg, _NodeIter +from torch.utils._dtype_abbrs import dtype_abbrs + +from . import _pytree as fx_pytree +from ._compatibility import compatibility +from .immutable_collections import immutable_dict +from .node import _get_qualified_name, _type_repr, Argument, Node, Target + + +__all__ = ["PythonCode", "CodeGen", "Graph"] + +if TYPE_CHECKING: + from ._symbolic_trace import Tracer # noqa: F401 + from .graph_module import GraphModule # noqa: F401 + + +# Mapping of builtins to their `typing` equivalent. +# (PEP585: See D68459095 test plan) +_origin_type_map = { + list: typing.List, # noqa: UP006 + dict: typing.Dict, # noqa: UP006 + set: typing.Set, # noqa: UP006 + frozenset: typing.FrozenSet, # noqa: UP006 + tuple: typing.Tuple, # noqa: UP006 +} + +_legal_ops = dict.fromkeys( + ["call_function", "call_method", "get_attr", "call_module", "placeholder", "output"] +) + + +# Signature for functions thattransforms the body (`list[str]`) of the +# generated code +TransformCodeFunc = Callable[[list[str]], list[str]] + + +class _CustomBuiltin(NamedTuple): + """Additional objs that we add to every graph's globals. + + The repr() for some standard library objects is not valid Python code without + an import. For common objects of this sort, we bundle them in the globals of + every FX graph. + """ + + # How to import this object from the standard library. + import_str: str + # The actual object, produced from that import string. + obj: Any + + +# Combined dict of disallowed variable names so we can check with one lookup +_illegal_names = {k: object() for k in keyword.kwlist} +_illegal_names.update(builtins.__dict__) # can't shadow a builtin name + +_custom_builtins: dict[str, _CustomBuiltin] = {} + + +def _register_custom_builtin(name: str, import_str: str, obj: Any): + _custom_builtins[name] = _CustomBuiltin(import_str, obj) + _illegal_names[name] = obj + + +_register_custom_builtin("inf", "from math import inf", math.inf) +_register_custom_builtin("nan", "from math import nan", math.nan) +_register_custom_builtin("NoneType", "NoneType = type(None)", type(None)) +_register_custom_builtin("torch", "import torch", torch) +_register_custom_builtin("device", "from torch import device", torch.device) +_register_custom_builtin("fx_pytree", "import torch.fx._pytree as fx_pytree", fx_pytree) +_register_custom_builtin("pytree", "import torch.utils._pytree as pytree", pytree) + + +def _is_magic(x: str) -> bool: + return x.startswith("__") and x.endswith("__") + + +def _snake_case(s: str) -> str: + """ + Transforms the given string ``s`` to a Python-style variable name + + Examples: + ``mod.snake_case`` -> ``mod.snake_case`` + ``mod.pascalCase``-> ``mod.pascal_case`` + ``mod.ALL_CAPS`` -> ``mod.all_caps`` + """ + return _snake_case_sub(s).lower() + + +# Replace occurrences where a lowercase letter is followed by an uppercase letter +_snake_case_sub = functools.partial(re.compile(r"(?<=[a-z])([A-Z])").sub, r"_\1") + +# Find chars that can't be in a Python identifier +_illegal_char_regex = re.compile("[^0-9a-zA-Z_]+") + +# Combined check for variable names: +# 1) Checks name is not empty +# 2) Checks first character is not a digit +# 3) Checks name has no illegal characters (_illegal_char_regex) +# 3) Splits off the number suffix (if present) +_name_regex = re.compile(r"^([a-zA-Z_][0-9a-zA-Z_]*?)(?:_(\d+))?$") + +# starts with torch but does not start with torch._dynamo. or torch._inductor. +_torch_but_not_dynamo = re.compile( + r"^torch(?:\.(?!_dynamo\.|_inductor\.)[^.]+)*$" +).fullmatch + + +def _is_from_torch(obj: Any) -> bool: + module_name = getattr(obj, "__module__", None) + if module_name is not None: + return _torch_but_not_dynamo(module_name) is not None + + name = getattr(obj, "__name__", None) + # exclude torch because torch.torch.torch.torch works. idk mang + if name is not None and name != "torch": + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is obj: + return True + + return False + + +class _Namespace: + """A context for associating names uniquely with objects. + + The following invariants are enforced: + - Each object gets a single name. + - Each name is unique within a given namespace. + - Names generated do not shadow builtins, unless the object is indeed that builtin. + """ + + def __init__(self): + self._obj_to_name: dict[Any, str] = {} + self._used_names: set[str] = set() + self._base_count: dict[str, int] = {} + + def create_name(self, candidate: str, obj: Optional[Any]) -> str: + """Create a unique name. + + Arguments: + candidate: used as the basis for the unique name, relevant to the user. + obj: If not None, an object that will be associated with the unique name. + """ + if obj is not None and obj in self._obj_to_name: + return self._obj_to_name[obj] + + # optimistically check if candidate is already a valid name + match = _name_regex.match(candidate) + if match is None: + # delete all characters that are illegal in a Python identifier + candidate = _illegal_char_regex.sub("_", candidate) + + if not candidate: + candidate = "_unnamed" + + if candidate[0].isdigit(): + candidate = f"_{candidate}" + + match = _name_regex.match(candidate) + assert match is not None + + base, num = match.group(1, 2) + if num is None or candidate in self._used_names: + num = self._base_count.get(candidate, 0) + if _illegal_names.get(candidate, obj) is not obj: + num += 1 + candidate = f"{base}_{num}" + # assume illegal names don't end in _\d so no need to check again + else: + num = int(num) + + while candidate in self._used_names: + num += 1 + candidate = f"{base}_{num}" + + self._used_names.add(candidate) + self._base_count[base] = num + if obj is not None: + self._obj_to_name[obj] = candidate + return candidate + + def associate_name_with_obj(self, name: str, obj: Any): + """Associate a unique name with an object. + + Neither `name` nor `obj` should be associated already. + """ + maybe_existing = self._obj_to_name.setdefault(obj, name) + assert maybe_existing is name, "obj is already associated" + + def _rename_object(self, obj: Any, name: str): + assert obj in self._obj_to_name + self._obj_to_name[obj] = name + self._used_names.add(name) + + +@compatibility(is_backward_compatible=True) +@dataclass +class PythonCode: + """ + Represents all the information necessary to exec or save a graph as Python code. + """ + + # Python source code for the forward function definition. + src: str + # Values in global scope during execution of `src_def`. + globals: dict[str, Any] + # Optional mapping from the forward function's line number to + # node index. + _lineno_map: Optional[dict[int, Optional[int]]] + + +def _format_target(base: str, target: str) -> str: + elems = target.split(".") + r = base + for e in elems: + if not e.isidentifier(): + r = f'getattr({r}, "{e}")' + else: + r = f"{r}.{e}" + return r + + +class _InsertPoint: + def __init__(self, graph, new_insert): + self.graph = graph + self.orig_insert, graph._insert = graph._insert, new_insert + + def __enter__(self): + pass + + def __exit__(self, type, value, tb): + self.graph._insert = self.orig_insert + + +class _node_list: + def __init__(self, graph: "Graph", direction: Literal["_prev", "_next"] = "_next"): + assert direction in ("_next", "_prev") + self.graph = graph + self.direction = direction + + def __len__(self): + return self.graph._len + + def __iter__(self): + return _NodeIter(self.graph._root, self.direction == "_prev") + + def __reversed__(self): + return _node_list(self.graph, "_next" if self.direction == "_prev" else "_prev") + + +class _PyTreeInfo(NamedTuple): + """ + Contains extra info stored when we're using Pytrees + """ + + orig_args: list[str] + in_spec: pytree.TreeSpec + out_spec: Optional[pytree.TreeSpec] + + +@dataclass(frozen=True) +class _ParsedStackTrace: + """ + Represents the top-most frame of a parsed stack trace + """ + + file: str + lineno: str + name: str + code: str + + def get_summary_str(self): + return f"File: {self.file}:{self.lineno} in {self.name}, code: {self.code}" + + +# get File:lineno code from stack_trace +def _parse_stack_trace(stack_trace: str): + if stack_trace is None: + return None + pattern = re.compile(r"^File \"(.+)\", line (\d+), in (.+)$") + lines = stack_trace.strip().split("\n") + # stacktrace should have innermost frame last, so we + # iterate backwards to find the first line that starts + # with 'File ' + for idx in range(len(lines) - 2, -1, -1): + line = lines[idx].strip() + matches = pattern.match(line) + if matches: + file = matches.group(1) + lineno = matches.group(2) + name = matches.group(3) + # next line should be the code + code = lines[idx + 1].strip() + return _ParsedStackTrace(file, lineno, name, code) + return None + + +@compatibility(is_backward_compatible=False) +class CodeGen: + # This is an override hook so we can customize the SymNode printer. + _sym_repr: Callable[["torch.types.PySymType"], str] = lambda x: repr(x) + + def __init__(self): + self._body_transformer: Optional[TransformCodeFunc] = None + self._func_name: str = "forward" + + def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str: + """ + Given the free variables and a return annotation, generates the beginning of the FX function. + By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'` + """ + # If the original function didn't have self as its first argument, we + # would have added it. + if len(free_vars) == 0 or free_vars[0] != "self": + free_vars.insert(0, "self") + return ( + f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:" + ) + + def generate_output(self, output_args: Argument) -> str: + """ + Given the output arguments, generates the return statement of the FX function. + Note: The returned statement should not be indented. + """ + return f"return {repr(output_args)}" + + def process_inputs(self, *args: Any) -> Any: + """ + Transforms the inputs so that the graph can take them as arguments, as + non-default codegen may result in the inputs to the function being + different from the inputs to the graph. + + If the graph was directly runnable, this invariant should hold true + `f.graph.process_outputs(f.graph(*f.graph.process_inputs(*inputs))) == f(*inputs)` + """ + return args + + def process_outputs(self, outputs: Any) -> Any: + """ + Transforms the outputs of the graph to be identical to the codegen. + + See ``process_inputs`` for more details. + """ + return outputs + + def additional_globals(self) -> list[tuple[str, Any]]: + """ + If your codegen uses extra global values, add tuples of (identifier,reference to the value) here. + For example, return ['List', typing.List] if you need ``List`` in the global context. + """ + return [] + + def _gen_python_code( + self, + nodes, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + ) -> PythonCode: + free_vars: list[str] = [] + body: list[str] = [] + globals_: dict[str, Any] = {} + wrapped_fns: dict[str, None] = {} + + # Wrap string in list to pass by reference + maybe_return_annotation: list[str] = [""] + include_stride = include_stride or ( + os.environ.get("FX_GRAPH_SHOW_STRIDE", "0") == "1" + ) + include_device = include_device or ( + os.environ.get("FX_GRAPH_SHOW_DEVICE", "0") == "1" + ) + + def add_global(name_hint: str, obj: Any): + """Add an obj to be tracked as a global. + + We call this for names that reference objects external to the + Graph, like functions or types. + + Returns: the global name that should be used to reference 'obj' in generated source. + """ + if ( + _is_from_torch(obj) and obj != torch.device + ): # to support registering torch.device + # HACK: workaround for how torch custom ops are registered. We + # can't import them like normal modules so they must retain their + # fully qualified name. + return _get_qualified_name(obj) + + # normalize the name hint to get a proper identifier + global_name = namespace.create_name(name_hint, obj) + + if global_name in globals_: + assert globals_[global_name] == obj + return global_name + globals_[global_name] = obj + return global_name + + # Pre-fill the globals table with registered builtins. + for name, (_, obj) in _custom_builtins.items(): + add_global(name, obj) + + def type_repr(o: Any): + if o == (): + # Empty tuple is used for empty tuple type annotation Tuple[()] + return "()" + + typename = _type_repr(o) + + if origin_type := getattr(o, "__origin__", None): + # list[...], typing.List[...], TensorType[...] + + if isinstance(o, typing._GenericAlias): # type: ignore[attr-defined] + # This is a generic pre-PEP585 type, e.g. typing.List[torch.Tensor] + origin_type = _origin_type_map.get(origin_type, origin_type) + + origin_typename = add_global(_type_repr(origin_type), origin_type) + + if hasattr(o, "__args__"): + # Assign global names for each of the inner type variables. + args = [type_repr(arg) for arg in o.__args__] + + if len(args) == 0: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python < 3.9 + return origin_typename + + return f"{origin_typename}[{','.join(args)}]" + else: + # Bare type, such as `typing.Tuple` with no subscript + # This code-path used in Python 3.9+ + return origin_typename + + # Common case: this is a regular module name like 'foo.bar.baz' + return add_global(typename, o) + + if colored: + red = _color_fns["red"] + dim_green = _color_fns["dim_green"] + dim = _color_fns["dim"] + dim_blue = _color_fns["dim_blue"] + blue = _color_fns["blue"] + else: + red = _identity + dim_green = _identity + dim = _identity + dim_blue = _identity + blue = _identity + + def _get_repr(arg: Any) -> str: + if isinstance(arg, Node): # first because common + return repr(arg) + elif isinstance(arg, tuple) and hasattr(arg, "_fields"): + # Handle NamedTuples (if it has `_fields`) via add_global. + qualified_name = _get_qualified_name(type(arg)) + global_name = add_global(qualified_name, type(arg)) + return f"{global_name}{repr(tuple(arg))}" + elif isinstance( + arg, (torch._ops.OpOverload, torch._ops.HigherOrderOperator) + ): + qualified_name = _get_qualified_name(arg) + global_name = add_global(qualified_name, arg) + return f"{global_name}" + elif isinstance(arg, enum.Enum): + cls = arg.__class__ + clsname = add_global(cls.__name__, cls) + return f"{clsname}.{arg.name}" + elif isinstance(arg, torch.Tensor): + size = list(arg.size()) + dtype = str(arg.dtype).split(".")[-1] + return f"torch.Tensor(size={size}, dtype={dtype})" + elif isinstance(arg, tuple): + if len(arg) == 1: + return f"({_get_repr(arg[0])},)" + else: + return "(" + ", ".join(_get_repr(a) for a in arg) + ")" + elif isinstance(arg, list): + return "[" + ", ".join(_get_repr(a) for a in arg) + "]" + elif isinstance(arg, slice): + return f"slice({_get_repr(arg.start)}, {_get_repr(arg.stop)}, {_get_repr(arg.step)})" + else: + return blue(repr(arg)) + + def _format_args( + args: tuple[Argument, ...], kwargs: dict[str, Argument] + ) -> str: + res = [_get_repr(a) for a in args] + res.extend([f"{k} = {_get_repr(v)}" for k, v in kwargs.items()]) + return ", ".join(res) + + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: dict[Node, Node] = {} + user_to_last_uses: dict[Node, list[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(nodes): + for input_node in node._input_nodes: + register_last_uses(input_node, node) + + def delete_unused_values(user: Node): + """ + Delete values after their last use. This ensures that values that are + not used in the remainder of the code are freed and the memory usage + of the code is optimal. + """ + if user.op == "placeholder": + return + if user.op == "output": + body.append("\n") + return + nodes_to_delete = user_to_last_uses.get(user, []) + + if len(user.users.keys()) == 0: + # This node is not used by any others. however it's also not + # removed by DCE since side-effect. We want to free it's outputs + # right after its execution done to save memory. + nodes_to_delete.append(user) + + if len(nodes_to_delete): + to_delete_str = " = ".join( + [repr(n) for n in nodes_to_delete] + ["None"] + ) + body.append(f"; {dim(to_delete_str)}\n") + else: + body.append("\n") + + prev_stacktrace = None + + def append_stacktrace_summary(node: Node): + """ + Append a summary of the stacktrace to the generated code. This is + useful for debugging. + """ + nonlocal prev_stacktrace + + if node.op not in {"placeholder", "output"}: + stack_trace = node.stack_trace + if stack_trace: + if stack_trace != prev_stacktrace: + prev_stacktrace = stack_trace + if parsed_stack_trace := _parse_stack_trace(stack_trace): + summary_str = parsed_stack_trace.get_summary_str() + else: + summary_str = "" + body.append(f"\n {dim(f'# {summary_str}')}\n") + elif prev_stacktrace != "": + prev_stacktrace = "" + no_stacktrace_msg = "# No stacktrace found for following nodes" + body.append(f"\n{dim(no_stacktrace_msg)}\n") + + def stringify_shape(shape: Iterable) -> str: + return f"[{', '.join([str(x) for x in shape])}]" + + def emit_node(node: Node): + maybe_type_annotation = ( + "" if node.type is None else f" : {type_repr(node.type)}" + ) + + if verbose: + # override annotation with more detailed information + from torch.fx.experimental.proxy_tensor import py_sym_types + from torch.fx.passes.shape_prop import TensorMetadata + + meta_val = node.meta.get( + "val", + node.meta.get("tensor_meta", node.meta.get("example_value", None)), + ) + # use string as annotation, to make it valid python code + if isinstance(meta_val, torch.Tensor) and meta_val.layout not in ( + torch.sparse_csc, + torch.sparse_csr, + ): + stride_annotation = ( + f"{stringify_shape(meta_val.stride())}" + if include_stride + else "" + ) + device_annotation = f"{meta_val.device}" if include_device else "" + maybe_type_annotation = ( + f': "{red(dtype_abbrs[meta_val.dtype])}{blue(stringify_shape(meta_val.shape))}' + f'{dim_blue(stride_annotation)}{dim_green(device_annotation)}"' + ) + elif isinstance(meta_val, py_sym_types): + val_str = CodeGen._sym_repr(meta_val) + maybe_type_annotation = f': "Sym({val_str})"' + elif isinstance(meta_val, TensorMetadata): + maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"' + + if node.op == "placeholder": + assert isinstance(node.target, str) + maybe_default_arg = ( + "" if not node.args else f" = {_get_repr(node.args[0])}" + ) + free_vars.append( + f"{node.target}{maybe_type_annotation}{maybe_default_arg}" + ) + raw_name = node.target.replace("*", "") + if raw_name != repr(node): + body.append(f"{repr(node)} = {raw_name}\n") + return + elif node.op == "call_method": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.target)}" + f"({_format_args(node.args[1:], node.kwargs)})" + ) + return + elif node.op == "call_function": + assert callable(node.target) + # pretty print operators + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in magic_methods + ): + assert isinstance(node.args, tuple) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{magic_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}" + ) + return + + # pretty print inplace operators; required for jit.script to work properly + # not currently supported in normal FX graphs, but generated by torchdynamo + if ( + getattr(node.target, "__module__", "") == "_operator" + and node.target.__name__ in inplace_methods + ): + body.append( + f"{inplace_methods[node.target.__name__].format(*(_get_repr(a) for a in node.args))}; " + f"{repr(node)}{maybe_type_annotation} = {_get_repr(node.args[0])}" + ) + return + + qualified_name = _get_qualified_name(node.target) + global_name = add_global(qualified_name, node.target) + # special case for getattr: node.args could be 2-argument or 3-argument + # 2-argument: attribute access; 3-argument: fall through to attrib function call with default value + if ( + global_name == "getattr" + and isinstance(node.args, tuple) + and isinstance(node.args[1], str) + and node.args[1].isidentifier() + and len(node.args) == 2 + ): + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(_get_repr(node.args[0]), node.args[1])}" + ) + return + body.append( + f"{repr(node)}{maybe_type_annotation} = {global_name}({_format_args(node.args, node.kwargs)})" + ) + if node.meta.get("is_wrapped", False): + wrapped_fns.setdefault(global_name) + return + elif node.op == "call_module": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = " + f"{_format_target(root_module, node.target)}({_format_args(node.args, node.kwargs)})" + ) + return + elif node.op == "get_attr": + assert isinstance(node.target, str) + body.append( + f"{repr(node)}{maybe_type_annotation} = {_format_target(root_module, node.target)}" + ) + return + elif node.op == "output": + if node.type is not None: + maybe_return_annotation[0] = f" -> {type_repr(node.type)}" + body.append(self.generate_output(node.args[0])) + return + raise NotImplementedError(f"node: {node.op} {node.target}") + + for i, node in enumerate(nodes): + # NOTE: emit_node does not emit a string with newline. It depends + # on delete_unused_values to append one + if verbose: + append_stacktrace_summary(node) + # emit a counter comment to keep track of + # node index, which will be deleted later + # after going through _body_transformer + body.append(f"# COUNTER: {i}\n") + emit_node(node) + delete_unused_values(node) + + if len(body) == 0: + # If the Graph has no non-placeholder nodes, no lines for the body + # have been emitted. To continue to have valid Python code, emit a + # single pass statement + body.append("pass\n") + + if len(wrapped_fns) > 0: + wrap_name = add_global("wrap", torch.fx.wrap) + wrap_stmts = "\n".join([f'{wrap_name}("{name}")' for name in wrapped_fns]) + else: + wrap_stmts = "" + + if self._body_transformer: + body = self._body_transformer(body) + + for name, value in self.additional_globals(): + add_global(name, value) + + prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0]) + + # remove counter and generate lineno to node index mapping + lineno_map: dict[int, Optional[int]] = {} + prologue_len = prologue.count("\n") + 1 + new_lines: list[str] = [] + cur_idx = None + for line in "".join(body).split("\n"): + counter = _counter_regexp.search(line) + if counter is not None: + cur_idx = int(counter.group(1)) + else: + lineno_map[len(new_lines) + prologue_len] = cur_idx + new_lines.append(line) + + code = "\n".join(new_lines).lstrip("\n") + code = "\n".join(" " + line for line in code.split("\n")) + + fn_code = f""" +{wrap_stmts} + +{prologue} +{code}""" + return PythonCode(fn_code, globals_, _lineno_map=lineno_map) + + +# Ideally, we'd like to refactor all of the pytree logic into this codegen +# class. Unfortunately, there are 3 areas we currently need extra logic in FX. +# 1. In the initial symbolic trace, the pytree logic is tied up with `concrete_args`. +# 2. In the FX graph, we need to access 2 attributes - in_spec and out_spec. +# Since we can't access .graph within the FX forward, we need to copy the attribute to the module. +# 3. We currently can't register the pytree imports with `add_global` - not sure why. +class _PyTreeCodeGen(CodeGen): + def __init__(self, pytree_info: _PyTreeInfo): + super().__init__() + self.pytree_info: _PyTreeInfo = pytree_info + + def process_inputs(self, *inputs: Any) -> Any: + flat_args = pytree.arg_tree_leaves(*inputs) + return flat_args + + def process_outputs(self, out: Any) -> Any: + if self.pytree_info is None or self.pytree_info.out_spec is None: + return out + if not isinstance(out, (list, tuple)): + out = [out] + assert self.pytree_info.out_spec is not None + return pytree.tree_unflatten(out, self.pytree_info.out_spec) + + def gen_fn_def(self, free_vars, maybe_return_annotation): + # Given a user function/model: + # myargs = [myargs0, myargs1] + # mykwargs = {'mykwargs0': ..., 'mykwargs1': ...} + # def forward(self, mypos, *myargs, mykey=None, **mykwargs): + # + # The generated code flattens all keywords into positional arguments for `forward()` + # e.g forward(self, mypos, myargs0, myargs1, mykey, mykwargs0, mykwargs1): + # + # Within `forward`, `tree_flatten_spec``still parses args and kwargs separately + # e.g. tree_flatten_spec(([mypos, myargs0, myargs1], + # {'mykey':mykey, 'mykwargs0':mykwargs0, 'mykwargs1':mykwargs1}), + # self._in_spec) + # + # If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec + # e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec) + if self.pytree_info is None: + return super().gen_fn_def(free_vars, maybe_return_annotation) + + fn_args = self.pytree_info.orig_args + has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False + if has_orig_self: + free_vars.insert(0, "self") + fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation) + + if len(free_vars) > 0: # pytree has placeholders in it + # when kwargs is present, in_spec is tuple(args, kwargs) + has_args_kwargs_tuple = ( + self.pytree_info.in_spec.type == tuple + and self.pytree_info.in_spec.num_children == 2 + and self.pytree_info.in_spec.children_specs[0].type == tuple + and self.pytree_info.in_spec.children_specs[1].type == dict + ) + fn_kwargs = "{}" + fn_signature = f"[{', '.join(fn_args)}], self._in_spec" + if has_args_kwargs_tuple: + count_args = self.pytree_info.in_spec.children_specs[0].num_children + fn_args = self.pytree_info.orig_args[:count_args] + fn_kwargs = ( + "{" + + ", ".join( + f"'{k}':{v}" + for k, v in zip( + self.pytree_info.in_spec.children_specs[1].context, + self.pytree_info.orig_args[count_args:], + ) + ) + + "}" + ) + fn_signature = f"([{', '.join(fn_args)}], {fn_kwargs}), self._in_spec" + + # in Python, `var1: annotation1, var2: annotation2 = function_call()` is invalid. + # we need to split it to two lines: + # one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon) + # one for code: `var1, var2, = function_call()` + without_annotation = [x.split(":")[0] for x in free_vars] + has_annotation = [x + "; " for x in free_vars if ":" in x] + if len(has_annotation) > 0: + fn_definition += "\n " + "".join(has_annotation) + "\n" + fn_definition += f""" + {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + return fn_definition + + def generate_output(self, output_args): + if self.pytree_info and self.pytree_info.out_spec: + return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)" + else: + return super().generate_output(output_args) + + +class _FindNodesLookupTable: + """ + Side table for the graph for the purpose of doing fast queries + """ + + def __init__(self): + self.table: dict[tuple[str, Optional[Target]], dict[Node, None]] = defaultdict( + dict + ) + + def _key(self, node) -> tuple[str, Optional[Target]]: + return (node.op, node.target if node.op == "call_function" else None) + + def __contains__(self, node) -> bool: + return node in self.table[self._key(node)] + + def insert(self, node: Node) -> None: + self.table[self._key(node)][node] = None + + def remove(self, node: Node) -> None: + self.table[self._key(node)].pop(node) + + def find_nodes(self, *, op: str, target: Optional["Target"] = None): + if op == "call_function": + assert target is not None + return [*self.table[(op, target)].keys()] + + if target is None: + return [*self.table[(op, None)].keys()] + + # op is call_method, get_attr, call_module + return [node for node in self.table[(op, None)].keys() if node.target == target] + + +@compatibility(is_backward_compatible=True) +class Graph: + """ + ``Graph`` is the main data structure used in the FX Intermediate Representation. + It consists of a series of ``Node`` s, each representing callsites (or other + syntactic constructs). The list of ``Node`` s, taken together, constitute a + valid Python function. + + For example, the following code + + .. code-block:: python + + import torch + import torch.fx + + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x): + return torch.topk( + torch.sum(self.linear(x + self.linear.weight).relu(), dim=-1), 3 + ) + + + m = MyModule() + gm = torch.fx.symbolic_trace(m) + + Will produce the following Graph:: + + print(gm.graph) + + .. code-block:: text + + graph(x): + %linear_weight : [num_users=1] = self.linear.weight + %add_1 : [num_users=1] = call_function[target=operator.add](args = (%x, %linear_weight), kwargs = {}) + %linear_1 : [num_users=1] = call_module[target=linear](args = (%add_1,), kwargs = {}) + %relu_1 : [num_users=1] = call_method[target=relu](args = (%linear_1,), kwargs = {}) + %sum_1 : [num_users=1] = call_function[target=torch.sum](args = (%relu_1,), kwargs = {dim: -1}) + %topk_1 : [num_users=1] = call_function[target=torch.topk](args = (%sum_1, 3), kwargs = {}) + return topk_1 + + For the semantics of operations represented in the ``Graph``, please see :class:`Node`. + """ + + @compatibility(is_backward_compatible=True) + def __init__( + self, + owning_module: Optional["GraphModule"] = None, + tracer_cls: Optional[type["Tracer"]] = None, + tracer_extras: Optional[dict[str, Any]] = None, + ): + """ + Construct an empty Graph. + """ + self._root: Node = Node(self, "", "root", "", (), {}) + self._used_names: dict[str, int] = {} # base name -> number + self._insert = self._root.prepend + self._len = 0 + self._graph_namespace = _Namespace() + self._owning_module = owning_module + self._tracer_cls = tracer_cls + self._tracer_extras = tracer_extras + self._codegen = CodeGen() + self._co_fields: dict[str, Any] = {} + self._find_nodes_lookup_table = _FindNodesLookupTable() + + @property + def owning_module(self): + return self._owning_module + + @owning_module.setter + def owning_module(self, mod: Optional["GraphModule"]): + self._owning_module = mod + + @property + def nodes(self) -> _node_list: + """ + Get the list of Nodes that constitute this Graph. + + Note that this ``Node`` list representation is a doubly-linked list. Mutations + during iteration (e.g. delete a Node, add a Node) are safe. + + Returns: + + A doubly-linked list of Nodes. Note that ``reversed`` can be called on + this list to switch iteration order. + """ + return _node_list(self) + + @compatibility(is_backward_compatible=False) + def output_node(self) -> Node: + output_node = next(iter(reversed(self.nodes))) + assert output_node.op == "output" + return output_node + + @compatibility(is_backward_compatible=False) + def find_nodes( + self, *, op: str, target: Optional["Target"] = None, sort: bool = True + ): + """ + Allows for fast query of nodes + + Args: + + op (str): the name of the operation + + target (Optional[Target]): the target of the node. For call_function, + the target is required. For other ops, the target is optional. + + sort (bool): whether to return nodes in the order they appear on + on the graph. + + Returns: + + Iteratable of nodes with the requested op and target. + """ + node_list = self._find_nodes_lookup_table.find_nodes(op=op, target=target) + if sort: + return sorted(node_list) + return node_list + + @compatibility(is_backward_compatible=True) + def graph_copy( + self, g: "Graph", val_map: dict[Node, Node], return_output_node=False + ) -> "Optional[Argument]": + """ + Copy all nodes from a given graph into ``self``. + + Args: + + g (Graph): The source graph from which to copy Nodes. + + val_map (Dict[Node, Node]): a dictionary that will be populated with a mapping + from nodes in ``g`` to nodes in ``self``. Note that ``val_map`` can be passed + in with values in it already to override copying of certain values. + + Returns: + + The value in ``self`` that is now equivalent to the output value in ``g``, + if ``g`` had an ``output`` node. ``None`` otherwise. + """ + for node in g.nodes: + if node in val_map: + continue + if node.op == "output": + rv = map_arg(node.args[0], lambda n: val_map[n]) + return rv if not return_output_node else (rv, node) + val_map[node] = self.node_copy(node, lambda n: val_map[n]) + return None + + def __deepcopy__(self, memo=None) -> "Graph": + """ + Explicitly implement __deepcopy__ to prevent excessive recursion depth + from the default implementation. This uses graph_copy to copy the nodes + in an iterative way, rather than recursive. It also populates the + memoization table to prevent unnecessary copies (e.g. references to + nodes or other parts of the Graph from a custom GraphModule implementation. + """ + memo = memo if memo else {} + g = Graph(tracer_cls=self._tracer_cls) + output_vals = g.graph_copy(self, val_map=memo, return_output_node=True) + g._codegen = copy.deepcopy(self._codegen) + if output_vals is not None: + assert isinstance(output_vals, tuple) + output_val, old_output_node = output_vals + new_output_node = g.output( + output_val, type_expr=getattr(old_output_node, "type", None) + ) + new_output_node.meta = copy.copy(old_output_node.meta) + return g + + @compatibility(is_backward_compatible=True) + def create_node( + self, + op: str, + target: "Target", + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Create a ``Node`` and add it to the ``Graph`` at the current insert-point. + Note that the current insert-point can be set via :meth:`Graph.inserting_before` + and :meth:`Graph.inserting_after`. + + Args: + op (str): the opcode for this Node. One of 'call_function', 'call_method', 'get_attr', + 'call_module', 'placeholder', or 'output'. The semantics of these opcodes are + described in the ``Graph`` docstring. + + args (Optional[Tuple[Argument, ...]]): is a tuple of arguments to this node. + + kwargs (Optional[Dict[str, Argument]]): the kwargs of this Node + + name (Optional[str]): an optional string name for the ``Node``. + This will influence the name of the value assigned to in the + Python generated code. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted node. + """ + # `target in _legal_ops` is checked in Node.__init__ + if not args: + args = () + else: + assert isinstance(args, tuple), "args must be a tuple" + if not kwargs: + kwargs = immutable_dict() + else: + assert isinstance(kwargs, dict), "kwargs must be a dict" + + candidate = name if name is not None else self._target_to_str(target) + name = self._graph_namespace.create_name(candidate, None) + n = Node(self, name, op, target, args, kwargs, type_expr) + + if ( + self.owning_module is not None + and getattr(self.owning_module, "_create_node_hooks", None) is not None + ): + for f in self.owning_module._create_node_hooks: + f(n) + + self._graph_namespace.associate_name_with_obj(name, n) + + self._insert(n) + self._find_nodes_lookup_table.insert(n) + self._len += 1 + return n + + @compatibility(is_backward_compatible=False) + def process_inputs(self, *args): + """ + Processes args so that they can be passed to the FX graph. + """ + return self._codegen.process_inputs(*args) + + @compatibility(is_backward_compatible=False) + def process_outputs(self, out): + return self._codegen.process_outputs(out) + + @compatibility(is_backward_compatible=True) + def erase_node(self, to_erase: Node) -> None: + """ + Erases a ``Node`` from the ``Graph``. Throws an exception if + there are still users of that node in the ``Graph``. + + Args: + + to_erase (Node): The ``Node`` to erase from the ``Graph``. + """ + if len(to_erase.users) > 0: + raise RuntimeError( + f"Tried to erase Node {to_erase} but it still had {len(to_erase.users)} " + f"users in the graph: {to_erase.users}!" + ) + if to_erase.graph != self: + raise RuntimeError(f"Attempting to remove {to_erase} from wrong graph!") + if to_erase._erased: + warnings.warn(f"erase_node({to_erase}) on an already erased node") + return + + if ( + self.owning_module is not None + and getattr(self.owning_module, "_erase_node_hooks", None) is not None + ): + for f in self.owning_module._erase_node_hooks: + f(to_erase) + + self._find_nodes_lookup_table.remove(to_erase) + to_erase._remove_from_list() + to_erase._erased = True # iterators may retain handles to erased nodes + self._len -= 1 + + # Null out this Node's argument nodes so that the Nodes referred to + # can update their ``users`` accordingly + to_erase._update_args_kwargs( + map_arg(to_erase._args, lambda n: None), + map_arg(to_erase._kwargs, lambda n: None), + ) + + @compatibility(is_backward_compatible=True) + def inserting_before(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_before(n): + ... # inserting before node n + ... # insert point restored to what it was previously + g.inserting_before(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert before + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_after(self._root) + assert n.graph == self, "Node to insert before is not in graph." + return _InsertPoint(self, n.prepend) + + @compatibility(is_backward_compatible=True) + def inserting_after(self, n: Optional[Node] = None): + """Set the point at which create_node and companion methods will insert into the graph. + When used within a 'with' statement, this will temporary set the insert point and + then restore it when the with statement exits:: + + with g.inserting_after(n): + ... # inserting after node n + ... # insert point restored to what it was previously + g.inserting_after(n) # set the insert point permanently + + Args: + + n (Optional[Node]): The node before which to insert. If None this will insert after + the beginning of the entire graph. + + Returns: + A resource manager that will restore the insert point on ``__exit__``. + """ + if n is None: + return self.inserting_before(self._root) + assert n.graph == self, "Node to insert after is not in graph." + return _InsertPoint(self, n.append) + + @compatibility(is_backward_compatible=True) + def placeholder( + self, + name: str, + type_expr: Optional[Any] = None, + default_value: Any = inspect.Signature.empty, + ) -> Node: + """ + Insert a ``placeholder`` node into the Graph. A ``placeholder`` represents + a function input. + + Args: + + name (str): A name for the input value. This corresponds to the name + of the positional argument to the function this ``Graph`` represents. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. This is needed in some + cases for proper code generation (e.g. when the function is used + subsequently in TorchScript compilation). + + default_value (Any): The default value this function argument should take + on. NOTE: to allow for `None` as a default value, `inspect.Signature.empty` + should be passed as this argument to specify that the parameter does _not_ + have a default value. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + args = () if default_value is inspect.Signature.empty else (default_value,) + return self.create_node("placeholder", name, args=args, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def get_attr(self, qualified_name: str, type_expr: Optional[Any] = None) -> Node: + """ + Insert a ``get_attr`` node into the Graph. A ``get_attr`` ``Node`` represents the + fetch of an attribute from the ``Module`` hierarchy. + + Args: + + qualified_name (str): the fully-qualified name of the attribute to be retrieved. + For example, if the traced Module has a submodule named ``foo``, which has a + submodule named ``bar``, which has an attribute named ``baz``, the qualified + name ``foo.bar.baz`` should be passed as ``qualified_name``. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + + Returns: + + The newly-created and inserted ``get_attr`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + + def _get_attr_reference_exists( + mod: torch.nn.Module, qualified_name: str + ) -> bool: + module_path, _, name = qualified_name.rpartition(".") + + try: + submod: torch.nn.Module = mod.get_submodule(module_path) + except AttributeError: + warnings.warn(f"Failed to fetch module {module_path}!") + return False + + if not hasattr(submod, name): + return False + + res = getattr(submod, name) + + if ( + not isinstance(res, torch.nn.Module) + and not isinstance(res, torch.nn.Parameter) + and name not in submod._buffers + ): + return False + + return True + + if self.owning_module and not _get_attr_reference_exists( + self.owning_module, qualified_name + ): + warnings.warn( + "Attempted to insert a get_attr Node with no " + "underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule, " + "GraphModule.add_parameter to add the " + "necessary Parameter, or " + "nn.Module.register_buffer to add the " + "necessary buffer", + stacklevel=2, + ) + return self.create_node("get_attr", qualified_name, type_expr=type_expr) + + @compatibility(is_backward_compatible=True) + def call_module( + self, + module_name: str, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Insert a ``call_module`` ``Node`` into the ``Graph``. A ``call_module`` node + represents a call to the forward() function of a ``Module`` in the ``Module`` + hierarchy. + + Args: + + module_name (str): The qualified name of the ``Module`` in the ``Module`` + hierarchy to be called. For example, if the traced ``Module`` has a + submodule named ``foo``, which has a submodule named ``bar``, the + qualified name ``foo.bar`` should be passed as ``module_name`` to + call that module. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this should *not* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly-created and inserted ``call_module`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + if self.owning_module and self.owning_module.get_submodule(module_name) is None: + warnings.warn( + "Attempted to insert a call_module Node with " + "no underlying reference in the owning " + "GraphModule! Call " + "GraphModule.add_submodule to add the " + "necessary submodule" + ) + return self.create_node( + "call_module", module_name, args, kwargs, type_expr=type_expr + ) + + @compatibility(is_backward_compatible=True) + def call_method( + self, + method_name: str, + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Insert a ``call_method`` ``Node`` into the ``Graph``. A ``call_method`` node + represents a call to a given method on the 0th element of ``args``. + + Args: + + method_name (str): The name of the method to apply to the self argument. + For example, if args[0] is a ``Node`` representing a ``Tensor``, + then to call ``relu()`` on that ``Tensor``, pass ``relu`` to ``method_name``. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called method. Note that this *should* include a ``self`` argument. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called method + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + Returns: + + The newly created and inserted ``call_method`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node( + "call_method", method_name, args, kwargs, type_expr=type_expr + ) + + @compatibility(is_backward_compatible=True) + def call_function( + self, + the_function: Callable[..., Any], + args: Optional[tuple["Argument", ...]] = None, + kwargs: Optional[dict[str, "Argument"]] = None, + type_expr: Optional[Any] = None, + name: Optional[str] = None, + ) -> Node: + """ + Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node + represents a call to a Python callable, specified by ``the_function``. + + Args: + + the_function (Callable[..., Any]): The function to be called. Can be any PyTorch + operator, Python function, or member of the ``builtins`` or ``operator`` + namespaces. + + args (Optional[Tuple[Argument, ...]]): The positional arguments to be passed + to the called function. + + kwargs (Optional[Dict[str, Argument]]): The keyword arguments to be passed + to the called function + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + name (Optional[str]): The name of the node. If not specified, set to None + + Returns: + + The newly created and inserted ``call_function`` node. + + .. note:: + The same insertion point and type expression rules apply for this method + as :meth:`Graph.create_node`. + """ + return self.create_node( + "call_function", the_function, args, kwargs, name=name, type_expr=type_expr + ) + + @compatibility(is_backward_compatible=True) + def node_copy( + self, node: Node, arg_transform: Callable[[Node], "Argument"] = lambda x: x + ) -> Node: + """ + Copy a node from one graph into another. ``arg_transform`` needs to transform arguments from + the graph of node to the graph of self. Example:: + + # Copying all the nodes in `g` into `new_graph` + g: torch.fx.Graph = ... + new_graph = torch.fx.graph() + value_remap = {} + for node in g.nodes: + value_remap[node] = new_graph.node_copy(node, lambda n: value_remap[n]) + + Args: + + node (Node): The node to copy into ``self``. + + arg_transform (Callable[[Node], Argument]): A function that transforms + ``Node`` arguments in node's ``args`` and ``kwargs`` into the + equivalent argument in ``self``. In the simplest case, this should + retrieve a value out of a table mapping Nodes in the original + graph to ``self``. + """ + args = map_arg(node.args, arg_transform) + kwargs = map_arg(node.kwargs, arg_transform) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + result_node = self.create_node( + node.op, node.target, args, kwargs, node.name, node.type + ) + result_node.meta = copy.copy(node.meta) + return result_node + + @compatibility(is_backward_compatible=True) + def output(self, result: "Argument", type_expr: Optional[Any] = None): + """ + Insert an ``output`` ``Node`` into the ``Graph``. An ``output`` node represents + a ``return`` statement in Python code. ``result`` is the value that should + be returned. + + Args: + + result (Argument): The value to be returned. + + type_expr (Optional[Any]): an optional type annotation representing the + Python type the output of this node will have. + + .. note:: + + The same insertion point and type expression rules apply for this method + as ``Graph.create_node``. + """ + return self.create_node( + op="output", target="output", args=(result,), type_expr=type_expr + ) + + def _target_to_str(self, target: Target) -> str: + if callable(target): + op = target.__name__ + else: + assert isinstance(target, str) + op = target + if _is_magic(op): + op = op[2:-2] + op = _snake_case(op) + return op + + @compatibility(is_backward_compatible=True) + def python_code( + self, + root_module: str, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + ) -> PythonCode: + """ + Turn this ``Graph`` into valid Python code. + + Args: + + root_module (str): The name of the root module on which to look-up + qualified name targets. This is usually 'self'. + + Returns: + + A PythonCode object, consisting of two fields: + src: the Python source code representing the object + globals: a dictionary of global names in `src` -> the objects that they reference. + """ + # NOTE: [Graph Namespaces] + # + # There are two types of symbols in generated Python source code: + # locals and globals. + # Locals are locally defined by the output of a node in the Graph. + # Globals are references to external objects, like functions or types. + # + # When generating Python code, we need to make sure to name things + # appropriately. In particular: + # - All names should be unique, to avoid weird shadowing bugs. + # - These names need to be consistent, e.g. a object should always be + # referenced by the same name. + # + # To do this, we create a new namespace just for this source. All names + # that get printed must come from this namespace. + # + # Why can't we re-use node.name? Because it was generated within the + # namespace `self._graph_namespace`. In order to provide uniqueness + # over both locals (node.name) *and* globals, we create a completely + # new namespace to put all identifiers in. + namespace = _Namespace() + + # Override Node's repr to generate a valid name within our namespace. + # Since repr() is designed to produce a valid Python expression, it + # makes sense to re-use it. This way, it's easy to print something like + # Tuple[Node, Node] by simply calling repr() on it. Node's __repr__ is + # implemented cooperatively to allow this. + def node_repr(n: Node): + return namespace.create_name(n.name, n) + + @contextmanager + def override_node_repr(graph: Graph): + orig_repr_fns = {} + for node in graph.nodes: + orig_repr_fns[node] = node._repr_fn + node._repr_fn = node_repr + try: + yield None + finally: + # restore the original repr functions + for node in graph.nodes: + node._repr_fn = orig_repr_fns[node] + + with override_node_repr(self): + return self._python_code( + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + + def _python_code( + self, + root_module: str, + namespace: _Namespace, + *, + verbose: bool = False, + include_stride: bool = False, + include_device: bool = False, + colored: bool = False, + ) -> PythonCode: + return self._codegen._gen_python_code( + self.nodes, + root_module, + namespace, + verbose=verbose, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + + def __str__(self) -> str: + """ + Return a human-readable (not machine-readable) string representation + of this Graph + """ + placeholder_names: list[str] = [] + # This is a one-element array just so ``format_node`` can modify the closed + # over value + maybe_return_typename: list[str] = [""] + + node_strs = [node.format_node(placeholder_names) for node in self.nodes] + param_str = ", ".join(placeholder_names) + s = f"graph({param_str}){maybe_return_typename[0]}:" + for node_str in node_strs: + if node_str: + s += "\n " + node_str + return s + + @compatibility(is_backward_compatible=True) + def print_tabular(self): + """ + Prints the intermediate representation of the graph in tabular + format. Note that this API requires the ``tabulate`` module to be + installed. + """ + try: + from tabulate import tabulate + except ImportError: + print( + "`print_tabular` relies on the library `tabulate`, " + "which could not be found on this machine. Run `pip " + "install tabulate` to install the library." + ) + raise + + node_specs = [[n.op, n.name, n.target, n.args, n.kwargs] for n in self.nodes] + print( + tabulate(node_specs, headers=["opcode", "name", "target", "args", "kwargs"]) + ) + + @compatibility(is_backward_compatible=True) + def lint(self): + """ + Runs various checks on this Graph to make sure it is well-formed. In + particular: + - Checks Nodes have correct ownership (owned by this graph) + - Checks Nodes appear in topological order + - If this Graph has an owning GraphModule, checks that targets + exist in that GraphModule + """ + + # Check topo order + def check_arg(arg: Node, n: Optional[Node] = None) -> None: + context_str = f" of Node '{n}' " if n else " " + if arg.graph is not self: + raise RuntimeError( + f"Argument '{arg}'{context_str}does not belong to this Graph, " + f"but was used as an argument! If you are copying nodes from another graph, make " + f"sure to use ``arg_transform`` on node_copy() to remap values\n{self}" + ) + if arg not in seen_values: + raise RuntimeError( + f"Argument '{arg}'{context_str}was used before it has been " + f"defined! Please check that Nodes in the graph are topologically ordered\n{self}" + ) + + seen_names: set[str] = set() + seen_values: set[Node] = set() + for node in self.nodes: + if node.op not in _legal_ops: + raise RuntimeError(f"Node {node} had unknown opcode {node.op}!") + if node.graph is not self: + raise RuntimeError(f"Node '{node}' does not belong to this Graph!") + if node not in self._find_nodes_lookup_table: + raise RuntimeError(f"Node '{node}' is not added to the side table") + for arg in node._input_nodes: + check_arg(arg, node) + seen_values.add(node) + + if node.name in seen_names: + raise RuntimeError(f"Node redefined name {node.name}!") + seen_names.add(node.name) + + # Check targets are legit + if self.owning_module: + for node in self.nodes: + if node.op == "call_function": + if not callable(node.target): + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a Callable is expected" + ) + else: + if not isinstance(node.target, str): + raise ValueError( + f"Node {node} target {node.target} has type {torch.typename(node.target)} but " + "a str is expected" + ) + if node.op in ["get_attr", "call_module"]: + target_atoms = node.target.split(".") + m_itr = self.owning_module + for i, atom in enumerate(target_atoms): + new_m_itr = getattr(m_itr, atom, None) + seen_qualname = ".".join(target_atoms[:i]) + if new_m_itr is None: + raise RuntimeError( + f"Node {node} target {node.target} references nonexistent attribute " + f"{atom} of {seen_qualname}" + ) + if node.op == "call_module" and not isinstance( + new_m_itr, torch.nn.Module + ): + raise RuntimeError( + f"Node {node} target {node.target} {atom} of {seen_qualname} does " + "not reference an nn.Module" + ) + + m_itr = new_m_itr + + @compatibility(is_backward_compatible=True) + def eliminate_dead_code( + self, is_impure_node: Optional[Callable[[Node], bool]] = None + ) -> bool: + """ + Remove all dead code from the graph, based on each node's number of + users, and whether the nodes have any side effects. The graph must be + topologically sorted before calling. + + Args: + is_impure_node (Optional[Callable[[Node], bool]]): A function that returns + whether a node is impure. If this is None, then the default behavior is to + use Node.is_impure. + + Returns: + bool: Whether the graph was changed as a result of the pass. + + Example: + + Before dead code is eliminated, `a` from `a = x + 1` below has no users + and thus can be eliminated from the graph without having an effect. + + .. code-block:: python + + def forward(self, x): + a = x + 1 + return x + self.attr_1 + + After dead code is eliminated, `a = x + 1` has been removed, and the rest + of `forward` remains. + + .. code-block:: python + + def forward(self, x): + return x + self.attr_1 + + .. warning:: + + Dead code elimination has some heuristics to avoid removing + side-effectful nodes (see Node.is_impure) but in general coverage + is very bad, so you should assume that this method is not sound + to call unless you know that your FX graph consists entirely + of functional operations or you supply your own custom + function for detecting side-effectful nodes. + """ + from torch.utils._ordered_set import OrderedSet + + # Lint the graph first to make sure its topologically sorted, otherwise + # DCE below will not behave as expected. + self.lint() + + impure_random = True + if torch._guards.TracingContext.try_get(): + impure_random = torch._inductor.config.fallback_random + + def has_side_effect(node): + if is_impure_node is not None: + return is_impure_node(node) + return node.is_impure(impure_random) + + # Reverse iterate so that when we remove a node, any nodes used as an + # input to that node have an updated user count that no longer reflects + # the removed node. + changed = False + for node in reversed(self.nodes): + if not has_side_effect(node) and len(node.users) == 0: + self.erase_node(node) + changed = True + + # Call DCE on the subgraphs + if self.owning_module is not None: + subgraph_names = OrderedSet( + x.target for x in self.find_nodes(op="get_attr") + ) + for child_name, child_module in self.owning_module.named_children(): + # Sometimes an owning_module can have unused children. Skip them + # by checking them from get_attr node targets. + if child_name in subgraph_names and isinstance( + child_module, torch.fx.GraphModule + ): + changed |= child_module.graph.eliminate_dead_code() + child_module.recompile() + + return changed + + @compatibility(is_backward_compatible=False) + def set_codegen(self, codegen: CodeGen): + self._codegen = codegen + + @compatibility(is_backward_compatible=False) + def on_generate_code( + self, + make_transformer: Callable[[Optional[TransformCodeFunc]], TransformCodeFunc], + ): + """Register a transformer function when python code is generated + + Args: + make_transformer (Callable[[Optional[TransformCodeFunc]], TransformCodeFunc]): + a function that returns a code transformer to be registered. + This function is called by `on_generate_code` to obtain the + code transformer. + + This function is also given as its input the currently + registered code transformer (or None if nothing is registered), + in case it is not desirable to overwrite it. This is useful to + chain code transformers together. + + Returns: + a context manager that when used in a `with` statement, to automatically + restore the previously registered code transformer. + + Example: + + .. code-block:: python + + + gm: fx.GraphModule = ... + + + # This is a code transformer we want to register. This code + # transformer prepends a pdb import and trace statement at the very + # beginning of the generated torch.fx code to allow for manual + # debugging with the PDB library. + def insert_pdb(body): + return ["import pdb; pdb.set_trace()\\n", *body] + + + # Registers `insert_pdb`, and overwrites the current registered + # code transformer (given by `_` to the lambda): + gm.graph.on_generate_code(lambda _: insert_pdb) + + # Or alternatively, registers a code transformer which first + # runs `body` through existing registered transformer, then + # through `insert_pdb`: + gm.graph.on_generate_code( + lambda current_trans: ( + lambda body: insert_pdb( + current_trans(body) if current_trans else body + ) + ) + ) + + gm.recompile() + gm(*inputs) # drops into pdb + + + This function can also be used as a context manager, with the benefit to + automatically restores the previously registered code transformer: + + .. code-block:: python + + # ... continue from previous example + + with gm.graph.on_generate_code(lambda _: insert_pdb): + # do more stuff with `gm`... + gm.recompile() + gm(*inputs) # drops into pdb + + # now previous code transformer is restored (but `gm`'s code with pdb + # remains - that means you can run `gm` with pdb here too, until you + # run next `recompile()`). + """ + on_gen_code_old = self._codegen._body_transformer + self._codegen._body_transformer = make_transformer(on_gen_code_old) + + @contextlib.contextmanager + def on_generate_code_context_manager(): + try: + yield + finally: + self._codegen._body_transformer = on_gen_code_old + + return on_generate_code_context_manager() + + +@contextmanager +def _override_sym_repr( + override: Callable[["torch.types.PySymType"], str], +) -> Iterator[None]: + tmp = CodeGen._sym_repr + try: + CodeGen._sym_repr = override + yield + finally: + CodeGen._sym_repr = tmp + + +def _identity(x): + return x + + +def _make_color_fn(code): + def f(s): + reset = "\033[0m" + return f"{code}{s}{reset}" + + return f + + +_color_codes = { + "yellow": "\033[33m", + "cyan": "\033[36m", + "green": "\033[32m", + "blue": "\033[34m", + "red": "\033[31m", + "dim": "\033[2m", + "dim_blue": "\033[2m\033[34m", + "dim_green": "\033[2m\033[32m", +} +_color_fns = {k: _make_color_fn(v) for k, v in _color_codes.items()} +_counter_regexp = re.compile(r"# COUNTER: (\d+)") + + +reflectable_magic_methods = { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "div": "{} / {}", + "mod": "{} % {}", + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "getitem": "{}[{}]", + "matmul": "{} @ {}", +} + +magic_methods = { + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "pos": "+{}", + "neg": "-{}", + "invert": "~{}", + **reflectable_magic_methods, +} + +inplace_methods = { + "iadd": "{} += {}", + "iand": "{} &= {}", + "ifloordiv": "{} //= {}", + "ilshift": "{} <<= {}", + "imod": "{} %= {}", + "imul": "{} *= {}", + "imatmul": "{} @= {}", + "ior": "{} |= {}", + "ipow": "{} **= {}", + "irshift": "{} >>= {}", + "isub": "{} -= {}", + "itruediv": "{} /= {}", + "ixor": "{} ^= {}", + "setitem": "{}[{}] = {}", +} diff --git a/phivenv/Lib/site-packages/torch/fx/graph_module.py b/phivenv/Lib/site-packages/torch/fx/graph_module.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a363a67fee41584efc72d46a717d0218120b9f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/graph_module.py @@ -0,0 +1,1090 @@ +# mypy: allow-untyped-defs +import contextlib +import copy +import itertools +import linecache +import os +import sys +import traceback +import warnings +from pathlib import Path +from typing import Any, Callable, Optional, Union + +import torch +import torch.nn as nn +import torch.overrides +from torch.nn.modules.module import _addindent +from torch.package import Importer, PackageExporter, PackageImporter, sys_importer + +from ._compatibility import compatibility +from .graph import ( + _custom_builtins, + _is_from_torch, + _override_sym_repr, + _PyTreeCodeGen, + Graph, + PythonCode, +) + + +__all__ = [ + "reduce_graph_module", + "reduce_package_graph_module", + "reduce_deploy_graph_module", + "GraphModule", +] + +_USER_PRESERVED_ATTRIBUTES_KEY = "_user_preserved_attributes" + + +# Normal exec loses the source code, however we can work with +# the linecache module to recover it. +# Using _exec_with_source will add it to our local cache +# and then tools like TorchScript will be able to get source info. +class _EvalCacheLoader: + def __init__(self): + self.eval_cache = {} + self.next_id = 0 + + def cache(self, src: str, globals: dict[str, Any], co_fields=None): + """Store the source in a private cache, and add a lazy entry in linecache + that allows the source to be retrieved by 'filename'. + + Args: + src (str): The module source to cache + globals (dict): The module globals + + Returns: + str: The cache key (and dummy filename) generated for src. + """ + + key = self._get_key() + if co_fields: + key += f" from {co_fields['co_filename']}:{co_fields['co_firstlineno']} in {co_fields['co_name']}" + self.eval_cache[key] = src + + # Don't mutate globals so that this loader is only used + # to populate linecache, and doesn't interact with other modules + # that might check `__loader__` + globals_copy = globals.copy() + globals_copy["__file__"] = key + globals_copy["__name__"] = key + globals_copy["__loader__"] = self + linecache.lazycache(key, globals_copy) + + return key + + # Part of the loader protocol (PEP 302) + # linecache will use this method when trying to find source code + def get_source(self, module_name) -> Optional[str]: + if module_name in self.eval_cache: + return self.eval_cache[module_name] + return None + + def _get_key(self): + key = f".{self.next_id}" + self.next_id += 1 + return key + + +_loader = _EvalCacheLoader() + + +def _exec_with_source(src: str, globals: dict[str, Any], co_fields=None): + key = _loader.cache(src, globals, co_fields) + exec(compile(src, key, "exec"), globals) + + +def _forward_from_src(src: str, globals: dict[str, Any], co_fields=None): + return _method_from_src( + method_name="forward", src=src, globals=globals, co_fields=co_fields + ) + + +def _method_from_src( + method_name: str, src: str, globals: dict[str, Any], co_fields=None +) -> Callable: + # avoid mutating the passed in dict + globals_copy = globals.copy() + _exec_with_source(src, globals_copy, co_fields) + fn = globals_copy[method_name] + del globals_copy[method_name] + return fn + + +def _format_import_statement(name: str, obj: Any, importer: Importer) -> str: + if name in _custom_builtins: + return _custom_builtins[name].import_str + if _is_from_torch(name): + return "import torch" + module_name, attr_name = importer.get_name(obj) + return f"from {module_name} import {attr_name} as {name}" + + +def _format_import_block(globals: dict[str, Any], importer: Importer): + import_strs: set[str] = { + _format_import_statement(name, obj, importer) for name, obj in globals.items() + } + # Sort the imports so we have a stable import block that allows us to + # hash the graph module and get a consistent key for use in a cache. + return "\n".join(sorted(import_strs)) + + +@compatibility(is_backward_compatible=True) +def reduce_graph_module(body: dict[Any, Any], import_block: str) -> torch.nn.Module: + # BC: attribute name was changed from `code` to `_code` to facilitate + # making `code` into a property and adding a docstring to it + fn_src = body.get("_code") or body["code"] + forward = _forward_from_src(import_block + fn_src, {}) + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_package_graph_module( + importer: PackageImporter, body: dict[Any, Any], generated_module_name: str +) -> torch.nn.Module: + forward = importer.import_module(generated_module_name).forward + return _deserialize_graph_module(forward, body) + + +@compatibility(is_backward_compatible=True) +def reduce_deploy_graph_module( + importer: PackageImporter, body: dict[Any, Any], import_block: str +) -> torch.nn.Module: + ns = {} + ns["__builtins__"] = importer.patched_builtins + fn_src = body.get("_code") + assert fn_src is not None + forward = _forward_from_src(import_block + fn_src, ns) + return _deserialize_graph_module(forward, body) + + +# We create a dummy class here because symbolic_trace pulls the forward() +# function off of the class, rather than the instance. This class is used +# in _deserialize_graph_module() below. +class _CodeOnlyModule(torch.nn.Module): + def __init__(self, body): + super().__init__() + self.__dict__ = body + + +def _deserialize_graph_module( + forward, body: dict[Any, Any], graph_module_cls=None +) -> torch.nn.Module: + """ + Deserialize a GraphModule given the dictionary of the original module, + using the code to reconstruct the graph. We delete the actual graph before + saving the dictionary so that changes to the in-memory graph format do not + get serialized. + """ + + # Try to retrieve the forward source in a backward-compatible way + _CodeOnlyModule.forward = forward + + tracer_cls = body.get("_tracer_cls") + if tracer_cls is None: + from ._symbolic_trace import Tracer + + tracer_cls = Tracer + + graphmodule_cls_name = body.get("_graphmodule_cls_name", "GraphModule") + + # This is a workaround for a mypy linter issue related to + # passing base class as an argument - https://github.com/python/mypy/issues/5865. + cls_tracer: Any = tracer_cls + + class KeepModules(cls_tracer): + # we shouldn't trace into any of the submodules, + # because they were not traced in the original GraphModule + def is_leaf_module(self, _: torch.nn.Module, __: str) -> bool: + return True + + com = _CodeOnlyModule(body) + + tracer_extras = body.get("_tracer_extras", {}) + graph = KeepModules().trace(com, **tracer_extras) + + # Recover node.meta["stack_trace"] after re-tracing + node_meta_stack_trace = body.get("_graphmodule_graph_node_meta_stack_trace", None) + if node_meta_stack_trace is not None: + del body["_graphmodule_graph_node_meta_stack_trace"] + for node in graph.nodes: + if node_meta_stack_trace.get(node.name, None) is not None: + node.meta["stack_trace"] = node_meta_stack_trace[node.name] + + # Manually set Tracer class on the reconstructed Graph, to avoid + # referencing the private local subclass KeepModules. + graph._tracer_cls = tracer_cls + from ._lazy_graph_module import _make_graph_module + + gm = _make_graph_module( + com, graph, class_name=graphmodule_cls_name, graph_module_cls=graph_module_cls + ) + + # The GraphModule constructor only retains attributes referenced by the graph. + # In this case, our goal is return a GraphModule as close to identical as the one + # put into the package. If any additional attributes were present in body, + # we should keep them. + for k, v in body.items(): + if not hasattr(gm, k): + setattr(gm, k, v) + return gm + + +# copy an attribute value with qualified name 'target' from 'from_module' to 'to_module' +# This installs empty Modules where none exist yet if they are subpaths of target +def _copy_attr(from_module: torch.nn.Module, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + f = getattr(from_module, item) + t = getattr(to_module, item, None) + if f is t: + # we have already installed one of its parents + # (e.g. target = root.linear.weight, but we have already installed root.linear) + # once we install a parent, we no longer need to copy the children + # since all the needed properties will already be present + return + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + from_module, to_module = f, t + + orig = getattr(from_module, field) + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(orig, torch.Tensor) and not isinstance(orig, torch.nn.Parameter): + to_module.register_buffer(field, orig) + else: + setattr(to_module, field, orig) + + +# Assign attribute 'from_obj' to the qualified name 'target' on 'to_module +# This installs empty Modules where none exist yet if they are subpaths of target +def _assign_attr(from_obj: Any, to_module: torch.nn.Module, target: str): + *prefix, field = target.split(".") + for item in prefix: + t = getattr(to_module, item, None) + + if t is None: + t = torch.nn.Module() + setattr(to_module, item, t) + to_module = t + + # If it is a tensor and not a parameter attribute of a module, it should be a named buffer. + # So, we register it as a named buffer in the target module. + if isinstance(from_obj, torch.Tensor) and not isinstance( + from_obj, torch.nn.Parameter + ): + to_module.register_buffer(field, from_obj) + else: + setattr(to_module, field, from_obj) + + +# Recursively look up target from a graph module. +def _get_attr(model: torch.nn.Module, attr_name: str): + return _get_attr_via_attr_list(model, attr_name.split(".")) + + +def _del_attr(model: torch.nn.Module, attr_name: str): + attr_names = attr_name.split(".") + t = _get_attr_via_attr_list(model, attr_names[:-1]) + return delattr(t, attr_names[-1]) + + +def _get_attr_via_attr_list(model: torch.nn.Module, attr_list: list[str]): + if len(attr_list) == 0: + return model + *prefix, field = attr_list + t = model + for item in prefix: + t = getattr(t, item, None) # type: ignore[assignment] + assert t is not None + + return getattr(t, field) + + +def _has_attr(model: torch.nn.Module, attr_name: str): + *prefix, field = attr_name.split(".") + t = model + for item in prefix: + t = hasattr(t, item) # type: ignore[assignment] + if t is False: + return False + + return hasattr(t, field) + + +def _print_readable( + module, + module_name, + print_output=True, + include_stride=False, + include_device=False, + colored=False, +): + graph = module.graph + assert graph is not None and isinstance(graph, torch.fx.Graph), ( + "print_readable must be used on a module with a graph" + ) + + verbose_python_code = graph.python_code( + root_module="self", + verbose=True, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + module_code = verbose_python_code.src + module_code = module_code.lstrip("\n") + module_code = f"class {module_name}(torch.nn.Module):\n" + module_code + module_code = _addindent(module_code, 4) + + submodule_code_list = [""] + for submodule_name, submodule in module.named_children(): + if hasattr(submodule, "graph"): + submodule_code_list.append( + _print_readable( + submodule, + submodule_name, + print_output=False, + include_stride=include_stride, + include_device=include_device, + colored=colored, + ) + ) + submodule_code = "\n".join(submodule_code_list) + submodule_code = _addindent(submodule_code, 4) + + output = module_code + submodule_code + if print_output: + print(module_code + submodule_code) + return output + + +class _WrappedCall: + def __init__(self, cls, cls_call): + self.cls = cls + self.cls_call = cls_call + + # Previously, if an error occurred when valid + # symbolically-traced code was run with an invalid input, the + # user would see the source of the error as coming from + # `File "`, where N is some number. We use + # this function to generate a more informative error message. We + # return the traceback itself, a message explaining that the + # error occurred in a traced Module's generated forward + # function, and five lines of context surrounding the faulty + # line + @staticmethod + def _generate_error_message(frame_summary: traceback.FrameSummary) -> str: + # auxiliary variables (for readability) + err_lineno = frame_summary.lineno + assert err_lineno is not None + line = frame_summary.line + assert line is not None + err_line_len = len(line) + all_src_lines = linecache.getlines(frame_summary.filename) + + # constituent substrings of the error message + tb_repr = torch._dynamo.disable( + traceback.format_exc, + reason="do not trace into traceback.format_exc when generating error message", + )() + custom_msg = ( + "Call using an FX-traced Module, " + f"line {err_lineno} of the traced Module's " + "generated forward function:" + ) + before_err = "".join(all_src_lines[err_lineno - 2 : err_lineno]) + marker = "~" * err_line_len + "~~~ <--- HERE" + err_and_after_err = "\n".join(all_src_lines[err_lineno : err_lineno + 2]) + + # joined message + return "\n".join([tb_repr, custom_msg, before_err, marker, err_and_after_err]) + + def __call__(self, obj, *args, **kwargs): + try: + if self.cls_call is not None: + return self.cls_call(obj, *args, **kwargs) + else: + return super(self.cls, obj).__call__(*args, **kwargs) # type: ignore[misc] + except Exception as e: + assert e.__traceback__ + topmost_framesummary: traceback.FrameSummary = ( + traceback.StackSummary.extract(traceback.walk_tb(e.__traceback__))[-1] + ) + if "eval_with_key" in topmost_framesummary.filename: + print( + _WrappedCall._generate_error_message(topmost_framesummary), + file=sys.stderr, + ) + raise e.with_traceback(None) # noqa: B904 + else: + raise e + + +@compatibility(is_backward_compatible=True) +class GraphModule(torch.nn.Module): + """ + GraphModule is an nn.Module generated from an fx.Graph. Graphmodule has a + ``graph`` attribute, as well as ``code`` and ``forward`` attributes generated + from that ``graph``. + + .. warning:: + + When ``graph`` is reassigned, ``code`` and ``forward`` will be automatically + regenerated. However, if you edit the contents of the ``graph`` without reassigning + the ``graph`` attribute itself, you must call ``recompile()`` to update the generated + code. + """ + + def __new__(cls: "type[GraphModule]", *args, **kwargs): + # each instance of a graph module needs its own forward method + # so create a new singleton class for each instance. + # it is a subclass of the user-defined class, the only difference + # is an extra layer to install the forward method + + # address issue described at https://github.com/pytorch/pytorch/issues/63883 + # in other words, traverse class hierarchy to fix the redundant class definition problem + for t in cls.__mro__: + c = t.__qualname__.split(".")[-1] + if c != "GraphModuleImpl": + cls = t + break + + class GraphModuleImpl(cls): # type: ignore[misc, valid-type] + pass + + return super().__new__(GraphModuleImpl) + + @compatibility(is_backward_compatible=True) + def __init__( + self, + root: Union[torch.nn.Module, dict[str, Any]], + graph: Graph, + class_name: str = "GraphModule", + ): + """ + Construct a GraphModule. + + Args: + + root (Union[torch.nn.Module, Dict[str, Any]): + ``root`` can either be an nn.Module instance or a Dict mapping strings to any attribute type. + In the case that ``root`` is a Module, any references to Module-based objects (via qualified + name) in the Graph's Nodes' ``target`` field will be copied over from the respective place + within ``root``'s Module hierarchy into the GraphModule's module hierarchy. + In the case that ``root`` is a dict, the qualified name found in a Node's ``target`` will be + looked up directly in the dict's keys. The object mapped to by the Dict will be copied + over into the appropriate place within the GraphModule's module hierarchy. + + graph (Graph): ``graph`` contains the nodes this GraphModule should use for code generation + + class_name (str): ``name`` denotes the name of this GraphModule for debugging purposes. If it's unset, all + error messages will report as originating from ``GraphModule``. It may be helpful to set this + to ``root``'s original name or a name that makes sense within the context of your transform. + """ + super().__init__() + self.__class__.__name__ = class_name + if isinstance(root, torch.nn.Module): + if hasattr(root, "training"): + self.training = root.training + + # When we pickle/unpickle graph module, we don't want to drop any module or attributes. + if isinstance(root, _CodeOnlyModule): + for k, _ in root.named_children(): + _copy_attr(root, self, k) + + for k, _ in root.named_buffers(): + _copy_attr(root, self, k) + + for k, _ in root.named_parameters(): + _copy_attr(root, self, k) + + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + _copy_attr(root, self, node.target) + elif isinstance(root, dict): + targets_to_copy = [] + for node in graph.nodes: + if node.op in ["get_attr", "call_module"]: + assert isinstance(node.target, str) + if node.target not in root: + raise RuntimeError( + "Node " + + str(node) + + " referenced target " + + node.target + + " but that target was not provided in ``root``!" + ) + targets_to_copy.append(node.target) + # Sort targets in ascending order of the # of atoms. + # This will ensure that less deeply nested attributes are assigned + # before more deeply nested attributes. For example, foo.bar + # will be assigned before foo.bar.baz. Otherwise, we might assign + # the user-provided ``foo.bar`` and wipe out the previously-assigned + # ``foo.bar.baz`` + targets_to_copy.sort(key=lambda t: t.count(".")) + for target_to_copy in targets_to_copy: + _assign_attr(root[target_to_copy], self, target_to_copy) + else: + raise RuntimeError("Unsupported type " + str(root) + " passed for root!") + + self.graph = graph + + # Store the Tracer class responsible for creating a Graph separately as part of the + # GraphModule state, except when the Tracer is defined in a local namespace. + # Locally defined Tracers are not pickleable. This is needed because torch.package will + # serialize a GraphModule without retaining the Graph, and needs to use the correct Tracer + # to re-create the Graph during deserialization. + self._tracer_cls = None + if ( + self.graph._tracer_cls + and "" not in self.graph._tracer_cls.__qualname__ + ): + self._tracer_cls = self.graph._tracer_cls + + self._tracer_extras = {} + if self.graph._tracer_extras: + self._tracer_extras = self.graph._tracer_extras + + # Dictionary to store metadata + self.meta: dict[str, Any] = {} + self._replace_hooks: list[Callable] = [] + self._create_node_hooks: list[Callable] = [] + self._erase_node_hooks: list[Callable] = [] + # Used to remove hooks from deepcopied graph modules within a context manager. + self._deepcopy_hooks: list[Callable] = [] + + # TorchScript breaks trying to compile the graph setter because of the + # continued string literal. Issue here: https://github.com/pytorch/pytorch/issues/44842 + # + # Shouldn't be an issue since these methods shouldn't be used in TorchScript anyway + __jit_unused_properties__ = ["graph"] + + @property + def graph(self) -> Graph: + """ + Return the ``Graph`` underlying this ``GraphModule`` + """ + return self._graph + + @graph.setter + def graph(self, g: Graph) -> None: + """ + Set the underlying ``Graph`` for this ``GraphModule``. This will internally + recompile the ``GraphModule`` so that the generated ``forward()`` function + corresponds to ``g`` + """ + assert isinstance(g, Graph), f"Expected a Graph instance, but got {type(g)}" + self._graph = g + g.owning_module = self + self.recompile() + + @compatibility(is_backward_compatible=False) + def to_folder(self, folder: Union[str, os.PathLike], module_name: str = "FxModule"): + """Dumps out module to ``folder`` with ``module_name`` so that it can be + imported with ``from import `` + + Args: + + folder (Union[str, os.PathLike]): The folder to write the code out to + + module_name (str): Top-level name to use for the ``Module`` while + writing out the code + """ + folder = Path(folder) + Path(folder).mkdir(exist_ok=True) + torch.save(self.state_dict(), folder / "state_dict.pt") + tab = " " * 4 + custom_builtins = "\n".join([v.import_str for v in _custom_builtins.values()]) + model_str = f""" +import torch +{custom_builtins} + +from torch.nn import * +class {module_name}(torch.nn.Module): + def __init__(self): + super().__init__() +""" + + def _gen_model_repr(module_name: str, module: torch.nn.Module) -> Optional[str]: + safe_reprs = [ + nn.Linear, + nn.Conv1d, + nn.Conv2d, + nn.Conv3d, + nn.BatchNorm1d, + nn.BatchNorm2d, + nn.BatchNorm3d, + ] + if type(module) in safe_reprs: + return f"{module.__repr__()}" + else: + return None + + blobified_modules = [] + for module_name, module in self.named_children(): + module_str = _gen_model_repr(module_name, module) + if module_str is None: + module_file = folder / f"{module_name}.pt" + torch.save(module, module_file) + blobified_modules.append(module_name) + module_repr = module.__repr__().replace("\r", " ").replace("\n", " ") + # weights_only=False as this is legacy code that saves the model + module_str = ( + f"torch.load(r'{module_file}', weights_only=False) # {module_repr}" + ) + model_str += f"{tab * 2}self.{module_name} = {module_str}\n" + + for buffer_name, buffer in self._buffers.items(): + if buffer is None: + continue + model_str += f"{tab * 2}self.register_buffer('{buffer_name}', torch.empty({list(buffer.shape)}, dtype={buffer.dtype}))\n" # noqa: B950 + + for param_name, param in self._parameters.items(): + if param is None: + continue + model_str += f"{tab * 2}self.{param_name} = torch.nn.Parameter(torch.empty({list(param.shape)}, dtype={param.dtype}))\n" # noqa: B950 + + model_str += ( + f"{tab * 2}self.load_state_dict(torch.load(r'{folder}/state_dict.pt'))\n" + ) + model_str += f"{_addindent(self.code, 4)}\n" + + module_file = folder / "module.py" + module_file.write_text(model_str) + + init_file = folder / "__init__.py" + init_file.write_text("from .module import *") + + if len(blobified_modules) > 0: + warnings.warn( + "Was not able to save the following children modules as reprs -" + f"saved as pickled files instead: {blobified_modules}" + ) + + @compatibility(is_backward_compatible=True) + def add_submodule(self, target: str, m: torch.nn.Module) -> bool: + """ + Adds the given submodule to ``self``. + + This installs empty Modules where none exist yet if they are + subpaths of ``target``. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + m: The submodule itself; the actual object we want to + install in the current Module + + Return: + bool: Whether or not the submodule could be inserted. For + this method to return True, each object in the chain + denoted by ``target`` must either a) not exist yet, + or b) reference an ``nn.Module`` (not a parameter or + other attribute) + """ + *prefix, field = target.split(".") + mod: torch.nn.Module = self + + for item in prefix: + submod = getattr(mod, item, None) + + if submod is None: + submod = torch.nn.Module() + setattr(mod, item, submod) + + if not isinstance(submod, torch.nn.Module): + return False + + mod = submod + + mod.add_module(field, m) + return True + + @compatibility(is_backward_compatible=True) + def delete_submodule(self, target: str) -> bool: + """ + Deletes the given submodule from ``self``. + + The module will not be deleted if ``target`` is not a valid + target. + + Args: + target: The fully-qualified string name of the new submodule + (See example in ``nn.Module.get_submodule`` for how to + specify a fully-qualified string.) + + Returns: + bool: Whether or not the target string referenced a + submodule we want to delete. A return value of ``False`` + means that the ``target`` was not a valid reference to + a submodule. + """ + atoms = target.split(".") + path, target_submod = atoms[:-1], atoms[-1] + mod: torch.nn.Module = self + + # Get the parent module + for item in path: + if not hasattr(mod, item): + return False + + mod = getattr(mod, item) + + if not isinstance(mod, torch.nn.Module): + return False + + if not hasattr(mod, target_submod): + return False + + if not isinstance(getattr(mod, target_submod), torch.nn.Module): + return False + + delattr(mod, target_submod) + return True + + @compatibility(is_backward_compatible=True) + def delete_all_unused_submodules(self) -> None: + """ + Deletes all unused submodules from ``self``. + + A Module is considered "used" if any one of the following is + true: + 1. It has children that are used + 2. Its forward is called directly via a ``call_module`` node + 3. It has a non-Module attribute that is used from a + ``get_attr`` node + + This method can be called to clean up an ``nn.Module`` without + manually calling ``delete_submodule`` on each unused submodule. + """ + used: list[str] = [] + + for node in self.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + # A list of strings representing the different parts + # of the path. For example, `foo.bar.baz` gives us + # ["foo", "bar", "baz"] + fullpath = node.target.split(".") + + # If we're looking at multiple parts of a path, join + # join them with a dot. Otherwise, return that single + # element without doing anything to it. + def join_fn(x: str, y: str) -> str: + return ".".join([x, y] if y else [x]) + + # Progressively collect all the names of intermediate + # modules. For example, if we have the target + # `foo.bar.baz`, we'll add `foo`, `foo.bar`, and + # `foo.bar.baz` to the list. + used.extend(itertools.accumulate(fullpath, join_fn)) + + # For a `call_module` node, also register all recursive submodules + # as used + if node.op == "call_module": + try: + submod = self.get_submodule(node.target) + + for submod_name, _ in submod.named_modules(): + if submod_name != "": + used.append(".".join([node.target, submod_name])) + except AttributeError: + # Node referenced nonexistent submodule, don't need to + # worry about GCing anything + pass + + to_delete = [name for name, _ in self.named_modules() if name not in used] + + for name in to_delete: + self.delete_submodule(name) + + @property + def code(self) -> str: + """ + Return the Python code generated from the ``Graph`` underlying this + ``GraphModule``. + """ + if not hasattr(self, "_code"): + raise RuntimeError( + "Code has not been generated! Please report a bug to PyTorch" + ) + return self._code + + @compatibility(is_backward_compatible=True) + def recompile(self) -> PythonCode: + """ + Recompile this GraphModule from its ``graph`` attribute. This should be + called after editing the contained ``graph``, otherwise the generated + code of this ``GraphModule`` will be out of date. + """ + if isinstance(self._graph._codegen, _PyTreeCodeGen): + self._in_spec = self._graph._codegen.pytree_info.in_spec + self._out_spec = self._graph._codegen.pytree_info.out_spec + python_code = self._graph.python_code(root_module="self") + self._code = python_code.src + self._lineno_map = python_code._lineno_map + + cls = type(self) + co_fields = self._graph._co_fields if hasattr(self._graph, "_co_fields") else {} + cls.forward = _forward_from_src(self._code, python_code.globals, co_fields) + + # Determine whether this class explicitly defines a __call__ implementation + # to wrap. If it does, save it in order to have wrapped_call invoke it. + # If it does not, wrapped_call can use a dynamic call to super() instead. + # In most cases, super().__call__ should be torch.nn.Module.__call__. + # We do not want to hold a reference to Module.__call__ here; doing so will + # bypass patching of torch.nn.Module.__call__ done while symbolic tracing. + cls_call = cls.__call__ if "__call__" in vars(cls) else None + + if "_wrapped_call" not in vars(cls): + cls._wrapped_call = _WrappedCall(cls, cls_call) # type: ignore[attr-defined] + + def call_wrapped(self, *args, **kwargs): + return self._wrapped_call(self, *args, **kwargs) + + cls.__call__ = call_wrapped # type: ignore[method-assign] + + return python_code + + # Passing Tracer as argument allows subclasses extending fx.GraphModule + # define their own Tracer (extending fx.Tracer). + def __reduce_deploy__(self, importer: Importer): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, importer) + return (reduce_deploy_graph_module, (dict_without_graph, import_block)) + + def __reduce_package__(self, exporter: PackageExporter): + dict_without_graph = self.__dict__.copy() + dict_without_graph["_graphmodule_cls_name"] = self.__class__.__name__ + del dict_without_graph["_graph"] + + # Store node.meta["stack_trace"] so we can recover them after re-tracing during deserialization + node_meta_stack_trace = { + node.name: node.meta["stack_trace"] + for node in self.graph.nodes + if "stack_trace" in node.meta + } + dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = ( + node_meta_stack_trace + ) + + generated_module_name = f"fx-generated._{exporter.get_unique_id()}" + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, exporter.importer) + module_code = import_block + self.code + exporter.save_source_string(generated_module_name, module_code) + return ( + reduce_package_graph_module, + (dict_without_graph, generated_module_name), + ) + + def __reduce__(self): + """ + Serialization of GraphModule. We serialize only the generated code, not + the underlying ``Graph``. This is because ``Graph`` does not have on-disk + backward-compatibility guarantees, whereas Python source code does. + On the deserialization side, we symbolically trace through the generated + code to regenerate the underlying ``Graph`` + """ + dict_without_graph = self.__dict__.copy() + + python_code = self.recompile() + import_block = _format_import_block(python_code.globals, sys_importer) + del dict_without_graph["_graph"] + return (reduce_graph_module, (dict_without_graph, import_block)) + + def _deepcopy_init(self): + return GraphModule.__init__ + + # because __reduce__ is defined for serialization, + # we need to define deepcopy otherwise it will call __reduce__ + # and cause symbolic tracing to occur every time we try to copy the object + def __deepcopy__(self, memo): + res = type(self).__new__(type(self)) + memo[id(self)] = res + fake_mod = _CodeOnlyModule(copy.deepcopy(self.__dict__, memo)) + self._deepcopy_init()(res, fake_mod, fake_mod.__dict__["_graph"]) + # hooks are lost during `GraphModule.__init__`, so we need to copy over + # them explicitly, note right now we are only copying state_dict related + # hooks, to reduce bc-related issues, we can copy forward/backward related + # hooks in the future as well if needed + extra_preserved_attrs = [ + "_state_dict_hooks", + "_load_state_dict_pre_hooks", + "_load_state_dict_post_hooks", + "_replace_hooks", + "_create_node_hooks", + "_erase_node_hooks", + "_deepcopy_hooks", + ] + for attr in extra_preserved_attrs: + if attr in self.__dict__: + setattr(res, attr, copy.deepcopy(self.__dict__[attr], memo)) + res.meta = copy.deepcopy(getattr(self, "meta", {}), memo) + if _USER_PRESERVED_ATTRIBUTES_KEY in res.meta: + for attr_name, attr in res.meta[_USER_PRESERVED_ATTRIBUTES_KEY].items(): + setattr(res, attr_name, attr) + if hasattr(self, "_deepcopy_hooks"): + for hook in self._deepcopy_hooks: + hook(res) + return res + + def __copy__(self): + from ._lazy_graph_module import _make_graph_module + + res = _make_graph_module(self, self.graph) + res.meta = getattr(self, "meta", {}) + return res + + @compatibility(is_backward_compatible=False) + def print_readable( + self, + print_output=True, + include_stride=False, + include_device=False, + colored=False, + *, + # If `fast_sympy_print` is True then we use a sympy printer which is faster + # but may result in less-readable output. + fast_sympy_print: bool = False, + ): + """ + Return the Python code generated for current GraphModule and its children GraphModules + """ + ctx_mgr = contextlib.ExitStack() + with ctx_mgr: + if fast_sympy_print: + from torch._inductor.utils import sympy_str + + def fast_repr(expr: torch.types.PySymType) -> str: + return sympy_str(expr.node.expr) + + ctx_mgr.enter_context(_override_sym_repr(fast_repr)) + + r = _print_readable( + self, + self._get_name(), + print_output, + include_stride, + include_device, + colored, + ) + return r + + def __str__(self) -> str: + orig_str = super().__str__() + print_readable_reminder = ( + "# To see more debug info, please use `graph_module.print_readable()`" + ) + return "\n".join([orig_str, self._code, print_readable_reminder]) + + def _replicate_for_data_parallel(self): + new_gm = self.__copy__() + new_gm._is_replica = True + return new_gm + + @contextlib.contextmanager + def _set_replace_hook(self, f): + """ + Takes a callable which will be called everytime when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "Replace hook must be a callable." + self._register_replace_node_hook(f) + try: + yield + finally: + self._unregister_replace_node_hook(f) + + def _register_replace_node_hook(self, f): + """ + Takes a callable which will be called everytime when we replace a node + to a new node, or change the node's name. Callable takes three arguments: + the old node we're changing, and NAME of the new node, followed by the + user node which consumes the old node to be replaced. + """ + assert callable(f), "create_node hook must be a callable." + self._replace_hooks.append(f) + + def _unregister_replace_node_hook(self, f): + """ + Takes a callable which was previously registered to be called everytime when we replace a node. + This function will unregister that callable so it is no longer invoked on node replacement. + """ + assert callable(f), "create_node hook must be a callable." + self._replace_hooks.remove(f) + + def _register_create_node_hook(self, f): + """ + Takes a callable which will be called after we create a new node. The + callable takes the newly created node as input and returns None. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.append(f) + + def _unregister_create_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we create a node. + This function will unregister that callable so it is no longer invoked on node creation. + """ + assert callable(f), "create_node hook must be a callable." + self._create_node_hooks.remove(f) + + def _register_erase_node_hook(self, f): + """ + Takes a callable which will be called after we erase a node. The + callable takes the node that is being erased as input and returns None. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.append(f) + + def _unregister_erase_node_hook(self, f): + """ + Takes a callable which was previously registered to be called after we erase a node. + This function will unregister that callable so it is no longer invoked on node erasure. + """ + assert callable(f), "erase_node hook must be a callable." + self._erase_node_hooks.remove(f) + + def _register_deepcopy_hook(self, f): + """ + Takes a callable which will be called when we deepcopy this graph module. The + callable takes the resulting deepcopied graph module. + """ + assert callable(f), "deepcopy hook must be a callable." + self._deepcopy_hooks.append(f) + + def _unregister_deepcopy_hook(self, f): + """ + Takes a callable which was previously registered to be called after deepcopy. + This function will unregister that callable so it is no longer invoked on deepcopy. + """ + assert callable(f), "deepcopy hook must be a callable." + self._deepcopy_hooks.remove(f) + + +# workarounds for issues in __torch_function__ + +# WAR for __torch_function__ not handling tensor lists, +# fix is in https://github.com/pytorch/pytorch/pull/34725 +# orig_cat = torch.cat +# def patched_cat(*args, **kwargs): +# tensors = args[0] +# for t in tensors: +# if isinstance(t, Proxy): +# return t.__torch_function__(patched_cat, (), args, kwargs) +# return orig_cat(*args, **kwargs) +# patched_cat.__module__ = 'torch' +# patched_cat.__name__ = 'cat' +# torch.cat = patched_cat diff --git a/phivenv/Lib/site-packages/torch/fx/immutable_collections.py b/phivenv/Lib/site-packages/torch/fx/immutable_collections.py new file mode 100644 index 0000000000000000000000000000000000000000..4f4bd21180361942ed397667d739dc269b28bfd9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/immutable_collections.py @@ -0,0 +1,122 @@ +from collections.abc import Iterable +from typing import Any, NoReturn, TypeVar +from typing_extensions import Self + +from torch.utils._pytree import ( + _dict_flatten, + _dict_flatten_with_keys, + _dict_unflatten, + _list_flatten, + _list_flatten_with_keys, + _list_unflatten, + Context, + register_pytree_node, +) + +from ._compatibility import compatibility + + +__all__ = ["immutable_list", "immutable_dict"] + + +_help_mutation = """ +If you are attempting to modify the kwargs or args of a torch.fx.Node object, +instead create a new copy of it and assign the copy to the node: + + new_args = ... # copy and mutate args + node.args = new_args +""".strip() + + +_T = TypeVar("_T") +_KT = TypeVar("_KT") +_VT = TypeVar("_VT") + + +def _no_mutation(self: Any, *args: Any, **kwargs: Any) -> NoReturn: + raise TypeError( + f"{type(self).__name__!r} object does not support mutation. {_help_mutation}", + ) + + +@compatibility(is_backward_compatible=True) +class immutable_list(list[_T]): + """An immutable version of :class:`list`.""" + + __delitem__ = _no_mutation + __iadd__ = _no_mutation + __imul__ = _no_mutation + __setitem__ = _no_mutation + append = _no_mutation + clear = _no_mutation + extend = _no_mutation + insert = _no_mutation + pop = _no_mutation + remove = _no_mutation + reverse = _no_mutation + sort = _no_mutation + + def __hash__(self) -> int: # type: ignore[override] + return hash(tuple(self)) + + def __reduce__(self) -> tuple[type[Self], tuple[tuple[_T, ...]]]: + return (type(self), (tuple(self),)) + + +@compatibility(is_backward_compatible=True) +class immutable_dict(dict[_KT, _VT]): + """An immutable version of :class:`dict`.""" + + __delitem__ = _no_mutation + __ior__ = _no_mutation + __setitem__ = _no_mutation + clear = _no_mutation + pop = _no_mutation + popitem = _no_mutation + setdefault = _no_mutation + update = _no_mutation # type: ignore[assignment] + + def __hash__(self) -> int: # type: ignore[override] + return hash(frozenset(self.items())) + + def __reduce__(self) -> tuple[type[Self], tuple[tuple[tuple[_KT, _VT], ...]]]: + return (type(self), (tuple(self.items()),)) + + +# Register immutable collections for PyTree operations +def _immutable_list_flatten(d: immutable_list[_T]) -> tuple[list[_T], Context]: + return _list_flatten(d) + + +def _immutable_list_unflatten( + values: Iterable[_T], + context: Context, +) -> immutable_list[_T]: + return immutable_list(_list_unflatten(values, context)) + + +def _immutable_dict_flatten(d: immutable_dict[Any, _VT]) -> tuple[list[_VT], Context]: + return _dict_flatten(d) + + +def _immutable_dict_unflatten( + values: Iterable[_VT], + context: Context, +) -> immutable_dict[Any, _VT]: + return immutable_dict(_dict_unflatten(values, context)) + + +register_pytree_node( + immutable_list, + _immutable_list_flatten, + _immutable_list_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +register_pytree_node( + immutable_dict, + _immutable_dict_flatten, + _immutable_dict_unflatten, + serialized_type_name="torch.fx.immutable_collections.immutable_dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) diff --git a/phivenv/Lib/site-packages/torch/fx/interpreter.py b/phivenv/Lib/site-packages/torch/fx/interpreter.py new file mode 100644 index 0000000000000000000000000000000000000000..931101a159a5085263c8be4db7fdfebcdf624341 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/interpreter.py @@ -0,0 +1,603 @@ +# mypy: allow-untyped-defs +import inspect +from contextlib import contextmanager +from typing import Any, Optional, TYPE_CHECKING, Union + +import torch +import torch.fx.traceback as fx_traceback +from torch.hub import tqdm + +from . import config +from ._compatibility import compatibility +from ._lazy_graph_module import _make_graph_module +from ._symbolic_trace import Tracer +from .graph import Graph +from .graph_module import GraphModule +from .node import Argument, map_aggregate, map_arg, Node, Target +from .proxy import Proxy + + +if TYPE_CHECKING: + from collections.abc import Iterator + + +__all__ = ["Interpreter", "Transformer"] + + +@compatibility(is_backward_compatible=True) +class Interpreter: + """ + An Interpreter executes an FX graph Node-by-Node. This pattern + can be useful for many things, including writing code + transformations as well as analysis passes. + + Methods in the Interpreter class can be overridden to customize + the behavior of execution. The map of overrideable methods + in terms of call hierarchy:: + + run() + +-- run_node + +-- placeholder() + +-- get_attr() + +-- call_function() + +-- call_method() + +-- call_module() + +-- output() + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass Interpreter like so:: + + class NegSigmSwapInterpreter(Interpreter): + def call_function( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(target, args, kwargs) + + def call_method(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + if target == "neg": + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(target, args, kwargs) + + + def fn(x): + return torch.sigmoid(x).neg() + + + gm = torch.fx.symbolic_trace(fn) + input = torch.randn(3, 4) + result = NegSigmSwapInterpreter(gm).run(input) + torch.testing.assert_close(result, torch.neg(input).sigmoid()) + + Args: + module (torch.nn.Module): The module to be executed + garbage_collect_values (bool): Whether to delete values after their last + use within the Module's execution. This ensures optimal memory usage during + execution. This can be disabled to, for example, examine all of the intermediate + values in the execution by looking at the ``Interpreter.env`` attribute. + graph (Optional[Graph]): If passed, the interpreter will execute this + graph instead of `module.graph`, using the provided `module` + argument to satisfy any requests for state. + """ + + @compatibility(is_backward_compatible=True) + def __init__( + self, + module: torch.nn.Module, + garbage_collect_values: bool = True, + graph: Optional[Graph] = None, + ): + self.module = module + self.submodules = dict(self.module.named_modules()) + if graph is not None: + self.graph = graph + else: + self.graph = self.module.graph # type: ignore[assignment] + self.env: dict[Node, Any] = {} + self.name = "Interpreter" + self.garbage_collect_values = garbage_collect_values + self.extra_traceback = True + + if self.garbage_collect_values: + # Run through reverse nodes and record the first instance of a use + # of a given node. This represents the *last* use of the node in the + # execution order of the program, which we will use to free unused + # values + node_to_last_use: dict[Node, Node] = {} + self.user_to_last_uses: dict[Node, list[Node]] = {} + + def register_last_uses(n: Node, user: Node): + if n not in node_to_last_use: + node_to_last_use[n] = user + self.user_to_last_uses.setdefault(user, []).append(n) + + for node in reversed(self.graph.nodes): + for n in node._input_nodes: + register_last_uses(n, node) + + @compatibility(is_backward_compatible=True) + def run( + self, + *args, + initial_env: Optional[dict[Node, Any]] = None, + enable_io_processing: bool = True, + ) -> Any: + """ + Run `module` via interpretation and return the result. + + Args: + *args: The arguments to the Module to run, in positional order + initial_env (Optional[Dict[Node, Any]]): An optional starting environment for execution. + This is a dict mapping `Node` to any value. This can be used, for example, to + pre-populate results for certain `Nodes` so as to do only partial evaluation within + the interpreter. + enable_io_processing (bool): If true, we process the inputs and outputs with graph's process_inputs and + process_outputs function first before using them. + + Returns: + Any: The value returned from executing the Module + """ + self.env = initial_env if initial_env is not None else {} + + # Positional function args are consumed left-to-right by + # `placeholder` nodes. Use an iterator to keep track of + # position and extract those values. + if enable_io_processing: + args = self.graph.process_inputs(*args) + self.args_iter: Iterator[Any] = iter(args) + pbar = tqdm( + total=len(self.graph.nodes), + desc=f"{self.name}: {str(list(self.graph.nodes)) if config.verbose_progress else ''}", + initial=0, + position=0, + leave=True, + disable=config.disable_progress, + delay=0, + ) + + for node in self.graph.nodes: + pbar.update(1) + if node in self.env: + # Short circuit if we have this value. This could + # be used, for example, for partial evaluation + # where the caller has pre-populated `env` with + # values for a subset of the program. + continue + + try: + self.env[node] = self.run_node(node) + except Exception as e: + if self.extra_traceback: + msg = f"While executing {node.format_node()}" + msg = f"{e.args[0]}\n\n{msg}" if e.args else str(msg) + if ( + isinstance(self.module, GraphModule) + and self.module.graph is not None + and isinstance(self.module.graph, torch.fx.Graph) + ): + msg += f"\nGraphModule: {self.module.print_readable(print_output=False, include_stride=True)}\n" + msg += f"\nOriginal traceback:\n{node.stack_trace}" + e.args = (msg,) + e.args[1:] + if isinstance(e, KeyError): + raise RuntimeError(*e.args) from e + raise + + if self.garbage_collect_values: + for to_delete in self.user_to_last_uses.get(node, []): + del self.env[to_delete] + + if node.op == "output": + output_val = self.env[node] + return ( + self.graph.process_outputs(output_val) + if enable_io_processing + else output_val + ) + + @compatibility(is_backward_compatible=True) + def boxed_run(self, args_list): + """ + Run `module` via interpretation and return the result. This uses the "boxed" + calling convention, where you pass a list of arguments, which will be cleared + by the interpreter. This ensures that input tensors are promptly deallocated. + """ + args_iter = iter(args_list) + env = {} + for n in self.graph.nodes: + if n.op == "placeholder": + env[n] = next(args_iter) + args_list.clear() + return self.run(initial_env=env) + + @contextmanager + def _set_current_node(self, node): + with fx_traceback.set_current_meta( + node, f"Interpreter_{self.__class__.__name__}" + ): + yield + + @compatibility(is_backward_compatible=True) + def run_node(self, n: Node) -> Any: + """ + Run a specific node ``n`` and return the result. + Calls into placeholder, get_attr, call_function, + call_method, call_module, or output depending + on ``node.op`` + + Args: + n (Node): The Node to execute + + Returns: + Any: The result of executing ``n`` + """ + with self._set_current_node(n): + args, kwargs = self.fetch_args_kwargs_from_env(n) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + return getattr(self, n.op)(n.target, args, kwargs) + + # Main Node running APIs + @compatibility(is_backward_compatible=True) + def placeholder( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``placeholder`` node. Note that this is stateful: + ``Interpreter`` maintains an internal iterator over + arguments passed to ``run`` and this method returns + next() on that iterator. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Returns: + Any: The argument value that was retrieved. + """ + assert isinstance(target, str) + if target.startswith("*"): + # For a starred parameter e.g. `*args`, retrieve all + # remaining values from the args list. + return list(self.args_iter) + else: + try: + return next(self.args_iter) + except StopIteration as si: + if len(args) > 0: + return args[0] + else: + raise RuntimeError( + f"Expected positional argument for parameter {target}, but one was not passed in!" + ) from si + + @compatibility(is_backward_compatible=True) + def get_attr( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``get_attr`` node. Will retrieve an attribute + value from the ``Module`` hierarchy of ``self.module``. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The value of the attribute that was retrieved + """ + assert isinstance(target, str) + return self.fetch_attr(target) + + @compatibility(is_backward_compatible=True) + def call_function( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``call_function`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the function invocation + """ + assert not isinstance(target, str) + + # Execute the function and return the result + return target(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_method( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``call_method`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the method invocation + """ + # args[0] is the `self` object for this method call + self_obj, *args_tail = args + + # Execute the method and return the result + assert isinstance(target, str) + return getattr(self_obj, target)(*args_tail, **kwargs) + + @compatibility(is_backward_compatible=True) + def call_module( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute a ``call_module`` node and return the result. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return + Any: The value returned by the module invocation + """ + # Retrieve executed args and kwargs values from the environment + + # Execute the method and return the result + assert isinstance(target, str) + submod = self.fetch_attr(target) + + return submod(*args, **kwargs) + + @compatibility(is_backward_compatible=True) + def output( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + """ + Execute an ``output`` node. This really just retrieves + the value referenced by the ``output`` node and returns it. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + + Return: + Any: The return value referenced by the output node + """ + return args[0] + + # Helper methods + @compatibility(is_backward_compatible=True) + def fetch_attr(self, target: str): + """ + Fetch an attribute from the ``Module`` hierarchy of ``self.module``. + + Args: + target (str): The fully-qualified name of the attribute to fetch + + Return: + Any: The value of the attribute. + """ + target_atoms = target.split(".") + attr_itr = self.module + for i, atom in enumerate(target_atoms): + if not hasattr(attr_itr, atom): + raise RuntimeError( + f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}" + ) + attr_itr = getattr(attr_itr, atom) + return attr_itr + + @compatibility(is_backward_compatible=True) + def fetch_args_kwargs_from_env(self, n: Node) -> tuple[tuple, dict]: + """ + Fetch the concrete values of ``args`` and ``kwargs`` of node ``n`` + from the current execution environment. + + Args: + n (Node): The node for which ``args`` and ``kwargs`` should be fetched. + + Return: + Tuple[Tuple, Dict]: ``args`` and ``kwargs`` with concrete values for ``n``. + """ + args = self.map_nodes_to_values(n.args, n) + assert isinstance(args, tuple) + kwargs = self.map_nodes_to_values(n.kwargs, n) + assert isinstance(kwargs, dict) + return args, kwargs + + @compatibility(is_backward_compatible=True) + def map_nodes_to_values(self, args: Argument, n: Node) -> Argument: + """ + Recursively descend through ``args`` and look up the concrete value + for each ``Node`` in the current execution environment. + + Args: + args (Argument): Data structure within which to look up concrete values + + n (Node): Node to which ``args`` belongs. This is only used for error reporting. + """ + + def load_arg(n_arg: Node) -> Any: + if n_arg not in self.env: + raise RuntimeError( + f"Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() " + f"to diagnose such issues" + ) + return self.env[n_arg] + + return map_arg(args, load_arg) + + +@compatibility(is_backward_compatible=True) +class Transformer(Interpreter): + """ + ``Transformer`` is a special type of interpreter that produces a + new ``Module``. It exposes a ``transform()`` method that returns + the transformed ``Module``. ``Transformer`` does not require + arguments to run, as ``Interpreter`` does. ``Transformer`` works + entirely symbolically. + + Example: + + Suppose we want to swap all instances of ``torch.neg`` with + ``torch.sigmoid`` and vice versa (including their ``Tensor`` + method equivalents). We could subclass ``Transformer`` like so:: + + class NegSigmSwapXformer(Transformer): + def call_function( + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + ) -> Any: + if target == torch.sigmoid: + return torch.neg(*args, **kwargs) + return super().call_function(target, args, kwargs) + + def call_method( + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], + ) -> Any: + if target == "neg": + call_self, *args_tail = args + return call_self.sigmoid(*args_tail, **kwargs) + return super().call_method(target, args, kwargs) + + + def fn(x): + return torch.sigmoid(x).neg() + + + gm = torch.fx.symbolic_trace(fn) + + transformed: torch.nn.Module = NegSigmSwapXformer(gm).transform() + input = torch.randn(3, 4) + torch.testing.assert_close(transformed(input), torch.neg(input).sigmoid()) + + Args: + module (GraphModule): The ``Module`` to be transformed. + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, module): + super().__init__(module) + self.new_graph = Graph() + self.new_graph.set_codegen(module.graph._codegen) + + class TransformerTracer(Tracer): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.tensor_attrs: dict[torch.Tensor, str] = {} # type: ignore[assignment] + + def is_leaf_module(self, _, __) -> bool: + return True + + self.tracer = TransformerTracer(self.new_graph) + self.tracer.root = module + + @compatibility(is_backward_compatible=True) + def placeholder( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Proxy: + """ + Execute a ``placeholder`` node. In ``Transformer``, this is + overridden to insert a new ``placeholder`` into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + default_value = next(iter(args)) if args else inspect.Signature.empty + return Proxy( + self.new_graph.placeholder(target, default_value=default_value), self.tracer + ) + + @compatibility(is_backward_compatible=True) + def get_attr( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Proxy: + """ + Execute a ``get_attr`` node. In ``Transformer``, this is + overridden to insert a new ``get_attr`` node into the output + graph. + + Args: + target (Target): The call target for this node. See + `Node `__ for + details on semantics + args (Tuple): Tuple of positional args for this invocation + kwargs (Dict): Dict of keyword arguments for this invocation + """ + assert isinstance(target, str) + return self.tracer.create_proxy("get_attr", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_module( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + # Override so that the leaf module policy from `self.tracer` is respected. + assert isinstance(target, str) + submod = self.fetch_attr(target) + return self.tracer.call_module(submod, submod.forward, args, kwargs) + + @compatibility(is_backward_compatible=True) + def call_function( + self, target: "Target", args: tuple[Argument, ...], kwargs: dict[str, Any] + ) -> Any: + # Override so that functions that were wrapped are still wrapped. + return self.tracer.create_proxy("call_function", target, args, kwargs) + + @compatibility(is_backward_compatible=True) + def transform(self) -> GraphModule: + """ + Transform ``self.module`` and return the transformed + ``GraphModule``. + """ + with fx_traceback.preserve_node_meta(): + result = super().run(enable_io_processing=False) + if result is not None: + + def strip_proxy(a: Union[Argument, Proxy]) -> Any: + return a.node if isinstance(a, Proxy) else a + + new_output_node = self.new_graph.output(map_aggregate(result, strip_proxy)) + # also preserve the metadata from the old output node, if it exists + old_output_node = list(self.graph.nodes)[-1] + assert old_output_node.op == "output" + for k, v in old_output_node.meta.items(): + new_output_node.meta[k] = v + + return _make_graph_module(self.module, self.new_graph) diff --git a/phivenv/Lib/site-packages/torch/fx/node.py b/phivenv/Lib/site-packages/torch/fx/node.py new file mode 100644 index 0000000000000000000000000000000000000000..9de92b57433a2614cf1b8972078a8611379c602b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/node.py @@ -0,0 +1,888 @@ +# Nodes represent a definition of a value in our graph of operators. +import builtins +import inspect +import logging +import operator +import types +from collections.abc import Mapping, Sequence +from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union +from typing_extensions import ParamSpec + +import torch +from torch._C import _fx_map_aggregate, _fx_map_arg, _NodeBase +from torch.fx.operator_schemas import ( + ArgsKwargsPair, + normalize_function, + normalize_module, +) + +from .._ops import ops as _ops +from ._compatibility import compatibility + + +if TYPE_CHECKING: + from .graph import Graph + +__all__ = ["Node", "map_arg", "map_aggregate", "has_side_effect"] + +log = logging.getLogger(__name__) + +BaseArgumentTypes = Union[ + str, + int, + float, + bool, + complex, + torch.dtype, + torch.Tensor, + torch.device, + torch.memory_format, + torch.layout, + torch._ops.OpOverload, + torch.SymInt, + torch.SymBool, + torch.SymFloat, +] +base_types = BaseArgumentTypes.__args__ # type: ignore[attr-defined] + +Target = Union[Callable[..., Any], str] + +Argument = Optional[ + Union[ + tuple["Argument", ...], + Sequence["Argument"], + Mapping[str, "Argument"], + slice, # Slice[Argument, Argument, Argument], but slice is not a templated type in typing + range, + "Node", + BaseArgumentTypes, + ] +] +ArgumentT = TypeVar("ArgumentT", bound=Argument) +_P = ParamSpec("_P") +_R = TypeVar("_R") + +_legal_ops = dict.fromkeys( + [ + "placeholder", + "call_method", + "call_module", + "call_function", + "get_attr", + "output", + "root", + ] +) + +# Dynamo is unable to trace global set[Callable].__contains__. +# See https://github.com/pytorch/pytorch/issues/145761. Since we only have +# a handful of ops so switch to list of callables. +_side_effectful_need_to_be_preserved_pre_dispatch: list[Callable[..., Any]] = [ + torch._C._set_grad_enabled, + torch.amp._enter_autocast, + torch.amp._exit_autocast, +] + +# TODO: Either refactor this into 2 functions 1 dce for functional graphs and 1 dce for all graphs, +# or add logic to correctly mark all inplace ops as side effectful. +_side_effectful_functions: set[Callable[..., Any]] = { + torch._assert, + torch._assert_async, + _ops.aten._assert_async.msg, + _ops.aten._assert_scalar.default, + _ops.aten._assert_tensor_metadata.default, + _ops.aten.sym_constrain_range.default, + _ops.aten.sym_constrain_range_for_size.default, + _ops.profiler._record_function_enter, + _ops.profiler._record_function_enter_new, + _ops.profiler._record_function_exit, + _ops.inductor.accumulate_grad_.default, + operator.setitem, + *_side_effectful_need_to_be_preserved_pre_dispatch, +} + +if hasattr(_ops.inductor, "resize_storage_bytes_"): + _side_effectful_functions.add(_ops.inductor.resize_storage_bytes_.default) + + +@compatibility(is_backward_compatible=False) +def has_side_effect(fn: Callable[_P, _R]) -> Callable[_P, _R]: + _side_effectful_functions.add(fn) + return fn + + +# this is fixed on master, WAR for 1.5 +def _find_module_of_method(orig_method: Callable[..., Any]) -> str: + name = orig_method.__name__ + module = orig_method.__module__ + if module is not None: + return module + for guess in [torch, torch.nn.functional]: + if getattr(guess, name, None) is orig_method: + return guess.__name__ + raise RuntimeError(f"cannot find module for {orig_method}") + + +# Borrowed from CPython typing module +# https://github.com/python/cpython/blob/f90dc36c15d7fee0efaf6d39e97be0bdf2683e93/Lib/typing.py#L156 +def _type_repr(obj: object) -> str: + """Return the repr() of an object, special-casing types (internal helper). + If obj is a type, we return a shorter version than the default + type.__repr__, based on the module and qualified name, which is + typically enough to uniquely identify a type. For everything + else, we fall back on repr(obj). + """ + # Extension: If we don't ignore GenericAlias then `list[int]` will print + # simply "list". + if isinstance(obj, type) and not isinstance(obj, types.GenericAlias): + if obj.__module__ == "builtins": + return obj.__qualname__ + return f"{obj.__module__}.{obj.__qualname__}" + if obj is ...: + return "..." + if isinstance(obj, types.FunctionType): + return obj.__name__ + return repr(obj) + + +def _get_qualified_name(func: Callable[..., Any]) -> str: + # things like getattr just appear in builtins + if getattr(builtins, func.__name__, None) is func: + return func.__name__ + # torch.Tensor.{fn} + if isinstance( + func, (types.MethodDescriptorType, types.WrapperDescriptorType) + ) and func is getattr(torch.Tensor, func.__name__, None): + return f"torch.Tensor.{func.__name__}" + name = func.__name__ + if name == "": + # For lambdas, try to get their defining name in the module + try: + name = inspect.getsource(func).split("=")[0].strip() + except Exception as e: + raise RuntimeError("Unable to represent lambda") from e + module = _find_module_of_method(func) + module = module.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module + # Fixup segment_reduce mismatch + if module == "torch" and name == "segment_reduce": + name = "_" + name + return f"{module}.{name}" + + +def _format_arg(arg: object, max_list_len: float = float("inf")) -> str: + if hasattr(arg, "_custom_fx_repr_fn"): + return arg._custom_fx_repr_fn() + elif isinstance(arg, list): + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + return f"[{items}{maybe_len}]" + elif isinstance(arg, tuple): + items = ", ".join( + _format_arg(a) for idx, a in enumerate(arg) if idx < max_list_len + ) + maybe_len = ( + "" if len(arg) < max_list_len + 1 else f", ...[total_len={len(arg)}]" + ) + maybe_comma = "," if len(arg) == 1 else "" + return f"({items}{maybe_comma}{maybe_len})" + elif isinstance(arg, dict): + items_str = ", ".join(f"{k}: {_format_arg(v)}" for k, v in arg.items()) + return f"{{{items_str}}}" + + if isinstance(arg, Node): + return "%" + str(arg) + else: + return str(arg) + + +@compatibility(is_backward_compatible=True) +class Node(_NodeBase): + """ + ``Node`` is the data structure that represents individual operations within + a ``Graph``. For the most part, Nodes represent callsites to various entities, + such as operators, methods, and Modules (some exceptions include nodes that + specify function inputs and outputs). Each ``Node`` has a function specified + by its ``op`` property. The ``Node`` semantics for each value of ``op`` are as follows: + + - ``placeholder`` represents a function input. The ``name`` attribute specifies the name this value will take on. + ``target`` is similarly the name of the argument. ``args`` holds either: 1) nothing, or 2) a single argument + denoting the default parameter of the function input. ``kwargs`` is don't-care. Placeholders correspond to + the function parameters (e.g. ``x``) in the graph printout. + - ``get_attr`` retrieves a parameter from the module hierarchy. ``name`` is similarly the name the result of the + fetch is assigned to. ``target`` is the fully-qualified name of the parameter's position in the module hierarchy. + ``args`` and ``kwargs`` are don't-care + - ``call_function`` applies a free function to some values. ``name`` is similarly the name of the value to assign + to. ``target`` is the function to be applied. ``args`` and ``kwargs`` represent the arguments to the function, + following the Python calling convention + - ``call_module`` applies a module in the module hierarchy's ``forward()`` method to given arguments. ``name`` is + as previous. ``target`` is the fully-qualified name of the module in the module hierarchy to call. + ``args`` and ``kwargs`` represent the arguments to invoke the module on, *excluding the self argument*. + - ``call_method`` calls a method on a value. ``name`` is as similar. ``target`` is the string name of the method + to apply to the ``self`` argument. ``args`` and ``kwargs`` represent the arguments to invoke the module on, + *including the self argument* + - ``output`` contains the output of the traced function in its ``args[0]`` attribute. This corresponds to the "return" statement + in the Graph printout. + """ + + _args: tuple["Argument", ...] + _kwargs: dict[str, "Argument"] + graph: "Graph" + # unique name of value being created + name: str + # the kind of operation = placeholder|call_method|call_module|call_function|get_attr + op: str + # for method/module/function, the name of the method/module/function/attr + # being invoked, e.g add, layer1, or torch.add + target: "Target" + # All `Node`-valued inputs. Key is the Node, value is don't-care. + # The public API for this is `all_input_nodes`, this private attribute + # should not be accessed directly. + _input_nodes: dict["Node", None] + # All of the nodes that use the value produced by this Node + # Note one user may correspond to several uses, e.g. the node fo ``x + x`` + # would appear once here, but represents two uses. + # Is a dict to act as an "ordered set". Keys are significant, value dont-care + users: dict["Node", None] + # Type expression representing the output value of this node. + # This should contain the same class of Type objects that would appear + # as type annotations for function inputs/outputs. + # + # For placeholder nodes, this value will be used to type-annotate the + # generated function parameters. + # For the return node, this value will be used to type-annotate the + # generated function return type. (Note this is a special case. ``return`` + # does not produce a value, it's more of a notation. Thus, this value + # describes the type of args[0] in the ``return`` node. + type: Optional[Any] + _sort_key: Any + # If set, use this fn to print this node + _repr_fn: Optional[Callable[["Node"], str]] + # Dictionary to store metadata passes need to do their + # transformations. This metadata is preserved across node copies + meta: dict[str, Any] + + @compatibility(is_backward_compatible=True) + def __init__( + self, + graph: "Graph", + name: str, + op: str, + target: "Target", + args: tuple["Argument", ...], + kwargs: dict[str, "Argument"], + return_type: Optional[Any] = None, + ) -> None: + """ + Instantiate an instance of ``Node``. Note: most often, you want to use the + Graph APIs, i.e. ``Graph.call_module``, ``Graph.call_method``, etc. rather + than instantiating a ``Node`` directly. + + Args: + graph (Graph): The ``Graph`` to which this ``Node`` should belong. + + name (str): The name to which the output of this ``Node`` should be assigned + + op (str): The opcode for this ``Node``. Can be one of 'placeholder', + 'call_method', 'call_module', 'call_function', 'get_attr', + 'output' + + target ('Target'): The target this op should call. See the broader + ``Node`` docstring for more details. + + args (Tuple['Argument']): The args to be passed to ``target`` + + kwargs (Dict[str, 'Argument']): The kwargs to be passed to ``target`` + + return_type (Optional[Any]): The python type expression representing the + type of the output of this node. This field can be used for + annotation of values in the generated code or for other types + of analyses. + """ + if op == "call_function": + if not callable(target): + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a Callable is expected" + ) + else: + assert op in _legal_ops + if not isinstance(target, str): + raise ValueError( + f"Node [graph = {graph}, name = '{name}'] target {target} has type {torch.typename(target)} " + "but a str is expected" + ) + super().__init__(graph, name, op, target, return_type) + self._update_args_kwargs(args, kwargs) + + def __getstate__(self) -> dict[str, Any]: + return { + **self.__dict__, + "graph": self.graph, + "name": self.name, + "op": self.op, + "target": self.target, + "type": self.target, + "_sort_key": self._sort_key, + "_args": self._args, + "_kwargs": self._kwargs, + "_erased": self._erased, + "_prev": self._prev, + "_next": self._next, + "_input_nodes": self._input_nodes, + "users": self.users, + "_repr_fn": self._repr_fn, + "meta": self.meta, + } + + def __setstate__(self, state: dict[str, Any]) -> None: + for k, v in state.items(): + setattr(self, k, v) + + @property + def next(self) -> "Node": + """ + Returns the next ``Node`` in the linked list of Nodes. + + Returns: + + The next ``Node`` in the linked list of Nodes. + """ + return self._next + + @property + def prev(self) -> "Node": + """ + Returns the previous ``Node`` in the linked list of Nodes. + + Returns: + + The previous ``Node`` in the linked list of Nodes. + """ + return self._prev + + @compatibility(is_backward_compatible=True) + def prepend(self, x: "Node") -> None: + """ + Insert x before this node in the list of nodes in the graph. Example:: + + Before: p -> self + bx -> x -> ax + After: p -> x -> self + bx -> ax + + Args: + x (Node): The node to put before this node. Must be a member of the same graph. + """ + assert self.graph == x.graph, "Attempting to move a Node into a different Graph" + if self == x: + log.debug( + "Trying to prepend a node to itself. This behavior has no effect on the graph." + ) + return + x._remove_from_list() + p = self._prev + p._next, x._prev = x, p + x._next, self._prev = self, x + + # compute x._sort_key + psk = x._prev._sort_key + nsk = x._next._sort_key + if len(psk) > len(nsk): + idx: int + *prefix, idx = psk[: len(nsk) + 1] + x._sort_key = (*prefix, idx + 1) + elif len(psk) < len(nsk): + *prefix, idx = nsk[: len(psk) + 1] + x._sort_key = (*prefix, idx - 1) + else: # same length, increase length by 1 + x._sort_key = (*psk, 0) + + def __gt__(self, other: "Node") -> bool: + return self._sort_key > other._sort_key + + def __lt__(self, other: "Node") -> bool: + return self._sort_key < other._sort_key + + def __ge__(self, other: "Node") -> bool: + return self > other or self == other + + def __le__(self, other: "Node") -> bool: + return self < other or self == other + + @compatibility(is_backward_compatible=True) + def append(self, x: "Node") -> None: + """ + Insert ``x`` after this node in the list of nodes in the graph. + Equivalent to ``self.next.prepend(x)`` + + Args: + x (Node): The node to put after this node. Must be a member of the same graph. + """ + self._next.prepend(x) + + def _remove_from_list(self) -> None: + p, n = self._prev, self._next + p._next, n._prev = n, p + + @property + def args(self) -> tuple[Argument, ...]: + """ + The tuple of arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._args + + @args.setter + def args(self, a: tuple[Argument, ...]) -> None: + """ + Set the tuple of arguments to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `_update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.args = new_args` + self._update_args_kwargs(a, self._kwargs) + + @property + def kwargs(self) -> dict[str, Argument]: + """ + The dict of keyword arguments to this ``Node``. The interpretation of arguments + depends on the node's opcode. See the :class:`Node` docstring for more + information. + + Assignment to this property is allowed. All accounting of uses and users + is updated automatically on assignment. + """ + return self._kwargs + + @kwargs.setter + def kwargs(self, k: dict[str, Argument]) -> None: + """ + Set the dict of kwargs to this Node. The interpretation of arguments + depends on the node's opcode. See the ``fx.Graph`` docstring for more + information. + """ + # DO NOT CALL `_update_args_kwargs` directly. The correct way to + # set `args` is via direct assignment, i.e. `node.kwargs = new_kwargs` + self._update_args_kwargs(self._args, k) + + @property + def all_input_nodes(self) -> list["Node"]: + """ + Return all Nodes that are inputs to this Node. This is equivalent to + iterating over ``args`` and ``kwargs`` and only collecting the values that + are Nodes. + + Returns: + + List of ``Nodes`` that appear in the ``args`` and ``kwargs`` of this + ``Node``, in that order. + """ + return list(self._input_nodes.keys()) + + @compatibility(is_backward_compatible=True) + def update_arg(self, idx: int, arg: Argument) -> None: + """ + Update an existing positional argument to contain the new value + ``arg``. After calling, ``self.args[idx] == arg``. + + Args: + + idx (int): The index into ``self.args`` of the element to update + arg (Argument): The new argument value to write into ``args`` + """ + args = list(self.args) + args[idx] = arg + self.args = tuple(args) + + @compatibility(is_backward_compatible=True) + def insert_arg(self, idx: int, arg: Argument) -> None: + """ + Insert an positional argument to the argument list with given index. + + Args: + + idx (int): The index of the element in ``self.args`` to be inserted before. + arg (Argument): The new argument value to insert into ``args`` + """ + assert 0 <= idx <= len(self.args), ( + "insert_args index must be between 0 and len(self.args)" + ) + args_left = self.args[:idx] + args_right = self.args[idx:] + + self._args = args_left + (arg,) + args_right + + _new_input_nodes: dict[Node, None] = {} + _fx_map_arg(arg, _new_input_nodes.setdefault) + + for new_use in _new_input_nodes.keys(): + if new_use not in self._input_nodes: + self._input_nodes.setdefault(new_use) + new_use.users.setdefault(self) + + @compatibility(is_backward_compatible=True) + def update_kwarg(self, key: str, arg: Argument) -> None: + """ + Update an existing keyword argument to contain the new value + ``arg``. After calling, ``self.kwargs[key] == arg``. + + Args: + + key (str): The key in ``self.kwargs`` of the element to update + arg (Argument): The new argument value to write into ``kwargs`` + """ + self.kwargs = {**self.kwargs, key: arg} + + @property + def stack_trace(self) -> Optional[str]: + """ + Return the Python stack trace that was recorded during tracing, if any. + When traced with fx.Tracer, this property is usually populated by + `Tracer.create_proxy`. To record stack traces during tracing for debug purposes, + set `record_stack_traces = True` on the `Tracer` instance. + When traced with dynamo, this property will be populated by default by + `OutputGraph.create_proxy`. + + stack_trace would have the innermost frame at the end of the string. + """ + return self.meta.get("stack_trace", None) + + @stack_trace.setter + def stack_trace(self, trace: Optional[str]) -> None: + self.meta["stack_trace"] = trace + + def __repr__(self) -> str: + if self._repr_fn: + return self._repr_fn(self) + return self.name + + @staticmethod + def _pretty_print_target(target: object) -> str: + """ + Make target printouts more user-friendly. + 1) builtins will be printed as `builtins.xyz` + 2) operators will be printed as `operator.xyz` + 3) other callables will be printed with qualified name, e.g. torch.add + """ + if isinstance(target, str): + return target + if hasattr(target, "__module__"): + name = getattr(target, "__name__", None) + if name is None: + # Just to be defensive, if we don't have `__name__`, get the + # qualname. Not sure if this happens for any members of `operator` + # or `builtins`. This fallback path is not as good, since e.g. + # things in `operator` have `_operator` as their __module__. + # TODO: THIS IS BROKEN: _get_qualified_name calls `__name__` + return _get_qualified_name(target) # type: ignore[arg-type] + if target.__module__ == "builtins": + return f"builtins.{name}" + elif target.__module__ == "_operator": + return f"operator.{name}" + return _get_qualified_name(target) # type: ignore[arg-type] + + @compatibility(is_backward_compatible=True) + def format_node( + self, + placeholder_names: Optional[list[str]] = None, + maybe_return_typename: Optional[list[str]] = None, + ) -> Optional[str]: + """ + Return a descriptive string representation of ``self``. + + This method can be used with no arguments as a debugging + utility. + + This function is also used internally in the ``__str__`` method + of ``Graph``. Together, the strings in ``placeholder_names`` + and ``maybe_return_typename`` make up the signature of the + autogenerated ``forward`` function in this Graph's surrounding + GraphModule. ``placeholder_names`` and ``maybe_return_typename`` + should not be used otherwise. + + Args: + placeholder_names: A list that will store formatted strings + representing the placeholders in the generated + ``forward`` function. Internal use only. + maybe_return_typename: A single-element list that will store + a formatted string representing the output of the + generated ``forward`` function. Internal use only. + + Returns: + str: If 1) we're using ``format_node`` as an internal helper + in the ``__str__`` method of ``Graph``, and 2) ``self`` + is a placeholder Node, return ``None``. Otherwise, + return a descriptive string representation of the + current Node. + """ + if self.op == "placeholder": + assert isinstance(self.target, str) + arg_str = self.target + arg_str += arg_str + f": {_type_repr(self.type)}" if self.type else "" + if placeholder_names: + placeholder_names.append(arg_str) + return None + maybe_typename = f"{_type_repr(self.type)} " if self.type else "" + default_val = "(default=" + str(self.args[0]) + ")" if self.args else "" + return f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = {self.op}[target={self.target}]{default_val}" + elif self.op == "get_attr": + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}]" + ) + elif self.op == "output": + if self.type and maybe_return_typename: + maybe_return_typename[0] = f" -> {_type_repr(self.type)}" + return f"return {self.args[0]}" + else: + maybe_typename = ( + f"{_type_repr(self.type)} " if self.type is not None else "" + ) + return ( + f"%{self.name} : {maybe_typename}[num_users={len(self.users)}] = " + f"{self.op}[target={self._pretty_print_target(self.target)}](" + f"args = {_format_arg(self.args)}, kwargs = {_format_arg(self.kwargs)})" + ) + + @compatibility(is_backward_compatible=True) + def replace_all_uses_with( + self, + replace_with: "Node", + delete_user_cb: Callable[["Node"], bool] = lambda user: True, + *, + propagate_meta: bool = False, + ) -> list["Node"]: + """ + Replace all uses of ``self`` in the Graph with the Node ``replace_with``. + + Args: + + replace_with (Node): The node to replace all uses of ``self`` with. + delete_user_cb (Callable): Callback that is called to determine + whether a given user of the self node should be removed. + propagate_meta (bool): Whether or not to copy all properties + on the .meta field of the original node onto the replacement node. + For safety, this is only valid to do if the replacement node + doesn't already have an existing .meta field. + + Returns: + + The list of Nodes on which this change was made. + """ + if propagate_meta: + assert len(replace_with.meta) == 0, ( + "Called node.replace_all_uses_with(replace_with, propagate_meta=True), " + "but replace_with already has .meta keys" + ) + for k, v in self.meta.items(): + replace_with.meta[k] = v + to_process = list(self.users) + skipped = [] + m = self.graph.owning_module + for use_node in to_process: + if not delete_user_cb(use_node): + skipped.append(use_node) + continue + + def maybe_replace_node(n: Node) -> Node: + if n == self: + return replace_with + else: + return n + + if getattr(m, "_replace_hooks", None): + for replace_hook in m._replace_hooks: + replace_hook(old=self, new=replace_with.name, user=use_node) + + new_args = _fx_map_arg(use_node.args, maybe_replace_node) + new_kwargs = _fx_map_arg(use_node.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + use_node._update_args_kwargs(new_args, new_kwargs) + + assert len(self.users) - len(skipped) == 0 + return [n for n in to_process if n not in skipped] + + @compatibility(is_backward_compatible=False) + def is_impure(self, impure_random: bool = True) -> bool: + """ + Returns whether this op is impure, i.e. if its op is a placeholder or + output, or if a call_function or call_module which is impure. + + Args: + impure_random (bool): Whether to treat rand op as impure. + + Returns: + + bool: If the op is impure or not. + """ + if self.op in {"placeholder", "output"}: + return True + + if self.op == "call_function": + schema = getattr(self.target, "_schema", None) + if schema is not None and schema.is_mutable: + # impure since it mutates inputs + return True + + if impure_random: + if getattr(self.target, "_nondeterministic_seeded", False): + # impure since it mutates RNG state + return True + + return self.target in _side_effectful_functions + + # Check if an impure module. + if self.op == "call_module": + assert self.graph.owning_module is not None, ( + "self.graph.owning_module not set for purity check" + ) + target_mod = self.graph.owning_module.get_submodule(self.target) + assert target_mod is not None, ( + f"Did not find expected submodule target {self.target}" + ) + return getattr(target_mod, "_is_impure", False) + + return False + + @compatibility(is_backward_compatible=False) + def normalized_arguments( + self, + root: torch.nn.Module, + arg_types: Optional[tuple[Any]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, + ) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to Python targets. This means that + `args/kwargs` will be matched up to the module/functional's + signature and return exclusively kwargs in positional order + if `normalize_to_only_use_kwargs` is true. + Also populates default values. Does not support positional-only + parameters or varargs parameters. + + Supports module calls. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + root (torch.nn.Module): Module upon which to resolve module targets. + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns NamedTuple ArgsKwargsPair, or `None` if not successful. + """ + if self.op == "call_function": + assert callable(self.target) + return normalize_function( + self.target, + self.args, # type: ignore[arg-type] + self.kwargs, + arg_types, + kwarg_types, + normalize_to_only_use_kwargs=normalize_to_only_use_kwargs, + ) + elif self.op == "call_module": + assert isinstance(self.target, str) + return normalize_module( + root, + self.target, + self.args, # type: ignore[arg-type] + self.kwargs, + normalize_to_only_use_kwargs=normalize_to_only_use_kwargs, + ) + + return None + + @compatibility(is_backward_compatible=True) + def replace_input_with(self, old_input: "Node", new_input: "Node") -> None: + """ + Loop through input nodes of ``self``, and replace all instances of + ``old_input`` with ``new_input``. + + Args: + + old_input (Node): The old input node to be replaced. + new_input (Node): The new input node to replace ``old_input``. + """ + + def maybe_replace_node(n: Node) -> Node: + return new_input if n == old_input else n + + m = self.graph.owning_module + if getattr(m, "_replace_hooks", None): + for replace_hook in m._replace_hooks: + replace_hook(old=old_input, new=new_input.name, user=self) + + new_args = _fx_map_arg(self.args, maybe_replace_node) + new_kwargs = _fx_map_arg(self.kwargs, maybe_replace_node) + assert isinstance(new_args, tuple) + assert isinstance(new_kwargs, dict) + self._update_args_kwargs(new_args, new_kwargs) + + def _rename(self, candidate: str) -> None: + if candidate == self.name: + return + name = self.graph._graph_namespace.create_name(candidate, None) + self.name = name + self.graph._graph_namespace._rename_object(self, name) + + def __setattr__(self, name: str, value: Any) -> None: + if name == "name" and hasattr(self, "name"): + m = self.graph.owning_module + if getattr(m, "_replace_hooks", None): + assert isinstance(value, str) + for user in self.users: + for replace_hook in m._replace_hooks: + replace_hook(old=self, new=value, user=user) + update = False + if ( + hasattr(self, name) + and hasattr(self.graph, "_find_nodes_lookup_table") + and self in self.graph._find_nodes_lookup_table + ): + update = True + self.graph._find_nodes_lookup_table.remove(self) + object.__setattr__(self, name, value) + if update: + self.graph._find_nodes_lookup_table.insert(self) + + +@compatibility(is_backward_compatible=True) +def map_arg(a: ArgumentT, fn: Callable[[Node], Argument]) -> ArgumentT: + """ + Apply fn recursively to each Node appearing in arg. + + arg may be a list, tuple, slice, or dict with string keys: the return value will + have the same type and structure. + """ + assert callable(fn), "torch.fx.map_arg(a, fn): fn must be a callable" + return _fx_map_arg(a, fn) + + +@compatibility(is_backward_compatible=True) +def map_aggregate(a: ArgumentT, fn: Callable[[Argument], Argument]) -> ArgumentT: + """ + Apply fn recursively to each object appearing in arg. + + arg may be a list, tuple, slice, or dict with string keys: the return value will + have the same type and structure. + """ + return _fx_map_aggregate(a, fn) diff --git a/phivenv/Lib/site-packages/torch/fx/operator_schemas.py b/phivenv/Lib/site-packages/torch/fx/operator_schemas.py new file mode 100644 index 0000000000000000000000000000000000000000..bd45fbf7f2d8b0ade4b8cf5d54b8165bfd50a0a0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/operator_schemas.py @@ -0,0 +1,566 @@ +# mypy: allow-untyped-defs +import enum +import inspect +import numbers +import types +import typing +import warnings +from typing import Any, Callable, cast, NamedTuple, Optional, TYPE_CHECKING + +import torch +from torch._jit_internal import boolean_dispatched +from torch._ops import OpOverload, OpOverloadPacket + +from ._compatibility import compatibility + + +if TYPE_CHECKING: + from .node import Argument + +__all__ = [ + "ArgsKwargsPair", + "check_for_mutable_operation", + "get_signature_for_torch_op", + "create_type_hint", + "type_matches", + "normalize_function", + "normalize_module", +] + + +@compatibility(is_backward_compatible=False) +class ArgsKwargsPair(NamedTuple): + """ + Simple named tuple for wrapping args/kwargs pairs. + """ + + args: tuple[Any, ...] + kwargs: dict[str, Any] + + +_manual_overrides: dict[Callable, list[inspect.Signature]] = {} + + +def _nonzero_schemas(): + signatures = [] + + def nonzero(self): + pass + + signatures.append(inspect.signature(nonzero)) + + def nonzero(self, *, as_tuple: bool): # type: ignore[no-redef] + pass + + signatures.append(inspect.signature(nonzero)) + + return signatures + + +_manual_overrides[torch.nonzero] = _nonzero_schemas() + + +class _FakeGlobalNamespace: + def __getattr__(self, name): + if name == "torch": + return torch + raise RuntimeError("Expected a torch namespace lookup") + + +_type_eval_globals = { + "Tensor": torch.Tensor, + "Device": torch.device, + "Layout": torch.layout, + "number": numbers.Number, + "Future": torch.jit.Future, + "AnyEnumType": enum.Enum, + "QScheme": torch.qscheme, + "__torch__": _FakeGlobalNamespace(), + "NoneType": type(None), + "Storage": torch.UntypedStorage, + "t": typing.TypeVar("t"), +} +for k in dir(typing): + _type_eval_globals[k] = getattr(typing, k) + + +def _torchscript_type_to_python_type(ts_type: "torch._C.JitType") -> Any: + """ + Convert a TorchScript type to a Python type (including subtypes) via + eval'ing the annotation_str. _type_eval_globals sets up expressions + like "List" and "Future" to map to actual types (typing.List and jit.Future) + """ + return eval(ts_type.annotation_str, _type_eval_globals) + + +def _torchscript_schema_to_signature_impl( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: + from inspect import Parameter + + parameters: list[Parameter] = [] + for arg in ts_schema.arguments: + arg_type = _torchscript_type_to_python_type(arg.type) + default = arg.default_value if arg.has_default_value() else Parameter.empty + # TODO: Figure out if this is safe. It seems like when generating the type signatures for + # PythonArgParser, we emit signatures with `input` instead of `self` as the first tensor + # argument name. Downstream, if someone converts that positional argument to a keyword + # argument, the name mismatch will break things, so here we're going to normalize the + # name to "input" + name = arg.name if arg.name != "self" else "input" + kind = ( + Parameter.KEYWORD_ONLY + if arg.kwarg_only + else Parameter.POSITIONAL_OR_KEYWORD + ) + # "from" is a keyword therefore it must be a POSITIONAL_ONLY argument + if name == "from": + assert kind == Parameter.POSITIONAL_OR_KEYWORD + # ParameterKind type is internal implementation detail to inspec package + # which makes it hard to do type annotation + kind = Parameter.POSITIONAL_ONLY # type: ignore[assignment] + # This renders all previous arguments to positional only + for idx, p in enumerate(parameters): + assert p.kind == Parameter.POSITIONAL_OR_KEYWORD + parameters[idx] = Parameter( + name=p.name, + kind=Parameter.POSITIONAL_ONLY, + default=p.default, + annotation=p.annotation, + ) + parameters.append( + Parameter(name=name, kind=kind, default=default, annotation=arg_type) + ) + return_types = [ + _torchscript_type_to_python_type(ret.type) for ret in ts_schema.returns + ] + if len(return_types) == 0: + return_type = None + elif len(return_types) == 1: + return_type = return_types[0] + else: + return_type = tuple(return_types) + + return inspect.Signature(parameters, return_annotation=return_type) + + +_SCHEMA_TO_SIGNATURE_CACHE: dict[tuple[str, str], inspect.Signature] = {} + + +def _torchscript_schema_to_signature( + ts_schema: torch._C.FunctionSchema, +) -> inspect.Signature: + # Cached as it's called in the hot path of FakeTensor dispatch + cache_key = ts_schema.name, ts_schema.overload_name + cache_val = _SCHEMA_TO_SIGNATURE_CACHE.get(cache_key) + if cache_val is not None: + return cache_val + + res = _torchscript_schema_to_signature_impl(ts_schema) + _SCHEMA_TO_SIGNATURE_CACHE[cache_key] = res + return res + + +@compatibility(is_backward_compatible=False) +def check_for_mutable_operation( + target: Callable, args: tuple["Argument", ...], kwargs: dict[str, "Argument"] +): + signatures, schemas = get_signature_for_torch_op(target, return_schemas=True) + + if signatures and schemas: + matched_schemas = [] + + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature, schema in zip(signatures, schemas): + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append((candidate_signature, schema)) + except TypeError: + continue + + def throw_if_mutable(schema): + if schema.is_mutable: + raise RuntimeError( + f"Tried to trace mutable operation {schema}. FX only supports functional " + f"code, so operations that mutate operands in-place (e.g. via `out` arguments) " + f"are not supported" + ) + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot check for mutation + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + _, schema_to_check = matched_schemas[0] + throw_if_mutable(schema_to_check) + else: + # Ambiguous schema match. Since mutability checking is best effort, + # do nothing. + pass + + +@compatibility(is_backward_compatible=False) +def get_signature_for_torch_op(op: Callable, return_schemas: bool = False): + """ + Given an operator on the `torch` namespace, return a list of `inspect.Signature` + objects corresponding to the overloads of that op.. May return `None` if a signature + could not be retrieved. + + Args: + op (Callable): An operator on the `torch` namespace to look up a signature for + + Returns: + Optional[List[inspect.Signature]]: A list of signatures for the overloads of this + operator, or None if the operator signatures could not be retrieved. If + return_schemas=True, returns a tuple containing the optional Python signatures + and the optional TorchScript Function signature + """ + if isinstance(op, OpOverload): + schemas = [op._schema] + elif isinstance(op, OpOverloadPacket): + schemas = [getattr(op, overload)._schema for overload in op.overloads()] + else: + override = _manual_overrides.get(op) + if override: + return (override, None) if return_schemas else None + + aten_fn = torch.jit._builtins._find_builtin(op) + + if aten_fn is None: + return (None, None) if return_schemas else None + schemas = torch._C._jit_get_schemas_for_operator(aten_fn) + + signatures = [_torchscript_schema_to_signature(schema) for schema in schemas] + return (signatures, schemas) if return_schemas else signatures + + +@compatibility(is_backward_compatible=False) +def create_type_hint(x): + """ + Produces a type hint for the given argument. + + The :func:`create_type_hint` looks for a type hint compatible with the input argument `x`. + + If `x` is a `list` or `tuple`, it looks for an object in the list whose type is a superclass + of the rest, and uses that as `base_type` for the `List` or `Tuple` to be returned. + If no such object is found, it defaults to `List[Any]`. + + If `x` is neither a `list` nor a `tuple`, it returns `x`. + """ + try: + if isinstance(x, (list, tuple)): + # todo(chilli): Figure out the right way for mypy to handle this + if isinstance(x, list): + + def ret_type(x): + return list[x] # type: ignore[valid-type] + + else: + + def ret_type(x): + return tuple[x, ...] # type: ignore[valid-type] + + if len(x) == 0: + return ret_type(Any) + base_type = x[0] + for t in x: + if issubclass(t, base_type): + continue + elif issubclass(base_type, t): + base_type = t + else: + return ret_type(Any) + return ret_type(base_type) + except Exception: + # We tried to create a type hint for list but failed. + warnings.warn( + f"We were not able to successfully create type hint from the type {x}" + ) + return x + + +@compatibility(is_backward_compatible=False) +def type_matches(signature_type: Any, argument_type: Any): + sig_origin_type = getattr(signature_type, "__origin__", signature_type) + + if signature_type is argument_type: + return True + + # Union types in signature. Given type needs to match one of the + # contained types in the Union + if sig_origin_type is typing.Union and signature_type != argument_type: + sig_contained = signature_type.__args__ + return any(type_matches(c, argument_type) for c in sig_contained) + + if getattr(signature_type, "__origin__", None) is list: + sig_el_type = signature_type.__args__[0] + + # int can be promoted to list[int] + if argument_type is int and sig_el_type is int: + return True + + if not inspect.isclass(sig_el_type): + warnings.warn( + f"Does not support nested parametric types, got {signature_type}. Please file a bug." + ) + return False + if getattr(argument_type, "__origin__", None) is list: + return issubclass(argument_type.__args__[0], sig_el_type) + + def is_homogeneous_tuple(t): + if getattr(t, "__origin__", None) is not tuple: + return False + contained = t.__args__ + if t.__args__ == ((),): # Tuple[()].__args__ == ((),) for some reason + return True + return all((c is Ellipsis) or issubclass(c, sig_el_type) for c in contained) + + # Tuple[T] is accepted for List[T] parameters + return is_homogeneous_tuple(argument_type) + + # Dtype is an int in schemas + if signature_type is int and argument_type is torch.dtype: + return True + + if signature_type is numbers.Number and argument_type in {int, float}: + return True + if inspect.isclass(argument_type) and inspect.isclass(signature_type): + return issubclass(argument_type, signature_type) + + return False + + +@compatibility(is_backward_compatible=False) +def normalize_function( + target: Callable, + args: tuple[Any, ...], + kwargs: Optional[dict[str, Any]] = None, + arg_types: Optional[tuple[Any]] = None, + kwarg_types: Optional[dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch functions. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). Does not support modules. + + May require `arg_types` and `kwarg_types` in order to disambiguate overloads. + + Args: + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + arg_types (Optional[Tuple[Any]]): Tuple of arg types for the args + kwarg_types (Optional[Dict[str, Any]]): Dict of arg types for the kwargs + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + if kwargs is None: + kwargs = {} + new_args_and_kwargs = None + if ( + not isinstance(target, types.BuiltinFunctionType) + and not (isinstance(target, (OpOverloadPacket, OpOverload))) + and hasattr(target, "_op") + ): + # ExecuTorch's EdgeOpOverload are a wrapper around PyTorch's OpOverload, + # so we can unwrap it here to get its schema + # Can't import EdgeOpOverload directly because of a circular dependency, + # so checking for "_op" existing is the next best thing. + target = target._op + + # Repeat the condition after checking for the inner _op field. + if not isinstance(target, types.BuiltinFunctionType) and not ( + isinstance(target, (OpOverloadPacket, OpOverload)) + ): + target_for_analysis = target + if target in boolean_dispatched: + # HACK: `boolean_dispatch` as used in `torch.nn.functional` makes it so that we have + # a 2-way dispatch based on a boolean value. Here we check that the `true` and `false` + # branches of the dispatch have exactly the same signature. If they do, use the `true` + # branch signature for analysis. Otherwise, leave this un-normalized + assert not isinstance(target, str) + dispatched = boolean_dispatched[target] + if_true, if_false = dispatched["if_true"], dispatched["if_false"] + if ( + inspect.signature(if_true).parameters + != inspect.signature(if_false).parameters + ): + return None + target_for_analysis = if_true + + assert callable(target_for_analysis) + sig = inspect.signature(inspect.unwrap(target_for_analysis)) + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) + else: + assert callable(target) + torch_op_schemas = get_signature_for_torch_op(target) + matched_schemas = [] + if torch_op_schemas: + # Iterate through all of the schema until we find one that matches + # If one matches, populate `new_args_and_kwargs` with the new args/kwargs + # values. If none matches, `new_args_and_kwargs` will be None + for candidate_signature in torch_op_schemas: + try: + candidate_signature.bind(*args, **kwargs) + matched_schemas.append(candidate_signature) + except TypeError: + continue + + if len(matched_schemas) == 0: + # Did not match any schema. Cannot normalize + pass + elif len(matched_schemas) == 1: + # Matched exactly one schema, unambiguous + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + matched_schemas[0], args, kwargs, normalize_to_only_use_kwargs + ) + else: + if arg_types is not None or kwarg_types is not None: + arg_types = arg_types if arg_types else cast(tuple[Any], ()) + kwarg_types = kwarg_types if kwarg_types else {} + for candidate_signature in torch_op_schemas: + sig_matches = True + try: + bound_types = candidate_signature.bind( + *arg_types, **kwarg_types + ) + for arg_name, arg_type in bound_types.arguments.items(): + param = candidate_signature.parameters[arg_name] + sig_matches = sig_matches and type_matches( + param.annotation, arg_type + ) + except TypeError: + sig_matches = False + if sig_matches: + new_args_and_kwargs = ( + _args_kwargs_to_normalized_args_kwargs( + candidate_signature, + args, + kwargs, + normalize_to_only_use_kwargs, + ) + ) + break + else: + # Matched more than one schema. In this situation, the caller must provide the types of + # the arguments of the overload they expect. + schema_printouts = "\n".join( + str(schema) for schema in matched_schemas + ) + raise RuntimeError( + f"Tried to normalize arguments to {torch.typename(target)} but " + f"the schema match was ambiguous! Please provide argument types to " + f"the normalize_arguments() call. Available schemas:\n{schema_printouts}" + ) + + return new_args_and_kwargs + + +@compatibility(is_backward_compatible=False) +def normalize_module( + root: torch.nn.Module, + target: str, + args: tuple[Any], + kwargs: Optional[dict[str, Any]] = None, + normalize_to_only_use_kwargs: bool = False, +) -> Optional[ArgsKwargsPair]: + """ + Returns normalized arguments to PyTorch modules. This means that + `args/kwargs` will be matched up to the functional's + signature and return exclusively kwargs in positional order if + `normalize_to_only_use_kwargs` is True. + Also populates default values. Does not support positional-only + parameters or varargs parameters (*args, **kwargs). + + Args: + root (nn.Module): root module upon which we query modules + target (Callable): Function that we are normalizing + args (Tuple[Any]): Tuple of args to the function + kwargs (Optional[Dict[str, Any]]): Dict of kwargs to the function + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Returns normalized_args_and_kwargs, or `None` if not successful. + """ + try: + submod = root.get_submodule(target) + except AttributeError as e: + raise RuntimeError( + f"Tried to normalize node with target {target} but root did not " + f"have that target!" + ) from e + if hasattr(submod.__class__, "__name__"): + classname = submod.__class__.__name__ + if getattr(torch.nn, classname, None) == submod.__class__: + sig = inspect.signature(inspect.unwrap(submod.forward)) + if kwargs is None: + kwargs = {} + new_args_and_kwargs = _args_kwargs_to_normalized_args_kwargs( + sig, args, kwargs, normalize_to_only_use_kwargs + ) + return new_args_and_kwargs + return None + + +def _args_kwargs_to_normalized_args_kwargs( + sig: inspect.Signature, + args: tuple[Any, ...], + kwargs: dict[str, Any], + normalize_to_only_use_kwargs: bool, +) -> Optional[ArgsKwargsPair]: + """ + Given a call target, args, and kwargs, return the arguments normalized into + an ArgsKwargsPair, or None if the type signature is not supported by + this normalization. + + Args: + + sig (inspect.Signature): Signature object for the target + args (Tuple): Arguments that appear at the callsite for `target` + kwargs (Dict): Keyword arguments that appear at the callsite for `target` + normalize_to_only_use_kwargs (bool): Whether to normalize to only use kwargs. + + Returns: + + Optional[ArgsKwargsPair]: Normalized args and kwargs for `target`, or `None` if + this target is not supported. + """ + + # Don't currently support positional-only + # or varargs (*args, **kwargs) signatures + supported_parameter_types = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } + if any(p.kind not in supported_parameter_types for p in sig.parameters.values()): + # Add an exception for one signature, which is common for random/uniform, i.e.: + # Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None + # `from` is Python keyword and as such functions with that signature should have + # positional-only args, but at the same time they could be dispatched as kwargs + if list(sig.parameters.keys()) != ["input", "from", "to", "generator"]: + return None + + bound_args = sig.bind(*args, **kwargs) + bound_args.apply_defaults() + + new_kwargs: dict[str, Any] = {} + new_args: list[Any] = [] + for i, param in enumerate(sig.parameters): + if not normalize_to_only_use_kwargs and i < len(args): + new_args.append(bound_args.arguments[param]) + else: + new_kwargs[param] = bound_args.arguments[param] + + return ArgsKwargsPair(tuple(new_args), new_kwargs) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..49f17907825e064a855c4001e023a0b2749a7c73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/__init__.py @@ -0,0 +1,14 @@ +from . import ( + graph_drawer, + graph_manipulation, + net_min_base, + operator_support, + param_fetch, + reinplace, + runtime_assert, + shape_prop, + split_module, + split_utils, + splitter_base, + tools_common, +) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d837b0497264132f0fd1d1fecf18ec2e2fe8055 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5edc12a8bb9e42063940f846f5a81165b1b1d59d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/_tensorify_python_scalars.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4658eb21ded20835833112bd3b034c292b54312c Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/annotate_getitem_nodes.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab63792d6f3147ab6d471a0753a58eb393afd672 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/fake_tensor_prop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e04101c72ffd6ceaf00b418281949536d3839be Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_drawer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf27494eb8278cfb9cfc223c8ae9520f51bd3acf Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_manipulation.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67210a4d5e40289ffbd4665f3ae22130d1446a04 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/graph_transform_observer.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4025ae54ae56195f8f8d0c152a544695ba0602f9 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/net_min_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af23c97938db71ad4602a478f96b02601e3a8775 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/operator_support.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59f235455880ee79f6051ce41b7a0fece4d92776 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/param_fetch.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d164f622fd2c63fecaf0f196484c6c34595c81a8 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/pass_manager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef64811ac85cce09effd9cd71be3506e7637d18e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/reinplace.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3b34504945527382db56b92b3eb46a2438f3e5ed Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/runtime_assert.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21230b5424f7ecadb4cca30b87d471440f219ca Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/shape_prop.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/split_module.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/split_module.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c75c04330c7b953a71b47aa2ee2e4255477a3cb Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/split_module.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..743c1ea28a0716325c9f261459a0d7387387a9d4 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/split_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d93330d22ed5bb6a6796924bdc3d1a7b0a8d517 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/splitter_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4b1760daed19de31cb655a2b9075380545af29d Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/__pycache__/tools_common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/_tensorify_python_scalars.py b/phivenv/Lib/site-packages/torch/fx/passes/_tensorify_python_scalars.py new file mode 100644 index 0000000000000000000000000000000000000000..29faefaf1aa722fbb4e7457876e0fcfeab67d0a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/_tensorify_python_scalars.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +import logging +import os +from typing import Any, Union + +from sympy import Integer, Number, Symbol +from sympy.logic.boolalg import BooleanAtom + +import torch +import torch.fx as fx +from torch._dynamo.exc import TensorifyScalarRestartAnalysis +from torch._dynamo.symbolic_convert import TensorifyState +from torch._dynamo.utils import get_metrics_context +from torch._prims_common import get_computation_dtype +from torch._subclasses import fake_tensor # noqa: TCH001 +from torch._subclasses.fake_tensor import FakeTensor +from torch._utils_internal import justknobs_check +from torch.fx._utils import lazy_format_graph_code +from torch.fx.experimental.symbolic_shapes import ( # noqa: TCH001 + guard_scalar, + has_free_symbols, + ShapeEnv, +) +from torch.fx.graph_module import GraphModule # noqa: TCH001 + +# TODO: refactor +from torch.fx.passes.runtime_assert import _get_sym_val +from torch.fx.proxy import MetaProxy +from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp +from torch.utils._sympy.reference import TensorReferenceAnalysis +from torch.utils._sympy.symbol import symbol_is_type, SymT + + +__all__: list[str] = [] + +log = logging.getLogger(__name__) +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose") + +# The general shape of this transformation is to look for Tensor operations +# that take a backed SymFloat as an argument, and then redo them as tensor +# compute (with ints and tensors as inputs). For example, add(Tensor, Scalar) +# can be translated into add(Tensor, Tensor). Because Dynamo has already +# arranged for floats to be Tensor inputs to the graph, for typical float +# compute you can entirely translate the Python float operations into Tensor +# operations with only Tensor inputs. +# +# This pass is also responsible for doing CSE on the fly as we do this, since +# you don't want to keep recomputing the same quantity over and over again if +# it's used multiple times. +# +# This pass runs on the JOINT graph produced by AOT Autograd, prior to partitioning. +# The primary goal of this pass is to eliminate floats by replacing TensorScalar +# operations with TensorTensor operations and then Dead Code Elimination (DCE) of +# the item calls, which effectively removes the floats. +# +# This needs to happen before partitioning because it influences partitioning decisions, +# specifically by ensuring that we don't need to save floats across partitions. +# Additionally, there is a separate pass that changes which device computations +# occur on. That pass must be run after this one, but still before partitioning. +# +# HISTORY NOTE: Originally, I wanted to formulate this pass as pushing item() +# calls down, transforming float compute into int compute as we went. If you +# manage to eliminate all float compute, this ends up being equivalent, but +# there is a critical difference when some floats cannot be eliminated: when +# we call item() on them, what should it's SymFloat be? Ideally, it would +# be the same backed SymFloat we had before. But without symbolic expresssion +# propogation on tensor quantities, repropagating would instead give you an +# unbacked SymFloat. Maybe it is a good idea to implement symbolic propagation +# on 0d scalar tensors, but I decided to go for something simpler to start. +# +# The boring stuff: +# +# * What operators can I Tensor-ify? (Anything with a Scalar argument) +# * How do I Tensor-ify a SymFloat sympy expression (Sympy -> Op Handler -> Tensor) +# +# TODO: make sure this runs before CPU->CUDA pass for cudagraph friendliness + + +SUPPORTED_OPS = { + torch.ops.aten.mul.Tensor: torch.ops.aten.mul.Tensor, + torch.ops.aten.add.Tensor: torch.ops.aten.add.Tensor, + torch.ops.aten.sub.Tensor: torch.ops.aten.sub.Tensor, + torch.ops.aten.div.Tensor: torch.ops.aten.div.Tensor, + torch.ops.aten.gt.Scalar: torch.ops.aten.gt.Tensor, + torch.ops.aten.lt.Scalar: torch.ops.aten.lt.Tensor, + torch.ops.aten.ge.Scalar: torch.ops.aten.ge.Tensor, + torch.ops.aten.le.Scalar: torch.ops.aten.le.Tensor, + torch.ops.aten.eq.Scalar: torch.ops.aten.eq.Tensor, + torch.ops.aten.ne.Scalar: torch.ops.aten.ne.Tensor, +} + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def tensorify_python_scalars( + gm: GraphModule, shape_env: ShapeEnv, fake_mode: fake_tensor.FakeTensorMode +) -> None: + """ + Converts Python scalar operations into Tensor operations within the graph. This pass looks for + Tensor operations that involve SymFloat arguments and transforms them into equivalent operations + that use only Tensor inputs. + + Args: + gm: The FX graph module representing the computation graph. + shape_env: The shape environment responsible for symbolic shape tracking and propagation + during graph transformations. + + Returns: + None + """ + import sympy + + knob = True + if (env := os.getenv("TENSORIFY_PYTHON_SCALARS")) is not None: + if env in ("0", "FALSE"): + knob = False + else: + knob = justknobs_check("pytorch/compiler:tensorify_python_scalars") + if not knob: + return None + + graph = gm.graph + tracer = fx.proxy.GraphAppendingTracer(graph) + expr_to_sym_proxy: dict[sympy.Expr, MetaProxy] = {} + expr_to_tensor_proxy: dict[sympy.Expr, MetaProxy] = {} + tensorified_symbols: set[sympy.Symbol] = set() + should_restart = False + + first_non_placeholder = None + placeholders = set() + for node in graph.nodes: + if node.op != "placeholder": + first_non_placeholder = node + break + else: + placeholders.add(node) + + Analysis = TensorReferenceAnalysis + + def _sympy_interp(expr: sympy.Expr) -> MetaProxy: + # sympy_interp() with hash consing, and special handling for + # generating constants correctly + + # hash cons + if isinstance(expr, Symbol) and expr not in expr_to_tensor_proxy: + # This is guaranteed to be populated by invariant established by + # insert_deferred_runtime_asserts + expr_to_tensor_proxy[expr] = torch.ops.aten.scalar_tensor.default( + expr_to_sym_proxy[expr] + ) + + # cache constants, why not + if isinstance(expr, (Integer, Number, BooleanAtom)): + dtype = None + c: Union[bool, int, float] + if isinstance(expr, BooleanAtom): + dtype = torch.bool + c = bool(expr) + elif isinstance(expr, sympy.Integer): + dtype = torch.int64 + c = int(expr) + elif isinstance(expr, sympy.Number): + dtype = torch.float64 + c = float(expr) + + node = graph.call_function( + torch.ops.aten.scalar_tensor.default, (c,), {"dtype": dtype} + ) + with fake_mode: + node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype) + expr_to_tensor_proxy[expr] = MetaProxy( + node, + tracer=tracer, + fake_mode=fake_mode, + ) + + if expr in expr_to_tensor_proxy: + return expr_to_tensor_proxy[expr] + + # don't cache + if isinstance(expr, Symbol): + return sympy_interp(Analysis, expr_to_tensor_proxy, expr) # type: ignore[arg-type] + + # hash cons on arguments, run expr handler + expr_to_tensor_proxy[expr] = _run_sympy_handler( + Analysis, + [_sympy_interp(arg) for arg in expr.args], # type: ignore[arg-type] + expr, + ) + + return expr_to_tensor_proxy[expr] + + failed_tensorify_ops: set[str] = set() + nodes = list(graph.nodes) + for i, node in enumerate(nodes[:-1]): + with graph.inserting_before( + nodes[i + 1] if node not in placeholders else first_non_placeholder + ): + # Look for tensor.item() calls on placeholders + if ( + node is not None + and node.op == "call_function" + and node.target is torch.ops.aten._local_scalar_dense.default + ): + dtype = node.args[0].meta["val"].dtype + if dtype != torch.float64: + continue + + assert isinstance(node.args[0], fx.Node), node.args[0] + + s = node.meta["val"].node.expr + expr_to_tensor_proxy[s] = MetaProxy( + node.args[0], tracer=tracer, fake_mode=fake_mode + ) + expr_to_sym_proxy[s] = MetaProxy( + node, tracer=tracer, fake_mode=fake_mode + ) + elif (sym_expr := _get_sym_val(node)) is not None: + if sym_expr not in expr_to_sym_proxy and not isinstance( + sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + ): + expr_to_sym_proxy[sym_expr] = MetaProxy( + node, tracer=tracer, fake_mode=fake_mode + ) + + # Specialize all dimensions that contain symfloats. Here's + # an example test that requires this: + # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950 + val = node.meta.get("val") + if isinstance(val, FakeTensor): + for dim in val.shape: + if isinstance(dim, torch.SymInt): + for s in dim.node.expr.free_symbols: + name = str(s) + if symbol_is_type( + s, SymT.FLOAT + ) and not TensorifyState.should_specialize(name): + # In principle, we could support float input that + # is used to do size compute. The problem is that + # we don't actually want to tensorify the compute + # in this case, which means we need codegen support for + # all symfloats. + TensorifyState.specialize(name) + should_restart = True + + # Look for functions to convert + if node.op == "call_function" and ( + replacement_op := SUPPORTED_OPS.get(node.target) + ): + args: list[Any] = [] + transform = False + compute_dtype = get_computation_dtype(node.meta["val"].dtype) + + for a in node.args: + if ( + isinstance(a, fx.Node) + and "val" in a.meta + and isinstance(zf := a.meta["val"], torch.SymFloat) + ): + transform = True + try: + proxy = _sympy_interp(zf.node.expr) + except NotImplementedError: + transform = False + break + + # We use _expr instead of expr b/c we want the symbol not the replacement + tensorified_symbols.add(a.meta["val"].node._expr) + + # The upcasting is irrelevant when the compute dtype is bool. This happens + # in cases where we are tensorifying a comparison operator such as + # torch.ops.aten.gt.Tensor + if ( + compute_dtype != torch.bool + and proxy.node.meta["val"].dtype != compute_dtype + ): + proxy = torch.ops.prims.convert_element_type.default( + proxy, compute_dtype + ) + + args.append(proxy) + elif isinstance(a, fx.Node): + args.append(MetaProxy(a, tracer=tracer, fake_mode=fake_mode)) + else: + args.append(a) + + if transform: + replacement_proxy = replacement_op(*args) + + if compute_dtype != node.meta["val"].dtype: + replacement_proxy = ( + torch.ops.prims.convert_element_type.default( + replacement_proxy, + node.meta["val"].dtype, + ) + ) + + node.replace_all_uses_with(replacement_proxy.node) + graph.erase_node(node) + + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.set( + "tensorify_float_success", True, overwrite=True + ) + else: + for a in node.args: + if ( + isinstance(a, fx.Node) + and "val" in a.meta + and isinstance(zf := a.meta["val"], torch.SymFloat) + ): + failed_tensorify_ops.update(str(node.target)) + log.info("Failed to tensorify %s", str(node.target)) + + # Now do one more pass that specializes all symfloats we didn't manage + # to tensorify away. + for node in reversed(graph.nodes): + if node.op == "output" or node.op == "placeholder": + continue + + with graph.inserting_before(node): + if len(node.users) == 0 and not node.is_impure(): + graph.erase_node(node) + continue + + if isinstance( + (val := node.meta.get("val")), + (torch.SymFloat, torch.SymInt, torch.SymBool), + ): + if has_free_symbols(val.node.expr) and all( + symbol_is_type(s, SymT.FLOAT) for s in val.node.expr.free_symbols + ): + # If all symbols are backed symfloats, we can just specialize the whole node + # and get more precise guards. eg. + # + # zf = a.item() + # zf2 = zf // 2 + # op(.. zf2 ..) + # + # It's better to guard on zf // 2 == 2.0 than zf == 5.0 + + node.replace_all_uses_with(guard_scalar(val)) + graph.erase_node(node) + + # Sometimes by the time we get to tensorify, there have already been + # specializations, eg. in python_arg_parser.h. In these cases, + # placeholder nodes no longer have a reference to their original + # symfloat and thus we need to deduce specializations have happend + # via shape_env.replacements. NB: there's an important invariant here + # that symfloats keep consistent names across restarts. + for k, v in shape_env.var_to_val.items(): + if symbol_is_type(k, SymT.FLOAT) and isinstance(v, sympy.core.numbers.Float): + name = str(k) + if ( + not TensorifyState.should_specialize(name) + and k not in tensorified_symbols + ): + TensorifyState.specialize(name) + should_restart = True + + if should_restart: + # Sledgehammer time. Restart dynamo analysis, keeping track of which input sources + # are no longer needed and should be specialized. Restarting analysis is necessary + # because we need to instruct Dynamo to NOT make these as inputs. + metrics_context = get_metrics_context() + if metrics_context.in_progress(): + metrics_context.set( + "tensorify_float_failure", failed_tensorify_ops, overwrite=True + ) + metrics_context.set("tensorify_float_success", True, overwrite=True) + raise TensorifyScalarRestartAnalysis + + graph_code_log.debug( + "%s", lazy_format_graph_code("tensorify_python_scalars", gm, colored=True) + ) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/annotate_getitem_nodes.py b/phivenv/Lib/site-packages/torch/fx/passes/annotate_getitem_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..9a2a6090c05cc08920b396be38118a1bef63bcd3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/annotate_getitem_nodes.py @@ -0,0 +1,59 @@ +import operator + +import torch + + +def annotate_getitem_nodes(graph: torch.fx.Graph) -> None: + """ + Annotate the type of getitem nodes, inferred from the type of sequence node. + If sequence node is not annotated with a type, do nothing. + Currently support getitem nodes from tuple, list, and NamedTuple sequence node. + + This is helpful since annotations on local names within function are lost during FX transforms. + Adding back known type annotation for getitem nodes to improve jit scriptability. + + Args: + graph (Graph): The graph to be annotated + """ + for node in graph.nodes: + if node.target == operator.getitem: + sequence_node, index_node = node.args + if not sequence_node.type: + continue + # container types + if hasattr(sequence_node.type, "_name"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type._name == "Tuple": + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type._name == "List": + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] + # Generic Alias Type + elif hasattr(sequence_node.type, "__origin__"): + parameterized_types = sequence_node.type.__args__ + if sequence_node.type.__origin__ is tuple: + if len(parameterized_types) == 2 and isinstance( + parameterized_types[1], type(...) + ): + node.type = parameterized_types[0] + else: + assert len(parameterized_types) > index_node + node_type = parameterized_types[index_node] + node.type = node_type + elif sequence_node.type.__origin__ is list: + assert len(parameterized_types) == 1 + node.type = parameterized_types[0] + # NamedTuple type + elif hasattr(sequence_node.type, "__annotations__"): + if sequence_node.type == torch.Tensor: + continue + sequence_node_field_types = sequence_node.type.__annotations__ + field_name = sequence_node.type._fields[index_node] + node.type = sequence_node_field_types[field_name] diff --git a/phivenv/Lib/site-packages/torch/fx/passes/backends/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516fda26db8ba052f1048b933bb25dda72c90b5e Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/backends/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d280dba1c866c844f943ab9b7caf536284d80d47 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/backends/__pycache__/cudagraphs.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/backends/cudagraphs.py b/phivenv/Lib/site-packages/torch/fx/passes/backends/cudagraphs.py new file mode 100644 index 0000000000000000000000000000000000000000..dae93547cda5107d4f3ee0a83854cd3188e7e3fe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/backends/cudagraphs.py @@ -0,0 +1,61 @@ +# mypy: allow-untyped-defs +import operator + +import torch +from torch.fx.passes.fake_tensor_prop import FakeTensorProp +from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner +from torch.fx.passes.operator_support import OperatorSupport +from torch.fx.passes.tools_common import CALLABLE_NODE_OPS +from torch.utils import _pytree as pytree + + +class CudaGraphsSupport(OperatorSupport): + # TODO: why is submodules passed here + def is_node_supported(self, submodules, node: torch.fx.Node) -> bool: + if node.op not in CALLABLE_NODE_OPS: + return False + + if node.target in [torch.ops.aten.embedding_dense_backward.default]: + return False + + if node.target in [operator.getitem]: + return True + + found_not_cuda = False + + def meta_fk(meta): + return meta["val"] if "val" in meta else meta["fake_result"] + + def find_not_cuda(t): + nonlocal found_not_cuda + if isinstance(t, torch.Tensor) and t.device.type != "cuda": + found_not_cuda = True + + for n in node.all_input_nodes: + pytree.tree_map_(find_not_cuda, meta_fk(n.meta)) + + pytree.tree_map_(find_not_cuda, meta_fk(node.meta)) + + # NB: factory function is accounted for because the result would be + # cpu or cuda + + return not found_not_cuda + + +def partition_cudagraphs(gm, inputs): + """ + Partition an FX graph into sub-GraphModules that can be validly run under + CUDA graphs. For a subgraph to be runnable under CUDA, all of the operations + must involve CUDA tensors only/ + """ + + FakeTensorProp(gm).propagate(*inputs) + supported_ops = CudaGraphsSupport() + # TODO: single node partition may be wrong due to the pessimization + # from copying in and out the data. Check in benchmarks, perhaps + partitioner = CapabilityBasedPartitioner( + gm, supported_ops, allows_single_node_partition=True + ) + partitions = partitioner.propose_partitions() + fused_graph = partitioner.fuse_partitions(partitions) + return fused_graph diff --git a/phivenv/Lib/site-packages/torch/fx/passes/dialect/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/dialect/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2e7204abadffbb2aa57f0758fbbacab429e9b00 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/dialect/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90a4e59bdf58498cb5755988d4c27edecfdce440 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19ae205ade55d0d0eb78e907a3e0fde02e5c4a53 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/__pycache__/cse_pass.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/cse_pass.py b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/cse_pass.py new file mode 100644 index 0000000000000000000000000000000000000000..7d08ba258e97cd46d5c1d74c9b2de11c4997ba44 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/dialect/common/cse_pass.py @@ -0,0 +1,155 @@ +# mypy: allow-untyped-defs +from typing import Any + +import torch +from torch.fx import Graph, GraphModule, Node +from torch.fx.passes.infra.pass_base import PassBase, PassResult +from torch.utils._pytree import tree_flatten + + +aten = torch.ops.aten + + +# stateful ops are banned from CSE +rand_ops = { + aten.dropout, + aten._fused_dropout, + aten._standard_gamma, + aten.bernoulli, + aten.multinomial, + aten.native_dropout, + aten.normal, + aten.poisson, + aten.binomial, + aten.rrelu, + aten.rand_like, + aten.rand, + aten.randint, + aten.randn, + aten.randperm, +} # noqa: E501,B950 + +inplace_ops = { + aten.add_, + aten.sub_, + aten.mul_, + aten.div_, + aten.pow_, + aten.lerp_, + aten.relu_, + aten.sigmoid_, + aten.tanh_, +} # noqa: E501 + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +def get_CSE_banned_ops(): + return rand_ops.union(inplace_ops) + + +@torch.fx._compatibility.compatibility(is_backward_compatible=False) +class CSEPass(PassBase): + def __init__(self, banned_ops=None): + """ + This version of CSE Pass aims to be dialect agnostic, and it's implemented purely based on the connectivity between fx.Node. + + For functional dialects, user would only need to specify the random ops in ban list. + + Warning: CSE Pass cannot be safely applied on a FX graph in non-functional dialects. + If your dialect contains stateful operators, please customized the banned_ops. + + """ + if banned_ops is None: + banned_ops = set() + self.banned_ops = banned_ops + super().__init__() + + def call(self, graph_module: GraphModule) -> PassResult: + """ + Return a new copy of torch.fx.GraphModule with CSE applied to the input graph + + Example usage: + + from torch.fx.experimental.proxy_tensor import make_fx + def f(a): + b = a * a + c = a * a + return b+c + + p = CSEPass() + traced_graph = make_fx(f)(torch.tensor(1)) + print(traced_graph) + result = p(traced_graph) + print(result.graph_module) + """ + + def get_aten_target(node): + if hasattr(node.target, "overloadpacket"): + return node.target.overloadpacket + return node.target + + modified = False + new_graph = Graph() + env: dict[ + Node, Node + ] = {} # map from node in the old graph to node in the new graph + hash_env: dict[ + tuple[torch._ops.OpOverload, int], Node + ] = {} # map from hash to a node in the new graph + token_map: dict[ + tuple[torch._ops.OpOverload, int], dict[str, Any] + ] = {} # map from hash to token + for n in graph_module.graph.nodes: + # The placeholder, output, and get_attr nodes are copied to the new graph without change + # do not CSE away random operations + if ( + n.op == "placeholder" + or n.op == "output" + or n.op == "get_attr" + or get_aten_target(n) in self.banned_ops + ): + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + else: # n.op == 'call_function', should never see n.op == 'call_module' or 'call_method' + # substitute args and kwargs members to their mapping in env if exists + # specs can be used to reconstruct nested list/dictionaries + def substitute(arg_list): + arg_list, spec = tree_flatten(arg_list) + for i in range(len(arg_list)): + v = arg_list[i] + if isinstance(v, Node) and v in env: + arg_list[i] = env[v] + return tuple(arg_list), spec + + args, args_spec = substitute(n.args) + kwargs, kwargs_spec = substitute(n.kwargs) + + # each token corresponds to a unique node + # nodes with the same token can be substituted + token = { + "target": n.target, + "args": args, + "args_spec": args_spec, + "kwargs": kwargs, + "kwargs_spec": kwargs_spec, + } + + # hash substituted args to a number, do not hash specs because specs are not hashable + hash_arg = hash((args, kwargs)) + hash_val = (n.target, hash_arg) + + # check if a node has a substitute and can be eliminated + hash_val_in_hash_env = hash_val in hash_env + if hash_val_in_hash_env and token_map[hash_val] == token: + modified = True # substitution happens and the graph is modified + env[n] = hash_env[hash_val] + continue + + new_node = new_graph.node_copy(n, lambda x: env[x]) + env[n] = new_node + if not hash_val_in_hash_env: + hash_env[hash_val] = new_node + token_map[hash_val] = token + + csed_gm = GraphModule(graph_module, new_graph) + return PassResult(csed_gm, modified) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/fake_tensor_prop.py b/phivenv/Lib/site-packages/torch/fx/passes/fake_tensor_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..eeff6650022369cac3e2c4ea811d3ff1d0abe830 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/fake_tensor_prop.py @@ -0,0 +1,109 @@ +# mypy: allow-untyped-defs +from typing import Optional + +import torch.fx +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch.fx.experimental.proxy_tensor import py_sym_types, snapshot_fake +from torch.fx.node import map_aggregate +from torch.utils._ordered_set import OrderedSet + + +__all__ = ["FakeTensorProp"] + + +@compatibility(is_backward_compatible=False) +class FakeTensorProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and record a fake tensor representing + the metadata for the node. Unlike ShapeProp, (1) this propagation + is cheap--it does the propagation with meta tensors which do not actually + store data, and (2) the fake tensors have much more fine grained information, + e.g., they have accurate alias information that can be consulted by looking + at the storages. + + Args: + module (GraphModule): The module to be executed + mode (Optional[FakeTensorMode]): The dispatch mode used to execute computation indicated by each FX Node. + """ + + def __init__( + self, module: torch.fx.GraphModule, mode: Optional[FakeTensorMode] = None + ): + super().__init__(module) + if mode is None: + mode = FakeTensorMode() + self._mode = mode + mode.epoch += 1 + mode.reset_nt_tensor_id_counter() + self.seen_subgraphs: OrderedSet[str] = OrderedSet() + + def run_node(self, n: Node): + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + ) + + if ( + n.op == "call_function" + and n.target is torch.ops.higher_order.invoke_subgraph + and n.args[1] not in self.seen_subgraphs + ): + # Prevent redundant fake tensor prop for invoke_subgraphs. Note that + # there is also fake tensor caching for the entire subgraph. This + # happens the next time we call `run_node` for the same subgraph, + # which goes through super.run_node and caches the fake tensor prop. + # Therefore, we are propagating fake tensor through the subgraphs + # twice. + assert isinstance(n.args[1], str) + assert ( + isinstance(n.args[0], torch.fx.Node) + and n.args[0].op == "get_attr" + and isinstance(n.args[0].target, str) + ) + self.seen_subgraphs.add(n.args[1]) + operands = n.args[2:] + example_inputs = [] + for operand in operands: + assert isinstance(operand, torch.fx.Node) and "val" in operand.meta + example_inputs.append(operand.meta["val"]) + return FakeTensorProp( + getattr(self.module, n.args[0].target), mode=self._mode + ).propagate(*example_inputs) + + result = super().run_node(n) + rebind_unbacked(self._mode.shape_env, n, result) + + def extract_val(obj): + if isinstance(obj, FakeTensor): + return snapshot_fake(obj) + elif isinstance(obj, torch.Tensor): + # TODO: How is it possible that we get a non fake tensor? We + # should be running under the mode... + return snapshot_fake(self._mode.from_tensor(obj, static_shapes=True)) + elif isinstance(obj, py_sym_types): + return obj + else: + return None + + meta = map_aggregate(result, extract_val) + if meta is not None: + n.meta["val"] = meta + if (shape_env := self._mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, result) + ): + n.meta["unbacked_bindings"] = symbol_to_path + + return result + + def propagate(self, *args): + fake_args = [ + self._mode.from_tensor(a) if isinstance(a, torch.Tensor) else a + for a in args + ] + return self.propagate_dont_convert_inputs(*fake_args) + + def propagate_dont_convert_inputs(self, *args): + with self._mode: + return super().run(*args) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/graph_drawer.py b/phivenv/Lib/site-packages/torch/fx/passes/graph_drawer.py new file mode 100644 index 0000000000000000000000000000000000000000..260c0138dba4f7885187653223f23894db68640b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/graph_drawer.py @@ -0,0 +1,501 @@ +# mypy: allow-untyped-defs + +import hashlib +from itertools import chain +from types import ModuleType +from typing import Any, Optional, TYPE_CHECKING + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import _parse_stack_trace +from torch.fx.node import _format_arg, _get_qualified_name +from torch.fx.operator_schemas import normalize_function +from torch.fx.passes.shape_prop import TensorMetadata + + +if TYPE_CHECKING: + import pydot + + HAS_PYDOT = True +else: + pydot: Optional[ModuleType] + try: + import pydot + + HAS_PYDOT = True + except ModuleNotFoundError: + HAS_PYDOT = False + pydot = None + + +__all__ = ["FxGraphDrawer"] + +_COLOR_MAP = { + "placeholder": '"AliceBlue"', + "call_module": "LemonChiffon1", + "get_param": "Yellow2", + "get_attr": "LightGrey", + "output": "PowderBlue", +} + +_HASH_COLOR_MAP = [ + "CadetBlue1", + "Coral", + "DarkOliveGreen1", + "DarkSeaGreen1", + "GhostWhite", + "Khaki1", + "LavenderBlush1", + "LightSkyBlue", + "MistyRose1", + "MistyRose2", + "PaleTurquoise2", + "PeachPuff1", + "Salmon", + "Thistle1", + "Thistle3", + "Wheat1", +] + +_WEIGHT_TEMPLATE = { + "fillcolor": "Salmon", + "style": '"filled,rounded"', + "fontcolor": "#000000", +} + +if HAS_PYDOT: + + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + """ + Visualize a torch.fx.Graph with graphviz + Basic usage: + g = FxGraphDrawer(symbolic_traced, "resnet18") + g.get_dot_graph().write_svg("a.svg") + """ + + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + normalize_args: bool = False, + ): + self._name = name + self.dot_graph_shape = ( + dot_graph_shape if dot_graph_shape is not None else "record" + ) + self.normalize_args = normalize_args + _WEIGHT_TEMPLATE["shape"] = self.dot_graph_shape + + self._dot_graphs = { + name: self._to_dot( + graph_module, + name, + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, + ) + } + + for node in graph_module.graph.nodes: + if node.op != "call_module": + continue + + leaf_node = self._get_leaf_node(graph_module, node) + + if not isinstance(leaf_node, torch.fx.GraphModule): + continue + + self._dot_graphs[f"{name}_{node.target}"] = self._to_dot( + leaf_node, + f"{name}_{node.target}", + ignore_getattr, + ignore_parameters_and_buffers, + skip_node_names_in_args, + parse_stack_trace, + ) + + def get_dot_graph(self, submod_name=None) -> pydot.Dot: + """ + Visualize a torch.fx.Graph with graphviz + Example: + >>> # xdoctest: +REQUIRES(module:pydot) + >>> # xdoctest: +REQUIRES(module:ubelt) + >>> # define module + >>> class MyModule(torch.nn.Module): + >>> def __init__(self) -> None: + >>> super().__init__() + >>> self.linear = torch.nn.Linear(4, 5) + >>> def forward(self, x): + >>> return self.linear(x).clamp(min=0.0, max=1.0) + >>> module = MyModule() + >>> # trace the module + >>> symbolic_traced = torch.fx.symbolic_trace(module) + >>> # setup output file + >>> import ubelt as ub + >>> dpath = ub.Path.appdir("torch/tests/FxGraphDrawer").ensuredir() + >>> fpath = dpath / "linear.svg" + >>> # draw the graph + >>> g = FxGraphDrawer(symbolic_traced, "linear") + >>> g.get_dot_graph().write_svg(fpath) + """ + if submod_name is None: + return self.get_main_dot_graph() + else: + return self.get_submod_dot_graph(submod_name) + + def get_main_dot_graph(self) -> pydot.Dot: + return self._dot_graphs[self._name] + + def get_submod_dot_graph(self, submod_name) -> pydot.Dot: + return self._dot_graphs[f"{self._name}_{submod_name}"] + + def get_all_dot_graphs(self) -> dict[str, pydot.Dot]: + return self._dot_graphs + + def _get_node_style(self, node: torch.fx.Node) -> dict[str, str]: + template = { + "shape": self.dot_graph_shape, + "fillcolor": "#CAFFE3", + "style": '"filled,rounded"', + "fontcolor": "#000000", + } + if node.op in _COLOR_MAP: + template["fillcolor"] = _COLOR_MAP[node.op] + else: + # Use a random color for each node; based on its name so it's stable. + target_name = node._pretty_print_target(node.target) + target_hash = int( + hashlib.md5( + target_name.encode(), usedforsecurity=False + ).hexdigest()[:8], + 16, + ) + template["fillcolor"] = _HASH_COLOR_MAP[ + target_hash % len(_HASH_COLOR_MAP) + ] + return template + + def _get_leaf_node( + self, module: torch.nn.Module, node: torch.fx.Node + ) -> torch.nn.Module: + py_obj = module + assert isinstance(node.target, str) + atoms = node.target.split(".") + for atom in atoms: + if not hasattr(py_obj, atom): + raise RuntimeError( + str(py_obj) + " does not have attribute " + atom + "!" + ) + py_obj = getattr(py_obj, atom) + return py_obj + + def _typename(self, target: Any) -> str: + if isinstance(target, torch.nn.Module): + ret = torch.typename(target) + elif isinstance(target, str): + ret = target + else: + ret = _get_qualified_name(target) + + # Escape "{" and "}" to prevent dot files like: + # https://gist.github.com/SungMinCho/1a017aab662c75d805c5954d62c5aabc + # which triggers `Error: bad label format (...)` from dot + return ret.replace("{", r"\{").replace("}", r"\}") + + # shorten path to avoid drawing long boxes + # for full path = '/home/weif/pytorch/test.py' + # return short path = 'pytorch/test.py' + def _shorten_file_name( + self, + full_file_name: str, + truncate_to_last_n: int = 2, + ): + splits = full_file_name.split("/") + if len(splits) >= truncate_to_last_n: + return "/".join(splits[-truncate_to_last_n:]) + return full_file_name + + def _get_node_label( + self, + module: torch.fx.GraphModule, + node: torch.fx.Node, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> str: + def _get_str_for_args_kwargs(arg): + if isinstance(arg, tuple): + prefix, suffix = r"|args=(\l", r",\n)\l" + arg_strs_list = [_format_arg(a, max_list_len=8) for a in arg] + elif isinstance(arg, dict): + prefix, suffix = r"|kwargs={\l", r",\n}\l" + arg_strs_list = [ + f"{k}: {_format_arg(v, max_list_len=8)}" for k, v in arg.items() + ] + else: # Fall back to nothing in unexpected case. + return "" + + # Strip out node names if requested. + if skip_node_names_in_args: + arg_strs_list = [a for a in arg_strs_list if "%" not in a] + if len(arg_strs_list) == 0: + return "" + arg_strs = prefix + r",\n".join(arg_strs_list) + suffix + if len(arg_strs_list) == 1: + arg_strs = arg_strs.replace(r"\l", "").replace(r"\n", "") + return arg_strs.replace("{", r"\{").replace("}", r"\}") + + label = "{" + f"name=%{node.name}|op_code={node.op}\n" + + if node.op == "call_module": + leaf_module = self._get_leaf_node(module, node) + label += r"\n" + self._typename(leaf_module) + r"\n|" + extra = "" + if hasattr(leaf_module, "__constants__"): + extra = r"\n".join( + [ + f"{c}: {getattr(leaf_module, c)}" + for c in leaf_module.__constants__ # type: ignore[union-attr] + ] # type: ignore[union-attr] + ) + label += extra + r"\n" + else: + label += f"|target={self._typename(node.target)}" + r"\n" + if self.normalize_args: + try: + args, kwargs = normalize_function( # type: ignore[misc] + node.target, # type: ignore[arg-type] + node.args, # type: ignore[arg-type] + node.kwargs, + normalize_to_only_use_kwargs=True, + ) + except Exception: + # Fallback to not normalizing if there's an exception. + # Some functions need overloads specified to normalize. + args, kwargs = node.args, node.kwargs + else: + args, kwargs = node.args, node.kwargs + if len(args) > 0: + label += _get_str_for_args_kwargs(args) + if len(kwargs) > 0: + label += _get_str_for_args_kwargs(kwargs) + label += f"|num_users={len(node.users)}" + r"\n" + + tensor_meta = node.meta.get("tensor_meta") + label += self._tensor_meta_to_label(tensor_meta) + + # for original fx graph + # print buf=buf0, n_origin=6 + buf_meta = node.meta.get("buf_meta", None) + if buf_meta is not None: + label += f"|buf={buf_meta.name}" + r"\n" + label += f"|n_origin={buf_meta.n_origin}" + r"\n" + + # for original fx graph + # print file:lineno code + if parse_stack_trace and node.stack_trace is not None: + parsed_stack_trace = _parse_stack_trace(node.stack_trace) + fname = self._shorten_file_name(parsed_stack_trace.file) + label += ( + f"|file={fname}:{parsed_stack_trace.lineno} {parsed_stack_trace.code}" + + r"\n" + ) + + return label + "}" + + def _tensor_meta_to_label(self, tm) -> str: + if tm is None: + return "" + elif isinstance(tm, TensorMetadata): + return self._stringify_tensor_meta(tm) + elif isinstance(tm, list): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + elif isinstance(tm, dict): + result = "" + for v in tm.values(): + result += self._tensor_meta_to_label(v) + return result + elif isinstance(tm, tuple): + result = "" + for item in tm: + result += self._tensor_meta_to_label(item) + return result + else: + raise RuntimeError(f"Unsupported tensor meta type {type(tm)}") + + def _stringify_tensor_meta(self, tm: TensorMetadata) -> str: + result = "" + if not hasattr(tm, "dtype"): + print("tm", tm) + result += "|" + "dtype" + "=" + str(tm.dtype) + r"\n" + result += "|" + "shape" + "=" + str(tuple(tm.shape)) + r"\n" + result += "|" + "requires_grad" + "=" + str(tm.requires_grad) + r"\n" + result += "|" + "stride" + "=" + str(tm.stride) + r"\n" + if tm.is_quantized: + assert tm.qparams is not None + assert "qscheme" in tm.qparams + qscheme = tm.qparams["qscheme"] + if qscheme in { + torch.per_tensor_affine, + torch.per_tensor_symmetric, + }: + result += "|" + "q_scale" + "=" + str(tm.qparams["scale"]) + r"\n" + result += ( + "|" + + "q_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_symmetric, + torch.per_channel_affine_float_qparams, + }: + result += ( + "|" + + "q_per_channel_scale" + + "=" + + str(tm.qparams["scale"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_zero_point" + + "=" + + str(tm.qparams["zero_point"]) + + r"\n" + ) + result += ( + "|" + + "q_per_channel_axis" + + "=" + + str(tm.qparams["axis"]) + + r"\n" + ) + else: + raise RuntimeError(f"Unsupported qscheme: {qscheme}") + result += "|" + "qscheme" + "=" + str(tm.qparams["qscheme"]) + r"\n" + return result + + def _get_tensor_label(self, t: torch.Tensor) -> str: + return str(t.dtype) + str(list(t.shape)) + r"\n" + + # when parse_stack_trace=True + # print file:lineno code + def _to_dot( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool, + ignore_parameters_and_buffers: bool, + skip_node_names_in_args: bool, + parse_stack_trace: bool, + ) -> pydot.Dot: + """ + Actual interface to visualize a fx.Graph. Note that it takes in the GraphModule instead of the Graph. + If ignore_parameters_and_buffers is True, the parameters and buffers + created with the module will not be added as nodes and edges. + """ + + # "TB" means top-to-bottom rank direction in layout + dot_graph = pydot.Dot(name, rankdir="TB") + + buf_name_to_subgraph = {} + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + style = self._get_node_style(node) + dot_node = pydot.Node( + node.name, + label=self._get_node_label( + graph_module, node, skip_node_names_in_args, parse_stack_trace + ), + **style, # type: ignore[arg-type] + ) + + current_graph = dot_graph + + buf_meta = node.meta.get("buf_meta", None) + if buf_meta is not None and buf_meta.n_origin > 1: + buf_name = buf_meta.name + if buf_name not in buf_name_to_subgraph: + buf_name_to_subgraph[buf_name] = pydot.Cluster( + buf_name, label=buf_name + ) + current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment] + + current_graph.add_node(dot_node) + + def get_module_params_or_buffers(): + for pname, ptensor in chain( + leaf_module.named_parameters(), leaf_module.named_buffers() + ): + pname1 = node.name + "." + pname + label1 = ( + pname1 + "|op_code=get_" + "parameter" + if isinstance(ptensor, torch.nn.Parameter) + else "buffer" + r"\l" + ) + dot_w_node = pydot.Node( + pname1, + label="{" + label1 + self._get_tensor_label(ptensor) + "}", + **_WEIGHT_TEMPLATE, # type: ignore[arg-type] + ) + dot_graph.add_node(dot_w_node) + dot_graph.add_edge(pydot.Edge(pname1, node.name)) + + if node.op == "call_module": + leaf_module = self._get_leaf_node(graph_module, node) + + if not ignore_parameters_and_buffers and not isinstance( + leaf_module, torch.fx.GraphModule + ): + get_module_params_or_buffers() + + for subgraph in buf_name_to_subgraph.values(): + subgraph.set("color", "royalblue") + subgraph.set("penwidth", "2") + dot_graph.add_subgraph(subgraph) # type: ignore[arg-type] + + for node in graph_module.graph.nodes: + if ignore_getattr and node.op == "get_attr": + continue + + for user in node.users: + dot_graph.add_edge(pydot.Edge(node.name, user.name)) + + return dot_graph + +else: + if not TYPE_CHECKING: + + @compatibility(is_backward_compatible=False) + class FxGraphDrawer: + def __init__( + self, + graph_module: torch.fx.GraphModule, + name: str, + ignore_getattr: bool = False, + ignore_parameters_and_buffers: bool = False, + skip_node_names_in_args: bool = True, + parse_stack_trace: bool = False, + dot_graph_shape: Optional[str] = None, + normalize_args: bool = False, + ): + raise RuntimeError( + "FXGraphDrawer requires the pydot package to be installed. Please install " + "pydot through your favorite Python package manager." + ) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/graph_manipulation.py b/phivenv/Lib/site-packages/torch/fx/passes/graph_manipulation.py new file mode 100644 index 0000000000000000000000000000000000000000..edb73fcf5f9e15c0b77e10d3b45bf9a5932cc5bf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/graph_manipulation.py @@ -0,0 +1,113 @@ +# mypy: allow-untyped-defs +from typing import Any, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import map_arg, Node, Target +from torch.fx.passes.shape_prop import ShapeProp + + +__all__ = [ + "replace_target_nodes_with", + "size_bytes", + "get_size_of_all_nodes", + "get_tensor_meta", + "get_size_of_node", +] + + +@compatibility(is_backward_compatible=False) +def replace_target_nodes_with( + fx_module: GraphModule, + old_op: str, + old_target: Target, + new_op: str, + new_target: Target, +): + """Modifies all nodes in fx_module.graph.nodes which match the specified op code and target, + and updates them to match the new op code and target""" + new_graph = Graph() + val_map: dict[Node, Node] = {} + for node in fx_module.graph.nodes: + if node.op == old_op and node.target == old_target: + args = map_arg(node.args, lambda n: val_map[n]) + kwargs = map_arg(node.kwargs, lambda n: val_map[n]) + assert isinstance(args, tuple) + assert isinstance(kwargs, dict) + val_map[node] = new_graph.create_node( + new_op, new_target, args, kwargs, node.name + ) + else: + val_map[node] = new_graph.node_copy(node, lambda n: val_map[n]) + fx_module.graph = new_graph + + +@compatibility(is_backward_compatible=False) +class size_bytes(NamedTuple): + output_size: int + total_size: int + + +@compatibility(is_backward_compatible=False) +def get_size_of_all_nodes( + fx_module: GraphModule, args: Optional[list[torch.Tensor]] = None +) -> None: + """Given a fx graph module, update each node with its total size (weights + bias + output) + and its output_size(output). For a non-module node, the total size is the output size. + return total size""" + if args is not None: + # Mark shape and dtype for each node (node.shape and node.dtype) + ShapeProp(fx_module).propagate(*args) + # Calculate the total size of the whole fx graph + for node in fx_module.graph.nodes: + if node.op == "output": + break + node.size_bytes = get_size_of_node(fx_module, node) + return + + +@compatibility(is_backward_compatible=False) +def get_tensor_meta(node: Node) -> Any: + tensor_meta = node.meta.get("tensor_meta") + + if not tensor_meta: + raise RuntimeError( + f"Node {node} has no tensor metadata associated with it! " + f"Check that shape propagation has run." + ) + + return tensor_meta + + +@compatibility(is_backward_compatible=False) +def get_size_of_node(fx_module: GraphModule, node: Node) -> size_bytes: + """Given a node with node.dtype and node.shape, return its total size and its output size. + total_size = weights + bias + output_size + """ + # Total num of elements + total_num_of_elems = 0 + # For a module, conside all parameters + if node.op == "call_module": + submodule_dict = dict(fx_module.named_modules()) + submodule = submodule_dict[node.target] + parameters = submodule.named_parameters() + # Parameters are named tuples + for _name, p in parameters: + total_num_of_elems += p.numel() + # Don't forget the output size + # node.shape is the shape of this node's output + tensor_meta = get_tensor_meta(node) + output_elem = tensor_meta.shape.numel() + total_num_of_elems += output_elem + # Assume for now if it's quantized then it's qint8 or quint8 + if tensor_meta.is_quantized: + size_per_elem_bytes = torch._empty_affine_quantized( + [], dtype=tensor_meta.dtype + ).element_size() + else: + size_per_elem_bytes = torch.tensor([], dtype=tensor_meta.dtype).element_size() + total_size = size_per_elem_bytes * total_num_of_elems + output_size = size_per_elem_bytes * output_elem + return size_bytes(output_size, total_size) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/graph_transform_observer.py b/phivenv/Lib/site-packages/torch/fx/passes/graph_transform_observer.py new file mode 100644 index 0000000000000000000000000000000000000000..5698a8ea139169665a4ca0ec40615da547d72709 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/graph_transform_observer.py @@ -0,0 +1,219 @@ +# mypy: allow-untyped-defs +import os +from typing import Callable, Optional, TypeVar + +from torch.fx import Graph, Node +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.traceback import NodeSource, NodeSourceAction + + +T = TypeVar("T") + + +from .graph_drawer import FxGraphDrawer + + +__all__ = ["GraphTransformObserver"] + + +@compatibility(is_backward_compatible=False) +class GraphTransformObserver: + __pass_count = 0 + + def __init__( + self, + gm: GraphModule, + passname: str, + subsystem: Optional[str] = None, + log_url: Optional[str] = None, + ): + """ + log_url is inferred to be torch._inductor.config.trace.log_url_for_graph_xform unless otherwise specified + """ + from torch._inductor.config import trace + + self.gm = gm + self.passname = passname + self.subsystem = subsystem + + if log_url is None: + log_url = trace.log_url_for_graph_xform + + self.log_url = log_url + + self.active = trace.enabled or self.log_url is not None + + if self.active: + self.erased_nodes: set[str] = set() + self.created_nodes: set[str] = set() + self.name_to_node: dict[str, Node] = {} + # record graph modules deepcopied from self.gm, so we can remove hoooks on them when exiting the context + self.copied_gms: list[GraphModule] = [] + + self._node_creation_hook = self.get_node_creation_hook() + self._node_erase_hook = self.get_node_erase_hook() + self._node_replace_hook = self.get_node_replace_hook() + self._deepcopy_hook = self.get_deepcopy_hook() + + # If log_url is None, we don't log anything + if self.log_url is None: + return + GraphTransformObserver.__pass_count += 1 + + self.input_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + + @classmethod + def get_current_pass_count(cls): + return cls.__pass_count + + def apply_gm_pass(self, pass_fn: Callable[[GraphModule], T]) -> Optional[T]: + with self: + if not self._check_disable_pass(): + return pass_fn(self.gm) + + return None + + def apply_graph_pass(self, pass_fn: Callable[[Graph], T]) -> Optional[T]: + with self: + if not self._check_disable_pass(): + return pass_fn(self.gm.graph) + + return None + + def _check_disable_pass(self): + if self.subsystem is None: + return False + + debug_info = lambda: self.passname # noqa: E731 + from torch._inductor.compiler_bisector import CompilerBisector + + return CompilerBisector.disable_subsystem( + "inductor", self.subsystem, debug_info + ) + + def __enter__(self): + if not self.active: + return self + self.gm._register_create_node_hook(self._node_creation_hook) + self.gm._register_erase_node_hook(self._node_erase_hook) + self.gm._register_replace_node_hook(self._node_replace_hook) + self.gm._register_deepcopy_hook(self._deepcopy_hook) + + self.erased_nodes.clear() + self.created_nodes.clear() + self.name_to_node.clear() + self.copied_gms.clear() + + for node in self.gm.graph.nodes: + self.name_to_node[node.name] = node + + return self + + def __exit__(self, type, value, tb): + if not self.active: + return + for gm in self.copied_gms + [self.gm]: + gm._unregister_create_node_hook(self._node_creation_hook) + gm._unregister_erase_node_hook(self._node_erase_hook) + gm._unregister_replace_node_hook(self._node_replace_hook) + gm._unregister_deepcopy_hook(self._deepcopy_hook) + + if self.log_url is None: + return + + if len(self.created_nodes) > 0 or len(self.erased_nodes) > 0: + for e in self.input_dot_graph.get_node_list(): + if e.get_name() in self.erased_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + assert self.log_url is not None + self.input_dot_graph.write( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_input_graph.dot", + ) + ) + + output_dot_graph = FxGraphDrawer( + self.gm, + self.passname, + ignore_getattr=True, + ignore_parameters_and_buffers=True, + ).get_dot_graph() + for e in output_dot_graph.get_node_list(): + if e.get_name() in self.created_nodes: + e.obj_dict["attributes"]["fillcolor"] = "yellow" + else: + e.obj_dict["attributes"]["fillcolor"] = "grey" + output_dot_graph.write( + os.path.join( + self.log_url, + f"pass_{GraphTransformObserver.__pass_count}_{self.passname}_output_graph.dot", + ) + ) + + def get_node_creation_hook(self): + # We have to return a function instead of using a class method directly + # to avoid max recursion issue when deepcopy a graph module within the context manager. + def on_node_creation(node): + self.created_nodes.add(node.name) + self.name_to_node[node.name] = node + source = NodeSource(None, self.passname, NodeSourceAction.CREATE) + if "from_node" not in node.meta: + node.meta["from_node"] = [source] + else: + node.meta["from_node"].append(source) + + return on_node_creation + + def get_node_erase_hook(self): + def on_node_erase(node): + self.erased_nodes.add(node.name) + self.name_to_node.pop(node.name, None) + + return on_node_erase + + def get_node_replace_hook(self): + def on_node_replace(old: Node, new: str, user: Node): + # Update node meta when replacing old node with new node + new_node = self.name_to_node.get(new, None) + + if not new_node: + return + + assert isinstance(new_node, Node) + + action = [NodeSourceAction.REPLACE] + if new_node.name in self.created_nodes: + action.append(NodeSourceAction.CREATE) + + def created_this_pass(source): + return source.pass_name == self.passname and source.action == [ + NodeSourceAction.CREATE + ] + + # remove redundant source added on node creation + new_from_node = new_node.meta.get("from_node", []) + new_from_node = [ + source for source in new_from_node if not created_this_pass(source) + ] + + # add new source + new_node_source = NodeSource(old, self.passname, action) + new_from_node.append(new_node_source) + new_node.meta["from_node"] = new_from_node + + return on_node_replace + + def get_deepcopy_hook(self): + def on_deepcopy(gm): + self.copied_gms.append(gm) + + return on_deepcopy diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/infra/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..331403972c6da66374e49d6ae067d6b0b20afb2c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/infra/__init__.py @@ -0,0 +1 @@ +from . import pass_manager diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c7ec5302dccfd35d034bed232227669db19d4b1 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc4fe53e042fc54f041d6b3d27149e5a33fcb629 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/partitioner.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09953e3c4f8cfdc129e5f4dcc52b1149e9e6bd14 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_base.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48f134f571e1c09825c88f5f92708c0f26488293 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/infra/__pycache__/pass_manager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/partitioner.py b/phivenv/Lib/site-packages/torch/fx/passes/infra/partitioner.py new file mode 100644 index 0000000000000000000000000000000000000000..342a9897afe61b6c5e3c75d666c7c64b0b0983f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/infra/partitioner.py @@ -0,0 +1,376 @@ +# mypy: allow-untyped-defs +import collections +import itertools +import logging +import operator +from collections.abc import Iterable, Sequence +from typing import Optional + +from torch.fx.graph_module import GraphModule +from torch.fx.node import _get_qualified_name, Node +from torch.fx.passes.operator_support import OperatorSupportBase +from torch.fx.passes.utils.fuser_utils import fuse_by_partitions + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +class Partition: + def __init__( + self, id: Optional[int] = None, nodes: Optional[Iterable[Node]] = None + ): + self.id = id + self.nodes = dict.fromkeys(nodes) if nodes is not None else {} + + def __repr__(self) -> str: + return str(self.nodes) + + def add_node(self, node: Node): + self.nodes.update({node: None}) + + def remove_node(self, node: Node): + del self.nodes[node] + + def size(self): + return len(self.nodes) + + +class _DependencyViewer: + def __init__(self, graph_module: GraphModule): + self.downstreams = collections.defaultdict(set) + + for node in reversed(graph_module.graph.nodes): + for output_node in node.users: + # add output_node and output_node's downstream dependency + self.downstreams[node].add(output_node) + self.downstreams[node].update(self.downstreams[output_node]) + + def downstreams_of(self, node: Node) -> set[Node]: + return self.downstreams[node] + + +class CapabilityBasedPartitioner: + def __init__( + self, + graph_module: GraphModule, + operator_support: OperatorSupportBase, + allows_single_node_partition: bool = False, + non_compute_ops: Optional[Sequence[str]] = None, + allowed_single_node_partition_ops: Optional[Sequence[str]] = None, + ) -> None: + self.graph_module = graph_module + self.operator_support = operator_support + self.allows_single_node_partition = allows_single_node_partition + self.non_compute_ops = non_compute_ops if non_compute_ops is not None else [] + self.allowed_single_node_partition_ops = ( + allowed_single_node_partition_ops + if allowed_single_node_partition_ops is not None + else [] + ) + self.dependency_viewer = _DependencyViewer(graph_module) + + def _is_node_supported(self, node: Node) -> bool: + return self.operator_support.is_node_supported( + dict(self.graph_module.named_modules()), node + ) + + def propose_partitions(self) -> list[Partition]: + # partition_map is a mapping from partition id to a set of partition id's. + # The value set contains all the partition ids that can be reached by doing a + # DFS starting from the partition id in the key. + partition_map: dict[int, set] = collections.defaultdict(set) + + # assumptions: nodes in candidate list is sorted in topological order + assignment: dict[Node, int] = {} # mapping from node to partition_id + partitions_by_id: dict[ + int, Partition + ] = {} # mapping from partition_id to partition + nodes_order: dict[ + Node, int + ] = {} # mapping from nodes to reversed topological order + partitions_order: dict[ + int, int + ] = {} # mapping from partition_id to minimum topo order of nodes in partition + partition_users: dict[ + int, set + ] = {} # mapping from partition_id to partition users + new_partition_id = itertools.count() + + # try to merge partition other_id into partition self_id + # merge only happens if the end graph doesn't contain cyclic dependency + # returns `True` when merge happens, `False` otherwise. + def maybe_merge_partition(self_id: int, other_id: int): + # merged_nodes is the union of nodes in two partition to-be-merged + self_nodes = partitions_by_id[self_id].nodes + other_nodes = partitions_by_id[other_id].nodes + + def dfs_iter_find_cycle(all_user_nodes: set[Node]): + for user_node in all_user_nodes: + visited_partition_ids = set() + + for path_node in self.dependency_viewer.downstreams_of(user_node): + # If any of the nodes in the dfs path of this node are in the merged_nodes + # list then there is a cycle in the graph. + if path_node in self_nodes or path_node in other_nodes: + return True + + # If any of the nodes in the dfs path of this node are in the assignment + # map then we have to make sure that the partitions that these nodes belong + # to do not form a cycle with the current partitions being merged. This means + # iterating through all the nodes in all the parititons that are traversed in + # the dfs path and checking if they are in the merged_nodes list. + if path_node in assignment: + partition_id = assignment[path_node] + # If the partition id has already been visited then we know that it doesn't + # form a cycle with the current partitions being merged. + if partition_id in visited_partition_ids: + continue + p_map = partition_map[partition_id] + if self_id in p_map or other_id in p_map: + return True + + visited_partition_ids.add(partition_id) + + return False + + # find new partition users if merge. + all_user_nodes = partition_users[self_id] | partition_users[other_id] + all_user_nodes.difference_update(other_nodes, self_nodes) + + # check if merge would create cyclic dependency. + if dfs_iter_find_cycle(all_user_nodes): + # return false indicating cyclic dependency found and + # merge is aborted + return self_id, False + + # merge the smaller partition into the larger. + merge_id, removed_id = self_id, other_id + if len(self_nodes) < len(other_nodes): + merge_id, removed_id = removed_id, merge_id + # no cyclic dependency found, move forward with the merge + # updating partition nodes + partitions_by_id[merge_id].nodes.update(partitions_by_id[removed_id].nodes) + # updating assignment map + for node in partitions_by_id[removed_id].nodes: + assignment[node] = merge_id + # delete other partition + del partitions_by_id[removed_id] + + partitions_order[merge_id] = min( + partitions_order[merge_id], partitions_order[removed_id] + ) + del partitions_order[removed_id] + + partition_map[merge_id] = partition_map[merge_id].union( + partition_map[removed_id] + ) + del partition_map[removed_id] + + partition_users[merge_id] = all_user_nodes + del partition_users[removed_id] + + return merge_id, True + + def merge_single_node(node: Node, id: Optional[int]): + def _update_partition_map(node: Node, id: int): + # Iterate through all the users of this node and update the partition map to indicate + # that there is a path from the partition id of this node to the target partition id. + for user_node in node.users: + target_id = assignment.get(user_node, None) + if target_id is not None: + partition_map[id].add(target_id) + partition_map[id].update(partition_map[target_id]) + + if node in assignment: + partitions_by_id[assignment[node]].remove_node(node) + + if id is None: + assignment.pop(node) + elif id not in partitions_by_id: + assignment[node] = id + partitions_by_id[id] = Partition(id=id, nodes=[node]) + partition_users[id] = set(node.users) + _update_partition_map(node, id) + else: + assignment[node] = id + partitions_by_id[id].add_node(node) + + logger.debug("Proposing partitions...") + + for node in reversed(self.graph_module.graph.nodes): + # use Dict as an ordered set to ensure deterministic partitioning result, don't care value + merge_candidates: dict[int, None] = {} + + # Note a limited horizontal fusion is enabled: + # when `node` is not supported, the code below attempts to fuse consumer of `node`. + # + # I don't see a need to add a knob to disable horizontal fusion yet, we can short-cut + # the fusion by adding an `else` block here to skip horizontal fusion. + if self._is_node_supported(node) and node not in assignment: + partition_id = next(new_partition_id) + nodes_order[node] = partition_id + partitions_order[partition_id] = partition_id + merge_single_node(node, partition_id) + merge_candidates[partition_id] = None + + # merge all possible partitions + for partition_id, _ in sorted( + partitions_order.items(), key=operator.itemgetter(1) + ): + merge_candidates[partition_id] = None + + merge_candidates_list = list(merge_candidates.keys()) + if len(merge_candidates_list) > 1: + self_id = merge_candidates_list[0] + for other_id in merge_candidates_list[1:]: + # note: merge partitions if it doesn't create cyclic dependency + # in the graph, otherwise, this is a no-op + self_id, _ = maybe_merge_partition(self_id, other_id) + + # post processing to re-assign "getitem" nodes into upstream partition + logger.debug("Reassigning getitem nodes to its producer node's partition...") + nodes_reassignment: dict[Node, int] = {} + for node in self.graph_module.graph.nodes: + is_tuple_output = True + for user in node.users: + if ( + user.op != "call_function" + or _get_qualified_name(user.target) != "_operator.getitem" + ): # type: ignore[arg-type] + is_tuple_output = False + break + + # node has tuple outputs, re-assign all following getitem node into node's partition + if is_tuple_output: + id = assignment.get(node, None) # type: ignore[arg-type] + for user in node.users: + if assignment.get(user, None) != id: # type: ignore[arg-type] + nodes_reassignment[user] = id # type: ignore[assignment] + for node, id in nodes_reassignment.items(): + merge_single_node(node, id) + + # filter out single node partitions + if not self.allows_single_node_partition: + logger.debug("Filtering out single node partitions...") + default_non_compute_ops = {"torch.ops.aten.view", "_operator.getitem"} + non_compute_ops = default_non_compute_ops.union(set(self.non_compute_ops)) + partitions_to_remove: list[int] = [] + for id, partition in partitions_by_id.items(): + compute_node_count = 0 + for node in partition.nodes: + if node.op == "call_function": + assert callable(node.target) + if _get_qualified_name(node.target) not in non_compute_ops: + compute_node_count += 1 + if ( + _get_qualified_name(node.target) + in self.allowed_single_node_partition_ops + ): + compute_node_count += 1 + if compute_node_count <= 1: + partitions_to_remove.append(id) + for id in partitions_to_remove: + del partitions_by_id[id] + + logger.debug("Partitions proposed:") + for id, partition in partitions_by_id.items(): + logger.debug( + "partition #%s: %s", id, [node.name for node in partition.nodes] + ) + + return [ + partition for partition in partitions_by_id.values() if partition.size() > 0 + ] + + def fuse_partitions( + self, partitions: list[Partition], prefix: str = "fused_" + ) -> GraphModule: + logger.debug("Fusing partitions...") + # fuse_by_partitions expects partitions in List[Dict[Node, None]]: [ {node0 : None}, {node1 : None} ] + return fuse_by_partitions( + self.graph_module, + [partition.nodes for partition in partitions], + prefix=prefix, + ) + + # remove non-compute-ops that sits at the boundary of a partition. + def remove_bookend_non_compute_ops(self, partitions: list[Partition]): + non_compute_ops = set(self.non_compute_ops) + + def is_non_compute_node(node: Node): + return ( + node.op == "call_function" + and _get_qualified_name(node.target) in non_compute_ops # type: ignore[arg-type] + ) + + # cache transparent nodes + transparent_input_nodes: dict[Node, bool] = {} + transparent_output_nodes: dict[Node, bool] = {} + + def is_transparent_input_node( + node: Node, partition: set[Node], removed_nodes: set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): + return True + if node in transparent_input_nodes: + return transparent_input_nodes[node] + if is_non_compute_node(node): + for input_n in node.all_input_nodes: + if not is_transparent_input_node(input_n, partition, removed_nodes): + transparent_input_nodes[node] = False + return False + transparent_input_nodes[node] = True + return True + transparent_input_nodes[node] = False + return False + + def is_transparent_output_node( + node: Node, partition: set[Node], removed_nodes: set[Node] + ): + if ( + node.op == "placeholder" + or (node not in partition) + or (node in removed_nodes) + ): + return True + if node in transparent_output_nodes: + return transparent_output_nodes[node] + if is_non_compute_node(node): + for output_n in node.users: + if not is_transparent_output_node( + output_n, partition, removed_nodes + ): + transparent_output_nodes[node] = False + return False + transparent_output_nodes[node] = True + return True + transparent_output_nodes[node] = False + return False + + for partition in partitions: + # Note it's ok to use `set` here, since we are only query if a node + # has been removed. We are NEVER going to iterate on nodes inside + # the set. + remove_node: set[Node] = set() + for node in partition.nodes: + if is_non_compute_node(node) and ( + is_transparent_input_node(node, set(partition.nodes), remove_node) + or is_transparent_output_node( + node, set(partition.nodes), remove_node + ) + ): + remove_node.add(node) + + if len(remove_node) != 0: + for node in remove_node: + partition.nodes.pop(node, None) + + def partition_and_fuse(self, prefix: str = "fused_") -> GraphModule: + partitions = self.propose_partitions() + fused_gm = self.fuse_partitions(partitions, prefix=prefix) + return fused_gm diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/pass_base.py b/phivenv/Lib/site-packages/torch/fx/passes/infra/pass_base.py new file mode 100644 index 0000000000000000000000000000000000000000..01bf4fd97e35547ad923697fee95e97b668bc5a9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/infra/pass_base.py @@ -0,0 +1,78 @@ +# mypy: allow-untyped-defs +import abc +from collections import namedtuple +from typing import Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = ["PassResult", "PassBase"] + + +@compatibility(is_backward_compatible=False) +class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): + """ + Result of a pass: + graph_module: The modified graph module + modified: A flag for if the pass has modified the graph module + """ + + __slots__ = () + + def __new__(cls, graph_module, modified): + return super().__new__(cls, graph_module, modified) + + +@compatibility(is_backward_compatible=False) +class PassBase(abc.ABC): + """ + Base interface for implementing passes. + + It is required to implement the `call` function so that we can directly + pass instances of the Pass directly to the PassManager and call them as a + function. + + We can directly pass an instance of a class implementing this interface into + the PassManager's `passes` attribute. + """ + + def __call__(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + Runs the precondition check, the pass itself, and the postcondition check. + """ + + self.requires(graph_module) + res = self.call(graph_module) + self.ensures(graph_module) + return res + + @abc.abstractmethod + def call(self, graph_module: GraphModule) -> Optional[PassResult]: + """ + The pass that is run through the given graph module. To implement a + pass, it is required to implement this function. + + Args: + graph_module: The graph module we will run a pass on + """ + + def requires(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called before the pass is run and will check that + the given graph module contains the preconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ + + def ensures(self, graph_module: GraphModule) -> None: # noqa: B027 + """ + This function will be called after the pass is run and will check that + the given graph module contains the postconditions needed to run the + pass. It is not required to implement this function. + + Args: + graph_module: The graph module we will run checks on + """ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/infra/pass_manager.py b/phivenv/Lib/site-packages/torch/fx/passes/infra/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..7f13d71396c477fdbb3a563109e8513e64019714 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/infra/pass_manager.py @@ -0,0 +1,309 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from functools import wraps +from queue import Queue +from typing import Callable + +import torch.nn as nn +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule +from torch.fx.passes.infra.pass_base import PassResult + + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + +__all__ = ["pass_result_wrapper", "this_before_that_pass_constraint", "PassManager"] + + +@compatibility(is_backward_compatible=False) +def pass_result_wrapper(fn: Callable) -> Callable: + """ + Wrapper for passes which currently do not return a PassResult. + This wrapper makes them return a PassResult containing the modified object + and True for the "modified" flag. + + Args: + fn (Callable[Module, Any]) + + Returns: + wrapped_fn (Callable[Module, PassResult]) + """ + if fn is None: + return None + + @wraps(fn) + def wrapped_fn(gm): + res = fn(gm) + if res is None: + return PassResult(gm, True) + if isinstance(res, PassResult): + return res + elif isinstance(res, nn.Module): + return PassResult(res, True) + + if not inspect.isfunction(fn): + wrapped_fn.__name__ = type(fn).__name__ + + return wrapped_fn + + +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] +) -> None: + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def _topological_sort_passes( + passes: list[Callable], constraints: list[Callable] +) -> list[Callable]: + """ + Args + passes: Passes that we are ordering + constraints: Constraints applied on these passes + + Returns + A sorted list of callables and a boolean of if a circular dependency + existed + """ + if len(constraints) == 0: + return passes + + # Contruct a graph mapping nodes to a list of their users + graph: dict[Callable, list[Callable]] = {p: [] for p in passes} + indegree_map: dict[Callable, int] = dict.fromkeys(passes, 0) + candidates: Queue = Queue() + for a in passes: + for b in passes: + if a == b: + continue + + for constraint in constraints: + if not constraint(a, b): + graph[b].append(a) + indegree_map[a] += 1 + + if indegree_map[a] == 0: + candidates.put(a) + + visited: dict[Callable, bool] = dict.fromkeys(passes, False) + sorted_passes: list[Callable] = [] + + while not candidates.empty(): + p = candidates.get() + sorted_passes.append(p) + visited[p] = True + + for n in graph[p]: + if not visited[n]: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + # Check if there are unvisited nodes (aka cycles in the graph) + cycle_passes = list(filter(lambda p: indegree_map[p] != 0, indegree_map.keys())) + if len(cycle_passes) != 0: + error = ( + f"Circular dependency detected within the following passes: {cycle_passes}" + ) + raise RuntimeError(error) + + return sorted_passes + + +@compatibility(is_backward_compatible=False) +def this_before_that_pass_constraint(this: Callable, that: Callable) -> Callable: + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [pass_b, pass_a] + + constraints = [this_before_that_pass_constraint(pass_a, pass_b)] + ``` + + Args: + this (Callable): pass which should occur first + that (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +@compatibility(is_backward_compatible=False) +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): List of passes. A pass is a + callable which modifies an object and returns a PassResult + constraint (Optional[List[Callable]]): List of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + steps (int): Max number of times we run the passes (default = 1). + run_checks_after_each_pass (bool): Whether to run checks and linting + after each pass + suppress_check_failures (bool): Whether to raise errors when running + checks + """ + + passes: list[Callable[[nn.Module], PassResult]] + constraints: list[Callable[[Callable, Callable], bool]] + _validated: bool = False + steps: int = 1 + + def __init__( + self, + passes=None, + constraints=None, + steps=None, + run_checks_after_each_pass: bool = False, + suppress_check_failures: bool = False, + ): + self.passes = passes or [] + self.constraints = constraints or [] + if steps: + self.steps = steps + + self.run_checks_after_each_pass = run_checks_after_each_pass + self.suppress_check_failures = suppress_check_failures + + def add_pass(self, _pass: Callable): + """ + Adds a pass into the current list of passes. + """ + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint: Callable): + """ + Adds a constraint into the current list of constraints. + """ + self.constraints.append(constraint) + self._validated = False + + def validate_constraints(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def solve_constraints(self): + """ + Finds a valid traversal order based on the given constraints and orders + the passes based on this order. + + If a circular dependency exists between the constraints and steps = 1, + then we will raise an error because if steps != 1 this means that we + will re-run the passes, allowing for circular dependencies. + """ + self.passes = _topological_sort_passes(self.passes, self.constraints) + self._validated = True + + def add_checks(self, check: Callable) -> None: + """ + Adds a function which takes runs various checks on a given graph module. + This function is run before and after each pass if the + `run_checks_after_each_pass` flag is enabled. + """ + sig = inspect.signature(check) + + if len(list(sig.parameters.values())) != 1: + raise TypeError( + "PassManager check function should only take in one variable, a module" + ) + + setattr(self, "check", check) # noqa: B010 + + def check(self, module: nn.Module) -> None: + pass + + def __call__(self, module: nn.Module) -> PassResult: + """ + Runs a list of passes in the order based on `self.passes` on the given + graph module. Each time a pass is run, checks and linting will be run on + the graph module if `run_checks_after_each_pass` is set. + + If the module is a graph module, we will run the list of passes until + the graph stops changing, or until `steps` number of times. + """ + # Order the passes based on the constraints + if not self._validated: + self.solve_constraints() + + # Check graph invariants + self.check(module) + + # Run the set of passes `steps` number of times or until the graph stops + # changing + overall_modified = False + for _ in range(self.steps): + modified = False + + # Run the set of passes on the graph module + for i, fn in enumerate(self.passes): + fn_name = fn.__name__ if inspect.isfunction(fn) else type(fn).__name__ + logger.debug("Running pass '%s'", fn_name) + + try: + res = fn(module) + + if not isinstance(res, PassResult) and not hasattr( + res, "graph_module" + ): + raise TypeError( + f"The result of the pass {fn_name} should be type PassResult." + + "Please wrap it with pass_result_wrapper()" + ) + module = res.graph_module + modified = modified or res.modified + + if isinstance(module, GraphModule): + logger.debug("Graph after pass '%s': %s", fn_name, module.graph) + module.recompile() + + # Check graph invariants + if self.run_checks_after_each_pass: + self.check(module) + + except Exception as e: + prev_pass_names = [ + p.__name__ if inspect.isfunction(p) else type(p).__name__ + for p in self.passes[:i] + ] + msg = f"An error occurred when running the '{fn_name}' pass after the following passes: {prev_pass_names}" + raise Exception(msg) from e # noqa: TRY002 + + # If the graph no longer changes, then we can stop running these passes + overall_modified = overall_modified or modified + if not modified: + break + + return PassResult(module, overall_modified) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/net_min_base.py b/phivenv/Lib/site-packages/torch/fx/passes/net_min_base.py new file mode 100644 index 0000000000000000000000000000000000000000..75738989cd54527ff22929b9ed9628e1c9a2a66c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/net_min_base.py @@ -0,0 +1,978 @@ +# mypy: allow-untyped-defs +import logging +from dataclasses import dataclass +from typing import Any, Callable, cast, Optional + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg + +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + Names, + NodeList, + NodeSet, + TensorOrTensors, + Tensors, +) + + +__all__ = [ + "FxNetMinimizerBadModuleError", + "FxNetMinimizerRunFuncError", + "FxNetMinimizerResultMismatchError", +] + +_LOGGER = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerBadModuleError(Exception): + """ + Raised if failed to split out a minimize module + """ + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerRunFuncError(Exception): + """ + Raised if error occurs during run_a or run_b functions + """ + + +@compatibility(is_backward_compatible=False) +class FxNetMinimizerResultMismatchError(Exception): + """ + Raised if comparing function thinks the results are mismatching. + """ + + +@dataclass +class _MinimizerSettingBase: + """ + Args: + `accumulate_error`: Instead of using a's input for both converted module to verify + , use the previous outputs of each converted module as input to accumulate the + errors. + + `traverse_method`: "sequential" or "binary" or "accumulate" + Determine the way of traverse the nodes in FX module. + + `find_all`: Minimizer will go through the entire model and return all problematic nodes. + + `return_intermediate`: If true, when using `run_nodes()` function to run the + model, intermediate results of all the ops will be returned as output. + + `all_outputs`: If true, when using `_run_and_compare()` function, + all the output nodes in the subgraph will be used for comparison. + """ + + accumulate_error: bool = False + traverse_method: str = "sequential" + find_all: bool = False + return_intermediate: bool = False + all_outputs: bool = False + + def __str__(self): + settings_str = "FX Minimizer Settings:\n" + + for k, v in vars(self).items(): + settings_str += f"\t{k}: {v}\n" + + return settings_str + + +class _MinimizerBase: + """ + This class is used to automatically find problematic nodes in a model. It takes a FX + graphmodule and generate some submodules while traverse the graph. Then two functions + `run_a` and `run_b` will be used to run the same submodule and a function `compare_fn` + will be used to compare the results. + + Currently we provides two ways to traverse the graph and generate submodules. + 1. Sequential traversal: this will traverse the graph node by node and generate + one submodule with one sigle node. + 2. Binary searching: this will do a binary search style traversal on the graph. + + For internal Users, a guide can be found here https://fb.quip.com/HDtuAgiKGfkP. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Tensors, + compare_fn: Callable[ + [TensorOrTensors, TensorOrTensors, Names], tuple[float, bool] + ], + settings: _MinimizerSettingBase, + module_exporter: Optional[ + Callable[[Tensors, torch.fx.GraphModule, str], None] + ] = None, + exclusion_fn: Optional[Callable[[NodeList, int, int], None]] = None, + ): + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + self.sample_input = sample_input + self.compare_fn = compare_fn + self.module_exporter = module_exporter + self.settings = settings + self.exclusion_fn = exclusion_fn + + # Stores outputs of run_a function + self.a_outputs: dict[str, Any] = {} + + # Stores outputs of run_b function + self.b_outputs: dict[str, Any] = {} + + # Stores the results of compare_fn + self.results: dict[Any, Any] = {} + + # Stores the report for the runs + self.reports: list[list[str]] = [] + + # Current iteration + self.iteration: int = 0 + + callable_nodes = { + node for node in self.module.graph.nodes if node.op in CALLABLE_NODE_OPS + } + self.run_shape_prop() + self.fusions = FxNetAccFusionsFinder(self.module, callable_nodes)() + + # Check if number of input in sample_input matches the number of placeholders + placeholders = [ + node.name for node in self.module.graph.nodes if node.op == "placeholder" + ] + assert len(placeholders) == len(self.sample_input) + + # Store sample_input + for i, name in enumerate(placeholders): + self.a_outputs[name] = sample_input[i] + self.b_outputs[name] = sample_input[i] + + def run_shape_prop(self) -> None: + """ + Helper function to run shape propagation on module. Can be overridden by + subclasses for custom shape propagation logic. + """ + ShapeProp(self.module).propagate(*self.sample_input) + + def run_a( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_b(). + """ + raise RuntimeError("run_a() is not implemented.") + + def run_b( + self, mod: torch.fx.GraphModule, inputs: Tensors, report_idx: int = -1 + ) -> TensorOrTensors: + """ + Run `mod` with `inputs` and generate output. The output will be compared with + output of run_a(). + """ + raise RuntimeError("run_b() is not implemented.") + + def _store_outputs( + self, + a_result: TensorOrTensors, + b_result: TensorOrTensors, + submodule: torch.fx.GraphModule, + ): + """ + Store the outputs of self.run_a() and self.run_b() into self.a_outputs and + self.b_outputs, so that we can use them when execute preceding nodes that + use those outputs as inputs. + + Args: + a_result: Output of self.run_a(). Could be a tensor or tensors. + b_result: Output of self.run_b(). Could be a tensor or tensors. + submodule: The module that generates a_result and b_result. + """ + output_node = next( + node for node in submodule.graph.nodes if node.op == "output" + ) + + # Only one output + if isinstance(output_node.args[0], torch.fx.Node): + self.a_outputs[output_node.args[0].name] = a_result + self.b_outputs[output_node.args[0].name] = b_result + # Multiple outputs + else: + for i, arg in enumerate(output_node.args[0]): + self.a_outputs[arg.name] = a_result[i] + self.b_outputs[arg.name] = b_result[i] + + def _get_submod_inputs( + self, main_module: torch.fx.GraphModule, submod_path: str + ) -> tuple[Tensors, Tensors]: + """ + Try get submodule inputs from stored outputs. If not found then use + torch_glow.get_submod_inputs to get the inputs. + + If accumulate_error is False, use a_input for run_a() and run_b() + otherwise use a_input for run_a and b_input for run_b. + + Args: + main_module: Top-levlel fx module. + submod_path: Path to the submodule we want to run and compare results. + + Returns: + a_input: List of tensor(s) that will be used by run_a() as submodule inputs. + b_input: List of tensor(s) that will be used by run_b() as submodule inputs. + """ + a_input = [] + b_input = [] + submodule = getattr(main_module, submod_path) + placeholders = [ + node.name for node in submodule.graph.nodes if node.op == "placeholder" + ] + + # If all placeholder can be found in stored outputs, use stored + # outputs as inputs. Otherwise, use `torch_glow.get_submod_inputs` + # to get the inputs. + if set(placeholders) <= self.a_outputs.keys(): + for name in placeholders: + a_input.append(self.a_outputs[name]) + b_input.append(self.b_outputs[name]) + else: + if self.settings.accumulate_error: + print(f"Can't find previous stored outputs named {placeholders}!") + + def get_inputs(self: torch.nn.Module, inputs: Any): + nonlocal a_input + a_input = inputs + + # Use forward hook to get the inputs to the submodule + handle = submodule.register_forward_pre_hook(get_inputs) + main_module(*self.sample_input) + handle.remove() + + b_input = a_input + + if not self.settings.accumulate_error: + return a_input, a_input + + return a_input, b_input + + def _tag_nodes(self, selected_nodes: NodeSet): + """ + Tag selected nodes with tag "minimize". Nodes with the same tags will + be split to the same submodule afterwards. + + Args: + selected_nodes: Nodes that we want to minimize. We will tag those nodes + with "minimize", all preceding nodes with "main_0" and all following + nodes with "main_1". + """ + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node in selected_nodes: + node.tag = "minimize" + elif any( + n.tag in {"minimize", "main_1"} + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS + ): + node.tag = "main_1" + else: + node.tag = "main_0" + + def _build_submodule(self, nodes: NodeSet) -> tuple[torch.fx.GraphModule, str]: + """ + Split self.module so that one submodule consists of `nodes` and only `nodes`. + + Args: + nodes: Nodes that we want to include in the minimize submodule. + + Returns: + split_module (torch.fx.GraphModule): the module after split. + submodule_name (str): the name of the submodule that consists of `nodes`. + """ + # Color provided nodes + self._tag_nodes(nodes) + + # Split module based on coloring + split_module = split_by_tags(self.module, ["main_0", "minimize", "main_1"]) + + # Find submodule containing colored nodes + submodule_name: str = "" + for child_name, _ in split_module.named_children(): # type: ignore[union-attr] + # Skip submodules we're not interested in at the moment + if "minimize" not in child_name: + continue + + if submodule_name == "": + submodule_name = child_name + else: + raise FxNetMinimizerBadModuleError( + f"Expected only one minimize submodule with nodes {nodes}" + ) + + if submodule_name == "": + raise FxNetMinimizerBadModuleError( + f"Minimize submodule was not found with nodes {nodes}" + ) + + return split_module, submodule_name # type: ignore[return-value] + + def _run_and_compare( + self, + split_module: torch.fx.GraphModule, + submod_name: str, + output_names: Names, + report_idx: int = -1, + ): + """ + Run the submodule in `split_module` that has name `submod_name` + using `self.run_a` and `self.run_b` and compare their results. + + Args: + split_module: Main module that contains the minimize submodule. + submod_name: Name of the minimize submodule. + output_names: Names of the node we want to output. If None, we + will use the original output. + """ + submodule = getattr(split_module, submod_name) + a_input, b_input = self._get_submod_inputs(split_module, submod_name) + + if len(self.reports) == 0: + self.reports.append([]) + self.iteration = 1 + + report = self.reports[report_idx if report_idx >= 0 else self.iteration - 1] + report.append("Run and compare ...") + + if output_names and not self.settings.all_outputs: + output_nodes: NodeList = [] + for node in submodule.graph.nodes: + if node.op == "output": + submodule.graph.erase_node(node) + + if node.name in output_names: + output_nodes.append(node) + + submodule.graph.output( + output_nodes[0] if len(output_nodes) == 1 else tuple(output_nodes) + ) + submodule.graph.lint() + submodule.recompile() + + # Use name of args in output node as key to store comparison result + for node in submodule.graph.nodes: + if node.op == "output": + result_key = map_arg(node.args, lambda x: x.name) + + try: + a_result = self.run_a(submodule, a_input, report_idx) + b_result = self.run_b(submodule, b_input, report_idx) + self._store_outputs(a_result, b_result, submodule) + except Exception as e: + report.append(f"Exception raised when running {submod_name}: {e}") + raise FxNetMinimizerRunFuncError( # noqa: B904 + f"Exception raised when running {submod_name}: {e}" + ) + + # Compare results + names: Names = output_names + if output_names is None: + names = [str(v) for v in result_key] # type: ignore[possibly-undefined] + + numeric_result, bool_result = self.compare_fn(a_result, b_result, names) + + self.results[result_key] = numeric_result # type: ignore[possibly-undefined] + report.append(f"Numerical accuracy = {numeric_result}") + if not bool_result: + report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] + if self.module_exporter: + if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + result_key = result_key[-1] + # If the result is still a tuple (happens in non-sequential mode), + # we only use the first element as name. + if isinstance(result_key, tuple): # type: ignore[possibly-undefined] + result_key = str(result_key[0]) + # pyre-ignore[29]: not a function + self.module_exporter( + a_input, + submodule, + result_key + "_cpu", + ) + # pyre-ignore[29]: not a function + self.module_exporter( + b_input, + submodule, + result_key + "_acc", + ) + raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] + + def _binary_search_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: + """ + Recursive binary search implementation. + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + + report: list[str] = [] + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, start_idx, end_idx) + if len(nodes) == 0: + report = ["All nodes are excluded by user"] + self.reports.append(report) + return culprits + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + self.iteration += 1 + self.reports.append(report) + report.append(f"Binary search iteration {self.iteration}") + report.append( + f"From node index {start_idx}:{first_node_name} to {end_idx - 1}:{output_node_name}. " + f"Size of the interested node list is {len(nodes)}" + ) + cur_nodes: NodeSet = set(nodes) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + + except (FxNetMinimizerRunFuncError, FxNetMinimizerResultMismatchError): + if len(nodes) == 1: + report.append( + f"This is the last node in the sub-module. " + f"Search in the current branch is successful with culprit = {cur_nodes}." + ) + self.print_report(report) + return cur_nodes + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + mid = len(nodes) // 2 + culprits = self._binary_search_impl(all_nodes, start_idx, start_idx + mid) + + if len(culprits) != 0 and not self.settings.find_all: + return culprits + + culprits = self._binary_search_impl(all_nodes, start_idx + mid, end_idx) + + if len(culprits) == 0: + report.append( + f"Further split and lowering found no errors. " + f"Unable to minimize the submodule with list of nodes: {nodes}" + ) + self.print_report(report) + + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + def _binary_traverse(self, nodes: NodeList) -> NodeSet: + """ + Binary search on `nodes` for culprit. + """ + return self._binary_search_impl(nodes, 0, len(nodes)) + + def _sequential_traverse(self, nodes: NodeList) -> NodeSet: + """ + Traverse `nodes` one by one and determine if any of them is a culprit. + """ + culprits: NodeSet = set() + + for node in nodes: + report: list[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Sequential traverse iteration {self.iteration}.") + report.append(f"Visit node: {node.name}") + + _LOGGER.info("Visit node: %s", node.name) + node_list: NodeList = [node] + if self.exclusion_fn is not None: + self.exclusion_fn(node_list, -1, -1) + if len(node_list) == 0: + report.append(f"User exclusion : {node.name}") + self.print_report(report) + if not self.settings.find_all: + return culprits + else: + continue + + cur_nodes: NodeSet = {node} + + if node in self.fusions: + cur_nodes = self.fusions[node] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [node.name]) + self.print_report(report) + except FxNetMinimizerResultMismatchError: + culprits.add(node) + report.append(f"Found culprit from numeric error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + except FxNetMinimizerRunFuncError: + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {node}") + self.print_report(report) + if not self.settings.find_all: + return culprits + + return culprits + + def _block_traverse_impl( + self, nodes: NodeList, start_idx: int, end_idx: int, find_last_node: bool + ) -> Optional[int]: + """ + Recursive block search implementation. + find_last_node: If True, search for the last node which result in numerics difference + if False: find first node in sorted node list + """ + report: list[str] = [] + + mid = (start_idx + end_idx) // 2 + cur_nodes_list: NodeList = nodes[: mid + 1] if find_last_node else nodes[mid:] + + if self.exclusion_fn: + self.exclusion_fn(cur_nodes_list, -1, -1) + + cur_nodes = set(cur_nodes_list) + + first_node_name = cur_nodes_list[0].name + last_node_name = cur_nodes_list[-1].name + target_node_name = last_node_name if find_last_node else first_node_name + + self.iteration += 1 + self.reports.append(report) + report.extend( + [ + "=" * 30, + f"Block search iteration {self.iteration}", + ] + ) + report.extend( + [ + f"Search for {'last' if find_last_node else 'first'} node in culprits", + f"From node index {start_idx}:{nodes[start_idx].name} to {end_idx}:{nodes[end_idx].name}. ", + f"Subgraph constructed by {first_node_name} to {last_node_name}", + f"Targeting node: {target_node_name}", + f"Size of the interested node list is {end_idx - start_idx + 1}", + ] + ) + report_idx = len(self.reports) - 1 + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare( + split_module, submod_name, [last_node_name], report_idx + ) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append( + f"Culprits found from node {first_node_name} to {last_node_name}." + ) + + if start_idx == mid == end_idx: + report.extend( + [ + "This is the last node in the sub-module. ", + "Search in the current branch is successful with node :", + f"{start_idx}, node name: {nodes[start_idx].name}.", + ] + ) + self.print_report(report) + return start_idx + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + if find_last_node: + return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) + else: + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) + else: + report.append( + f"Culprits not found from node start to {mid}:{nodes[mid].name}." + ) + + if start_idx == mid == end_idx: + # We did not find anything if the pointers have not moved + if (start_idx == 0 and not find_last_node) or ( + start_idx == len(nodes) - 1 and find_last_node + ): + report.append( + f"At {'last' if find_last_node else 'first'} node, no culprits found." + ) + self.print_report(report) + return None + + # Otherwise, we have converged on the border between discrepancy and valid + return start_idx + (1 if find_last_node else -1) + + report.append( + "Proceed to split and lower the halves of the current " + "sub-module individually." + ) + self.print_report(report) + + if find_last_node: + return self._block_traverse_impl( + nodes, mid + 1, end_idx, find_last_node + ) + else: + return self._block_traverse_impl(nodes, start_idx, mid, find_last_node) + + def _block_traverse( + self, nodes: NodeList, find_last_node: Optional[bool] + ) -> NodeSet: + """ + Traverse topologically sorted node list + Find minimium block (start_idx, end_idx) which contains the culprit + 1st pass: search for end_idx by finding the last node in culprit block + where Numerical accuracy (0, end_idx) > threshold + 2nd pass: search for start_idx by finding the first node in culprit block + where Numerical accuracy (start_idx, end_idx) < threshold + Form minimum block by (start_idx - 1, end_idx) + """ + culprits: NodeSet = set() + first_node_name = nodes[0].name + last_node_name = nodes[-1].name + last_node_report = [f"Block search from {first_node_name} to {last_node_name}"] + last_node_report.append("*" * 50) + self.reports.append(last_node_report) + + start_idx = 0 + end_idx = len(nodes) - 1 + + final_start_idx: Optional[int] = start_idx + final_end_idx: Optional[int] = end_idx + + run_both = True if find_last_node is None else False + + # step 1: find (0, end_idx) of culprit block + if run_both or find_last_node: + last_node_report.append("Start searching for last node in culprit") + self.print_report(last_node_report) + final_end_idx = self._block_traverse_impl(nodes, start_idx, end_idx, True) + + if final_end_idx is None: + last_node_report.append("No culprits found") + self.print_report(last_node_report) + return culprits + + last_node_report.extend( + [ + "Finish Pass 1", + f"Find end_idx = {final_end_idx}:{nodes[final_end_idx].name}", + ] + ) + self.print_report(last_node_report) + + # step 2: reduce culprit block to (start_idx, end_idx) + if run_both or not find_last_node: + first_node_report = ["Start searching for first node in culprit"] + self.print_report(first_node_report) + final_start_idx = self._block_traverse_impl( + nodes[0 : end_idx + 1], start_idx, final_end_idx or end_idx, False + ) + + if final_start_idx is None: + last_node_report.append("No culprits found") + self.print_report(last_node_report) + return culprits + + first_node_report.append("*" * 50) + self.reports.append(first_node_report) + first_node_report.extend( + [ + "Finish Pass 2", + f"Find start_idx = {final_start_idx}:{nodes[final_start_idx].name}", + ] + ) + self.print_report(first_node_report) + + # step 3: form module with minimum culprits. These indexes are guaranteed to exist + range_start, range_end = cast(int, final_start_idx), cast(int, final_end_idx) + culprits.update(nodes[range_start : range_end + 1]) + result_report = [ + f"Finish searching, found minimum block ({nodes[range_start]},{nodes[range_end]})" + ] + self.reports.append(result_report) + self.print_report(result_report) + return culprits + + def _defined_traverse(self, nodes: NodeList) -> NodeSet: + """ + run user defined `nodes` and determine if it is a culprit. + """ + culprits: NodeSet = set() + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, -1, -1) + if len(nodes) == 0: + report = ["All nodes are excluded by user"] + self.reports.append(report) + return culprits + + first_node_name = nodes[0].name + output_node_name = nodes[-1].name + report = [f"Defined graph from {first_node_name} to {output_node_name}"] + cur_nodes: NodeSet = set(nodes) + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, [output_node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + report.append(f"Found culprit {cur_nodes}") + self.print_report(report) + return culprits + + return culprits + + def _accumulate_traverse(self, nodes: NodeList) -> NodeSet: + culprits: NodeSet = set() + nodes_to_run: NodeSet = set() + + # find_all is not supported for accumulate traversal because all the + # ops run on NNPI. So we return after the first op that raises error. + if self.settings.find_all: + print("'Find All' mode is not supported in accumulate traversal.") + return culprits + + for node in nodes: + report: list[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f"Accumulate traverse iteration {self.iteration}.") + + nodes_to_run.add(node) + + node_name = node.name + if node_name is not None and isinstance(node_name, tuple): + node_name = node_name[0] + assert node_name is not None and isinstance(node_name, str), ( + f"minimize: node_name: {node_name}" + ) + + report.append(f"Add node: {node_name}") + + try: + split_module, submod_name = self._build_submodule(nodes_to_run) + self._run_and_compare(split_module, submod_name, [node_name]) + self.print_report(report) + except (FxNetMinimizerResultMismatchError, FxNetMinimizerRunFuncError): + culprits.add(node) + report.append(f"Found culprit {node}") + self.print_report(report) + return culprits + + return culprits + + def _skip_traverse_impl( + self, all_nodes: NodeList, start_idx: int, end_idx: int + ) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + culprits: NodeSet = set() + nodes: NodeList = all_nodes[start_idx:end_idx] + cur_nodes: NodeSet = set(nodes) + if self.exclusion_fn is not None: + self.exclusion_fn(nodes, start_idx, end_idx) + cur_nodes = set(nodes) + else: + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + report: list[str] = [] + self.reports.append(report) + self.iteration += 1 + report.append(f" Nodes block {self.iteration}.") + report.append( + f"From node index {start_idx} to {end_idx - 1}. " + f"Size of the interested node list is {len(nodes)}" + ) + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, []) + except FxNetMinimizerResultMismatchError: + culprits.update(cur_nodes) + report.append(f"Found culprit from numeric error: {cur_nodes}") + self.print_report(report) + return culprits + except FxNetMinimizerRunFuncError: + culprits.update(cur_nodes) + report.append(f"Found culprit from run error: {cur_nodes}") + self.print_report(report) + return culprits + else: + report.append("No discrepancy found.") + self.print_report(report) + return set() + + def _skip_traverse(self, all_nodes: NodeList, skip_nodes: list) -> NodeSet: + """ + Skip certain nodes in graph based on settings + """ + start_idx = 0 + num_nodes = len(all_nodes) + idx = 0 + culprits = set() + while idx < num_nodes: + node = all_nodes[idx] + if node.name in skip_nodes: # skip the node + if idx > start_idx: + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx) + start_idx = idx + 1 + elif idx == num_nodes - 1 and start_idx <= idx: # last node + culprits = self._skip_traverse_impl(all_nodes, start_idx, idx + 1) + idx += 1 + + return culprits + + def _collect_nodes(self, start: Optional[str], end: Optional[str]) -> NodeList: + """ + Collect nodes in the model that between nodes with name of `start` and `end`. + These two nodes are also included. + """ + nodes: NodeList = [] + add_node = start is None + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + if node.name == start: + add_node = True + + if add_node: + nodes.append(node) + + if node.name == end: + break + + return nodes + + def run_nodes(self, start: Optional[str] = None, end: Optional[str] = None): + """ + Run part of the model from `start` node to `end` node. If `start` is None + then we start from the beginning of the model. If `end` is None then we + stop at the end of the model. + + Args: + start: The name of the node which is the first node of the submodule + we want to run. If set to None, then we'll start with the first + node of the model. + end: The name of the node which is the last node of the submodule we + want to run. If set to None, we'll end with the last node of the + model. + """ + nodes = self._collect_nodes(start, end) + cur_nodes = set(nodes) + + for node in nodes: + if node in self.fusions: + cur_nodes.update(self.fusions[node]) + + output_names = [] + if self.settings.return_intermediate: + output_names = [node.name for node in nodes] + + try: + split_module, submod_name = self._build_submodule(cur_nodes) + self._run_and_compare(split_module, submod_name, output_names) + except ( + FxNetMinimizerRunFuncError, + FxNetMinimizerResultMismatchError, + ) as e: + print(e) + + def print_report(self, report: list[str]): + for i in range(len(report)): + if i > 0: + print(" . " + report[i]) + else: + print(report[i]) + + def print_reports(self): + for report in self.reports: + self.print_report(report) + + def minimize( + self, + start: Optional[str] = None, + end: Optional[str] = None, + skip_nodes: Optional[list] = None, + find_last_node: Optional[bool] = None, + ) -> NodeSet: + """ + Minimizing the model from node with name `start` to node with name `end` base + on self.settings. Find culprits that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors. + + Args: + start: The name of the node where we want to start minimizing. If set + to None, then we'll start with the first node of the model. + end: The name of the node where we want to terminate minimizing. If + set to None, we'll end with the last node of the model. + skip_nodes: The names of nodes where we want to skip during minimizing. + It'll create subgraphs without these skip nodes under the hood. + Only applicable in mode "skip". + find_last_node: True if only last_node of a culprits is needed in mode "block". + False if only the first_node of a culprits is needed. + Only applicable in mode "block". + + Returns: + nodes: A list of nodes that causes FxNetMinimizerRunFuncError or + FxNetMinimizerResultMismatchError errors during minimizing. + """ + + print(self.settings) + print(self.module.graph) + + nodes = self._collect_nodes(start, end) + + if self.settings.traverse_method == "sequential": + return self._sequential_traverse(nodes) + + if self.settings.traverse_method == "binary": + return self._binary_traverse(nodes) + + if self.settings.traverse_method == "accumulate": + return self._accumulate_traverse(nodes) + + if self.settings.traverse_method == "skip": + if skip_nodes is None: + raise RuntimeError( + "'skip_nodes' can't be None when 'traverse_method' is 'skip'." + ) + return self._skip_traverse(nodes, skip_nodes) + + if self.settings.traverse_method == "defined": + return self._defined_traverse(nodes) + + if self.settings.traverse_method == "block": + return self._block_traverse(nodes, find_last_node) + + raise RuntimeError(f"Unknown traverse method {self.settings.traverse_method}!") diff --git a/phivenv/Lib/site-packages/torch/fx/passes/operator_support.py b/phivenv/Lib/site-packages/torch/fx/passes/operator_support.py new file mode 100644 index 0000000000000000000000000000000000000000..57c59775b862de9c33121a8601d9559c6ceca2d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/operator_support.py @@ -0,0 +1,229 @@ +# mypy: allow-untyped-defs +import abc +import typing as t + +import torch +import torch.fx +from torch.fx._compatibility import compatibility + +from .shape_prop import TensorMetadata +from .tools_common import CALLABLE_NODE_OPS, get_node_target + + +__all__ = [ + "OperatorSupportBase", + "OperatorSupport", + "create_op_support", + "chain", + "OpSupports", + "any_chain", +] + +# fx.Node.target typename, as returned by `get_node_target()` +TargetTypeName = str + +# Arguments' dtypes for a given node, see `OperatorSupport` +SupportedArgumentDTypes = t.Optional[ + tuple[ + t.Sequence[t.Sequence[torch.dtype]], + dict[str, t.Sequence[torch.dtype]], + ] +] + +SupportDict = t.Mapping[TargetTypeName, SupportedArgumentDTypes] + + +@compatibility(is_backward_compatible=False) +class OperatorSupportBase(abc.ABC): + """Interface for determining if a fx.Node is supported by a backend""" + + @abc.abstractmethod + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + raise NotImplementedError + + +@compatibility(is_backward_compatible=False) +class OperatorSupport(OperatorSupportBase): + """ + `_support_dict` maps node.target typename to supported inputs dtypes. + + node.target typename is retrieved using helper function `get_node_target()` + + If supported inputs dtypes is None, it means any dtype is supported, else + we should see a tuple like (([dtypes], ...), {"name":[dtypes], ...}). + + The first tuple ([dtypes], ...) indicates what dtypes are supported for + inputs in node.args and the second dict {"name": [dtypes], ...} indicates + what dtypes are supported for inputs in node.kwargs. + + For inputs in args, if we don't want to check it, we can put None there, + e.g. (None, [torch.float]) indicates that we don't care about the type of + the first input in args. And for inputs in kwargs, if not listed, will not + be checked. + """ + + _support_dict: SupportDict + + def __init__(self, support_dict: t.Optional[SupportDict] = None): + self._support_dict = support_dict or {} + + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + """ + Args: + `submodules`: mapping from module name to the module. This can be + retrieved by calling model.named_modules(). + + `node`: a Fx node that we want to determine whether it's supported. + + Returns: + `is_supported`: whether the arg `node` is supported. + """ + if node.op not in CALLABLE_NODE_OPS: + return True + + target = get_node_target(submodules, node) + + # Target not found in _support_dict meaning that we don't support this op at all + if target not in self._support_dict: + return False + + # The rule for target is None meaning that we accept any dtype + if self._support_dict[target] is None: + return True + + args_dtypes, kwargs_dtypes = self._support_dict[target] # type: ignore[misc] + + # Check args dtypes + for i, dtypes in enumerate(args_dtypes): + if len(node.args) <= i: + break + + # None indicates we don't care about the dtype of args[i] + if dtypes is None: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.args[i], torch.fx.Node): + continue + + arg_dtype = _get_arg_dtype(node.args[i]) # type: ignore[arg-type] + if arg_dtype not in dtypes: + return False + + # Check kwargs dtypes + for k, dtypes in kwargs_dtypes.items(): + if k not in node.kwargs: + continue + + # If arg is not a node then we don't check it + if not isinstance(node.kwargs[k], torch.fx.Node): + continue + + kwarg_dtype = _get_arg_dtype(node.kwargs[k]) # type: ignore[arg-type] + if kwarg_dtype not in dtypes: + return False + + return True + + +# ====================================================================== +# Functional interfaces and utils for defining basic operator support logic +# and composing them into more complex ones +# ====================================================================== + +IsNodeSupported = t.Callable[[t.Mapping[str, torch.nn.Module], torch.fx.Node], bool] + + +@compatibility(is_backward_compatible=False) +def create_op_support(is_node_supported: IsNodeSupported) -> OperatorSupportBase: + """Wraps a `IsNodeSupported` function into an `OperatorSupportBase` instance + + `IsNodeSupported` has the same call signature as + `OperatorSupportBase.is_node_supported` + """ + + class FunctionalOperatorSupport(OperatorSupportBase): + def is_node_supported( + self, submodules: t.Mapping[str, torch.nn.Module], node: torch.fx.Node + ) -> bool: + return is_node_supported(submodules, node) + + return FunctionalOperatorSupport() + + +@compatibility(is_backward_compatible=False) +def chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns False if + any of it reports False. + """ + + def _chain(submods, node) -> bool: + return all(x.is_node_supported(submods, node) for x in op_support) + + return create_op_support(_chain) + + +@compatibility(is_backward_compatible=False) +def any_chain(*op_support: OperatorSupportBase) -> OperatorSupportBase: + """Combines a sequence of `OperatorSupportBase` instances to form a single `OperatorSupportBase` + instance by evaluating each input `OperatorSupportBase` instance, and returns True if + any of it reports True. + """ + + def _any_chain(submods, node) -> bool: + return any(x.is_node_supported(submods, node) for x in op_support) + + return create_op_support(_any_chain) + + +@compatibility(is_backward_compatible=False) +class OpSupports: + """A set of atomic `OperatorSupportBase` instances that can be combined together + to form more complex operator support logic. + """ + + @classmethod + def decline_if_input_dtype(cls, dtype: torch.dtype) -> OperatorSupportBase: + """Report a node as non-supported, if any of its arguments is of dtype""" + + def _decline_if_input_dtype( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + for arg in node.all_input_nodes: + arg_dtype = _get_arg_dtype(arg) + if arg_dtype == dtype: + return False + return True + + return create_op_support(_decline_if_input_dtype) + + @classmethod + def decline_if_node_in_names(cls, disallow_set: set[str]) -> OperatorSupportBase: + """ + If a node has a name that is in the disallow set, reported it as non-supported. + """ + + def _decline_if_node_in_names( + submodules: t.Mapping[str, torch.nn.Module], + node: torch.fx.Node, + ) -> bool: + return node.name not in disallow_set + + return create_op_support(_decline_if_node_in_names) + + +def _get_arg_dtype(arg: torch.fx.Node) -> t.Any: + assert isinstance(arg, torch.fx.Node) + tensor_meta = arg.meta.get("tensor_meta") # type: ignore[union-attr] + dtype = ( + tensor_meta.dtype + if isinstance(tensor_meta, TensorMetadata) + else arg.meta["type"] + ) + return dtype diff --git a/phivenv/Lib/site-packages/torch/fx/passes/param_fetch.py b/phivenv/Lib/site-packages/torch/fx/passes/param_fetch.py new file mode 100644 index 0000000000000000000000000000000000000000..13ebc05246a76590c6d4d6465cab95f1d2b22bd7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/param_fetch.py @@ -0,0 +1,96 @@ +from typing import Any, Callable + +import torch +import torch.nn as nn +from torch.fx._compatibility import compatibility +from torch.fx.graph_module import GraphModule + + +__all__ = [ + "default_matching", + "extract_attrs_for_lowering", + "lift_lowering_attrs_to_nodes", +] + + +# Matching method matches the attribute name of current version to the attribute name of `target_version` +@compatibility(is_backward_compatible=False) +def default_matching(name: str, target_version: int) -> str: + """Default matching method""" + return name + + +# This dict maps the nn.Module class name to the attribute name list that we want to fetch for lowering. +# The first integer in the tuple is the version number of the nn.Module class when we create the parameter list. +# If there's a version mismatch then it means the parameter names in the book might be mismatched with nn.Module. +module_fetch_book: dict[type, tuple[int, list[str], Callable[[str, int], str]]] = { + torch.nn.modules.linear.Linear: (1, ["weight", "bias"], default_matching), + torch.nn.modules.conv.Conv2d: ( + 1, + [ + "weight", + "bias", + "kernel_size", + "stride", + "padding", + "dilation", + "groups", + "padding_mode", + ], + default_matching, + ), + torch.nn.modules.batchnorm.BatchNorm2d: ( + 2, + ["weight", "bias", "running_mean", "running_var", "eps"], + default_matching, + ), + torch.nn.modules.pooling.AdaptiveAvgPool2d: (1, [], default_matching), + torch.nn.modules.pooling.MaxPool2d: ( + 1, + ["kernel_size", "stride", "padding", "dilation", "return_indices", "ceil_mode"], + default_matching, + ), + torch.nn.modules.activation.ReLU: (1, ["inplace"], default_matching), +} + + +@compatibility(is_backward_compatible=False) +def extract_attrs_for_lowering(mod: nn.Module) -> dict[str, Any]: + """If `mod` is in `module_fetch_book`, fetch the mod's attributes that in the `module_fetch_book` + after checking module's version is compatible with the `module_fetch_book`. + """ + attrs_for_lowering: dict[str, Any] = {} + attrs_for_lowering["name"] = torch.typename(mod) + + if type(mod) in module_fetch_book: + version, param_to_fetch, matching_method = module_fetch_book[type(mod)] + if version < mod._version: + raise RuntimeError( + f"Fetcher version {version} try to fetch {torch.typename(mod)} version {mod._version}, " + "please upgrade the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) + for attr in param_to_fetch: + attrs_for_lowering[attr] = getattr(mod, matching_method(attr, mod._version)) + else: + raise RuntimeError( + f"{torch.typename(mod)} is not in the module_fetch_book yet, " + "please add it to the module_fetch_book, open an issue and @842974287 " + "or report a bug to AIACC team directly." + ) + return attrs_for_lowering + + +@compatibility(is_backward_compatible=False) +def lift_lowering_attrs_to_nodes(fx_module: GraphModule) -> None: + """Recursively traverse all `fx_module` nodes and fetch the module's attributes if the node is a leaf module.""" + submodules = dict(fx_module.named_modules()) + + for node in fx_module.graph.nodes: + if node.op == "call_module": + if isinstance(submodules[node.target], GraphModule): + lift_lowering_attrs_to_nodes(submodules[node.target]) + else: + node.attrs_for_lowering = extract_attrs_for_lowering( + submodules[node.target] + ) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/pass_manager.py b/phivenv/Lib/site-packages/torch/fx/passes/pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..2afe24d8f0262b2cd3f77b0ed0e54f110619c5bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/pass_manager.py @@ -0,0 +1,253 @@ +# mypy: allow-untyped-defs +import logging +from functools import wraps +from inspect import unwrap +from typing import Callable, Optional + + +logger = logging.getLogger(__name__) + +__all__ = [ + "PassManager", + "inplace_wrapper", + "log_hook", + "loop_pass", + "this_before_that_pass_constraint", + "these_before_those_pass_constraint", +] + + +# for callables which modify object inplace and return something other than +# the object on which they act +def inplace_wrapper(fn: Callable) -> Callable: + """ + Convenience wrapper for passes which modify an object inplace. This + wrapper makes them return the modified object instead. + + Args: + fn (Callable[Object, Any]) + + Returns: + wrapped_fn (Callable[Object, Object]) + """ + + @wraps(fn) + def wrapped_fn(gm): + fn(gm) + return gm + + return wrapped_fn + + +def log_hook(fn: Callable, level=logging.INFO) -> Callable: + """ + Logs callable output. + + This is useful for logging output of passes. Note inplace_wrapper replaces + the pass output with the modified object. If we want to log the original + output, apply this wrapper before inplace_wrapper. + + + ``` + def my_pass(d: Dict) -> bool: + changed = False + if "foo" in d: + d["foo"] = "bar" + changed = True + return changed + + + pm = PassManager(passes=[inplace_wrapper(log_hook(my_pass))]) + ``` + + Args: + fn (Callable[Type1, Type2]) + level: logging level (e.g. logging.INFO) + + Returns: + wrapped_fn (Callable[Type1, Type2]) + """ + + @wraps(fn) + def wrapped_fn(gm): + val = fn(gm) + logger.log(level, "Ran pass %s\t Return value: %s", fn, val) + return val + + return wrapped_fn + + +def loop_pass( + base_pass: Callable, + n_iter: Optional[int] = None, + predicate: Optional[Callable] = None, +): + """ + Convenience wrapper for passes which need to be applied multiple times. + + Exactly one of `n_iter`or `predicate` must be specified. + + Args: + base_pass (Callable[Object, Object]): pass to be applied in loop + n_iter (int, optional): number of times to loop pass + predicate (Callable[Object, bool], optional): + + """ + assert (n_iter is not None) ^ (predicate is not None), ( + "Exactly one of `n_iter`or `predicate` must be specified." + ) + + @wraps(base_pass) + def new_pass(source): + output = source + if n_iter is not None and n_iter > 0: + for _ in range(n_iter): + output = base_pass(output) + elif predicate is not None: + while predicate(output): + output = base_pass(output) + else: + raise RuntimeError( + f"loop_pass must be given positive int n_iter (given " + f"{n_iter}) xor predicate (given {predicate})" + ) + return output + + return new_pass + + +# Pass Schedule Constraints: +# +# Implemented as 'depends on' operators. A constraint is satisfied iff a list +# has a valid partial ordering according to this comparison operator. +def _validate_pass_schedule_constraint( + constraint: Callable[[Callable, Callable], bool], passes: list[Callable] +): + for i, a in enumerate(passes): + for j, b in enumerate(passes[i + 1 :]): + if constraint(a, b): + continue + raise RuntimeError( + f"pass schedule constraint violated. Expected {a} before {b}" + f" but found {a} at index {i} and {b} at index{j} in pass" + f" list." + ) + + +def this_before_that_pass_constraint(this: Callable, that: Callable): + """ + Defines a partial order ('depends on' function) where `this` must occur + before `that`. + """ + + def depends_on(a: Callable, b: Callable): + return a != that or b != this + + return depends_on + + +def these_before_those_pass_constraint(these: Callable, those: Callable): + """ + Defines a partial order ('depends on' function) where `these` must occur + before `those`. Where the inputs are 'unwrapped' before comparison. + + For example, the following pass list and constraint list would be invalid. + ``` + passes = [ + loop_pass(pass_b, 3), + loop_pass(pass_a, 5), + ] + + constraints = [these_before_those_pass_constraint(pass_a, pass_b)] + ``` + + Args: + these (Callable): pass which should occur first + those (Callable): pass which should occur later + + Returns: + depends_on (Callable[[Object, Object], bool] + """ + + def depends_on(a: Callable, b: Callable): + return unwrap(a) != those or unwrap(b) != these + + return depends_on + + +class PassManager: + """ + Construct a PassManager. + + Collects passes and constraints. This defines the pass schedule, manages + pass constraints and pass execution. + + Args: + passes (Optional[List[Callable]]): list of passes. A pass is a + callable which modifies an object and returns modified object + constraint (Optional[List[Callable]]): list of constraints. A + constraint is a callable which takes two passes (A, B) and returns + True if A depends on B and False otherwise. See implementation of + `this_before_that_pass_constraint` for example. + """ + + passes: list[Callable] + constraints: list[Callable] + _validated: bool = False + + def __init__( + self, + passes=None, + constraints=None, + ): + self.passes = passes or [] + self.constraints = constraints or [] + + @classmethod + def build_from_passlist(cls, passes): + pm = PassManager(passes) + # TODO(alexbeloi): add constraint management/validation + return pm + + def add_pass(self, _pass: Callable): + self.passes.append(_pass) + self._validated = False + + def add_constraint(self, constraint): + self.constraints.append(constraint) + self._validated = False + + def remove_pass(self, _passes: list[str]): + if _passes is None: + return + passes_left = [ps for ps in self.passes if ps.__name__ not in _passes] + self.passes = passes_left + self._validated = False + + def replace_pass(self, _target, _replacement): + passes_left = [] + for ps in self.passes: + if ps.__name__ == _target.__name__: + passes_left.append(_replacement) + else: + passes_left.append(ps) + self.passes = passes_left + self._validated = False + + def validate(self): + """ + Validates that current pass schedule defined by `self.passes` is valid + according to all constraints in `self.constraints` + """ + if self._validated: + return + for constraint in self.constraints: + _validate_pass_schedule_constraint(constraint, self.passes) + self._validated = True + + def __call__(self, source): + self.validate() + out = source + for _pass in self.passes: + out = _pass(out) + return out diff --git a/phivenv/Lib/site-packages/torch/fx/passes/reinplace.py b/phivenv/Lib/site-packages/torch/fx/passes/reinplace.py new file mode 100644 index 0000000000000000000000000000000000000000..acfee6438878bb10c89a9ff613d285d1c98fb1f4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/reinplace.py @@ -0,0 +1,754 @@ +# mypy: allow-untyped-defs +import _operator +import itertools +from collections import defaultdict +from enum import Enum +from typing import Any, Callable + +import torch +from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode +from torch.fx import Node +from torch.fx._compatibility import compatibility +from torch.multiprocessing.reductions import StorageWeakRef +from torch.utils import _pytree as pytree +from torch.utils._pytree import tree_map_only + + +__all__ = ["reinplace"] + + +class _ViewType(Enum): + NonView = 0 + SingleOutputView = 1 + MultiOutputView = 2 + + +def _is_view_op(tgt): + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + return ( + first_arg.alias_info is not None and not first_arg.alias_info.is_write + ) + + +def _get_view_type(tgt) -> _ViewType: + if tgt is not None and isinstance(tgt, torch._ops.OpOverload): + schema = tgt._schema + if len(schema.arguments) > 0: + first_arg = schema.arguments[0] + # check if op is a view + if first_arg.alias_info is not None and not first_arg.alias_info.is_write: + # check if op is a multi-output view + if "*" in first_arg.alias_info.after_set: + return _ViewType.MultiOutputView + else: + return _ViewType.SingleOutputView + return _ViewType.NonView + + +# Stores a bunch of metadata related to functionalization each node. +# Relevant metadata: +# n.meta['fake_result']: FakeTensor (same type as the output of the node, but with FakeTenors instead of Tensors) +# The fake tensor output from running the current node +# n.meta['view_of']: Node +# If the current node n is a view of some base tensor, the 'view_of' field tells us which +# view node was used to generate the current node (a view tensor). +# This information actually makes `fake_result` redundant, but we can use `fake_result` +# to sanity check that our aliasing information is correct. +@compatibility(is_backward_compatible=False) +class _FunctionalizationMetadataProp(torch.fx.Interpreter): + def run_node(self, node: Node): + self.node_counter += 1 + result = super().run_node(node) + node.meta["fake_result"] = result + node.meta["node_idx"] = self.node_counter + + # (1) Update metadata with the list of nodes that are used by this node + # copy_() doesn't read from its first argument; it writes to it, overwriting previous data. + # We don't want to treat it as "being used as an input". + node_args = node.args + if node.target is torch.ops.aten.copy_.default: + node_args = node_args[1:] + + # (2) Update metadata to track aliasing information about view tensor nodes. + if node.op == "call_function": + view_type = _get_view_type(node.target) + if view_type == _ViewType.SingleOutputView: + assert isinstance(node.args[0], Node) + node.meta["view_of"] = node.args[0] + elif view_type == _ViewType.MultiOutputView: + self.multi_output_view_nodes[node] = node.args[0] + + # Check if we returned a multi-output view, + # and we're now grabbing the individual views from the output. + # + # For multi-output views, we want to map each output view to the base, + # but this mapping involves two separate nodes in FX IR. + # e.g. "a, b = x_1.split(...)" becomes: + # %split_tensor : [num_users=2] = call_function[target=torch.ops.aten.split.Tensor](args = (%x_1, 2), kwargs = {}) + # %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 0), kwargs = {}) + # %getitem_1 : [num_users=1] = call_function[target=operator.getitem](args = (%split_tensor, 1), kwargs = {}) + # And we'd like to set: + # getitem1.meta['view_of'] = x_1 + elif node.target is _operator.getitem: + list_arg = node.args[0] + maybe_base_of_view = self.multi_output_view_nodes.get(list_arg, None) + if maybe_base_of_view is not None: + # Note: we could also track indexing info here for multi-output views. + # I don't think this metadata is strictly needed for de-functionalization. + assert isinstance(maybe_base_of_view, Node) + node.meta["view_of"] = maybe_base_of_view + + if "view_of" in node.meta: + # We're linking the current node with its first argument as views. + # Assert here that this is actually the case, and their storages are the same. + assert isinstance(node.meta["fake_result"], FakeTensor) + assert isinstance(node.meta["view_of"].meta["fake_result"], FakeTensor) + view_storage = StorageWeakRef(node.meta["fake_result"]._typed_storage()) + base_storage = StorageWeakRef( + node.meta["view_of"].meta["fake_result"]._typed_storage() + ) + assert view_storage == base_storage + return result + + def propagate(self, *args): + self.multi_output_view_nodes = {} + self.node_counter = -1 + + with FakeTensorMode() as mode: + fake_args = [ + mode.from_tensor(a) if isinstance(a, torch.Tensor) else a for a in args + ] + return super().run(*fake_args) + + +def _schemas_match(functional_schema, inplace_schema): + names_match = ( + inplace_schema.name.endswith("_") + and inplace_schema.name[:-1] == functional_schema.name + ) + arg_types_match = len(functional_schema.arguments) == len( + inplace_schema.arguments + ) and all( + a1.type == a2.type + for a1, a2 in zip(functional_schema.arguments, inplace_schema.arguments) + ) + # for the inplace op, its first argument should be mutable + assert ( + inplace_schema.arguments[0].alias_info is not None + and inplace_schema.arguments[0].alias_info.is_write + ) + # and its remaining arguments shouldn't be. + assert all(a.alias_info is None for a in inplace_schema.arguments[1:]) + return names_match and arg_types_match + + +# TODO: this should be beefed up to be able to properly re-inplace with: +# - mutating ops (e.g. _fused_moving_avg_obs_fq_helper) +# - out= ops (e.g. angle -> angle.out) +# TODO: we should also figure this info out using torchgen. +def _maybe_get_inplace_op(op): + # __module__ seems broken; it returns torch._ops.aten which doesn't exist + if not isinstance(op, torch._ops.OpOverload): + return None + # Some view ops have inplace variants (as_strided_, etc), + # but we do NOT want the reinplacing pass to directly add these into the program. + # (they'll require extra special handling, aren't aren't really useful for perf anyway) + if _is_view_op(op): + return None + op_namespace = op.__module__.split(".")[-1] + op_base_name = op.overloadpacket.__name__ + maybe_namespace_module = getattr(torch.ops, op_namespace) + maybe_inplace_op = ( + None + if maybe_namespace_module is None + else getattr(maybe_namespace_module, f"{op_base_name}_", None) + ) + if maybe_inplace_op is None: + return None + + inplace_overloads = [ + getattr(maybe_inplace_op, overload_name) + for overload_name in maybe_inplace_op.overloads() + ] + inplace_overloads_with_matching_schemas = [ + f for f in inplace_overloads if _schemas_match(op._schema, f._schema) + ] + # Just because foo() and foo_() are both existing operators, + # They aren't guaranteed to have compatible schemas. + # For example, pow.Scalar(Scalar self, Tensor exponent) has no valid inplace variant, + # Even though several overloads of pow_ exist. + if len(inplace_overloads_with_matching_schemas) == 0: + return None + assert len(inplace_overloads_with_matching_schemas) == 1 + inplace_op = inplace_overloads_with_matching_schemas[0] + return inplace_op + + +_VIEW_INVERSE_MAP: dict[Callable[..., Any], Callable[..., Any]] = { + torch.ops.aten.diagonal_scatter.default: torch.ops.aten.diagonal.default, + torch.ops.aten.select_scatter.default: torch.ops.aten.select.int, + torch.ops.aten.slice_scatter.default: torch.ops.aten.slice.Tensor, + torch.ops.aten.as_strided_scatter.default: torch.ops.aten.as_strided.default, +} + + +# This function, given a set of set of (aliased) tensor nodes, +# Returns any nodes in the graph that *use* any of the aliases, that occur *after* op_index +# in the node ordering. +def _get_all_later_node_usages(tensor_aliases: set[Node], op_index: int): + def _add_if_tensor(x, set_): + if isinstance(x, FakeTensor): + set_.add(StorageWeakRef(x._typed_storage())) + + nodes_used_after = set() + for t in tensor_aliases: + # get all nodes that use the current alias + usage_nodes = t.users + for n in usage_nodes: + # We only care about usages after the current node + if "node_idx" not in n.meta or n.meta["node_idx"] <= op_index: + continue + # We also don't care about intermediate view ops. + # They only matter if their output is then used elsewhere + # (either in an out-of-place op, or as an output to the function). + if n in tensor_aliases: + if ( + isinstance(n.target, torch._ops.OpOverload) + or n.target == _operator.getitem + ): + continue + nodes_used_after.add(n) + return nodes_used_after + + +# Given an op that we're trying to re-inplace, "b = foo(a)", +# And given a {view}_scatter op that shows up later in the graph, "y = {view}_scatter(base, x, args...)" +# Then re-inplacing `foo()` would allow us to remove the `{view}_scatter` op entirely, IF: +# If there are any aliases in the alias_set(a) that satisfy: +# (1) The base of "alias", "alias_base", has the same size/stride/offset metadata as "base" +# (2) The output of running {view}(alias, args...) gives you the same size/stride/offset metadata +# as "alias" +def _get_view_inverse_node_usages( + later_node_usages: set[Node], self_aliases: set[Node] +) -> set[Node]: + def matching_view_metadata(a, b): + return ( + a.size() == b.size() + and a.stride() == b.stride() + and a.storage_offset() == b.storage_offset() + ) + + view_inverse_nodes = set() + # Go through them in node order, so we can see chains of view_scatter ops. + for n in sorted(later_node_usages, key=lambda x: x.meta["node_idx"]): + if n.target not in _VIEW_INVERSE_MAP: + continue + base = n.args[0] + mutated_view = n.args[1] + assert isinstance(base, Node) + assert isinstance(base.meta["fake_result"], FakeTensor) + assert isinstance(mutated_view, Node) + assert isinstance(mutated_view.meta["fake_result"], FakeTensor) + assert not isinstance(n.target, str) + # Check that this view_inverse op actually corresponds to taking doing the inverse + # of one of our existing self_alias nodes. + original_view = _VIEW_INVERSE_MAP[n.target] + for self_alias in self_aliases: + # We're looking for some alias of the self arg, "alias", + # that was created from some op `alias = foo(base, args...)` + # such that the current _scatter op "inverts" that foo call. + # We can check that by running the original op again, and checking that the strides match. + if "view_of" not in self_alias.meta: + continue + self_alias_base = self_alias.meta["view_of"] + try: + # The we're trying to re-use the args from the view_scatter call inside of the corresponding + # view op, which might throw. This just indicates that view_scatter op isn't a valid inverse + # of the current alias we're looking at. + view_replay_metadata = original_view( + self_alias_base.meta["fake_result"], *n.args[2:], **n.kwargs + ) + expected_metadata = self_alias.meta["fake_result"] + # If the alias and its base both have matching metadata, then this view_scatter op is valid to re-inplace. + if matching_view_metadata( + self_alias_base.meta["fake_result"], base.meta["fake_result"] + ) and matching_view_metadata(view_replay_metadata, expected_metadata): + view_inverse_nodes.add(n) + except Exception: + continue + + return view_inverse_nodes + + +@compatibility(is_backward_compatible=True) +def reinplace(gm, *sample_args): + """ + Given an fx.GraphModule, modifies it to perform "reinplacing", + mutating the nodes of the graph. + We look for out-of-place op call sites like `b = a.add(...)`, + and convert them to be inplace (`b = a.add_(...)`), + as long as the input to the current operator ("a") isn't re-used + anywhere later in the graph. + + This pass currently expects to operate on a **functional, ATen** graph. + This can be obtained by running `make_fx(functionalize(f))`. + + Sample inputs are needed to determine aliasing relationships of the inputs. + In general, we can't reinplace node `b = a.add(...)` if "a" aliases any of the + inputs to the program. + + Given a node "b = foo(a, args...) the algorithm for re-inplacing is as follows: + + (1) Perform some initial checks on the metadata of "a" and "args..." + that can disqualify them from being reinplaced. + + (1a) Check that the self argument we're attempting to reinplace + has acceptable dtype/size metadata to reinplace with. + + For example, if we have: + a = torch.ones(1) + b = torch.ones(10) + out = torch.add(a, b) + We can't turn that into + a.add_(b) + Because that would require resizing "a". + + Similarly, we can't convert torch.ge(a, b) into a.ge_(b), + because that would require changing a's dtype (from e.g. float32 to bool). + Note that in this specific example, we could technically do better.. + + If we see the pattern: + a_1 = a.ge(b) + a_2 = aten._to_copy(a_1, a.dtype) + Then we this should be valid to completely re-inplace + (this is exactly what functionalization will emit when it sees a.ge_(b)). + + This optimization is only really important for user programs + that directly use inplace comparison ops though. + + We also cannot re-inplace on tensors that have overlapping memory, + e.g. torch.ones(1).expand(4, 4).add_(1) + + (1b) Check if "a" is an alias of any of the program inputs. + + If it is, skip and move to the next node. + Inplace'ing an op that would cause it to mutate a program is not sound, + because that would be a side effect visible to the user. + + NOTE: there's a future optimization that we should make: + if "a" is a (alias of a) program input, but later in the program + there is a node that looks like "a.copy_(...)", + Then re-inplacing is ok to do - we are temporarily re-using a's buffer, + which will later be overwritten by the copy_() call. + + This will be an important optimization to have for programs that mutate + their inputs. It currently isn't implemented though. + + (1c) Check if "a" and "args..." alias + + For example, re-inplacing to create code like the below + isn't guaranteed to be sound: + + aten.mul_(a, a) + + (2) Check that "a" and all of its outstanding aliases are not used anywhere + later in the graph. If this is the case, then it's safe to re-inplace + to "b = foo_(a)". + + There are a few caveats to this, explained in more detail below: + (a) If "a" is used later as an argument to a view op, that is okay. + It's only a problem if "a" (or that view) is later passed + into a normal operator, or if it is returned as the program output. + (b) If "a" is a repeat argument in `foo()`, then don't reinplace. + Most ATen kernels don't make any guarantees that this is sound, + e.g. if you do aten.mul_(a, a). + So we'll just ban re-inplacing in this case. + It's only a problem if "a" (or that view) is later passed + (c) If "a" is used as an input into a view "inverse" / "scatter" + operator, it is potentially fine to re-inplace + (and remove that scatter operator from the graph). + See below for a more detailed example. + + NOTE: there is an optimization in this step that is crucial + to fully recovering performance from functionalization. + + Given this program: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a) + torch.ops.aten.fill_(b, 0) + return d + + Functionalization will emit the following: + def f(x): + a = torch.ops.aten.add(x, x) + b = torch.ops.aten.diagonal(a, 0, 1) + b_updated = torch.ops.aten.fill(b, 0) + a_updated = torch.ops.aten.diagonal_scatter(a, b_updated, 0, 1) + return a_updated + + Ordinarily, we would not be able to reinplace the fill, + because "b" aliases with "a" which is used by the diagonal_scatter call. + + "re-inplacing" is on the hook for figuring out that it is ok to + completely, the expensive diagonal_scatter call, if we re-inplace the add(). + + So, for every `alias in alias_set(a)`, instead of checking + that "alias" is not used anywhere later in the graph, + we check that + EITHER: + (a) alias is not used anywhere later in the graph + OR: + (b) alias is used exactly once later on in the graph, + in the following op: + + out = foo_scatter(alias, x, args...) + + where the following must hold: + (i) "foo_scatter" is the "inverse" operator for foo. + This only applies to "foo" ops that are view operators, + which view into a subset of the original tensor's memory. + In practice, there are ~4 operators where this applies: + diagonal -> diagonal_scatter + slice -> slice_scatter + select -> select_scatter + as_strided -> as_strided_scatter + (ii) "args..." are the same between the foo() and foo_scatter() calls. + + (3) Perform the actual re-inplacing on foo! + + (3b) is the common case, but special care is needed for {view}_scatter (3a) + + (3a) {view}_scatter ops. + + Consider this program: + a = torch.zeros(2, 2) + b = torch.ones(2) + a[0] = b + + Post functionalization, that will look like: + a = torch.zeros(2) + b = torch.ones(1) + a_updated = torch.select_scatter(a, b, 0, 0) + + In this case though, there is no "functional" op to re-inplace! + Instead, we'd like to directly remove toe select_scatter call. + We already know from (3) that this is valid, + because "a" has no later usages in the graph. + + We perform the re-inplacing on the {view}_scatter op like so + Before: + a_updated = torch.select_scatter(a, b, args...) + After: + a_slice = a.select(a, args...) + a_slice.copy_(b) + + (3b) Otherwise, replace the functional op with its inplace variant. + Before: + b = foo(a, args...) + After: + a.foo_(args...) + + (4) Finally, after converting either: + Before: + b = foo(a) + After: + foo_(a) + or + Before: + b = {slice}_scatter(a, mutated_slice, args...) + After: + slice = {slice}(a, args...) + slice.copy_(mutated_slice) + + We now need to find all later nodes that use "b" as an argument + and update them to take in "a" instead. + + Note that for the majority of inplace ops, this isn't actually necessary + (because most inplace ops return "self" as their output). + This isn't generally true for all mutable ops though, which is why + we need to actually replace all of the arguments. + + We also need to update our metadata of Dict[StorageWeakRef, Set[Node]], + That maps a given tensor storage to the set of all nodes that take in that storage + as an input. + Specifically, re-inplacing `b = foo(a)` causes "a" and "b"'s sets to get fused + together. + + (5) Any "view_inverse/scatter" nodes that were identified as "it's ok to ignore them" + during step (3) get manually deleted from the graph. + Their outputs are no longer used, so technically standard DCE would be able + to do this, but we can no longer run FX's DCE pass now that we have mutable + ops in the graph. + """ + _FunctionalizationMetadataProp(gm).propagate(*sample_args) + + # Useful debug printing + # def _print(x): + # if isinstance(x, FakeTensor): + # print(f'fake_result: {StorageWeakRef(x._typed_storage()).cdata}') + + # for n in gm.graph.nodes: + # print(n.format_node()) + # if hasattr(n, 'meta'): + # print(f'node_idx: {n.meta["node_idx"]}') + # if 'fake_result' in n.meta: + # tree_map(_print, n.meta['fake_result']) + # if 'view_of' in n.meta: + # print(f'view_of: {str(n.meta["view_of"])}') + # print() + + # We need to know which nodes correspond to inputs (or their aliases) + # so we know not to re-inplace them. + # NOTE: later, we'll need to add an optimization for fully recovering performance + # on programs that mutate inputs. + input_storages = { + StorageWeakRef(node.meta["fake_result"]._typed_storage()) + for node in gm.graph.nodes + if ( + node.op == "placeholder" + and isinstance(node.meta["fake_result"], torch.Tensor) + ) + } + + # We also need to know for a given node, what are all of its aliasing nodes. + storage_to_nodes: dict[StorageWeakRef, set[Node]] = defaultdict(set) + for n in gm.graph.nodes: + if "fake_result" in n.meta: + # Tree-mapping because some ops can return lists of tensors. + def _add_to_map(x): + if isinstance(x, FakeTensor): + storage_to_nodes[StorageWeakRef(x._typed_storage())].add(n) + + pytree.tree_map_(_add_to_map, n.meta["fake_result"]) + + # inplace-ify functional ops, subject to the constraints written below. + all_later_view_inverse_nodes_to_delete = set() + for node in gm.graph.nodes: + if node.op == "call_function": + # Today, the re-inplace pass on directly acts on: + # - functional ops with an inplace variant + # - {view}_scatter ops that can be potentially removed from the graph. + # Both of these ops take in tensor first args, so filtering on this condition + # makes the later code simpler. + # We should revisit this at some point though, particularly when we also want + # the reinplacer to be able to handle out= and mutable operators + # and tensorlist first args (like `_foreach_` ops). + if not isinstance(node.target, torch._ops.OpOverload): + continue + if len(node.target._schema.arguments) < 1: + continue + if type(node.target._schema.arguments[0].type) != torch.TensorType: + continue + + # Step 1a: Check that the self argument we're attempting to reinplace + # has the same size/stride as the output. + # For example, we shouldn't try to reinplace torch.add(scalar_tensor, larger_tensor) + # As it would require resizing scalar_tensor. + # (We could potentially swizzle this into larger_tensor.add_(scalar_tensor), + # this is probably an optimization to revisit later). + self_arg = node.args[0] + self_flattened = pytree.tree_leaves(self_arg.meta["fake_result"]) + node_flattened = pytree.tree_leaves(node.meta["fake_result"]) + self_has_wrong_metadata = False + if len(self_flattened) == len(node_flattened): + for self_meta, node_meta in zip(self_flattened, node_flattened): + if self_meta.numel() != node_meta.numel(): + self_has_wrong_metadata = True + if self_meta.dtype != node_meta.dtype: + self_has_wrong_metadata = True + # We also cannot re-inplace on tensors that have internal memory overlap. + # e.g. torch.ones(1).expand(4, 4).add_(1) + if torch._debug_has_internal_overlap(self_meta) == 1: + self_has_wrong_metadata = True + # Here, we (optimistically) assume that a.resize(b) is valid to re-inplace, + # Since users should never really be calling the functional "torch.ops.aten.resize" + # op directly in their programs. + if self_has_wrong_metadata and node.target != torch.ops.aten.resize.default: + continue + + # Step 1b: ensure that the op we're trying to re-inplace isn't a program input + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) + if self_arg_storage in input_storages: + # TODO: later, add the optimization for handling `copy_()` calls in the graph. + continue + if len([x for x in node.args if x is self_arg]) > 1: + # Step 1c: + # Calling stuff like aten.mul_(a, a) isn't guaranteed to be sound, + # so we prevent re-inplacing in this case. + continue + + self_arg_storage = StorageWeakRef( + self_arg.meta["fake_result"]._typed_storage() + ) + self_aliases = storage_to_nodes[self_arg_storage] + + # First, we find all later usages of any of the aliases of self_arg. + later_node_usages = _get_all_later_node_usages( + self_aliases, node.meta["node_idx"] + ) + # Then, we check if any of those later usages are actually view_scatter ops + # that are safe to fully remove. + later_view_inverse_node_usages = _get_view_inverse_node_usages( + later_node_usages, self_aliases + ) + + # Step 2: Check to see if the input to the op is re-used later in the graph. + # If not (same goes for its aliases), then this op is safe to re-in place. + # This is a slightly roundabout way to check that there are no later usages of the current self argument. + # (later_view_inverse_node_usages corresponds to "view_scatter" nodes that we are allowed to delete) + can_reinplace = len(later_node_usages - later_view_inverse_node_usages) == 0 + if not can_reinplace: + continue + + # Step 3a: Special handling for when we see *_scatter operators. + # When we see an operator like `b = torch.slice_scatter(a, ...)`, + # instead of trying to "inplace" it into a.slice_scatter_(..._), + # we would prefer to remove it from the graph entirely, + # and instead copy_() the slice directly into the larger tensor. + # See the description of the algorithm for a full example. + if ( + node.target in _VIEW_INVERSE_MAP + and node not in all_later_view_inverse_nodes_to_delete + ): + view_op = _VIEW_INVERSE_MAP[node.target] + # Before: + # base_updated = torch.ops.aten.slice_scatter.default(base, mutated_slice, args...) + # After: + # slice = torch.ops.aten.slice.default(base, args...) + # slice.copy_(mutated_slice) + with gm.graph.inserting_before(node): + mutated_slice_node = node.args[1] + remaining_slice_args = node.args[2:] + slice_node = gm.graph.create_node( + "call_function", + view_op, + (self_arg,) + tuple(remaining_slice_args), + node.kwargs, + ) + gm.graph.create_node( + "call_function", + torch.ops.aten.copy_.default, + ( + slice_node, + mutated_slice_node, + ), + {}, + ) + # Add the slice_scatter node to our "nodes to delete" list. + all_later_view_inverse_nodes_to_delete.add(node) + + else: + # Step 3b: Check to see if this operator has an inplace variant. + maybe_inplace_op = _maybe_get_inplace_op(node.target) + if maybe_inplace_op is None: + continue + # And if so, replace it with its inplace variant. + node.target = maybe_inplace_op + + # At this point, 'storage_to_nodes' will be stale. + # Now that we're inplacing `b = foo(a)`, we need to effectively + # union together the dict values for b and a's storage. + # Hmm... morally I think we also want to keep the `fake_result` metadata + # up to date here, but I'm not sure how easy it is to do. + # Maybe it's fine to wait until the end of the pass to update it. + curr_node_storage = StorageWeakRef( + node.meta["fake_result"]._typed_storage() + ) + storage_to_nodes[self_arg_storage].update( + storage_to_nodes[curr_node_storage] + ) + storage_to_nodes[curr_node_storage].update( + storage_to_nodes[self_arg_storage] + ) + + # Need to remember the view_scatter view nodes we found so we can remove them alter. + all_later_view_inverse_nodes_to_delete.update( + later_view_inverse_node_usages + ) + + # Step 4: + # Now that we've replaced b = a.foo() with a.foo_(), + # We need to replace any later usages of "b" with "a" + for old in itertools.chain([node], later_view_inverse_node_usages): + new = old.args[0] + nodes_to_update = [ + n for n in old.users if n.meta["node_idx"] > node.meta["node_idx"] + ] + for node_to_update in nodes_to_update: + + def replace_arg(a): + if a == old: + return new + return a + + # First, replace usages of "b" with "a" + node_to_update.args = tree_map_only( + Node, replace_arg, node_to_update.args + ) + node_to_update.kwargs = tree_map_only( + Node, replace_arg, node_to_update.kwargs + ) + + # Second, update our storage_to_nodes data structure. + old_flattened_res = pytree.tree_leaves(old.meta["fake_result"]) + node_flattened_res = pytree.tree_leaves( + node_to_update.meta["fake_result"] + ) + + old_res_storage = { + StorageWeakRef(x._typed_storage()) + for x in old_flattened_res + if isinstance(x, FakeTensor) + } + node_res_storage = { + StorageWeakRef(x._typed_storage()) + for x in node_flattened_res + if isinstance(x, FakeTensor) + } + + # This will happen if we're updating a view op, e.g. + # e.g. replacing + # x = view(old) + # x = view(new) + # When that happens, we need to make sure to keep our + # storage mapping up to date. + # + # We're checking for len(...) == 1 here because all view ops are guaranteed to return either a single tensor, + # or multiple tensors that all share the same storage. + # We can't just check equality because we might encounter FX nodes that return zero tensor outputs. + if ( + len(old_res_storage) == 1 + and len(node_res_storage) == 1 + and old_res_storage == node_res_storage + ): + new_flattened_res = pytree.tree_leaves(new.meta["fake_result"]) + new_res_storage = { + StorageWeakRef(x._typed_storage()) + for x in new_flattened_res + if isinstance(x, FakeTensor) + } + assert len(new_res_storage) == 1 + (new_ref,) = new_res_storage + (node_ref,) = node_res_storage + # Technically, "old_ref" and all its aliases will remain + # in our mapping. + # That should be fine though, since we deleted "old" + # from the graph at this point. + storage_to_nodes[node_ref].update(storage_to_nodes[new_ref]) + storage_to_nodes[new_ref].update(storage_to_nodes[node_ref]) + + # Step 4: delete any _scatter nodes that we de-functionalized + # Need to take care not to delete any of these nodes until after *all* modifications + # to the graph are finished. + for to_delete in all_later_view_inverse_nodes_to_delete: + gm.graph.erase_node(to_delete) + + gm.recompile() + return gm diff --git a/phivenv/Lib/site-packages/torch/fx/passes/runtime_assert.py b/phivenv/Lib/site-packages/torch/fx/passes/runtime_assert.py new file mode 100644 index 0000000000000000000000000000000000000000..4bebe596d8e78c3bedcfd0a42c7c3c7285deb969 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/runtime_assert.py @@ -0,0 +1,633 @@ +# mypy: allow-untyped-defs +import functools +import logging +import operator +import sys +from typing import Any, Optional, TYPE_CHECKING + + +# Import sympy and ShapeEnv during TYPE_CHECKING since importing sympy is slow +if TYPE_CHECKING: + import sympy + + from torch.fx.experimental.symbolic_shapes import ShapeEnv +else: + ShapeEnv = Any + +import torch +import torch.utils._pytree as pytree +from torch import fx +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code +from torch.fx.experimental.proxy_tensor import py_sym_types +from torch.fx.experimental.sym_node import SymNode +from torch.fx.graph_module import GraphModule + + +__all__ = ["insert_deferred_runtime_asserts"] + +log = logging.getLogger(__name__) +graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code_verbose") + + +def _get_example_value(node: fx.Node) -> Optional[str]: + """ + Get the example value key for a node, since dynamo uses "example_value" + while non-strict export uses "val. + """ + if "example_value" in node.meta: + return node.meta["example_value"] + elif "val" in node.meta: + return node.meta["val"] + else: + return None + + +def _get_sym_val(node: fx.Node) -> Optional["sympy.Expr"]: + val = _get_example_value(node) + if isinstance(val, py_sym_types): + return val.node.expr + return None + + +@compatibility(is_backward_compatible=True) +def insert_deferred_runtime_asserts( + gm: GraphModule, + shape_env: ShapeEnv, + name: str, + export: bool = False, +) -> None: + """ + During tracing, we may have discovered that some data-dependent values + had runtime assert on them; e.g., torch.empty(x.item()) induces a runtime + that x.item() >= 0. This asserts can happen unpredictably during fake + tensor propagation, so we cannot conveniently insert them into the FX graph + when they occur. Instead, we accumulate them in the ShapeEnv, and in this + pass insert them into the graph as proper tests. + + This pass also deduplicates size-related computation, CSE-ing ops that produce + symbolic values and/or are involved in runtime asserts. Additionally, shape calls + (size/stride/storage_offset) are turned into compute on input sizes if possible, + allowing intermediate tensors to be freed earlier. For example, here dynamo will + DCE the cat and repeat calls: + + z = torch.cat([x, x], dim=0) # 2*s0 + w = z.repeat(y.shape[0]) # 2*s0*s1 + _w = w.shape[0] + # something with _w, but not w ... + + # turns into -> + _w0 = 2 * s0 + _w = _w0 * s1 + + # where s0, s1 are either SymInt graph inputs, or the result of added size calls + + Redundant torch._check or torch.ops.aten._assert_scalar.default calls that assert + the same expression, and redundant constrain_range calls are also deduplicated. + Additionally, because single-symbol bound checks (e.g. u0 >= 0, u0 <= 5) accumulate + information in the ShapeEnv, the ShapeEnv contains min/max bounds for each symbol, + and we delete all previous calls, adding bound checks at the end of this pass. + """ + + # Import sympy locally + import sympy + + from torch._export.passes._node_metadata_hook import _set_node_metadata_hook + from torch.fx.experimental.symbolic_shapes import ( + _get_placeholder_expr, + _has_uninterpretable_sympy_function, + CallMethodKey, + cast_symbool_to_symint_guardless, + ConvertIntKey, + DivideByKey, + free_symbols, + InnerTensorKey, + resolve_unbacked_bindings, + ) + from torch.utils._sympy.numbers import int_oo + from torch.utils._sympy.reference import ( + OptimizedPythonReferenceAnalysis, + PythonReferenceAnalysis, + ) + from torch.utils._sympy.value_ranges import ValueRanges + + # TODO: Request simplification on runtime asserts before emitting them + ras_by_symbol = shape_env.deferred_runtime_asserts.copy() + graph = gm.graph + tracer = fx.proxy.GraphAppendingTracer(graph) + graph_code_log.debug( + "%s", + lazy_format_graph_code( + f"pre insert_deferred_runtime_asserts {name}", gm, colored=True + ), + ) + + # We are going to mutate the dict + expr_to_proxy: dict[sympy.Expr, fx.Proxy] = {} + placeholders = set() + first_non_placeholder = None + for node in graph.nodes: + if node.op != "placeholder": + first_non_placeholder = node + break + else: + placeholders.add(node) + + def _is_intermediate_tensor_sym_call(node: fx.Node) -> bool: + """ + If a size/stride/storage offset call on an intermediate tensor, + we can try to compute the value from input shapes instead. + """ + return ( + (val := _get_sym_val(node)) is not None + and not isinstance(val, sympy.Number) + # this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported + and not _has_uninterpretable_sympy_function(val) + and any( + isinstance(arg, fx.Node) + and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size)) + and arg.op != "placeholder" + for arg in node.args + ) + ) + + # Figure out what key to use, val or example_value + val_key = "val" + for node in graph.nodes: + if "example_value" in node.meta: + val_key = "example_value" + break + elif "val" in node.meta: + break + + def _node_metadata_hook( + node: torch.fx.Node, + stack_trace: Optional[str] = None, + nn_module_stack: Optional[dict[str, Any]] = None, + ) -> None: + fake_args = pytree.tree_map( + lambda arg: ( + _get_example_value(arg) if isinstance(arg, torch.fx.Node) else arg + ), + node.args, + ) + try: + target = node.target + if node.op == "call_method": + assert isinstance(node.target, str) + target = getattr(fake_args[0], node.target) + fake_args = fake_args[1:] + node.meta[val_key] = target(*fake_args) # type: ignore[operator] + except NotImplementedError: + # This can happen when attempting to reify a symbol with an unsupported call_function node, + # e.g. with NestedTensors + sym_size.int via match_symbol(). + # This seems to be fine, as the node gets CSE'd and deleted later in favor of a SymInt graph input. + pass + if stack_trace is not None: + node.meta["stack_trace"] = stack_trace + if nn_module_stack is not None: + node.meta["nn_module_stack"] = nn_module_stack + + # Track asserts/checks we've added + added_asserts: set[sympy.Expr] = set() + constrained_unbacked_symbols: set[sympy.Symbol] = set() + + Analysis = PythonReferenceAnalysis if export else OptimizedPythonReferenceAnalysis + + def _sympy_interp(expr_to_proxy, expr): + # sympy_interp() with hash consing + from sympy import Integer, Number, Symbol + from sympy.logic.boolalg import BooleanAtom + + from torch.utils._sympy.interp import _run_sympy_handler, sympy_interp + + # hash cons + if expr in expr_to_proxy: + return expr_to_proxy[expr] + # base cases, don't cache + if isinstance(expr, (Integer, Number, Symbol, BooleanAtom)): + return sympy_interp(Analysis, expr_to_proxy, expr) + + # hash cons on arguments, run expr handler + expr_to_proxy[expr] = _run_sympy_handler( + Analysis, + [_sympy_interp(expr_to_proxy, arg) for arg in expr.args], + expr, + ) + return expr_to_proxy[expr] + + def _is_bound_expr_for_symbol(expr: "sympy.Expr") -> bool: + # This is probably unnecessary, but since torch._check() calls for single-symbol bounds + # like u0 >= 0, 10 >= u0 accumulate range info in the ShapeEnv, we designate these calls as redundant + # and instead add 2 runtime asserts at the end of this pass, if the min/max bounds are non-trivial. + if len(expr.args) != 2 or expr.func not in (sympy.LessThan, sympy.GreaterThan): + return False + lhs, rhs = expr.args + return (isinstance(lhs, sympy.Symbol) and isinstance(rhs, sympy.Number)) or ( + isinstance(rhs, sympy.Symbol) and isinstance(lhs, sympy.Number) + ) + + def add_runtime_asserts(ras): + for ra in ras: + if ( + # redundant + ra.expr in added_asserts + # if we've already added a constrain_range call for this symbol, + # then single-symbol bound asserts like u0 >= 0, u0 <= 5 are redundant. + or ( + len(ra.expr.free_symbols) == 1 + and next(iter(ra.expr.free_symbols)) in constrained_unbacked_symbols + and _is_bound_expr_for_symbol(ra.expr) + ) + # don't try to reify sympy functions we can't turn into FX nodes + or _has_uninterpretable_sympy_function(ra.expr) + ): + continue + + log.debug("inserting runtime assert %s", ra.expr) + # Need to process ALL free symbols, not just unbacked ones + fvs = free_symbols(ra.expr) + missing = fvs - expr_to_proxy.keys() + if missing: + i1 = min(missing, key=str) + # TODO: Remove relaxing assert on unbacked_symint https://github.com/pytorch/pytorch/issues/119689 + # assert shape_env.is_unbacked_symint(i1), i1 + ras_by_symbol.setdefault(i1, []).append(ra) + else: + # Convert the sympy expression into a sequence of FX + # nodes + with _set_node_metadata_hook(gm, _node_metadata_hook): + res = _sympy_interp(expr_to_proxy, ra.expr).node + + graph.call_function( + torch.ops.aten._assert_scalar.default, + # TODO: use ra.msg here, but it's pretty + # useless right now + ( + res, + f"Runtime assertion failed for expression {ra.expr} on node '{res}'", + ), + ) + added_asserts.add(ra.expr) + + nodes = list(graph.nodes) + for i, node in enumerate(nodes[:-1]): + # Placeholders can match symbols, but when we destructure them + # with size we have to make sure we insert the nodes after all + # the placeholders + with graph.inserting_before( + nodes[i + 1] if node not in placeholders else first_non_placeholder + ): + # Unfortunately, this logic still must remain because manual + # make_fx calls may not explicitly bind all symbolic ints as + # arguments to the function, so we must infer it from the other + # arguments + if ( + node in placeholders + and (example_value := _get_example_value(node)) is not None + ): + + def match_symbol(symint, cb): + if ( + isinstance(symint, torch.SymInt) + and isinstance(symint.node, SymNode) + and isinstance( + s := _get_placeholder_expr(symint.node), sympy.Symbol + ) + and s not in expr_to_proxy + ): + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer) + log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + + match_symbol(example_value, lambda: node) + if isinstance(t := example_value, torch.Tensor): + for i, s in enumerate(t.size()): + match_symbol( + s, + lambda: graph.call_function( + torch.ops.aten.sym_size.int, (node, i) + ), + ) + if not is_sparse_any(t): + for i, s in enumerate(t.stride()): + match_symbol( + s, + lambda: graph.call_function( + torch.ops.aten.sym_stride.int, (node, i) + ), + ) + match_symbol( + t.storage_offset(), + lambda: graph.call_function( + torch.ops.aten.sym_storage_offset.default, (node,) + ), + ) + + # Handle asserts that aren't associated with any symbol. This + # doesn't really have to be in the loop as it will only run once, + # it just needs to happen right after the placeholders. + # insert this after placeholders & added sym nodes, and before non-placeholders. + if node == first_non_placeholder: + add_runtime_asserts(ras_by_symbol.pop(None, [])) # type: ignore[call-overload] + + # deduplicate asserts already present in graph, and remove trivial asserts + if node.target in ( + torch._check, + torch.ops.aten._assert_scalar.default, + ): + if ( + node.args[0] == True # noqa: E712 + or (assert_expr := _get_sym_val(node.args[0])) in expr_to_proxy + and assert_expr in added_asserts + ): + arg = node.args[0] + gm.graph.erase_node(node) + if isinstance(arg, fx.Node) and not arg.users: + gm.graph.erase_node(arg) + else: + added_asserts.add(assert_expr) # type: ignore[arg-type] + + # hash cons, replace function calls that return torch.SymInts with direct references to + # FX nodes built up to reify the sympy expression. + if ( + node.op != "placeholder" + and (sym_expr := _get_sym_val(node)) is not None + ): + # this guards against deleting calls like item() that produce new untracked symbols + def has_new_untracked_symbols(): + for symbol in sym_expr.free_symbols: + if symbol not in expr_to_proxy: + return True + return False + + # this guards against deleting calls that produce unbacked bindings we haven't yet seen. + # in this case looking at sym_expr.free_symbols might not be enough, if the example value has a hint + # (is backed), but produces an unbacked symbol. In this case keep the node alive. + resolved_unbacked_bindings = resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings", {}) + ) + + assert resolved_unbacked_bindings is not None + + def has_new_unbacked_bindings(): + for key in resolved_unbacked_bindings.keys(): + if key not in expr_to_proxy: + return True + return False + + # maybe re-reify expression, replace current node + if ( + sym_expr in expr_to_proxy + or ( # example value is redundant + _is_intermediate_tensor_sym_call(node) + # shape call on intermediate tensor, turn into computation on input shapes + and not has_new_untracked_symbols() + ) + ) and not has_new_unbacked_bindings(): + if _is_intermediate_tensor_sym_call( + node + ): # reify from input shapes + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + expr_to_proxy[sym_expr] = _sympy_interp( + expr_to_proxy, sym_expr + ) # type: ignore[arg-type] + # won't try DCE-ing tensor compute here + hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] + node.replace_all_uses_with(hash_node) + gm.graph.erase_node(node) + log.debug( + "CSE node %s -> %s for expr %s", node, hash_node, sym_expr + ) + + # store node in hash cons, don't delete/replace + elif sym_expr not in expr_to_proxy and not isinstance( + sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) + ): # don't hash cons primitives + expr_to_proxy[sym_expr] = fx.Proxy(node, tracer=tracer) # type: ignore[arg-type] + + # We add sym_constrain_range calls for symbols later in any case if they're size-like or range-constrained, + # so calls before that are redundant. + if node.target in ( + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten.sym_constrain_range_for_size.default, + ): + gm.graph.erase_node(node) + + defs = [] + + # AOTAutograd will create new symbols as the unbacked_bindings keys, which PropagateSymInts will set as + # equivalent, but the refinement calls we perform in this pass may struggle with associating the two. + # More concretely, when re-exporting/tracing, constraining only the new symbol may not communicate enough + # information about the old symbol when we re-export, raising errors on data-dependent guards. + # Call resolve_unbacked_bindings() to get the original symbol if present, otherwise we take it as is. + if unbacked_bindings := resolve_unbacked_bindings( + shape_env, node.meta.get("unbacked_bindings") + ): + for s, keypath in unbacked_bindings.items(): + defs.append(s) + + # TODO: some CSE when generating these nodes can probably + # help reduce graph size and improve compile time + def go(node, keypath): + if keypath == (): + return node + if ( + len(keypath) >= 2 + and isinstance(keypath[0], CallMethodKey) + and isinstance(keypath[1], pytree.SequenceKey) + ): + if keypath[0].name == "size": + return go( + graph.call_function( + torch.ops.aten.sym_size.int, + (node, keypath[1].idx), + ), + keypath[2:], + ) + if keypath[0].name == "stride": + return go( + graph.call_function( + torch.ops.aten.sym_stride.int, + (node, keypath[1].idx), + ), + keypath[2:], + ) + return go( + graph.call_method( + keypath[0].name, (node, keypath[1].idx) + ), + keypath[2:], + ) + elif isinstance(keypath[0], CallMethodKey): + return go( + graph.call_method(keypath[0].name, (node,)), keypath[1:] + ) + elif isinstance(keypath[0], pytree.SequenceKey): + return go( + graph.call_function( + operator.getitem, (node, keypath[0].idx) + ), + keypath[1:], + ) + elif isinstance(keypath[0], ConvertIntKey): + return go( + graph.call_function( + cast_symbool_to_symint_guardless, (node,) + ), + keypath[1:], + ) + elif isinstance(keypath[0], DivideByKey): + # TODO: need to assert divisibility + return go( + graph.call_function( + operator.floordiv, (node, keypath[0].divisor) + ), + keypath[1:], + ) + elif isinstance(keypath[0], InnerTensorKey): + return go( + graph.call_function( + getattr, (node, keypath[0].inner_name) + ), + keypath[1:], + ) + else: + raise AssertionError(f"unrecognized keypath {keypath}") + + if s not in expr_to_proxy: + with _set_node_metadata_hook(gm, _node_metadata_hook): + expr_to_proxy[s] = fx.Proxy( + go(node, keypath), tracer=tracer + ) + log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) + + for i0 in defs: + ras = ras_by_symbol.pop(i0, []) + # Before we perform any asserts, first apply range + # refinement. This is important, because if we are going + # to retrace the graph (and we typically are if we send + # the graph to AOTAutograd), we need to make sure we apply + # range refinement (ala _check_is_size) first, BEFORE we + # run any of the asserts. Otherwise, we may decide to + # perform substitutions based on the asserts which we then + # can't back out, because value ranges can only be applied + # to asserts.) + # + # A perhaps better long term plan is to avoid this order + # dependence by making it possible to refine ranges on + # arbitrary expressions, not just symbols. But it is not + # so easy to make use of this information, see + # https://twitter.com/ezyang/status/1745801370299482492 + # We actually made an attempt at this in + # https://github.com/pytorch/pytorch/pull/119043 + # which didn't work. + # + # Another ideas for how to do this: + # - Have bound_sympy be the source of truth of the ranges of any expression + # - Cache intermediate results for every subexpression of bound_sympy + # - This cache should be possible to edit to refine ranges + # + # One issue with this proposal is that if + # we have a bound on 2x, we are not going to be able to + # apply it for 4x. Similarly, we may have bounds for an + # equivalent expression that we are not applying because + # it's not a perfect match (e.g. x < y vs y > x)". + # + # The first issue we already have it and it's impossible + # to solve in general, so any implementation on a best + # effort basis should do. + # + # The second issue is a preexisting one. It can be mitigated + # with a normalization algorithm. In general, it may also + # be on a best effort basis, but since our grammar is not + # terribly difficult, chances are we could even fully + # normalize SymPy expressions... who knows. + if i0 in constrained_unbacked_symbols: + continue # constrain symbol just once + + if i0 in shape_env.size_like: + if export: + graph.call_function( + torch.ops.aten.sym_constrain_range_for_size.default, + (expr_to_proxy[i0].node,), + ) + else: + graph.call_function( + torch._check_is_size, (expr_to_proxy[i0].node,) + ) + + vr = shape_env.var_to_range[i0] + if vr.is_int and vr.upper == sys.maxsize - 1: + # treat upper bound == sys.maxsize - 1 for int symbols as +oo + # to avoid redundant runtime assert + vr = ValueRanges(vr.lower, int_oo) + if not shape_env._default_unspecified_value_range().issubset(vr): + # The runtime range is constrained, so add a runtime + # assert and also explicitly refine the range + # (refinement should not be necessary once runtime + # asserts cause refinement, but that's NYI) + def convert(s): + if s in (int_oo, -int_oo): + return None + try: + return int(s) + except TypeError: + return None + + if ( + expr_to_proxy[i0].node.target + != cast_symbool_to_symint_guardless + ): + # TODO(pianpwk): calling sym_constrain_range_for_size or adding bound asserts + # raises AOTAutograd errors on cast_symbool_to_symint_guardless + + with _set_node_metadata_hook( + gm, + functools.partial( + _node_metadata_hook, + stack_trace=node.meta.get("stack_trace"), + nn_module_stack=node.meta.get("nn_module_stack"), + ), + ): + if (min_val := convert(vr.lower)) is not None: + ge = _sympy_interp(expr_to_proxy, i0 >= min_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + ge, + f"Runtime assertion failed for expression {i0 >= min_val} on node '{ge}'", + ), + ) + added_asserts.add(i0 >= min_val) + if (max_val := convert(vr.upper)) is not None: + le = _sympy_interp(expr_to_proxy, i0 <= max_val).node + graph.call_function( + torch.ops.aten._assert_scalar.default, + ( + le, + f"Runtime assertion failed for expression {i0 <= max_val} on node '{le}'", + ), + ) + added_asserts.add(i0 <= max_val) + + constrained_unbacked_symbols.add(i0) + add_runtime_asserts(ras) + + # delete unused reified symbols + for expr, proxy in expr_to_proxy.items(): + if ( + isinstance(expr, sympy.Symbol) + and proxy.node.op != "placeholder" # keep placeholders intact + and not proxy.node.users + ): + log.debug("deleting unused reified symbol for %s", expr) + gm.graph.erase_node(proxy.node) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/shape_prop.py b/phivenv/Lib/site-packages/torch/fx/passes/shape_prop.py new file mode 100644 index 0000000000000000000000000000000000000000..489a96cedecac22946727b8fe623472f37b7e0c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/shape_prop.py @@ -0,0 +1,230 @@ +# mypy: ignore-errors + +import traceback +from typing import Any, NamedTuple, Optional + +import torch +import torch.fx +from torch._dispatch.python import enable_python_dispatcher +from torch._guards import detect_fake_mode +from torch._prims_common import definitely_contiguous_for_memory_format +from torch._subclasses.meta_utils import is_sparse_any +from torch.fx._compatibility import compatibility +from torch.fx.node import map_aggregate, Node + + +__all__ = ["TensorMetadata", "ShapeProp"] + + +@compatibility(is_backward_compatible=True) +class TensorMetadata(NamedTuple): + # TensorMetadata is a structure containing pertinent information + # about a tensor within a PyTorch program. + + # General Tensor metadata + shape: torch.Size + dtype: torch.dtype + requires_grad: bool + stride: tuple[int, ...] + memory_format: Optional[torch.memory_format] + + # Quantization metadata + is_quantized: bool + qparams: dict[str, Any] + + +# When include_contiguity is True, we will set contiguity when its always true for the tensor. +# Some tensors can represent both contiguous and non-contiguous tensors. e.g: (u0, u1) with (u2, u3). +# In such situation contiguity is not set. We could also make it a tri-state i.e: (definitely_contiguous, +# contiguous, and unknown). +def _extract_tensor_metadata( + result: torch.Tensor, include_contiguity=True +) -> TensorMetadata: + """ + Extract a TensorMetadata NamedTuple describing `result`. + """ + shape = result.shape + dtype = result.dtype + requires_grad = result.requires_grad + stride = result.stride() if not is_sparse_any(result) else () + + memory_format = None + + if include_contiguity and not is_sparse_any(result): + memory_formats = { + torch.contiguous_format, + torch.channels_last, + torch.channels_last_3d, + } + for query_format in memory_formats: + if definitely_contiguous_for_memory_format( + result, memory_format=query_format + ): + memory_format = query_format + break + + is_quantized = result.is_quantized + qparams: dict[str, Any] = {} + if is_quantized: + qscheme = result.qscheme() + qparams["qscheme"] = qscheme + if qscheme in {torch.per_tensor_affine, torch.per_tensor_symmetric}: + qparams["scale"] = result.q_scale() # type: ignore[assignment] + qparams["zero_point"] = result.q_zero_point() # type: ignore[assignment] + elif qscheme in { + torch.per_channel_affine, + torch.per_channel_affine_float_qparams, + torch.per_channel_symmetric, + }: + # In this branch, scale and zero_point are expected to be tensors, + # we store the values as immutable_list in TensorMetadata for + # easier serialization downstream + qparams["scale"] = result.q_per_channel_scales().tolist() # type: ignore[assignment] + qparams["zero_point"] = result.q_per_channel_zero_points().tolist() # type: ignore[assignment] + qparams["axis"] = result.q_per_channel_axis() # type: ignore[assignment] + + return TensorMetadata( + shape, dtype, requires_grad, stride, memory_format, is_quantized, qparams + ) + + +@compatibility(is_backward_compatible=True) +class ShapeProp(torch.fx.Interpreter): + """ + Execute an FX graph Node-by-Node and + record the shape and type of the result + into the corresponding node. + + Example: + In this example, we record the shape + and data type of a module given + an example input ``torch.randn(50, D_in)``. + We print the name, shape and dtype of each node. + + class TwoLayerNet(torch.nn.Module): + def __init__(self, D_in, H, D_out): + super().__init__() + self.linear1 = torch.nn.Linear(D_in, H) + self.linear2 = torch.nn.Linear(H, D_out) + def forward(self, x): + h_relu = self.linear1(x).clamp(min=0) + y_pred = self.linear2(h_relu) + return y_pred + N, D_in, H, D_out = 64, 1000, 100, 10 + x = torch.randn(N, D_in) + y = torch.randn(N, D_out) + model = TwoLayerNet(D_in, H, D_out) + gm = torch.fx.symbolic_trace(model) + sample_input = torch.randn(50, D_in) + ShapeProp(gm).propagate(sample_input) + + for node in gm.graph.nodes: + print(node.name, node.meta['tensor_meta'].dtype, + node.meta['tensor_meta'].shape) + + The output of this code is: + + x torch.float32 torch.Size([50, 1000]) + linear1 torch.float32 torch.Size([50, 100]) + clamp_1 torch.float32 torch.Size([50, 100]) + linear2 torch.float32 torch.Size([50, 10]) + output torch.float32 torch.Size([50, 10]) + + Args: + module (GraphModule): The module to be executed + fake_mode (FakeTensorMode): A fake mode for copying the gm + + """ + + def __init__(self, gm, fake_mode=None): + super().__init__(gm) + if fake_mode is None: + fake_mode = detect_fake_mode() + if fake_mode is not None: + from torch._dynamo.utils import deepcopy_to_fake_tensor + + # Note: + # We need fake execution cause the inputs are fake, however, we cannot fakify the module + # - because we need to write to the tensor_meta of the real module. So we fakify to + # produce a result (L131 below), to extract tensor meta, and then keep going. + # + # If we were to fakify, we would write to the wrong node, and then downstream fusion + # would be missing the tensor_meta. + # + # See torch/_inductor/overrides.py for where this is called upstream of fusion. + self.fake_module = deepcopy_to_fake_tensor(self.module, fake_mode) + self.fake_mode = fake_mode + else: + self.fake_module = None + self.fake_mode = None + + self.real_module = self.module + + def run_node(self, n: Node) -> Any: + from torch.fx.experimental.symbolic_shapes import ( + compute_unbacked_bindings, + rebind_unbacked, + ) + + try: + if self.fake_module is not None: + # Hacky swap. Alternatively, we could do this with overriding + # call_module and get_attr. + self.module = self.fake_module + try: + if self.fake_mode is not None: + with self.fake_mode, enable_python_dispatcher(): + result = super().run_node(n) + rebind_unbacked(self.fake_mode.shape_env, n, result) + else: + result = super().run_node(n) + finally: + self.module = self.real_module + except Exception as e: + traceback.print_exc() + raise RuntimeError( + f"ShapeProp error for: node={n.format_node()} with meta={n.meta}" + ) from e + + found_tensor = False + + def extract_tensor_meta(obj): + if isinstance(obj, torch.Tensor): + nonlocal found_tensor + found_tensor = True + return _extract_tensor_metadata(obj) + else: + return obj + + meta = map_aggregate(result, extract_tensor_meta) + if found_tensor: + n.meta["tensor_meta"] = meta + + if self.fake_mode: + if (shape_env := self.fake_mode.shape_env) and ( + symbol_to_path := compute_unbacked_bindings(shape_env, result) + ): + n.meta["unbacked_bindings"] = symbol_to_path + + n.meta["type"] = type(result) + return result + + def propagate(self, *args): + """ + Run `module` via interpretation and return the result and + record the shape and type of each node. + + Args: + *args (Tensor): the sample input. + + Returns: + Any: The value returned from executing the Module + """ + if self.fake_mode is not None: + fake_args = [ + self.fake_mode.from_tensor(t) if isinstance(t, torch.Tensor) else t + for t in args + ] + else: + fake_args = args + return super().run(*fake_args) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/split_module.py b/phivenv/Lib/site-packages/torch/fx/passes/split_module.py new file mode 100644 index 0000000000000000000000000000000000000000..5fbe7f36269a05c6861405388d35d07d1ffab32a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/split_module.py @@ -0,0 +1,639 @@ +# mypy: allow-untyped-defs +import inspect +import logging +from collections import OrderedDict +from typing import Any, Callable, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx._utils import lazy_format_graph_code +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node + + +__all__ = ["Partition", "split_module"] +log = _LOGGER = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=True) +class Partition: + def __init__(self, name: str): + self.name: str = name + self.submod_name = f"submod_{name}" + self.node_names: list[str] = [] + self.inputs: dict[str, None] = {} + self.outputs: dict[str, None] = {} + self.dependencies: dict[str, None] = {} + self.dependents: dict[str, None] = {} + self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + self.environment: dict[Node, Node] = {} + self.targets: dict[str, Any] = {} + + def __repr__(self) -> str: + return ( + f"name: {self.name},\n" + f" nodes: {self.node_names},\n" + f" inputs: {self.inputs},\n" + f" outputs: {self.outputs},\n" + f" partitions depended on: {self.dependencies},\n" + f" partition dependents: {self.dependents}" + ) + + +def _get_attr_from_qualname(mod: torch.nn.Module, qualname: str) -> Any: + attr_val = mod + for atom in qualname.split("."): # type: ignore[union-attr] + if not hasattr(attr_val, atom): + raise AttributeError(f"Node target {qualname} not found!") + attr_val = getattr(attr_val, atom) + return attr_val + + +# Creates subgraphs out of main graph +@compatibility(is_backward_compatible=True) +def split_module( + m: GraphModule, + root_m: torch.nn.Module, + split_callback: Callable[[Node], int], + qualname_map: Optional[dict[str, str]] = None, + keep_original_order: Optional[bool] = False, + keep_original_node_name: Optional[bool] = False, + keep_original_input_name: bool = True, +): + """ + Creates subgraphs out of main graph + + Args: + m (GraphModule): Graph module to split + root_m (torch.nn.Module): root nn module. Not currently used. Included + because the root nn module is usually transformed via + torch.fx._symbolic_trace.symbolic_trace (see example below) + split_callback (Callable[[Node], int]): Callable function + that maps a given Node instance to a numeric partition identifier. + split_module will use this function as the policy for which operations + appear in which partitions in the output Module. + qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a + mapping from new target names in the module after split to old target + names in the original module. + keep_original_order: Optional[bool]: keep the original order of the GraphModule + or use the Topological order of the new constructed GraphModule + keep_original_node_name: Optional[bool]: If the partitioned graphs should + have the same node names as the original graph. + keep_original_input_name: bool: If the partitioned graphs should + have the same input names as the original graph. + + Returns: + GraphModule: the module after split. + + Example: + + This is a sample setup: + + import torch + from torch.fx.symbolic_trace import symbolic_trace + from torch.fx.graph_module import GraphModule + from torch.fx.node import Node + from torch.fx.passes.split_module import split_module + + class MyModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.param = torch.nn.Parameter(torch.rand(3, 4)) + self.linear = torch.nn.Linear(4, 5) + + def forward(self, x, y): + z = self.linear(x + self.param).clamp(min=0.0, max=1.0) + w = self.linear(y).clamp(min=0.0, max=1.0) + return z + w + + # symbolically trace model + my_module = MyModule() + my_module_traced = symbolic_trace(my_module) + + # random mod partitioning + partition_counter = 0 + NPARTITIONS = 3 + + def mod_partition(node: Node): + global partition_counter + partition = partition_counter % NPARTITIONS + partition_counter = (partition_counter + 1) % NPARTITIONS + return partition + + # split module in module with submodules + module_with_submodules = split_module( + my_module_traced, my_module, mod_partition + ) + + Output looks like this. Original graph is broken into partitions + + > print(module_with_submodules) + GraphModule( + (submod_0): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_1): GraphModule( + (linear): Linear(in_features=4, out_features=5, bias=True) + ) + (submod_2): GraphModule() + ) + + def forward(self, x, y): + param = self.param + submod_0 = self.submod_0(x, param, y); x = param = y = None + getitem = submod_0[0] + getitem_1 = submod_0[1]; submod_0 = None + submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None + getitem_2 = submod_1[0] + getitem_3 = submod_1[1]; submod_1 = None + submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None + return submod_2 + + Output of split module is the same as output of input traced module. + This is an example within a test setting: + + > orig_out = my_module_traced(x, y) + > submodules_out = module_with_submodules(x, y) + > self.assertEqual(orig_out, submodules_out) + True + """ + + log.debug( + "%s", + lazy_format_graph_code("pre split_module", m, colored=True), + ) + + def construct_graph( + node: Node, + base_mod_env: dict[str, Node], + base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule], + ): + if node.op == "placeholder": + default_value = ( + node.args[0] if len(node.args) > 0 else inspect.Signature.empty + ) + if keep_original_node_name: + args = ( + () if default_value is inspect.Signature.empty else (default_value,) + ) + base_mod_env[node.name] = base_mod_graph.create_node( + "placeholder", + node.name, + args=args, # type: ignore[arg-type] + type_expr=node.type, + ) + else: + base_mod_env[node.name] = base_mod_graph.placeholder( + node.target, # type: ignore[arg-type] + type_expr=node.type, + default_value=default_value, + ) + base_mod_env[node.name].meta = node.meta.copy() + elif node.op == "get_attr": + base_mod_env[node.name] = base_mod_graph.get_attr(node.target) # type: ignore[arg-type] + base_mod_env[node.name].meta = node.meta.copy() + assert isinstance(node.target, str) + attr_val = _get_attr_from_qualname(m, node.target) + base_mod_attrs[node.target] = attr_val # type: ignore[index] + return base_mod_env, base_mod_attrs + + import sympy + + partitions: dict[str, Partition] = {} + orig_nodes: dict[str, Node] = {} + symbol_to_node: dict[sympy.Symbol, Node] = {} + + def record_cross_partition_use(def_node: Node, use_node: Optional[Node]): + from torch.fx.experimental.symbolic_shapes import free_symbols + + defined = getattr(def_node, "_fx_partition", None) + used = getattr(use_node, "_fx_partition", None) + + log.debug( + "record_cross_partition_use %s (%s) %s (%s)", + def_node.name, + defined, + use_node.name if use_node is not None else "-", + used, + ) + + if defined != used: + if defined is not None: + def_partition = partitions[defined] + def_partition.outputs.setdefault(def_node.name) + if used is not None: + def_partition.dependents.setdefault(used) + + if used is not None: + use_partition = partitions[used] + use_partition.inputs.setdefault(def_node.name) + # We have made def_node an input to the use_partition. If + # this input has symbolic symbols in its size, those also must + # be made as inputs to the partition + if (def_val := def_node.meta.get("example_value")) is not None: + for s in sorted(free_symbols(def_val), key=str): + s_node = symbol_to_node[s] + use_partition.inputs.setdefault(s_node.name) + if symbol_to_node[s].op != "placeholder": + # If the node that defines the symbol is not a + # placeholder, we must make it an output of the + # partition. Note that this may be in a different + # partition than defined! Although, this doesn't + # really make a difference for correctness, since + # defined is guaranteed to have the symbol in + # scope and can return it; you just get less + # optimal codegen in this case. + s_defined = getattr(s_node, "_fx_partition", None) + if s_defined is not None: + s_def_partition = partitions[s_defined] + s_def_partition.outputs.setdefault(s_node.name) + s_def_partition.dependents.setdefault(used) + if defined is not None: + use_partition.dependencies.setdefault(defined) + + def instantiate_node_partition_mapping(node): + partition_name = str(split_callback(node)) + log.debug( + "instantiate_node_partition_mapping %s (%s)", node.name, partition_name + ) + + # add node to partitions + partition = partitions.get(partition_name) + if partition is None: + partitions[partition_name] = partition = Partition(partition_name) + + partition.node_names.append(node.name) + node._fx_partition = partition_name + + # Global State Nodes are nodes which by their global state effects, + # "taint" all downstream nodes while they are active. + GLOBAL_STATE_NODES = [ + torch.amp._enter_autocast, + torch.amp._exit_autocast, + torch._C._set_grad_enabled, + ] + + # For grad regions: + # ------------------------ + # 1. first region: we do nothing + # 2. subsequent regions: we insert the set_grad at the beginning + grad_regions: OrderedDict[Node, set[int]] = OrderedDict() + + # For autocast regions: + # ------------------------ + # 1. first region: we will only insert the _exit at the end + # 2. intermediate regions: we will insert both the + # _enter at the beginning and _exit at the end + # 3. last region: we will only insert _enter at the beginning + # We will do so in the order in which the autocasts were instantiated. + autocast_regions: OrderedDict[Node, set[int]] = OrderedDict() + autocast_exits: dict[Node, Optional[Node]] = {} + + active_grad = None + active_autocasts = set() + + for node in m.graph.nodes: + # This will prefer placeholder bindings, because those come first. + # This is a little dangerous though: it is possible that an unbacked + # symbol is used without any binding site for it, in which case we + # will get a KeyError not able to find it. I'd like to fix this by + # having passes.runtime_assert establish some invariants that I can + # rely on later, but this needs some extra work. Quick fix first. + # See https://github.com/pytorch/pytorch/issues/130534 + if ( + (val := node.meta.get("example_value")) is not None + and isinstance(val, (torch.SymInt, torch.SymFloat)) + and isinstance(s0 := val.node.expr, sympy.Symbol) + and s0 not in symbol_to_node + ): + symbol_to_node[val.node.expr] = node + + if node.op in ["placeholder", "get_attr", "output"]: + continue + + instantiate_node_partition_mapping(node) + + if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: + if node.target == torch._C._set_grad_enabled: + assert len(node.args) == 1 + assert isinstance(node.args[0], bool) + active_grad = node + grad_regions[active_grad] = set({split_callback(node)}) + elif node.target == torch.amp._enter_autocast: + # Should all be python constants + assert all(not isinstance(arg, Node) for arg in node.args) + active_autocasts.add(node) + autocast_regions[node] = set({split_callback(node)}) + autocast_exits[node] = None + elif node.target == torch.amp._exit_autocast: + assert len(node.args) == 1 + autocast_regions[node.args[0]].add(split_callback(node)) + active_autocasts.remove(node.args[0]) + autocast_exits[node.args[0]] = node + + if active_grad is not None: + grad_regions[active_grad].add(split_callback(node)) + + for a in active_autocasts: + autocast_regions[a].add(split_callback(node)) + + assert all(v is not None for v in autocast_exits.values()), "autocast must exit" + + autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} + grad_regions = {k: sorted(v) for k, v in grad_regions.items()} + + if _LOGGER.isEnabledFor(logging.DEBUG): + _LOGGER.debug("autocast_regions: %s", autocast_regions) + _LOGGER.debug("grad_regions: %s", grad_regions) + + assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) + + # split nodes into partitions + highest_partition = -1 + for node in m.graph.nodes: + orig_nodes[node.name] = node + + # TODO currently placeholders/parameters aren't put into random partitions, + # rather they're added to the graphs where they are used down below + if node.op in ["placeholder", "get_attr"]: + continue + if node.op == "output": + torch.fx.graph.map_arg( + node.args[0], lambda n: record_cross_partition_use(n, None) + ) + continue + + if assert_monotonically_increasing: + pid = split_callback(node) + assert highest_partition <= pid, ( + "autocast or set_grad_enabled require monotonically increasing partitions:" + f"highest: {highest_partition}, this node's: {pid}" + ) + highest_partition = pid + + # do not capture cross-partition dependencies for global state nodes as they will be + # self-contained - their setup and unwind will be isolated to each partition submodule. + if node.target not in GLOBAL_STATE_NODES: + torch.fx.graph.map_arg( + node.args, lambda def_node: record_cross_partition_use(def_node, node) + ) + torch.fx.graph.map_arg( + node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) + ) # noqa: B950 + + original_partition_order = list(partitions.keys()) + # find partitions with no dependencies + root_partitions: list[str] = [] + for partition_name, partition in partitions.items(): + if not len(partition.dependencies): + root_partitions.append(partition_name) + + # check partitions for circular dependencies and create topological partition ordering + sorted_partitions: list[str] = [] + while root_partitions: + root_partition = root_partitions.pop() + sorted_partitions.append(root_partition) + for dependent in partitions[root_partition].dependents: + partitions[dependent].dependencies.pop(root_partition) + if not partitions[dependent].dependencies: + root_partitions.append(dependent) + if len(sorted_partitions) != len(partitions): + raise RuntimeError("cycle exists between partitions!") + + # Enter prelude + for regions_mapping in [autocast_regions, grad_regions]: + for node, regions in regions_mapping.items(): + assert len(regions) > 0 + partitions[str(regions[0])].environment[node] = node + for r in regions[1:]: + partition = partitions[str(r)] + new_node = partition.graph.create_node( + op=node.op, + target=node.target, + args=tuple(arg for arg in node.args), + kwargs={}, + type_expr=node.type, + ) + new_node.meta = ( + node.meta.copy() + ) # is it really a good idea to copy this? + partition.environment[node] = new_node + + # add placeholders to partition inputs + for partition_name in sorted_partitions: + partition = partitions[partition_name] + new_inputs: dict[str, None] = {} + + counter = 0 + + for inp in partition.inputs: + orig_node = orig_nodes[inp] + # We don't pass in get_attr nodes as inputs to the partition, but + # instead set them as targets and use getattr within the module + + def add_placeholder(): + if keep_original_input_name: + name = inp + else: + nonlocal counter + name = f"arg_{counter}" + counter += 1 + placeholder = partition.graph.placeholder( + name, + type_expr=orig_nodes[inp].type, + ) + new_inputs[inp] = None + return placeholder + + if orig_node.op == "get_attr": + assert isinstance(orig_node.target, str) + + orig_attr = _get_attr_from_qualname(m, orig_node.target) + if isinstance(orig_attr, torch.nn.Module): + placeholder = partition.graph.get_attr(orig_node.target) + partition.targets[orig_node.target] = orig_attr + else: + placeholder = add_placeholder() + else: + placeholder = add_placeholder() + placeholder.meta = orig_nodes[inp].meta.copy() + partition.environment[orig_nodes[inp]] = placeholder + partition.inputs = new_inputs + + # Transform nodes and collect targets for partition's submodule + for node in m.graph.nodes: + if hasattr(node, "_fx_partition"): + partition = partitions[node._fx_partition] + + # swap out old graph nodes in kw/args with references to new nodes in this submodule + environment = partition.environment + gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) + gathered_kwargs = torch.fx.graph.map_arg( + node.kwargs, lambda n: environment[n] + ) + + if node.op not in ["call_module", "get_attr"]: + target = node.target + else: + target_attr = _get_attr_from_qualname(m, node.target) + target = node.target.replace(".", "_") + partition.targets[target] = target_attr + # Fill in the passed-in mapping from new qualname to old qualname + if qualname_map is not None: + # When creating the split module later, the submodules will have + # path prefix matching the corresponding partition's submod_name + qualname = f"{partition.submod_name}.{target}" + qualname_map[qualname] = node.target + + assert isinstance(gathered_args, tuple) + assert isinstance(gathered_kwargs, dict) + name = node.name if keep_original_node_name else None + new_node = partition.graph.create_node( + op=node.op, + target=target, + args=gathered_args, + kwargs=gathered_kwargs, + type_expr=node.type, + name=name, + ) + new_node.meta = node.meta.copy() + partition.environment[node] = new_node + + # Exit epilogue + for regions_mapping in [autocast_regions]: + for node in reversed(regions_mapping): + regions = regions_mapping[node] + assert len(regions) > 0 + for r in regions[:-1]: + partition = partitions[str(r)] + exit_node = autocast_exits[node] + assert exit_node is not None, "Missing exit node" + new_node = partition.graph.create_node( + op=exit_node.op, + target=exit_node.target, + args=(partition.environment[node],), + kwargs={}, + type_expr=exit_node.type, + ) + new_node.meta = ( + exit_node.meta.copy() + ) # is it really a good idea to copy this? + + # original module environment dict mapping node names to nodes + orig_mod_env: dict[str, Node] = {} + # Set up values to construct base module + base_mod_env: dict[str, Node] = {} + base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() + base_mod_attrs: dict[str, torch.fx.graph_module.GraphModule] = {} + if not keep_original_order: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + else: + # Go through the graph to construct the mapping dict + for node in m.graph.nodes: + orig_mod_env[node.name] = node + + # Do some things iterating over the partitions in topological order again: + # 1) Finish off submodule Graphs by setting corresponding outputs + # 2) Construct GraphModules for each submodule + # 3) Construct the base graph by emitting calls to those submodules in + # topological order or original order specified by keep_original_order + + construct_order_partitions = ( + sorted_partitions if not keep_original_order else original_partition_order + ) + + already_constructed_attr_nodes = set() + + # We actually need to insert the placeholder nodes in the original order + # otherwise graph signature will be wrong. + original_order = [node for node in m.graph.nodes if node.op == "placeholder"] + + for partition_name in construct_order_partitions: + partition = partitions[partition_name] + + # Set correct output values + output_vals = tuple( + partition.environment[orig_nodes[name]] for name in partition.outputs + ) + + # skip output node generation if there are no output values + num_output_vals = len(output_vals) + if num_output_vals == 1: + partition.graph.output(output_vals[0]) + elif num_output_vals > 1: + partition.graph.output(output_vals) + else: + # Invariant - Graph should always have an output node. + partition.graph.output(()) + + if keep_original_order: + # first get the attr nodes required by this partition + orig_mod_attr_nodes: list[Node] = [ + orig_mod_env[key] + for key in partition.inputs + if key not in original_order + ] + + for node in original_order: + if node in already_constructed_attr_nodes: + continue # already added this attr to the base graph + base_mod_env, _based_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + # Construct GraphModule for this partition + for node in orig_mod_attr_nodes: # type: ignore[attr-defined] + if node in already_constructed_attr_nodes: + continue + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + already_constructed_attr_nodes.add(node) + + base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( + partition.targets, partition.graph + ) # noqa: B950 + + # Emit call in base graph to this submodule + output_val = base_mod_graph.call_module( + partition.submod_name, + tuple(base_mod_env[name] for name in partition.inputs), + ) + + num_outputs = len(partition.outputs) + if num_outputs > 1: + # Unpack multiple return values from submodule + output_val_proxy = torch.fx.proxy.Proxy(output_val) + for i, output_name in enumerate(partition.outputs): + base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] + elif num_outputs == 1: + base_mod_env[next(iter(partition.outputs))] = output_val + + # When keep_original_order=True and if the graph doesn't have any + # `call_function` node then `base_mod_graph`, `base_mod_env` and `base_mod_attrs` + # are never populated. + # For this case, we call `construct_graph` here which takes care of updating them. + if keep_original_order and not base_mod_env: + for node in m.graph.nodes: + base_mod_env, base_mod_attrs = construct_graph( + node, base_mod_env, base_mod_attrs + ) + + # Add output node to `base_mod_graph` (i.e. the split graph) which will be returned. + for node in m.graph.nodes: + if node.op == "output": + base_mod_graph.output( + torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) + ) # noqa: B950 + + ret = torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) + log.debug( + "%s", + lazy_format_graph_code("post split_module", ret, colored=True), + ) + return ret diff --git a/phivenv/Lib/site-packages/torch/fx/passes/split_utils.py b/phivenv/Lib/site-packages/torch/fx/passes/split_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..7d9914ef296f1171291f2822fc9905b81751d5e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/split_utils.py @@ -0,0 +1,307 @@ +# mypy: allow-untyped-defs +import copy +from dataclasses import dataclass, field +from typing import Optional, Union + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import map_arg +from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module + +from .tools_common import NodeList + + +__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] + + +@compatibility(is_backward_compatible=False) +def getattr_recursive(obj, name): + for layer in name.split("."): + if hasattr(obj, layer): + obj = getattr(obj, layer) + else: + return None + return obj + + +@compatibility(is_backward_compatible=False) +def setattr_recursive(obj, attr, value): + if "." not in attr: + setattr(obj, attr, value) + else: + layer = attr.split(".") + setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) + + +@compatibility(is_backward_compatible=False) +@dataclass +class Component: + """ + A component serves as a container for a subgraph we want to create afterwards. + """ + + graph: torch.fx.Graph + order: int + name: str + + # Stores the placeholder nodes in `graph`. + input_placeholders: list = field(default_factory=list) + + # Store the nodes in original graph that are placeholder in `graph`. + orig_inputs: list = field(default_factory=list) + + # Store the nodes in original graph that are outputs in `graph`. + orig_outputs: list = field(default_factory=list) + + # Mapping from get_attr node in original graph to get_attr node in `graph`. + getattr_maps: dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) + constructor_args: list[str] = field(default_factory=list) + gm: Optional[torch.fx.GraphModule] = None + + +@compatibility(is_backward_compatible=False) +def split_by_tags( + gm: torch.fx.GraphModule, + tags: list[str], + return_fqn_mapping: bool = False, + return_tuple: bool = False, + GraphModuleCls: type[torch.fx.GraphModule] = torch.fx.GraphModule, +) -> Union[torch.fx.GraphModule, tuple[torch.fx.GraphModule, dict[str, str]]]: + """ + Splits a GraphModule using tags on its graph nodes. We honor the order of + tags. For example, we have tags = ["a", "b", "c"], the function will create + the initial submodules in the order of "a", "b", "c". + + To set a tag: + gm.graph.nodes[idx].tag = "mytag" + + This will result in all nodes with the same tag being extracted and placed in their + own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder + and output nodes are created when needed while get_attr nodes get copied to submodules + where they are used. + + Given the following module def: + + class SimpleModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.linear1 = torch.nn.Linear(...) + self.linear2 = torch.nn.Linear(...) + self.linear3 = torch.nn.Linear(...) + + def forward(self, in1, in2): + r1 = self.linear1(in1) + r2 = self.linear2(in2) + r3 = torch.cat([r1, r2]) + return self.linear3(r3) + + Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: + + ro: + def forward(self, in1): + self = self.root + linear1 = self.linear1(in1) + return linear1 + + main: + def forward(self, in2, linear1): + self = self.root + linear2 = self.linear2(in2) + cat_1 = torch.cat([linear1, linear2]) + linear3 = self.linear3(cat_1) + return linear3 + + main: + def forward(self, in1, in2): + self = self.root + ro_0 = self.ro_0(in1) + main_1 = self.main_1(in2, ro_0) + return main_1 + + Returns: + split_gm: torch fx graph after split + orig_to_split_fqn_mapping: a map between the original fqn and the fqn + after split for call_module and get_attr. + """ + + def flatten(x: torch.fx.node.Argument) -> NodeList: + """ + Stores nodes in x to a list and returns the list. + """ + r: NodeList = [] + map_arg(x, r.append) + return r + + # Mapping from node in original module to node in created submodule. + node_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + + # Mapping from node in original module or created submodules to + # corresponding component. + node_to_component: dict[torch.fx.Node, Component] = {} + + # Mapping from tag to the corresponding component. + tag_to_component: dict[str, Component] = {} + + # Stores all components. + all_components: list[Component] = [] + + # Stores nodes that will be used in main graph. + used_in_main: dict[torch.fx.Node, None] = {} + + # Main graph after split. + main_g = torch.fx.Graph() + + # Mapping from node in original module to node in main graph after split. + main_remapping: dict[torch.fx.Node, torch.fx.Node] = {} + + # Output node of original module. + output_node: Optional[torch.fx.Node] = None + + # Create a component for each tag, we don't expect to create other components afterwards. + for tag in tags: + comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") + all_components.append(comp) + tag_to_component[tag] = comp + + # Traverse the nodes in original graph and take care of them. + for node in gm.graph.nodes: + if node.op == "output": + if output_node is not None: + raise RuntimeError("Multiple output nodes in graph!") + output_node = node + continue + + # Placeholders in the original graph get copied to main graph. + if node.op == "placeholder": + main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) + main_remapping[node].meta = copy.copy(node.meta) + continue + + # Get_attr nodes are ignored because we are not tagging them. + # Instead, we copy them directly to the submodules use them afterwards. + if node.op == "get_attr": + continue + + # Now we process callable nodes which are nodes with op of call_module, + # call_function or call_method. Every callable nodes should be tagged. + assert hasattr(node, "tag"), f"Node does not have tag: {node.format_node()}" + + upstream_components = [ + node_to_component[x] + for x in flatten(node.args) + flatten(node.kwargs) + if x.op not in {"placeholder", "get_attr"} + ] + + comp = tag_to_component[node.tag] + node_to_component[node] = comp + + # Max order of upperstream components. + mx = max((c.order for c in upstream_components), default=0) + + # Expect the component for `node` has higher order then its upstream components. + assert comp.order >= mx, ( + f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" + ) + + # Map a input of `node` to nodes in the component's graph. + def remap_func(x): + # If input is a get_attr node, copy it to current component's graph. + # Returns the get_attr node in current component's graph. + if x.op == "get_attr": + if x not in comp.getattr_maps: + comp.getattr_maps[x] = comp.graph.get_attr( + x.target, type_expr=x.type + ) + comp.getattr_maps[x].meta = copy.copy(x.meta) + return comp.getattr_maps[x] + + # If input is not a placeholder, it should have been put into a component + # already. If it's the current component then we return the corresponding + # node in the component. + if x.op != "placeholder" and node_to_component[x] == comp: + return node_remapping[x] + + # If input is a placeholder or it's in other components, we want to make it + # as a placeholder in current component's graph. + if x not in comp.orig_inputs: + comp.orig_inputs.append(x) + placeholder = comp.graph.placeholder(x.name, type_expr=x.type) + placeholder.meta = copy.copy(x.meta) + comp.input_placeholders.append(placeholder) + used_in_main[x] = None + + return comp.input_placeholders[comp.orig_inputs.index(x)] + + n = comp.graph.node_copy(node, remap_func) + n.tag = node.tag # type: ignore[attr-defined] + node_remapping[node] = n + node_to_component[n] = comp + + if output_node is None: + raise RuntimeError("Graph had no output node!") + + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + # We don't need components mapping for nodes of type "get_attr" + # that are consumed by the output. Only need to make sure we create + # corresponding counterparts in the resulting graph. + main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) + else: + # All component results consumed by the output node should be + # marked as "used in main". + used_in_main[x] = None + + # If a node is used in main graph then we mark it as an output in the component + # it belongs to. + for n in used_in_main: + if n.op != "placeholder": + node_to_component[n].orig_outputs.append(n) + + # Now we create a graphmodule for each component. + orig_to_split_fqn_mapping: dict[str, str] = {} + for comp in all_components: + outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) + + if return_tuple: + comp.graph.output(outs) + else: + # Take care of the args of FX output node. If there's a single + # output then the output node args is like (output_single), else + # if there're multiple outputs then the output node args is like + # ((output_0, output_1, ...)). + comp.graph.output(outs[0] if len(outs) == 1 else outs) + + comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module( + gm, subgraph=comp.graph, comp_name=comp.name + ) + orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping) + + # Create a call_module node in main graph. + main_node = main_g.call_module( + comp.name, + args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), + kwargs=None, + ) + + if len(outs) == 1 and not return_tuple: + main_remapping[comp.orig_outputs[0]] = main_node + else: + for i, o in enumerate(comp.orig_outputs): + # Use Proxy to record getitem access. + main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index] + + main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) + main_root = HolderModule({comp.name: comp.gm for comp in all_components}) + main_g._codegen = gm.graph._codegen + + # If the output nodes consumes get_attr directly in the original graph, + # then we need to make sure get_attr is copied to the new graph. + for x in flatten(output_node.args[0]): + if x.op == "get_attr": + setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] + + result_gm = GraphModuleCls(main_root, main_g) + if return_fqn_mapping: + return result_gm, orig_to_split_fqn_mapping + + return result_gm diff --git a/phivenv/Lib/site-packages/torch/fx/passes/splitter_base.py b/phivenv/Lib/site-packages/torch/fx/passes/splitter_base.py new file mode 100644 index 0000000000000000000000000000000000000000..b1330a704339e823faace95cd8d5407d8bd31b11 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/splitter_base.py @@ -0,0 +1,925 @@ +# mypy: allow-untyped-defs +import argparse +import copy +import logging +from collections import defaultdict +from collections.abc import Iterable, Sequence +from dataclasses import dataclass +from typing import Any, NamedTuple, Optional + +import torch +from torch.fx._compatibility import compatibility +from torch.fx.node import map_arg +from torch.fx.passes.graph_manipulation import get_size_of_node + +from .graph_drawer import FxGraphDrawer +from .operator_support import get_node_target, OperatorSupportBase +from .shape_prop import ShapeProp +from .split_utils import split_by_tags +from .tools_common import ( + CALLABLE_NODE_OPS, + FxNetAccFusionsFinder, + is_node_output_tensor, + NodeList, + NodeSet, + Tensors, +) + + +__all__ = [ + "FxNetAccNodesFinder", + "FxNetSplitterInternalError", + "Subgraph", + "SplitResult", + "generate_inputs_for_submodules", +] +_LOGGER = logging.getLogger(__name__) + +DEFAULT_MIN_ACC_MODULE_SIZE = 1 +DEFAULT_SKIP_FUSION = False +DEFAULT_ALLOW_NON_TENSOR = False + + +class _SplitterSettingBase: + def __init__( + self, + min_acc_module_size=DEFAULT_MIN_ACC_MODULE_SIZE, + skip_fusion=DEFAULT_SKIP_FUSION, + allow_non_tensor=DEFAULT_ALLOW_NON_TENSOR, + max_acc_splits: int = -1, + ): + parser = argparse.ArgumentParser() + parser.add_argument( + "--min-acc-module-size", + "--min_acc_module_size", + required=False, + type=int, + help="Minimum size limit of an accelerator subgraph.", + ) + parser.add_argument( + "--max-acc-splits", + "--max_acc_splits", + required=False, + type=int, + help="Enforce a maximum number of split subgraphs.", + ) + parser.add_argument( + "--skip-fusion", + "--skip_fusion", + default=False, + action="store_true", + help="If true then no fusion groups. Fusion group is used to " + "enforce no non-tensor data flow between submodules. If we don't " + "have this constrain, setting this to false is recommended as it " + "can reduce overhead.", + ) + parser.add_argument( + "--allow-non-tensor", + "--allow_non_tensor", + default=False, + action="store_true", + help="For some backends non-tensor data flow between cpu and them " + "are not allowed. Therefore, if a node supported by accelerator but " + "it has non-tensor inputs or outputs to a cpu node we would want to " + "consider it as a cpu node during splitting. However, for some backends " + "we might not care about non-tensor data flow and we can set this option " + "to true to disable the functionality that prevent non-tensor data flow.", + ) + args, _unknown = parser.parse_known_args() + + self.min_acc_module_size: int = ( + args.min_acc_module_size + if args.min_acc_module_size + else min_acc_module_size + ) + self.skip_fusion: bool = args.skip_fusion if args.skip_fusion else skip_fusion + self.allow_non_tensor: bool = ( + args.allow_non_tensor if args.allow_non_tensor else allow_non_tensor + ) + self.max_acc_splits: int = max_acc_splits + + +@compatibility(is_backward_compatible=False) +class FxNetAccNodesFinder: + """ + Finds a set of nodes that can be supported on ACC, excluding nodes that have non-tensor + input/output to cpu nodes to prevent non-tensor data flow between backends and cpu. + + I.e. if we have a chain: + + ACC_NODE_1 -> ACC_NODE_2 -> ACC_NODE_3 -> CPU_NODE_1 + + where every ACC node produces non-tensor output, then they all should be treated as CPU nodes. + + This behavior can be turned off by passing allow_non_tensor=True. + """ + + def __init__( + self, + module: torch.fx.GraphModule, + operator_support: OperatorSupportBase, + allow_non_tensor: bool, + ): + self.module = module + self.operator_support = operator_support + self.allow_non_tensor = allow_non_tensor + self.acc_nodes: NodeSet = set() + + def reduce_acc_nodes_non_tensor_input_helper(self, cpu_worklist: NodeList): + """ + Transitively excludes nodes from ACC supported set. + For every node in the worklist: + - removes its downstream ACC nodes from ACC supported set, + - if any downstream ACC node produces non-tensor output, + then it gets added into the worklist. + """ + while cpu_worklist: + node = cpu_worklist.pop(0) + + for user in node.users: + if user in self.acc_nodes: + self.acc_nodes.remove(user) + if not is_node_output_tensor(user): + cpu_worklist.append(user) + + def reduce_acc_nodes_non_tensor_input(self): + """ + Excludes nodes from ACC supported set that have direct + upstream CPU nodes that produce non-tensor outputs. + """ + non_tensor_cpu_nodes: NodeList = [] + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + if node in self.acc_nodes: + continue + if is_node_output_tensor(node): + continue + non_tensor_cpu_nodes.append(node) + + self.reduce_acc_nodes_non_tensor_input_helper(non_tensor_cpu_nodes) + + def reduce_acc_nodes_non_tensor_output(self): + """ + Excludes nodes from ACC supported set that produce non-tensor + outputs and have downstream CPU nodes. + """ + while True: + new_cpu_nodes: NodeList = [] + + for acc_node in self.acc_nodes: + if is_node_output_tensor(acc_node): + continue + for user in acc_node.users: + if user not in self.acc_nodes: + new_cpu_nodes.append(acc_node) + break + + if not new_cpu_nodes: + break + + for new_cpu_node in new_cpu_nodes: + self.acc_nodes.remove(new_cpu_node) + + self.reduce_acc_nodes_non_tensor_input_helper(new_cpu_nodes) + + def __call__(self) -> NodeSet: + submodules = dict(self.module.named_modules()) + self.acc_nodes = { + n + for n in self.module.graph.nodes + if n.op in CALLABLE_NODE_OPS + and self.operator_support.is_node_supported(submodules, n) + } + + if not self.allow_non_tensor: + self.reduce_acc_nodes_non_tensor_input() + self.reduce_acc_nodes_non_tensor_output() + + return self.acc_nodes + + +@compatibility(is_backward_compatible=False) +class FxNetSplitterInternalError(Exception): + pass + + +@compatibility(is_backward_compatible=False) +@dataclass +class Subgraph: + is_acc: bool + nodes: NodeList + device_ordinal: Optional[int] = None + + +@compatibility(is_backward_compatible=False) +class SplitResult(NamedTuple): + """ + Stores the results of the splitter. + + Attributes: + split_module: root module after splitting. + submodule_inputs: a dict that maps submodule name to its inputs. + non_acc_submodule_prefix: the prefix for non acc submodules. For + acc submodule the prefix is alwasy "_run_on_acc_". + """ + + split_module: torch.fx.GraphModule + submodule_inputs: dict[str, Any] + non_acc_submodule_prefix: str + + +@compatibility(is_backward_compatible=False) +def generate_inputs_for_submodules( + model: torch.nn.Module, + inputs: Sequence[Any], + target_submodules: Iterable[str], + deepcopy: bool = False, +) -> dict[str, Any]: + """ + Generate inputs for targeting submdoules in the given model. Note that if two submodules refer to the same obj, this + function doesn't work. + + Args: + model: root model. + inputs: inputs to the root model. + target_submodules: submodules that we want to generate inputs for. + + Returns: + A dict that maps from submodule name to its inputs. + """ + + handles = [] + results = {} + submodule_to_names = {mod: name for name, mod in model.named_modules()} + + def pre_forward(module, module_inputs): + results[submodule_to_names[module]] = ( + copy.deepcopy(module_inputs) if deepcopy else module_inputs + ) + + for name, mod in model.named_modules(): + if name in target_submodules: + handles.append(mod.register_forward_pre_hook(pre_forward)) + + def clean_up_handles(): + for h in handles: + h.remove() + + try: + with torch.no_grad(): + model(*inputs) + except Exception as e: + clean_up_handles() + raise e + + clean_up_handles() + return results + + +class _SplitterBase: + """ + Splits a GraphModule into sub-GraphModules for execution on CPU or the accelerator. + Output is a GraphModule with supported and unsupported operators grouped into as few sub-GraphModules as possible. + Assumes that only "call_module", "call_function" and "call_method" from FX IR can potentially be executed on the accelerator. + + Given the following graph: + ==> b ==> + // \\ + a d + \\ // + ==> c ==> + + class SimpleModule(torch.nn.Module): + def forward(self, a): + b = torch.sin(a) + c = torch.cos(a) + d = b + c + return d + + and providing "operator_support" that indicates that 'b' and 'c' can be executed on the accelerator, + we will get the following split result: + + main: + def forward(self, a): + run_on_acc_0_0 = self._run_on_acc_0_0(a) + getitem = run_on_acc_0_0[0] + getitem_1 = run_on_acc_0_0[1] + run_on_cpu_1_1 = self._run_on_cpu_1_1(getitem, getitem_1) + return run_on_cpu_1_1 + + _run_on_acc_0_0: + def forward(self, a): + sin_1 = torch.sin(a) + cos_1 = torch.cos(a) + return (sin_1, cos_1) + + _run_on_cpu_1_1: + def forward(self, sin_1, cos_1): + add_1 = sin_1 + cos_1 + return add_1 + """ + + # PCIe bandwidth for the backend, default to 100 GB/s + PCIe_BW = 100 * 2**30 + + def __init__( + self, + module: torch.fx.GraphModule, + sample_input: Sequence[Any], + operator_support: OperatorSupportBase, + settings: _SplitterSettingBase, + non_acc_submodule_name: str = "_run_on_cpu_", + return_tuple: bool = False, + nodes_finder: Optional[FxNetAccNodesFinder] = None, + ): + """ + Preprocesses graph before splitting: + - finds nodes supported by ACC, + - finds fusion groups for ACC nodes having non-tensor IO, + - builds a graph of direct dependencies, + - builds a map of fused nodes to their fusions. + As a result we get self.acc_nodes, self.deps and self.fusions. + """ + assert isinstance(module, torch.fx.GraphModule) + + self.module = module + ShapeProp(self.module).propagate(*sample_input) + + self.settings = settings + self.operator_support = operator_support + self.sample_input = sample_input + if nodes_finder is None: + nodes_finder = FxNetAccNodesFinder( + self.module, self.operator_support, self.settings.allow_non_tensor + ) + self.acc_nodes = nodes_finder() + + if self.settings.skip_fusion: + self.fusions = {} + else: + self.fusions = FxNetAccFusionsFinder(module, self.acc_nodes)() + + # Modify deps to add more deps for fused nodes + self.deps = self.find_deps() + self.update_deps_for_fusions() + + self.non_acc_submodule_name = non_acc_submodule_name + self._node_submodule_map: dict[str, str] = {} + self._return_tuple = return_tuple + + self.tags: list[str] = [] + + # =============================================================== + # Helpers for ctor and initial state + # =============================================================== + + def get_node_submodule_map(self) -> dict[str, str]: + """Returns a map from node name to submodule name, e.g. + node: main_module_impl_impl_over_arch_unary_multiple_embedding + _pooling_embedding_pooling_sparse_entity_equivalence_key + _proxy_embedding_bag + maps to submodule name of: _run_on_acc_1 + """ + return self._node_submodule_map + + def find_deps(self) -> dict[torch.fx.Node, NodeSet]: + """ + Builds a graph of node dependencies. Leaf nodes don't have any + dependencies and the "output" node doesn't have nodes depending on it. + + Resulting graph has only direct dependencies, i.e. there are no + transitive dependencies. + """ + deps: dict[torch.fx.Node, NodeSet] = defaultdict(set) + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op != "output": + deps[user].add(node) + return deps + + def update_deps_for_fusions(self): + """ + Updates graph of dependencies so that: + - nodes from the same fusion depend on the same set of outer nodes, + - outer nodes depending on a fusion depend on all nodes in that fusion. + """ + for node in self.fusions: + fusion = self.fusions[node] + for fused_neighbor in fusion: + self.deps[node].update(self.deps[fused_neighbor] - fusion) + + for user in fused_neighbor.users: + if user not in fusion: + self.deps[user].add(node) + + # =============================================================== + # Helpers for preview + # =============================================================== + + def _lower_model_to_backend( + self, mod: torch.fx.GraphModule, inputs: Tensors + ) -> torch.nn.Module: + """ + Lower the model to a backend. + """ + + return mod + + def _find_culprit(self, mod: torch.fx.GraphModule, inputs: Tensors) -> str: + """ + When an error occurs during lowering or running the lowered mod, we use this + function to find culprits in the `mod` that causes the error. + """ + + return "Unable to find a culprit because _find_culprit() function is not implemented." + + def _draw_graph_based_on_node_support( + self, mod: torch.fx.GraphModule, supported_nodes: NodeList + ): + color_map = { + "default": "AliceBlue", + "supported": "chartreuse1", + "unsupported": "crimson", + } + + class CustomDrawer(FxGraphDrawer): + def _get_node_style(self, node): + template = super()._get_node_style(node) + if node in supported_nodes: + template["fillcolor"] = color_map["supported"] + elif node.op in CALLABLE_NODE_OPS: + template["fillcolor"] = color_map["unsupported"] + else: + template["fillcolor"] = color_map["default"] + + return template + + drawer = CustomDrawer(mod, "node_support", ignore_getattr=True) + dot_graph = drawer.get_main_dot_graph() + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw("node_support.dot") # type: ignore[attr-defined] + + def node_support_preview(self, dump_graph: bool = False): + submodules = dict(self.module.named_modules()) + + supported_nodes: NodeList = [] + supported_node_types = defaultdict(set) + unsupported_node_types = defaultdict(set) + + def get_dtype(arg): + tensor_meta = arg.meta.get("tensor_meta") + return getattr(tensor_meta, "dtype", None) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + target = get_node_target(submodules, node) + + # Store dtype of arg in node.args. If arg doesn't have dtype, i.e. not a tensor, we'll store None. + arg_dtypes = [ + get_dtype(arg) if isinstance(arg, torch.fx.Node) else None + for arg in node.args + ] + + # Find last non-None element. If all elements are None, return max_len. + last_index = len(arg_dtypes) - next( + ( + i + for i, dtype in enumerate(reversed(arg_dtypes)) + if dtype is not None + ), + len(arg_dtypes), + ) + + # Strip None elements at the end. + arg_dtypes_tuple = tuple(arg_dtypes[:last_index]) + kwarg_dtypes_tuple = tuple( + (k, get_dtype(arg)) + for k, arg in node.kwargs.items() + if isinstance(arg, torch.fx.Node) + ) + + if self.operator_support.is_node_supported(submodules, node): + supported_nodes.append(node) + supported_node_types[target].add((arg_dtypes_tuple, kwarg_dtypes_tuple)) + else: + unsupported_node_types[target].add( + (arg_dtypes_tuple, kwarg_dtypes_tuple) + ) + + if dump_graph: + self._draw_graph_based_on_node_support(self.module, supported_nodes) + + reports = "\nSupported node types in the model:\n" + for t, dtypes in supported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + reports += "\nUnsupported node types in the model:\n" + for t, dtypes in unsupported_node_types.items(): + for arg_dtypes_tuple, kwarg_dtypes_tuple in dtypes: + reports += f"{t}: ({arg_dtypes_tuple}, {dict(kwarg_dtypes_tuple)})\n" + + print(reports) + + # Return reports for testing purpose + return reports + + def split_preview(self, dump_graph: bool = False): + reports = "" + subgraphs = self.put_nodes_into_subgraphs() + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"Before removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_num = len([g for g in subgraphs if g.is_acc]) + cpu_subgraphs_num = len(subgraphs) - acc_subgraphs_num + reports += f"After removing small acc subgraphs, total {len(subgraphs)} subgraphs are created:" + reports += f" {acc_subgraphs_num} acc subgraphs and {cpu_subgraphs_num} cpu subgraphs.\n" + + for i, subgraph in enumerate(subgraphs): + reports += ( + f"_run_on_acc_{i}: " + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{i}: " + ) + reports += f"{len(subgraph.nodes)} node(s)\n" + + self.tag(subgraphs) + split_mod = self.split(remove_tag=True) + split_mod.eval() + + if dump_graph: + drawer = FxGraphDrawer(split_mod, "preview", ignore_getattr=True) + dot_graphs = drawer.get_all_dot_graphs() + for name, dot_graph in dot_graphs.items(): + # pyre-fixme[16]: `pydot.Dot` has no attribute `write_raw`. + dot_graph.write_raw(f"{name}.dot") # type: ignore[attr-defined] + + max_qps: float = self.PCIe_BW + bottleneck_module = "" + + for node in split_mod.graph.nodes: + if node.op == "call_module" and "acc" in node.target: + reports += f"\nProcessing acc submodule {node.target}\n" + + submod = getattr(split_mod, node.target) + + def get_submod_inputs(main_mod, submod, example_inputs): + sub_inputs = None + + def get_inputs(self, inputs): + nonlocal sub_inputs + sub_inputs = inputs + + handle = submod.register_forward_pre_hook(get_inputs) + main_mod(*example_inputs) + handle.remove() + return sub_inputs + + submod_inputs = get_submod_inputs(split_mod, submod, self.sample_input) + ShapeProp(submod).propagate(*submod_inputs) + + total_input_bytes = 0 + total_output_bytes = 0 + + reports += "Checking inputs...\n" + for n in submod.graph.nodes: + if n.op == "placeholder": + if not is_node_output_tensor(n): + reports += f"Input {n.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_input_bytes += get_size_of_node(submod, n)[0] + if n.op == "output": + output_node = n + + reports += "Checking outputs...\n" + + def get_bytes(node: torch.fx.Node): + nonlocal total_output_bytes + nonlocal reports + if not is_node_output_tensor(node): + reports += f"Output {node.name} is not a tensor, this might cause problems during lowering!\n" + else: + total_output_bytes += get_size_of_node(submod, node)[0] + + map_arg(output_node.args, get_bytes) # type: ignore[possibly-undefined] + qps = self.PCIe_BW / max(total_input_bytes, total_output_bytes) + reports += f"Total input size in bytes is {total_input_bytes}, total output size in bytes is {total_output_bytes}," + reports += f" theoretical max qps (bounds by PCIe bandwidth) for this submodule is {qps}.\n" + + if qps < max_qps: + max_qps = qps + bottleneck_module = node.target + + try: + lowered_submod = self._lower_model_to_backend(submod, submod_inputs) + except RuntimeError: + reports += "Run into an error during lowering!\n" + reports += self._find_culprit(submod, submod_inputs) + continue + + try: + lowered_submod(*submod_inputs) + except RuntimeError: + reports += "Run into an error during inference!\n" + reports += self._find_culprit(submod, submod_inputs) + else: + reports += "Lowering and running succeed!\n" + + reports += f"\nTheoretical max qps (bounds by PCIe bandwidth) for this model is {max_qps}," + reports += f" bottleneck is submodule {bottleneck_module}." + print(reports) + + # return the reports for testing purposes + return reports + + # =============================================================== + # Helpers for extend_acc_subgraph() method + # =============================================================== + + def find_reverse_deps( + self, tag_id: Optional[int] = None + ) -> dict[torch.fx.Node, NodeSet]: + """ + Builds reversed topological node dependencies, if tag_id is specified, + we ignore nodes that are in later subgraph i.e. nodes have greater tag_id. + """ + result: dict[torch.fx.Node, NodeSet] = defaultdict(set) + + for node in self.module.graph.nodes: + if node.op not in CALLABLE_NODE_OPS: + continue + + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + if tag_id is None or (int(user.tag.split("_")[-1]) < tag_id): + result[node].add(user) + + return result + + def update_reverse_deps_for_fusions(self, deps: dict[torch.fx.Node, NodeSet]): + processed_node = set() + + for node, fusion in self.fusions.items(): + if node in processed_node: + continue + + new_dep = set() + + # Create a new dependency set which include all the + # dependencies of the nodes in the fusion group + for n in fusion: + new_dep.update(deps[n]) + + # Exclude nodes in the fusion + new_dep.difference_update(fusion) + + # Update dependency + for n in fusion: + deps[n] = new_dep + + for arg in n.all_input_nodes: + if arg not in fusion: + deps[arg].update(fusion) + + processed_node.add(n) + + def find_parent_nodes_of_subgraph(self, tag: str) -> NodeSet: + """ + Finds parent nodes of the `tag` subgraph. + + Traverse the inputs of nodes in the subgraph, if input doesn't belong to the subgraph + and is not a placeholder, we consider it as the parent node of the subgraph. + """ + parent_nodes = set() + + for node in self.module.graph.nodes: + if node.op in CALLABLE_NODE_OPS and node.tag == tag: + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg.tag != tag: + parent_nodes.add(arg) + + return parent_nodes + + def extend_acc_subgraph(self, tag: str): + """ + Extend the acc subgraph with `tag` going the reversed topological direction. + """ + # Dict that maps node to its users and ignore users that + # are in the subgraph that has greater tag + deps = self.find_reverse_deps(tag_id=int(tag.split("_")[-1])) + self.update_reverse_deps_for_fusions(deps) + + # Parent nodes of the subgraph + parent_nodes = self.find_parent_nodes_of_subgraph(tag) + + visited_nodes: NodeSet = set() + + while parent_nodes: + node = None + + # Find a acc node that depends on visited nodes only + for n in parent_nodes: + if deps[n] <= visited_nodes and n in self.acc_nodes: + node = n + break + + if node is None: + break + + # Put the node into `tag` subgraph + node.tag = tag # type: ignore[attr-defined] + parent_nodes.remove(node) + visited_nodes.add(node) + + # If node is in a fusion group, add all fusion buddies to parent nodes + if node in self.fusions: + for fusion_node in self.fusions[node]: + if fusion_node not in visited_nodes: + parent_nodes.add(fusion_node) + + # Add inputs of the node to parent nodes + for arg in node.all_input_nodes: + if arg.op in CALLABLE_NODE_OPS and arg not in visited_nodes: + parent_nodes.add(arg) + + # =============================================================== + # Helpers for split() method + # =============================================================== + + def starter_nodes(self) -> tuple[NodeSet, NodeSet]: + """ + Finds nodes that consume module inputs or get_attr nodes. + """ + starter_cpu_nodes: NodeSet = set() + starter_acc_nodes: NodeSet = set() + for node in self.module.graph.nodes: + if node.op not in {"placeholder", "get_attr"}: + continue + for user in node.users: + if user in self.acc_nodes: + starter_acc_nodes.add(user) + else: + starter_cpu_nodes.add(user) + return starter_cpu_nodes, starter_acc_nodes + + def put_nodes_into_subgraphs(self) -> list[Subgraph]: + # We start graph traversal from leaf nodes + current_cpu_nodes, current_acc_nodes = self.starter_nodes() + visited_nodes: NodeSet = set() + + # Determine which subgraph to start from based on which subgraph has + # 0-dep node + acc_subgraph: bool = not any(len(self.deps[n]) == 0 for n in current_cpu_nodes) + + current_subgraph_nodes: NodeList = [] + + # Result accumulator + subgraphs: list[Subgraph] = [] + while current_cpu_nodes or current_acc_nodes: + # Find the first node that should belong to the current subgraph and has all dependencies resolved + current_nodes = current_acc_nodes if acc_subgraph else current_cpu_nodes + node = next( + (n for n in current_nodes if self.deps[n] <= visited_nodes), + None, + ) + + # If nothing was found, then it's time to flip the mode and start a new subgraph + if node is None: + if not current_subgraph_nodes: + raise FxNetSplitterInternalError("Subgraph can't be empty") + + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + acc_subgraph = not acc_subgraph + current_subgraph_nodes = [] + continue + + current_nodes.remove(node) + visited_nodes.add(node) + current_subgraph_nodes.append(node) + + # Add fusion buddies + if node in self.fusions: + if node in self.acc_nodes: + current_acc_nodes.update(self.fusions[node] - visited_nodes) + else: + current_cpu_nodes.update(self.fusions[node] - visited_nodes) + + # Put depending nodes into the queue + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + + # Add downstream nodes + if user in self.acc_nodes: + current_acc_nodes.add(user) + else: + current_cpu_nodes.add(user) + + # Check if the last subgraph was not created + if current_subgraph_nodes: + subgraphs.append( + Subgraph(is_acc=acc_subgraph, nodes=current_subgraph_nodes) + ) + + if not subgraphs: + raise FxNetSplitterInternalError("Couldn't create subgraphs") + + return subgraphs + + def remove_small_acc_subgraphs(self, subgraphs: list[Subgraph]) -> list[Subgraph]: + """ + This pass finds ACC submodules with less than specified size and merges + them with adjacent CPU submodules. + """ + result: list[Subgraph] = [] + for subgraph in subgraphs: + if subgraph.is_acc: + if len(subgraph.nodes) >= self.settings.min_acc_module_size: + result.append(subgraph) + else: + print( + "Eliminating acc subgraph because it's smaller than the threshold: " + f"{len(subgraph.nodes)} < {self.settings.min_acc_module_size}" + ) + if result: + result[-1].nodes.extend(subgraph.nodes) + else: + subgraph.is_acc = False + result.append(subgraph) + else: + if result and not result[-1].is_acc: + result[-1].nodes.extend(subgraph.nodes) + else: + result.append(subgraph) + return result + + def tag(self, subgraphs: list[Subgraph]): + self.tags = [] + for subgraph in subgraphs: + tag = ( + f"_run_on_acc_{len(self.tags)}" + if subgraph.is_acc + else f"{self.non_acc_submodule_name}{len(self.tags)}" + ) + self.tags.append(tag) + for node in subgraph.nodes: + if hasattr(node, "tag"): + raise FxNetSplitterInternalError(f"Node {node} was already tagged") + + node.tag = tag # type: ignore[attr-defined] + self._node_submodule_map[node.name] = tag + + def split(self, remove_tag: bool = False) -> torch.fx.GraphModule: + split_module = split_by_tags( + self.module, self.tags, return_tuple=self._return_tuple + ) + if remove_tag: + for node in self.module.graph.nodes: + if hasattr(node, "tag"): + del node.tag + return split_module # type: ignore[return-value] + + def __call__(self) -> torch.fx.GraphModule: + subgraphs = self.put_nodes_into_subgraphs() + subgraphs = self.remove_small_acc_subgraphs(subgraphs) + acc_subgraphs_count = len([s for s in subgraphs if s.is_acc]) + non_acc_subgraphs_count = len(subgraphs) - acc_subgraphs_count + print( + f"Got {acc_subgraphs_count} acc subgraphs and {non_acc_subgraphs_count} non-acc subgraphs" + ) + self.tag(subgraphs) + return self.split() + + def generate_split_results(self) -> SplitResult: + split_module = self() + submodule_names = [] + for name, _mod in split_module.named_children(): + submodule_names.append(name) + if ( + self.settings.max_acc_splits > 0 + and len(submodule_names) > self.settings.max_acc_splits + ): + raise ValueError( + "Cannot fulfill max_acc_splits limit. " + "This may cause split fragmentation and " + "result in performance issues." + ) + + submodule_inputs = generate_inputs_for_submodules( + split_module, self.sample_input, submodule_names + ) + return SplitResult(split_module, submodule_inputs, self.non_acc_submodule_name) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/tests/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/phivenv/Lib/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6871b844567f1a1ce67010b6515405ac07cfcede Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/tests/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8859ba4f1a1ab9b5296b64a49ba3b768280393e6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/tests/__pycache__/test_pass_manager.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/tests/test_pass_manager.py b/phivenv/Lib/site-packages/torch/fx/passes/tests/test_pass_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..153850adadf2ab3b4271bc190a87ebb49ca0e8f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/tests/test_pass_manager.py @@ -0,0 +1,56 @@ +import unittest + +from ..pass_manager import ( + inplace_wrapper, + PassManager, + these_before_those_pass_constraint, + this_before_that_pass_constraint, +) + + +class TestPassManager(unittest.TestCase): + def test_pass_manager_builder(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + pm.validate() + + def test_this_before_that_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + pm = PassManager(passes) + + # add unfulfillable constraint + pm.add_constraint(this_before_that_pass_constraint(passes[-1], passes[0])) + + self.assertRaises(RuntimeError, pm.validate) + + def test_these_before_those_pass_constraint(self) -> None: + passes = [lambda x: 2 * x for _ in range(10)] + constraint = these_before_those_pass_constraint(passes[-1], passes[0]) + pm = PassManager([inplace_wrapper(p) for p in passes]) + + # add unfulfillable constraint + pm.add_constraint(constraint) + + self.assertRaises(RuntimeError, pm.validate) + + def test_two_pass_managers(self) -> None: + """Make sure we can construct the PassManager twice and not share any + state between them""" + + passes = [lambda x: 2 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm1 = PassManager() + for p in passes: + pm1.add_pass(p) + pm1.add_constraint(constraint) + output1 = pm1(1) + self.assertEqual(output1, 2**3) + + passes = [lambda x: 3 * x for _ in range(3)] + constraint = these_before_those_pass_constraint(passes[0], passes[1]) + pm2 = PassManager() + for p in passes: + pm2.add_pass(p) + pm2.add_constraint(constraint) + output2 = pm2(1) + self.assertEqual(output2, 3**3) diff --git a/phivenv/Lib/site-packages/torch/fx/passes/tools_common.py b/phivenv/Lib/site-packages/torch/fx/passes/tools_common.py new file mode 100644 index 0000000000000000000000000000000000000000..ffde4dd6c7c43a243b34bcdf401cb584e063a507 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/tools_common.py @@ -0,0 +1,319 @@ +# mypy: allow-untyped-defs +import collections +import operator +from collections.abc import Mapping +from dataclasses import dataclass +from typing import Any, Optional, Union + +import torch +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.node import _get_qualified_name + + +__all__ = [ + "get_acc_ops_name", + "get_node_target", + "is_node_output_tensor", + "FxNetAccFusionsFinder", + "legalize_graph", +] + +Tensors = Union[tuple[torch.Tensor], list[torch.Tensor]] +TensorOrTensors = Union[torch.Tensor, Tensors] +NodeList = list[torch.fx.Node] +NodeSet = set[torch.fx.Node] +Names = list[str] +CALLABLE_NODE_OPS = {"call_module", "call_function", "call_method"} + + +@compatibility(is_backward_compatible=False) +def get_acc_ops_name(k): + if isinstance(k, str): + return k + elif k.__module__ and "acc_ops" in k.__module__: + return f"acc_ops.{k.__name__}" + else: + module = k.__module__.replace( + "torch._ops", "torch.ops" + ) # WAR for bug in how torch.ops assigns module + return f"{module if module else ''}.{k.__name__}" + + +@compatibility(is_backward_compatible=False) +def get_node_target( + submodules: Mapping[str, torch.nn.Module], node: torch.fx.Node +) -> str: + """ + Given a `node` returns its target typename. + + For "call_method" node, return node.target which is the name of that method being called. + This could potential lead to conflict but should be okay because normally it's on a tensor. + + For "call_function" node, return typename of node.target. + + For "call_module" node, return typename of the module that node.target point to. + + If seeing "_VariableFunctionsClass" in the target name string, it will be replaced by + "torch". e.g. _VariableFunctionsClass.relu would become torch.relu. + """ + + assert node.op in CALLABLE_NODE_OPS, ( + "Expect op types of " + ", ".join(CALLABLE_NODE_OPS) + f", but found {node.op}" + ) + + if node.op == "call_module": + assert isinstance(node.target, str) + submod = submodules[node.target] + submod_type = getattr(submod, "_base_class_origin", type(submod)) + return get_acc_ops_name(submod_type) + elif node.op == "call_function": + target: Any = node.target + return ( + f"acc_ops.{target.__name__}" + if target.__module__ is not None and "acc_ops" in target.__module__ + else _get_qualified_name(target) + ) + else: + assert isinstance(node.target, str) + return node.target + + +@compatibility(is_backward_compatible=False) +def is_node_output_tensor(node: torch.fx.Node) -> bool: + """Checks if the node output produces a Tensor or not. + + NOTE: This requires to run `ShapeProp` on the containing fx graph before + calling this function. This is because it works by checking the `type` + metadata on the node. This metadata is produced by the `ShapeProp`. + """ + type_ = node.meta.get("type", None) + return type_ is not None and issubclass(type_, torch.Tensor) + + +@compatibility(is_backward_compatible=False) +class FxNetAccFusionsFinder: + """ + Finds groups of connected ACC nodes that pass non-tensor data between each other. + Such groups are called fusion groups. + """ + + def __init__(self, module: torch.fx.GraphModule, acc_nodes: NodeSet): + self.module = module + self.nodes = list(module.graph.nodes) + self.acc_nodes = acc_nodes + + @dataclass + class FusionGroup: + # The smallest idx of nodes in the fusion group after topological sorting all the nodes in the model. + top_node_idx: int + + # Nodes in this fusion group. + nodes: NodeSet + + # Inputs to this fusion group. + inputs: NodeSet + + # Nodes that in the fusion group that haven't been processed yet. + nodes_need_process: NodeSet + + def add_node(self, node): + """ + Add a node to fusion group. + """ + if node in self.nodes: + return + + self.nodes_need_process.add(node) + self.nodes.add(node) + self.inputs.discard(node) + self.inputs.update( + { + n + for n in node.all_input_nodes + if n.op in CALLABLE_NODE_OPS and n not in self.nodes + } + ) + + def recursive_add_node( + self, + fusion_group: "FxNetAccFusionsFinder.FusionGroup", + inputs: Union[NodeSet, NodeList], + visited: Optional[NodeSet] = None, + ): + """ + Start from inputs and going reverse topological order. If any upstream node + is in the fusion group, add all the nodes in this path to fusion group. + """ + for arg in inputs: + # skip the node if already seen + if visited is not None: + if arg in visited: + continue + visited.add(arg) + + # Skip placeholder and get_attr because they won't be in the fusion group. + if arg.op not in CALLABLE_NODE_OPS: + continue + + # If the node has smaller idx, it's already an upstream node of the fusion + # group. We don't need to check it anymore. + if self.nodes.index(arg) < fusion_group.top_node_idx: + continue + + # If the node is in the fusion group, return True. + if arg in fusion_group.nodes: + return True + + # Check the upstream nodes of the node, if any of them is in the fusion group + # we'll add this node to fusion group and return True. + if self.recursive_add_node(fusion_group, arg.all_input_nodes, visited): + fusion_group.add_node(arg) + return True + + return False + + def __call__(self) -> dict[torch.fx.Node, NodeSet]: + result: dict[torch.fx.Node, NodeSet] = {} + acc_nodes = list(self.acc_nodes) + + for node in acc_nodes: + if node in result: + continue + if node.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in node.meta: + continue + if node not in self.acc_nodes: + continue + + fusion_group: FxNetAccFusionsFinder.FusionGroup = self.FusionGroup( + top_node_idx=self.nodes.index(node), + nodes={node}, + inputs=set(node.all_input_nodes), + nodes_need_process={node}, + ) + while fusion_group.nodes_need_process: + node = fusion_group.nodes_need_process.pop() + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Optionally add downstream nodes + if "tensor_meta" not in node.meta: + for user in node.users: + if user.op not in CALLABLE_NODE_OPS: + continue + if user in fusion_group.nodes: + continue + + fusion_group.add_node(user) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + # Add some upstream nodes + for arg in node.all_input_nodes: + if arg.op not in CALLABLE_NODE_OPS: + continue + if "tensor_meta" in arg.meta: + continue + if arg in fusion_group.nodes: + continue + + fusion_group.add_node(arg) + fusion_group.top_node_idx = min( + fusion_group.top_node_idx, self.nodes.index(arg) + ) + self.recursive_add_node( + fusion_group, + fusion_group.inputs, + visited=set(), + ) + + if not (set(fusion_group.nodes) <= self.acc_nodes): + self.acc_nodes -= fusion_group.nodes + else: + for n in fusion_group.nodes: + result[n] = fusion_group.nodes + + return result + + +@compatibility(is_backward_compatible=False) +def legalize_graph(gm: torch.fx.GraphModule) -> torch.fx.GraphModule: + """ + Replace the graph of the given GraphModule with one that contains the same nodes as the + original, but in topologically sorted order. + + This is used by the merge_matmul transformation below, which disturbs the topologically sorted + order of its input GraphModule, so that this order is restored before further transformation. + + Arguments: + gm: The graph module to topologically sort. It is modified in-place. + + Returns: + The graph module in-place sorted + """ + + # These operators are used for making runtime assertions before any + # data-dependent operators occur. We want to prioritize sorting these to + # ensure that these assertions appear before any data-dependent operations + # in the graph. + PRIORITIZED_OPS = [ + operator.add, + operator.mul, + operator.sub, + operator.floordiv, + operator.truediv, + operator.mod, + operator.le, + operator.lt, + operator.ge, + operator.gt, + operator.eq, + operator.ne, + torch.ops.aten.sym_constrain_range.default, + torch.ops.aten.sym_constrain_range_for_size.default, + torch.ops.aten._assert_async.msg, + torch.ops.aten.scalar_tensor.default, + torch.ops.aten._assert_scalar.default, + ] + + indeg = dict.fromkeys(gm.graph.nodes, 0) + new_graph = torch.fx.Graph() + # Track how many unfulfilled dependencies each node has + for node in gm.graph.nodes: + for user in node.users: + indeg[user] += 1 + queue: collections.deque = collections.deque() + # Add all nodes with no dependencies to the queue + for node in gm.graph.nodes: + if indeg[node] == 0: + queue.append(node) + env: dict[torch.fx.Node, torch.fx.Node] = {} + # Pop nodes from the queue, and add nodes that have had all their + # dependencies fulfilled + while len(queue) > 0: + cur = queue.popleft() + env[cur] = new_graph.node_copy(cur, lambda x: env[x]) + for user in cur.users: + indeg[user] -= 1 + if indeg[user] == 0: + if user.op == "call_function" and user.target in PRIORITIZED_OPS: + queue.appendleft(user) + else: + queue.append(user) + # If the new graph's size is not as large as the old one, then there must be + # a cycle (i.e. some node's dependencies were not satisfied.) + if len(new_graph.nodes) < len(gm.graph.nodes): + raise RuntimeError( + f"Input graph has cycles, unable to add {[node for node in indeg if indeg[node] != 0]}" + ) + new_graph._codegen = gm.graph._codegen + gm.graph = new_graph + return gm diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__init__.py b/phivenv/Lib/site-packages/torch/fx/passes/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a78b1765a3c65d8cc4931ba9cc41df0a32189544 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/utils/__init__.py @@ -0,0 +1 @@ +from .common import compare_graphs, HolderModule, lift_subgraph_as_module diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb157c551554e4c8e4c0d802d46c78a5c4fb7b90 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/__init__.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4776b300b6d2ab969edc3f55807a042964e5f82 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/common.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f3469c3cc0f512ccbac72a57cecc9428647d81a Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/fuser_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..815b06bae45a1074f68c19c0f7e1c0cd684bf8c6 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab177148d2dda7d0c5687dc318fde2a227c8414f Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/matcher_with_name_node_map_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..519fc74c8492a1b45612453f8f66222cdbd39aa7 Binary files /dev/null and b/phivenv/Lib/site-packages/torch/fx/passes/utils/__pycache__/source_matcher_utils.cpython-39.pyc differ diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/common.py b/phivenv/Lib/site-packages/torch/fx/passes/utils/common.py new file mode 100644 index 0000000000000000000000000000000000000000..2be6dd226293216ae13002e0207cb3b6f5b24ad8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/utils/common.py @@ -0,0 +1,94 @@ +# mypy: allow-untyped-defs + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.passes.utils.matcher_utils import SubgraphMatcher +from torch.nn import Module + + +__all__ = ["HolderModule", "lift_subgraph_as_module", "compare_graphs"] + + +@compatibility(is_backward_compatible=False) +class HolderModule(Module): + """ + HolderModule is used to copy all the attributes from original module to submodules + that uses the attributes + """ + + def __init__(self, d): + super().__init__() + for k, v in d.items(): + self.add_module(k, v) + + +@compatibility(is_backward_compatible=False) +def lift_subgraph_as_module( + gm: GraphModule, + subgraph: Graph, + comp_name: str = "", + class_name: str = "GraphModule", +) -> tuple[GraphModule, dict[str, str]]: + """ + Create a GraphModule for subgraph, which copies the necessary attributes from the original parent graph_module. + + Args: + gm (GraphModule): parent graph module + + subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph + + comp_name (str): name for the new component + + class_name (str): name for the submodule + + """ + + # Loop through all module calls (call_module) and param fetches (get_attr) + # in this component, creating HolderModules as necessary to match the path. + # e.g. if in the original module there's a get_attr node fetches "conv.weight". + # We create a HolderModule as root -> add a HolderModule named "conv" -> + # make "weight" a attribute of "conv" HolderModule and point to conv.weight in + # the original module. + submodule = HolderModule({}) + orig_to_split_fqn_mapping: dict[str, str] = {} + for n in subgraph.nodes: + if n.op not in ("call_module", "get_attr"): + continue + + target = n.target + assert isinstance(target, str) + target_name_parts = target.split(".") + curr = submodule + orig_gm = gm + + for name in target_name_parts[:-1]: + if not hasattr(curr, name): + curr.add_module(name, HolderModule({})) + + curr = getattr(curr, name) + orig_gm = getattr(orig_gm, name) + + leaf_node_name = target_name_parts[-1] + leaf_node = getattr(orig_gm, leaf_node_name) + + orig_to_split_fqn_mapping[target] = f"{comp_name}.{target}" + # Relies on custom __setattr__ magic. + setattr(curr, leaf_node_name, leaf_node) + + return GraphModule(submodule, subgraph, class_name), orig_to_split_fqn_mapping + + +@compatibility(is_backward_compatible=False) +def compare_graphs(left: Graph, right: Graph) -> bool: + """ + Return True if two graphs are identical, i.e they + - have the same number of outputs in the same order + - have the same number of inputs in the same order + - have the same set of nodes, and identical connectivity + """ + + matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) + matches = matcher.match(right) + + return len(matches) > 0 diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/fuser_utils.py b/phivenv/Lib/site-packages/torch/fx/passes/utils/fuser_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..a293a921c30f8fa9949e78fca4227399d5dc3eec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/utils/fuser_utils.py @@ -0,0 +1,275 @@ +import copy +from queue import SimpleQueue +from typing import Optional as _Optional + +import torch.fx +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.graph_module import GraphModule +from torch.fx.node import Node +from torch.fx.passes.tools_common import legalize_graph, NodeList, NodeSet +from torch.fx.passes.utils import lift_subgraph_as_module # type: ignore[attr-defined] + + +@compatibility(is_backward_compatible=False) +def topo_sort(nodes: NodeList) -> NodeList: + # sort nodes according to the topological order + indegree_map = dict.fromkeys(nodes, 0) + candidates: SimpleQueue[Node] = SimpleQueue() + + for node in nodes: + for n in node.all_input_nodes: + if n in indegree_map: + indegree_map[node] += 1 + if indegree_map[node] == 0: + candidates.put(node) + + sorted_nodes: NodeList = [] + while not candidates.empty(): + node = candidates.get() + sorted_nodes.append(node) + + for n in node.users: + if n in indegree_map: + indegree_map[n] -= 1 + if indegree_map[n] == 0: + candidates.put(n) + + assert len(nodes) == len(sorted_nodes), ( + "topological sorted nodes doesn't have same length as input nodes" + ) + + return sorted_nodes + + +@compatibility(is_backward_compatible=False) +def validate_partition(partition: NodeList) -> bool: + # verify the partition does't form a dependency cycle in the original graph + # returns True for valid partition, False for invalid + + partition_set = set(partition) + + outputs: NodeList = [] + for node in partition_set: + for user_node in node.users: + if user_node not in partition_set: + # external user node, need to expose as an output + outputs.append(user_node) + + # Perform BFS on the partition outputs. + # If it reaches a node within the partition, then it found a cycle. + # This function takes the ownership of `root_nodes` and may modify it. + def bfs_find_cycle(root_nodes: NodeList) -> bool: + # Set used to exclude nodes that have already been visited. + # If a node has been visited, that node and all its children have + # been checked for cycles. + visited: NodeSet = set() + + # Start with `root_nodes` and traverse through (toward child nodes) + # their connected sub-graph. Nodes in `visited` won't be added + # to `queue` again. + queue: NodeList = root_nodes + while queue: + current = queue.pop() + visited.add(current) + if current in partition_set: + # Started from partition's `output` nodes, and reached + # another node in partition. Cycle! + return True + for user_node in current.users: + if user_node in visited: + continue + queue.append(user_node) + # `root_nodes` don't cause cycle. + return False + + # Use all output nodes as roots to traverse + # the graph to check cycles. + if bfs_find_cycle(outputs): + return False + + return True + + +@compatibility(is_backward_compatible=False) +def fuse_as_graphmodule( + gm: GraphModule, + nodes: NodeList, + module_name: str, + partition_lookup_table: _Optional[dict[Node, None]] = None, + *, + always_return_tuple: bool = False, +) -> tuple[GraphModule, tuple[Node, ...], tuple[Node, ...]]: + """ + Fuse nodes in graph_module into a GraphModule. + + Args: + gm (GraphModule): target graph_module + + nodes (List[Node]): list of nodes in `gm` to fuse, where the node must be topologically sorted + + module_name: class name for the fused GraphModule + + partition_lookup_table (Optional[Dict[Node, None]]): optional dict of nodes to speed up lookup + + always_return_tuple (bool): whether to always return a tuple, even if there is only one output + + Returns: + fused_gm (GraphModule): fused graph module, where its node is a copy of `nodes` in `gm` + + original_inputs (Tuple[Node, ...]): input nodes to `nodes` in original `gm` + + original_outputs (Tuple[Node, ...]): consumer nodes of `nodes` in original `gm` + + """ + + # assumption: nodes are already sorted in topo order + + for node in nodes: + assert node.graph.owning_module is gm, ( + f"{node} doesn't belong to passed in graph module {gm._get_name()}" + ) + assert not node._erased, f"{node} has been removed from owning graph" + assert node in gm.graph._find_nodes_lookup_table, ( + f"{node} is not found in graph module {gm._get_name()}" + ) + + # validates partition doesn't introduce dependency circles in the graph + assert validate_partition(nodes), "Invalid partition, found dependency cycles" + + # if no dict of partition nodes is provided, reconstruct it by nodes list to reduce lookup time + if partition_lookup_table is None: + partition_lookup_table = dict.fromkeys(nodes) + + subgraph = Graph() + + node_to_placeholder: dict[ + Node, Node + ] = {} # mapping of nodes from old graph to placeholder in new graph + node_map: dict[Node, Node] = {} # mapping of nodes from old graph to new graph + + # handles inputs through graph.node_copy's arg_transform functions + def remap_inputs(x: Node) -> Node: + if x.op == "get_attr": + # TODO: do we really need copy the get_attr node into the graph? + # do something here + pass + + if x in partition_lookup_table: + # x is inside subgraph, return the copied node + # the node should have been copied aleady, as we are copying graph in the topological order + return node_map[x] + + if x not in node_to_placeholder: + # x is not in subgraph, create a new placeholder for subgraph + placeholder_node = subgraph.placeholder(x.name, type_expr=x.type) + # copy all meta fields, even if some fields might be irrelvant for the placeholder node + placeholder_node.meta = copy.copy(x.meta) + node_to_placeholder[x] = placeholder_node + + return node_to_placeholder[x] + + # copy nodes in topological order + for node in nodes: + new_node = subgraph.node_copy(node, remap_inputs) + node_map[node] = new_node + + # handles outputs + output_mapping: dict[Node, Node] = {} # mapping from old output to new outputs + + for node in nodes: + for user_node in node.users: + if user_node not in partition_lookup_table: + # external user node, need to expose as an output + output_mapping[node] = node_map[node] + + # outs contain nodes in the new subgraph + outs = tuple(output_mapping.values()) + + if always_return_tuple: + # always return a tuple, even if there is only one output + subgraph.output(outs) + else: + # If there's a single output then return it directly, otherwise return a tuple. + subgraph.output(outs[0] if len(outs) == 1 else outs) + + # lint to ensure correctness + subgraph.lint() # type: ignore[no-untyped-call] + fused_gm: GraphModule + fused_gm, _ = lift_subgraph_as_module( + gm, subgraph, comp_name="", class_name=module_name + ) + + # sub_gm's input nodes in the original module + original_inputs: tuple[Node, ...] = tuple(node_to_placeholder.keys()) + + # sub_gm's outputs node in the original module + original_outputs: tuple[Node, ...] = tuple(output_mapping.keys()) + + return fused_gm, original_inputs, original_outputs + + +@compatibility(is_backward_compatible=False) +def insert_subgm( + gm: GraphModule, + sub_gm: GraphModule, + orig_inputs: tuple[Node, ...], + orig_outputs: tuple[Node, ...], +) -> GraphModule: + # add sub_gm into gm + submodule_name = sub_gm.__class__.__name__ + gm.add_submodule(submodule_name, sub_gm) + + # Create a call_module node in main graph. + module_node = gm.graph.call_module(submodule_name, args=orig_inputs, kwargs=None) + + output_node = sub_gm.graph.output_node() + if len(orig_outputs) == 1 and not isinstance(output_node.args[0], tuple): + # main_remapping[comp.orig_outputs[0]] = module_node + orig_outputs[0].replace_all_uses_with(module_node, propagate_meta=True) + else: + for i, orig_output in enumerate(orig_outputs): + # Use Proxy to record getitem access. + proxy_out = torch.fx.Proxy(module_node)[i].node # type: ignore[index] + orig_output.replace_all_uses_with(proxy_out, propagate_meta=True) + + module_node.meta["val"] = tuple( + orig_output.meta.get("val", None) for orig_output in orig_outputs + ) + return gm + + +@compatibility(is_backward_compatible=False) +def erase_nodes(gm: GraphModule, nodes: NodeList) -> None: + # erase original nodes in inversed topological order + for node in reversed(nodes): + gm.graph.erase_node(node) + + +@compatibility(is_backward_compatible=False) +def fuse_by_partitions( + gm: GraphModule, + partitions: list[dict[Node, None]], + prefix: str = "fused_", + always_return_tuple: bool = False, +) -> GraphModule: + for partition_id, partition in enumerate(partitions): + sorted_nodes = topo_sort(list(partition)) + + submodule_name = prefix + str(partition_id) + sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, + sorted_nodes, + submodule_name, + partition, + always_return_tuple=always_return_tuple, + ) + + insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) + + erase_nodes(gm, sorted_nodes) + + # topological sort original gm with newly created sub_gm + legalize_graph(gm) + + return gm diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/matcher_utils.py b/phivenv/Lib/site-packages/torch/fx/passes/utils/matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0c255d5808abbb566ce9a5925b297fc4d2d6380d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/utils/matcher_utils.py @@ -0,0 +1,440 @@ +# mypy: allow-untyped-defs +import copy +import logging +import os +from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any, Union + +import torch +from torch.fx import Graph, Node +from torch.fx._compatibility import compatibility + + +__all__ = ["SubgraphMatcher", "InternalMatch"] + + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger(): + logger = logging.getLogger(__name__) + + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class InternalMatch: + # Nodes from which the match was found + anchors: list[Node] + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] = field(default_factory=dict) + + # nodes in target graph that are matched placeholder in pattern + placeholder_nodes: list[Node] = field(default_factory=list) + + # nodes in matched subgraph returned by output + returning_nodes: list[Node] = field(default_factory=list) + + # map from a string name to a node in the target graph + # only available if the matcher is `SubgraphMatcherWithNameNodesMap` + name_node_map: dict[str, Node] = field(default_factory=dict) + + def __copy__(self): + return InternalMatch( + anchors=self.anchors, + nodes_map=self.nodes_map.copy(), + placeholder_nodes=self.placeholder_nodes.copy(), + returning_nodes=self.returning_nodes.copy(), + ) + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcher: + def __init__( + self, + pattern: Graph, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + """ + Args: + pattern: the targeted matching pattern, represented in fx.Graph. + match_output: If True, output node in the pattern graph will be treated as a part of the targeted pattern. + If False, output node is ignored during match. + match_placeholder: If True, placeholder node in the pattern graph will be treated as a part of + the targeted pattern. If False, placeholder nodes will be used a wildcard. + remove_overlapping_matches: If True, in the case of overlapping matches, only the first match + will be returned. + ignore_literals: If True, will not check if literals are equal and + will instead treat them as wildcards. + """ + + self.pattern = pattern + self.match_output = match_output + self.match_placeholder = match_placeholder + self.remove_overlapping_matches = remove_overlapping_matches + self.ignore_literals = ignore_literals + + if len(pattern.nodes) == 0: + raise ValueError( + "SubgraphMatcher cannot be initialized with an empty pattern" + ) + + for node in pattern.nodes: + if node.op != "output": + assert len(node.users) > 0, ( + "SubgraphMatcher cannot be initialized with an pattern with dead code" + ) + + # TODO: assert pattern is a connected graph + + self.pattern_placeholder_nodes = [ + n for n in pattern.nodes if n.op == "placeholder" + ] + output_node = next(iter(reversed(pattern.nodes))) + # nodes returned by outputs + self.pattern_returning_nodes: list[Node] = output_node.all_input_nodes + + self.pattern_anchors: list[Node] = [] + if match_output: + self.pattern_anchors = [output_node] + else: + # If a node has output_node as the ONLY user, then this node is a graph sink, + # and should be matched against as an anchor + self.pattern_anchors = [ + n for n in output_node.all_input_nodes if len(n.users) == 1 + ] + + def _match_attributes(self, pn: Node, gn: Node) -> bool: + # Attributes matching is complicated. Right now we only support matching constant tensor + assert isinstance(pn.target, str), f"pn.target {pn.target} must be a string." + assert isinstance(gn.target, str), f"gn.target {gn.target} must be a string." + + pn_value = torch.fx.graph_module._get_attr(pn.graph.owning_module, pn.target) + gn_value = torch.fx.graph_module._get_attr(gn.graph.owning_module, gn.target) + + if type(pn_value) != type(gn_value): + return False + + # Don't require exact match on tensor values. + if isinstance(pn_value, torch.Tensor): + return isinstance(gn_value, torch.Tensor) + else: + raise RuntimeError(f"Unsupported type {pn_value} when matching attributes") + return False + + def _nodes_are_equal(self, pn: Node, gn: Node) -> bool: + # if exact match for placeholder is not required, then use placeholder as a wildcard + if not self.match_placeholder and pn.op == "placeholder": + return True + + if pn.op == gn.op: + if pn.op == "placeholder" or pn.op == "output": + return True + elif pn.op == "get_attr": + return self._match_attributes(pn, gn) + return pn.target == gn.target + return False + + def _is_contained(self, nodes_map: dict[Node, Node]) -> bool: + # `lookup` represents all the nodes in `original_graph` + # that are part of `pattern` + + # Placeholders can be used by other nodes in the graphs + lookup: dict[Node, Node] = { + gn: pn for pn, gn in nodes_map.items() if pn.op != "placeholder" + } + + for gn, pn in lookup.items(): + # nodes returned by output are allowed to be used in other areas of the graph + if pn in self.pattern_returning_nodes: + continue + + for user in gn.users: + # If this node has users that were not in `lookup`, then it must leak out of the + # pattern subgraph + if user not in lookup: + return False + return True + + def _remove_overlapping_matches( + self, matches: list[InternalMatch] + ) -> list[InternalMatch]: + non_overlapping_matches: list[InternalMatch] = [] + nodes_matched: set[Node] = set() + + for match in matches: + found_overlap = False + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"} and gn in nodes_matched: + found_overlap = True + break + + if not found_overlap: + non_overlapping_matches.append(match) + for pn, gn in match.nodes_map.items(): + if pn.op not in {"placeholder", "output"}: + nodes_matched.add(gn) + return non_overlapping_matches + + def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: + assert not (isinstance(pn, Node) and isinstance(gn, Node)), ( + "pn and gn cannot both be Node" + ) + + if isinstance(pn, Node) and not isinstance(gn, Node): + if pn.op == "placeholder": + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + match.nodes_map[pn] = gn + return True + else: + return False + elif not isinstance(pn, Node) and isinstance(gn, Node): + return False + else: + return type(gn) == type(pn) and gn == pn + + def _match_nodes(self, pn: Node, gn: Node, match: InternalMatch) -> bool: + logger.info(" matching %s to %s", pn, gn) + + assert isinstance(pn, Node) and isinstance(gn, Node), str( + f"pn and gn must be Node, pn: {pn}, gn: {gn}" + ) + + # Check if we've already matched these nodes in the current + # traversal + if pn in match.nodes_map: + return match.nodes_map[pn] == gn + + # TODO: use a more efficient way to check if gn is matched before: two-way dict + if gn in match.nodes_map.values(): + return False + + if not self._nodes_are_equal(pn, gn): + return False + + # Optimistically mark `pn` as a match for `gn`, and save a local copy of match + saved_match = copy.copy(match) + match.nodes_map[pn] = gn + + # Placeholder is a wildcard and can be matched with any python object + # (including list/tuple) + if pn.op == "placeholder": + return True + + # Recursively traverse upwards to check if `pn` is a true + # match for `gn` + match_found = True + + def _match_args(args1: Union[list, tuple], args2: Union[list, tuple]) -> bool: + if len(args1) != len(args2): + return False + + for a1, a2 in zip(args1, args2): + if isinstance(a1, Node) and isinstance(a2, Node): + matched = self._match_nodes(a1, a2, match) + elif isinstance(a1, (list, tuple)) and isinstance(a2, (list, tuple)): + matched = _match_args(a1, a2) + else: + matched = ( + self._match_literals(a1, a2, match) or self.ignore_literals + ) + + if not matched: + return False + + return True + + # Flatten all args/kwargs into 1 list of args + pn_args, gn_args = None, None + if ( + ( + len(pn.args) != len(gn.args) + or list(pn.kwargs.keys()) != list(gn.kwargs.keys()) + ) + and pn.op == "call_function" + and isinstance(pn.target, torch._ops.OpOverload) + ): + args_schema = pn.target._schema.arguments + + def get_all_arguments(orig_args, orig_kwargs): + all_args = [] + for i, schema in enumerate(args_schema): + if schema.name in orig_kwargs: + all_args.append(orig_kwargs[schema.name]) + elif not schema.kwarg_only and i < len(orig_args): + all_args.append(orig_args[i]) + else: + all_args.append(schema.default_value) + return all_args + + pn_args = get_all_arguments(pn.args, pn.kwargs) + gn_args = get_all_arguments(gn.args, gn.kwargs) + + elif len(pn.args) == len(gn.args) and list(pn.kwargs.keys()) == list( + gn.kwargs.keys() + ): + pn_args = list(pn.args) + gn_args = list(gn.args) + pn_args.extend(list(pn.kwargs.values())) + gn_args.extend(list(gn.kwargs.values())) + else: + match_found = False + + match_found = ( + match_found + and pn_args is not None + and gn_args is not None + and _match_args(pn_args, gn_args) + ) + + if not match_found: + # revert to saved_match before matching with current node + match = copy.copy(saved_match) + return False + + return True + + def match(self, graph: Graph) -> list[InternalMatch]: + """ + Returns: + The matched subgraphs. + Thre returned subgraph would be fully self-contained, meaning the nodes (except placeholder + and nodes returned by output) can only be consumed by nodes within the matched subgraph. + + Subgraph pattern matcher is implemented with the backtracking style in the following steps: + + 1. We first identify all the anchor nodes in the pattern graph. The anchor nodes + are the "sinks" (nodes with no user other than the output node) of the pattern graph. + One pattern graph could have multiple anchors if it has multiple return values. + + 2. In the target graph, we identify the potential candidate nodes that can be matched + with each anchor. These anchor-candidate pairs are the starting points for + pairwise per-node matching. + + 3. For each anchor-candidate pair, we simultaneously traverse backwards (DFS) in both + pattern and target graphs. For every pattern nodes along traversal path, we compare it + against the target nodes. In case any comparison failed, the match for this anchor-candidate + pair fails. A match is found when DFS completes traversing the graph. See `self._match_nodes` + for more details. + + 4. In the case of multiple anchors, every anchor will need to find a match using step 3. + In addition, the matches found between anchors need to have a common intersection node + in order for the match to be valid. This is implemented with backtracking. See `backtracking` + for more details. + + Notice: graph traversal must be done in the reverser order because a tensor can have multiple + consumers, but can only have a single producer. Only with reverser order, we can we jointly + traverse the pattern and target graph in a deterministic path. + + Warning: In theory, this backtracking algorithm have an **exponential** time complexity. However, + in practice, it's unlikely to blow up. + + """ + from torch.fx.passes.utils.fuser_utils import validate_partition + + # find candidate nodes to match with pattern anchors + match_candidates: dict[Node, list[Node]] = defaultdict(list) + for pattern_anchor in self.pattern_anchors: + for node in graph.nodes: + if self._nodes_are_equal(pattern_anchor, node): + match_candidates[pattern_anchor].append(node) + match_candidates_list = list(match_candidates.items()) + + logger.info("Initial match_candidates_list: %s\n", match_candidates_list) + + matches: list[InternalMatch] = [] + + def backtracking(anchor_index, match): + if anchor_index == len(match_candidates_list): + match.placeholder_nodes = [ + match.nodes_map[pn] for pn in self.pattern_placeholder_nodes + ] + match.returning_nodes = [ + match.nodes_map[pn] for pn in self.pattern_returning_nodes + ] + matches.append(match) + + logger.info("Found a match: %s\n", match) + return + + pattern_anchor, candidate_nodes = match_candidates_list[anchor_index] + saved_match = copy.copy(match) + + for node in candidate_nodes: + logger.info("Trying to match anchor %s to %s", pattern_anchor, node) + + match_found = self._match_nodes(pattern_anchor, node, match) + if match_found: + # match next anchor + backtracking(anchor_index + 1, match) + else: + logger.info( + "Failed to match anchor %s to %s\n", pattern_anchor, node + ) + + # revert to saved_match before matching with current anchor + match = copy.copy(saved_match) + + match = InternalMatch(anchors=self.pattern_anchors) + if match_candidates_list: + backtracking(0, match) + + # filter out the matches where the subgraph is not fully_contained + before = len(matches) + matches = [match for match in matches if self._is_contained(match.nodes_map)] + after = len(matches) + if before != after: + logger.info( + "Filtered out %s matches because they are not fully contained", + before - after, + ) + + # filter out the matches that form a cycle if the subgraph is fused + valid_matches = [] + for match in matches: + matched_compute_nodes = [ + gn + for pn, gn in match.nodes_map.items() + if pn.op not in {"placeholder", "output"} + ] + if validate_partition(matched_compute_nodes): + valid_matches.append(match) + if len(valid_matches) != len(matches): + logger.info( + "Filtered out %s matches because \ + matched subgraph would form a cycle if fused", + len(matches) - len(valid_matches), + ) + + if self.remove_overlapping_matches: + before = len(valid_matches) + matches = self._remove_overlapping_matches(valid_matches) + after = len(matches) + if before != after: + logger.info( + "Filtered out %s matches because matched subgraphs are overlapping", + before - after, + ) + + logger.info("Matches returned: %s", matches) + + return matches diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/phivenv/Lib/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..31a04e388a5198b01ebcdf754b8df5a091e8aa5e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -0,0 +1,114 @@ +from torch.fx import Graph, GraphModule, Node +from torch.fx._compatibility import compatibility + +from .matcher_utils import InternalMatch, SubgraphMatcher + + +__all__ = ["SubgraphMatcherWithNameNodeMap"] + + +def _split_to_graph_and_name_node_map( + gm: GraphModule, +) -> tuple[GraphModule, dict[str, Node]]: + from torch.fx.graph import _PyTreeInfo + from torch.utils._pytree import tree_flatten, tree_unflatten + + name_node_map = {} + for n in gm.graph.nodes: + if n.op == "output": + assert gm._out_spec is not None + output = tree_unflatten(n.args[0], gm._out_spec) + assert isinstance(output, tuple), ( + "Expecting the pattern graph to return a tuple" + ) + assert len(output) >= 2, ( + "Expecting the pattern graph to have at least two outputs" + ) + *out, name_node_map = output + flattened, out_spec = tree_flatten(out) + assert isinstance(name_node_map, dict), ( + "Expecting the input graph to have a dict output as the last element" + ) + n.args = (flattened,) + orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] + gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] + orig_pytree_info.orig_args, orig_pytree_info.in_spec, out_spec + ) + gm.recompile() + return gm, name_node_map + + +@compatibility(is_backward_compatible=False) +class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): + """Extends SubgraphMatcher to support querying the matched subgraph nodes through node name, + this requires pattern to have specific format (returning and additional dictionary at the output, + that has node name as key, and the node in the pattern graph as value, see Example for more details) + + Difference with SubgraphMatcher is that it takes a `pattern_gm` GraphModule as input during + initialization since we need to modify the graph (which requires `recompile` the GraphModule) + + Example:: + def pattern(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + return relu, {"conv": conv, "relu": relu} + + + def target_graph(x, weight): + conv = F.conv2d(x, weight) + relu = F.relu(conv) + relu *= 2 + return relu + + + pattern_gm = export_for_training(pattern, example_inputs).module() + target_gm = export_for_training(target_graph, example_inputs).module() + matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) + matches = matcher.match(target_gm) + for match in matches: + match.name_node_map["conv"].meta["annotation"] = ... + + """ + + def __init__( + self, + pattern_gm: GraphModule, + match_output: bool = False, + match_placeholder: bool = False, + remove_overlapping_matches: bool = True, + ignore_literals: bool = False, + ) -> None: + pattern_gm, name_node_map = _split_to_graph_and_name_node_map(pattern_gm) + self.name_node_map = name_node_map + super().__init__( + pattern_gm.graph, + match_output, + match_placeholder, + remove_overlapping_matches, + ignore_literals, + ) + + def match(self, graph: Graph) -> list[InternalMatch]: + """The returned InternalMatch will have name_node_map populated with a map + from node name (str) to the target node, e.g. + {"conv": target_conv_ndoe, "relu": target_relu_node} + + this requires the pattern graph returns an additional + output of node name to node, e.g. instead of: + ``` + def pattern(...): + ... + return relu + ``` + we should do: + ``` + def pattern(...): + ... + return relu, {"conv": conv, "relu": relu} + ``` instead + """ + internal_matches = super().match(graph) + for internal_match in internal_matches: + for k, n in self.name_node_map.items(): + internal_match.name_node_map[k] = internal_match.nodes_map[n] + return internal_matches diff --git a/phivenv/Lib/site-packages/torch/fx/passes/utils/source_matcher_utils.py b/phivenv/Lib/site-packages/torch/fx/passes/utils/source_matcher_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54f646dee3b4fb7da04952e3e4c44c831112a033 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/passes/utils/source_matcher_utils.py @@ -0,0 +1,162 @@ +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Callable, Optional + +from torch.fx._compatibility import compatibility +from torch.fx.graph import Graph +from torch.fx.node import Node + + +__all__ = ["get_source_partitions", "check_subgraphs_connected", "SourcePartition"] + + +# Set`PYTORCH_MATCHER_LOGLEVEL=INFO` to see debug logs +def _init_logger() -> logging.Logger: + logger = logging.getLogger(__name__) + + level = os.environ.get("PYTORCH_MATCHER_LOGLEVEL", "WARNING").upper() + logger.setLevel(level) + console = logging.StreamHandler() + formatter = logging.Formatter("%(filename)s > %(message)s") + console.setFormatter(formatter) + console.setLevel(level) + # add the handlers to the logger + logger.addHandler(console) + logger.propagate = False + return logger + + +logger = _init_logger() + + +@compatibility(is_backward_compatible=False) +@dataclass +class SourcePartition: + # Nodes in a particular partition + nodes: list[Node] + + # The source these nodes decomposed from + source: Any + + # Nodes in the graph that are needed as inputs to the partition + # These do not include the params of the partition + input_nodes: list[Node] = field(default_factory=list) + + # Nodes in the partition that are being used by nodes outside of the + # partition + output_nodes: list[Node] = field(default_factory=list) + + # Parameters that are being used + params: list[Node] = field(default_factory=list) + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def get_source_partitions( + graph: Graph, + wanted_sources: list[Any], + filter_fn: Optional[Callable[[Node], bool]] = None, +) -> dict[Any, list[SourcePartition]]: + """ + Args: + graph: The graph we want to partition + wanted_sources: List of sources of nodes that were decomposed from this + source. This can be a function (ex. torch.nn.functional.linear) or a + leaf module type (ex. torch.nn.Linear). + + Returns: + Dictionary mapping sources that were given to a list of SourcePartitions + that correspond to the list of nodes that were decomposed from the given + source. + """ + modules: dict[type, dict[str, list[Node]]] = {} + + for node in graph.nodes: + # The metadata source_fn should contain a tuple of a unique name for the + # source, and the source function if the node is decomposed from a + # function, or the type of module if the node is decomposed from a leaf + # module + + # TODO: Bypass "torch_fn" when "source_fn_stack" because now "torch_fn" can + # be different from "source_fn_stack", for example for the add_ node + # decomposed from batch norm. We should remove the check on "source_fn_stack" + # after we fix "torch_fn". T199561090 + if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( + torch_fn := node.meta.get("torch_fn", None) + ) is not None: + node_fqn, source_fn = torch_fn + source_fn_name = source_fn.split(".")[1] + if source_fn_name in wanted_sources: + diff_modules = modules.setdefault(source_fn_name, {}) + partition = diff_modules.setdefault(node_fqn, []) + partition.append(node) + + if (source_fn_st := node.meta.get("source_fn_stack", None)) is not None: + source_fn = source_fn_st[-1] + if source_fn[1] in wanted_sources: + diff_modules = modules.setdefault(source_fn[1], {}) + partition = diff_modules.setdefault(source_fn[0], []) + partition.append(node) + + def make_partition(nodes: list[Node], module_type: type) -> SourcePartition: + input_nodes = set() + output_nodes = set() + params = set() + for node in nodes: + for arg in node.args: + if isinstance(arg, Node) and arg not in nodes and arg.op != "get_attr": + input_nodes.add(arg) + + if node.op == "get_attr": + params.add(node) + # get_attr nodes won't be output nodes + continue + + for user in node.users.keys(): + if user not in nodes: + output_nodes.add(node) + + return SourcePartition( + nodes, + module_type, + list(input_nodes), + list(output_nodes), + list(params), # type: ignore[arg-type] + ) + + ret: dict[type[Any], list[SourcePartition]] = {} + + if filter_fn: + # for each partition, we apply filter_fn to filter out all partitions that doesn't satisfy the + # filter condition + filtered_modules = {} + for tp, name_to_partition in modules.items(): + filtered_name_to_partition = { + name: partition + for name, partition in name_to_partition.items() + if all(map(filter_fn, partition)) + } + filtered_modules[tp] = filtered_name_to_partition + modules = filtered_modules + + for k, v in modules.items(): + ret[k] = [make_partition(partition, k) for partition in v.values()] + + return ret + + +@compatibility(is_backward_compatible=False) # type: ignore[misc] +def check_subgraphs_connected( + subgraph1: SourcePartition, subgraph2: SourcePartition +) -> bool: + """ + Given two subgraphs A and B (in the form of a list of nodes), checks if + A has nodes connecting to at least one node in B -- aka there exists a node + in B that uses a node in A (not the other way around). + """ + + for node in reversed(subgraph1.nodes): + for user in node.users.keys(): + if user in subgraph2.nodes: + return True + return False diff --git a/phivenv/Lib/site-packages/torch/fx/proxy.py b/phivenv/Lib/site-packages/torch/fx/proxy.py new file mode 100644 index 0000000000000000000000000000000000000000..96f3a837c0f1ca63d3025cd33a8f12dfed27a445 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/proxy.py @@ -0,0 +1,799 @@ +# mypy: ignore-errors + +import collections +import copy +import dis +import enum +import inspect +import logging +import operator +import sys +import traceback +from collections import OrderedDict +from collections.abc import Iterator +from dataclasses import fields, is_dataclass +from typing import Any, Callable, Optional + +import torch +import torch.fx.traceback as fx_traceback +from torch._C import _fx_map_aggregate as map_aggregate, _fx_map_arg as map_arg +from torch.utils._traceback import CapturedTraceback + +from ._compatibility import compatibility +from .graph import Graph, magic_methods, reflectable_magic_methods +from .immutable_collections import immutable_dict, immutable_list +from .node import Argument, base_types, Node, Target +from .operator_schemas import check_for_mutable_operation + + +__all__ = [ + "TracerBase", + "GraphAppendingTracer", + "TraceError", + "Proxy", + "MetaProxy", + "Attribute", + "ParameterProxy", + "Scope", + "ScopeContextManager", +] + + +log = logging.getLogger(__name__) + + +@compatibility(is_backward_compatible=False) +class Scope: + """Scope object that records the module path and the module type + of a module. Scope is used to track the information of the module + that contains a Node in a Graph of GraphModule. For example:: + + class Sub(torch.nn.Module): + def forward(self, x): + # This will be a call_method Node in GraphModule, + # scope for this would be (module_path="sub", module_type=Sub) + return x.transpose(1, 2) + + + class M(torch.nn.Module): + def __init__(self) -> None: + self.sub = Sub() + + def forward(self, x): + # This will be a call_method Node as well, + # scope for this would be (module_path="", None) + x = x.transpose(1, 2) + x = self.sub(x) + return x + + """ + + def __init__(self, module_path: str, module_type: Any): + super().__init__() + self.module_path = module_path + self.module_type = module_type + + +@compatibility(is_backward_compatible=False) +class ScopeContextManager: + """A context manager to track the Scope of Node during symbolic tracing. + When entering a forward function of a Module, we'll update the scope information of + the current module, and when we exit, we'll restore the previous scope information. + """ + + def __init__( + self, + scope: Scope, + current_scope: Scope, + ): + super().__init__() + # Keep a copy of prev scope to restore on exit + self._prev_scope = copy.copy(scope) + # Update scope to current scope + scope.module_path = current_scope.module_path + scope.module_type = current_scope.module_type + # Save a reference so we can restore it + self._scope = scope + + def __enter__(self): + return self._scope + + def __exit__(self, *args): + self._scope.module_path = self._prev_scope.module_path + self._scope.module_type = self._prev_scope.module_type + return + + +_COPY_META_FIELDS = [ + "nn_module_stack", + "torch_fn", + "source_fn_stack", + "original_aten", + "recompute", + "ac_graph_id", + "has_backward_hook", + "from_node", + "quantization_tag", # TODO deprecated + "_numeric_debug_handle", # TODO deprecated + "custom", + "partitioner_tag", +] + + +@compatibility(is_backward_compatible=True) +class TracerBase: + graph: Graph + record_stack_traces: bool = False + # Feature flag for mutable schema checking + # Enableby default in 1.12 + check_mutable_operations: bool = False + # Feature flag for assert tracing + trace_asserts: bool = False + # Feature flag for proxying accesses to buffer values + proxy_buffer_attributes: bool = False + + # Name of the function to be traced. It will only be used when + # ``root`` is an instance of ``nn.Module`` + traced_func_name: str = "forward" + + # Maps the containing module's name to the operator name + scope: Scope + + # Records the module call stack + module_stack: OrderedDict[str, tuple[str, Any]] + + # Mapping of node name to module scope + node_name_to_scope: dict[str, tuple[str, type]] + + @compatibility(is_backward_compatible=True) + def create_node( + self, + kind: str, + target: Target, + args: tuple[Argument, ...], + kwargs: dict[str, Argument], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + ) -> Node: + """ + Inserts a graph node given target, args, kwargs, and name. + + This method can be overridden to do extra checking, validation, or + modification of values used in node creation. For example, one might + want to disallow in-place operations from being recorded. + """ + + if kind == "call_function" and self.check_mutable_operations: + check_for_mutable_operation(target, args, kwargs) + + node = self.graph.create_node(kind, target, args, kwargs, name, type_expr) + # TODO node_name_to_scope will be depreciated in favor of + # node.meta['nn_module_stack'] + self.node_name_to_scope[node.name] = ( + self.scope.module_path, + self.scope.module_type, + ) + + # Optionally set stack trace on the created Node for debugging purposes + if fx_traceback.has_preserved_node_meta(): + current_meta: dict[str, Any] = fx_traceback.get_current_meta() + + stack_trace = current_meta.get("stack_trace") + if stack_trace: + node.stack_trace = stack_trace + # Explicitly set the stack_trace, nn_module_stack and source_fn on the node.meta + # If other meta fields are needed, they can be added here + for field in _COPY_META_FIELDS: + if field in current_meta: + node.meta[field] = copy.copy(current_meta[field]) + + # Here we decrement to account for the sequence_nr having + # just been incremented while tracing this lowered aten op. + new_seq_nr = torch.autograd._get_sequence_nr() - 1 + # The sequence_nr increments every time a new autograd Node + # is created. During the FWD pass we store the sequence_nr + # corresponding to the last autograd Node created on this fx + # node's meta. A single aten op can create multiple autograd + # nodes as is the case with in-place foreach ops. During the + # BWD pass we retrieve the sequence_nr stored on the current + # executing autograd Node. See NOTE [ Sequence Number ]. + if current_meta.get("in_grad_fn", 0) > 0: + new_seq_nr = current_meta["grad_fn_seq_nr"][-1] + node.meta["seq_nr"] = new_seq_nr + + elif self.module_stack: + node.meta["nn_module_stack"] = copy.copy(self.module_stack) + + log.debug("create_node %s", node) + return node + + @compatibility(is_backward_compatible=True) + def proxy(self, node: Node) -> "Proxy": + return Proxy(node, self) + + @compatibility(is_backward_compatible=True) + def create_proxy( + self, + kind: str, + target: Target, + args: tuple[Any, ...], + kwargs: dict[str, Any], + name: Optional[str] = None, + type_expr: Optional[Any] = None, + # fix noqa when updating bc tests + proxy_factory_fn: Callable[[Node], "Proxy"] = None, # noqa: RUF013 + ): + """ + Create a Node from the given arguments, then return the Node + wrapped in a Proxy object. + + If kind = 'placeholder', then we're creating a Node that + represents the parameter of a function. If we need to encode + a default parameter, we use the ``args`` tuple. ``args`` is + otherwise empty for ``placeholder`` Nodes. + """ + + args_ = self.create_arg(args) + kwargs_ = self.create_arg(kwargs) + assert isinstance(args_, tuple) + assert isinstance(kwargs_, dict) + + node = self.create_node(kind, target, args_, kwargs_, name, type_expr) + + if not proxy_factory_fn: + proxy = self.proxy(node) + else: + proxy = proxy_factory_fn(node) + + if self.record_stack_traces and not proxy.node.stack_trace: + from torch.fx.experimental.symbolic_shapes import uninteresting_files + + user_frame_summary = CapturedTraceback.extract().summary() + if user_frame_summary: + first_forward = -1 + for i, frame in enumerate(user_frame_summary): + if frame.name == "forward": + user_frame_summary = user_frame_summary[i:] + first_forward = i + break + + # Not having a "forward" call in the stacktrace implies the + # stacktrace will probably be irrelevant + if first_forward == -1: + user_frame_summary = [] + + stack_trace = [ + frame + for frame in user_frame_summary + if frame.filename not in uninteresting_files() + ] + stack_trace = traceback.StackSummary.from_list(stack_trace) + proxy.node.stack_trace = "".join(stack_trace.format()).strip() + + return proxy + + def _find_user_frame(self): + """ + Find the Python stack frame executing the user code during + symbolic tracing. + """ + # We have to do a little dance here. Basically, walk up the callstack and + # record the first frame not in the pytorch source. This is the frame executing + # the user code during tracing. + frame = inspect.currentframe() + + pt_files = [ + "torch/fx/proxy.py", + "torch/fx/_symbolic_trace.py", + "torch/fx/experimental/proxy_tensor.py", + "torch/_ops.py", + "torch/_tensor.py", + "torch/utils/_python_dispatch.py", + "torch/_prims_common/wrappers.py", + "torch/_refs/__init__.py", + "torch/_refs/nn/functional/__init__.py", + "torch/utils/_stats.py", + ] + while frame: + frame = frame.f_back + if frame and all( + not frame.f_code.co_filename.endswith(file) for file in pt_files + ): + break + + if not frame: + return None + + return frame + + @compatibility(is_backward_compatible=True) + def create_arg(self, a: Any) -> Argument: + """ + A method that lowers the objects seen as arguments during symbolic evaluation + into Argument types that can be stored in IR. + + Can be override to support more trace-specific types. + """ + # IMPORTANT: Are you here because you are trying to proxy a new type into + # the graph? Please Please Please contact someone on the PyTorch Compiler team; + # the considerations are subtle. + # + # 1) When you add a new type, all of the downstream consumers and pass writers + # need to handle the new type. torch.fx is intended to be easy to write + # passes for, so we will push back against new types. + # 2) In torch.compile's IR, there are only specific operations that go + # into the graph. In particular, Tensor operations should go into the graph, + # but non-Tensor operations shouldn't. What that means is that constructors + # for new types *SHOULD NOT* become nodes in the FX graph. + handler = _create_arg_bypass.get(type(a)) + if handler is not None: + # this is just a performance optimization and can be removed if needed + # for common types, we have a fast path to avoid isinstance() overhead + # this doesn't remove the checks below since we need to handle subclasses + return handler(self, a) + + if isinstance(a, Proxy): + return a.node # most common arg type goes first + elif hasattr(a, "__fx_create_arg__"): + return a.__fx_create_arg__(self) + # aggregates + elif isinstance(a, tuple): + if hasattr(a, "_fields"): + # NamedTuple constructors don't seem to like getting a generator + # expression as an argument to their constructor, so build this + # intermediate tuple and unpack it into the NamedTuple constructor + args = [self.create_arg(elem) for elem in a] + return type(a)(*args) # type: ignore[arg-type] + return type(a)([self.create_arg(elem) for elem in a]) + elif isinstance(a, list): + return [self.create_arg(elem) for elem in a] + elif isinstance(a, dict): + return _create_arg_dict(self, a) + elif isinstance(a, slice): + return slice( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) + + elif isinstance(a, range): + return range( + self.create_arg(a.start), + self.create_arg(a.stop), + self.create_arg(a.step), + ) + + elif isinstance(a, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)): + return a + + elif is_dataclass(a): + kwargs = { + field.name: self.create_arg(getattr(a, field.name)) + for field in fields(a) + } + return self.create_node("call_function", a.__class__, (), kwargs) + + elif isinstance(a, (*base_types, enum.Enum)) or a is None or a is ...: + return a + + raise NotImplementedError(f"argument of type: {type(a)}") + + @compatibility(is_backward_compatible=True) + def to_bool(self, obj: "Proxy") -> bool: + """Called when a proxy object is being converted to a boolean, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return a value. + """ + raise TraceError( + "symbolically traced variables cannot be used as inputs to control flow" + ) + + @compatibility(is_backward_compatible=True) + def iter(self, obj: "Proxy") -> Iterator: + """Called when a proxy object is being iterated over, such as + when used in control flow. Normally we don't know what to do because + we don't know the value of the proxy, but a custom tracer can attach more + information to the graph node using create_node and can choose to return an iterator. + """ + raise TraceError( + "Proxy object cannot be iterated. This can be " + "attempted when the Proxy is used in a loop or" + " as a *args or **kwargs function argument. " + "See the torch.fx docs on pytorch.org for a " + "more detailed explanation of what types of " + "control flow can be traced, and check out the" + " Proxy docstring for help troubleshooting " + "Proxy iteration errors" + ) + + @compatibility(is_backward_compatible=True) + def keys(self, obj: "Proxy") -> Any: + """Called when a proxy object is has the keys() method called. + This is what happens when ** is called on a proxy. This should return an + iterator it ** is suppose to work in your custom tracer. + """ + return Attribute(obj, "keys")() + + +# used in Proxy object when just appending to the graph while not tracing. +@compatibility(is_backward_compatible=True) +class GraphAppendingTracer(TracerBase): + def __init__(self, graph: Graph): + super().__init__() + self.graph = graph + self.scope = Scope("", None) + self.module_stack = collections.OrderedDict() + self.node_name_to_scope = {} + + +@compatibility(is_backward_compatible=False) +def assert_fn(x): + assert x + + +@compatibility(is_backward_compatible=True) +class TraceError(ValueError): + pass + + +@compatibility(is_backward_compatible=True) +class Proxy: + """ + ``Proxy`` objects are ``Node`` wrappers that flow through the + program during symbolic tracing and record all the operations + (``torch`` function calls, method calls, operators) that they touch + into the growing FX Graph. + + If you're doing graph transforms, you can wrap your own ``Proxy`` + method around a raw ``Node`` so that you can use the overloaded + operators to add additional things to a ``Graph``. + + ``Proxy`` objects cannot be iterated. In other words, the symbolic + tracer will throw an error if a ``Proxy`` is used in a loop or as + an ``*args``/``**kwargs`` function argument. + + There are two main ways around this: + 1. Factor out the untraceable logic into a top-level function and + use ``fx.wrap`` on it. + 2. If the control flow is static (i.e. the loop trip count is + based on some hyperparameter), the code can be kept in its original + position and refactored into something like:: + + for i in range(self.some_hyperparameter): + indexed_item = proxied_value[i] + + For a more detailed description into the Proxy internals, check out + the "Proxy" section in `torch/fx/README.md` + """ + + @compatibility(is_backward_compatible=True) + def __init__(self, node: Node, tracer: "Optional[TracerBase]" = None): + if tracer is None: + # This allows you to create a Proxy object around a raw Node + tracer = GraphAppendingTracer(node.graph) + self.tracer = tracer + self.node = node + + def __repr__(self) -> str: + return f"Proxy({self.node.name})" + + def __getattr__(self, k) -> "Attribute": + # note: not added to the graph yet, if this is a method call + # we peephole optimize to the method invocation + return Attribute(self, k) + + def __getstate__(self) -> dict: + return self.__dict__ + + def __deepcopy__(self, memo) -> dict: + # We have to explicitly override this method, because otherwise deepcopy + # will go to __getattr__(self, "__deepcopy__") and return a + # Attribute(__deepcopy__), and may go into an infinite loop in some cases. + import copy + + new_dict = {} + for k, v in self.__dict__.items(): + try: + new_obj = copy.deepcopy(v, memo) + except Exception: + log.warning( + "Shallow copy %s of Proxy because it cannot be deepcopied. " + "Proxy is created for node %s", + k, + self.node.name, + ) + new_obj = copy.copy(v) + new_dict[k] = new_obj + assert "node" in new_dict + assert "tracer" in new_dict + new_proxy = Proxy(new_dict["node"], new_dict["tracer"]) + for k, v in new_dict.items(): + new_proxy.__dict__[k] = v + return new_proxy + + def __setstate__(self, d): + # This is called when being unpickled/loaded. + self.__dict__ = d + + def __call__(self, *args, **kwargs) -> "Proxy": + return self.tracer.create_proxy( + "call_method", "__call__", (self,) + args, kwargs + ) + + def __iter__(self) -> Iterator["Proxy"]: + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + inst_list = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + + inst_idx = bisect_left( + inst_list, calling_frame.f_lasti, key=lambda x: x.offset + ) + else: + inst_idx = calling_frame.f_lasti // 2 + inst = inst_list[inst_idx] + if inst.opname == "UNPACK_SEQUENCE": + return (self[i] for i in range(inst.argval)) # type: ignore[index] + + return self.tracer.iter(self) + + def __abs__(self): + return self.tracer.create_proxy("call_function", operator.abs, (self,), {}) + + def __bool__(self) -> bool: + if self.tracer.trace_asserts: + # check if this boolean is used in an assertion, bytecode pattern for assertions + # is pretty stable for Python 3.7--3.9 + frame = inspect.currentframe() + assert frame is not None + calling_frame = frame.f_back + assert calling_frame is not None + insts = list(dis.get_instructions(calling_frame.f_code)) + if sys.version_info >= (3, 11): + from bisect import bisect_left + + cur = bisect_left(insts, calling_frame.f_lasti, key=lambda x: x.offset) + else: + cur = calling_frame.f_lasti // 2 + inst = insts[cur] + + if inst.opname == "POP_JUMP_IF_TRUE": + first = insts[cur + 1] + assert inst.arg is not None + last = insts[inst.arg // 2 - 1] + starts_with_assert = ( + first.opname == "LOAD_GLOBAL" + and first.argval == "AssertionError" + or first.opname == "LOAD_ASSERTION_ERROR" + ) + if starts_with_assert and last.opname == "RAISE_VARARGS": + self.tracer.create_proxy("call_function", assert_fn, (self,), {}) + return True + + return self.tracer.to_bool(self) + + @compatibility(is_backward_compatible=True) + def keys(self): + return self.tracer.keys(self) + + def __len__(self): + raise RuntimeError( + "'len' is not supported in symbolic tracing by default. If you want " + "this call to be recorded, please call torch.fx.wrap('len') at " + "module scope" + ) + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + tracers: dict[Any, None] = {} + + def find_tracer(a): + if isinstance(a, cls): + tracers[a.tracer] = None + + map_aggregate(args, find_tracer) + map_aggregate(kwargs, find_tracer) + + if len(tracers) > 1: + raise RuntimeError( + f"Found multiple different tracers {list(tracers.keys())} while " + f"trying to trace operations {orig_method}" + ) + tracer = next(iter(tracers.keys())) + + if isinstance(orig_method, torch._C.ScriptMethod): + args = (orig_method.owner,) + args + return tracer.create_proxy("call_method", orig_method.name, args, kwargs) + if torch.overrides.is_tensor_method_or_property(orig_method): + return tracer.create_proxy( + "call_method", orig_method.__name__, args, kwargs + ) + else: + if isinstance(orig_method, torch._ops.HigherOrderOperator): + # TODO: Define how to symbolically trace HigherOrderOperators + raise RuntimeError("Unable to symbolically trace HigherOrderOperators") + return tracer.create_proxy( + "call_function", + orig_method, + args, + kwargs, + name=tracer.graph._target_to_str(orig_method.__name__), + ) + + +@compatibility(is_backward_compatible=False) +class MetaProxy(Proxy): + """ + A Proxy subclass that propagates metadata (meta['val']) during graph tracing. + """ + + def __init__( + self, node: Node, tracer: "Optional[TracerBase]" = None, fake_mode=None + ): + super().__init__(node, tracer) + self.fake_mode = fake_mode + + def __repr__(self) -> str: + return f"MetaProxy({self.node.name})" + + @classmethod + def __torch_function__(cls, orig_method, types, args=None, kwargs=None): + args = args if args else () + kwargs = kwargs if kwargs else {} + + meta_proxy = None + for arg in args: + if isinstance(arg, MetaProxy): + meta_proxy = arg + break + + assert meta_proxy is not None, ( + "No MetaProxy found in arguments, but one is expected." + ) + + proxy = super().__torch_function__(orig_method, types, args, kwargs) + with meta_proxy.fake_mode: + proxy.node.meta["val"] = orig_method( + *[a.node.meta["val"] if isinstance(a, Proxy) else a for a in args], + **kwargs, + ) + return MetaProxy(proxy.node, proxy.tracer, meta_proxy.fake_mode) + + +@compatibility(is_backward_compatible=True) +class Attribute(Proxy): + @compatibility(is_backward_compatible=True) + def __init__(self, root: Proxy, attr: str): + self.root = root + self.attr = attr + self.tracer = root.tracer + self._node: Optional[Node] = None + + @property + def node(self): + # the node for attributes is added lazily, since most will just be method calls + # which do not rely on the getitem call + if self._node is None: + self._node = self.tracer.create_proxy( + "call_function", getattr, (self.root, self.attr), {} + ).node + return self._node + + def __call__(self, *args, **kwargs): + return self.tracer.create_proxy( + "call_method", self.attr, (self.root,) + args, kwargs + ) + + +@compatibility(is_backward_compatible=False) +class ParameterProxy(Proxy): + """ + A special proxy which lets "shape", "size", "dim", and a few other + attribute accesses pass through to the underlying module parameter object, + so that conditional tests on these attributes will not throw exception during tracing + """ + + def __init__(self, tracer: TracerBase, node: Node, name, param): + super().__init__(node, tracer) + assert isinstance(param, torch.nn.Parameter) + self.param = param + self.name = name + + def __repr__(self) -> str: + return f"ParameterProxy({self.name})" + + @property + def shape(self): + return self.param.shape + + def size(self): + return self.param.size() + + def dim(self): + return self.param.dim() + + @property + def ndim(self): + return self.param.ndim + + def numel(self): + return self.param.numel() + + def nelement(self): + return self.param.nelement() + + +for method in magic_methods: + + def _scope(method): + def impl(*args, **kwargs): + tracer = args[0].tracer + target = getattr(operator, method) + return tracer.create_proxy("call_function", target, args, kwargs) + + impl.__name__ = method + as_magic = f"__{method.strip('_')}__" + setattr(Proxy, as_magic, impl) + + _scope(method) + + +def _define_reflectable(orig_method_name): + method_name = f"__r{orig_method_name.strip('_')}__" + + def impl(self, rhs): + target = getattr(operator, orig_method_name) + return self.tracer.create_proxy("call_function", target, (rhs, self), {}) + + impl.__name__ = method_name + impl.__qualname__ = method_name + setattr(Proxy, method_name, impl) + + +for orig_method_name in reflectable_magic_methods: + _define_reflectable(orig_method_name) + + +def _no_nodes_error(arg): + raise RuntimeError( + "Keys for dictionaries used as an argument cannot contain a " + f"Node. Got key: {arg}" + ) + + +def _create_arg_dict(self, a): + r = {} + for k, v in a.items(): + if not isinstance(k, str): + # Check for invalid dict keys. We do not want a Proxy to appear + # anywhere within the key. Since keys can be collection types, + # we iterate through the key with map_arg + k = self.create_arg(k) + map_arg(k, _no_nodes_error) + r[k] = self.create_arg(v) + return r + + +_create_arg_bypass = { + t: lambda self, a: a + for t in [ + *base_types, + type(None), + type(...), + torch._ops.OpOverload, + torch._ops.HigherOrderOperator, + ] +} +_create_arg_bypass[Proxy] = lambda self, a: a.node +_create_arg_bypass[tuple] = lambda self, a: tuple([self.create_arg(elem) for elem in a]) +_create_arg_bypass[list] = lambda self, a: [self.create_arg(elem) for elem in a] +_create_arg_bypass[dict] = _create_arg_dict +_create_arg_bypass[immutable_list] = _create_arg_bypass[list] +_create_arg_bypass[immutable_dict] = _create_arg_bypass[dict] diff --git a/phivenv/Lib/site-packages/torch/fx/subgraph_rewriter.py b/phivenv/Lib/site-packages/torch/fx/subgraph_rewriter.py new file mode 100644 index 0000000000000000000000000000000000000000..fa710004abd2b7bdfb01fac6d30d78e69c1d493d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/subgraph_rewriter.py @@ -0,0 +1,428 @@ +import copy +from dataclasses import dataclass +from typing import Any, Callable, NamedTuple, Optional, TYPE_CHECKING, Union + +import torch + +from ._compatibility import compatibility +from ._symbolic_trace import symbolic_trace +from .graph import Graph +from .graph_module import GraphModule +from .node import Node + + +if TYPE_CHECKING: + from .passes.utils.matcher_with_name_node_map_utils import InternalMatch + +__all__ = [ + "Match", + "replace_pattern", + "replace_pattern_with_filters", + "ReplacedPatterns", +] + + +@compatibility(is_backward_compatible=True) +class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] + + +@compatibility(is_backward_compatible=False) +@dataclass +class ReplacedPatterns: + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: dict[Node, Node] + # List of nodes that were added into the graph + replacements: list[Node] + + +def _replace_attributes(gm: GraphModule, replacement: torch.nn.Module) -> None: + gm.delete_all_unused_submodules() + + if isinstance(replacement, GraphModule): + replacement.graph.lint() + + def try_get_attr(gm: torch.nn.Module, target: str) -> Optional[Any]: + module_path, _, attr_name = target.rpartition(".") + try: + mod: torch.nn.Module = gm.get_submodule(module_path) + except AttributeError: + return None + attr = getattr(mod, attr_name, None) + return attr + + for node in gm.graph.nodes: + if node.op == "call_module" or node.op == "get_attr": + gm_attr = try_get_attr(gm, node.target) + replacement_attr = try_get_attr(replacement, node.target) + + # CASE 1: This target already exists as an attribute in our + # result GraphModule. Whether or not it exists in + # `replacement`, the existing submodule takes precedence. + if gm_attr is not None: + continue + + # CASE 2: The target exists as an attribute in `replacement` + # only, so we need to copy it over. + elif replacement_attr is not None: + new_attr = copy.deepcopy(replacement_attr) + if isinstance(replacement_attr, torch.nn.Module): + gm.add_submodule(node.target, new_attr) + else: + setattr(gm, node.target, new_attr) + + # CASE 3: The target doesn't exist as an attribute in `gm` + # or `replacement` + else: + raise RuntimeError( + 'Attempted to create a "', + node.op, + '" node during subgraph rewriting ' + f"with target {node.target}, but " + "the referenced attribute does not " + "exist in the replacement GraphModule", + ) + + gm.graph.lint() + + +@compatibility(is_backward_compatible=True) +def replace_pattern( + gm: GraphModule, + pattern: Union[Callable, GraphModule], + replacement: Union[Callable, GraphModule], +) -> list[Match]: + """ + Matches all possible non-overlapping sets of operators and their + data dependencies (``pattern``) in the Graph of a GraphModule + (``gm``), then replaces each of these matched subgraphs with another + subgraph (``replacement``). + + Args: + ``gm``: The GraphModule that wraps the Graph to operate on + ``pattern``: The subgraph to match in ``gm`` for replacement + ``replacement``: The subgraph to replace ``pattern`` with + + Returns: + List[Match]: A list of ``Match`` objects representing the places + in the original graph that ``pattern`` was matched to. The list + is empty if there are no matches. ``Match`` is defined as: + + .. code-block:: python + + class Match(NamedTuple): + # Node from which the match was found + anchor: Node + # Maps nodes in the pattern subgraph to nodes in the larger graph + nodes_map: Dict[Node, Node] + + Examples: + + .. code-block:: python + + import torch + from torch.fx import symbolic_trace, subgraph_rewriter + + + class M(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x, w1, w2): + m1 = torch.cat([w1, w2]).sum() + m2 = torch.cat([w1, w2]).sum() + return x + torch.max(m1) + torch.max(m2) + + + def pattern(w1, w2): + return torch.cat([w1, w2]) + + + def replacement(w1, w2): + return torch.stack([w1, w2]) + + + traced_module = symbolic_trace(M()) + + subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) + + The above code will first match ``pattern`` in the ``forward`` + method of ``traced_module``. Pattern-matching is done based on + use-def relationships, not node names. For example, if you had + ``p = torch.cat([a, b])`` in ``pattern``, you could match + ``m = torch.cat([a, b])`` in the original ``forward`` function, + despite the variable names being different (``p`` vs ``m``). + + The ``return`` statement in ``pattern`` is matched based on its + value only; it may or may not match to the ``return`` statement in + the larger graph. In other words, the pattern doesn't have to extend + to the end of the larger graph. + + When the pattern is matched, it will be removed from the larger + function and replaced by ``replacement``. If there are multiple + matches for ``pattern`` in the larger function, each non-overlapping + match will be replaced. In the case of a match overlap, the first + found match in the set of overlapping matches will be replaced. + ("First" here being defined as the first in a topological ordering + of the Nodes' use-def relationships. In most cases, the first Node + is the parameter that appears directly after ``self``, while the + last Node is whatever the function returns.) + + One important thing to note is that the parameters of the + ``pattern`` Callable must be used in the Callable itself, + and the parameters of the ``replacement`` Callable must match + the pattern. The first rule is why, in the above code block, the + ``forward`` function has parameters ``x, w1, w2``, but the + ``pattern`` function only has parameters ``w1, w2``. ``pattern`` + doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. + As an example of the second rule, consider replacing + + .. code-block:: python + + def pattern(x, y): + return torch.neg(x) + torch.relu(y) + + with + + .. code-block:: python + + def replacement(x, y): + return torch.relu(x) + + In this case, ``replacement`` needs the same number of parameters + as ``pattern`` (both ``x`` and ``y``), even though the parameter + ``y`` isn't used in ``replacement``. + + After calling ``subgraph_rewriter.replace_pattern``, the generated + Python code looks like this: + + .. code-block:: python + + def forward(self, x, w1, w2): + stack_1 = torch.stack([w1, w2]) + sum_1 = stack_1.sum() + stack_2 = torch.stack([w1, w2]) + sum_2 = stack_2.sum() + max_1 = torch.max(sum_1) + add_1 = x + max_1 + max_2 = torch.max(sum_2) + add_2 = add_1 + max_2 + return add_2 + """ + match_and_replacements = _replace_pattern(gm, pattern, replacement) + return [ + Match(anchor=m.anchor, nodes_map=m.nodes_map) for m in match_and_replacements + ] + + +# Experimental API, not backward compatible +@compatibility(is_backward_compatible=False) +def replace_pattern_with_filters( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, + match_filters: Optional[ + list[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, + ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, +) -> list[ReplacedPatterns]: + """ + See replace_pattern for documentation. This function is an overload with an additional match_filter argument. + + Args: + ``match_filters``: A list of functions that take in + (match: InternalMatch, original_graph: Graph, pattern_graph: Graph) and return a boolean indicating + whether the match satisfies the condition. + See matcher_utils.py for definition of InternalMatch. + ``replacement_callback``: A function that takes in a match and returns a + Graph to be used as the replacement. This allows you to construct a + replacement graph based on the match. + """ + + return _replace_pattern( + gm, pattern, replacement, match_filters, ignore_literals, replacement_callback + ) + + +def _replace_pattern( + gm: GraphModule, + pattern: Union[Callable, Graph, GraphModule], + replacement: Union[Callable, Graph, GraphModule, None] = None, + match_filters: Optional[ + list[Callable[["InternalMatch", Graph, Graph], bool]] + ] = None, + ignore_literals: bool = False, + # Placed at the end to avoid breaking backward compatibility + replacement_callback: Optional[ + Callable[["InternalMatch", Graph, Graph], Graph] + ] = None, +) -> list[ReplacedPatterns]: + from torch.fx.passes.utils.matcher_utils import InternalMatch, SubgraphMatcher + + if match_filters is None: + match_filters = [] + + # Get the graphs for `gm`, `pattern`, `replacement` + original_graph: Graph = gm.graph + + if isinstance(pattern, GraphModule): + pattern_graph = pattern.graph + elif isinstance(pattern, Graph): + pattern_graph = pattern + else: + pattern_graph = symbolic_trace(pattern).graph + + matcher = SubgraphMatcher( + pattern_graph, + match_output=False, + match_placeholder=False, + remove_overlapping_matches=True, + ignore_literals=ignore_literals, + ) + _matches: list[InternalMatch] = matcher.match(original_graph) + + # Filter out matches that don't match the filter + _matches = [ + m + for m in _matches + if all( + match_filter(m, original_graph, pattern_graph) + for match_filter in match_filters + ) + ] + + if isinstance(replacement, GraphModule): + common_replacement_graph = replacement.graph + elif isinstance(replacement, Graph): + common_replacement_graph = replacement + elif callable(replacement): + common_replacement_graph = symbolic_trace(replacement).graph + else: + assert replacement_callback is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) + common_replacement_graph = None + + # As we progressively replace nodes, we'll need to keep track of how the match results should change + match_changed_node: dict[Node, Node] = {} + + match_and_replacements = [] + for match in _matches: + if replacement_callback is not None: + replacement_graph = replacement_callback( + match, original_graph, pattern_graph + ) + else: + assert common_replacement_graph is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) + replacement_graph = common_replacement_graph + replacement_placeholders = [ + n for n in replacement_graph.nodes if n.op == "placeholder" + ] + + # Build connecting between replacement graph's input and original graph input producer node + + # Initialize `val_map` with mappings from placeholder nodes in + # `replacement` to their corresponding node in `original_graph` + assert len(match.placeholder_nodes) == len(replacement_placeholders) + val_map: dict[Node, Node] = {} + for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): + if isinstance(gn, Node): + val_map[rn] = match_changed_node.get(gn, gn) + if gn != val_map[rn]: + # Update match.placeholder_nodes and match.nodes_map with the node that replaced gn + gn_ind = match.placeholder_nodes.index(gn) + match.placeholder_nodes[gn_ind] = match_changed_node[gn] + map_key = list(match.nodes_map.keys())[ + list(match.nodes_map.values()).index(gn) + ] + match.nodes_map[map_key] = match_changed_node[gn] + else: + val_map[rn] = gn + + # Copy the replacement graph over + user_nodes: set[Node] = set() + for n in match.returning_nodes: + user_nodes.update(n.users) + + first_user_node = None + if len(user_nodes) == 0: + first_user_node = None + elif len(user_nodes) == 1: + first_user_node = next(iter(user_nodes)) + else: + # If there are multiple user nodes, we need to find the first user node + # in the current execution order of the `original_graph` + for n in original_graph.nodes: + if n in user_nodes: + first_user_node = n + break + + first_next_node = None + if first_user_node is None: + # no users, so we insert the replacement graph before the first next + # node of returning nodes + next_node = None + for n in reversed(original_graph.nodes): + if n in match.returning_nodes: + first_next_node = next_node + break + else: + next_node = n + insert_point = ( + first_user_node if first_user_node is not None else first_next_node + ) + assert insert_point is not None, "The insert point can't be None" + with original_graph.inserting_before(insert_point): + copied_returning_nodes = original_graph.graph_copy( + replacement_graph, val_map + ) + + if isinstance(copied_returning_nodes, Node): + copied_returning_nodes = (copied_returning_nodes,) + + # Get a list of nodes that have been replaced into the graph + replacement_nodes: list[Node] = [ + v for v in val_map.values() if v not in match.placeholder_nodes + ] + + # Hook the output Node of the replacement subgraph in to the + # original Graph at the correct location + assert len(match.returning_nodes) == len(copied_returning_nodes) # type: ignore[arg-type] + for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): # type: ignore[arg-type] + gn.replace_all_uses_with(copied_node) + match_changed_node[gn] = copied_node + # Remove the original nodes + for node in reversed(pattern_graph.nodes): + if node.op != "placeholder" and node.op != "output": + gn = match.nodes_map[node] + gm.graph.erase_node(gn) + + match_and_replacements.append( + ReplacedPatterns( + anchor=match.anchors[0], + nodes_map=match.nodes_map, + replacements=replacement_nodes, + ) + ) + + # Update the passed-in GraphModule to reflect the new state of + # `original_graph` + gm.recompile() + + # If `replacement` was an nn.Module, we'll need to make sure that + # all the submodules have been copied over correctly + if isinstance(replacement, torch.nn.Module): + _replace_attributes(gm, replacement) + + return match_and_replacements diff --git a/phivenv/Lib/site-packages/torch/fx/tensor_type.py b/phivenv/Lib/site-packages/torch/fx/tensor_type.py new file mode 100644 index 0000000000000000000000000000000000000000..a388ae90206b3d3881dfb81b85a05ca3950314b7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/tensor_type.py @@ -0,0 +1,111 @@ +# mypy: allow-untyped-defs +from torch.fx.experimental.unification import Var # type: ignore[attr-defined] + +from ._compatibility import compatibility + + +@compatibility(is_backward_compatible=False) +class TensorType: + """ + TensorType defines a type for tensors, which consists of a list of dimensions. + Example: + class M(torch.nn.Module): + def forward(self, x:TensorType((1,2,3, Dyn)), y:TensorType((1,2,3, Dyn))): + return torch.add(x, y) + """ + + def __init__(self, dim): + self.__origin__ = TensorType + self.__args__ = dim + + def __repr__(self): + return f"TensorType[{self.__args__}]" + + def __eq__(self, other): + if isinstance(other, self.__class__): + return list(self.__args__) == list(other.__args__) + else: + return False + + @staticmethod + def __class_getitem__(*args): + if len(args) == 1 and isinstance(args[0], tuple): + args = args[0] + return TensorType(tuple(args)) + + +class _DynType: + """ + _DynType defines a type which stands for the absence of type information. + """ + + def __init__(self) -> None: + self.__name__ = "_DynType" + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __str__(self): + return "Dyn" + + def __repr__(self): + return "Dyn" + + +Dyn = _DynType() + + +@compatibility(is_backward_compatible=False) +def is_consistent(t1, t2): + """ + A binary relation denoted by ~ that determines if t1 is consistent with t2. + The relation is reflexive, symmetric but not transitive. + returns True if t1 and t2 are consistent and False otherwise. + Example: + Dyn ~ TensorType((1,2,3)) + int ~ Dyn + int ~ int + TensorType((1,Dyn,3)) ~ TensorType((1,2,3)) + """ + + if t1 == t2: + return True + + if t1 == Dyn or t2 == Dyn or isinstance(t1, Var) or isinstance(t2, Var): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and all( + is_consistent(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) + else: + return False + + +@compatibility(is_backward_compatible=False) +def is_more_precise(t1, t2): + """ + A binary relation denoted by <= that determines if t1 is more precise than t2. + The relation is reflexive and transitive. + returns True if t1 is more precise than t2 and False otherwise. + Example: + Dyn >= TensorType((1,2,3)) + int >= Dyn + int >= int + TensorType((1,Dyn,3)) <= TensorType((1,2,3)) + """ + if t1 == t2: + return True + + if isinstance(t2, _DynType): + return True + + if isinstance(t1, TensorType) and isinstance(t2, TensorType): + return len(t1.__args__) == len(t2.__args__) and all( + is_more_precise(elem1, elem2) + for elem1, elem2 in zip(t1.__args__, t2.__args__) + ) + + else: + return False diff --git a/phivenv/Lib/site-packages/torch/fx/traceback.py b/phivenv/Lib/site-packages/torch/fx/traceback.py new file mode 100644 index 0000000000000000000000000000000000000000..81e77c3796db567d5e55fc26e35c38b5d079823e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/fx/traceback.py @@ -0,0 +1,238 @@ +# mypy: allow-untyped-defs +import copy +import traceback +from contextlib import contextmanager +from enum import Enum +from typing import Any, Optional, Union + +from ._compatibility import compatibility +from .graph import Graph +from .node import Node + + +__all__ = [ + "preserve_node_meta", + "has_preserved_node_meta", + "set_stack_trace", + "set_grad_fn_seq_nr", + "reset_grad_fn_seq_nr", + "format_stack", + "set_current_meta", + "get_current_meta", + "NodeSource", + "NodeSourceAction", + "get_graph_provenance_json", +] + +current_meta: dict[str, Any] = {} +should_preserve_node_meta = False + + +@compatibility(is_backward_compatible=False) +class NodeSourceAction(Enum): + CREATE = "create" + REPLACE = "replace" + + +@compatibility(is_backward_compatible=False) +class NodeSource: + """ + NodeSource is a data structure that contains the provenance information of a node. + If node `a` is created from node `b`, then `a.meta["from_node"]` may contain NodeSource(b). + """ + + class NodeInfo: + def __init__(self, name: str, target: str, graph_id: int): + self.name = name + self.target = target + self.graph_id = graph_id + + pass_name: str + action: list["NodeSourceAction"] + from_node: list["NodeSource"] + node_info: Optional["NodeInfo"] + + def __init__( + self, + node: Optional[Node], + pass_name: str = "", + action: Optional[Union["NodeSourceAction", list["NodeSourceAction"]]] = None, + ): + self.pass_name = pass_name + + if action is None: + action = [] + elif not isinstance(action, list): + action = [action] + for a in action: + assert isinstance(a, NodeSourceAction) + self.action = action + if node: + self.node_info = self.NodeInfo( + name=node.name, target=str(node.target), graph_id=id(node.graph) + ) + self.from_node = ( + copy.deepcopy(node.meta["from_node"]) + if "from_node" in node.meta + else [] + ) + else: + self.node_info = None + self.from_node = [] + + @property + def name(self) -> str: + return self.node_info.name if self.node_info else "" + + @property + def target(self) -> str: + return self.node_info.target if self.node_info else "" + + @property + def graph_id(self) -> int: + return self.node_info.graph_id if self.node_info else -1 + + def __repr__(self): + return self.print_readable() + + def _get_action_string(self): + return "+".join([a.name.lower() for a in self.action]) + + def print_readable(self, indent=0): + if indent > 9: + return "" + result = "" + action_string = self._get_action_string() + result += ( + " " * indent * 4 + + f"(name={self.name}, pass_name={self.pass_name}, action={action_string}, graph_id={self.graph_id})\n" + ) + for item in self.from_node: + result += item.print_readable(indent + 1) + return result + + def to_dict(self) -> dict: + # Convert the object to a dictionary + action_string = self._get_action_string() + return { + "name": self.name, + "target": self.target, + "graph_id": self.graph_id, + "pass_name": self.pass_name, + "action": action_string, + "from_node": [node.to_dict() for node in self.from_node], + } + + +@compatibility(is_backward_compatible=False) +@contextmanager +def preserve_node_meta(enable=True): + global should_preserve_node_meta + global current_meta + # If enable is False, this context manager is a no-op + if not enable: + yield + else: + saved_should_preserve_node_meta = should_preserve_node_meta + # Shallow copy is OK since fields of current_meta are not mutated + saved_current_meta = current_meta.copy() + try: + should_preserve_node_meta = True + yield + finally: + should_preserve_node_meta = saved_should_preserve_node_meta + current_meta = saved_current_meta + + +@compatibility(is_backward_compatible=False) +def set_stack_trace(stack: list[str]): + global current_meta + + if should_preserve_node_meta and stack: + current_meta["stack_trace"] = "".join(stack) + + +@compatibility(is_backward_compatible=False) +def set_grad_fn_seq_nr(seq_nr): + global current_meta + + if should_preserve_node_meta: + # The seq_nr is captured by eager mode in the grad_fn during forward + current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [ + seq_nr + ] + current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 + + +@compatibility(is_backward_compatible=False) +def reset_grad_fn_seq_nr(): + # NB: reset state properly, this would be helpful towards supporting + # reentrant autograd if we actually wanted to do that. + global current_meta + if should_preserve_node_meta: + current_level = current_meta.get("in_grad_fn", 0) + assert current_level > 0 + if current_level == 1: + del current_meta["in_grad_fn"] + del current_meta["grad_fn_seq_nr"] + else: + current_meta["in_grad_fn"] = current_level - 1 + current_meta["grad_fn_seq_nr"] = current_meta["grad_fn_seq_nr"][:-1] + + +@compatibility(is_backward_compatible=False) +def format_stack() -> list[str]: + if should_preserve_node_meta: + return [current_meta.get("stack_trace", "")] + else: + # fallback to traceback.format_stack() + return traceback.format_list(traceback.extract_stack()[:-1]) + + +@compatibility(is_backward_compatible=False) +def has_preserved_node_meta() -> bool: + return should_preserve_node_meta + + +@compatibility(is_backward_compatible=False) +@contextmanager +def set_current_meta(node, pass_name=""): + global current_meta + if should_preserve_node_meta and node.meta: + saved_meta = current_meta + try: + current_meta = node.meta.copy() + + # Update the "from_node" field in current_meta for provenance tracking. + # Instead of appending, overwrite the "from_node" field because current_meta + # will be assigned to the new node. The new NodeSource(node, ...) will + # include the information from the previous current_meta["from_node"]. + current_meta["from_node"] = [ + NodeSource(node, pass_name, NodeSourceAction.CREATE) + ] + yield + finally: + current_meta = saved_meta + else: + yield + + +@compatibility(is_backward_compatible=False) +def get_current_meta() -> dict[str, Any]: + return current_meta + + +@compatibility(is_backward_compatible=False) +def get_graph_provenance_json(graph: Graph) -> dict[str, Any]: + """ + Given an fx.Graph, return a json that contains the provenance information of each node. + """ + provenance_tracking_json = {} + for node in graph.nodes: + if node.op == "call_function": + provenance_tracking_json[node.name] = ( + [source.to_dict() for source in node.meta["from_node"]] + if "from_node" in node.meta + else [] + ) + return provenance_tracking_json diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ATen.h b/phivenv/Lib/site-packages/torch/include/ATen/ATen.h new file mode 100644 index 0000000000000000000000000000000000000000..60a33d74a04a0a1ae07d1e000b8c73bc3ca3cda4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ATen.h @@ -0,0 +1,37 @@ +#pragma once + +#if !defined(_MSC_VER) && __cplusplus < 201703L +#error C++17 or later compatible compiler is required to use ATen. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// TODO: try to remove this +// There is some back story, see https://github.com/pytorch/pytorch/issues/48684 +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/AccumulateType.h b/phivenv/Lib/site-packages/torch/include/ATen/AccumulateType.h new file mode 100644 index 0000000000000000000000000000000000000000..6043c3ed439db223a800dd343300c05c153b955f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/AccumulateType.h @@ -0,0 +1,173 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Defines the accumulation type for a scalar type. +// Example: +// using accscalar_t = acc_type; +// +// Accumulation types are an important concept in numeric computing +// because you frequently want to perform intermediate computations +// at a higher precision than the input and output precision, to avoid +// compounding internal rounding errors. Accumulation is the most +// well-known intermediate computation (it is of great importance for +// sum reduction and matrix multiply, for example), but in PyTorch +// acc_type ends up getting used for all sorts of other intermediate +// computations, so it perhaps would be more accurately (ahem) called an +// "accurate" type. acc_type is especially important for reduced +// precision operations like float16 and bfloat16, where relatively +// benign looking inputs can easily end up overflowing/underflowing. +// +// acc_type is parametrized by whether or not you are running on CUDA +// or not, because on CUDA double precision operations are expensive +// and so by default, we don't actually want to use double as an +// acc_type on CUDA. A lot of things are typed out below, but +// basically, the table is generated by a few rules: +// +// If bool: +// Use 'bool' as acc_type. +// If floating point: +// If CUDA, use 'float' as acc_type (unless scalar_t is double), +// otherwise (CPU) use 'double' +// If integral: +// Use 'int64_t' as acc_type +// +// You're not forced to use this template; if you happen to know +// something specific about your use case, you can specify your own +// desired behavior. This template, however, will give you a reasonable +// default that will work for all dtypes supported in PyTorch. + +#if defined(__CUDACC__) +#include +#include +#elif defined(__HIPCC__) +#include +#include +#endif + +namespace at { + +template +struct AccumulateTypeDevice {}; + +template +struct AccumulateType {}; + +template +struct AccumulateType { + using type = typename AccumulateTypeDevice::type; +}; + +template +struct AccumulateType { + using type = typename AccumulateTypeDevice::type; +}; + +template +using acc_type_device = typename AccumulateTypeDevice::type; + +template +using acc_type = typename AccumulateType::type; + +#define ACC_TYPE(t, acc_t, device_type) \ + template <> \ + struct AccumulateTypeDevice { \ + using type = acc_t; \ + }; +#define MPS_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::MPS) +#define XPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::XPU) +#define CUDA_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CUDA) +#define CPU_ACC_TYPE(t, acc_t) ACC_TYPE(t, acc_t, c10::DeviceType::CPU) + +MPS_ACC_TYPE(BFloat16, float) +MPS_ACC_TYPE(Half, float) +MPS_ACC_TYPE(Float8_e5m2, float) +MPS_ACC_TYPE(Float8_e4m3fn, float) +MPS_ACC_TYPE(Float8_e5m2fnuz, float) +MPS_ACC_TYPE(Float8_e4m3fnuz, float) +MPS_ACC_TYPE(float, float) +MPS_ACC_TYPE(double, float) +MPS_ACC_TYPE(int8_t, int64_t) +MPS_ACC_TYPE(uint8_t, int64_t) +MPS_ACC_TYPE(char, int64_t) +MPS_ACC_TYPE(int16_t, int64_t) +MPS_ACC_TYPE(int32_t, int64_t) +MPS_ACC_TYPE(int64_t, int64_t) +MPS_ACC_TYPE(bool, bool) +MPS_ACC_TYPE(c10::complex, c10::complex) +MPS_ACC_TYPE(c10::complex, c10::complex) +MPS_ACC_TYPE(c10::complex, c10::complex) + +XPU_ACC_TYPE(BFloat16, float) +XPU_ACC_TYPE(Half, float) +XPU_ACC_TYPE(Float8_e5m2, float) +XPU_ACC_TYPE(Float8_e4m3fn, float) +XPU_ACC_TYPE(Float8_e5m2fnuz, float) +XPU_ACC_TYPE(Float8_e4m3fnuz, float) +XPU_ACC_TYPE(float, float) +XPU_ACC_TYPE(double, double) +XPU_ACC_TYPE(int8_t, int64_t) +XPU_ACC_TYPE(uint8_t, int64_t) +XPU_ACC_TYPE(char, int64_t) +XPU_ACC_TYPE(int16_t, int64_t) +XPU_ACC_TYPE(int32_t, int64_t) +XPU_ACC_TYPE(int64_t, int64_t) +XPU_ACC_TYPE(bool, bool) +XPU_ACC_TYPE(c10::complex, c10::complex) +XPU_ACC_TYPE(c10::complex, c10::complex) +XPU_ACC_TYPE(c10::complex, c10::complex) + +#if defined(__CUDACC__) || defined(__HIPCC__) +CUDA_ACC_TYPE(half, float) +#endif +CUDA_ACC_TYPE(BFloat16, float) +CUDA_ACC_TYPE(Half, float) +CUDA_ACC_TYPE(Float8_e5m2, float) +CUDA_ACC_TYPE(Float8_e4m3fn, float) +CUDA_ACC_TYPE(Float8_e5m2fnuz, float) +CUDA_ACC_TYPE(Float8_e4m3fnuz, float) +CUDA_ACC_TYPE(float, float) +CUDA_ACC_TYPE(double, double) +CUDA_ACC_TYPE(int8_t, int64_t) +CUDA_ACC_TYPE(uint8_t, int64_t) +CUDA_ACC_TYPE(char, int64_t) +CUDA_ACC_TYPE(int16_t, int64_t) +CUDA_ACC_TYPE(int32_t, int64_t) +CUDA_ACC_TYPE(int64_t, int64_t) +CUDA_ACC_TYPE(bool, bool) +CUDA_ACC_TYPE(c10::complex, c10::complex) +CUDA_ACC_TYPE(c10::complex, c10::complex) +CUDA_ACC_TYPE(c10::complex, c10::complex) + +CPU_ACC_TYPE(BFloat16, float) +CPU_ACC_TYPE(Half, float) +CPU_ACC_TYPE(Float8_e5m2, float) +CPU_ACC_TYPE(Float8_e4m3fn, float) +CPU_ACC_TYPE(Float8_e5m2fnuz, float) +CPU_ACC_TYPE(Float8_e4m3fnuz, float) +CPU_ACC_TYPE(float, double) +CPU_ACC_TYPE(double, double) +CPU_ACC_TYPE(int8_t, int64_t) +CPU_ACC_TYPE(uint8_t, int64_t) +CPU_ACC_TYPE(char, int64_t) +CPU_ACC_TYPE(int16_t, int64_t) +CPU_ACC_TYPE(int32_t, int64_t) +CPU_ACC_TYPE(int64_t, int64_t) +CPU_ACC_TYPE(bool, bool) +CPU_ACC_TYPE(c10::complex, c10::complex) +CPU_ACC_TYPE(c10::complex, c10::complex) +CPU_ACC_TYPE(c10::complex, c10::complex) + +TORCH_API c10::ScalarType toAccumulateType( + c10::ScalarType type, + c10::DeviceType device); +TORCH_API c10::ScalarType toAccumulateType(c10::ScalarType type, bool is_cuda); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ArrayRef.h b/phivenv/Lib/site-packages/torch/include/ATen/ArrayRef.h new file mode 100644 index 0000000000000000000000000000000000000000..8c1febe4654361afa6b90cd38898b90cf8a8d17f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ArrayRef.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Backend.h b/phivenv/Lib/site-packages/torch/include/ATen/Backend.h new file mode 100644 index 0000000000000000000000000000000000000000..34b3b191549d2be6218da30bc2acab3baa215888 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Backend.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Backtrace.h b/phivenv/Lib/site-packages/torch/include/ATen/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..2d6eba46720207605fd2b6640ce48c9ae0bffd20 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Backtrace.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/BlasBackend.h b/phivenv/Lib/site-packages/torch/include/ATen/BlasBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..44b555d9b6bc2f57868e18ec3d58a29363a7564a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/BlasBackend.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +enum class BlasBackend : int8_t { Default, Cublas, Cublaslt, Ck }; + +inline std::string BlasBackendToString(at::BlasBackend backend) { + switch (backend) { + case BlasBackend::Default: + return "at::BlasBackend::Default"; + case BlasBackend::Cublas: + return "at::BlasBackend::Cublas"; + case BlasBackend::Cublaslt: + return "at::BlasBackend::Cublaslt"; + case BlasBackend::Ck: + return "at::BlasBackend::Ck"; + default: + TORCH_CHECK(false, "Unknown blas backend"); + } +} + +inline std::ostream& operator<<(std::ostream& stream, at::BlasBackend backend) { + return stream << BlasBackendToString(backend); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..002ef0e2d2a97d1e17a1288c0bf5f9f6adbb17db --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CPUApplyUtils.h @@ -0,0 +1,352 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace at { + +/* + * The basic strategy for apply is as follows: + * + * 1. Starting with the outermost index, loop until we reach a dimension where + * the data is no longer contiguous, i.e. the stride at that dimension is not + * equal to the size of the tensor defined by the outer dimensions. Let's call + * this outer (contiguous) tensor A. Note that if the Tensor is contiguous, then + * A is equal to the entire Tensor. Let's call the inner tensor B. + * + * 2. We loop through the indices in B, starting at its outermost dimension. For + * example, if B is a 2x2 matrix, then we do: + * + * B[0][0] + * B[0][1] + * B[1][0] + * B[1][1] + * + * We set the offset into the underlying storage as (storageOffset + stride_B * + * index_B), i.e. basically we compute the offset into the storage as we would + * normally for a Tensor. But because we are guaranteed the subsequent data is + * contiguous in memory, we can simply loop for sizeof(A) iterations and perform + * the operation, without having to follow the order described by the strides of + * A. + * + * 3. As an optimization, we merge dimensions of A that are contiguous in + * memory. For example, if A is a 3x3x3x3 tensor narrowed from a 3x3x4x3 tensor, + * then the first two dimensions can be merged for the purposes of APPLY, + * reducing the number of nested loops. + */ + +inline Tensor sort_strides(Tensor& tensor_) { + IntArrayRef strides = tensor_.strides(); + std::vector indices; + indices.reserve(tensor_.ndimension()); + for (const auto i : c10::irange(tensor_.ndimension())) { + indices.push_back(i); + } + std::sort(indices.begin(), indices.end(), [&strides](int64_t i1, int64_t i2) { + return strides[i1] > strides[i2]; + }); + Tensor tensor = tensor_.permute(indices); + return tensor; +} + +template +struct strided_tensor_iter_fixed { + public: + T* data_ = NULL; + int64_t dim_ = 0; + + // NOLINTNEXTLINE(*array*) + int64_t counter_[N] = {0}; + // NOLINTNEXTLINE(*array*) + int64_t sizes_[N] = {0}; + // NOLINTNEXTLINE(*array*) + int64_t strides_[N] = {0}; + + strided_tensor_iter_fixed(strided_tensor_iter_fixed const&) = delete; + strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed const& x) = + delete; + strided_tensor_iter_fixed(strided_tensor_iter_fixed&&) noexcept = default; + strided_tensor_iter_fixed& operator=(strided_tensor_iter_fixed&& x) noexcept = + default; + ~strided_tensor_iter_fixed() noexcept = default; + strided_tensor_iter_fixed( + Tensor& tensor, + [[maybe_unused]] bool sort_strides = false) + : data_(tensor.data_ptr()) { + std::memset(counter_, 0, sizeof(int64_t) * N); + if (tensor.dim() > 0) { + std::memcpy( + sizes_, tensor.sizes().data(), tensor.dim() * sizeof(int64_t)); + std::memcpy( + strides_, tensor.strides().data(), tensor.dim() * sizeof(int64_t)); + } + dim_ = std::get<1>(collapse_dims(sizes_, strides_, tensor.ndimension())); + } +}; + +template +struct strided_tensor_iter { + private: + public: + T* data_ = NULL; + int64_t dim_; + + std::vector counter_; + std::vector sizes_; + std::vector strides_; + + strided_tensor_iter(strided_tensor_iter const&) = delete; + strided_tensor_iter& operator=(strided_tensor_iter const& x) = delete; + strided_tensor_iter(strided_tensor_iter&&) noexcept = default; + strided_tensor_iter& operator=(strided_tensor_iter&&) noexcept = default; + ~strided_tensor_iter() noexcept = default; + strided_tensor_iter(Tensor& tensor) + : data_(tensor.data_ptr()), + dim_(tensor.ndimension()), + counter_(dim_, 0), + sizes_(tensor.sizes().vec()), + strides_(tensor.strides().vec()) { + dim_ = std::get<1>(collapse_dims(sizes_.data(), strides_.data(), dim_)); + } +}; + +inline bool _all_equal_numel(at::ArrayRef tensors) { + if (tensors.empty()) + return true; + int64_t all_numel = tensors[0].numel(); + for (const auto i : c10::irange(1, tensors.size())) { + if (tensors[i].numel() != all_numel) + return false; + } + return true; +} + +inline std::string _all_equal_numel_error(at::ArrayRef tensors) { + std::ostringstream oss; + oss << "inconsistent tensor size, expected "; + for (size_t i = 0; i < tensors.size() - 1; i++) { + oss << tensors[i].sizes() << ", "; + } + oss << "and " << tensors[tensors.size() - 1].sizes() + << " to have the same number of elements, but got "; + for (size_t i = 0; i < tensors.size() - 1; i++) { + oss << tensors[i].numel() << ", "; + } + oss << "and " << tensors[tensors.size() - 1].numel() + << " elements respectively"; + return oss.str(); +} + +inline bool _apply_preamble(ArrayRef tensors) { + checkDeviceType("CPU_tensor_apply", tensors, kCPU); + checkLayout("CPU_tensor_apply", tensors, kStrided); + if (!_all_equal_numel(tensors)) + TORCH_CHECK(false, _all_equal_numel_error(tensors)); + // An empty tensor has no elements + for (auto& t : tensors) + if (t.numel() == 0) + return false; + return true; +} + +inline int64_t _max_dim_tensors(ArrayRef tensors) { + int64_t dim = 0; + for (auto& t : tensors) + dim = std::max(dim, t.ndimension()); + return dim; +} + +inline void iterate(int64_t /*size*/) {} + +template +inline void iterate(int64_t size, Arg& iter, Args&... iter_tail) { + iter.counter_[iter.dim_ - 1] += size; + iter.data_ = iter.data_ + size * iter.strides_[iter.dim_ - 1]; + iterate(size, iter_tail...); +} + +inline bool iterate_continue() { + return true; +} + +template +inline bool iterate_continue(Arg& iter, Args&... iter_tail) { + return iter.counter_[iter.dim_ - 1] < iter.sizes_[iter.dim_ - 1] && + iterate_continue(iter_tail...); +} + +inline int64_t max_iterate_size() { + return std::numeric_limits::max(); +} + +template +inline int64_t max_iterate_size(Arg& iter, Args&... iter_tail) { + return std::min( + (iter.sizes_[iter.dim_ - 1] - iter.counter_[iter.dim_ - 1]), + max_iterate_size(iter_tail...)); +} + +inline void iterate_overflow() {} + +template +inline void iterate_overflow(Arg& iter, Args&... iter_tail) { + if (iter.counter_[iter.dim_ - 1] == iter.sizes_[iter.dim_ - 1]) { + for (int64_t i = iter.dim_ - 1; i > 0; i--) { + if (iter.counter_[i] == iter.sizes_[i]) { + iter.counter_[i] = 0; + iter.counter_[i - 1]++; + iter.data_ = iter.data_ - (iter.sizes_[i] * iter.strides_[i]) + + iter.strides_[i - 1]; + } + } + } + iterate_overflow(iter_tail...); +} + +inline void forward(int64_t /*offset*/) {} + +template +inline void forward(int64_t offset, Arg& iter, Args&... iter_tail) { + int64_t multi = offset; + for (int64_t i = iter.dim_ - 1; i >= 0; i--) { + int64_t inc = multi % iter.sizes_[i]; + multi = multi / iter.sizes_[i]; + iter.data_ = iter.data_ + inc * iter.strides_[i]; + iter.counter_[i] += inc; + } + forward(offset, iter_tail...); +} + +inline int64_t max_dim() { + return 0; +} + +template +inline int64_t max_dim(Arg& iter, Args&... iter_tail) { + return std::max(iter.dim_, max_dim(iter_tail...)); +} + +inline void apply_op() {} + +template +inline void apply_op( + int64_t numel, + int64_t offset, + const Op& op, + Args... iters) { + // For 0-dim tensors + if (numel == 1 && max_dim(iters...) == 0) { + op(*iters.data_...); + return; + } + if (offset > 0) + forward(offset, iters...); + // Splitting this into chunks helps the compiler create faster assembly + for (int64_t i = 0; i < numel;) { + for (; iterate_continue(iters...) && i < numel;) { + op(*iters.data_...); + iterate(1, iters...); + i++; + } + iterate_overflow(iters...); + } +} + +/* + Apply a pointwise operator to sequence of tensors + + The calling convention for op is a function/functor that takes the same + number of pointers of type scalar as the number of given tensors. For example, + to compute a = b * c, op would be of the form: + [](scalar* a_val, const scalar* b_val, const scalar* c_val) { a_val[0] = + b_val[0] * c_val[0]; }; +*/ + +template +inline void CPU_tensor_apply2(Tensor tensor1, Tensor tensor2, const Op op) { + if (!_apply_preamble({tensor1, tensor2})) + return; + if (_max_dim_tensors({tensor1, tensor2}) <= 8) { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter_fixed(tensor1), + strided_tensor_iter_fixed(tensor2)); + } else { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter(tensor1), + strided_tensor_iter(tensor2)); + } +} + +template +inline void CPU_tensor_apply3( + Tensor tensor1, + Tensor tensor2, + Tensor tensor3, + const Op op) { + if (!_apply_preamble({tensor1, tensor2, tensor3})) + return; + if (_max_dim_tensors({tensor1, tensor2, tensor3}) <= 8) { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter_fixed(tensor1), + strided_tensor_iter_fixed(tensor2), + strided_tensor_iter_fixed(tensor3)); + } else { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter(tensor1), + strided_tensor_iter(tensor2), + strided_tensor_iter(tensor3)); + } +} + +template < + typename scalar1, + typename scalar2, + typename scalar3, + typename scalar4, + typename Op> +inline void CPU_tensor_apply4( + Tensor tensor1, + Tensor tensor2, + Tensor tensor3, + Tensor tensor4, + const Op op) { + if (!_apply_preamble({tensor1, tensor2, tensor3, tensor4})) + return; + if (_max_dim_tensors({tensor1, tensor2, tensor3, tensor4}) <= 8) { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter_fixed(tensor1), + strided_tensor_iter_fixed(tensor2), + strided_tensor_iter_fixed(tensor3), + strided_tensor_iter_fixed(tensor4)); + } else { + apply_op( + tensor1.numel(), + 0, + op, + strided_tensor_iter(tensor1), + strided_tensor_iter(tensor2), + strided_tensor_iter(tensor3), + strided_tensor_iter(tensor4)); + } +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..edac70c86b20ebc320825d09ffaa2e09dc737853 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CPUFixedAllocator.h @@ -0,0 +1,33 @@ +#pragma once + +#include +#include + +// This file creates a fake allocator that just throws exceptions if +// it is actually used. + +// state passed to the allocator is the std::function called +// when the blob is release by ATen + +namespace at { + +static void* cpu_fixed_malloc(void*, ptrdiff_t) { + TORCH_CHECK(false, "attempting to resize a tensor view of an external blob"); +} + +static void* cpu_fixed_realloc(void*, void*, ptrdiff_t) { + TORCH_CHECK(false, "attempting to resize a tensor view of an external blob"); +} + +static void cpu_fixed_free(void* state, void* allocation) { + auto on_release = static_cast*>(state); + (*on_release)(allocation); + delete on_release; +} + +static Allocator CPU_fixed_allocator = { + cpu_fixed_malloc, + cpu_fixed_realloc, + cpu_fixed_free}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CPUFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/CPUFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..a69389fc98430268922e35a6a54fdd47340ce13c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CPUFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..fde296ab9c5313cbb0958c3a6aa91549551e66cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CPUFunctions_inl.h @@ -0,0 +1,543 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..a0f9cfea8a89c1a88d120f1b0005ee6a33e277f1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CPUGeneratorImpl.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { + +struct TORCH_API CPUGeneratorImpl : public c10::GeneratorImpl { + // Constructors + CPUGeneratorImpl(uint64_t seed_in = default_rng_seed_val); + ~CPUGeneratorImpl() override = default; + + // CPUGeneratorImpl methods + std::shared_ptr clone() const; + void set_current_seed(uint64_t seed) override; + void set_offset(uint64_t offset) override; + uint64_t get_offset() const override; + uint64_t current_seed() const override; + uint64_t seed() override; + void set_state(const c10::TensorImpl& new_state) override; + c10::intrusive_ptr get_state() const override; + static c10::DeviceType device_type(); + uint32_t random(); + uint64_t random64(); + std::optional next_float_normal_sample(); + std::optional next_double_normal_sample(); + void set_next_float_normal_sample(std::optional randn); + void set_next_double_normal_sample(std::optional randn); + at::mt19937 engine(); + void set_engine(at::mt19937 engine); + + private: + CPUGeneratorImpl* clone_impl() const override; + at::mt19937 engine_; + std::optional next_float_normal_sample_; + std::optional next_double_normal_sample_; +}; + +namespace detail { + +TORCH_API const Generator& getDefaultCPUGenerator(); +TORCH_API Generator +createCPUGenerator(uint64_t seed_val = default_rng_seed_val); + +} // namespace detail + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CUDAFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/CUDAFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..1f5bbea0bedeebde0f94ca194a86bcd6b9b359d6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CUDAFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..b4b504c1684d95a06027d865979b0e776ddbff7d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CUDAFunctions_inl.h @@ -0,0 +1,628 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CachedTensorUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/CachedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..573ac8e18c2548bde3f97e1489be102d9bd89f3d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CachedTensorUtils.h @@ -0,0 +1,24 @@ +#pragma once + +#include + +namespace at::caching { + +// Some systems (just cudagraphs currently) will persist a static tensor output +// whose TensorImpl does not change across iterations. For these tensors caching +// dtype conversions is invalid. Additionally, there will be an extra reference +// count to these cached tensors that would prevent buffer inplacing and other +// checks on tensor uniqueness. If we are not using these systems the enabled +// flag will be false and we will avoid the hash lookup. + +TORCH_API bool is_cached_tensor(const at::Tensor& t); +TORCH_API void add_cached_tensor(const at::Tensor& t); +TORCH_API void remove_cached_tensor(const at::Tensor& t); +TORCH_API void set_cached_tensors_enabled(bool enable); + +// For gradient buffer stealing we will adjust the use count of tensors +// which are persisted by cudagraphs, just as we need to adjust reference +// count of tensors with hooks. +TORCH_API size_t adjusted_use_count(const at::Tensor& t); + +} // namespace at::caching diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CollapseDims.h b/phivenv/Lib/site-packages/torch/include/ATen/CollapseDims.h new file mode 100644 index 0000000000000000000000000000000000000000..b7ca0d9db788470049ff8ce48a433217ffeb5cc3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CollapseDims.h @@ -0,0 +1,94 @@ +#include +#include + +namespace at { + +/* +[collapse dims] Updates sizes, and strides to reflect a "collapse" of +the info, possibly excluding the optional excludeDim. A "collapsed" version +of the info is the fewest dims that order the tensor's elements in the same +way as the original info. If excludeDim is specified, the collapse is the +fewest dims that order the tensor's elements as the original and preserve the +excluded dimension, unless the tensor collapses to a point. + +This function returns a pair of values. + +1) The (new) index of the preserved dimension if excludeDim is +specified. 0 if the tensor is collapsed to a point. -1 +otherwise. + +2) The new number of dimensions. +*/ +template +inline std::pair collapse_dims( + T* sizes, + T* strides, + int64_t dims, + const int excludeDim = -1) { + TORCH_CHECK( + excludeDim >= -1 && excludeDim < dims, + "expected excluded dim between -1 and dims - 1"); + + int64_t stopDim = (excludeDim == -1) ? dims : excludeDim; + int64_t newIndex = -1; + int64_t oldIndex = 0; + int64_t remappedExcludedDim = -1; + + while (oldIndex < dims) { + // Finds a dimension to collapse into + for (; oldIndex < stopDim; ++oldIndex) { + if (sizes[oldIndex] == 1) { + continue; + } + + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + ++oldIndex; + break; + } + + // Collapses dims + for (; oldIndex < stopDim; ++oldIndex) { + if (sizes[oldIndex] == 1) { + continue; + } + + if (strides[newIndex] == sizes[oldIndex] * strides[oldIndex]) { + sizes[newIndex] *= sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + } else { + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + } + } + + // Handles excludeDim being set (oldIndex == excludeDim) + if (oldIndex != dims) { + // Preserves excluded dimension + ++newIndex; + sizes[newIndex] = sizes[oldIndex]; + strides[newIndex] = strides[oldIndex]; + remappedExcludedDim = newIndex; + + // Restarts iteration after excludeDim + ++oldIndex; + stopDim = dims; + } + } + + // Handles special case of all dims size 1 + if (newIndex == -1 || (newIndex == 0 && sizes[0] == 1)) { + dims = 1; + sizes[0] = 1; + strides[0] = 1; + + return std::pair(0, 1); + } + + dims = newIndex + 1; + return std::pair(remappedExcludedDim, dims); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..6b3054f184930832ac7883e4b92f04f9c730b7b3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..8ff7a19a59bb5ab0d29e88443d36ed8ce881738d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradFunctions_inl.h @@ -0,0 +1,560 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..b042483720ead1988e522e9911b0ba2a33e3e916 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..22915229c825ab0cf1aa8488a3bdb67931b96601 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeExplicitAutogradNonFunctionalFunctions_inl.h @@ -0,0 +1,323 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..7a21e3494e8382662c9a4d6e88acd3ae0c2161c4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..a13630df1da6addacfe34b4c7729610bca8e52e0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradFunctions_inl.h @@ -0,0 +1,501 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..f16ec0929e44f496c0b6339442d8a1afe7f81be5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..d5b7c77fbca654b74b428193f1da16b348f6d325 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/CompositeImplicitAutogradNestedTensorFunctions_inl.h @@ -0,0 +1,25 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Config.h b/phivenv/Lib/site-packages/torch/include/ATen/Config.h new file mode 100644 index 0000000000000000000000000000000000000000..ac281f1b8d1aa974a88dcb496b0c636fef586e79 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Config.h @@ -0,0 +1,22 @@ +#pragma once + +// Test these using #if AT_MKL_ENABLED(), not #ifdef, so that it's +// obvious if you forgot to include Config.h +// c.f. https://stackoverflow.com/questions/33759787/generating-an-error-if-checked-boolean-macro-is-not-defined +// +// DO NOT put the macros for CUDA libraries in this file; they belong in cuda/CUDAConfig.h + +#define AT_MKLDNN_ENABLED() 1 +#define AT_MKLDNN_ACL_ENABLED() 0 +#define AT_MKL_ENABLED() 1 +#define AT_MKL_SEQUENTIAL() 0 +#define AT_POCKETFFT_ENABLED() 0 +#define AT_NNPACK_ENABLED() 0 +#define CAFFE2_STATIC_LINK_CUDA() 0 +#define AT_BUILD_WITH_BLAS() 1 +#define AT_BUILD_WITH_LAPACK() 1 +#define AT_PARALLEL_OPENMP 1 +#define AT_PARALLEL_NATIVE 0 +#define AT_BLAS_F2C() 0 +#define AT_BLAS_USE_CBLAS_DOT() 0 +#define AT_KLEIDIAI_ENABLED() 0 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Context.h b/phivenv/Lib/site-packages/torch/include/ATen/Context.h new file mode 100644 index 0000000000000000000000000000000000000000..82d86f20c3dbc7de53f5b88bd86cf16f05ae5863 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Context.h @@ -0,0 +1,648 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { + +class Tensor; + +enum class TORCH_API Float32MatmulPrecision { HIGHEST, HIGH, MEDIUM }; + +class TORCH_API Context { + public: + Context(); + + const Generator& defaultGenerator(Device device) { + c10::DeviceType device_type = device.type(); + lazyInitDevice(device_type); + + if (device_type == at::kCPU) { + return at::detail::getDefaultCPUGenerator(); + } else { + return getAcceleratorHooksInterface(device_type) + .getDefaultGenerator(device.index()); + } + } + + const AcceleratorHooksInterface& getAcceleratorHooksInterface( + std::optional opt_device_type = std::nullopt) { + if (!opt_device_type.has_value()) { + opt_device_type = at::getAccelerator(true); + } + if (opt_device_type == at::kCUDA) { + return at::detail::getCUDAHooks(); + } else if (opt_device_type == at::kXPU) { + return at::detail::getXPUHooks(); + } else if (opt_device_type == at::kMPS) { + return at::detail::getMPSHooks(); + } else if (opt_device_type == at::kPrivateUse1) { + return at::detail::getPrivateUse1Hooks(); + } else if (opt_device_type == at::kMTIA) { + return at::detail::getMTIAHooks(); + } else if (opt_device_type == at::kHIP) { + return at::detail::getHIPHooks(); + } else if (opt_device_type == at::kHPU) { + return at::detail::getHPUHooks(); + } else { + TORCH_CHECK( + false, + opt_device_type.has_value() + ? c10::DeviceTypeName(opt_device_type.value()) + : "None", + " device type not an accelerator."); + } + } + + Device getDeviceFromPtr(void* data, c10::DeviceType device_type) { + lazyInitDevice(device_type); + + if (device_type == at::kCPU) { + return c10::DeviceType::CPU; + } else { + return getAcceleratorHooksInterface(device_type).getDeviceFromPtr(data); + } + } + + bool isPinnedPtr( + const void* data, + std::optional device_type = std::nullopt) { + auto opt_device_type = + device_type.has_value() ? device_type : at::getAccelerator(); + if (!opt_device_type.has_value() || // there is no accelerator + !at::isAccelerator( + opt_device_type.value())) { // passed device not an accelerator + return false; + } + if (!init_[static_cast(opt_device_type.value())].test_once()) { + // If the device is not initialized, no pointer can be pinned for it + return false; + } + return getAcceleratorHooksInterface(opt_device_type).isPinnedPtr(data); + } + + Allocator* getPinnedMemoryAllocator( + std::optional device_type = std::nullopt) { + auto opt_device_type = + device_type.has_value() ? device_type : at::getAccelerator(); + if (opt_device_type) { + lazyInitDevice(opt_device_type.value()); + } + return getAcceleratorHooksInterface(device_type).getPinnedMemoryAllocator(); + } + + void lazyInitDevice(c10::DeviceType device_type) { + if (device_type != at::kCPU) { + c10::call_once(init_[static_cast(device_type)], [&] { + getAcceleratorHooksInterface(device_type).init(); + }); + } + } + + static bool hasOpenMP(); + static bool hasMKL(); + static bool hasKleidiAI(); + static bool hasLAPACK(); + static bool hasMKLDNN(); + static bool hasMAGMA() { + return detail::getCUDAHooks().hasMAGMA(); + } + static bool hasCUDA() { + return detail::getCUDAHooks().hasCUDA(); + } + static bool hasMTIA() { + return detail::getMTIAHooks().hasMTIA(); + } + static bool hasCUDART() { + return detail::getCUDAHooks().hasCUDART(); + } + static long versionCUDART() { + return detail::getCUDAHooks().versionCUDART(); + } + static bool hasCuDNN() { + return detail::getCUDAHooks().hasCuDNN(); + } + static long versionCuDNN() { + return detail::getCUDAHooks().versionCuDNN(); + } + static bool hasCuSOLVER() { + return detail::getCUDAHooks().hasCuSOLVER(); + } + static bool hasCuBLASLt() { + return detail::getCUDAHooks().hasCuBLASLt(); + } + static bool hasROCM() { + return detail::getCUDAHooks().hasROCM(); + } + static bool hasHIP() { + return detail::getHIPHooks().hasHIP(); + } + static bool hasMPS() { + return detail::getMPSHooks().hasMPS(); + } + static bool hasIPU() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::IPU); + } + static bool hasXLA() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::XLA); + } + static bool hasXPU() { + return detail::getXPUHooks().hasXPU(); + } + static bool hasLazy() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::Lazy); + } + static bool hasMAIA() { + return c10::impl::hasDeviceGuardImpl(c10::DeviceType::MAIA); + } + static bool hasHPU() { + return detail::getHPUHooks().hasHPU(); + } + + static const at::cuda::NVRTC& getNVRTC() { + return detail::getCUDAHooks().nvrtc(); + } + + static bool setFlushDenormal(bool on); + + // NB: This method is *purely* whether or not a user requested + // that CuDNN was enabled, it doesn't actually say anything about + // whether or not CuDNN is actually usable. Use cudnn_is_acceptable + // to test this instead + bool userEnabledCuDNN() const; + void setUserEnabledCuDNN(bool e); + bool userEnabledMkldnn() const; + void setUserEnabledMkldnn(bool e); + bool benchmarkCuDNN() const; + void setBenchmarkCuDNN(bool); + int benchmarkLimitCuDNN() const; + void setBenchmarkLimitCuDNN(int); + bool deterministicCuDNN() const; + void setDeterministicCuDNN(bool); + bool deterministicMkldnn() const; + void setDeterministicMkldnn(bool); + bool userEnabledNNPACK() const; + void setUserEnabledNNPACK(bool e); + + // Note [Disabling Fused SDP Kernels] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Flash and Memory Efficient SDP kernels are enabled by default. + // However, they can be disabled by setting + // at::globalContext().setUserEnabledFlashSDP(false) flag. + // This is useful for debugging purposes. For example, if you want to + // compare the performance of the flash SDP kernels with the unfused + // kernel, you can disable the flash SDP kernels. By disabling + // the math SDP kernel, you can force your code to use flash kernels. + // The math SDP kernel can be disabled by setting + // at::globalContext().setUserEnabledMathSDP(false) flag. + void setSDPPriorityOrder(const std::vector& order); + std::array sDPPriorityOrder(); + + void setSDPUseFlash(bool); + bool userEnabledFlashSDP() const; + + void setSDPUseMemEfficient(bool); + bool userEnabledMemEfficientSDP() const; + + void setSDPUseMath(bool); + bool userEnabledMathSDP() const; + + void setSDPUseCuDNN(bool); + bool userEnabledCuDNNSDP() const; + + void setAllowFP16BF16ReductionMathSDP(bool); + bool allowFP16BF16ReductionMathSDP() const; + + void setSDPUseOverrideable(bool); + bool userEnabledOverrideableSDP() const; + + at::LinalgBackend linalgPreferredBackend() const; + void setLinalgPreferredBackend(at::LinalgBackend); + + at::BlasBackend blasPreferredBackend(); + void setBlasPreferredBackend(at::BlasBackend); + + at::ROCmFABackend getROCmFAPreferredBackend() const; + void setROCmFAPreferredBackend(at::ROCmFABackend); + + // Note [Enabling Deterministic Operations] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Operations in PyTorch that normally act nondeterministically, but have an + // alternate deterministic implementation, should satisfy the following + // requirements: + // + // * Include this comment: "See Note [Enabling Deterministic Operations]" + // + // * Check the value of `at::globalContext().deterministicAlgorithms()` to + // toggle + // between nondeterministic and deterministic implementations. + // + // * Have an entry in the list of PyTorch operations that toggle between + // nondeterministic + // and deterministic implementations, in the docstring of + // `use_deterministic_algorithms()` in torch/__init__.py + // + // `example_func()` below shows an example of toggling between + // nondeterministic and deterministic implementations: + // + // void example_func() { + // // See Note [Enabling Deterministic Operations] + // if (at::globalContext().deterministicAlgorithms()) { + // example_func_deterministic(); + // } else { + // example_func_nondeterministic(); + // } + // } + + bool deterministicAlgorithms() const; + bool deterministicAlgorithmsWarnOnly() const; + void setDeterministicAlgorithms(bool, bool); + bool deterministicFillUninitializedMemory() const; + void setDeterministicFillUninitializedMemory(bool); + + // Note [Writing Nondeterministic Operations] + // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + // Operations in PyTorch that act nondeterministically and do not have an + // alternate deterministic implementation should satisfy the following + // requirements: + // + // * Include this comment: "See Note [Writing Nondeterministic Operations]" + // + // * Include a comment explaining why the operation is nondeterministic. + // + // * Throw an error when `Context::deterministicAlgorithms()` is true. Most + // of the time, this should be accomplished by calling + // `at::globalContext().alertNotDeterminstic()`. However, if the + // nondeterministic behavior is caused by the CuBLAS workspace + // configuration in CUDA >= 10.2, + // `at::globalContext().alertCuBLASConfigNotDeterministic()` should be + // called instead (in this case, a comment explaining why the operation is + // nondeterministic is not necessary). See below for details on these + // methods. + // + // * Have an entry in the list of nondeterministic PyTorch operations in the + // docstring of `use_deterministic_algorithms()` in torch/__init__.py + // + // * Have a test function in `test/test_torch.py` whose name begins with + // `test_nondeterministic_alert_`. Alternatively, if CuBLAS workspace + // configuration is the reason for nondeterminism, the operation should be + // included in the `test_cublas_config_nondeterministic_alert` test. Any new + // tests should ideally follow a pattern similar to the existing ones. + // + // `example_func()` below shows an example of the comments and error-throwing + // code for a nondeterministic operation: + // + // void example_func() { + // // See Note [Writing Nondeterministic Operations] + // // Nondeterministic because + // at::globalContext().alertNondeterministic("example_func"); + // ... + // } + + // Throws an error if `Context::deterministicAlgorithms()` is true + static void alertNotDeterministic(std::string_view const& caller); + + // Throws an error if `Context::deterministicAlgorithms()` is true, CUDA + // >= 10.2, and CUBLAS_WORKSPACE_CONFIG is not set to either ":16:8" or + // ":4096:8". For more details: + // https://docs.nvidia.com/cuda/cublas/index.html#results-reproducibility + void alertCuBLASConfigNotDeterministic() const; + + void setFloat32MatmulPrecision(const std::string& s); + bool allowTF32CuDNN() const; + void setAllowTF32CuDNN(bool); + bool allowTF32OneDNN() const; + void setAllowTF32OneDNN(bool); + bool allowTF32CuBLAS() const; + void setAllowTF32CuBLAS(bool); + Float32MatmulPrecision float32MatmulPrecision() const; + void setFloat32MatmulPrecision(Float32MatmulPrecision p); + bool allowFP16ReductionCuBLAS() const; + void setAllowFP16ReductionCuBLAS(bool); + bool allowBF16ReductionCuBLAS() const; + void setAllowBF16ReductionCuBLAS(bool); + bool allowFP16AccumulationCuBLAS() const; + void setAllowFP16AccumulationCuBLAS(bool); + + // Matmuls can use a so-called "persistent" kernel which launches one CUDA + // block for each SM on the GPU, and each block then iterates over multiple + // output tiles. This allows to use software pipelining to hide the begin/end + // latencies (e.g., epilogue), especially when only one tile fits per SM. + // However, if some SMs are busy (e.g., with a background NCCL kernel), the + // matmul's blocks will be scheduled in two waves and, in the absence of some + // smart load balancing, the kernel will take twice as long. This flag allows + // to make matmuls target only a subset of the SMs, so they can fully schedule + // even next to a comms kernel, and only be a few percent slower. + std::optional _SMCarveout_EXPERIMENTAL() const; + void _setSMCarveout_EXPERIMENTAL(std::optional); + + at::QEngine qEngine() const; + void setQEngine(at::QEngine e); + static const std::vector& supportedQEngines(); + static bool isXNNPACKAvailable(); + void setCheckSparseTensorInvariants(bool e); + bool checkSparseTensorInvariants() const; + // This method is used to release the original weight after pre-packing. + // It should be called once before loading/running the model. + // NB: By default it is set to true for mobile builds. + void setReleaseWeightsWhenPrepacking(bool e); + bool releaseWeightsWhenPrepacking() const; + + void setDisplayVmapFallbackWarnings(bool enabled); + bool areVmapFallbackWarningsEnabled() const; + + bool isDefaultMobileCPUAllocatorSet(); + void setDefaultMobileCPUAllocator(); + void unsetDefaultMobileCPUAllocator(); + bool allowFP16ReductionCPU() const; + void setAllowFP16ReductionCPU(bool); + + // Preserved for BC + void lazyInitCUDA() { + TORCH_WARN_DEPRECATION( + "lazyInitCUDA is deprecated. Please use lazyInitDevice(at::kCUDA) instead.") + lazyInitDevice(at::kCUDA); + } + void lazyInitHIP() { + TORCH_WARN_DEPRECATION( + "lazyInitHIP is deprecated. Please use lazyInitDevice(at::kHIP) instead.") + lazyInitDevice(at::kHIP); + } + void lazyInitXPU() { + TORCH_WARN_DEPRECATION( + "lazyInitXPU is deprecated. Please use lazyInitDevice(at::kXPU) instead.") + lazyInitDevice(at::kXPU); + } + void lazyInitMTIA() { + TORCH_WARN_DEPRECATION( + "lazyInitMTIA is deprecated. Please use lazyInitDevice(at::kMTIA) instead.") + lazyInitDevice(at::kMTIA); + } + void lazyInitPrivateUse1() { + TORCH_WARN_DEPRECATION( + "lazyInitPrivateUse1 is deprecated. Please use lazyInitDevice(at::kPrivateUse1) instead.") + lazyInitDevice(at::kPrivateUse1); + } + + private: + static bool checkCuBLASConfigDeterministic(); + std::array init_; + bool enabled_cudnn = true; + bool deterministic_cudnn = false; + bool deterministic_mkldnn = false; + bool _deterministic_algorithms = false; + bool _deterministic_algorithms_warn_only = false; + bool _deterministic_fill_uninitialized_memory = true; + std::array sdp_priority_order = { + at::SDPBackend::flash_attention, + at::SDPBackend::efficient_attention, + at::SDPBackend::math, + at::SDPBackend::cudnn_attention}; + bool enabled_flashSDP = true; + bool enabled_mem_efficientSDP = true; + bool enabled_mathSDP = true; + bool enabled_cudnnSDP = true; + bool enabled_overrideable = true; + bool allow_fp16_bf16_reduction_mathSDP = false; + bool benchmark_cudnn = false; + Float32MatmulPrecision float32_matmul_precision = + c10::utils::check_env("TORCH_ALLOW_TF32_CUBLAS_OVERRIDE") == true + ? at::Float32MatmulPrecision::HIGH + : at::Float32MatmulPrecision::HIGHEST; + int benchmark_limit_cudnn = 10; + bool allow_tf32_cudnn = true; + bool allow_fp16_reduction_cublas = true; + bool allow_bf16_reduction_cublas = true; + bool allow_fp16_accumulation_cublas = false; + std::optional sm_carveout = std::nullopt; + bool enabled_mkldnn = true; + bool allow_tf32_onednn = false; + bool enabled_nnpack = true; + at::LinalgBackend linalg_preferred_backend = + (c10::utils::check_env("TORCH_LINALG_PREFER_CUSOLVER") == true || + c10::utils::check_env("TORCH_LINALG_PREFER_HIPSOLVER") == true) // alias + ? at::LinalgBackend::Cusolver + : at::LinalgBackend::Default; + at::BlasBackend blas_preferred_backend = + (c10::utils::check_env("TORCH_BLAS_PREFER_CUBLASLT") == true || + c10::utils::check_env("TORCH_BLAS_PREFER_HIPBLASLT") == true) // alias + ? at::BlasBackend::Cublaslt + : at::BlasBackend::Default; + at::ROCmFABackend rocm_fa_preferred_backend = + c10::utils::check_env("TORCH_ROCM_FA_PREFER_CK") == true + ? at::ROCmFABackend::Ck + : at::ROCmFABackend::Default; +#ifdef C10_MOBILE + bool release_original_weights = true; +#else + bool release_original_weights = false; +#endif + bool display_vmap_fallback_warnings_ = false; + std::optional quantized_engine = std::nullopt; + bool enable_sparse_tensor_invariant_checks = false; + bool allow_fp16_reduction_cpu = false; + + Allocator* prev_allocator_ptr_{nullptr}; +}; + +TORCH_API Context& globalContext(); + +inline void init() { + globalContext(); +} + +TORCH_API Allocator* getCPUAllocator(); + +inline DeprecatedTypeProperties& getDeprecatedTypeProperties( + Backend p, + ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + p, s); +} + +inline DeprecatedTypeProperties& CPU(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::CPU, s); +} + +inline DeprecatedTypeProperties& CUDA(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::CUDA, s); +} + +inline DeprecatedTypeProperties& HIP(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::HIP, s); +} + +inline DeprecatedTypeProperties& MPS(ScalarType s) { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + Backend::MPS, s); +} + +inline bool hasCUDA() { + return globalContext().hasCUDA(); +} + +inline bool hasMTIA() { + return globalContext().hasMTIA(); +} + +inline bool hasHIP() { + return globalContext().hasHIP(); +} + +inline bool hasIPU() { + return globalContext().hasIPU(); +} + +inline bool hasXLA() { + return globalContext().hasXLA(); +} + +inline bool hasMPS() { + return globalContext().hasMPS(); +} + +inline bool hasMAIA() { + return globalContext().hasMAIA(); +} + +inline bool hasXPU() { + return globalContext().hasXPU(); +} + +inline bool hasHPU() { + return globalContext().hasHPU(); +} + +// Despite its name, this function returns the number of *CUDA* GPUs. +inline size_t getNumGPUs() { + // WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS + // FUNCTION. If you are interested in interrogating the number of + // devices for a specific device type, add that function to the + // relevant library (e.g., similar to at::cuda::device_count()) + if (hasCUDA() && hasHIP()) { + TORCH_CHECK( + false, + "Enabling both CUDA and HIP in ATen is not supported, as HIP masquerades " + "to be CUDA (e.g., when you say CUDA, on a HIP build of ATen, this actually " + "means HIP. Rebuild PyTorch with one or the other disabled."); + } else if (hasCUDA()) { + return detail::getCUDAHooks().deviceCount(); + } else if (hasHIP()) { + return detail::getHIPHooks().getNumGPUs(); + } else { + return 0; + } +} + +inline bool hasOpenMP() { + return globalContext().hasOpenMP(); +} + +inline bool hasMKL() { + return globalContext().hasMKL(); +} + +inline bool hasKleidiAI() { + return globalContext().hasKleidiAI(); +} + +inline bool hasLAPACK() { + return globalContext().hasLAPACK(); +} + +inline bool hasMAGMA() { + return globalContext().hasMAGMA(); +} + +inline bool hasMKLDNN() { + return globalContext().hasMKLDNN(); +} + +inline void manual_seed(uint64_t seed) { + { + auto gen = globalContext().defaultGenerator(c10::DeviceType::CPU); + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_current_seed(seed); + } + + const auto opt_device_type = at::getAccelerator(); + if (!opt_device_type.has_value()) { + return; + } + const auto num_gpus = globalContext() + .getAcceleratorHooksInterface(opt_device_type) + .deviceCount(); + for (const auto i : c10::irange(num_gpus)) { + auto gen = globalContext().defaultGenerator( + Device(opt_device_type.value(), static_cast(i))); + { + // See Note [Acquire lock when using random generators] + std::lock_guard lock(gen.mutex()); + gen.set_current_seed(seed); + } + } +} + +// When the global flag `allow_tf32` is set to true, cuBLAS handles are +// automatically configured to use math mode CUBLAS_TF32_TENSOR_OP_MATH. +// For some operators, such as addmv, TF32 offers no performance improvement +// but causes precision loss. To help this case, this class implements +// a RAII guard that can be used to quickly disable TF32 within its scope. +// +// Usage: +// NoTF32Guard disable_tf32; +struct TORCH_API NoTF32Guard { + NoTF32Guard(); + NoTF32Guard(NoTF32Guard&& other) = delete; + NoTF32Guard(const NoTF32Guard&) = delete; + NoTF32Guard& operator=(const NoTF32Guard&) = delete; + NoTF32Guard& operator=(NoTF32Guard&&) = delete; + ~NoTF32Guard(); + static bool should_disable_tf32(); + + private: + bool changed = false; +}; + +struct TORCH_API ROCmBackwardPassGuard { + ROCmBackwardPassGuard(); + ROCmBackwardPassGuard(ROCmBackwardPassGuard&& other) = delete; + ROCmBackwardPassGuard(const ROCmBackwardPassGuard&) = delete; + ROCmBackwardPassGuard& operator=(const ROCmBackwardPassGuard&) = delete; + ROCmBackwardPassGuard& operator=(ROCmBackwardPassGuard&&) = delete; + ~ROCmBackwardPassGuard(); + static bool is_backward_pass(); +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/DLConvertor.h b/phivenv/Lib/site-packages/torch/include/ATen/DLConvertor.h new file mode 100644 index 0000000000000000000000000000000000000000..481d30711d432a1697e71ad37a1aa6fef1668de1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/DLConvertor.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include +#include + +// this convertor will: +// 1) take a Tensor object and wrap it in the DLPack tensor +// 2) take a dlpack tensor and convert it to the ATen Tensor + +namespace at { + +TORCH_API ScalarType toScalarType(const DLDataType& dtype); +TORCH_API DLManagedTensor* toDLPack(const Tensor& src); +TORCH_API Tensor fromDLPack(DLManagedTensor* src); +TORCH_API Tensor +fromDLPack(DLManagedTensor* src, std::function deleter); +TORCH_API DLDataType getDLDataType(const Tensor& t); +TORCH_API DLDevice getDLContext(const Tensor& tensor, const int64_t& device_id); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Device.h b/phivenv/Lib/site-packages/torch/include/ATen/Device.h new file mode 100644 index 0000000000000000000000000000000000000000..77626cce2465850485e137b148845ee38b9ebb4d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Device.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h b/phivenv/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h new file mode 100644 index 0000000000000000000000000000000000000000..deed2c93fe3f6c0dc845dcaac72d8665fe40e261 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/DeviceAccelerator.h @@ -0,0 +1,81 @@ +#pragma once + +#include +#include + +#include +#include + +namespace at::accelerator { + +// Note [Accelerator Concept] +// This file defines the top level Accelerator concept for PyTorch. +// A device is an accelerator per the definition here if: +// - It is mutually exclusive with all other accelerators +// - It performs asynchronous compute via a Stream/Event system +// - It provides a set of common APIs as defined by AcceleratorHooksInterface +// +// As of today, accelerator devices are (in no particular order): +// CUDA, MTIA, XPU, HIP, MPS, PrivateUse1 + +// Ensures that only one accelerator is available (at +// compile time if possible) and return it. +// When checked is true, the returned optional always has a value. +TORCH_API std::optional getAccelerator(bool checked = false); + +// Check if the given device type is an accelerator. +TORCH_API bool isAccelerator(c10::DeviceType device_type); + +// Check if the given device type is an accelerator, not the excluded ones. +template < + typename... T, + typename = std::enable_if_t<(std::is_same_v && ...)>> +TORCH_API inline bool isAcceleratorExcluded( + c10::DeviceType device_type, + c10::DeviceType first_excluded, + T... rest_excluded) { + if constexpr (sizeof...(rest_excluded) > 0) { + return device_type != first_excluded && + isAcceleratorExcluded(device_type, rest_excluded...); + } else { + return device_type != first_excluded && isAccelerator(device_type); + } +} + +// Return the number of the device available. Note that this is *REQUIRED* to +// not raise any exception. +TORCH_API c10::DeviceIndex deviceCount(); + +// Set the current device index to the given device index. +TORCH_API void setDeviceIndex(c10::DeviceIndex device_index); + +// Get the current device index. +TORCH_API c10::DeviceIndex getDeviceIndex(); + +// Set the current stream to a given stream. Note that this API doesn't change +// the current device index. +TORCH_API void setCurrentStream(c10::Stream stream); + +// Get the current stream of the given device index. +TORCH_API c10::Stream getCurrentStream(c10::DeviceIndex device_index); + +// Wait (by blocking the calling thread) until all the work previously enqueued +// on the given device index has been completed. +TORCH_API void synchronizeDevice(c10::DeviceIndex device_index); + +// Set the current device index to the given device_index and return the +// original device index that was active before the change. +TORCH_API c10::DeviceIndex exchangeDevice(c10::DeviceIndex device_index); + +// Set the current device index to the given device_index. Avoid creating a new +// context if the context for device_index is not initialized. Return the +// original device index that was active before the change. +TORCH_API c10::DeviceIndex maybeExchangeDevice(c10::DeviceIndex device_index); + +} // namespace at::accelerator + +namespace at { +// Keep BC only +using at::accelerator::getAccelerator; +using at::accelerator::isAccelerator; +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/DeviceGuard.h b/phivenv/Lib/site-packages/torch/include/ATen/DeviceGuard.h new file mode 100644 index 0000000000000000000000000000000000000000..234e00b06c4091c632c5f4637915c8919f276a7c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/DeviceGuard.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include +#include // TensorList whyyyyy + +namespace at { + +// Are you here because you're wondering why DeviceGuard(tensor) no +// longer works? For code organization reasons, we have temporarily(?) +// removed this constructor from DeviceGuard. The new way to +// spell it is: +// +// OptionalDeviceGuard guard(device_of(tensor)); + +/// Return the Device of a Tensor, if the Tensor is defined. +inline std::optional device_of(const Tensor& t) { + if (t.defined()) { + return t.device(); + } else { + return std::nullopt; + } +} + +inline std::optional device_of(const std::optional& t) { + return t.has_value() ? device_of(t.value()) : std::nullopt; +} + +/// Return the Device of a TensorList, if the list is non-empty and +/// the first Tensor is defined. (This function implicitly assumes +/// that all tensors in the list have the same device.) +inline std::optional device_of(ITensorListRef t) { + if (!t.empty()) { + return device_of(t.front()); + } else { + return std::nullopt; + } +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/DimVector.h b/phivenv/Lib/site-packages/torch/include/ATen/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..0a854a378782824f756ff054d39965d259054351 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/DimVector.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Dimname.h b/phivenv/Lib/site-packages/torch/include/ATen/Dimname.h new file mode 100644 index 0000000000000000000000000000000000000000..9a93a8e38f8f25d42131a320ecf54a55c59bb481 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Dimname.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Dispatch.h b/phivenv/Lib/site-packages/torch/include/ATen/Dispatch.h new file mode 100644 index 0000000000000000000000000000000000000000..b07293542707cd69828c21757de1ae601375fa88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Dispatch.h @@ -0,0 +1,807 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#ifdef __CUDACC__ +#include // For CUDA_VERSION +#endif + +#ifdef TEMPLATE_SELECTIVE_BUILD +#include +#else +namespace at { +/** + * The method should_include_kernel_dtype() returns true/false + * based on whether the switching code for a specific dtype should be + * included based on build time constants generated from tracing model + * execution. This method will be implemented via code-generation and + * included in this file when code-gen is ready. + */ +inline constexpr bool should_include_kernel_dtype( + const char* /*kernel_tag_str*/, + at::ScalarType /*scalar_type*/ +) { + return true; +} +} // namespace at +#endif + +/** + * In the Facebook internal build (using BUCK), this macro is enabled by + * passing in -c pt.enable_record_kernel_dtype=1 when building the tracer + * binary. + */ +#if defined ENABLE_RECORD_KERNEL_FUNCTION_DTYPE +namespace at::detail { +TORCH_API void record_kernel_function_dtype(std::string name); +} // namespace at::detail + +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) \ + at::detail::record_kernel_function_dtype( \ + std::string(NAME) + "$" + toString(enum_type)); +#else +#define RECORD_KERNEL_FUNCTION_DTYPE(NAME, enum_type) +#endif + +#define AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type) \ + do { \ + if constexpr (!at::should_include_kernel_dtype( \ + at_dispatch_name, enum_type)) { \ + TORCH_CHECK( \ + false, \ + "dtype '", \ + toString(enum_type), \ + "' not selected for kernel tag ", \ + at_dispatch_name); \ + } \ + } while (0) + +#define AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, HINT, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using HINT [[maybe_unused]] = c10::impl::ScalarTypeToCPPTypeT; \ + return __VA_ARGS__(); \ + } + +#define AT_DISPATCH_CASE(enum_type, ...) \ + AT_PRIVATE_CASE_TYPE_USING_HINT(enum_type, scalar_t, __VA_ARGS__) + +#define AT_DISPATCH_CASE_QINT(enum_type, scalar_type, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \ + [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \ + [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \ + return __VA_ARGS__(); \ + } + +#define AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + enum_type, scalar_type, bitwidth, qmin, qmax, ...) \ + case enum_type: { \ + AT_PRIVATE_CHECK_SELECTIVE_BUILD(enum_type); \ + using scalar_t = scalar_type; \ + using underlying_t [[maybe_unused]] = typename scalar_t::underlying; \ + [[maybe_unused]] const auto& SCALAR_TYPE = enum_type; \ + [[maybe_unused]] const auto& UNDERLYING_TYPE = toUnderlying(enum_type); \ + [[maybe_unused]] int bit_width = bitwidth; \ + [[maybe_unused]] int64_t quant_min = qmin; \ + [[maybe_unused]] int64_t quant_max = qmax; \ + return __VA_ARGS__(); \ + } + +namespace detail { + +inline at::ScalarType scalar_type(at::ScalarType s) { + return s; +} + +} // namespace detail + +// The AT_DISPATCH_* family of macros provides the ability to +// conveniently generate specializations of a kernel over all of the +// dtypes we care about in PyTorch. We call it "dispatch" because +// we are "dispatching" to the correct, dtype-specific kernel. +// +// A standard usage looks like: +// +// AT_DISPATCH_ALL_TYPES(self.scalar_type(), "op_name", [&] { +// // Your code here, with 'scalar_t' now defined to +// // be the dtype in question +// }); +// +// There are many variations of this macro, so it's important to +// understand exactly /which/ dtypes you want to get instantiated, as +// well as what the "default" set is. +// +// The default set of dtypes that are instantiated (e.g., by +// AT_DISPATCH_ALL_TYPES) are floating point types (float, double), +// and integral types (int32_t, int64_t, int16_t, int8_t, uint8_t), +// but NOT booleans (bool), half-precision floats (Half) or +// complex number (c10::complex, c10::complex). +// This "cut" is somewhat historical (the default types are the +// ones that TH historically supported), but it also reflects the +// fact that the non-default types are "poorly" behaved (booleans +// are NOT integers mod 2, half precision operations ~essentially +// don't exist on CPU, complex numbers are an experimental application). +// +// Here are the questions you should generally ask to decide which +// dispatch you want: +// +// 1. Is this an integral or floating point specific operation? +// (If so, you'll want one of the FLOATING or INTEGRAL macros.) +// +// 2. Should half be supported? (If you're on CPU, the answer is almost +// definitely no. If you do want support, use one of the AND_HALF +// macros) +// +// Much rarer situations: +// +// 3. Should bool be supported? (You often have to write your kernel +// differently if arithmetic operations are involved.) If so, +// Use AT_DISPATCH_ALL_TYPES_AND along with ScalarType::Bool +// +// 4. Should complex be supported? The answer is almost always no, +// unless you are working on "generic" code that should work on +// all dtypes. +// +// Parameters: +// ----------- +// +// 1. The NAME argument is a "tag" that is used to trace and then +// conditionally compile fragments of the case statements such +// that the kernel functions are specialized only for the dtypes +// that are needed. The NAME parameter *must* be a build time +// const char* (can't be std::string, etc...) +// +// Please ensure that the NAME is unique for every implementation +// or you run the risk of over-including code for the kernel +// functions. There is no risk of missing out on any code, so +// it's mostly a risk of a Type-2 error, and not a Type-1 error. +// +// Switch-like syntax: +// ------------------- +// There is also a switch-case like syntax which is useful if a kernel +// needs to be specialized for particular scalar types +// +// AT_DISPATCH_SWITCH(self.scalar_type(), "op_name", +// AT_DISPATCH_CASE_INTEGRAL_TYPES([&] { +// op_integral(iter); +// }) +// AT_DISPATCH_CASE_FLOATING_TYPES([&] { +// op_floating(iter); +// }) +// AT_DISPATCH_CASE(kBool, [&] { +// op_bool(iter); +// }) +// ); +// +// For each AT_DISPATCH_FOO macro, there is a corresponding +// AT_DISPATCH_CASE_FOO macro which can be used inside of an +// AT_DISPATCH_SWITCH block. + +// NB: the the_type variable is not used, but we have kept it for +// backwards compatibility. It's probably not used by anyone though; +// but we're just being safe (and it doesn't hurt.) Note we must +// use it to shut up warnings about unused store. + +#define AT_DISPATCH_SWITCH(TYPE, NAME, ...) \ + [&] { \ + const auto& the_type = TYPE; \ + constexpr const char* at_dispatch_name = NAME; \ + /* don't use TYPE again in case it is an expensive or side-effect op */ \ + at::ScalarType _st = ::detail::scalar_type(the_type); \ + RECORD_KERNEL_FUNCTION_DTYPE(at_dispatch_name, _st); \ + switch (_st) { \ + __VA_ARGS__ \ + default: \ + TORCH_CHECK_NOT_IMPLEMENTED( \ + false, \ + '"', \ + at_dispatch_name, \ + "\" not implemented for '", \ + toString(_st), \ + "'"); \ + } \ + }() + +#define AT_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(...) \ + AT_DISPATCH_CASE(at::ScalarType::Double, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND_HALF(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_FLOATING_TYPES_AND_HALF(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define AT_DISPATCH_REDUCED_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_COMPLEX_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::ComplexDouble, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::ComplexFloat, __VA_ARGS__) + +#define AT_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_COMPLEX_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_COMPLEX_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(...) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1( \ + SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND1( \ + SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + ...) \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) + +#define AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_FLOATING_AND_COMPLEX_TYPES_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) + +#define AT_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_INTEGRAL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_INTEGRAL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES(...) \ + AT_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_TYPES(...) \ + AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQInt32, at::qint32, __VA_ARGS__) + +#define AT_DISPATCH_QINT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_QINT_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_QINT_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_QINT_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_BYTE_TYPES(...) \ + AT_DISPATCH_CASE_QINT(at::kQInt8, at::qint8, __VA_ARGS__) \ + AT_DISPATCH_CASE_QINT(at::kQUInt8, at::quint8, __VA_ARGS__) + +#define AT_DISPATCH_QINT_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_QINT_BYTE_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(...) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt8, at::qint8, CHAR_BIT, SCHAR_MIN, SCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt8, at::quint8, CHAR_BIT, 0, UCHAR_MAX, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQInt32, \ + at::qint32, \ + CHAR_BIT * sizeof(int), \ + INT_MIN, \ + INT_MAX, \ + __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt4x2, at::quint4x2, 4, 0, 15, __VA_ARGS__) \ + AT_QINT_SUB_BYTE_PRIVATE_CASE_TYPE( \ + at::kQUInt2x4, at::quint2x4, 2, 0, 3, __VA_ARGS__) + +#define AT_DISPATCH_QINT_AND_SUB_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_QINT_AND_SUB_BYTE_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE_COMPLEX_TYPES(__VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, NAME, AT_DISPATCH_CASE_ALL_TYPES_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND(SCALARTYPE, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND2(SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND2( \ + SCALARTYPE1, SCALARTYPE2, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_ALL_TYPES(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND3( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \ + SCALARTYPE1, SCALARTYPE2, SCALARTYPE3, SCALARTYPE4, SCALARTYPE5, ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND5( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND6( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND7( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + SCALARTYPE8, \ + ...) \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX(__VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE1, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE2, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE3, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE4, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE5, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE6, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE7, __VA_ARGS__) \ + AT_DISPATCH_CASE(SCALARTYPE8, __VA_ARGS__) + +#define AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND8( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + SCALARTYPE8, \ + TYPE, \ + NAME, \ + ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND8( \ + SCALARTYPE1, \ + SCALARTYPE2, \ + SCALARTYPE3, \ + SCALARTYPE4, \ + SCALARTYPE5, \ + SCALARTYPE6, \ + SCALARTYPE7, \ + SCALARTYPE8, \ + __VA_ARGS__)) + +#define AT_DISPATCH_CASE_BIT_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Bits1x8, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits2x4, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits4x2, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits8, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Bits16, __VA_ARGS__) + +#define AT_DISPATCH_BIT_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_DISPATCH_CASE_BIT_TYPES(__VA_ARGS__)) + +#define AT_DISPATCH_INDEX_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Int, index_t, __VA_ARGS__) \ + AT_PRIVATE_CASE_TYPE_USING_HINT( \ + at::ScalarType::Long, index_t, __VA_ARGS__)) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Dispatch_v2.h b/phivenv/Lib/site-packages/torch/include/ATen/Dispatch_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..65363f26ad7d81e8d89c5fc0d7167342e90decd2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Dispatch_v2.h @@ -0,0 +1,202 @@ +#include + +// This is a new implementation of the AT_DISPATCH macro family from +// ATen/Dispatch.h +// +// The intended usage is: +// +// ScalarType scalar_type; +// +// AT_DISPATCH_V2( +// scalar_type, +// "debug string", +// AT_WRAP([&] { +// ... code to specialize with scalar_t ... +// }), +// kHalf, +// AT_EXPAND(AT_ALL_TYPES), +// ... as many types arguments as needed ... +// ) +// +// For example, given an old style: +// +// AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2( +// kComplexHalf, +// kHalf, +// self.scalar_type(), +// "_local_scalar_dense_cpu", +// [&] { +// scalar_t value = *self.data_ptr(); +// r = Scalar(value); +// } +// ) +// +// You now write: +// +// AT_DISPATCH_V2( +// self.scalar_type(), +// "_local_scalar_dense_cpu", +// AT_WRAP([&] { +// scalar_t value = *self.data_ptr(); +// r = Scalar(value); +// }), +// AT_EXPAND(AT_ALL_TYPES), +// AT_EXPAND(AT_COMPLEX_TYPES), +// kComplexHalf, +// kHalf, +// ) +// +// Notably, it sports the following improvements: +// +// - It is not necessary to specify the arity (e.g., +// AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND{2,3,4,...}) +// when using the macro +// +// - It is not necessary to specify each dtype individually; if +// there is a set of related dtypes and you want to dispatch +// over all of them, you can simply say, e.g., AT_EXPAND(AT_INTEGRAL_TYPES) +// in your argument list. +// +// However, you must remember to wrap the payload body in AT_WRAP, or commas +// inside your lambda will be improperly handled. Furthermore, if you more +// entries to ScalarType than can be supported by this macro, it will fail +// with an obscure error (due to attempting to concatenate AT_AP with +// something that is not a number). +// +// The implementation strategy is to use the count arguments trick +// (e.g., as described in https://stackoverflow.com/a/2124385/23845) +// to discover how many dtypes have been passed, and then dispatch to a +// hand-written macro for each arity that applies as many DISPATCH_CASE as +// necessary. The hand-written macros can be regenerated for other arities +// with the script below. +// +// There is some delicacy in the implementation in controlling when +// macro expansion occurs, mediated with AT_EXPAND and AT_GUARD. I mostly +// relied on GPT4 to help me get it right. + +// Public API macros + +// See documentation above +#define AT_DISPATCH_V2(TYPE, NAME, BODY, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, AT_AP_VAR(AT_WRAP(BODY), TYPE, __VA_ARGS__)) + +// This macro lets you pass an arbitrary expression that may contain internal +// commas to another macro without having the commas causing the expression +// to be interpreted as being multiple arguments +#define AT_WRAP(...) __VA_ARGS__ + +#define AT_FLOAT8_TYPES \ + c10::kFloat8_e5m2, c10::kFloat8_e5m2fnuz, c10::kFloat8_e4m3fn, \ + c10::kFloat8_e4m3fnuz, c10::kFloat8_e8m0fnu + +#define AT_INTEGRAL_TYPES \ + c10::kByte, c10::kChar, c10::kInt, c10::kLong, c10::kShort +#define AT_FLOATING_TYPES c10::kDouble, c10::kFloat +#define AT_BAREBONES_UNSIGNED_TYPES c10::kUInt16, c10::kUInt32, c10::kUInt64 +#define AT_INTEGRAL_TYPES_V2 \ + AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES) +#define AT_COMPLEX_TYPES c10::kComplexDouble, c10::kComplexFloat +#define AT_QINT_TYPES c10::kQInt8, c10::kQUInt8, c10::kQInt32 +// NB: not *actually* all types +#define AT_ALL_TYPES AT_EXPAND(AT_INTEGRAL_TYPES), AT_EXPAND(AT_FLOATING_TYPES) +#define AT_ALL_TYPES_AND_COMPLEX \ + AT_EXPAND(AT_ALL_TYPES), AT_EXPAND(AT_COMPLEX_TYPES) + +// Helper macros + +#define AT_AP_VAR(N, T, ...) \ + AT_EXPAND(AT_CONCAT(AT_AP, AT_NUM_ARGS(__VA_ARGS__))(AT_WRAP(N), __VA_ARGS__)) +#define AT_CONCAT(a, b) AT_CONCAT_AUX(a, b) +#define AT_CONCAT_AUX(a, b) a##b +#define AT_EXPAND(X) X + +// Ensure we never have too many scalar types for the expansion here to +// support. To bump this, you must regenerate the macros below. +static_assert(static_cast(c10::ScalarType::NumOptions) < 60); + +// Python code to regenerate generate code below: +#if 0 + +num_args = 60 + +nums = ', '.join(str(i) for i in reversed(range(num_args+1))) +args = ', '.join(f'_{i}' for i in range(1, num_args+1)) + +print(f'#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, {nums}))') +print(f'#define AT_NUM_ARGS_AUX({args}, N, ...) N') + +for i in range(1, num_args+1): + args = ', '.join(f'_{i}' for i in range(1, i+1)) + cases = ' '.join([f'AT_DISPATCH_CASE(_{j}, N)' for j in range(1, i+1)]) + print(f'#define AT_AP{i}(N, {args}) {cases}') + +#endif + +// Begin generated code +// clang-format off + +#define AT_NUM_ARGS(...) AT_EXPAND(AT_NUM_ARGS_AUX(__VA_ARGS__, 60, 59, 58, 57, 56, 55, 54, 53, 52, 51, 50, 49, 48, 47, 46, 45, 44, 43, 42, 41, 40, 39, 38, 37, 36, 35, 34, 33, 32, 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0)) +#define AT_NUM_ARGS_AUX(_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60, N, ...) N +#define AT_AP1(N, _1) AT_DISPATCH_CASE(_1, N) +#define AT_AP2(N, _1, _2) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) +#define AT_AP3(N, _1, _2, _3) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) +#define AT_AP4(N, _1, _2, _3, _4) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) +#define AT_AP5(N, _1, _2, _3, _4, _5) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) +#define AT_AP6(N, _1, _2, _3, _4, _5, _6) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) +#define AT_AP7(N, _1, _2, _3, _4, _5, _6, _7) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) +#define AT_AP8(N, _1, _2, _3, _4, _5, _6, _7, _8) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) +#define AT_AP9(N, _1, _2, _3, _4, _5, _6, _7, _8, _9) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) +#define AT_AP10(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) +#define AT_AP11(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) +#define AT_AP12(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) +#define AT_AP13(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) +#define AT_AP14(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) +#define AT_AP15(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) +#define AT_AP16(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) +#define AT_AP17(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) +#define AT_AP18(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) +#define AT_AP19(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) +#define AT_AP20(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) +#define AT_AP21(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) +#define AT_AP22(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) +#define AT_AP23(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) +#define AT_AP24(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) +#define AT_AP25(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) +#define AT_AP26(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) +#define AT_AP27(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) +#define AT_AP28(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) +#define AT_AP29(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) +#define AT_AP30(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) +#define AT_AP31(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) +#define AT_AP32(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) +#define AT_AP33(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) +#define AT_AP34(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) +#define AT_AP35(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) +#define AT_AP36(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) +#define AT_AP37(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) +#define AT_AP38(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) +#define AT_AP39(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) +#define AT_AP40(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) +#define AT_AP41(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) +#define AT_AP42(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) +#define AT_AP43(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) +#define AT_AP44(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) +#define AT_AP45(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) +#define AT_AP46(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) +#define AT_AP47(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) +#define AT_AP48(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) +#define AT_AP49(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) +#define AT_AP50(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) +#define AT_AP51(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) +#define AT_AP52(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) +#define AT_AP53(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) +#define AT_AP54(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) +#define AT_AP55(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) +#define AT_AP56(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) +#define AT_AP57(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) +#define AT_AP58(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) +#define AT_AP59(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) +#define AT_AP60(N, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11, _12, _13, _14, _15, _16, _17, _18, _19, _20, _21, _22, _23, _24, _25, _26, _27, _28, _29, _30, _31, _32, _33, _34, _35, _36, _37, _38, _39, _40, _41, _42, _43, _44, _45, _46, _47, _48, _49, _50, _51, _52, _53, _54, _55, _56, _57, _58, _59, _60) AT_DISPATCH_CASE(_1, N) AT_DISPATCH_CASE(_2, N) AT_DISPATCH_CASE(_3, N) AT_DISPATCH_CASE(_4, N) AT_DISPATCH_CASE(_5, N) AT_DISPATCH_CASE(_6, N) AT_DISPATCH_CASE(_7, N) AT_DISPATCH_CASE(_8, N) AT_DISPATCH_CASE(_9, N) AT_DISPATCH_CASE(_10, N) AT_DISPATCH_CASE(_11, N) AT_DISPATCH_CASE(_12, N) AT_DISPATCH_CASE(_13, N) AT_DISPATCH_CASE(_14, N) AT_DISPATCH_CASE(_15, N) AT_DISPATCH_CASE(_16, N) AT_DISPATCH_CASE(_17, N) AT_DISPATCH_CASE(_18, N) AT_DISPATCH_CASE(_19, N) AT_DISPATCH_CASE(_20, N) AT_DISPATCH_CASE(_21, N) AT_DISPATCH_CASE(_22, N) AT_DISPATCH_CASE(_23, N) AT_DISPATCH_CASE(_24, N) AT_DISPATCH_CASE(_25, N) AT_DISPATCH_CASE(_26, N) AT_DISPATCH_CASE(_27, N) AT_DISPATCH_CASE(_28, N) AT_DISPATCH_CASE(_29, N) AT_DISPATCH_CASE(_30, N) AT_DISPATCH_CASE(_31, N) AT_DISPATCH_CASE(_32, N) AT_DISPATCH_CASE(_33, N) AT_DISPATCH_CASE(_34, N) AT_DISPATCH_CASE(_35, N) AT_DISPATCH_CASE(_36, N) AT_DISPATCH_CASE(_37, N) AT_DISPATCH_CASE(_38, N) AT_DISPATCH_CASE(_39, N) AT_DISPATCH_CASE(_40, N) AT_DISPATCH_CASE(_41, N) AT_DISPATCH_CASE(_42, N) AT_DISPATCH_CASE(_43, N) AT_DISPATCH_CASE(_44, N) AT_DISPATCH_CASE(_45, N) AT_DISPATCH_CASE(_46, N) AT_DISPATCH_CASE(_47, N) AT_DISPATCH_CASE(_48, N) AT_DISPATCH_CASE(_49, N) AT_DISPATCH_CASE(_50, N) AT_DISPATCH_CASE(_51, N) AT_DISPATCH_CASE(_52, N) AT_DISPATCH_CASE(_53, N) AT_DISPATCH_CASE(_54, N) AT_DISPATCH_CASE(_55, N) AT_DISPATCH_CASE(_56, N) AT_DISPATCH_CASE(_57, N) AT_DISPATCH_CASE(_58, N) AT_DISPATCH_CASE(_59, N) AT_DISPATCH_CASE(_60, N) + +// End generated code +// clang-format on diff --git a/phivenv/Lib/site-packages/torch/include/ATen/DynamicLibrary.h b/phivenv/Lib/site-packages/torch/include/ATen/DynamicLibrary.h new file mode 100644 index 0000000000000000000000000000000000000000..8529b4025ab01871f6b0e9da7074e25e2d617adc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/DynamicLibrary.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include +#include + +namespace c10 { + +class DynamicLibraryError : public Error { + using Error::Error; +}; + +} // namespace c10 + +namespace at { + +struct DynamicLibrary { + AT_DISALLOW_COPY_AND_ASSIGN(DynamicLibrary); + DynamicLibrary(DynamicLibrary&& other) = delete; + DynamicLibrary& operator=(DynamicLibrary&&) = delete; + + TORCH_API DynamicLibrary( + const char* name, + const char* alt_name = nullptr, + bool leak_handle = false); + + TORCH_API void* sym(const char* name); + + TORCH_API ~DynamicLibrary(); + + private: + bool leak_handle; + void* handle = nullptr; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/EmptyTensor.h b/phivenv/Lib/site-packages/torch/include/ATen/EmptyTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..b4b27acedf2229a6e8ff7b21fd699ff998ae32d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/EmptyTensor.h @@ -0,0 +1,166 @@ +#pragma once +#include + +namespace at::detail { + +inline void check_size_nonnegative(ArrayRef size) { + for (const auto& x : size) { + TORCH_CHECK( + x >= 0, + "Trying to create tensor with negative dimension ", + x, + ": ", + size); + } +} + +inline void check_size_nonnegative(ArrayRef size) { + for (const auto& x : size) { + TORCH_CHECK( + x.expect_size(__FILE__, __LINE__), + "Trying to create tensor with negative dimension ", + x, + ": ", + size); + } +} + +TORCH_API size_t computeStorageNbytesContiguous( + IntArrayRef sizes, + size_t itemsize, + size_t storage_offset = 0); +TORCH_API SymInt computeStorageNbytesContiguous( + SymIntArrayRef sizes, + const SymInt& itemsize, + const SymInt& storage_offset = 0); +TORCH_API size_t computeStorageNbytes( + IntArrayRef sizes, + IntArrayRef strides, + size_t itemsize, + size_t storage_offset = 0); +TORCH_API SymInt computeStorageNbytes( + SymIntArrayRef sizes, + SymIntArrayRef strides, + const SymInt& itemsize, + const SymInt& storage_offset = 0); + +TORCH_API TensorBase empty_generic( + IntArrayRef size, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_generic_symint( + SymIntArrayRef size, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_strided_generic( + IntArrayRef size, + IntArrayRef stride, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type); + +TORCH_API TensorBase empty_strided_symint_generic( + SymIntArrayRef size, + SymIntArrayRef stride, + c10::Allocator* allocator, + c10::DispatchKeySet ks, + ScalarType scalar_type); + +TORCH_API TensorBase empty_cpu( + IntArrayRef size, + ScalarType dtype, + bool pin_memory = false, + std::optional memory_format_opt = std::nullopt); + +TORCH_API TensorBase empty_cpu( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_cpu(IntArrayRef size, const TensorOptions& options); + +TORCH_API TensorBase empty_strided_cpu( + IntArrayRef size, + IntArrayRef stride, + ScalarType dtype, + bool pin_memory = false); + +TORCH_API TensorBase empty_strided_cpu( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TORCH_API TensorBase empty_strided_cpu( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +TORCH_API TensorBase empty_meta( + IntArrayRef size, + ScalarType dtype, + std::optional memory_format_opt = std::nullopt); + +TORCH_API TensorBase empty_meta( + IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_symint_meta( + SymIntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +TORCH_API TensorBase empty_meta(IntArrayRef size, const TensorOptions& options); + +TORCH_API TensorBase +empty_strided_meta(IntArrayRef size, IntArrayRef stride, ScalarType dtype); + +TORCH_API TensorBase empty_strided_meta( + IntArrayRef size, + IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +TORCH_API TensorBase empty_strided_meta( + IntArrayRef size, + IntArrayRef stride, + const TensorOptions& options); + +TORCH_API TensorBase empty_strided_symint_meta( + SymIntArrayRef size, + SymIntArrayRef stride, + ScalarType dtype); + +TORCH_API TensorBase empty_strided_symint_meta( + SymIntArrayRef size, + SymIntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt); + +TORCH_API TensorBase empty_strided_symint_meta( + SymIntArrayRef size, + SymIntArrayRef stride, + const TensorOptions& options); + +} // namespace at::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ExpandBase.h b/phivenv/Lib/site-packages/torch/include/ATen/ExpandBase.h new file mode 100644 index 0000000000000000000000000000000000000000..d59a2714455873cf776242bd04157130911c8b28 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ExpandBase.h @@ -0,0 +1,30 @@ +#include + +// Broadcasting utilities for working with TensorBase +namespace at { +namespace internal { +TORCH_API TensorBase expand_slow_path(const TensorBase& self, IntArrayRef size); +} // namespace internal + +inline c10::MaybeOwned expand_size( + const TensorBase& self, + IntArrayRef size) { + if (size.equals(self.sizes())) { + return c10::MaybeOwned::borrowed(self); + } + return c10::MaybeOwned::owned( + at::internal::expand_slow_path(self, size)); +} +c10::MaybeOwned expand_size(TensorBase&& self, IntArrayRef size) = + delete; + +inline c10::MaybeOwned expand_inplace( + const TensorBase& tensor, + const TensorBase& to_expand) { + return expand_size(to_expand, tensor.sizes()); +} +c10::MaybeOwned expand_inplace( + const TensorBase& tensor, + TensorBase&& to_expand) = delete; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ExpandUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/ExpandUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d12731f110e563a8d55b5ae7cdeeb0fceeed62d2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ExpandUtils.h @@ -0,0 +1,535 @@ +#pragma once + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace at { + +TORCH_API std::vector infer_size(IntArrayRef a, IntArrayRef b); +TORCH_API std::vector infer_size_symint( + SymIntArrayRef a, + SymIntArrayRef b); +TORCH_API DimVector infer_size_dimvector(IntArrayRef a, IntArrayRef b); +TORCH_API SymDimVector +infer_size_symdimvector(SymIntArrayRef a, SymIntArrayRef b); + +// Named type instead of a pair/tuple so that we can be sure to +// construct the vectors in place and get NRVO. +template +struct InferExpandGeometryResult { + Container sizes; + Container strides; + explicit InferExpandGeometryResult(size_t ndim) + : sizes(ndim), strides(ndim) {} + explicit InferExpandGeometryResult(IntArrayRef sizes_, size_t ndim) + : sizes(sizes_.begin(), sizes_.end()), strides(ndim) {} +}; + +TORCH_API std::tuple, std::vector> +inferExpandGeometry( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides, + IntArrayRef sizes); + +TORCH_API InferExpandGeometryResult inferExpandGeometry_dimvector( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides, + IntArrayRef sizes); + +TORCH_API std::vector infer_dense_strides( + IntArrayRef tensor_sizes, + IntArrayRef tensor_strides); + +// True if input shapes are expandable +// NOTE: infer_size did a similar check, please keep them sync if change is +// needed +inline bool are_expandable(IntArrayRef shape1, IntArrayRef shape2) { + size_t ndim1 = shape1.size(); + size_t ndim2 = shape2.size(); + size_t ndim = ndim1 < ndim2 ? ndim1 : ndim2; + + for (int64_t i = static_cast(ndim) - 1; i >= 0; --i) { + if (shape1[--ndim1] == shape2[--ndim2] || shape1[ndim1] == 1 || + shape2[ndim2] == 1) { + continue; + } + return false; + } + return true; +} + +// avoid copy-construction of Tensor by using a reference_wrapper. +inline void check_defined( + std::initializer_list> tensors, + const char* api_name) { + for (auto& t : tensors) { + if (!t.get().defined()) { + TORCH_CHECK(false, api_name, "(...) called with an undefined Tensor"); + } + } +} + +// NOTE [ ExpandUtils Borrowing ] +// +// Functions in ExpandUtils return `c10::MaybeOwned` because +// expansion may not actually be needed, in which case we can improve +// efficiency by returning +// `c10::MaybeOwned::borrowed(to_expand)`. However, this means +// that you need to be careful: the returned `c10::MaybeOwned` +// must not outlive the original `Tensor` object that `to_expand` +// referred to! The deleted rvalue reference overloads of these +// functions help with this by preventing trivial use of a temporary +// resulting from a function call, but it is still possible to make a +// mistake. + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + const Tensor& to_expand) { + if (tensor.sym_sizes().equals(to_expand.sym_sizes())) { + return c10::MaybeOwned::borrowed(to_expand); + } + return c10::MaybeOwned::owned( + to_expand.expand_symint(tensor.sym_sizes())); +} + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + Tensor&& to_expand) = delete; + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + const Tensor& to_expand, + const char* api_name) { + check_defined({tensor, to_expand}, api_name); + return expand_inplace(tensor, to_expand); +} + +inline c10::MaybeOwned expand_inplace( + const Tensor& tensor, + Tensor&& to_expand, + const char* api_name) = delete; + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + const Tensor& to_expand2) { + if (tensor.sizes().equals(to_expand1.sizes()) && + tensor.sizes().equals((to_expand2.sizes()))) { + return std::make_tuple( + c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2)); + } + + return std::make_tuple( + c10::MaybeOwned::owned(to_expand1.expand(tensor.sizes())), + c10::MaybeOwned::owned(to_expand2.expand(tensor.sizes()))); +} + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + Tensor&& to_expand1, + const Tensor& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + Tensor&& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace(const Tensor& tensor, Tensor&& to_expand1, Tensor&& to_expand2) = + delete; + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + const Tensor& to_expand2, + const char* api_name) { + check_defined({tensor, to_expand1, to_expand2}, api_name); + return expand_inplace(tensor, to_expand1, to_expand2); +} + +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + Tensor&& to_expand1, + const Tensor& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + const Tensor& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_inplace( + const Tensor& tensor, + Tensor&& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; + +// See NOTE [ ExpandUtils Borrowing ] above for `MaybeOwned` explanation. +inline std::tuple, c10::MaybeOwned> +expand_outplace(const Tensor& to_expand1, const Tensor& to_expand2) { + auto s1 = to_expand1.sym_sizes(); + auto s2 = to_expand2.sym_sizes(); + if (s1.equals(s2)) { + return std::make_tuple( + c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2)); + } + + auto expanded_size = infer_size_symdimvector(s1, s2); + return std::make_tuple( + c10::MaybeOwned::owned(to_expand1.expand_symint(expanded_size)), + c10::MaybeOwned::owned(to_expand2.expand_symint(expanded_size))); +} + +inline std::tuple, c10::MaybeOwned> +expand_outplace(Tensor&& to_expand1, const Tensor& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace(const Tensor& to_expand1, Tensor&& to_expand2) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2) = delete; + +inline std::tuple, c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + const char* api_name) { + check_defined({to_expand1, to_expand2}, api_name); + return expand_outplace(to_expand1, to_expand2); +} + +inline std::tuple, c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; +inline std::tuple, c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + const char* api_name) = delete; + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3) { + if (to_expand1.sizes().equals(to_expand2.sizes()) && + to_expand1.sizes().equals(to_expand3.sizes())) { + return std::make_tuple( + c10::MaybeOwned::borrowed(to_expand1), + c10::MaybeOwned::borrowed(to_expand2), + c10::MaybeOwned::borrowed(to_expand3)); + } + + auto expanded_size12 = + infer_size_dimvector(to_expand1.sizes(), to_expand2.sizes()); + auto expanded_size = + infer_size_dimvector(expanded_size12, to_expand3.sizes()); + return std::make_tuple( + c10::MaybeOwned::owned(to_expand1.expand(expanded_size)), + c10::MaybeOwned::owned(to_expand2.expand(expanded_size)), + c10::MaybeOwned::owned(to_expand3.expand(expanded_size))); +} + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + Tensor&& to_expand3) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace(Tensor&& to_expand1, Tensor&& to_expand2, Tensor&& to_expand3) = + delete; + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3, + const char* api_name) { + check_defined({to_expand1, to_expand2, to_expand3}, api_name); + return expand_outplace(to_expand1, to_expand2, to_expand3); +} + +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + const Tensor& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + const Tensor& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + const Tensor& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + const Tensor& to_expand1, + Tensor&& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; +inline std::tuple< + c10::MaybeOwned, + c10::MaybeOwned, + c10::MaybeOwned> +expand_outplace( + Tensor&& to_expand1, + Tensor&& to_expand2, + Tensor&& to_expand3, + const char* api_name) = delete; + +inline c10::MaybeOwned expand_size( + const Tensor& to_expand, + IntArrayRef sizes) { + if (to_expand.sizes().equals(sizes)) { + return c10::MaybeOwned::borrowed(to_expand); + } + + return c10::MaybeOwned::owned(to_expand.expand(sizes)); +} + +inline c10::MaybeOwned expand_size( + Tensor&& to_expand, + IntArrayRef sizes) = delete; + +inline c10::MaybeOwned expand_size( + const Tensor& to_expand, + IntArrayRef sizes, + const char* api_name) { + check_defined({to_expand}, api_name); + return expand_size(to_expand, sizes); +} + +inline c10::MaybeOwned expand_size( + Tensor&& to_expand, + IntArrayRef sizes, + const char* api_name) = delete; + +inline std::vector expand_outplace(TensorList to_expand) { + // expands a list of Tensors; ignores undefined (null) tensors + bool first = true; + SymDimVector sizes; + for (const auto i : c10::irange(to_expand.size())) { + if (!to_expand[i].defined()) { + continue; + } else if (first) { + sizes = to_expand[i].sym_sizes(); + first = false; + } else { + sizes = infer_size_symdimvector(sizes, to_expand[i].sym_sizes()); + } + } + + std::vector result(to_expand.size()); + for (const auto i : c10::irange(to_expand.size())) { + if (!to_expand[i].defined()) { + continue; + } else if (to_expand[i].sym_sizes().equals(sizes)) { + result[i] = to_expand[i]; + } else { + result[i] = to_expand[i].expand_symint(sizes); + } + } + return result; +} + +template +inline Tensor _sum_to( + Tensor tensor, + const c10::ArrayRef shape, + bool always_return_non_view = false) { + if (shape.size() == 0) { + return tensor.sum(); + } + + auto sizes = at::symint::sizes(tensor); + c10::SmallVector reduce_dims; + const int64_t leading_dims = sizes.size() - shape.size(); + for (const auto i : c10::irange(leading_dims)) { + reduce_dims.push_back(i); + } + for (int64_t i = leading_dims; i < static_cast(sizes.size()); ++i) { + if (TORCH_GUARD_OR_FALSE(sym_eq(shape[i - leading_dims], 1)) && + TORCH_GUARD_OR_TRUE(sym_ne(sizes[i], 1))) { + reduce_dims.push_back(i); + } else { + // if we assume no reduction due to unbacked we ensure that at runtime. + TORCH_MAYBE_SYM_CHECK( + sym_eq(shape[i - leading_dims], sizes[i]), + "non-reduction path was assumed due to unabcked symbols expected those two sizes to be the same:", + shape[i - leading_dims], + ", ", + sizes[i]) + } + } + + if (!reduce_dims.empty()) { + tensor = tensor.sum(reduce_dims, /*keepdim=*/true); + } + + if (always_return_non_view) { + // This is only actually used by the functionalization pass. + // We want to be able to guarantee that this function doesn't return a view + // of the input. + return leading_dims > 0 ? at::symint::view_copy(tensor, shape) + : tensor.clone(); + } else { + return leading_dims > 0 ? at::symint::view(tensor, shape) : tensor; + } +} + +inline Tensor sum_to( + Tensor tensor, + const c10::SymIntArrayRef shape, + bool always_return_non_view = false) { + return _sum_to(std::move(tensor), shape, always_return_non_view); +} + +// Sums `tensor` repeatedly to produce a tensor of shape `shape`. +// Precondition: is_expandable_to(shape, tensor.sizes()) must be true +inline Tensor sum_to( + Tensor tensor, + const IntArrayRef shape, + bool always_return_non_view = false) { + return _sum_to(std::move(tensor), shape, always_return_non_view); +} + +inline bool is_expandable_to( + SymIntArrayRef shape, + c10::SymIntArrayRef desired) { + size_t ndim = shape.size(); + size_t target_dim = desired.size(); + if (ndim > target_dim) { + return false; + } + for (const auto i : c10::irange(ndim)) { + const auto& size = shape[ndim - i - 1]; + const auto& target = desired[target_dim - i - 1]; + if (size != target && size != 1) { + return false; + } + } + return true; +} + +inline bool is_expandable_to(IntArrayRef shape, IntArrayRef desired) { + auto sym_shape = c10::SymIntArrayRef( + reinterpret_cast(shape.data()), shape.size()); + auto sym_desired = c10::SymIntArrayRef( + reinterpret_cast(desired.data()), desired.size()); + return is_expandable_to(sym_shape, sym_desired); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Formatting.h b/phivenv/Lib/site-packages/torch/include/ATen/Formatting.h new file mode 100644 index 0000000000000000000000000000000000000000..e23b27ffd373180a1857a5491694eff11705f9a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Formatting.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h b/phivenv/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..6430caadfa947f57f76f1d7e218b4b4d60140f8b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/FuncTorchTLS.h @@ -0,0 +1,46 @@ +#pragma once + +#include +#include + +namespace at::functorch { + +// NOTE [functorch TLS in pytorch/pytorch] +// +// functorch lives out-of-tree. However, it has some TLS that needs to be +// propagated. The solution for that is we store a pointer to the TLS +// inside pytorch/pytorch and extend FuncTorchTLSBase inside functorch to +// include whatever functorch needs. +// +// We need to store a pointer due to the indirection: +// inside functorch, we will create a subclass of FunctorchTLSBase called +// FuncTorchTLSImpl that actually contains metadata, like the DynamicLayerStack. +// FuncTorchTLSBase doesn't have any metadata because it hasn't been defined +// yet. +// +// Here in pytorch/pytorch, we will pass around FuncTorchTLSBase*, but inside +// functorch, we will assign a FuncTorchTLSImpl* to the FunctorchTLSBase*. +// We can't directly pass around FunctorchTLSBase (without a pointer) because +// FuncTorchTLSImpl does not fit inside a FuncTorchTLSBase by virtue of having +// more elements. +struct TORCH_API FuncTorchTLSBase { + virtual ~FuncTorchTLSBase() = default; + virtual std::unique_ptr deepcopy() const = 0; + + virtual int64_t checkSupportsSingleLevelAutogradFunction() const = 0; + virtual void checkSupportsCppAutogradFunction() const = 0; + virtual void checkSupportsInplaceRequiresGrad() const = 0; + virtual void checkSupportsRetainGrad() const = 0; +}; + +// returns deepcopy of the functorch tls +TORCH_API std::unique_ptr getCopyOfFuncTorchTLS(); + +// sets the functorch tls. always does a deep copy. +TORCH_API void setFuncTorchTLS( + const std::shared_ptr& state); + +// get a mutable reference to the functorch tls +TORCH_API std::unique_ptr& functorchTLSAccessor(); + +} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/FunctionalStorageImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/FunctionalStorageImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..9cf573d8795ff3c557de2f51f502b21e70ce8d4f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/FunctionalStorageImpl.h @@ -0,0 +1,208 @@ +#pragma once + +#include + +#include + +namespace at::functionalization { + +// See Note [Functionalization Pass In Core] + +// ViewMeta is a class used by the functionalization pass to navigate between +// a base tensor and a view tensor. +// For example, if I call `b = a.view1(...)` +// the functionalization pass will generate and store a ViewMeta on b that looks +// like: +// +// ViewMeta( +// [](const Tensor& base, int64_t mutated_view_idx) { +// return base.view1(...); +// }, +// [](const at::Tensor& base, const at::Tensor& mutated_view, +// int64_t mutated_view_idx) -> at::Tensor { +// return at::functionalization::impl::view1_inverse(base, mutated_view, +// ...); +// } +// +// The forward_fn lambda describes how to replay view1 on a tensor. +// +// The reverse_fn lambda describes how, given a tensor that is already a view, +// how to get the corresponding base tensor. See Note [Functionalization Pass: +// View Inverses] for details. +struct ViewMeta { + ViewMeta( + std::function forward, + std::function reverse, + bool has_symbolic_inputs, + bool is_multi_output = false, + bool is_as_strided = false, + int64_t out_idx = 0) + : forward_fn(std::move(forward)), + reverse_fn(std::move(reverse)), + out_index(out_idx), + is_multi_output(is_multi_output), + is_as_strided(is_as_strided), + has_symbolic_inputs(has_symbolic_inputs) {} + + std::function forward_fn; + std::function reverse_fn; + // See Note [out_idx in ViewMeta] + int64_t out_index; + + // Tells us if this is a multi-output view + bool is_multi_output; + + bool is_as_strided; + + // Tells us if this view operation has any symbolic inputs + bool has_symbolic_inputs; + + // Returns a copy of the current ViewMeta, if out_idx matches the current + // out_index. Otherwise, returns a new ViewMeta with the same forward/reverse + // functions, but a new out index. + ViewMeta to_out_idx(int64_t out_idx); +}; + +// FunctionalStorageImpl is a subclass of StorageImpl used by the +// functionalization pass. It has no underlying data (similar to meta storage). +// It also knows how to reflect mutations to tensors in the absence of a valid +// data pointer. +// +// A storage represents the state shared by (potentially multiple) views of the +// same tensor. For example, in the following code: +// +// b = a.view1(...) +// c = b.view2(...) +// b.add_(1) +// --> storage.add_update(b, {view1_meta}) +// +// The call to add_(1) will result in a call to alias.add_update(b, +// {view1_meta}), queueing up the mutation from b onto the alias. Later, suppose +// c is used in an expression (e.g. you try to print c, or pass it to an +// operator). Doing so will involve "syncing" c. First we apply any pending +// updates to the alias, and then we regenerate c by replaying its views off of +// the updated alias. E.g: +// +// print(str(c)) +// --> c.sync_() +// --> alias.apply_updates() // after this, the alias will be updated to +// reflect the mutation to b +struct TORCH_API FunctionalStorageImpl : public c10::StorageImpl { + public: + struct Update { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const at::Tensor new_val; + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const std::vector view_metas; + }; + + explicit FunctionalStorageImpl(const Tensor& value); + + void add_update( + const Tensor& updated_val, + const std::vector& view_metas); + bool apply_updates(); + const Tensor& base() { + return base_; + } + size_t generation() const { + return generation_; + } + void freeze() { + frozen_ = true; + } + + c10::SymInt get_storage_size(bool before) { + if (before) { + return original_storage_size_; + } else { + return curr_storage_size_; + } + } + + ~FunctionalStorageImpl() override = default; + + void mark_mutation() { + mutation_counter_++; + } + void mark_mutation_during_no_grad_or_inference_mode() { + mutation_counter_during_no_grad_or_inference_mode_++; + } + void mark_mutation_hidden_from_autograd() { + mutation_counter_hidden_from_autograd_++; + } + + bool are_all_mutations_under_no_grad_or_inference_mode() const { + auto non_autograd_mutations = + mutation_counter_during_no_grad_or_inference_mode_ + + mutation_counter_hidden_from_autograd_; + // The <= is because both counters will technically be incremented, if we + // perform e.g. a triton kernel mutation under no_grad + return mutation_counter_ <= non_autograd_mutations; + } + + bool are_all_mutations_hidden_from_autograd() const { + // mutations under no_grad / inference_mode are technically not hidden from + // autograd - they change the version counter + return mutation_counter_ <= mutation_counter_hidden_from_autograd_; + } + + void mark_inductor_storage_resize(c10::SymInt new_size) { + inductor_storage_resized_ = true; + curr_storage_size_ = std::move(new_size); + } + + bool was_inductor_storage_resized() { + return inductor_storage_resized_; + } + + private: + // NB: base_ should always point to a tensor BELOW the current + // functionalization layer. This is mainly to avoid reference cycles. e.g. + // given `b = a.view(...)` Both a.storage_ and b.storage_ are a + // FunctionStorageImpl containing an Walualias, with contains a Tensor + // `base_`. In this case (where a and b are FunctionalTensorWrapper's), base_ + // should point not to a, but to a's unwrapped value, a.value_` See Note + // [Functionalization: Walualias Removal] for a diagram that shows this + // visually. + at::Tensor base_; + std::vector updates_; + // generation_ gets incremented every time a mutation is queued onto the + // alias. It is used to determine if a given tensor is "up to date", or if it + // needs to be regenerated from the alias. + size_t generation_ = 0; + // If frozen, no more mutations are allowed on this storage. Once frozen, a + // storage cannot be unfrozen. + bool frozen_ = false; + + // These mutation counters are bumped on the storage + // whenever a FunctionalTensorWrapper experiences a mutation. + // When the mutation is under no_grad, or comes from a triton kernel, we also + // bump the corresponding during_no_grad or hidden_from_autograd counters. Why + // do we need to detect these two situations separately from "normal" input + // mutations? (1) "normal" input mutations can mutate autograd metadata like + // .grad_fn, + // in which case they need to be replayed outside of the compiled graph + // (2) "no_grad" input mutations are generally safe to keep in the graph (and + // compile), + // but they bump the tensor's VC, so we need to mark_dirty() on the inputs + // in torch.compile + // (3) mutations that are fully hidden from autograd (e.g. from a triton + // kernel) + // do not mutate any autograd state, and be fully kept in the graph + // When we detect that an input was mutated, we need to be able to tell if: + // (1) all of the mutations were from triton kernels + // (2) all of the mutations were under no_grad + uint64_t mutation_counter_during_no_grad_or_inference_mode_ = 0; + uint64_t mutation_counter_ = 0; + uint64_t mutation_counter_hidden_from_autograd_ = 0; + + // Used to tell if: + // (1) There were any storage resizes on a graph input + // (2) The original/curr storage size tell us if these resizes result in a nop + bool inductor_storage_resized_ = false; + c10::SymInt original_storage_size_; + c10::SymInt curr_storage_size_; +}; + +} // namespace at::functionalization diff --git a/phivenv/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h b/phivenv/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h new file mode 100644 index 0000000000000000000000000000000000000000..2d01d7cec13eab40bb58fbaf639e2a76777ed035 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/FunctionalTensorWrapper.h @@ -0,0 +1,454 @@ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace at { + +// Note [Functionalization Pass In Core] +// The Functionalization pass is used to remove aliasing from a pytorch program. +// +// This is useful for backends that don't support aliasing, like XLA and Vulkan. +// It's also necessary in order to remove mutation from a program, which is +// needed in Functorch. +// +// Consider this program: +// a = torch.ones(...) +// b = a.view(...) +// b.add_(1) +// +// In this program, b is meant to alias with a due to the use of view(). At the +// end of the program, both a and b are full of 2's. However, backends that +// don't support aliasing aren't able to correctly implement the view() +// operator. Instead, they can opt into the Functionalization pass, which will +// sit between the user and the backend, and provide the necessary aliasing +// logic. +// +// The functionalization pass will turn the above program into a slightly +// different program that has the same semantics, transparently to the user, +// that backends like XLA/Vulkan are able to implement a = torch.ones(...) b = +// a.view_copy(...) # view() replaced with view_copy(). Backends like +// XLA/Vulkan can implement this! b.add_(1) a.add_(1) # Our functionalization +// pass machinery knows that a and b are aliased - it applies b's mutation to a +// too. +// +// So, how does the functionalization pass keep track of which tensors are +// aliased? The pass works by wrapping EVERY tensor in the program inside of a +// FunctionalTensorWrapper, which knows about its alias'd tensors. +// +// See Note [Functionalization: Alias Removal] for details on the aliasing +// machinery. See Note [Functionalization: Mutation Removal] for details on +// mutation removal. +struct TORCH_API FunctionalTensorWrapper : public c10::TensorImpl { + explicit FunctionalTensorWrapper(const Tensor& value); + // Additional constructor to create a FunctionalTensorWrapper directly from an + // underlying tensor that was created from a view. For example, the code b = + // a.view1() will generate a constructor call to FunctionalTensorWrapper(b, a, + // view1_meta) + explicit FunctionalTensorWrapper( + const Tensor& view_value, + const FunctionalTensorWrapper* base, + const functionalization::ViewMeta& meta); + + // Get the underlying, actual tensor, that doesn't know anything about + // functionalization. + const Tensor& value() const { + return value_; + } + // The concept of "level" is only ever important to functorch; it's exposed + // here as more of a hook for functorch to use. + int64_t level() const { + return level_; + } + void set_level(int64_t level) { + level_ = level; + } + bool has_metadata_mutation() const { + return has_metadata_mutation_; + } + + void mark_mutation() { + functional_storage_impl()->mark_mutation(); + } + // Denotes a mutation that's hidden from autograd, + // e.g. for the purposes of passing a tensor to a triton kernel + void mark_mutation_hidden_from_autograd() { + functional_storage_impl()->mark_mutation_hidden_from_autograd(); + } + void mark_mutation_during_no_grad_or_inference_mode() { + functional_storage_impl()->mark_mutation_during_no_grad_or_inference_mode(); + } + // Are all the mutations happening to the tensor hidden from autograd + bool are_all_mutations_hidden_from_autograd() const { + return functional_storage_impl()->are_all_mutations_hidden_from_autograd(); + } + // Did all mutations happen under no_grad or inference_mode + // (We also need to ignore mutations fully hidden from autograd here) + bool are_all_mutations_under_no_grad_or_inference_mode() const { + return functional_storage_impl() + ->are_all_mutations_under_no_grad_or_inference_mode(); + } + + void maybe_mark_symbolic(const functionalization::ViewMeta& meta) { + is_symbolic_ = is_symbolic_ | meta.has_symbolic_inputs; + } + + bool is_symbolic() const { + return is_symbolic_; + } + + // Runs the forward_fn of every ViewMeta collected in the current instance + // to some other base. + Tensor apply_view_metas(const Tensor& base); + + // Sync's the underlying tensor with its alias, if it's out of date. This + // involves two steps: 1) Apply any pending updates/mutations to the alias 2) + // Replay the views (if any) to regenerate the current tensor off of the + // updated alias. + void sync_(); + // Performs step (1) of the sync. This is its own public API because it's + // needed by view_inplace ops like transpose_. See Note [Functionalization + // Pass - Inplace View Ops] + void regenerate_from_base(); + // Performs step (2) of the sync. This is its own public API because it's + // needed by functorch. functorch wants to make sure that all input tensors to + // a functionalized program have been properly synced so it can properly + // propagate mutations to inputs. It can't just call sync_(), because the + // FunctionalTensorWrapper will look like it has no aliases and sync_ will be + // a noop. We use the reference count on storage_ to determine if the wrapper + // is aliased, and by the time functorch is ready to propagate updates to + // inputs, any intermediate views of the input created by the program will + // have been deallocated. This function also returns whether or not the base + // actually had any updates to apply. + bool apply_updates(); + // Takes the current state of value_ and snapshots it, sending it as a pending + // update to the alias. + void commit_update(); + // When any tensor is mutated, the tensor increments its alias's "generation". + // Separately, each tensor maintains its own "generation" counter, which is + // used to determine if it's up-to-date with its alias. The act of syncing a + // tensor will set a tensor's generation equal to its alias's generation. + bool is_up_to_date() const; + // Freezes the storage of this tensor, preventing subsequent mutations + void freeze_storage() const; + // Every FunctionalTensorWrapper contains a vector objects + // describing the series of view ops that ran to generate the current tensor + // from the base tensor. This method is used by inplace-view ops like + // transpose_. It appends a ViewMeta to the existing stack, and refreshes the + // tensor by replaying the views off of the alias. + void mutate_view_meta(const at::functionalization::ViewMeta& meta); + + // Custom implementation of self.set_(src) + void set__impl(const FunctionalTensorWrapper* other); + + // Custom implementation of resize_storage_bytes_(self, new_size) + void storage_resize_(const c10::SymInt& new_size); + + // Returns whether the current tensor's data was ever mutated + bool has_data_mutation(); + // + // Returns whether the current FunctionalTensorWrapper + // experienced a set_() call. + bool was_storage_changed() { + return was_storage_changed_; + } + + void set_storage_changed() { + was_storage_changed_ = true; + } + + // A FunctionalTensor is considered a base if its not a view of another + // tensor. + bool isBaseTensor() const { + return view_metas_.empty(); + } + + c10::SymInt get_storage_size(bool before) { + return functional_storage_impl()->get_storage_size(before); + } + + // Returns whether the FunctionalTensor experienced an + // untyped_storage().resize_() call + bool was_inductor_storage_resized() { + return functional_storage_impl()->was_inductor_storage_resized(); + } + + // The functionalization pass can be used to remove mutations. + // It does so by replacing any mutation op with it's corresponding + // out-of-place op, followed by a call to replace_(). e.g: + // + // a.add_(1) + // + // will turn into: + // + // tmp = a.add(1) + // a.replace_(tmp) + // + // replace_() swaps out the wrapped tensor, value_, with tmp. + void replace_(const Tensor& other, bool from_lazy_regenerate = false); + + bool is_multi_output_view() { + return is_multi_output_view_; + } + + // See Note[resize_() in functionalization pass] + void maybe_replace_storage(const Tensor& other); + + // Replaces the storage with a new functional storage, + // and clears the view_metas_ stack. + // WARNING: Calling this function will sever the aliasing relationship between + // the current FunctionalTensorWrapper and any of its outstanding aliases. + // Please only call if you know what you're doing. + void _unsafe_reset_storage(); + + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + + ~FunctionalTensorWrapper() override = default; + + // FunctionalTensorWrapper overrides all custom size/stride function, + // so that if the inner tensor has a custom implementation + // we make sure to call that implementation. + at::IntArrayRef sizes_custom() const override; + at::IntArrayRef strides_custom() const override; + int64_t dim_custom() const override; + int64_t numel_custom() const override; + bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + c10::SymIntArrayRef sym_sizes_custom() const override; + c10::SymInt sym_size_custom(int64_t d) const override; + c10::SymIntArrayRef sym_strides_custom() const override; + c10::SymInt sym_storage_offset_custom() const override; + c10::Device device_custom() const override; + c10::Layout layout_impl() const override; + + private: + const char* tensorimpl_type_name() const override; + void set_constructor_metadata(); + functionalization::FunctionalStorageImpl* functional_storage_impl() const; + + // This is used to re-implement shallow_copy_and_detach for + // FunctionalTensorWrapper. The implementation is identical, but we just need + // to return a subclass instead of a plain TensorImpl. + // TODO: maybe it's possible to arrange for that to happen automatically + // without an override here? + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + void shallow_copy_from(const c10::intrusive_ptr& impl) override; + void copy_tensor_metadata_and_refresh( + const FunctionalTensorWrapper* src_impl, + FunctionalTensorWrapper* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const; + + // Note that value is not taken by reference: internally, the wrapper will + // change the value tensor that it points to over time. + Tensor value_; + int64_t level_{}; + // These two counters are used for identifying + // whether all the mutations on a given tensor are hidden from autograd or + // not. If we have an input mutation that is hidden from autograd, then once + // we convert the input mutation to a copy_() we know it will be safe to hide + // the copy_() from autograd as well. + bool has_metadata_mutation_ = false; + bool is_multi_output_view_ = false; + // Did the tensor experience a set_() call. + bool was_storage_changed_ = false; + // Did the tensor experience any view operation with symbolic int. + bool is_symbolic_ = false; + + size_t generation_ = 0; + std::vector view_metas_; + + protected: + static void copy_tensor_metadata( + const FunctionalTensorWrapper* src_impl, + FunctionalTensorWrapper* dest_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change); +}; + +// Utility functions for the functionalization pass. + +namespace functionalization { +namespace impl { + +TORCH_API inline FunctionalTensorWrapper* unsafeGetFunctionalWrapper( + const Tensor& tensor) { + auto functional_impl = + static_cast(tensor.unsafeGetTensorImpl()); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(functional_impl != nullptr); + return functional_impl; +} + +TORCH_API bool isBaseTensor(const at::Tensor& tensor); + +TORCH_API bool isFunctionalTensor(const at::Tensor& tensor); +TORCH_API bool isFunctionalTensor(const std::optional& t); +TORCH_API bool isFunctionalTensor( + const c10::List>& t_list); +TORCH_API bool isFunctionalTensor(ITensorListRef list); + +TORCH_API Tensor to_functional_tensor(const Tensor& tensor); +TORCH_API std::optional to_functional_tensor( + const std::optional& tensor); +TORCH_API c10::List> to_functional_tensor( + const c10::List>& t_list); +TORCH_API std::vector to_functional_tensor(ITensorListRef t_list); + +TORCH_API void freeze_functional_tensor(const Tensor& tensor); + +TORCH_API Tensor +from_functional_tensor(const Tensor& tensor, bool assert_functional = true); +TORCH_API std::optional from_functional_tensor( + const std::optional& t, + bool assert_functional = true); +TORCH_API c10::List> from_functional_tensor( + const c10::List>& t_list); +TORCH_API std::vector from_functional_tensor(ITensorListRef t_list); + +TORCH_API void sync(const at::Tensor& t); +TORCH_API void sync(const std::optional& t); +TORCH_API void sync(const c10::List>& t_list); +TORCH_API void sync(ITensorListRef t_list); + +TORCH_API void replace_(const Tensor& functional_tensor, const Tensor& other); +TORCH_API void replace_( + const ITensorListRef functional_tensor, + ITensorListRef other); + +TORCH_API void commit_update(const Tensor& functional_tensor); +TORCH_API void commit_update(ITensorListRef functional_tensor); + +TORCH_API void unsafe_reset_storage(const Tensor& functional_tensor); + +TORCH_API void mark_mutation_hidden_from_autograd( + const Tensor& functional_tensor); + +TORCH_API bool are_all_mutations_hidden_from_autograd( + const Tensor& functional_tensor); + +TORCH_API bool are_all_mutations_under_no_grad_or_inference_mode( + const Tensor& functional_tensor); + +// These two methods are XLA-specific logic and are no-ops +// for the normal functionalization flow. +TORCH_API void propagate_xla_data( + const Tensor& functional_tensor, + const Tensor& other); +TORCH_API void propagate_xla_data( + const ITensorListRef functional_tensor, + ITensorListRef other); + +TORCH_API void propagate_xla_data_direct( + const Tensor& tensor, + const Tensor& other); +TORCH_API void propagate_xla_data_direct( + const ITensorListRef tensor, + ITensorListRef other); + +Tensor create_functional_tensor_with_view_meta( + const Tensor& view_to_wrap, + const Tensor& base, + functionalization::ViewMeta meta, + int64_t out_idx = 0); +std::vector create_functional_tensor_with_view_meta( + ITensorListRef view_to_wrap, + const Tensor& base, + const functionalization::ViewMeta& meta); + +void mutate_view_meta( + const Tensor& self, + const functionalization::ViewMeta& meta); + +void set_sizes_strides_offset(const Tensor& out, const Tensor& meta_out); +void set_sizes_strides_offset( + const std::vector& outs, + const std::vector& meta_outs); + +// ~~~~~ TLS used in functionalization ~~~~~ + +TORCH_API bool getFunctionalizationReapplyViewsTLS(); +TORCH_API void setFunctionalizationReapplyViewsTLS(bool reapply_views); + +class TORCH_API FunctionalizationReapplyViewsGuard { + public: + FunctionalizationReapplyViewsGuard(bool reapply_views) + : prev_(getFunctionalizationReapplyViewsTLS()) { + setFunctionalizationReapplyViewsTLS(reapply_views); + } + + ~FunctionalizationReapplyViewsGuard() { + setFunctionalizationReapplyViewsTLS(prev_); + } + + FunctionalizationReapplyViewsGuard( + const FunctionalizationReapplyViewsGuard&) = delete; + FunctionalizationReapplyViewsGuard operator=( + const FunctionalizationReapplyViewsGuard&) = delete; + FunctionalizationReapplyViewsGuard(FunctionalizationReapplyViewsGuard&&) = + delete; + FunctionalizationReapplyViewsGuard operator=( + FunctionalizationReapplyViewsGuard&&) = delete; + + private: + bool prev_; +}; + +} // namespace impl + +// Helper function to call an out-of-place composite aten kernel that may use +// mutations / views internally, and functionalize them. +TORCH_API void functionalize_op_helper( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +template +struct _functionalize_aten_op final {}; + +template +struct _functionalize_aten_op final { + static ReturnType call( + typename c10::maybe_keep_symint::type... args) { + using FuncType = ReturnType( + typename c10::maybe_keep_symint::type...); + auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow( + (const char*)Op::name, (const char*)Op::overload_name) + .typed(); + + return c10::impl::BoxedKernelWrapper::call( + c10::BoxedKernel::makeFromFunction(), + op, + // BoxedKernelWrapper knows to ignore this keyset argument, + // because functionalize_op_helper doesn't take in a DispatchKeySet + c10::DispatchKeySet(), + args...); + } +}; + +template +using functionalize_aten_op = + _functionalize_aten_op; + +template +using functionalize_aten_op_symint = + _functionalize_aten_op; + +} // namespace functionalization +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Functions.h b/phivenv/Lib/site-packages/torch/include/ATen/Functions.h new file mode 100644 index 0000000000000000000000000000000000000000..6978ebafc448c8fc56a985abd8d7243af3434fa4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Functions.h @@ -0,0 +1,1465 @@ +#pragma once + +// @generated by torchgen/gen.py from Functions.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from and \ + see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +// NOTE: [TORCH_ASSERT_ONLY_METHOD_OPERATORS] +// +// In ATen, certain generated headers files include the definitions of +// every single operator in PyTorch. Unfortunately this means every +// time an operator signature is updated or changed in +// native_functions.yaml, you (and every other PyTorch developer) need +// to recompile every source file that includes any of these headers. +// +// To break up these header dependencies, and improve incremental +// build times for all PyTorch developers. These headers are split +// into per-operator headers in the `ATen/ops` folder. This limits +// incremental builds to only changes to methods of `Tensor`, or files +// that use the specific operator being changed. With `at::sum` as an +// example, you should include +// +// // instead of ATen/Functions.h +// // instead of ATen/NativeFunctions.h +// // instead of ATen/Operators.h +// // instead of ATen/CPUFunctions.h +// +// However, even if you're careful to use this in your own code. +// `Functions.h` might be included indirectly through another header +// without you realising. To avoid this, you can add +// +// #define TORCH_ASSERT_ONLY_METHOD_OPERATORS +// +// to the top of your source file. This way any time the non-specific +// headers are included, the compiler will error out. +// +// Also, be aware that `ops` are not available in all build +// configurations (namely fb-internal) so you must guard these +// includes with `#ifdef AT_PER_OPERATOR_HEADERS`. e.g. +// +// #ifndef AT_PER_OPERATOR_HEADERS +// #include +// #else +// #include +// #endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + + + +// Special C++ only overloads for std()-like functions (See gh-40287) +// These are needed because int -> bool conversion takes precedence over int -> IntArrayRef +// So, for example std(0) would select the std(unbiased=False) overload +TORCH_API inline Tensor var(const Tensor& self, int dim) { + return at::var(self, IntArrayRef{dim}); +} +TORCH_API inline std::tuple var_mean(const Tensor& self, int dim) { + return at::var_mean(self, IntArrayRef{dim}); +} +TORCH_API inline Tensor std(const Tensor& self, int dim) { + return at::std(self, IntArrayRef{dim}); +} +TORCH_API inline std::tuple std_mean(const Tensor& self, int dim) { + return at::std_mean(self, IntArrayRef{dim}); +} + +inline int64_t numel(const Tensor& tensor) { + return tensor.numel(); +} + +inline int64_t size(const Tensor& tensor, int64_t dim) { + return tensor.size(dim); +} + +inline int64_t stride(const Tensor& tensor, int64_t dim) { + return tensor.stride(dim); +} + +inline bool is_complex(const Tensor& tensor) { + return tensor.is_complex(); +} + +inline bool is_floating_point(const Tensor& tensor) { + return tensor.is_floating_point(); +} + +inline bool is_signed(const Tensor& tensor) { + return tensor.is_signed(); +} + +inline bool is_inference(const Tensor& tensor) { + return tensor.is_inference(); +} + +inline bool _is_zerotensor(const Tensor& tensor) { + return tensor._is_zerotensor(); +} + +inline bool is_conj(const Tensor& tensor) { + return tensor.is_conj(); +} + +inline Tensor conj(const Tensor& tensor) { + return tensor.conj(); +} + +inline bool is_neg(const Tensor& tensor) { + return tensor.is_neg(); +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Generator.h b/phivenv/Lib/site-packages/torch/include/ATen/Generator.h new file mode 100644 index 0000000000000000000000000000000000000000..741e39f29dae4cca6cb39f8b1d385bb14ed1b6c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Generator.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/InferSize.h b/phivenv/Lib/site-packages/torch/include/ATen/InferSize.h new file mode 100644 index 0000000000000000000000000000000000000000..8517de86b4b5f2ae95d24a8d86fb2b1ada4068d5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/InferSize.h @@ -0,0 +1,104 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Infers the size of a dim with size -1, if it exists. Also checks that new +// shape is compatible with the number of elements. +// +// templated to handle std::vector and DimVector use cases, see +// below +// +template +inline void infer_size_impl( + InputArrayRef shape, + NumelType numel, + ResultVec& res) { + NumelType newsize = 1; + // N.B. this is an index, not a sym dim! + std::optional infer_dim; + for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) { + if (TORCH_GUARD_OR_FALSE(sym_eq(shape[dim], -1))) { + if (infer_dim) { + throw std::runtime_error("only one dimension can be inferred"); + } + infer_dim = dim; + } else { + // in case of unbacked shape[dim] we assume it's not -1 and add a runtime + // assertion. + TORCH_MAYBE_SYM_CHECK( + sym_gt(shape[dim], -1), + "invalid shape dimension ", + shape[dim], + " at index ", + dim, + " of shape ", + shape); + newsize *= shape[dim]; + } + } + + auto set_infer_dim = [&]() { + // We have a degree of freedom here to select the dimension size; follow + // NumPy semantics and just bail. However, a nice error message is needed + // because users often use `view` as a way to flatten & unflatten + // dimensions and will otherwise be confused why + // empty_tensor.view( 0, 0) + // works yet + // empty_tensor.view(-1, 0) + // doesn't. + TORCH_CHECK( + newsize != 0, + "cannot reshape tensor of 0 elements into shape ", + shape, + " because the unspecified dimension size -1 can be any " + "value and is ambiguous"); + res[*infer_dim] = numel / newsize; + return; + }; + + if (infer_dim && newsize > 0 && numel % newsize == 0) { + set_infer_dim(); + return; + } + + TORCH_MAYBE_SYM_CHECK( + sym_eq(numel, newsize), + "shape '", + shape, + "' is invalid for input of size ", + numel); + if (infer_dim) { + set_infer_dim(); + } +} + +inline std::vector infer_size(IntArrayRef shape, int64_t numel) { + auto res = shape.vec(); + infer_size_impl(shape, numel, res); + return res; +} + +inline at::DimVector infer_size_dv(IntArrayRef shape, int64_t numel) { + auto res = at::DimVector(shape); + infer_size_impl(shape, numel, res); + return res; +} + +inline at::SymDimVector infer_size_dv( + c10::SymIntArrayRef shape, + c10::SymInt numel) { + auto res = at::SymDimVector(shape); + infer_size_impl( + shape, std::move(numel), res); + return res; +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h b/phivenv/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..58289fb41c6f66b85ca17297864e1639f0a78441 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/InitialTensorOptions.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +namespace at { + +// Represents the initial TensorOptions, before the "defaults" are ever changed. +// This is designed to be used in library code, where the explicit devices, +// dtypes, etc. are known. NOTE: this is not a stable API. +inline TensorOptions initialTensorOptions() { + return TensorOptions(kCPU).dtype(kFloat).layout(kStrided).requires_grad( + false); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Layout.h b/phivenv/Lib/site-packages/torch/include/ATen/Layout.h new file mode 100644 index 0000000000000000000000000000000000000000..11bda768d2fc435e5aa32c764097ef158fe4a315 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Layout.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h b/phivenv/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h new file mode 100644 index 0000000000000000000000000000000000000000..7a4a1961a5f57d0aed6a4bd9b07ae2ff7e094d8a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/LegacyBatchedFallback.h @@ -0,0 +1,25 @@ +#pragma once +#include +#include +#include + +namespace at { + +// If an operator doesn't have a batching rule implemented then we fallback +// to this implementation. The fallback only works on out-of-place operators +// that return only tensors with new memory. (e.g., no in-place operators, no +// view operations). +// +// The fallback effectively takes all of the BatchedTensors in `stack`, slices +// them, and runs `op` on all of the corresponding slices to produce slices +// of the outputs. The output slices then get `torch.stack`ed to create the +// final returns. +// +// The performance of the fallback is not very good because it introduces an +// extra copy from stacking the sliced outputs. Because of this, we prefer to +// write batching rules for operators whenever possible. +void batchedTensorForLoopFallback( + const c10::OperatorHandle& op, + torch::jit::Stack* stack); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..83152717640ed92b85d859e2ecacfaddc6e1dc2b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/LegacyBatchedTensorImpl.h @@ -0,0 +1,160 @@ +#pragma once + +#include + +#include +#include +#include + +namespace at { + +// We assume this in a few other places in the codebase, +// but there isn't a centralized definition. +constexpr int64_t kVmapMaxTensorDims = 64; + +// The valid vmap levels range from [0, 64). This effectively means that we +// support a maximum of 64 nested vmaps. +constexpr int64_t kVmapNumLevels = 64; + +// Store this number of elements of BatchDims on the stack. Most people will +// probably use <= 5 nested vmaps, but adjust this number as necessary. +constexpr int64_t kBatchDimsStackSize = 5; + +// a BatchDim represents a "private" dimension on a Tensor created inside of +// vmap. It is a (level, dim) tuple, with the `dim` indicating which dimension +// is being vmap'ed over and the `level` being an identifier for which vmap +// said dimension was created inside. The `dim` corresponds to a "physical +// dim" - it is a dimension index on the underlying physical tensor that is +// being vmapped over. +struct BatchDim { + BatchDim(int64_t level, int64_t dim) : dim_(dim), level_(level) {} + int64_t dim() const { + return dim_; + } + int64_t level() const { + return level_; + } + + private: + int64_t dim_; + int64_t level_; +}; + +using BatchDims = SmallVector; +using BatchDimsRef = ArrayRef; + +// A BatchedTensorImpl holds an underlying Tensor and a list of BatchDim +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +// +// The batch dimensions are treated as being "private"; they are not +// user-visible. For example, in the following Tensor, +// bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, dim=1)]) +// dimensions 0 and 1 are batch dimensions. +// +// bt.sizes() returns (5, 7); bt.sum(0) performs a reduction over the (public) +// dim 0, which is equivalent to dim 3 in the underlying ones(2, 3, 5, 7) +// tensor. +struct TORCH_API BatchedTensorImpl : public c10::TensorImpl { + explicit BatchedTensorImpl(Tensor value, BatchDims bdims); + + // Returns a reference to BatchDims that represent which dimensions of this + // tensor are private. + BatchDimsRef bdims() const { + return bdims_; + } + + // BatchedTensorImpl wraps a Tensor + const Tensor& value() const { + return value_; + } + + // Given a public dimension index, return the dimension index in the + // underlying value() tensor. For example, if we have + // bt = BatchedTensorImpl(ones(2, 3, 5, 7), [(lvl=1, dim=0), (lvl=2, + // dim=2)]) + // bt.actualDim(0) -> 1 + // bt.actualDim(1) -> 3 + // bt.actualDim(2) -> Error + int64_t actualDim(int64_t dim, bool wrap_dim = true) const; + + // We have to override this because we opted into CustomStrides + IntArrayRef strides_custom() const override; + // Override a bunch of methods inherited from TensorImpl to return error + // messages. + bool is_contiguous_custom(at::MemoryFormat memory_format) const override; + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; +#ifdef DEBUG + bool has_storage() const override; +#endif + + private: + // see NOTE: [BatchedTensorImpl levels invariant] + void checkInvariants() const; + const char* tensorimpl_type_name() const override; + + Tensor value_; + + // Note: [BatchedTensorImpl levels invariant] + // There is an invariant that the BatchDims must be stored in increasing + // `level` order. That is, for i < j, bdims_[i].level must be less than + // bdims_[j].level. + BatchDims bdims_; +}; + +// NB: We use the term "BatchedTensor" to mean a Tensor that is backed with a +// BatchedTensorImpl. +inline bool isBatchedTensor(const Tensor& tensor) { + return tensor.unsafeGetTensorImpl()->key_set().has(DispatchKey::Batched); +} + +// It is unsafe to call this on a Tensor that is not backed by a +// BatchedTensorImpl. Please use `maybeGetBatchedImpl` whenever possible. +inline BatchedTensorImpl* unsafeGetBatchedImpl(const Tensor& tensor) { + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline BatchedTensorImpl* maybeGetBatchedImpl(const Tensor& tensor) { + if (!isBatchedTensor(tensor)) { + return nullptr; + } + return unsafeGetBatchedImpl(tensor); +} + +// Returns a bitset. If bit i is set, then that means dim i is a batchdim. +inline std::bitset createBatchDimBitset( + BatchDimsRef bdims) { + std::bitset is_bdim; + for (const auto& bdim : bdims) { + is_bdim.set(bdim.dim()); + } + return is_bdim; +} + +// Creates a bitset for all of the levels present in `bdims` +inline std::bitset createVmapLevelsBitset(BatchDimsRef bdims) { + std::bitset result; + for (const auto& bdim : bdims) { + result.set(bdim.level()); + } + return result; +} + +inline std::ostream& operator<<(std::ostream& out, const BatchDim& bdim) { + out << "(lvl=" << bdim.level() << ", dim=" << bdim.dim() << ")"; + return out; +} + +// Use this to construct a BatchedTensor from a regular Tensor +TORCH_API Tensor makeBatched(Tensor tensor, BatchDims bdims); + +// Adds a batch dim to `tensor`, returning a BatchedTensor +TORCH_API Tensor addBatchDim(Tensor tensor, int64_t level, int64_t dim); + +// Checks if an inplace operation on self and other is "vmap compatible". +// See NOTE: [vmap-incompatible in-place operations] for the definition of this. +TORCH_API bool inplaceIsVmapCompatible(const Tensor& self, const Tensor& other); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h b/phivenv/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h new file mode 100644 index 0000000000000000000000000000000000000000..dfb093566ccbe05a23e1d474cad84166496eb402 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/LegacyVmapMode.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace at::impl { + +// VmapMode contains a thread local count of how many nested vmaps +// we are currently inside. That number is known as the `vmap level`. +// VmapMode is used in the implementation of the Python `torch.vmap` API. +// +// NOTE: this is NOT the c++ api for torch.vmap. That doesn't exist yet. + +struct TORCH_API VmapMode { + // Returns the vmap level, aka the count of how many nested vmaps we're in. + static int64_t current_vmap_level(); + + // Increment the count of nested vmaps. If this causes the vmap level to be + // greater than 0, then it enables DispatchKey::VmapMode on all tensors. + static int64_t increment_nesting(); + + // Decrements the count of nested vmaps. If this causes the vmap level to be + // equal to 0, then it disables DispatchKey::VmapMode on all tensors. + static int64_t decrement_nesting(); +}; + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h b/phivenv/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h new file mode 100644 index 0000000000000000000000000000000000000000..13af3ad08ad24f59d81bf6d4ade0cb925d3a5b95 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/LegacyVmapTransforms.h @@ -0,0 +1,183 @@ +#pragma once + +#include +#include + +namespace at { + +// This file contains abstractions used for transforming *logical* vmap +// arguments into *physical* arguments. (Keep reading for definitions of these +// terms). + +// NOTE: [Logical vs physical args] +// Consider the following vmap. +// vmap(vmap(func, in_dims=(2,)), in_dims=(0,))(torch.ones(2, 3, 4)) +// This would produce a BatchedTensor wrapping a Tensor of size [2, 3, 4], +// with batch dims 0 and 2: +// BatchedTensor(ones(2, 3, 4), bdims=[(lvl=1,dim=0),(lvl=2,dim=2)]) +// +// We say the *logical* view of the tensor has size [3] -- tensors inside +// `func` appear to have size [3]. +// However, the *physical* underlying tensor (the one passed to vmap) has size +// [2, 3, 4]. +// +// This notion of logical vs physical also extends to non-tensor arguments. +// Consider the previous tensor; let's assume the user called +// `torch.sum(tensor, dim=0)` inside of `func`. Then the logical +// dimension they are reducing over is dim 0 but the physical dim is dim 1 +// (the first non-batch dimension) + +// Forward declared; see NOTE: [What is a VmapPhysicalView?] +struct VmapPhysicalView; + +// Most PyTorch operators take 4 or fewer inputs. +constexpr int64_t kVmapTransformStaticInputSize = 4; +using VmapPhysicalViewVec = + SmallVector; + +// Pytorch generally advertises good performance for <= 5 dims. +// (see ATen/core/DimVector.h). We add a few extra dims (~3) for vmap +// dimensions to get 8. Adjust this number as necessary +constexpr int64_t kVmapStaticDimVecSize = 8; +using VmapDimVector = SmallVector; +using VmapSymDimVector = SmallVector; + +// NOTE: [What is an VmapTransform?] +// An *VmapTransform* converts logical views of tensors to physical views. +// +// Batching rules use VmapTransforms to convert logical arguments to +// physical arguments, then call one or more at:: operator that handles the +// physical arguments, and then converts the physical result back to a logical +// argument. + +// VmapTransform for operators that take tensors with multiple batch dims. +// Given one or more logical views on Tensors, `logicalToPhysical` +// permutes all of the batch dims to the front of the tensor, aligns +// and expands the batch dims to match each other (according to their `level`), +// and returns a VmapPhysicalView on the tensor(s). +struct TORCH_API MultiBatchVmapTransform { + static VmapPhysicalView logicalToPhysical(const Tensor& logical_tensor); + static VmapPhysicalViewVec logicalToPhysical(ITensorListRef logical_tensors); +}; + +// VmapTransform for operators that broadcast all inputs. +// Given some logical views on Tensors, `logicalToPhysical`: +// - permutes all of the batch dims to the front of the tensors +// - aligns all the batch dims to the collective levels of all of the tensors. +// If a tensor does not have a batch dim for a vmap level, then it receives +// a size-one dimension for said level. +// - aligns the non-batch dims to have the same dimensionality, adding extra +// size-1 dimensions in between the batch dimensions and the non-batch +// dimensions so that the batch dimensions are lined up from the right. +// +// For example: given inputs of size (B, 2) and (B, 3, 2) where B is the batch +// dimension, BroadcastingVmapTransform returns VmapPhysicalViews that wrap +// tensors of size (B, 1, 2) and (B, 3, 2). +// +// Given inputs of size (B, 2) and (2,), BroadcastingVmapTransform returns +// VmapPhysicalViews wrapping tensors of size (B, 2) and (1, 2). We don't +// actually *need* to return a tensor of size (1, 2) for the second tensor +// because the broadcasting operation takes care of that for us, but we do +// it anyways to keep things simple. +struct TORCH_API BroadcastingVmapTransform { + static VmapPhysicalViewVec logicalToPhysical(TensorList logical_tensors); +}; + +// Forward declared, if you're reading this file head to toe, don't worry about +// it yet. +struct VmapPhysicalToLogicalMap; + +// NOTE: [What is a VmapPhysicalView?] +// VmapPhysicalView represents a physical view on a Tensor. +// +// One can use it to further convert logical dimension indices, logical shapes, +// and more to their physical variants, or convert a new (physical) tensor into +// a logical BatchedTensor. (TODO(rzou): some of these are not yet implemented). +// +// VmapPhysicalView stores a physical tensor with all of its batch dimensions at +// the front and some levels that correspond to said batch dimensions. +// +// The levels bitset specifies which vmap levels correspond to the batch +// dimensions at the front of the tensor. In particular, the number of set bits +// corresponds to the number of batch dimensions on `tensor` and the rightmost +// bit of `levels` specifies the maximum number of nested vmaps we are in at +// this point in time. +// For example, given: +// physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5, 6), levels={1, 3}) +// +// Rightmost bit of `levels` is 3 indicating the number of nested vmaps less +// than or equal to 3. +// bitset: 010100 +// ^ +// | +// levels: 012345 +struct TORCH_API VmapPhysicalView { + VmapPhysicalView(Tensor&& tensor, std::bitset levels) + : levels_(levels), tensor_(std::move(tensor)) { + TORCH_INTERNAL_ASSERT(!isBatchedTensor(tensor_)); + } + + Tensor& tensor() { + return tensor_; + } + const Tensor& tensor() const { + return tensor_; + } + + // Maps logical dim indices to physical dim indices. Also does dim wrapping. + // + // For example, given: + // physical_view = VmapPhysicalView(tensor=ones(2, 3, 4, 5), levels={1, 3}) + // + // Then physical_view.getPhysicalDims({0, 1}) returns {2, 3}. + // This is because the size of levels tell us that the first two dimensions + // of `tensor_` are batch dimensions, so a logical dim of `n` is actually + // a physical dim of `n + 2`. + VmapDimVector getPhysicalDims(OptionalIntArrayRef logical_dims) const; + int64_t getPhysicalDim(int64_t logical_dim) const; + + // Returns a VmapPhysicalToLogicalMap object. This can be used for + // mapping a physical tensor to a new logical tensor (BatchedTensor) + VmapPhysicalToLogicalMap getPhysicalToLogicalMap() const; + + // Maps a logical shape to a physical shape by pre-pending the batch + // sizes to the logical shape. + VmapDimVector getPhysicalShape(IntArrayRef logical_shape) const; + + int64_t numBatchDims() const; + + private: + int64_t numLogicalDims() const; + + std::bitset levels_; + Tensor tensor_; +}; + +// Convenience struct used for mapping a physical tensor (a non-BatchedTensor) +// to a logical one (BatchedTensor). It holds some levels that are used to do +// the mapping and assumes that the batch dimensions in the physical tensor all +// occur at the front of the tensor. +struct TORCH_API VmapPhysicalToLogicalMap { + VmapPhysicalToLogicalMap(std::bitset levels) + : levels_(levels) {} + + // Maps a physical tensor to a new logical tensor (BatchedTensor). + // Assumes that all of the "batch dimensions" are at the front + // of the physical tensor. For example, given: + // - x = rank-4 Tensor with size 2, 3, 5, 7 + // - levels = (2, 4) + // Returns: + // - BatchedTensor(x, bdims=[(dim=0,lvl=2), (dim=1, lvl=4)]) + Tensor apply(const Tensor& physical_tensor) const; + + // Given a vector of physical tensors, + // 1. maps each tensor to a new logical tensor. Assumes that all of the + // "batch dimensions" are at the front of the physical tensors. + // 2. stores the new logical tensors back into the passed-in vector. This is + // to avoid additional dynamic allocations. + void applyInplace(std::vector& physical_tensors) const; + + std::bitset levels_; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/LinalgBackend.h b/phivenv/Lib/site-packages/torch/include/ATen/LinalgBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..3b084d189d7fb61cc0f67ccc0be15614be7e490c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/LinalgBackend.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +enum class LinalgBackend : int8_t { Default, Cusolver, Magma }; + +inline std::string LinalgBackendToString(at::LinalgBackend backend) { + switch (backend) { + case LinalgBackend::Default: + return "at::LinalgBackend::Default"; + case LinalgBackend::Cusolver: + return "at::LinalgBackend::Cusolver"; + case LinalgBackend::Magma: + return "at::LinalgBackend::Magma"; + default: + TORCH_CHECK(false, "Unknown linalg backend"); + } +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::LinalgBackend backend) { + return stream << LinalgBackendToString(backend); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/MapAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/MapAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..d14d4dafe375580b18de895252c46087b6eb1b37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/MapAllocator.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include + +namespace at { + +enum MappedAllocatorModes { + ALLOCATOR_MAPPED_SHARED = 1, + ALLOCATOR_MAPPED_SHAREDMEM = 2, + ALLOCATOR_MAPPED_EXCLUSIVE = 4, + ALLOCATOR_MAPPED_NOCREATE = 8, + ALLOCATOR_MAPPED_KEEPFD = 16, + ALLOCATOR_MAPPED_FROMFD = 32, + ALLOCATOR_MAPPED_UNLINK = 64 +}; + +// Sentinel value/type to help distinguish the file descriptor constructor from +// the non-file descriptor constructor +enum WithFd { WITH_FD }; + +TORCH_API std::string NewProcessWideShmHandle(); + +class TORCH_API MapAllocator { + public: + MapAllocator(std::string_view filename, int flags, size_t size); + MapAllocator( + WithFd, + std::string_view filename, + int fd, + int flags, + size_t size); + MapAllocator(const MapAllocator&) = delete; + MapAllocator& operator=(const MapAllocator&) = delete; + MapAllocator(MapAllocator&&) = delete; + MapAllocator& operator=(MapAllocator&&) = delete; + + const char* filename() const { + return filename_.c_str(); + } + int fd() const { +#ifdef _WIN32 + TORCH_CHECK(false, "MapAllocator::fd() is unsupported on Windows"); +#else + return fd_; +#endif + } + ptrdiff_t size() const { + return size_; + } + // Return a pointer to the actual data for this allocator + // (in the case of the refcounted allocator, this is offset + // from the base pointer.) + virtual void* data() const { + return base_ptr_; + } + + int flags() const { + return flags_; + } + + static MapAllocator* fromDataPtr(const at::DataPtr&); + static at::DataPtr makeDataPtr( + std::string_view filename, + int flags, + size_t size, + size_t* actual_size_out); + static at::DataPtr makeDataPtr( + WithFd, + const char* filename, + int fd, + int flags, + size_t size, + size_t* actual_size_out); + + // Closes the data. Helps us avoid destructor shenanigans + virtual void close(); + + // This is very dangerous. You have to redefine this destructor for each + // subclass + virtual ~MapAllocator(); + + protected: + bool closed_ = false; + std::string filename_; + int flags_ = 0; + ptrdiff_t size_; /* mapped size */ +#ifdef _WIN32 + void* handle_; + void* event_; + std::string eventname_; +#else + int fd_ = -1; +#endif + void* base_ptr_ = nullptr; +}; + +// Base-from-member idiom +struct TORCH_API RefcountedMapAllocatorArgCheck { + RefcountedMapAllocatorArgCheck(int flags); +}; + +class TORCH_API RefcountedMapAllocator : private RefcountedMapAllocatorArgCheck, + public MapAllocator { + public: + RefcountedMapAllocator(const char* filename, int flags, size_t size); + RefcountedMapAllocator( + WithFd, + const char* filename, + int fd, + int flags, + size_t size); + + static RefcountedMapAllocator* fromDataPtr(const at::DataPtr&); + RefcountedMapAllocator(const RefcountedMapAllocator&) = delete; + RefcountedMapAllocator(RefcountedMapAllocator&&) = delete; + RefcountedMapAllocator& operator=(const RefcountedMapAllocator&) = delete; + RefcountedMapAllocator& operator=(RefcountedMapAllocator&&) = delete; + static at::DataPtr makeDataPtr( + const char* filename, + int flags, + size_t size, + size_t* actual_size_out); + static at::DataPtr makeDataPtr( + WithFd, + const char* filename, + int fd, + int flags, + size_t size, + size_t* actual_size_out); + + void* data() const override; + + void incref(); + int decref(); + void close() override; + + ~RefcountedMapAllocator() override { + RefcountedMapAllocator::close(); + } + + protected: + void checkFlags(); + void initializeAlloc(); +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/MatrixRef.h b/phivenv/Lib/site-packages/torch/include/ATen/MatrixRef.h new file mode 100644 index 0000000000000000000000000000000000000000..354d7c241c9a732b34b839c27b59a27f71d2ea2f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/MatrixRef.h @@ -0,0 +1,109 @@ +#pragma once +#include +#include + +namespace at { +/// MatrixRef - Like an ArrayRef, but with an extra recorded strides so that +/// we can easily view it as a multidimensional array. +/// +/// Like ArrayRef, this class does not own the underlying data, it is expected +/// to be used in situations where the data resides in some other buffer. +/// +/// This is intended to be trivially copyable, so it should be passed by +/// value. +/// +/// For now, 2D only (so the copies are actually cheap, without having +/// to write a SmallVector class) and contiguous only (so we can +/// return non-strided ArrayRef on index). +/// +/// P.S. dimension 0 indexes rows, dimension 1 indexes columns +template +class MatrixRef { + public: + typedef size_t size_type; + + private: + /// Underlying ArrayRef + ArrayRef arr; + + /// Stride of dim 0 (outer dimension) + size_type stride0; + + // Stride of dim 1 is assumed to be 1 + + public: + /// Construct an empty Matrixref. + /*implicit*/ MatrixRef() : arr(nullptr), stride0(0) {} + + /// Construct an MatrixRef from an ArrayRef and outer stride. + /*implicit*/ MatrixRef(ArrayRef arr, size_type stride0) + : arr(arr), stride0(stride0) { + TORCH_CHECK( + arr.size() % stride0 == 0, + "MatrixRef: ArrayRef size ", + arr.size(), + " not divisible by stride ", + stride0) + } + + /// @} + /// @name Simple Operations + /// @{ + + /// empty - Check if the matrix is empty. + bool empty() const { + return arr.empty(); + } + + const T* data() const { + return arr.data(); + } + + /// size - Get size a dimension + size_t size(size_t dim) const { + if (dim == 0) { + return arr.size() / stride0; + } else if (dim == 1) { + return stride0; + } else { + TORCH_CHECK( + 0, "MatrixRef: out of bounds dimension ", dim, "; expected 0 or 1"); + } + } + + size_t numel() const { + return arr.size(); + } + + /// equals - Check for element-wise equality. + bool equals(MatrixRef RHS) const { + return stride0 == RHS.stride0 && arr.equals(RHS.arr); + } + + /// @} + /// @name Operator Overloads + /// @{ + ArrayRef operator[](size_t Index) const { + return arr.slice(Index * stride0, stride0); + } + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + std::enable_if_t, MatrixRef>& operator=( + // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward) + U&& Temporary) = delete; + + /// Disallow accidental assignment from a temporary. + /// + /// The declaration here is extra complicated so that "arrayRef = {}" + /// continues to select the move assignment operator. + template + std::enable_if_t, MatrixRef>& operator=( + std::initializer_list) = delete; +}; + +} // end namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/MemoryOverlap.h b/phivenv/Lib/site-packages/torch/include/ATen/MemoryOverlap.h new file mode 100644 index 0000000000000000000000000000000000000000..f8427eef13cdd1741262f4dcdb84900389157e22 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/MemoryOverlap.h @@ -0,0 +1,42 @@ +#pragma once + +#include + +namespace c10 { +struct TensorImpl; +} + +namespace at { +class TensorBase; + +// MemOverlap: Whether or not there is memory overlap +// +// No: Absolutely no memory overlap +// Yes: Absolutely yes memory overlap +// TooHard: There might be memory overlap, but it was too expensive to compute. +// +// NB: Please update the python test for these if you renumber them. +enum class MemOverlap { No, Yes, TooHard }; + +enum class MemOverlapStatus { Full, Partial, No, TooHard }; + +TORCH_API MemOverlap has_internal_overlap(const TensorBase& t); +TORCH_API MemOverlap has_internal_overlap(c10::TensorImpl* t); + +TORCH_API void assert_no_internal_overlap(const TensorBase& t); +TORCH_API void assert_no_internal_overlap(c10::TensorImpl* t); + +TORCH_API MemOverlapStatus +get_overlap_status(const TensorBase& a, const TensorBase& b); +TORCH_API MemOverlapStatus +get_overlap_status(const c10::TensorImpl* a, const c10::TensorImpl* b); + +TORCH_API void assert_no_partial_overlap( + const TensorBase& a, + const TensorBase& b); +void assert_no_partial_overlap(c10::TensorImpl* a, c10::TensorImpl* b); + +TORCH_API void assert_no_overlap(const TensorBase& a, const TensorBase& b); +TORCH_API void assert_no_overlap(c10::TensorImpl* a, c10::TensorImpl* b); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/MetaFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/MetaFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..4a676f9d883198b1af84395035e2fc20cb0954c6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/MetaFunctions.h @@ -0,0 +1,29 @@ +#include + +// TODO Undo all logic introduced for Note [Avoiding Include Cycles In Static Dispatch] +// Code introduced to avoid cyclic dependency in static dispatch is no longer +// needed as static dispatch logic is moved from TensorBody.h, which caused cycles in the first place, +// to Operators.cpp for supporting multiple backends with multiple kernels. +// +// Note [Avoiding Include Cycles In Static Dispatch] +// In order to avoid #include cycles in the static dispatch build, we've carefully split out +// the static function definition files into {DispatchKey}Functions.h and {DispatchKey}Functions_inl.h. +// +// Without this split, the include cycle looks like TensorBody.h -> CPUFunctions.h -> TensorBody.h. +// - TensorBody.h #includes CPUFunctions.h in the static dispatch build, because the tensor methods +// all need to call into the fastpath C++ API defined in CPUFunctions.h. The methods are also all +// directly inlined into TensorBody.h. +// - CPUFunctions.h #includes TensorBody.h because it contains function declarations for the entire C++ API, +// which include functions that have defaultable std::optional arguments. +// That requires knowing the full Tensor class definition. +// +// We break the cycle by doing the following: +// - Split out CPUFunction.h into two files: CPUFunctions.h and CPUFunctions_inl.h +// - CPUFunction.h is a dummy file that just includes the Tensor class and includes CPUFunctions_inl., +// - CPUFunctions_inl.h includes everything else +// - (only in the static dispatch build) TensorBody.h makes sure to finish defining the Tensor class, +// and then it includes CPUFunctions_inl.h. +// - All other files that want the cpu fastpath functions can include CPUFunctions.h directly. +// - This also means that static dispatch build, CPUFunctions.h only needs to +// #include TensorBody.h, and it will automatically bring in CPUFunctions_inl.h. +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..0465a51e24beda86aa3ac4a1444fd31610545af0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/MetaFunctions_inl.h @@ -0,0 +1,326 @@ +#pragma once +// @generated by torchgen/gen.py from DispatchKeyFunctions_inl.h + +// NB: The implementing C++ file is RegisterDispatchKey.cpp + +// The only #includes we need are for custom classes that have defaults in the C++ API +#include +#include +#include + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + . \ + See NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/MethodOperators.h b/phivenv/Lib/site-packages/torch/include/ATen/MethodOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..c9848f67d4b24fcb5d69f0396de5930271b4ac64 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/MethodOperators.h @@ -0,0 +1,443 @@ +#pragma once + +// @generated by torchgen/gen.py from MethodOperators.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +// Forward declarations of any types needed in the operator signatures. +// We can't directly include these classes because it will cause circular include dependencies. +// This file is included by TensorBody.h, which defines the Tensor class. +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { +namespace _ops { + +} // namespace _ops +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/NamedTensor.h b/phivenv/Lib/site-packages/torch/include/ATen/NamedTensor.h new file mode 100644 index 0000000000000000000000000000000000000000..b18f8d95b195a19fc5c78cc941b7ce6de28f4534 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/NamedTensor.h @@ -0,0 +1 @@ +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/NamedTensorUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/NamedTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..a8974221b785bf12e4d9d07fdc92c56803fcbcb2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/NamedTensorUtils.h @@ -0,0 +1,214 @@ +#pragma once +#include +#include +#include + +#include +#include + +namespace at { + +using NameVector = SmallVector; + +inline bool has_names(const ITensorListRef& tensors) { + return std::any_of(tensors.begin(), tensors.end(), [](const Tensor& t) { + return t.has_names(); + }); +} + +// Converts dim to an positional index. Errors if `dim` cannot be used to +// refer to any dimension of tensor. +TORCH_API int64_t dimname_to_position(const Tensor& tensor, Dimname dim); +TORCH_API std::vector dimnames_to_positions( + const Tensor& tensor, + DimnameList dims); + +// Unifies two DimnameList to produce a third. This is useful for implementing +// the named inference rule for binary broadcasting operations like add. +// +// There are three main constraints: +// 1) Check matching: Names must match positionally from the right. +// 2) Check misaligned: If a name `n` is in `names`, then it must appear at +// the same index from the right in other. +// 3) The output names are obtained by unifying the names individually from the +// right. +TORCH_API std::vector unify_from_right( + DimnameList names, + DimnameList other, + const char* action = "broadcast"); + +[[noreturn]] inline void reportNYIDimnameOverload(const char* op_name) { + TORCH_CHECK( + false, + op_name, + ": You passed a dimname (string) to this op in place of a dimension " + "index but it does not yet support this behavior. Please pass a dimension " + "index to work around this."); +} + +// [NOTE] Writing name inference rules +// +// Operators that support named tensors are either composed of operations that +// support named tensors or implement some name inference rule. An op that +// implements its own name inference rule generally looks like the following: +// +// Tensor op(...) { +// perform_shape_checks(...); +// # (1) +// auto maybe_outnames = compute_outnames(...); +// auto result = [&]() { +// NoNamesGuard guard; +// return op_impl(...); +// }(); +// # (2) +// propagate_names_if_nonempty(result, maybe_outnames); +// +// Each op has (1) a compute outnames step and (2) a propagate names step. +// +// compute_outnames is responsible for checking that input names match and +// determining what the output names should be. It returns either: +// - {} (if the inputs tensors are all unnamed) +// - non-empty outnames. +// +// propagate_names_if_nonempty propagates the outnames if they exist to the +// result tensors. +// +// The {} case is an optimization; if the user does not use named tensors they +// pay no perf cost for it. + +namespace namedinference { + +const Tensor& propagate_names_if_present_and_nonempty( + const Tensor& result, + std::optional maybe_names, + bool validate_names = false); +// Propagates `names` to `result` if `names` is not empty. +// `names` can be empty; see [NOTE] Writing name inference rules +// If `names` is not empty, `names.size()` should equal `result.dim()`. +// When in doubt, use this overload instead of the others. +TORCH_API const Tensor& propagate_names_if_nonempty( + const Tensor& result, + DimnameList maybe_names, + bool validate_names = false); + +// Propagates `names` to `result`. Only use this if we are certain that there +// are names to propagate (that names is not empty). +TORCH_API const Tensor& propagate_names( + const Tensor& result, + DimnameList names, + bool validate_names = false); + +// Propagates all names from src to result. +TORCH_API void propagate_names(const Tensor& result, const Tensor& src); + +// Propagates all names except for those at the excluded_idxs. +TORCH_API void propagate_names_except( + const Tensor& result, + const Tensor& src, + IntArrayRef excluded_idxs); + +// Used for reduction ops that have a `keepdim` arg. +TORCH_API void propagate_names_for_reduction( + const Tensor& result, + const Tensor& src, + IntArrayRef excluded_idxs, + bool keepdim); + +TORCH_API void propagate_names_for_expand( + const Tensor& result, + const Tensor& self); + +TORCH_API std::vector compute_cat_outnames( + const MaterializedITensorListRef& tensors); + +TORCH_API std::vector compute_broadcast_outnames( + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector broadcast_to_outnames( + const Tensor& tensor, + const Tensor& reference_tensor, + const char* op_name); + +TORCH_API std::vector compute_matmul_outnames( + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector compute_cdist_outnames( + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector compute_bmm_outnames( + const Tensor& result, + const Tensor& self, + const Tensor& other); + +TORCH_API std::vector compute_squeeze_outnames(const Tensor& tensor); +TORCH_API std::vector compute_squeeze_outnames( + const Tensor& tensor, + std::bitset dims); + +std::vector compute_diagonal_outnames( + const Tensor& tensor, + int64_t dim1, + int64_t dim2); + +// TensorImpl* overloads for Legacy TH/THC code. Use these sparingly. + +TORCH_API TensorImpl* propagate_names_if_nonempty( + TensorImpl* result, + DimnameList maybe_names, + bool validate_names = false); + +TORCH_API TensorImpl* propagate_names( + TensorImpl* result, + DimnameList names, + bool validate_names = false); + +TORCH_API void propagate_names(TensorImpl* result, /*const */ TensorImpl* src); + +TORCH_API inline void propagate_names( + const TensorBase& result, + DimnameList names, + bool validate_names = false) { + propagate_names(result.unsafeGetTensorImpl(), names, validate_names); +} + +TORCH_API inline void propagate_names_if_nonempty( + const TensorBase& result, + DimnameList names, + bool validate_names = false) { + propagate_names_if_nonempty( + result.unsafeGetTensorImpl(), names, validate_names); +} + +TORCH_API inline void propagate_names( + const TensorBase& result, + const TensorBase& src) { + propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl()); +} + +// result = m1 @ m2 + bias +TORCH_API std::vector propagate_names_for_addmm( + const Tensor& m1, + const Tensor& m2, + const Tensor& bias); + +TORCH_API std::vector propagate_names_for_addmv( + const Tensor& mat, + const Tensor& vec, + const Tensor& bias); + +TORCH_API void check_names_for_dot(TensorImpl* vec1, TensorImpl* vec2); + +TORCH_API std::vector compute_baddbmm_outnames( + const Tensor& result, + const Tensor& self, + const Tensor& other, + const Tensor& bias); + +TORCH_API bool are_names_equal(TensorImpl* self, TensorImpl* other); + +} // namespace namedinference + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/NativeFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/NativeFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..4572b2ce11d509c5e81d95f6c2f30bb058b601fc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/NativeFunctions.h @@ -0,0 +1,1355 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeFunctions.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + diff --git a/phivenv/Lib/site-packages/torch/include/ATen/NativeMetaFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/NativeMetaFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..b9aca975f77e3edee853d99e18c47d0e95ea2e01 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/NativeMetaFunctions.h @@ -0,0 +1,1341 @@ +#pragma once + +// @generated by torchgen/gen.py from NativeMetaFunctions.h + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +namespace meta { + + + +} // namespace meta +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/NestedTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/NestedTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..56b7bf40a78d02588e2d2d253c12017cedf8bed1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/NestedTensorImpl.h @@ -0,0 +1,286 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at::native { +struct NestedTensorImpl; +inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt); +int64_t get_numel_from_nested_size_tensor(const at::Tensor& tensor); +at::Tensor construct_nested_strides(const at::Tensor& nested_size); +at::Tensor construct_offsets(const at::Tensor& nested_size); + +struct TORCH_API NestedTensorImpl : public c10::TensorImpl { + explicit NestedTensorImpl( + Storage storage, + c10::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets); + + explicit NestedTensorImpl( + const at::Tensor& buffer, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets); + // assume contiguous, `nested_strides` and `offsets` + // can be infered from `nested_sizes` + explicit NestedTensorImpl( + const at::Tensor& buffer, + const at::Tensor& nested_sizes); + + // This constructor is used creating view tensors from nested tensors + explicit NestedTensorImpl( + c10::TensorImpl::ImplType impl_type, + const at::Tensor& base_tensor, + at::Tensor nested_sizes, + at::Tensor nested_strides, + at::Tensor storage_offsets); + + // TODO: don't expose private implementation details like this; in + // particular, resizing this tensor will mess up our dim() and + // callers cannot fix it. + const Tensor& get_nested_sizes() const { + return nested_sizes_; + } + // TODO: don't expose private implementation details like this + const Tensor& get_nested_strides() const { + return nested_strides_; + } + const Tensor& get_storage_offsets() const { + return storage_offsets_; + } + // Returns nullopt if the ith dimension is irregular. The ith dimension + // of a NestedTensor is regular if the unbound tensors match in + // size at the (i-1)th dimension. + std::optional opt_size(int64_t d) const; + + int64_t size(int64_t d) const { + std::optional optional_size = this->opt_size(d); + TORCH_CHECK( + optional_size.has_value(), + "Given dimension ", + d, + " is irregular and does not have a size."); + return *optional_size; + } + /** + * Return a view of the nested tensor as a 1 dimensional contiguous tensor. + * + * The buffer tensor created by this function shares the same storage_impl as + * the original nested tensor, and therefore can be seen as a view. + * + * @return A newly constructed view tensor + */ + at::Tensor get_buffer() const { + TORCH_CHECK( + nested_tensor_impl_is_contiguous(this), + "NestedTensor must be contiguous to get buffer."); + return get_unsafe_storage_as_tensor(); + } + /** + * If possible use get_buffer() instead. This function returns the storage + * as a tensor directly, which is not safe to use in general. If using this + * function, The caller must ensure to account for nested_sizes, + * nested_strides and storage_offsets. + * + * @return A newly constructed view tensor + */ + at::Tensor get_unsafe_storage_as_tensor() const { + auto buffer_key_set_ = generate_buffer_key_set(); + const auto buffer_size = get_buffer_size(); + auto buffer_tensor_impl = c10::make_intrusive( + c10::TensorImpl::VIEW, Storage(storage_), buffer_key_set_, data_type_); + buffer_tensor_impl->set_sizes_contiguous( + c10::makeArrayRef(static_cast(buffer_size))); + return Tensor(buffer_tensor_impl); + } + + size_t get_buffer_size() const { + return storage_.nbytes() / data_type_.itemsize(); + } + + protected: + const char* tensorimpl_type_name() const override; + + // TODO: numel_custom and is_contiguous_custom can be profitably overridden + // with real implementations + int64_t numel_custom() const override; + c10::SymInt sym_numel_custom() const override; + bool is_contiguous_custom(MemoryFormat) const override; + int64_t size_custom(int64_t d) const override { + return this->size(d); + } + c10::SymInt sym_size_custom(int64_t d) const override { + return c10::SymInt{this->size(d)}; + } + IntArrayRef sizes_custom() const override; + c10::SymIntArrayRef sym_sizes_custom() const override; + IntArrayRef strides_custom() const override; + c10::SymIntArrayRef sym_strides_custom() const override; + + // this one is real + int64_t dim_custom() const override; + + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override; + + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override; + + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + copy_tensor_metadata( + /*src_impl=*/impl.get(), + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + } + + private: + // Must be called after any changes to our dim() to sync the state + // to TensorImpl. + void refresh_dim(); + + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const at::Tensor nested_sizes_, nested_strides_; + // The starting positions of the underlying tensors in contiguous buffer + // i.e. the buffer memory offsets to get the underlying tensors + // The reason to keep this metadata is that, without strong enough constraint + // it cannot be derived from `nested_sizes_` + // and `nested_strides_`: + // 1. when buffer has blanks, e.g. [tensor1, blank, tensor2] + // this can happen e.g. after slicing a nested tensor + // 2. when multiple tensors share a same memory + // 3. when the nesting ordering is changed, e.g. [tensor1, tensor3, tensor2] + // Some strong enough constraints are: + // 1. every underlying tensor is contiguous in memory + // && nesting in ascending order + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const at::Tensor storage_offsets_; + // NOTE: -1 here means the size is missing + // Optional to allow it to be computed lazily from nested. + // TODO: maybe we can remove this metadata since + // we can compute it from `nested_sizes_` + mutable std::optional> opt_sizes_; + + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const; + + /** + * Generates a non-nested key_set from a nested tensor. + * + * For many nested tensor kernel implementations a buffer tensor + * is generated and redispatched to a non-nested kernel this function + * generates the key set used by that buffer tensor + * + * @return Appropriate key set for non-nested tensor + */ + inline c10::DispatchKeySet generate_buffer_key_set() const { + auto buffer_key_set = this->key_set(); + const bool Autograd = buffer_key_set.has_any(c10::autograd_dispatch_keyset); + // Remove nested tensor specific keys + buffer_key_set = buffer_key_set - + c10::DispatchKeySet{ + c10::DispatchKey::NestedTensor, + c10::DispatchKey::AutogradNestedTensor}; + + // Add dense tensor specific keys + buffer_key_set = + buffer_key_set | c10::DispatchKeySet{c10::DispatchKey::Dense}; + buffer_key_set = Autograd + ? c10::DispatchKeySet{c10::DispatchKey::Autograd} | buffer_key_set + : buffer_key_set; + + return buffer_key_set; + } +}; + +inline NestedTensorImpl* get_nested_tensor_impl_or_null( + const at::Tensor& tensor) { + if (tensor.is_nested()) { + return static_cast(tensor.unsafeGetTensorImpl()); + } + return nullptr; +} + +inline NestedTensorImpl* get_nested_tensor_impl(const at::Tensor& tensor) { + TORCH_CHECK( + tensor.is_nested(), "get_nested_tensor_impl requires a NestedTensor."); + return static_cast(tensor.unsafeGetTensorImpl()); +} + +inline bool nested_tensor_impl_is_contiguous(const NestedTensorImpl* nt) { + int64_t ntensors = nt->size(0); + if (ntensors == 0) { + return true; + } + const Tensor &sizemat = nt->get_nested_sizes(), + &stridemat = nt->get_nested_strides(); + const int64_t* offsets_ptr = + nt->get_storage_offsets().const_data_ptr(); + int64_t orig_dim = sizemat.size(1); + // nesting scalars + if (orig_dim == 0) { + // each scalar must be contiguous + // if there is blank memory between underlying scalars + for (int64_t i = 0; i < ntensors; i++) { + if (offsets_ptr[i] != i) { + return false; + } + } + } + // nesting tensors + else { + // if any underlying tensor is non-contiguous + const int64_t *sizemat_ptr = sizemat.const_data_ptr(), + *stridemat_ptr = stridemat.const_data_ptr(); + for (int64_t i = 0; i < ntensors; i++) { + if (stridemat_ptr[orig_dim - 1] != 1) { + return false; + } + int64_t product = sizemat_ptr[orig_dim - 1]; + for (int64_t j = orig_dim - 2; j >= 0; j--) { + if (stridemat_ptr[j] != product) { + return false; + } + product *= sizemat_ptr[j]; + } + sizemat_ptr += orig_dim; + stridemat_ptr += orig_dim; + } + // if there is blank memory between underlying tensors + if (offsets_ptr[0] != 0) { + return false; + } + sizemat_ptr = sizemat.const_data_ptr(); + stridemat_ptr = stridemat.const_data_ptr(); + for (int64_t i = 1; i < ntensors; i++) { + if (offsets_ptr[i] != + offsets_ptr[i - 1] + *sizemat_ptr * *stridemat_ptr) { + return false; + } + sizemat_ptr += orig_dim; + stridemat_ptr += orig_dim; + } + } + // everything is fine + return true; +} + +inline const at::Tensor& get_nested_sizes(const at::Tensor& tensor) { + return get_nested_tensor_impl(tensor)->get_nested_sizes(); +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/NumericUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/NumericUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..d9a95536de5b328b3df2e1918d8744fbc1b18e24 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/NumericUtils.h @@ -0,0 +1,203 @@ +#pragma once + +#ifdef __HIPCC__ +#include +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { + +// std::isnan isn't performant to use on integral types; it will +// (uselessly) convert to floating point and then do the test. +// This function is. + +template , int> = 0> +inline C10_HOST_DEVICE bool _isnan(T /*val*/) { + return false; +} + +template , int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isnan(val); +#else + return std::isnan(val); +#endif +} + +template ::value, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return std::isnan(val.real()) || std::isnan(val.imag()); +} + +template , int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return at::_isnan(static_cast(val)); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { + return at::_isnan(static_cast(val)); +} + +inline C10_HOST_DEVICE bool _isnan(at::BFloat16 val) { + return at::_isnan(static_cast(val)); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +template < + typename T, + std::enable_if_t, int> = 0> +inline C10_HOST_DEVICE bool _isnan(T val) { + return val.isnan(); +} + +// std::isinf isn't performant to use on integral types; it will +// (uselessly) convert to floating point and then do the test. +// This function is. + +template , int> = 0> +inline C10_HOST_DEVICE bool _isinf(T /*val*/) { + return false; +} + +template , int> = 0> +inline C10_HOST_DEVICE bool _isinf(T val) { +#if defined(__CUDACC__) || defined(__HIPCC__) + return ::isinf(val); +#else + return std::isinf(val); +#endif +} + +inline C10_HOST_DEVICE bool _isinf(at::Half val) { + return at::_isinf(static_cast(val)); +} + +inline C10_HOST_DEVICE bool _isinf(at::BFloat16 val) { + return at::_isinf(static_cast(val)); +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2 val) { + return val.isinf(); +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fn val [[maybe_unused]]) { + return false; +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e5m2fnuz val [[maybe_unused]]) { + return false; +} + +inline C10_HOST_DEVICE bool _isinf(at::Float8_e4m3fnuz val [[maybe_unused]]) { + return false; +} + +template +C10_HOST_DEVICE inline T exp(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __expf fast approximation for peak bandwidth + return __expf(x); +#else + return ::exp(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double exp(double x) { + return ::exp(x); +} + +template +C10_HOST_DEVICE inline T log(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + return __logf(x); +#else + return ::log(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double log(double x) { + return ::log(x); +} + +template +C10_HOST_DEVICE inline T log1p(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __logf fast approximation for peak bandwidth + // NOTE: There is no __log1pf so unfortunately we lose precision. + return __logf(1.0f + x); +#else + return ::log1p(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double log1p(double x) { + return ::log1p(x); +} + +template +C10_HOST_DEVICE inline T tan(T x) { + static_assert( + !std::is_same_v, + "this template must be used with float or less precise type"); +#if defined(__CUDA_ARCH__) || defined(__HIP_ARCH__) + // use __tanf fast approximation for peak bandwidth + return __tanf(x); +#else + return ::tan(x); +#endif +} + +template <> +C10_HOST_DEVICE inline double tan(double x) { + return ::tan(x); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/OpMathType.h b/phivenv/Lib/site-packages/torch/include/ATen/OpMathType.h new file mode 100644 index 0000000000000000000000000000000000000000..b540bdf4740e7664e2f0b2979897178b92b99a4f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/OpMathType.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// For FP16 or BFloat16 inputs, ops should perform internal math in FP32. +template +struct OpMathType { + using type = scalar_t; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType { + using type = float; +}; +template <> +struct OpMathType> { + using type = c10::complex; +}; + +template +using opmath_type = typename OpMathType::type; + +namespace { + +inline c10::ScalarType toOpMathType(const c10::ScalarType type) { + switch (type) { +#define DEFINE_CASE(scalar_t, TypeNum) \ + case ScalarType::TypeNum: \ + return CppTypeToScalarType>::value; + + AT_FORALL_SCALAR_TYPES_WITH_COMPLEX(DEFINE_CASE) +#undef DEFINE_CASE + + default: + TORCH_INTERNAL_ASSERT(false, "Unrecognized ScalarType: ", type); + } +} + +} // namespace + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/OpaqueTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/OpaqueTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..e4128783f04bc51befef326dfcea4cf22228b348 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/OpaqueTensorImpl.h @@ -0,0 +1,206 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { + +// An "Opaque" TensorImpl -- there are no strides and (for now) +// even data() is not supported (thus no pointer arithmetic). + +// NOTE: We could allow data() in the future, but would have to ensure pointer +// arithmetic code is properly guarded. +// +// NOTE: This does not support resize_ (and other metadata-changing ops) because +// of `shallow_copy_and_detach`. We would need to define an interface to +// "shallow copy" in order to add support. + +template +struct TORCH_API OpaqueTensorImpl : public TensorImpl { + // public constructor for now... + OpaqueTensorImpl( + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + c10::Device device, + OpaqueHandle opaque_handle, + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense = true) + : TensorImpl(key_set, data_type, device), + opaque_handle_(std::move(opaque_handle)) { + constructor_impl(sizes, is_non_overlapping_and_dense); + } + + OpaqueTensorImpl( + TensorImpl::ImplType impl_type, + c10::Storage&& storage, + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + OpaqueHandle opaque_handle, + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense = true) + : TensorImpl(impl_type, std::move(storage), key_set, data_type), + opaque_handle_(std::move(opaque_handle)) { + constructor_impl(sizes, is_non_overlapping_and_dense); + } + + // Destructor doesn't call release_resources because it's + // unnecessary; don't forget to change that if needed! + void release_resources() override { + TensorImpl::release_resources(); + opaque_handle_ = {}; + } + + void set_size(int64_t dim, int64_t new_size) override { + TORCH_CHECK(false, "opaque tensors do not have set_size"); + } + + void set_stride(int64_t dim, int64_t new_stride) override { + TORCH_CHECK(false, "opaque tensors do not have set_stride"); + } + + void set_storage_offset(int64_t storage_offset) override { + TORCH_CHECK(false, "opaque tensors do not have set_storage_offset"); + } + +#ifdef DEBUG + bool has_storage() const override { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + !storage_, "OpaqueTensorImpl assumes that storage_ is never set"); + return false; + } +#endif + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive>( + key_set(), + dtype(), + device(), + opaque_handle_, + sizes_and_strides_.sizes_arrayref()); + copy_tensor_metadata( + /*src_opaque_impl=*/this, + /*dest_opaque_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + auto impl = c10::make_intrusive>( + key_set(), + dtype(), + device(), + opaque_handle_, + sizes_and_strides_.sizes_arrayref()); + copy_tensor_metadata( + /*src_opaque_impl=*/this, + /*dest_opaque_impl=*/impl.get(), + /*version_counter=*/std::move(version_counter), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); + auto opaque_impl = + static_cast*>(impl.get()); + copy_tensor_metadata( + /*src_impl=*/opaque_impl, + /*dest_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + refresh_numel(); + } + + const OpaqueHandle& opaque_handle() const { + return opaque_handle_; + } + + OpaqueHandle& unsafe_opaque_handle() { + return opaque_handle_; + } + + protected: + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const OpaqueTensorImpl* src_opaque_impl, + OpaqueTensorImpl* dest_opaque_impl, + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_opaque_impl, + dest_opaque_impl, + version_counter, + allow_tensor_metadata_change); + + // OpaqueTensorImpl-specific fields. + dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; + } + + static void copy_tensor_metadata( + const OpaqueTensorImpl* src_opaque_impl, + OpaqueTensorImpl* dest_opaque_impl, + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_opaque_impl, + dest_opaque_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // OpaqueTensorImpl-specific fields. + dest_opaque_impl->opaque_handle_ = src_opaque_impl->opaque_handle_; + } + + private: + const char* tensorimpl_type_name() const override { + return "OpaqueTensorImpl"; + } + + void constructor_impl( + c10::IntArrayRef sizes, + bool is_non_overlapping_and_dense) { + set_storage_access_should_throw(); + set_custom_sizes_strides(SizesStridesPolicy::CustomStrides); + sizes_and_strides_.set_sizes(sizes); + refresh_numel(); + // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer) + is_non_overlapping_and_dense_ = is_non_overlapping_and_dense; + } + + OpaqueHandle opaque_handle_; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Operators.h b/phivenv/Lib/site-packages/torch/include/ATen/Operators.h new file mode 100644 index 0000000000000000000000000000000000000000..764cb106347f0ebef1e9cccac1892cb33c00b886 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Operators.h @@ -0,0 +1,1396 @@ +#pragma once + +// @generated by torchgen/gen.py from Operators.h + +#ifdef TORCH_ASSERT_NO_OPERATORS +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if your change would be better placed in \ + another file, or if a more specific header might achieve the same goal. \ + See NOTE: [Tensor vs. TensorBase] +#endif + +#if defined(AT_PER_OPERATOR_HEADERS) && defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider including a specific operator from \ + and see NOTE [TORCH_ASSERT_ONLY_METHOD_OPERATORS]. +#endif + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// Extension writers: do you write wrapper functions? Are you frustrated with +// resolving overloads of operators? Are you frustrated with dealing with +// pointer-to-methods and resolving overloads of pointer-to-methods?? Look no +// further, this is the utility for you. +// +// Given an operator schema: aten::op.overload(... +// +// Use ATEN_FN2(op, overload) to get a *function* version of the operator +// that is guaranteed to not be overloaded. This means that you can safely +// decltype(&ATEN_FN2(op, overload)) it. NB: the 2 means this macro takes 2 args. +// +// Given an operator schema without an overload name: aten::op(... +// +// Use ATEN_FN(op) to get an unambiguous *function* version of the operator. +// +// There is some interesting behavior for out= operations. +// ATEN_FN2(sin, out) gives a function that is *faithful* to the schema; +// that is, the order of arguments is exactly what it looks like in the schema. + +#define ATEN_FN2(op_name, overload) at::_ops::op_name##_##overload::call +#define ATEN_FN(op_name) at::_ops::op_name::call + +// Separately, ATEN_OP(op) and ATEN_OP2(op, overload) define a class containing compile-time +// metadata about a given aten operator. +// Notable data on the class includes: +// - ATEN_OP2(add, Tensor)::name // returns the string name: "add" +// - ATEN_OP2(add, Tensor)::overload_name // returns the string overload name: "Tensor" +// - ATEN_OP2(add, Tensor)::schema // returns the C++ schema type: at::Tensor (const at::Tensor &, const at::Tensor &, const at::Scalar &) +// - ATEN_OP2(add, Tensor)::schema_str // returns the string jit type: "add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor" + +#define ATEN_OP2(op_name, overload) at::_ops::op_name##_##overload +#define ATEN_OP(op_name) at::_ops::op_name + +// WARNING: Please do not call any of the ops in the _ops namespace directly. +// Use the ATEN_FN macros. We do not guarantee stability of the naming +// scheme for the functions in at::_ops + +// See Note [The ATen Operators API] for details of the at::_ops namespace + +namespace at { +namespace _ops { + +} // namespace _ops +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/PTThreadPool.h b/phivenv/Lib/site-packages/torch/include/ATen/PTThreadPool.h new file mode 100644 index 0000000000000000000000000000000000000000..d18d80161296db96fc6cc0c89ba4546490b6e5a4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/PTThreadPool.h @@ -0,0 +1,17 @@ +#pragma once + +#include +#include + +namespace at { + +class TORCH_API PTThreadPool : public c10::ThreadPool { + public: + explicit PTThreadPool(int pool_size, int numa_node_id = -1) + : c10::ThreadPool(pool_size, numa_node_id, []() { + c10::setThreadName("PTThreadPool"); + at::init_num_threads(); + }) {} +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/PadNd.h b/phivenv/Lib/site-packages/torch/include/ATen/PadNd.h new file mode 100644 index 0000000000000000000000000000000000000000..fb7b1558a37e596ecb7df928af2533707d020bec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/PadNd.h @@ -0,0 +1,12 @@ +#pragma once + +namespace at { + +enum class padding_mode { + reflect, + replicate, + circular, + constant, +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Parallel-inl.h b/phivenv/Lib/site-packages/torch/include/ATen/Parallel-inl.h new file mode 100644 index 0000000000000000000000000000000000000000..966aa4b6371df7442cade150cae890bd772e4491 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Parallel-inl.h @@ -0,0 +1,93 @@ +#pragma once + +#include +#include +#include + +namespace at { + +template +inline void parallel_for( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f) { + TORCH_INTERNAL_ASSERT_DEBUG_ONLY(grain_size >= 0); + if (begin >= end) { + return; + } + +#ifdef INTRA_OP_PARALLEL + at::internal::lazy_init_num_threads(); + const auto numiter = end - begin; + const bool use_parallel = + (numiter > grain_size && numiter > 1 && !at::in_parallel_region() && + at::get_num_threads() > 1); + if (!use_parallel) { + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + f(begin, end); + return; + } + + internal::invoke_parallel( + begin, end, grain_size, [&](int64_t begin, int64_t end) { + c10::ParallelGuard guard(true); + f(begin, end); + }); +#else + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + f(begin, end); +#endif +} + +template +inline scalar_t parallel_reduce( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const scalar_t ident, + const F& f, + const SF& sf) { + TORCH_CHECK(grain_size >= 0); + if (begin >= end) { + return ident; + } + +#ifdef INTRA_OP_PARALLEL + at::internal::lazy_init_num_threads(); + const auto max_threads = at::get_num_threads(); + const bool use_parallel = + ((end - begin) > grain_size && !at::in_parallel_region() && + max_threads > 1); + if (!use_parallel) { + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + return f(begin, end, ident); + } + + c10::SmallVector results(max_threads, ident); + internal::invoke_parallel( + begin, + end, + grain_size, + [&](const int64_t my_begin, const int64_t my_end) { + const auto tid = at::get_thread_num(); + c10::ParallelGuard guard(true); + results[tid] = f(my_begin, my_end, ident); + }); + + scalar_t result = ident; + for (auto partial_result : results) { + result = sf(result, partial_result); + } + return result; +#else + internal::ThreadIdGuard tid_guard(0); + c10::ParallelGuard guard(true); + return f(begin, end, ident); +#endif +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Parallel.h b/phivenv/Lib/site-packages/torch/include/ATen/Parallel.h new file mode 100644 index 0000000000000000000000000000000000000000..e191b802e0643febd7ee49472f20e1f67f8cebaf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Parallel.h @@ -0,0 +1,158 @@ +#pragma once +#include +#include +#include +#include + +namespace at { + +inline int64_t divup(int64_t x, int64_t y) { + return (x + y - 1) / y; +} + +// Called during new thread initialization +TORCH_API void init_num_threads(); + +// Sets the number of threads to be used in parallel region +TORCH_API void set_num_threads(int); + +// Returns the maximum number of threads that may be used in a parallel region +TORCH_API int get_num_threads(); + +// Returns the current thread number (starting from 0) +// in the current parallel region, or 0 in the sequential region +TORCH_API int get_thread_num(); + +// Checks whether the code runs in parallel region +TORCH_API bool in_parallel_region(); + +namespace internal { + +// Initialise num_threads lazily at first parallel call +inline void lazy_init_num_threads() { + thread_local bool init = false; + if (C10_UNLIKELY(!init)) { + at::init_num_threads(); + init = true; + } +} + +TORCH_API void set_thread_num(int); + +class TORCH_API ThreadIdGuard { + public: + ThreadIdGuard(int new_id) : old_id_(at::get_thread_num()) { + set_thread_num(new_id); + } + + ~ThreadIdGuard() { + set_thread_num(old_id_); + } + + private: + int old_id_; +}; + +} // namespace internal + +/* +parallel_for + +begin: index at which to start applying user function + +end: index at which to stop applying user function + +grain_size: number of elements per chunk. impacts the degree of parallelization + +f: user function applied in parallel to the chunks, signature: + void f(int64_t begin, int64_t end) + +Warning: parallel_for does NOT copy thread local +states from the current thread to the worker threads. +This means for example that Tensor operations CANNOT be used in the +body of your function, only data pointers. +*/ +template +inline void parallel_for( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const F& f); + +/* +parallel_reduce + +begin: index at which to start applying reduction + +end: index at which to stop applying reduction + +grain_size: number of elements per chunk. impacts number of elements in +intermediate results tensor and degree of parallelization. + +ident: identity for binary combination function sf. sf(ident, x) needs to return +x. + +f: function for reduction over a chunk. f needs to be of signature scalar_t +f(int64_t partial_begin, int64_t partial_end, scalar_t identifiy) + +sf: function to combine two partial results. sf needs to be of signature +scalar_t sf(scalar_t x, scalar_t y) + +For example, you might have a tensor of 10000 entires and want to sum together +all the elements. Parallel_reduce with a grain_size of 2500 will then allocate +an intermediate result tensor with 4 elements. Then it will execute the function +"f" you provide and pass the beginning and end index of these chunks, so +0-2499, 2500-4999, etc. and the combination identity. It will then write out +the result from each of these chunks into the intermediate result tensor. After +that it'll reduce the partial results from each chunk into a single number using +the combination function sf and the identity ident. For a total summation this +would be "+" and 0 respectively. This is similar to tbb's approach [1], where +you need to provide a function to accumulate a subrange, a function to combine +two partial results and an identity. + +Warning: parallel_reduce does NOT copy thread local +states from the current thread to the worker threads. +This means for example that Tensor operations CANNOT be used in the +body of your function, only data pointers. + +[1] https://software.intel.com/en-us/node/506154 +*/ +template +inline scalar_t parallel_reduce( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const scalar_t ident, + const F& f, + const SF& sf); + +// Returns a detailed string describing parallelization settings +TORCH_API std::string get_parallel_info(); + +// Sets number of threads used for inter-op parallelism +TORCH_API void set_num_interop_threads(int); + +// Returns the number of threads used for inter-op parallelism +TORCH_API size_t get_num_interop_threads(); + +// Launches inter-op parallel task +TORCH_API void launch(std::function func); +namespace internal { +void launch_no_thread_state(std::function fn); +} // namespace internal + +// Launches intra-op parallel task +TORCH_API void intraop_launch(const std::function& func); + +// Returns number of intra-op threads used by default +TORCH_API int intraop_default_num_threads(); + +} // namespace at + +#if AT_PARALLEL_OPENMP +#include // IWYU pragma: keep +#elif AT_PARALLEL_NATIVE +#include // IWYU pragma: keep +#endif + +#include // IWYU pragma: keep diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ParallelFuture.h b/phivenv/Lib/site-packages/torch/include/ATen/ParallelFuture.h new file mode 100644 index 0000000000000000000000000000000000000000..029716cde0d5ad5c41f653702ef4bb94af6b7174 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ParallelFuture.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include +#include + +namespace at { + +// Launches intra-op parallel task, returns a future +TORCH_API c10::intrusive_ptr intraop_launch_future( + const std::function& func); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ParallelNative.h b/phivenv/Lib/site-packages/torch/include/ATen/ParallelNative.h new file mode 100644 index 0000000000000000000000000000000000000000..9cea49149223c221462949dc3798ce983af21399 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ParallelNative.h @@ -0,0 +1,15 @@ +#pragma once + +#include + +#define INTRA_OP_PARALLEL + +namespace at::internal { + +TORCH_API void invoke_parallel( + const int64_t begin, + const int64_t end, + const int64_t grain_size, + const std::function& f); + +} // namespace at::internal diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ParallelOpenMP.h b/phivenv/Lib/site-packages/torch/include/ATen/ParallelOpenMP.h new file mode 100644 index 0000000000000000000000000000000000000000..40a8830c764543d90f8b0180fa3a91039a537d38 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ParallelOpenMP.h @@ -0,0 +1,54 @@ +#pragma once + +#include +#include +#include +#include + +#ifdef _OPENMP +#define INTRA_OP_PARALLEL + +#include +#endif + +#ifdef _OPENMP +namespace at::internal { +template +inline void invoke_parallel( + int64_t begin, + int64_t end, + int64_t grain_size, + const F& f) { + std::atomic_flag err_flag = ATOMIC_FLAG_INIT; + std::exception_ptr eptr; + +#pragma omp parallel + { + // choose number of tasks based on grain size and number of threads + // can't use num_threads clause due to bugs in GOMP's thread pool (See + // #32008) + int64_t num_threads = omp_get_num_threads(); + if (grain_size > 0) { + num_threads = std::min(num_threads, divup((end - begin), grain_size)); + } + + int64_t tid = omp_get_thread_num(); + int64_t chunk_size = divup((end - begin), num_threads); + int64_t begin_tid = begin + tid * chunk_size; + if (begin_tid < end) { + try { + internal::ThreadIdGuard tid_guard(tid); + f(begin_tid, std::min(end, chunk_size + begin_tid)); + } catch (...) { + if (!err_flag.test_and_set()) { + eptr = std::current_exception(); + } + } + } + } + if (eptr) { + std::rethrow_exception(eptr); + } +} +} // namespace at::internal +#endif // _OPENMP diff --git a/phivenv/Lib/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h b/phivenv/Lib/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h new file mode 100644 index 0000000000000000000000000000000000000000..60cdbca63a774ff9e28bbcd73430fafcc031bb37 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/PythonTorchFunctionTLS.h @@ -0,0 +1,36 @@ +#pragma once + +#include +#include + +namespace at::impl { + +enum TorchFunctionDisabledState { ENABLED, SUBCLASSES_DISABLED, ALL_DISABLED }; + +struct TORCH_API PythonTorchFunctionTLS { + static void set_disabled_state(TorchFunctionDisabledState disabled_state_); + static TorchFunctionDisabledState get_disabled_state(); + + static void push_onto_stack(std::shared_ptr mode); + static const std::shared_ptr pop_stack(); + static const std::shared_ptr& get_stack_at(int64_t idx); + static int64_t stack_len(); + + static const PythonTorchFunctionTLS& get_state(); + static void set_state(const PythonTorchFunctionTLS& state); + + private: + // The mode TLS is split into + // - disabled_state, which says which part of torch function are disabled + // - stack_, which is a vector of modes representing the stack of user + // defined modes + TorchFunctionDisabledState disabled_state_ = + TorchFunctionDisabledState::ENABLED; + std::vector> stack_; +}; + +TORCH_API bool torch_function_mode_enabled(); + +TORCH_API bool torch_function_all_disabled(); + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ROCmFABackend.h b/phivenv/Lib/site-packages/torch/include/ATen/ROCmFABackend.h new file mode 100644 index 0000000000000000000000000000000000000000..06773725d349e7a28abd6403da645eaca1dcc845 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ROCmFABackend.h @@ -0,0 +1,31 @@ +#pragma once + +#include + +#include +#include + +namespace at { + +enum class ROCmFABackend : int8_t { Default, AOTriton, Ck }; + +inline std::string ROCmFABackendToString(at::ROCmFABackend backend) { + switch (backend) { + case ROCmFABackend::Default: + return "at::ROCmFABackend::Default"; + case ROCmFABackend::AOTriton: + return "at::ROCmFABackend::AOTriton"; + case ROCmFABackend::Ck: + return "at::ROCmFABackend::Ck"; + default: + TORCH_CHECK(false, "Unknown ROCm flash attention backend") + } +} + +inline std::ostream& operator<<( + std::ostream& stream, + at::ROCmFABackend backend) { + return stream << ROCmFABackendToString(backend); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/RedispatchFunctions.h b/phivenv/Lib/site-packages/torch/include/ATen/RedispatchFunctions.h new file mode 100644 index 0000000000000000000000000000000000000000..d4021e6c884605536d580e065533cec1f85948ce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/RedispatchFunctions.h @@ -0,0 +1,25366 @@ +#pragma once + +// @generated by torchgen/gen.py from RedispatchFunctions.h + +#ifdef TORCH_ASSERT_ONLY_METHOD_OPERATORS +#error This change adds a dependency on all pytorch operators, meaning the \ + file will need to be re-compiled every time an operator is changed or added. \ + Consider using the at::_ops::{name}::redispatch() interface by including \ + the specific operator from +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +namespace redispatch { + + // aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Byte(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Byte::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Char(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Char(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Char::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Double(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Double(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Double::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Float(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Float(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Float::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Int(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Int(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Int::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Long(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Long(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Long::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Short(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Short(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Short::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_cast_Half(Tensor self, bool non_blocking=False) -> Tensor + inline at::Tensor _cast_Half(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking=false) { + return at::_ops::_cast_Half::redispatch(dispatchKeySet, self, non_blocking); + } + + // aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> () + inline void __dispatch__backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList inputs, const ::std::optional & gradient={}, ::std::optional retain_graph=::std::nullopt, bool create_graph=false) { + return at::_ops::_backward::redispatch(dispatchKeySet, self, inputs, gradient, retain_graph, create_graph); + } + + // aten::set_data(Tensor(a!) self, Tensor new_data) -> () + inline void __dispatch_set_data(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & new_data) { + return at::_ops::set_data::redispatch(dispatchKeySet, self, new_data); + } + + // aten::data(Tensor self) -> Tensor + inline at::Tensor __dispatch_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::data::redispatch(dispatchKeySet, self); + } + + // aten::is_leaf(Tensor self) -> bool + inline bool __dispatch_is_leaf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_leaf::redispatch(dispatchKeySet, self); + } + + // aten::output_nr(Tensor self) -> int + inline int64_t __dispatch_output_nr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::output_nr::redispatch(dispatchKeySet, self); + } + + // aten::_version(Tensor self) -> int + inline int64_t __dispatch__version(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_version::redispatch(dispatchKeySet, self); + } + + // aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!) + inline at::Tensor & __dispatch_requires_grad_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, bool requires_grad=true) { + return at::_ops::requires_grad_::redispatch(dispatchKeySet, self, requires_grad); + } + + // aten::retain_grad(Tensor(a!) self) -> () + inline void __dispatch_retain_grad(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::retain_grad::redispatch(dispatchKeySet, self); + } + + // aten::retains_grad(Tensor self) -> bool + inline bool __dispatch_retains_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::retains_grad::redispatch(dispatchKeySet, self); + } + + // aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a) + inline at::Tensor _fw_primal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level) { + return at::_ops::_fw_primal::redispatch(dispatchKeySet, self, level); + } + + // aten::_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a) + inline at::Tensor _make_dual(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + return at::_ops::_make_dual::redispatch(dispatchKeySet, primal, tangent, level); + } + + // aten::_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent) + inline ::std::tuple _unpack_dual(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dual, int64_t level) { + return at::_ops::_unpack_dual::redispatch(dispatchKeySet, dual, level); + } + + // aten::_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor + inline at::Tensor _new_zeros_with_same_feature_meta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims=0) { + return at::_ops::_new_zeros_with_same_feature_meta::redispatch(dispatchKeySet, self, other, self_num_batch_dims); + } + + // aten::_has_same_storage_numel(Tensor self, Tensor other) -> bool + inline bool _has_same_storage_numel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_has_same_storage_numel::redispatch(dispatchKeySet, self, other); + } + + // aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!) + inline at::Tensor & rename_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional names) { + return at::_ops::rename_::redispatch(dispatchKeySet, self, names); + } + + // aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a) + inline at::Tensor rename(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional names) { + return at::_ops::rename::redispatch(dispatchKeySet, self, names); + } + + // aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a) + inline at::Tensor align_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names) { + return at::_ops::align_to::redispatch(dispatchKeySet, self, names); + } + + // aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a) + inline at::Tensor align_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx) { + return at::_ops::align_to_ellipsis_idx::redispatch(dispatchKeySet, self, order, ellipsis_idx); + } + + // aten::align_as(Tensor self, Tensor other) -> Tensor + inline at::Tensor align_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::align_as::redispatch(dispatchKeySet, self, other); + } + + // aten::align_tensors(Tensor[] tensors) -> Tensor[] + inline ::std::vector align_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::align_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::_assert_async(Tensor self) -> () + inline void _assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_assert_async::redispatch(dispatchKeySet, self); + } + + // aten::_assert_async.msg(Tensor self, str assert_msg) -> () + inline void _assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view assert_msg) { + return at::_ops::_assert_async_msg::redispatch(dispatchKeySet, self, assert_msg); + } + + // aten::_assert_scalar(Scalar self, str assert_msg) -> () + inline void _assert_scalar(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, c10::string_view assert_msg) { + return at::_ops::_assert_scalar::redispatch(dispatchKeySet, self, assert_msg); + } + + // aten::_functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor + inline at::Tensor _functional_assert_scalar(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + return at::_ops::_functional_assert_scalar::redispatch(dispatchKeySet, self, assert_msg, dep_token); + } + + // aten::_functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor + inline at::Tensor _functional_assert_async(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + return at::_ops::_functional_assert_async_msg::redispatch(dispatchKeySet, self, assert_msg, dep_token); + } + + // aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () + inline void _assert_tensor_metadata(c10::DispatchKeySet dispatchKeySet, const at::Tensor & a, at::OptionalIntArrayRef size=::std::nullopt, at::OptionalIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::redispatch(dispatchKeySet, a, size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*size)) : ::std::nullopt, stride.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*stride)) : ::std::nullopt, dtype, device, layout); + } + + // aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> () + inline void _assert_tensor_metadata_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & a, at::OptionalSymIntArrayRef size=::std::nullopt, at::OptionalSymIntArrayRef stride=::std::nullopt, ::std::optional dtype=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional layout=::std::nullopt) { + return at::_ops::_assert_tensor_metadata::redispatch(dispatchKeySet, a, size, stride, dtype, device, layout); + } + + // aten::_print(str s) -> () + inline void _print(c10::DispatchKeySet dispatchKeySet, c10::string_view s) { + return at::_ops::_print::redispatch(dispatchKeySet, s); + } + + // aten::sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> () + inline void sym_constrain_range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min=::std::nullopt, ::std::optional max=::std::nullopt) { + return at::_ops::sym_constrain_range::redispatch(dispatchKeySet, size, min, max); + } + + // aten::sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> () + inline void sym_constrain_range_for_size(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min=::std::nullopt, ::std::optional max=::std::nullopt) { + return at::_ops::sym_constrain_range_for_size::redispatch(dispatchKeySet, size, min, max); + } + + // aten::_functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + inline at::Tensor _functional_sym_constrain_range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + return at::_ops::_functional_sym_constrain_range::redispatch(dispatchKeySet, size, min, max, dep_token); + } + + // aten::_functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor + inline at::Tensor _functional_sym_constrain_range_for_size(c10::DispatchKeySet dispatchKeySet, const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + return at::_ops::_functional_sym_constrain_range_for_size::redispatch(dispatchKeySet, size, min, max, dep_token); + } + + // aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _make_dep_token(c10::DispatchKeySet dispatchKeySet, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::_make_dep_token::redispatch(dispatchKeySet, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _make_dep_token(c10::DispatchKeySet dispatchKeySet, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::_make_dep_token::redispatch(dispatchKeySet, dtype, layout, device, pin_memory, memory_format); + } + + // aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a) + inline at::Tensor refine_names(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList names) { + return at::_ops::refine_names::redispatch(dispatchKeySet, self, names); + } + + // aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool + inline bool _use_cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank) { + return at::_ops::_use_cudnn_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank); + } + + // aten::_use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool + inline bool _use_cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank) { + return at::_ops::_use_cudnn_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank); + } + + // aten::_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + inline ::std::tuple _cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + return at::_ops::_cudnn_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + + // aten::_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor) + inline ::std::tuple _cudnn_ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + return at::_ops::_cudnn_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + + // aten::_use_cudnn_rnn_flatten_weight() -> bool + inline bool _use_cudnn_rnn_flatten_weight(c10::DispatchKeySet dispatchKeySet) { + return at::_ops::_use_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet); + } + + // aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor + inline at::Tensor _cudnn_rnn_flatten_weight(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + } + + // aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor + inline at::Tensor _cudnn_rnn_flatten_weight_symint(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + } + + // aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _cudnn_rnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state); + } + + // aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _cudnn_rnn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + + // aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + inline ::std::tuple> _cudnn_rnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask); + } + + // aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + inline ::std::tuple> _cudnn_rnn_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + + // aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _cudnn_init_dropout_state(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, at::TensorOptions options) { + return at::_ops::_cudnn_init_dropout_state::redispatch(dispatchKeySet, dropout, train, dropout_seed, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _cudnn_init_dropout_state(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_cudnn_init_dropout_state::redispatch(dispatchKeySet, dropout, train, dropout_seed, dtype, layout, device, pin_memory); + } + + // aten::_debug_has_internal_overlap(Tensor self) -> int + inline int64_t _debug_has_internal_overlap(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_debug_has_internal_overlap::redispatch(dispatchKeySet, self); + } + + // aten::_fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor) + inline ::std::tuple _fused_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::_fused_dropout::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::_masked_scale(Tensor self, Tensor mask, float scale) -> Tensor + inline at::Tensor _masked_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, double scale) { + return at::_ops::_masked_scale::redispatch(dispatchKeySet, self, mask, scale); + } + + // aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor) + inline ::std::tuple native_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, ::std::optional train) { + return at::_ops::native_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor + inline at::Tensor native_dropout_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, double scale) { + return at::_ops::native_dropout_backward::redispatch(dispatchKeySet, grad_output, mask, scale); + } + + // aten::_sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor) + inline ::std::tuple _sobol_engine_draw(c10::DispatchKeySet dispatchKeySet, const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, ::std::optional dtype) { + return at::_ops::_sobol_engine_draw::redispatch(dispatchKeySet, quasi, n, sobolstate, dimension, num_generated, dtype); + } + + // aten::_sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!) + inline at::Tensor & _sobol_engine_ff_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated) { + return at::_ops::_sobol_engine_ff_::redispatch(dispatchKeySet, self, n, sobolstate, dimension, num_generated); + } + + // aten::_sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!) + inline at::Tensor & _sobol_engine_scramble_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & ltm, int64_t dimension) { + return at::_ops::_sobol_engine_scramble_::redispatch(dispatchKeySet, self, ltm, dimension); + } + + // aten::_sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!) + inline at::Tensor & _sobol_engine_initialize_state_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dimension) { + return at::_ops::_sobol_engine_initialize_state_::redispatch(dispatchKeySet, self, dimension); + } + + // aten::_reshape_from_tensor(Tensor self, Tensor shape) -> Tensor + inline at::Tensor _reshape_from_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & shape) { + return at::_ops::_reshape_from_tensor::redispatch(dispatchKeySet, self, shape); + } + + // aten::_shape_as_tensor(Tensor self) -> Tensor + inline at::Tensor _shape_as_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_shape_as_tensor::redispatch(dispatchKeySet, self); + } + + // aten::dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::feature_dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor feature_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::feature_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & feature_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::feature_dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor alpha_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::alpha_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & alpha_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::alpha_dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor + inline at::Tensor feature_alpha_dropout(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, bool train) { + return at::_ops::feature_alpha_dropout::redispatch(dispatchKeySet, input, p, train); + } + + // aten::feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) + inline at::Tensor & feature_alpha_dropout_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, bool train) { + return at::_ops::feature_alpha_dropout_::redispatch(dispatchKeySet, self, p, train); + } + + // aten::abs(Tensor self) -> Tensor + inline at::Tensor abs(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::abs::redispatch(dispatchKeySet, self); + } + + // aten::abs_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & abs_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::abs_::redispatch(dispatchKeySet, self); + } + + // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & abs_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & abs_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::absolute(Tensor self) -> Tensor + inline at::Tensor absolute(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::absolute::redispatch(dispatchKeySet, self); + } + + // aten::absolute_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & absolute_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::absolute_::redispatch(dispatchKeySet, self); + } + + // aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & absolute_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::absolute_out::redispatch(dispatchKeySet, self, out); + } + + // aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & absolute_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::absolute_out::redispatch(dispatchKeySet, self, out); + } + + // aten::angle(Tensor self) -> Tensor + inline at::Tensor angle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::angle::redispatch(dispatchKeySet, self); + } + + // aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & angle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::angle_out::redispatch(dispatchKeySet, self, out); + } + + // aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & angle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::angle_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_real(Tensor(a) self) -> Tensor(a) + inline at::Tensor view_as_real(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_real::redispatch(dispatchKeySet, self); + } + + // aten::view_as_complex(Tensor(a) self) -> Tensor(a) + inline at::Tensor view_as_complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_complex::redispatch(dispatchKeySet, self); + } + + // aten::sgn(Tensor self) -> Tensor + inline at::Tensor sgn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sgn::redispatch(dispatchKeySet, self); + } + + // aten::sgn_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sgn_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sgn_::redispatch(dispatchKeySet, self); + } + + // aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sgn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sgn_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sgn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sgn_out::redispatch(dispatchKeySet, self, out); + } + + // aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor chalf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::chalf::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::real(Tensor(a) self) -> Tensor(a) + inline at::Tensor real(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::real::redispatch(dispatchKeySet, self); + } + + // aten::imag(Tensor(a) self) -> Tensor(a) + inline at::Tensor imag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::imag::redispatch(dispatchKeySet, self); + } + + // aten::_conj(Tensor(a) self) -> Tensor(a) + inline at::Tensor _conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_conj::redispatch(dispatchKeySet, self); + } + + // aten::conj(Tensor(a) self) -> Tensor(a) + inline at::Tensor __dispatch_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::conj::redispatch(dispatchKeySet, self); + } + + // aten::_conj_physical(Tensor self) -> Tensor + inline at::Tensor _conj_physical(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_conj_physical::redispatch(dispatchKeySet, self); + } + + // aten::conj_physical(Tensor self) -> Tensor + inline at::Tensor conj_physical(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::conj_physical::redispatch(dispatchKeySet, self); + } + + // aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conj_physical_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conj_physical_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::conj_physical_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & conj_physical_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::conj_physical_::redispatch(dispatchKeySet, self); + } + + // aten::resolve_conj(Tensor(a) self) -> Tensor(a) + inline at::Tensor resolve_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::resolve_conj::redispatch(dispatchKeySet, self); + } + + // aten::resolve_neg(Tensor(a) self) -> Tensor(a) + inline at::Tensor resolve_neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::resolve_neg::redispatch(dispatchKeySet, self); + } + + // aten::_neg_view(Tensor(a) self) -> Tensor(a) + inline at::Tensor _neg_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_neg_view::redispatch(dispatchKeySet, self); + } + + // aten::acos(Tensor self) -> Tensor + inline at::Tensor acos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::acos::redispatch(dispatchKeySet, self); + } + + // aten::acos_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & acos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::acos_::redispatch(dispatchKeySet, self); + } + + // aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccos(Tensor self) -> Tensor + inline at::Tensor arccos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arccos::redispatch(dispatchKeySet, self); + } + + // aten::arccos_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arccos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arccos_::redispatch(dispatchKeySet, self); + } + + // aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arccos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arccos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor + inline at::Tensor avg_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) { + return at::_ops::avg_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad); + } + + // aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor) + inline ::std::tuple adaptive_max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool1d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::add_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor _add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & _add_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::_add_relu_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor _add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & _add_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addmv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv::redispatch(dispatchKeySet, self, mat, vec, beta, alpha); + } + + // aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addmv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv_::redispatch(dispatchKeySet, self, mat, vec, beta, alpha); + } + + // aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmv_out::redispatch(dispatchKeySet, self, mat, vec, beta, alpha, out); + } + + // aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmv_out::redispatch(dispatchKeySet, self, mat, vec, beta, alpha, out); + } + + // aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha); + } + + // aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addr_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr_::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha); + } + + // aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addr_out::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha, out); + } + + // aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addr_out::redispatch(dispatchKeySet, self, vec1, vec2, beta, alpha, out); + } + + // aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners); + } + + // aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator::redispatch(dispatchKeySet, theta, size, align_corners); + } + + // aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(size), align_corners); + } + + // aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor + inline at::Tensor affine_grid_generator_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_backward::redispatch(dispatchKeySet, grad, size, align_corners); + } + + // aten::_is_all_true(Tensor self) -> Tensor + inline at::Tensor _is_all_true(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_is_all_true::redispatch(dispatchKeySet, self); + } + + // aten::_is_any_true(Tensor self) -> Tensor + inline at::Tensor _is_any_true(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_is_any_true::redispatch(dispatchKeySet, self); + } + + // aten::_test_check_tensor(Tensor self) -> Tensor + inline at::Tensor _test_check_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_check_tensor::redispatch(dispatchKeySet, self); + } + + // aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor + inline at::Tensor _test_functorch_fallback(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_test_functorch_fallback::redispatch(dispatchKeySet, self, other); + } + + // aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::all_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::all_dims::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::all_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::all_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::all_dimname::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::all_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) { + return at::_ops::all_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool + inline bool allclose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) { + return at::_ops::allclose::redispatch(dispatchKeySet, self, other, rtol, atol, equal_nan); + } + + // aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::any_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::any_dims::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::any_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::any_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_dims_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::any_dimname::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::any_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out) { + return at::_ops::any_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::arange::redispatch(dispatchKeySet, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange::redispatch(dispatchKeySet, end, dtype, layout, device, pin_memory); + } + + // aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::arange_start::redispatch(dispatchKeySet, start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange_start::redispatch(dispatchKeySet, start, end, dtype, layout, device, pin_memory); + } + + // aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::TensorOptions options={}) { + return at::_ops::arange_start_step::redispatch(dispatchKeySet, start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor arange(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::arange_start_step::redispatch(dispatchKeySet, start, end, step, dtype, layout, device, pin_memory); + } + + // aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & end) { + return at::_ops::arange_out::redispatch(dispatchKeySet, end, out); + } + + // aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & end, at::Tensor & out) { + return at::_ops::arange_out::redispatch(dispatchKeySet, end, out); + } + + // aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) { + return at::_ops::arange_start_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arange_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { + return at::_ops::arange_start_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::_dim_arange(Tensor like, int dim) -> Tensor + inline at::Tensor _dim_arange(c10::DispatchKeySet dispatchKeySet, const at::Tensor & like, int64_t dim) { + return at::_ops::_dim_arange::redispatch(dispatchKeySet, like, dim); + } + + // aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor argmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmax::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out) { + return at::_ops::argmax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor + inline at::Tensor argmin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmin::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::argmin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argmin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out) { + return at::_ops::argmin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::acosh(Tensor self) -> Tensor + inline at::Tensor acosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::acosh::redispatch(dispatchKeySet, self); + } + + // aten::acosh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & acosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::acosh_::redispatch(dispatchKeySet, self); + } + + // aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::acosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & acosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::acosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccosh(Tensor self) -> Tensor + inline at::Tensor arccosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arccosh::redispatch(dispatchKeySet, self); + } + + // aten::arccosh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arccosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arccosh_::redispatch(dispatchKeySet, self); + } + + // aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arccosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arccosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arccosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::asinh(Tensor self) -> Tensor + inline at::Tensor asinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::asinh::redispatch(dispatchKeySet, self); + } + + // aten::asinh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & asinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::asinh_::redispatch(dispatchKeySet, self); + } + + // aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::asinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::asinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsinh(Tensor self) -> Tensor + inline at::Tensor arcsinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arcsinh::redispatch(dispatchKeySet, self); + } + + // aten::arcsinh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arcsinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arcsinh_::redispatch(dispatchKeySet, self); + } + + // aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arcsinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arcsinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atanh(Tensor self) -> Tensor + inline at::Tensor atanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atanh::redispatch(dispatchKeySet, self); + } + + // aten::atanh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & atanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::atanh_::redispatch(dispatchKeySet, self); + } + + // aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::atanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::atanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctanh(Tensor self) -> Tensor + inline at::Tensor arctanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arctanh::redispatch(dispatchKeySet, self); + } + + // aten::arctanh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arctanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arctanh_::redispatch(dispatchKeySet, self); + } + + // aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arctanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arctanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + inline at::Tensor as_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a) + inline at::Tensor as_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided::redispatch(dispatchKeySet, self, size, stride, storage_offset); + } + + // aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + inline const at::Tensor & as_strided_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!) + inline const at::Tensor & as_strided__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_::redispatch(dispatchKeySet, self, size, stride, storage_offset); + } + + // aten::asin(Tensor self) -> Tensor + inline at::Tensor asin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::asin::redispatch(dispatchKeySet, self); + } + + // aten::asin_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & asin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::asin_::redispatch(dispatchKeySet, self); + } + + // aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & asin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsin(Tensor self) -> Tensor + inline at::Tensor arcsin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arcsin::redispatch(dispatchKeySet, self); + } + + // aten::arcsin_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arcsin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arcsin_::redispatch(dispatchKeySet, self); + } + + // aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arcsin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arcsin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arcsin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atan(Tensor self) -> Tensor + inline at::Tensor atan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atan::redispatch(dispatchKeySet, self); + } + + // aten::atan_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & atan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::atan_::redispatch(dispatchKeySet, self); + } + + // aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctan(Tensor self) -> Tensor + inline at::Tensor arctan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::arctan::redispatch(dispatchKeySet, self); + } + + // aten::arctan_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & arctan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::arctan_::redispatch(dispatchKeySet, self); + } + + // aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::arctan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::arctan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::atleast_1d(Tensor self) -> Tensor + inline at::Tensor atleast_1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atleast_1d::redispatch(dispatchKeySet, self); + } + + // aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[] + inline ::std::vector atleast_1d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::atleast_1d_Sequence::redispatch(dispatchKeySet, tensors); + } + + // aten::atleast_2d(Tensor self) -> Tensor + inline at::Tensor atleast_2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atleast_2d::redispatch(dispatchKeySet, self); + } + + // aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[] + inline ::std::vector atleast_2d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::atleast_2d_Sequence::redispatch(dispatchKeySet, tensors); + } + + // aten::atleast_3d(Tensor self) -> Tensor + inline at::Tensor atleast_3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::atleast_3d::redispatch(dispatchKeySet, self); + } + + // aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[] + inline ::std::vector atleast_3d(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::atleast_3d_Sequence::redispatch(dispatchKeySet, tensors); + } + + // aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor baddbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & baddbmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::baddbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor baddbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_dtype::redispatch(dispatchKeySet, self, batch1, batch2, out_dtype, beta, alpha); + } + + // aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::baddbmm_dtype_out::redispatch(dispatchKeySet, self, batch1, batch2, out_dtype, beta, alpha, out); + } + + // aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & baddbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::baddbmm_dtype_out::redispatch(dispatchKeySet, self, batch1, batch2, out_dtype, beta, alpha, out); + } + + // aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::bartlett_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::bartlett_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::bartlett_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor bartlett_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::bartlett_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor + inline at::Tensor batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); + } + + // aten::quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor + inline at::Tensor quantized_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) { + return at::_ops::quantized_batch_norm::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point); + } + + // aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int) + inline ::std::tuple _batch_norm_impl_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::_batch_norm_impl_index::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); + } + + // aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _batch_norm_impl_index_backward(c10::DispatchKeySet dispatchKeySet, int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace) { + return at::_ops::_batch_norm_impl_index_backward::redispatch(dispatchKeySet, impl_index, input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, train, eps, output_mask, reservedSpace); + } + + // aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor + inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli::redispatch(dispatchKeySet, self, generator); + } + + // aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::bernoulli_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & bernoulli_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli__Tensor::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & bernoulli_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p=0.5, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli__float::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor + inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_p::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor + inline at::Tensor bilinear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::bilinear::redispatch(dispatchKeySet, input1, input2, weight, bias); + } + + // aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor binary_cross_entropy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy::redispatch(dispatchKeySet, self, target, weight, reduction); + } + + // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_out::redispatch(dispatchKeySet, self, target, weight, reduction, out); + } + + // aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & out) { + return at::_ops::binary_cross_entropy_out::redispatch(dispatchKeySet, self, target, weight, reduction, out); + } + + // aten::binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor binary_cross_entropy_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction); + } + + // aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, grad_input); + } + + // aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::binary_cross_entropy_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, grad_input); + } + + // aten::binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor + inline at::Tensor binary_cross_entropy_with_logits(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, const ::std::optional & pos_weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_with_logits::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction); + } + + // aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor + inline at::Tensor bincount(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights={}, int64_t minlength=0) { + return at::_ops::bincount::redispatch(dispatchKeySet, self, weights, minlength); + } + + // aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor + inline at::Tensor bincount_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights={}, c10::SymInt minlength=0) { + return at::_ops::bincount::redispatch(dispatchKeySet, self, weights, minlength); + } + + // aten::bitwise_not(Tensor self) -> Tensor + inline at::Tensor bitwise_not(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::bitwise_not::redispatch(dispatchKeySet, self); + } + + // aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & bitwise_not_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::bitwise_not_::redispatch(dispatchKeySet, self); + } + + // aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_not_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::bitwise_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_not_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::bitwise_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::copysign_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::copysign_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor copysign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::copysign_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & copysign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::copysign__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor copysign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::copysign_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & copysign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::copysign__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::copysign_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copysign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::copysign_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_lazy_clone(Tensor self) -> Tensor + inline at::Tensor _lazy_clone(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_lazy_clone::redispatch(dispatchKeySet, self); + } + + // aten::logical_not(Tensor self) -> Tensor + inline at::Tensor logical_not(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::logical_not::redispatch(dispatchKeySet, self); + } + + // aten::logical_not_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & logical_not_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::logical_not_::redispatch(dispatchKeySet, self); + } + + // aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_not_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::logical_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_not_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::logical_not_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logical_xor(Tensor self, Tensor other) -> Tensor + inline at::Tensor logical_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_xor::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & logical_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_xor_::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_xor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logical_xor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_and(Tensor self, Tensor other) -> Tensor + inline at::Tensor logical_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_and::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & logical_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_and_::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_and_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logical_and_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_or(Tensor self, Tensor other) -> Tensor + inline at::Tensor logical_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_or::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & logical_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_or_::redispatch(dispatchKeySet, self, other); + } + + // aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logical_or_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logical_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logical_or_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::blackman_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::blackman_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::blackman_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor blackman_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::blackman_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::bmm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor bmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::bmm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::bmm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::bmm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor + inline at::Tensor bmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::bmm_dtype::redispatch(dispatchKeySet, self, mat2, out_dtype); + } + + // aten::bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::bmm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out) { + return at::_ops::bmm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::broadcast_tensors(Tensor[] tensors) -> Tensor[] + inline ::std::vector broadcast_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::broadcast_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor broadcast_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::broadcast_to::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor broadcast_to_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::broadcast_to::redispatch(dispatchKeySet, self, size); + } + + // aten::_sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a) + inline at::Tensor _sparse_broadcast_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_sparse_broadcast_to::redispatch(dispatchKeySet, self, size); + } + + // aten::cat(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor cat(c10::DispatchKeySet dispatchKeySet, const at::ITensorListRef & tensors, int64_t dim=0) { + return at::_ops::cat::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::ITensorListRef & tensors, int64_t dim=0) { + return at::_ops::cat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_outf(c10::DispatchKeySet dispatchKeySet, const at::ITensorListRef & tensors, int64_t dim, at::Tensor & out) { + return at::_ops::cat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor + inline at::Tensor cat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) { + return at::_ops::cat_names::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) { + return at::_ops::cat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) { + return at::_ops::cat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor concat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concat::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::concat_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat.names(Tensor[] tensors, Dimname dim) -> Tensor + inline at::Tensor concat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concat_names::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) { + return at::_ops::concat_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor concatenate(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concatenate::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::concatenate_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::concatenate_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor + inline at::Tensor concatenate(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concatenate_names::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, at::Dimname dim) { + return at::_ops::concatenate_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & concatenate_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Dimname dim, at::Tensor & out) { + return at::_ops::concatenate_names_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::block_diag(Tensor[] tensors) -> Tensor + inline at::Tensor block_diag(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::block_diag::redispatch(dispatchKeySet, tensors); + } + + // aten::ceil(Tensor self) -> Tensor + inline at::Tensor ceil(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ceil::redispatch(dispatchKeySet, self); + } + + // aten::ceil_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & ceil_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::ceil_::redispatch(dispatchKeySet, self); + } + + // aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ceil_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ceil_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::chain_matmul(Tensor[] matrices) -> Tensor + inline at::Tensor chain_matmul(c10::DispatchKeySet dispatchKeySet, at::TensorList matrices) { + return at::_ops::chain_matmul::redispatch(dispatchKeySet, matrices); + } + + // aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & chain_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList matrices) { + return at::_ops::chain_matmul_out::redispatch(dispatchKeySet, matrices, out); + } + + // aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & chain_matmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList matrices, at::Tensor & out) { + return at::_ops::chain_matmul_out::redispatch(dispatchKeySet, matrices, out); + } + + // aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[] + inline ::std::vector unsafe_chunk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t chunks, int64_t dim=0) { + return at::_ops::unsafe_chunk::redispatch(dispatchKeySet, self, chunks, dim); + } + + // aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[] + inline ::std::vector chunk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t chunks, int64_t dim=0) { + return at::_ops::chunk::redispatch(dispatchKeySet, self, chunks, dim); + } + + // aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections, int64_t dim=0) { + return at::_ops::tensor_split_sections::redispatch(dispatchKeySet, self, sections, dim); + } + + // aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt sections, int64_t dim=0) { + return at::_ops::tensor_split_sections::redispatch(dispatchKeySet, self, sections, dim); + } + + // aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices, int64_t dim=0) { + return at::_ops::tensor_split_indices::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(indices), dim); + } + + // aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim=0) { + return at::_ops::tensor_split_indices::redispatch(dispatchKeySet, self, indices, dim); + } + + // aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[] + inline ::std::vector tensor_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim=0) { + return at::_ops::tensor_split_tensor_indices_or_sections::redispatch(dispatchKeySet, self, tensor_indices_or_sections, dim); + } + + // aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + inline at::Tensor clamp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clamp::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + inline at::Tensor clamp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clamp_Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + inline at::Tensor & clamp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clamp_::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) + inline at::Tensor & clamp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clamp__Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clamp_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clamp_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clamp_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clamp_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clamp_max(Tensor self, Scalar max) -> Tensor + inline at::Tensor clamp_max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & max) { + return at::_ops::clamp_max::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor + inline at::Tensor clamp_max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & max) { + return at::_ops::clamp_max_Tensor::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!) + inline at::Tensor & clamp_max_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & max) { + return at::_ops::clamp_max_::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!) + inline at::Tensor & clamp_max_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & max) { + return at::_ops::clamp_max__Tensor::redispatch(dispatchKeySet, self, max); + } + + // aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & max) { + return at::_ops::clamp_max_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & max, at::Tensor & out) { + return at::_ops::clamp_max_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & max) { + return at::_ops::clamp_max_Tensor_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & max, at::Tensor & out) { + return at::_ops::clamp_max_Tensor_out::redispatch(dispatchKeySet, self, max, out); + } + + // aten::clamp_min(Tensor self, Scalar min) -> Tensor + inline at::Tensor clamp_min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min) { + return at::_ops::clamp_min::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor + inline at::Tensor clamp_min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & min) { + return at::_ops::clamp_min_Tensor::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!) + inline at::Tensor & clamp_min_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & min) { + return at::_ops::clamp_min_::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!) + inline at::Tensor & clamp_min_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & min) { + return at::_ops::clamp_min__Tensor::redispatch(dispatchKeySet, self, min); + } + + // aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & min) { + return at::_ops::clamp_min_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min, at::Tensor & out) { + return at::_ops::clamp_min_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & min) { + return at::_ops::clamp_min_Tensor_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clamp_min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & min, at::Tensor & out) { + return at::_ops::clamp_min_Tensor_out::redispatch(dispatchKeySet, self, min, out); + } + + // aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor + inline at::Tensor clip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clip::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor + inline at::Tensor clip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clip_Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!) + inline at::Tensor & clip_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clip_::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!) + inline at::Tensor & clip_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clip__Tensor::redispatch(dispatchKeySet, self, min, max); + } + + // aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max=::std::nullopt) { + return at::_ops::clip_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clip_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & min={}, const ::std::optional & max={}) { + return at::_ops::clip_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out) { + return at::_ops::clip_Tensor_out::redispatch(dispatchKeySet, self, min, max, out); + } + + // aten::cudnn_is_acceptable(Tensor self) -> bool + inline bool cudnn_is_acceptable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::cudnn_is_acceptable::redispatch(dispatchKeySet, self); + } + + // aten::complex(Tensor real, Tensor imag) -> Tensor + inline at::Tensor complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & real, const at::Tensor & imag) { + return at::_ops::complex::redispatch(dispatchKeySet, real, imag); + } + + // aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & complex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & real, const at::Tensor & imag) { + return at::_ops::complex_out::redispatch(dispatchKeySet, real, imag, out); + } + + // aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & complex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & real, const at::Tensor & imag, at::Tensor & out) { + return at::_ops::complex_out::redispatch(dispatchKeySet, real, imag, out); + } + + // aten::polar(Tensor abs, Tensor angle) -> Tensor + inline at::Tensor polar(c10::DispatchKeySet dispatchKeySet, const at::Tensor & abs, const at::Tensor & angle) { + return at::_ops::polar::redispatch(dispatchKeySet, abs, angle); + } + + // aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polar_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & abs, const at::Tensor & angle) { + return at::_ops::polar_out::redispatch(dispatchKeySet, abs, angle, out); + } + + // aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polar_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & abs, const at::Tensor & angle, at::Tensor & out) { + return at::_ops::polar_out::redispatch(dispatchKeySet, abs, angle, out); + } + + // aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + inline at::Tensor constant_pad_nd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value); + } + + // aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor + inline at::Tensor constant_pad_nd_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd::redispatch(dispatchKeySet, self, pad, value); + } + + // aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a) + inline at::Tensor __dispatch_contiguous(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::MemoryFormat memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::contiguous::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups); + } + + // aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + + // aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple convolution_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : ::std::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask); + } + + // aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple convolution_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + + // aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution_overrideable::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups); + } + + // aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor + inline at::Tensor convolution_overrideable_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution_overrideable::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + + // aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple convolution_backward_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask); + } + + // aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple convolution_backward_overrideable_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + + // aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + inline at::Tensor _convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + } + + // aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor + inline at::Tensor _convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + } + + // aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor + inline at::Tensor _convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + return at::_ops::_convolution_deprecated::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + } + + // aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor + inline at::Tensor _convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + return at::_ops::_convolution_deprecated::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + } + + // aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _convolution_mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_convolution_mode::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _convolution_mode_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_convolution_mode::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _convolution_double_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::_convolution_double_backward::redispatch(dispatchKeySet, ggI, ggW, ggb, gO, weight, self, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask); + } + + // aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _convolution_double_backward_symint(c10::DispatchKeySet dispatchKeySet, const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::_convolution_double_backward::redispatch(dispatchKeySet, ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + + // aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv1d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv1d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv2d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv2d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv3d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv3d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv1d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding="valid", SymInt[1] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv1d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv2d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding="valid", SymInt[2] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv2d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, c10::string_view padding, at::IntArrayRef dilation=1, int64_t groups=1) { + return at::_ops::conv3d_padding::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), padding, c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding="valid", SymInt[3] dilation=1, SymInt groups=1) -> Tensor + inline at::Tensor conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1) { + return at::_ops::conv3d_padding::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, groups); + } + + // aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor + inline at::Tensor conv_tbc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad=0) { + return at::_ops::conv_tbc::redispatch(dispatchKeySet, self, weight, bias, pad); + } + + // aten::conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor) + inline ::std::tuple conv_tbc_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) { + return at::_ops::conv_tbc_backward::redispatch(dispatchKeySet, self, input, weight, bias, pad); + } + + // aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor + inline at::Tensor conv_transpose1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) { + return at::_ops::conv_transpose1d::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor + inline at::Tensor conv_transpose1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::conv_transpose1d::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation); + } + + // aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor + inline at::Tensor conv_transpose2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) { + return at::_ops::conv_transpose2d_input::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor + inline at::Tensor conv_transpose2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::conv_transpose2d_input::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation); + } + + // aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor + inline at::Tensor conv_transpose3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, int64_t groups=1, at::IntArrayRef dilation=1) { + return at::_ops::conv_transpose3d_input::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), groups, c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor + inline at::Tensor conv_transpose3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymInt groups=1, c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::conv_transpose3d_input::redispatch(dispatchKeySet, input, weight, bias, stride, padding, output_padding, groups, dilation); + } + + // aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + inline at::Tensor copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + inline at::Tensor & copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor + inline at::Tensor _copy_from(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, bool non_blocking=false) { + return at::_ops::_copy_from::redispatch(dispatchKeySet, self, dst, non_blocking); + } + + // aten::_copy_from_and_resize(Tensor self, Tensor dst) -> Tensor + inline at::Tensor _copy_from_and_resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst) { + return at::_ops::_copy_from_and_resize::redispatch(dispatchKeySet, self, dst); + } + + // aten::cos(Tensor self) -> Tensor + inline at::Tensor cos(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::cos::redispatch(dispatchKeySet, self); + } + + // aten::cos_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & cos_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::cos_::redispatch(dispatchKeySet, self); + } + + // aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cos_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cos_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cosh(Tensor self) -> Tensor + inline at::Tensor cosh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::cosh::redispatch(dispatchKeySet, self); + } + + // aten::cosh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & cosh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::cosh_::redispatch(dispatchKeySet, self); + } + + // aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cosh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cosh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + inline at::Tensor cosine_embedding_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin=0.0, int64_t reduction=at::Reduction::Mean) { + return at::_ops::cosine_embedding_loss::redispatch(dispatchKeySet, input1, input2, target, margin, reduction); + } + + // aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor + inline at::Tensor count_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::count_nonzero_dim_IntList::redispatch(dispatchKeySet, self, dim); + } + + // aten::count_nonzero(Tensor self, int? dim=None) -> Tensor + inline at::Tensor count_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt) { + return at::_ops::count_nonzero::redispatch(dispatchKeySet, self, dim); + } + + // aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor + inline at::Tensor cov(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t correction=1, const ::std::optional & fweights={}, const ::std::optional & aweights={}) { + return at::_ops::cov::redispatch(dispatchKeySet, self, correction, fweights, aweights); + } + + // aten::corrcoef(Tensor self) -> Tensor + inline at::Tensor corrcoef(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::corrcoef::redispatch(dispatchKeySet, self); + } + + // aten::cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid + inline at::Tensor cudnn_affine_grid_generator(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator::redispatch(dispatchKeySet, theta, N, C, H, W); + } + + // aten::cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta + inline at::Tensor cudnn_affine_grid_generator_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator_backward::redispatch(dispatchKeySet, grad, N, C, H, W); + } + + // aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple cudnn_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::cudnn_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + + // aten::cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor) + inline ::std::tuple cudnn_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace) { + return at::_ops::cudnn_batch_norm_backward::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace); + } + + // aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32); + } + + // aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_out::redispatch(dispatchKeySet, self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32); + } + + // aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor + inline at::Tensor cudnn_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + + // aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution_transpose::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution_transpose::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups); + } + + // aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple mps_convolution_transpose_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask); + } + + // aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple mps_convolution_transpose_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask); + } + + // aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups); + } + + // aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor cudnn_convolution_add_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + + // aten::cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output + inline at::Tensor cudnn_grid_sampler(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid) { + return at::_ops::cudnn_grid_sampler::redispatch(dispatchKeySet, self, grid); + } + + // aten::cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid) + inline ::std::tuple cudnn_grid_sampler_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) { + return at::_ops::cudnn_grid_sampler_backward::redispatch(dispatchKeySet, self, grid, grad_output); + } + + // aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::cummax::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim) { + return at::_ops::cummax_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummax_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummax_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummax_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummax_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + inline void _cummax_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + return at::_ops::_cummax_helper::redispatch(dispatchKeySet, self, values, indices, dim); + } + + // aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::cummin::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim) { + return at::_ops::cummin_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummin_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + inline ::std::tuple cummin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummin_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim) { + return at::_ops::cummin_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple cummin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::cummin_dimname_out::redispatch(dispatchKeySet, self, dim, values, indices); + } + + // aten::_cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> () + inline void _cummin_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + return at::_ops::_cummin_helper::redispatch(dispatchKeySet, self, values, indices, dim); + } + + // aten::cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor + inline at::Tensor cummaxmin_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim) { + return at::_ops::cummaxmin_backward::redispatch(dispatchKeySet, grad, input, indices, dim); + } + + // aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumprod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumprod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumprod_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumprod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumprod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod__dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumprod_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumprod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumprod_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor + inline at::Tensor cumprod_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output) { + return at::_ops::cumprod_backward::redispatch(dispatchKeySet, grad, input, dim, output); + } + + // aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumsum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumsum_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumsum_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor cumsum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!) + inline at::Tensor & cumsum_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum__dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::cumsum_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cumsum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::cumsum_dimname_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + inline at::Tensor cumulative_trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) { + return at::_ops::cumulative_trapezoid_x::redispatch(dispatchKeySet, y, x, dim); + } + + // aten::cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor + inline at::Tensor cumulative_trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Scalar & dx=1, int64_t dim=-1) { + return at::_ops::cumulative_trapezoid_dx::redispatch(dispatchKeySet, y, dx, dim); + } + + // aten::ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + inline at::Tensor ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, int64_t reduction=at::Reduction::Mean, bool zero_infinity=false) { + return at::_ops::ctc_loss_IntList::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + + // aten::ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor + inline at::Tensor ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, int64_t reduction=at::Reduction::Mean, bool zero_infinity=false) { + return at::_ops::ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + + // aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + inline ::std::tuple _ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + + // aten::_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor) + inline ::std::tuple _ctc_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss_Tensor::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + + // aten::_ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + inline at::Tensor _ctc_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) { + return at::_ops::_ctc_loss_backward::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + + // aten::_ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor + inline at::Tensor _ctc_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) { + return at::_ops::_ctc_loss_backward_Tensor::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + + // aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor + inline at::Tensor diag_embed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) { + return at::_ops::diag_embed::redispatch(dispatchKeySet, self, offset, dim1, dim2); + } + + // aten::diagflat(Tensor self, int offset=0) -> Tensor + inline at::Tensor diagflat(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0) { + return at::_ops::diagflat::redispatch(dispatchKeySet, self, offset); + } + + // aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a) + inline at::Tensor diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal::redispatch(dispatchKeySet, self, offset, dim1, dim2); + } + + // aten::linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a) + inline at::Tensor linalg_diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) { + return at::_ops::linalg_diagonal::redispatch(dispatchKeySet, A, offset, dim1, dim2); + } + + // aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a) + inline at::Tensor diagonal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset=0) { + return at::_ops::diagonal_Dimname::redispatch(dispatchKeySet, self, outdim, dim1, dim2, offset); + } + + // aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + inline at::Tensor diagonal_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2); + } + + // aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor + inline at::Tensor diagonal_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2); + } + + // aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!) + inline at::Tensor & fill_diagonal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & fill_value, bool wrap=false) { + return at::_ops::fill_diagonal_::redispatch(dispatchKeySet, self, fill_value, wrap); + } + + // aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor + inline at::Tensor diff(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n=1, int64_t dim=-1, const ::std::optional & prepend={}, const ::std::optional & append={}) { + return at::_ops::diff::redispatch(dispatchKeySet, self, n, dim, prepend, append); + } + + // aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diff_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n=1, int64_t dim=-1, const ::std::optional & prepend={}, const ::std::optional & append={}) { + return at::_ops::diff_out::redispatch(dispatchKeySet, self, n, dim, prepend, append, out); + } + + // aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diff_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append, at::Tensor & out) { + return at::_ops::diff_out::redispatch(dispatchKeySet, self, n, dim, prepend, append, out); + } + + // aten::gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & spacing=::std::nullopt, ::std::optional dim=::std::nullopt, int64_t edge_order=1) { + return at::_ops::gradient_scalarint::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_scalararray::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_array::redispatch(dispatchKeySet, self, dim, edge_order); + } + + // aten::gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ArrayRef spacing, ::std::optional dim=::std::nullopt, int64_t edge_order=1) { + return at::_ops::gradient_scalarrayint::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_scalarrayarray::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList spacing, ::std::optional dim=::std::nullopt, int64_t edge_order=1) { + return at::_ops::gradient_tensorarrayint::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[] + inline ::std::vector gradient(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order=1) { + return at::_ops::gradient_tensorarray::redispatch(dispatchKeySet, self, spacing, dim, edge_order); + } + + // aten::div.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::div_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::div__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::div_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::div_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::div_Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::div__Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::div_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out) { + return at::_ops::div_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::div.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::div_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::div__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + inline at::Tensor div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::div_Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & div_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::div__Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::divide_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::divide__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::divide.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::divide_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::divide__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::divide_Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::divide__Tensor_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + return at::_ops::divide_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out) { + return at::_ops::divide_out_mode::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor + inline at::Tensor divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::divide_Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!) + inline at::Tensor & divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::divide__Scalar_mode::redispatch(dispatchKeySet, self, other, rounding_mode); + } + + // aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor true_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::true_divide_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & true_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::true_divide__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & true_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::true_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & true_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::true_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor true_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::true_divide_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & true_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::true_divide__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::dot(Tensor self, Tensor tensor) -> Tensor + inline at::Tensor dot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor) { + return at::_ops::dot::redispatch(dispatchKeySet, self, tensor); + } + + // aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor) { + return at::_ops::dot_out::redispatch(dispatchKeySet, self, tensor, out); + } + + // aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor, at::Tensor & out) { + return at::_ops::dot_out::redispatch(dispatchKeySet, self, tensor, out); + } + + // aten::vdot(Tensor self, Tensor other) -> Tensor + inline at::Tensor vdot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::vdot::redispatch(dispatchKeySet, self, other); + } + + // aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vdot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::vdot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vdot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::vdot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor + inline at::Tensor einsum(c10::DispatchKeySet dispatchKeySet, c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path=::std::nullopt) { + return at::_ops::einsum::redispatch(dispatchKeySet, equation, tensors, path); + } + + // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + inline at::Tensor embedding(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor + inline at::Tensor embedding_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + inline at::Tensor embedding_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { + return at::_ops::embedding_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor + inline at::Tensor embedding_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { + return at::_ops::embedding_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse); + } + + // aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + inline at::Tensor embedding_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq); + } + + // aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor + inline at::Tensor embedding_dense_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq); + } + + // aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!) + inline at::Tensor & embedding_renorm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + return at::_ops::embedding_renorm_::redispatch(dispatchKeySet, self, indices, max_norm, norm_type); + } + + // aten::embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor + inline at::Tensor embedding_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_sparse_backward::redispatch(dispatchKeySet, grad, indices, num_weights, padding_idx, scale_grad_by_freq); + } + + // aten::_embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _embedding_bag_forward_only(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_forward_only::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + + // aten::_rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor) + inline ::std::tuple _rowwise_prune(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype) { + return at::_ops::_rowwise_prune::redispatch(dispatchKeySet, weight, mask, compressed_indices_dtype); + } + + // aten::row_stack(Tensor[] tensors) -> Tensor + inline at::Tensor row_stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::row_stack::redispatch(dispatchKeySet, tensors); + } + + // aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::row_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::row_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false) { + return at::_ops::embedding_bag::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset); + } + + // aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, ::std::optional padding_idx) { + return at::_ops::embedding_bag_padding_idx::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + + // aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _embedding_bag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + + // aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_sparse_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_sparse_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_sparse_backward::redispatch(dispatchKeySet, grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_dense_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + + // aten::_embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor + inline at::Tensor _embedding_bag_per_sample_weights_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_per_sample_weights_backward::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx); + } + + // aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_memory_format::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, at::TensorOptions options={}) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, dtype, layout, device, pin_memory); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::TensorOptions options={}) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, size, physical_layout, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_permuted_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_permuted::redispatch(dispatchKeySet, size, physical_layout, dtype, layout, device, pin_memory); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_empty_strided_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_empty_strided::redispatch(dispatchKeySet, self, size, stride, dtype, layout, device, pin_memory); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_full_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_full::redispatch(dispatchKeySet, self, size, fill_value, dtype, layout, device, pin_memory); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_zeros_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_zeros::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor new_ones_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::new_ones::redispatch(dispatchKeySet, self, size, dtype, layout, device, pin_memory); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), scale, zero_point, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, double scale, int64_t zero_point, ::std::optional memory_format) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, scale, zero_point, memory_format); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), scale, zero_point, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, double scale, int64_t zero_point, ::std::optional memory_format) { + return at::_ops::_empty_affine_quantized::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory, scale, zero_point, memory_format); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::TensorOptions options={}, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, size, scales, zero_points, axis, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor + inline at::Tensor _empty_per_channel_affine_quantized_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::_empty_per_channel_affine_quantized::redispatch(dispatchKeySet, size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format); + } + + // aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) + inline const at::Tensor & resize_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format); + } + + // aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!) + inline const at::Tensor & resize__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_::redispatch(dispatchKeySet, self, size, memory_format); + } + + // aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) + inline const at::Tensor & _resize_output_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device) { + return at::_ops::_resize_output_::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device); + } + + // aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!) + inline const at::Tensor & _resize_output__symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + return at::_ops::_resize_output_::redispatch(dispatchKeySet, self, size, device); + } + + // aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_quantized::redispatch(dispatchKeySet, size, qtensor, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_quantized(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_quantized::redispatch(dispatchKeySet, size, qtensor, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_out::redispatch(dispatchKeySet, size, memory_format, out); + } + + // aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_out::redispatch(dispatchKeySet, size, memory_format, out); + } + + // aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor empty_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::empty_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), dtype, layout, device, pin_memory); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::TensorOptions options={}) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, size, stride, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor empty_strided_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::empty_strided::redispatch(dispatchKeySet, size, stride, dtype, layout, device, pin_memory); + } + + // aten::erf(Tensor self) -> Tensor + inline at::Tensor erf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::erf::redispatch(dispatchKeySet, self); + } + + // aten::erf_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & erf_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::erf_::redispatch(dispatchKeySet, self); + } + + // aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erfc(Tensor self) -> Tensor + inline at::Tensor erfc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::erfc::redispatch(dispatchKeySet, self); + } + + // aten::erfc_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & erfc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::erfc_::redispatch(dispatchKeySet, self); + } + + // aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp(Tensor self) -> Tensor + inline at::Tensor exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::exp::redispatch(dispatchKeySet, self); + } + + // aten::exp_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & exp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::exp_::redispatch(dispatchKeySet, self); + } + + // aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp2(Tensor self) -> Tensor + inline at::Tensor exp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::exp2::redispatch(dispatchKeySet, self); + } + + // aten::exp2_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & exp2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::exp2_::redispatch(dispatchKeySet, self); + } + + // aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::expm1(Tensor self) -> Tensor + inline at::Tensor expm1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::expm1::redispatch(dispatchKeySet, self); + } + + // aten::expm1_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & expm1_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::expm1_::redispatch(dispatchKeySet, self); + } + + // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expm1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expm1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + inline at::Tensor expand(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit); + } + + // aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a) + inline at::Tensor expand_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand::redispatch(dispatchKeySet, self, size, implicit); + } + + // aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a) + inline at::Tensor expand_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::expand_as::redispatch(dispatchKeySet, self, other); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, at::TensorOptions options={}) { + return at::_ops::eye::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::TensorOptions options={}) { + return at::_ops::eye::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, at::TensorOptions options={}) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, dtype, layout, device, pin_memory); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, at::TensorOptions options={}) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor eye_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::eye_m::redispatch(dispatchKeySet, n, m, dtype, layout, device, pin_memory); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, at::Tensor & out) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::Tensor & out) { + return at::_ops::eye_out::redispatch(dispatchKeySet, n, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, int64_t m) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, int64_t m, at::Tensor & out) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n, c10::SymInt m) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eye_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, c10::SymInt m, at::Tensor & out) { + return at::_ops::eye_m_out::redispatch(dispatchKeySet, n, m, out); + } + + // aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t start_dim=0, int64_t end_dim=-1) { + return at::_ops::flatten_using_ints::redispatch(dispatchKeySet, self, start_dim, end_dim); + } + + // aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim) { + return at::_ops::flatten_named_out_dim::redispatch(dispatchKeySet, self, start_dim, end_dim, out_dim); + } + + // aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) { + return at::_ops::flatten_using_names::redispatch(dispatchKeySet, self, start_dim, end_dim, out_dim); + } + + // aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a) + inline at::Tensor flatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim) { + return at::_ops::flatten_DimnameList::redispatch(dispatchKeySet, self, dims, out_dim); + } + + // aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) + inline at::Tensor unflatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::IntArrayRef sizes) { + return at::_ops::unflatten_int::redispatch(dispatchKeySet, self, dim, c10::fromIntArrayRefSlow(sizes)); + } + + // aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a) + inline at::Tensor unflatten_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes) { + return at::_ops::unflatten_int::redispatch(dispatchKeySet, self, dim, sizes); + } + + // aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) + inline at::Tensor unflatten(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::IntArrayRef sizes, at::DimnameList names) { + return at::_ops::unflatten_Dimname::redispatch(dispatchKeySet, self, dim, c10::fromIntArrayRefSlow(sizes), names); + } + + // aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a) + inline at::Tensor unflatten_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) { + return at::_ops::unflatten_Dimname::redispatch(dispatchKeySet, self, dim, sizes, names); + } + + // aten::fill.Scalar(Tensor self, Scalar value) -> Tensor + inline at::Tensor fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & value) { + return at::_ops::fill_Scalar::redispatch(dispatchKeySet, self, value); + } + + // aten::fill.Tensor(Tensor self, Tensor value) -> Tensor + inline at::Tensor fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & value) { + return at::_ops::fill_Tensor::redispatch(dispatchKeySet, self, value); + } + + // aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!) + inline at::Tensor & fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & value) { + return at::_ops::fill__Scalar::redispatch(dispatchKeySet, self, value); + } + + // aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!) + inline at::Tensor & fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & value) { + return at::_ops::fill__Tensor::redispatch(dispatchKeySet, self, value); + } + + // aten::floor(Tensor self) -> Tensor + inline at::Tensor floor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::floor::redispatch(dispatchKeySet, self); + } + + // aten::floor_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & floor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::floor_::redispatch(dispatchKeySet, self); + } + + // aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::floor_divide(Tensor self, Tensor other) -> Tensor + inline at::Tensor floor_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::floor_divide::redispatch(dispatchKeySet, self, other); + } + + // aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & floor_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::floor_divide__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::floor_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::floor_divide_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor floor_divide(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::floor_divide_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & floor_divide_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::floor_divide__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::frac(Tensor self) -> Tensor + inline at::Tensor frac(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::frac::redispatch(dispatchKeySet, self); + } + + // aten::frac_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & frac_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::frac_::redispatch(dispatchKeySet, self); + } + + // aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frac_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frac_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::full_names::redispatch(dispatchKeySet, size, fill_value, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::full_names::redispatch(dispatchKeySet, size, fill_value, names, dtype, layout, device, pin_memory); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::full::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::full::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, dtype, layout, device, pin_memory); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::TensorOptions options={}) { + return at::_ops::full::redispatch(dispatchKeySet, size, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor full_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::full::redispatch(dispatchKeySet, size, fill_value, dtype, layout, device, pin_memory); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::full_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::full_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::full_out::redispatch(dispatchKeySet, size, fill_value, out); + } + + // aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::full_out::redispatch(dispatchKeySet, size, fill_value, out); + } + + // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor full_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::full_like::redispatch(dispatchKeySet, self, fill_value, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor full_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::full_like::redispatch(dispatchKeySet, self, fill_value, dtype, layout, device, pin_memory, memory_format); + } + + // aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor from_file(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, ::std::optional shared=::std::nullopt, ::std::optional size=0, at::TensorOptions options={}) { + return at::_ops::from_file::redispatch(dispatchKeySet, filename, shared, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor from_file(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, ::std::optional shared, ::std::optional size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::from_file::redispatch(dispatchKeySet, filename, shared, size, dtype, layout, device, pin_memory); + } + + // aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gcd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gcd_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gcd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::gcd_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gcd(Tensor self, Tensor other) -> Tensor + inline at::Tensor gcd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gcd::redispatch(dispatchKeySet, self, other); + } + + // aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & gcd_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::gcd_::redispatch(dispatchKeySet, self, other); + } + + // aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lcm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lcm_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lcm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::lcm_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lcm(Tensor self, Tensor other) -> Tensor + inline at::Tensor lcm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lcm::redispatch(dispatchKeySet, self, other); + } + + // aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & lcm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::lcm_::redispatch(dispatchKeySet, self, other); + } + + // aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor grid_sampler(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor grid_sampler_2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_2d::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple grid_sampler_2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_2d_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + + // aten::_grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor _grid_sampler_2d_cpu_fallback(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::_grid_sampler_2d_cpu_fallback::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::_grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor) + inline ::std::tuple _grid_sampler_2d_cpu_fallback_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::_grid_sampler_2d_cpu_fallback_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor + inline at::Tensor grid_sampler_3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_3d::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners); + } + + // aten::grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple grid_sampler_3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_3d_backward::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + + // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::hann_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hann_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::hann_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hann_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hann_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::hamming_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::hamming_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, at::TensorOptions options={}) { + return at::_ops::hamming_window_periodic_alpha::redispatch(dispatchKeySet, window_length, periodic, alpha, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window_periodic_alpha::redispatch(dispatchKeySet, window_length, periodic, alpha, dtype, layout, device, pin_memory); + } + + // aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, at::TensorOptions options={}) { + return at::_ops::hamming_window_periodic_alpha_beta::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor hamming_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::hamming_window_periodic_alpha_beta::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, dtype, layout, device, pin_memory); + } + + // aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::TensorOptions options={}) { + return at::_ops::kaiser_window::redispatch(dispatchKeySet, window_length, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::kaiser_window::redispatch(dispatchKeySet, window_length, dtype, layout, device, pin_memory); + } + + // aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::TensorOptions options={}) { + return at::_ops::kaiser_window_periodic::redispatch(dispatchKeySet, window_length, periodic, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::kaiser_window_periodic::redispatch(dispatchKeySet, window_length, periodic, dtype, layout, device, pin_memory); + } + + // aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, at::TensorOptions options={}) { + return at::_ops::kaiser_window_beta::redispatch(dispatchKeySet, window_length, periodic, beta, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor kaiser_window(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::kaiser_window_beta::redispatch(dispatchKeySet, window_length, periodic, beta, dtype, layout, device, pin_memory); + } + + // aten::hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor + inline at::Tensor hinge_embedding_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, double margin=1.0, int64_t reduction=at::Reduction::Mean) { + return at::_ops::hinge_embedding_loss::redispatch(dispatchKeySet, self, target, margin, reduction); + } + + // aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor + inline at::Tensor group_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t num_groups, const ::std::optional & weight={}, const ::std::optional & bias={}, double eps=1e-05, bool cudnn_enabled=true) { + return at::_ops::group_norm::redispatch(dispatchKeySet, input, num_groups, weight, bias, eps, cudnn_enabled); + } + + // aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps); + } + + // aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps); + } + + // aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask); + } + + // aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_group_norm_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask); + } + + // aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor + inline at::Tensor _fft_r2c(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) { + return at::_ops::_fft_r2c::redispatch(dispatchKeySet, self, dim, normalization, onesided); + } + + // aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_r2c_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) { + return at::_ops::_fft_r2c_out::redispatch(dispatchKeySet, self, dim, normalization, onesided, out); + } + + // aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_r2c_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided, at::Tensor & out) { + return at::_ops::_fft_r2c_out::redispatch(dispatchKeySet, self, dim, normalization, onesided, out); + } + + // aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + inline at::Tensor _fft_c2r(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { + return at::_ops::_fft_c2r::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size); + } + + // aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor + inline at::Tensor _fft_c2r_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) { + return at::_ops::_fft_c2r::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, int64_t last_dim_size, at::Tensor & out) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2r_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out) { + return at::_ops::_fft_c2r_out::redispatch(dispatchKeySet, self, dim, normalization, last_dim_size, out); + } + + // aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + inline at::Tensor _fft_c2c(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward); + } + + // aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor + inline at::Tensor _fft_c2c_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c::redispatch(dispatchKeySet, self, dim, normalization, forward); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward, out); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dim), normalization, forward, out); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, dim, normalization, forward, out); + } + + // aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fft_c2c_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out) { + return at::_ops::_fft_c2c_out::redispatch(dispatchKeySet, self, dim, normalization, forward, out); + } + + // aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> () + inline void _validate_compressed_sparse_indices(c10::DispatchKeySet dispatchKeySet, bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz) { + return at::_ops::_validate_compressed_sparse_indices::redispatch(dispatchKeySet, is_crow, compressed_idx, plain_idx, cdim, dim, nnz); + } + + // aten::_cufft_get_plan_cache_size(DeviceIndex device_index) -> int + inline int64_t _cufft_get_plan_cache_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) { + return at::_ops::_cufft_get_plan_cache_size::redispatch(dispatchKeySet, device_index); + } + + // aten::_cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int + inline int64_t _cufft_get_plan_cache_max_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) { + return at::_ops::_cufft_get_plan_cache_max_size::redispatch(dispatchKeySet, device_index); + } + + // aten::_cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> () + inline void _cufft_set_plan_cache_max_size(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index, int64_t max_size) { + return at::_ops::_cufft_set_plan_cache_max_size::redispatch(dispatchKeySet, device_index, max_size); + } + + // aten::_cufft_clear_plan_cache(DeviceIndex device_index) -> () + inline void _cufft_clear_plan_cache(c10::DispatchKeySet dispatchKeySet, at::DeviceIndex device_index) { + return at::_ops::_cufft_clear_plan_cache::redispatch(dispatchKeySet, device_index); + } + + // aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + inline at::Tensor index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices) { + return at::_ops::index_Tensor::redispatch(dispatchKeySet, self, indices); + } + + // aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional> & indices) { + return at::_ops::index_Tensor_out::redispatch(dispatchKeySet, self, indices, out); + } + + // aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, at::Tensor & out) { + return at::_ops::index_Tensor_out::redispatch(dispatchKeySet, self, indices, out); + } + + // aten::_unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor + inline at::Tensor _unsafe_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices) { + return at::_ops::_unsafe_index_Tensor::redispatch(dispatchKeySet, self, indices); + } + + // aten::_unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor + inline at::Tensor _unsafe_masked_index(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill) { + return at::_ops::_unsafe_masked_index::redispatch(dispatchKeySet, self, mask, indices, fill); + } + + // aten::_unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor + inline at::Tensor _unsafe_masked_index_put_accumulate(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Tensor & values) { + return at::_ops::_unsafe_masked_index_put_accumulate::redispatch(dispatchKeySet, self, mask, indices, values); + } + + // aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy_out::redispatch(dispatchKeySet, self, dim, index, source, out); + } + + // aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, at::Tensor & out) { + return at::_ops::index_copy_out::redispatch(dispatchKeySet, self, dim, index, source, out); + } + + // aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!) + inline at::Tensor & index_copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy_::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor + inline at::Tensor index_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!) + inline at::Tensor & index_copy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy__dimname::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor + inline at::Tensor index_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + return at::_ops::index_copy_dimname::redispatch(dispatchKeySet, self, dim, index, source); + } + + // aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!) + inline at::Tensor & index_put_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::index_put_::redispatch(dispatchKeySet, self, indices, values, accumulate); + } + + // aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + inline at::Tensor index_put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::index_put::redispatch(dispatchKeySet, self, indices, values, accumulate); + } + + // aten::_unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor + inline at::Tensor _unsafe_index_put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::_unsafe_index_put::redispatch(dispatchKeySet, self, indices, values, accumulate); + } + + // aten::_index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!) + inline at::Tensor & _index_put_impl_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) { + return at::_ops::_index_put_impl_::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe); + } + + // aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor + inline at::Tensor instance_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) { + return at::_ops::instance_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled); + } + + // aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor + inline at::Tensor isclose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, double rtol=1e-05, double atol=1e-08, bool equal_nan=false) { + return at::_ops::isclose::redispatch(dispatchKeySet, self, other, rtol, atol, equal_nan); + } + + // aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Tensor_out::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert, out); + } + + // aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out) { + return at::_ops::isin_Tensor_Tensor_out::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert, out); + } + + // aten::isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor + inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Tensor::redispatch(dispatchKeySet, elements, test_elements, assume_unique, invert); + } + + // aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Scalar_out::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert, out); + } + + // aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert, at::Tensor & out) { + return at::_ops::isin_Tensor_Scalar_out::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert, out); + } + + // aten::isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor + inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Tensor_Scalar::redispatch(dispatchKeySet, elements, test_element, assume_unique, invert); + } + + // aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Scalar_Tensor_out::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert, out); + } + + // aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isin_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out) { + return at::_ops::isin_Scalar_Tensor_out::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert, out); + } + + // aten::isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor + inline at::Tensor isin(c10::DispatchKeySet dispatchKeySet, const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique=false, bool invert=false) { + return at::_ops::isin_Scalar_Tensor::redispatch(dispatchKeySet, element, test_elements, assume_unique, invert); + } + + // aten::isnan(Tensor self) -> Tensor + inline at::Tensor isnan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isnan::redispatch(dispatchKeySet, self); + } + + // aten::is_distributed(Tensor self) -> bool + inline bool is_distributed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_distributed::redispatch(dispatchKeySet, self); + } + + // aten::is_floating_point(Tensor self) -> bool + inline bool __dispatch_is_floating_point(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_floating_point::redispatch(dispatchKeySet, self); + } + + // aten::is_complex(Tensor self) -> bool + inline bool __dispatch_is_complex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_complex::redispatch(dispatchKeySet, self); + } + + // aten::is_conj(Tensor self) -> bool + inline bool __dispatch_is_conj(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_conj::redispatch(dispatchKeySet, self); + } + + // aten::_is_zerotensor(Tensor self) -> bool + inline bool __dispatch__is_zerotensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_is_zerotensor::redispatch(dispatchKeySet, self); + } + + // aten::is_neg(Tensor self) -> bool + inline bool __dispatch_is_neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_neg::redispatch(dispatchKeySet, self); + } + + // aten::isreal(Tensor self) -> Tensor + inline at::Tensor isreal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isreal::redispatch(dispatchKeySet, self); + } + + // aten::is_nonzero(Tensor self) -> bool + inline bool is_nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_nonzero::redispatch(dispatchKeySet, self); + } + + // aten::is_same_size(Tensor self, Tensor other) -> bool + inline bool is_same_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::is_same_size::redispatch(dispatchKeySet, self, other); + } + + // aten::is_signed(Tensor self) -> bool + inline bool __dispatch_is_signed(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_signed::redispatch(dispatchKeySet, self); + } + + // aten::is_inference(Tensor self) -> bool + inline bool __dispatch_is_inference(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_inference::redispatch(dispatchKeySet, self); + } + + // aten::kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor + inline at::Tensor kl_div(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, bool log_target=false) { + return at::_ops::kl_div::redispatch(dispatchKeySet, self, target, reduction, log_target); + } + + // aten::kron(Tensor self, Tensor other) -> Tensor + inline at::Tensor kron(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::kron::redispatch(dispatchKeySet, self, other); + } + + // aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kron_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::kron_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kron_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::kron_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool keepdim=false) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_values::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple kthvalue_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname::redispatch(dispatchKeySet, self, k, dim, keepdim); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim=false) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple kthvalue_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::kthvalue_dimname_out::redispatch(dispatchKeySet, self, k, dim, keepdim, values, indices); + } + + // aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor + inline at::Tensor layer_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight={}, const ::std::optional & bias={}, double eps=1e-05, bool cudnn_enable=true) { + return at::_ops::layer_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, cudnn_enable); + } + + // aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor + inline at::Tensor layer_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight={}, const ::std::optional & bias={}, double eps=1e-05, bool cudnn_enable=true) { + return at::_ops::layer_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, cudnn_enable); + } + + // aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps); + } + + // aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps); + } + + // aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask); + } + + // aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_layer_norm_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask); + } + + // aten::rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + inline at::Tensor rms_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight={}, ::std::optional eps=::std::nullopt) { + return at::_ops::rms_norm::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, eps); + } + + // aten::rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + inline at::Tensor rms_norm_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight={}, ::std::optional eps=::std::nullopt) { + return at::_ops::rms_norm::redispatch(dispatchKeySet, input, normalized_shape, weight, eps); + } + + // aten::_fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor + inline at::Tensor _fused_rms_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t normalized_shape_ndim, const at::Tensor & weight, double eps) { + return at::_ops::_fused_rms_norm::redispatch(dispatchKeySet, input, normalized_shape_ndim, weight, eps); + } + + // aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor + inline at::Tensor nan_to_num(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) { + return at::_ops::nan_to_num::redispatch(dispatchKeySet, self, nan, posinf, neginf); + } + + // aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!) + inline at::Tensor & nan_to_num_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) { + return at::_ops::nan_to_num_::redispatch(dispatchKeySet, self, nan, posinf, neginf); + } + + // aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nan_to_num_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional nan=::std::nullopt, ::std::optional posinf=::std::nullopt, ::std::optional neginf=::std::nullopt) { + return at::_ops::nan_to_num_out::redispatch(dispatchKeySet, self, nan, posinf, neginf, out); + } + + // aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nan_to_num_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf, at::Tensor & out) { + return at::_ops::nan_to_num_out::redispatch(dispatchKeySet, self, nan, posinf, neginf, out); + } + + // aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor + inline at::Tensor linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::linear::redispatch(dispatchKeySet, input, weight, bias); + } + + // aten::linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple linear_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::linear_backward::redispatch(dispatchKeySet, self, grad_output, weight, output_mask); + } + + // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::linear_out::redispatch(dispatchKeySet, input, weight, bias, out); + } + + // aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out) { + return at::_ops::linear_out::redispatch(dispatchKeySet, input, weight, bias, out); + } + + // aten::mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor + inline at::Tensor mkldnn_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::mkldnn_linear::redispatch(dispatchKeySet, self, weight, bias); + } + + // aten::mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor + inline at::Tensor mkldnn_linear_backward_input(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) { + return at::_ops::mkldnn_linear_backward_input::redispatch(dispatchKeySet, input_size, grad_output, weight); + } + + // aten::mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor) + inline ::std::tuple mkldnn_linear_backward_weights(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) { + return at::_ops::mkldnn_linear_backward_weights::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined); + } + + // aten::mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple mkldnn_linear_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::mkldnn_linear_backward::redispatch(dispatchKeySet, self, grad_output, weight, output_mask); + } + + // aten::_cslt_compress(Tensor input) -> Tensor + inline at::Tensor _cslt_compress(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::_cslt_compress::redispatch(dispatchKeySet, input); + } + + // aten::_cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor + inline at::Tensor _cslt_sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias={}, const ::std::optional & alpha={}, ::std::optional out_dtype=::std::nullopt, bool transpose_result=false, int64_t alg_id=0, int64_t split_k=1, int64_t split_k_mode=-1) { + return at::_ops::_cslt_sparse_mm::redispatch(dispatchKeySet, compressed_A, dense_B, bias, alpha, out_dtype, transpose_result, alg_id, split_k, split_k_mode); + } + + // aten::_cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int + inline int64_t _cslt_sparse_mm_search(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias={}, const ::std::optional & alpha={}, ::std::optional out_dtype=::std::nullopt, bool transpose_result=false) { + return at::_ops::_cslt_sparse_mm_search::redispatch(dispatchKeySet, compressed_A, dense_B, bias, alpha, out_dtype, transpose_result); + } + + // aten::_sparse_semi_structured_tile(Tensor input, str algorithm="", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _sparse_semi_structured_tile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::string_view algorithm="", bool use_cutlass=true) { + return at::_ops::_sparse_semi_structured_tile::redispatch(dispatchKeySet, input, algorithm, use_cutlass); + } + + // aten::_sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor) + inline ::std::tuple _sparse_semi_structured_apply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & thread_masks) { + return at::_ops::_sparse_semi_structured_apply::redispatch(dispatchKeySet, input, thread_masks); + } + + // aten::_sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor + inline at::Tensor _sparse_semi_structured_apply_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & thread_masks) { + return at::_ops::_sparse_semi_structured_apply_dense::redispatch(dispatchKeySet, input, thread_masks); + } + + // aten::_sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _sparse_semi_structured_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const ::std::optional & bias={}, ::std::optional activation=::std::nullopt, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_sparse_semi_structured_linear::redispatch(dispatchKeySet, input, weight, meta, bias, activation, out_dtype); + } + + // aten::_sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _sparse_semi_structured_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_sparse_semi_structured_mm::redispatch(dispatchKeySet, mat1, mat1_meta, mat2, out_dtype); + } + + // aten::_sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _sparse_semi_structured_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, const at::Scalar & alpha=1, const at::Scalar & beta=1, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_sparse_semi_structured_addmm::redispatch(dispatchKeySet, input, mat1, mat1_meta, mat2, alpha, beta, out_dtype); + } + + // aten::_mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor + inline at::Tensor _mixed_dtypes_linear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const ::std::optional & bias={}, ::std::optional activation=::std::nullopt) { + return at::_ops::_mixed_dtypes_linear::redispatch(dispatchKeySet, input, weight, scale, bias, activation); + } + + // aten::fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_int8_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_int8_weight_fp32_activation::redispatch(dispatchKeySet, input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + + // aten::fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_int8_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_int8_weight::redispatch(dispatchKeySet, input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + + // aten::fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int) + inline ::std::tuple fbgemm_linear_quantize_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::fbgemm_linear_quantize_weight::redispatch(dispatchKeySet, input); + } + + // aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor + inline at::Tensor fbgemm_pack_gemm_matrix_fp16(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::fbgemm_pack_gemm_matrix_fp16::redispatch(dispatchKeySet, input); + } + + // aten::_wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor + inline at::Tensor _wrapped_linear_prepack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & weight_scale, const at::Tensor & weight_zero_point, const at::Tensor & bias) { + return at::_ops::_wrapped_linear_prepack::redispatch(dispatchKeySet, weight, weight_scale, weight_zero_point, bias); + } + + // aten::_wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor + inline at::Tensor _wrapped_quantized_linear_prepacked(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & input_scale, const at::Tensor & input_zero_point, const at::Tensor & packed_weight, const at::Tensor & output_scale, const at::Tensor & output_zero_point, int64_t out_channel) { + return at::_ops::_wrapped_quantized_linear_prepacked::redispatch(dispatchKeySet, input, input_scale, input_zero_point, packed_weight, output_scale, output_zero_point, out_channel); + } + + // aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_fp16_weight_fp32_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_fp16_weight_fp32_activation::redispatch(dispatchKeySet, input, packed_weight, bias); + } + + // aten::fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor + inline at::Tensor fbgemm_linear_fp16_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) { + return at::_ops::fbgemm_linear_fp16_weight::redispatch(dispatchKeySet, input, packed_weight, bias); + } + + // aten::fbgemm_pack_quantized_matrix(Tensor input) -> Tensor + inline at::Tensor fbgemm_pack_quantized_matrix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input) { + return at::_ops::fbgemm_pack_quantized_matrix::redispatch(dispatchKeySet, input); + } + + // aten::fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor + inline at::Tensor fbgemm_pack_quantized_matrix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t K, int64_t N) { + return at::_ops::fbgemm_pack_quantized_matrix_KN::redispatch(dispatchKeySet, input, K, N); + } + + // aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor ldexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ldexp_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & ldexp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::ldexp_::redispatch(dispatchKeySet, self, other); + } + + // aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ldexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ldexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ldexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::ldexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, at::TensorOptions options={}) { + return at::_ops::linspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor linspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::linspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, dtype, layout, device, pin_memory); + } + + // aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, int64_t steps) { + return at::_ops::linspace_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Tensor & end, int64_t steps) { + return at::_ops::linspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Scalar & end, int64_t steps) { + return at::_ops::linspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Tensor & end, int64_t steps) { + return at::_ops::linspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, at::Tensor & out) { + return at::_ops::linspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, out); + } + + // aten::log(Tensor self) -> Tensor + inline at::Tensor log(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log::redispatch(dispatchKeySet, self); + } + + // aten::log_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log_::redispatch(dispatchKeySet, self); + } + + // aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log10(Tensor self) -> Tensor + inline at::Tensor log10(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log10::redispatch(dispatchKeySet, self); + } + + // aten::log10_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log10_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log10_::redispatch(dispatchKeySet, self); + } + + // aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log10_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log10_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log1p(Tensor self) -> Tensor + inline at::Tensor log1p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log1p::redispatch(dispatchKeySet, self); + } + + // aten::log1p_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log1p_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log1p_::redispatch(dispatchKeySet, self); + } + + // aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log1p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log1p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log2(Tensor self) -> Tensor + inline at::Tensor log2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log2::redispatch(dispatchKeySet, self); + } + + // aten::log2_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & log2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::log2_::redispatch(dispatchKeySet, self); + } + + // aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logaddexp_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp(Tensor self, Tensor other) -> Tensor + inline at::Tensor logaddexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp::redispatch(dispatchKeySet, self, other); + } + + // aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logaddexp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::logaddexp2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logaddexp2(Tensor self, Tensor other) -> Tensor + inline at::Tensor logaddexp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::logaddexp2::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::xlogy_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor + inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::xlogy_Scalar_Self::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor + inline at::Tensor xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::xlogy_Scalar_Other::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & xlogy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::xlogy__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & xlogy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::xlogy__Scalar_Other::redispatch(dispatchKeySet, self, other); + } + + // aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::xlogy_OutTensor::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::xlogy_OutTensor::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::xlogy_OutScalar_Self::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::xlogy_OutScalar_Self::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::xlogy_OutScalar_Other::redispatch(dispatchKeySet, self, other, out); + } + + // aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::xlogy_OutScalar_Other::redispatch(dispatchKeySet, self, other, out); + } + + // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace_Tensor_Tensor::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace_Tensor_Scalar::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base=10.0, at::TensorOptions options={}) { + return at::_ops::logspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, base, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor logspace(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::logspace_Scalar_Tensor::redispatch(dispatchKeySet, start, end, steps, base, dtype, layout, device, pin_memory); + } + + // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_Tensor_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_Tensor_Scalar_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base=10.0) { + return at::_ops::logspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logspace_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out) { + return at::_ops::logspace_Scalar_Tensor_out::redispatch(dispatchKeySet, start, end, steps, base, out); + } + + // aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::log_softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::log_softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::log_softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::log_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_log_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + inline at::Tensor _log_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_log_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype); + } + + // aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, out); + } + + // aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _log_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & out) { + return at::_ops::_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, out); + } + + // aten::_logcumsumexp(Tensor self, int dim) -> Tensor + inline at::Tensor _logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::_logcumsumexp::redispatch(dispatchKeySet, self, dim); + } + + // aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::_logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::_logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp(Tensor self, int dim) -> Tensor + inline at::Tensor logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::logcumsumexp::redispatch(dispatchKeySet, self, dim); + } + + // aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::logcumsumexp_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor + inline at::Tensor logcumsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::logcumsumexp_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim) { + return at::_ops::logcumsumexp_dimname_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logcumsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, at::Tensor & out) { + return at::_ops::logcumsumexp_dimname_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::logsumexp::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false) { + return at::_ops::logsumexp_names::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false) { + return at::_ops::logsumexp_names_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, at::Tensor & out) { + return at::_ops::logsumexp_names_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor + inline at::Tensor margin_ranking_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin=0.0, int64_t reduction=at::Reduction::Mean) { + return at::_ops::margin_ranking_loss::redispatch(dispatchKeySet, input1, input2, target, margin, reduction); + } + + // aten::matmul(Tensor self, Tensor other) -> Tensor + inline at::Tensor matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::matmul::redispatch(dispatchKeySet, self, other); + } + + // aten::matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor) + inline ::std::tuple matmul_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) { + return at::_ops::matmul_backward::redispatch(dispatchKeySet, grad, self, other, mask); + } + + // aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::matrix_power(Tensor self, int n) -> Tensor + inline at::Tensor matrix_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n) { + return at::_ops::matrix_power::redispatch(dispatchKeySet, self, n); + } + + // aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matrix_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n) { + return at::_ops::matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & matrix_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, at::Tensor & out) { + return at::_ops::matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::matrix_exp(Tensor self) -> Tensor + inline at::Tensor matrix_exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::matrix_exp::redispatch(dispatchKeySet, self); + } + + // aten::matrix_exp_backward(Tensor self, Tensor grad) -> Tensor + inline at::Tensor matrix_exp_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad) { + return at::_ops::matrix_exp_backward::redispatch(dispatchKeySet, self, grad); + } + + // aten::_aminmax(Tensor self) -> (Tensor, Tensor) + inline ::std::tuple _aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_aminmax::redispatch(dispatchKeySet, self); + } + + // aten::_aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple _aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::_aminmax_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max) + inline ::std::tuple aminmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::aminmax::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) + inline ::std::tuple aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & max, const at::Tensor & self, ::std::optional dim=::std::nullopt, bool keepdim=false) { + return at::_ops::aminmax_out::redispatch(dispatchKeySet, self, dim, keepdim, min, max); + } + + // aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max) + inline ::std::tuple aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max) { + return at::_ops::aminmax_out::redispatch(dispatchKeySet, self, dim, keepdim, min, max); + } + + // aten::_compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor + inline at::Tensor _compute_linear_combination(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & coefficients) { + return at::_ops::_compute_linear_combination::redispatch(dispatchKeySet, input, coefficients); + } + + // aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _compute_linear_combination_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & coefficients) { + return at::_ops::_compute_linear_combination_out::redispatch(dispatchKeySet, input, coefficients, out); + } + + // aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _compute_linear_combination_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & coefficients, at::Tensor & out) { + return at::_ops::_compute_linear_combination_out::redispatch(dispatchKeySet, input, coefficients, out); + } + + // aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::max_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & max, at::Tensor & max_values, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::max_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & max, at::Tensor & max_values) { + return at::_ops::max_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::max_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & max, at::Tensor & max_values, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::max_names_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & max, at::Tensor & max_values) { + return at::_ops::max_names_dim_max::redispatch(dispatchKeySet, self, dim, keepdim, max, max_values); + } + + // aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor + inline at::Tensor value_selecting_reduction_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t dim, const at::Tensor & indices, at::IntArrayRef sizes, bool keepdim) { + return at::_ops::value_selecting_reduction_backward::redispatch(dispatchKeySet, grad, dim, indices, c10::fromIntArrayRefSlow(sizes), keepdim); + } + + // aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor + inline at::Tensor value_selecting_reduction_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim) { + return at::_ops::value_selecting_reduction_backward::redispatch(dispatchKeySet, grad, dim, indices, sizes, keepdim); + } + + // aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + inline at::Tensor amax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amax::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::amax_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + inline ::std::tuple max_pool1d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool1d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor mkldnn_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor quantized_max_pool1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool1d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor quantized_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor quantized_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor + inline at::Tensor max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean::redispatch(dispatchKeySet, self, dtype); + } + + // aten::mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::mean_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_dim::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::mean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_names_dim::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::mean_names_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::mean_names_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor nanmean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nanmean::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nanmean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::nanmean_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::median(Tensor self) -> Tensor + inline at::Tensor median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::median::redispatch(dispatchKeySet, self); + } + + // aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::median_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::median_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::median_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple median(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::median_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::median_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::median_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian(Tensor self) -> Tensor + inline at::Tensor nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::nanmedian::redispatch(dispatchKeySet, self); + } + + // aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::nanmedian_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::nanmedian_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::nanmedian_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple nanmedian(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::nanmedian_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::nanmedian_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::nanmedian_names_dim_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::min_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & min_indices, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::min_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices) { + return at::_ops::min_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::min_names_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & min, at::Tensor & min_indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::min_names_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices) { + return at::_ops::min_names_dim_min::redispatch(dispatchKeySet, self, dim, keepdim, min, min_indices); + } + + // aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor + inline at::Tensor amin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amin::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim={}, bool keepdim=false) { + return at::_ops::amin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & amin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::amin_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor _mps_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups); + } + + // aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple mps_convolution_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask); + } + + // aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple mps_convolution_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask); + } + + // aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor mkldnn_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::mkldnn_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor mkldnn_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::mkldnn_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups); + } + + // aten::mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple mkldnn_rnn_layer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { + return at::_ops::mkldnn_rnn_layer::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); + } + + // aten::mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple mkldnn_rnn_layer_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) { + return at::_ops::mkldnn_rnn_layer_backward::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace); + } + + // aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor) + inline ::std::tuple miopen_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::miopen_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + + // aten::miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor) + inline ::std::tuple miopen_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon) { + return at::_ops::miopen_batch_norm_backward::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon); + } + + // aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic); + } + + // aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + + // aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic); + } + + // aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_convolution_transpose_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + } + + // aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_depthwise_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic); + } + + // aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor + inline at::Tensor miopen_depthwise_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + + // aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::miopen_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::miopen_convolution_relu::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups); + } + + // aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_add_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::miopen_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups); + } + + // aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor + inline at::Tensor miopen_convolution_add_relu_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::miopen_convolution_add_relu::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + + // aten::miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple miopen_rnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::miopen_rnn::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + + // aten::miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[]) + inline ::std::tuple> miopen_rnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::miopen_rnn_backward::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + + // aten::mm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::mm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor + inline at::Tensor mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::mm_dtype::redispatch(dispatchKeySet, self, mat2, out_dtype); + } + + // aten::mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + return at::_ops::mm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out) { + return at::_ops::mm_dtype_out::redispatch(dispatchKeySet, self, mat2, out_dtype, out); + } + + // aten::_int_mm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor _int_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::_int_mm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _int_mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::_int_mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _int_mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::_int_mm_out::redispatch(dispatchKeySet, self, mat2, out); + } + + // aten::_convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor + inline at::Tensor _convert_weight_to_int4pack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t innerKTiles) { + return at::_ops::_convert_weight_to_int4pack::redispatch(dispatchKeySet, self, innerKTiles); + } + + // aten::_weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + inline at::Tensor _weight_int4pack_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + return at::_ops::_weight_int4pack_mm::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScaleAndZeros); + } + + // aten::_weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor + inline at::Tensor _weight_int4pack_mm_with_scales_and_zeros(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScale, const at::Tensor & qZeros) { + return at::_ops::_weight_int4pack_mm_with_scales_and_zeros::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScale, qZeros); + } + + // aten::_convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor + inline at::Tensor _convert_weight_to_int4pack_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t innerKTiles) { + return at::_ops::_convert_weight_to_int4pack_for_cpu::redispatch(dispatchKeySet, self, innerKTiles); + } + + // aten::_weight_int4pack_mm_for_cpu(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor + inline at::Tensor _weight_int4pack_mm_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + return at::_ops::_weight_int4pack_mm_for_cpu::redispatch(dispatchKeySet, self, mat2, qGroupSize, qScaleAndZeros); + } + + // aten::_dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor + inline at::Tensor _dyn_quant_pack_4bit_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weights, const at::Tensor & scales_zeros, const ::std::optional & bias, int64_t block_size, int64_t in_features, int64_t out_features) { + return at::_ops::_dyn_quant_pack_4bit_weight::redispatch(dispatchKeySet, weights, scales_zeros, bias, block_size, in_features, out_features); + } + + // aten::_dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor + inline at::Tensor _dyn_quant_matmul_4bit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & inp, const at::Tensor & packed_weights, int64_t block_size, int64_t in_features, int64_t out_features) { + return at::_ops::_dyn_quant_matmul_4bit::redispatch(dispatchKeySet, inp, packed_weights, block_size, in_features, out_features); + } + + // aten::_weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor + inline at::Tensor _weight_int8pack_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales) { + return at::_ops::_weight_int8pack_mm::redispatch(dispatchKeySet, self, mat2, scales); + } + + // aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor + inline at::Tensor _sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sparse, const at::Tensor & dense) { + return at::_ops::_sparse_mm::redispatch(dispatchKeySet, sparse, dense); + } + + // aten::_sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor + inline at::Tensor _sparse_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce) { + return at::_ops::_sparse_mm_reduce::redispatch(dispatchKeySet, sparse, dense, reduce); + } + + // aten::_sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor + inline at::Tensor _sparse_sparse_matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_sparse_sparse_matmul::redispatch(dispatchKeySet, self, other); + } + + // aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool keepdim=false) { + return at::_ops::mode::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim=-1, bool keepdim=false) { + return at::_ops::mode_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::mode_values::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices) + inline ::std::tuple mode(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::mode_dimname::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool keepdim=false) { + return at::_ops::mode_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple mode_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices) { + return at::_ops::mode_dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, values, indices); + } + + // aten::mul.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor mul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::mul_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & mul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::mul__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::mul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::mul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor mul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::mul_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & mul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::mul__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor multiply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::multiply_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & multiply_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::multiply__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multiply_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::multiply_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multiply_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::multiply_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor multiply(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::multiply_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & multiply_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::multiply__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::mv(Tensor self, Tensor vec) -> Tensor + inline at::Tensor mv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec) { + return at::_ops::mv::redispatch(dispatchKeySet, self, vec); + } + + // aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec) { + return at::_ops::mv_out::redispatch(dispatchKeySet, self, vec, out); + } + + // aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec, at::Tensor & out) { + return at::_ops::mv_out::redispatch(dispatchKeySet, self, vec, out); + } + + // aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mvlgamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t p) { + return at::_ops::mvlgamma_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mvlgamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p, at::Tensor & out) { + return at::_ops::mvlgamma_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::mvlgamma(Tensor self, int p) -> Tensor + inline at::Tensor mvlgamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p) { + return at::_ops::mvlgamma::redispatch(dispatchKeySet, self, p); + } + + // aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!) + inline at::Tensor & mvlgamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t p) { + return at::_ops::mvlgamma_::redispatch(dispatchKeySet, self, p); + } + + // aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor + inline at::Tensor narrow_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) { + return at::_ops::narrow_copy::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor + inline at::Tensor narrow_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + return at::_ops::narrow_copy::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length, at::Tensor & out) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & narrow_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length, at::Tensor & out) { + return at::_ops::narrow_copy_out::redispatch(dispatchKeySet, self, dim, start, length, out); + } + + // aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) + inline at::Tensor narrow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t start, int64_t length) { + return at::_ops::narrow::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a) + inline at::Tensor narrow_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + return at::_ops::narrow::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) + inline at::Tensor narrow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & start, int64_t length) { + return at::_ops::narrow_Tensor::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a) + inline at::Tensor narrow_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length) { + return at::_ops::narrow_Tensor::redispatch(dispatchKeySet, self, dim, start, length); + } + + // aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_batch_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps) { + return at::_ops::native_batch_norm::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps); + } + + // aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps) { + return at::_ops::native_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) { + return at::_ops::native_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _native_batch_norm_legit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps); + } + + // aten::_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _native_batch_norm_legit_no_training(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_training::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _native_batch_norm_legit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _native_batch_norm_legit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) { + return at::_ops::_native_batch_norm_legit_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _native_batch_norm_legit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_stats::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps); + } + + // aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_stats_out::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd) { + return at::_ops::_native_batch_norm_legit_no_stats_out::redispatch(dispatchKeySet, input, weight, bias, training, momentum, eps, out, save_mean, save_invstd); + } + + // aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps) { + return at::_ops::batch_norm_stats::redispatch(dispatchKeySet, input, eps); + } + + // aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor + inline at::Tensor batch_norm_elemt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + return at::_ops::batch_norm_elemt::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps); + } + + // aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_elemt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + return at::_ops::batch_norm_elemt_out::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps, out); + } + + // aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_elemt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out) { + return at::_ops::batch_norm_elemt_out::redispatch(dispatchKeySet, input, weight, bias, mean, invstd, eps, out); + } + + // aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_gather_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + return at::_ops::batch_norm_gather_stats::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count); + } + + // aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_gather_stats_with_counts(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + return at::_ops::batch_norm_gather_stats_with_counts::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts); + } + + // aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor) + inline ::std::tuple native_batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask) { + return at::_ops::native_batch_norm_backward::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask); + } + + // aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple batch_norm_backward_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + return at::_ops::batch_norm_backward_reduce::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + + // aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor + inline at::Tensor batch_norm_backward_elemt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + return at::_ops::batch_norm_backward_elemt::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + + // aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor) + inline ::std::tuple batch_norm_update_stats(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + return at::_ops::batch_norm_update_stats::redispatch(dispatchKeySet, input, running_mean, running_var, momentum); + } + + // aten::is_vulkan_available() -> bool + inline bool is_vulkan_available(c10::DispatchKeySet dispatchKeySet) { + return at::_ops::is_vulkan_available::redispatch(dispatchKeySet); + } + + // aten::_nnpack_available() -> bool + inline bool _nnpack_available(c10::DispatchKeySet dispatchKeySet) { + return at::_ops::_nnpack_available::redispatch(dispatchKeySet); + } + + // aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + inline at::Tensor _nnpack_spatial_convolution(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride=1) { + return at::_ops::_nnpack_spatial_convolution::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride)); + } + + // aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor + inline at::Tensor _nnpack_spatial_convolution_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1)) { + return at::_ops::_nnpack_spatial_convolution::redispatch(dispatchKeySet, input, weight, bias, padding, stride); + } + + // aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::ones_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::ones_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::ones::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::ones::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::ones::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor ones_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::ones::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::ones_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::ones_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::ones_out::redispatch(dispatchKeySet, size, out); + } + + // aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::ones_out::redispatch(dispatchKeySet, size, out); + } + + // aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor ones_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::ones_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor ones_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::ones_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor + inline at::Tensor pairwise_distance(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p=2, double eps=1e-06, bool keepdim=false) { + return at::_ops::pairwise_distance::redispatch(dispatchKeySet, x1, x2, p, eps, keepdim); + } + + // aten::cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor + inline at::Tensor cdist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p=2, ::std::optional compute_mode=::std::nullopt) { + return at::_ops::cdist::redispatch(dispatchKeySet, x1, x2, p, compute_mode); + } + + // aten::_euclidean_dist(Tensor x1, Tensor x2) -> Tensor + inline at::Tensor _euclidean_dist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2) { + return at::_ops::_euclidean_dist::redispatch(dispatchKeySet, x1, x2); + } + + // aten::_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor + inline at::Tensor _cdist_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + return at::_ops::_cdist_forward::redispatch(dispatchKeySet, x1, x2, p, compute_mode); + } + + // aten::_cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor + inline at::Tensor _cdist_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) { + return at::_ops::_cdist_backward::redispatch(dispatchKeySet, grad, x1, x2, p, cdist); + } + + // aten::pdist(Tensor self, float p=2) -> Tensor + inline at::Tensor pdist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p=2) { + return at::_ops::pdist::redispatch(dispatchKeySet, self, p); + } + + // aten::_pdist_forward(Tensor self, float p=2) -> Tensor + inline at::Tensor _pdist_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p=2) { + return at::_ops::_pdist_forward::redispatch(dispatchKeySet, self, p); + } + + // aten::_pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor + inline at::Tensor _pdist_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) { + return at::_ops::_pdist_backward::redispatch(dispatchKeySet, grad, self, p, pdist); + } + + // aten::cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor + inline at::Tensor cosine_similarity(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, int64_t dim=1, double eps=1e-08) { + return at::_ops::cosine_similarity::redispatch(dispatchKeySet, x1, x2, dim, eps); + } + + // aten::permute(Tensor(a) self, int[] dims) -> Tensor(a) + inline at::Tensor permute(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::permute::redispatch(dispatchKeySet, self, dims); + } + + // aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + inline at::Tensor movedim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + return at::_ops::movedim_intlist::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a) + inline at::Tensor movedim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t source, int64_t destination) { + return at::_ops::movedim_int::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a) + inline at::Tensor moveaxis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + return at::_ops::moveaxis_intlist::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a) + inline at::Tensor moveaxis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t source, int64_t destination) { + return at::_ops::moveaxis_int::redispatch(dispatchKeySet, self, source, destination); + } + + // aten::numpy_T(Tensor(a) self) -> Tensor(a) + inline at::Tensor numpy_T(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::numpy_T::redispatch(dispatchKeySet, self); + } + + // aten::matrix_H(Tensor(a) self) -> Tensor(a) + inline at::Tensor matrix_H(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::matrix_H::redispatch(dispatchKeySet, self); + } + + // aten::mT(Tensor(a) self) -> Tensor(a) + inline at::Tensor mT(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::mT::redispatch(dispatchKeySet, self); + } + + // aten::mH(Tensor(a) self) -> Tensor(a) + inline at::Tensor mH(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::mH::redispatch(dispatchKeySet, self); + } + + // aten::adjoint(Tensor(a) self) -> Tensor(a) + inline at::Tensor adjoint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::adjoint::redispatch(dispatchKeySet, self); + } + + // aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor + inline at::Tensor pixel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t upscale_factor) { + return at::_ops::pixel_shuffle::redispatch(dispatchKeySet, self, upscale_factor); + } + + // aten::pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor + inline at::Tensor pixel_unshuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t downscale_factor) { + return at::_ops::pixel_unshuffle::redispatch(dispatchKeySet, self, downscale_factor); + } + + // aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor channel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups) { + return at::_ops::channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor channel_shuffle_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups) { + return at::_ops::channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor native_channel_shuffle(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups) { + return at::_ops::native_channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor + inline at::Tensor native_channel_shuffle_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups) { + return at::_ops::native_channel_shuffle::redispatch(dispatchKeySet, self, groups); + } + + // aten::is_pinned(Tensor self, Device? device=None) -> bool + inline bool is_pinned(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::is_pinned::redispatch(dispatchKeySet, self, device); + } + + // aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a) + inline at::Tensor pin_memory(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::pin_memory::redispatch(dispatchKeySet, self, device); + } + + // aten::_pin_memory(Tensor self, Device? device=None) -> Tensor + inline at::Tensor _pin_memory(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::_pin_memory::redispatch(dispatchKeySet, self, device); + } + + // aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor + inline at::Tensor pinverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond=1e-15) { + return at::_ops::pinverse::redispatch(dispatchKeySet, self, rcond); + } + + // aten::poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor + inline at::Tensor poisson_nll_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction) { + return at::_ops::poisson_nll_loss::redispatch(dispatchKeySet, input, target, log_input, full, eps, reduction); + } + + // aten::rad2deg(Tensor self) -> Tensor + inline at::Tensor rad2deg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::rad2deg::redispatch(dispatchKeySet, self); + } + + // aten::rad2deg_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & rad2deg_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::rad2deg_::redispatch(dispatchKeySet, self); + } + + // aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rad2deg_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::rad2deg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rad2deg_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::rad2deg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::deg2rad(Tensor self) -> Tensor + inline at::Tensor deg2rad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::deg2rad::redispatch(dispatchKeySet, self); + } + + // aten::deg2rad_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & deg2rad_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::deg2rad_::redispatch(dispatchKeySet, self); + } + + // aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & deg2rad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::deg2rad_out::redispatch(dispatchKeySet, self, out); + } + + // aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & deg2rad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::deg2rad_out::redispatch(dispatchKeySet, self, out); + } + + // aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor scalar_tensor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, at::TensorOptions options={}) { + return at::_ops::scalar_tensor::redispatch(dispatchKeySet, s, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor scalar_tensor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::scalar_tensor::redispatch(dispatchKeySet, s, dtype, layout, device, pin_memory); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, dtype, layout, device, pin_memory); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, dtype, layout, device, pin_memory); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, size, generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator_with_names::redispatch(dispatchKeySet, size, generator, names, dtype, layout, device, pin_memory); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::rand::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::rand::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor rand_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::rand_generator::redispatch(dispatchKeySet, size, generator, dtype, layout, device, pin_memory); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::rand_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::rand_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::rand_out::redispatch(dispatchKeySet, size, out); + } + + // aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::rand_out::redispatch(dispatchKeySet, size, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::rand_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::rand_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor rand_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::rand_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint::redispatch(dispatchKeySet, high, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint::redispatch(dispatchKeySet, high, size, dtype, layout, device, pin_memory); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_generator::redispatch(dispatchKeySet, high, size, generator, dtype, layout, device, pin_memory); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low::redispatch(dispatchKeySet, low, high, size, dtype, layout, device, pin_memory); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randint_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randint_low_generator::redispatch(dispatchKeySet, low, high, size, generator, dtype, layout, device, pin_memory); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t high, at::IntArrayRef size) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt high, c10::SymIntArrayRef size) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, size, out); + } + + // aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::randint_out::redispatch(dispatchKeySet, high, size, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t high, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t high, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, size, generator, out); + } + + // aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_generator_out::redispatch(dispatchKeySet, high, size, generator, out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t low, int64_t high, at::IntArrayRef size) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, size, out); + } + + // aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::randint_low_out::redispatch(dispatchKeySet, low, high, size, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_outf(c10::DispatchKeySet dispatchKeySet, int64_t low, int64_t high, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, size, generator, out); + } + + // aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randint_low_generator_out::redispatch(dispatchKeySet, low, high, size, generator, out); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_Tensor::redispatch(dispatchKeySet, self, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_Tensor::redispatch(dispatchKeySet, self, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randint_like_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randint_like_low_dtype::redispatch(dispatchKeySet, self, low, high, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::randn::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::randn::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::TensorOptions options={}) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator::redispatch(dispatchKeySet, size, generator, dtype, layout, device, pin_memory); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, dtype, layout, device, pin_memory); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, dtype, layout, device, pin_memory); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, size, generator, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randn_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randn_generator_with_names::redispatch(dispatchKeySet, size, generator, names, dtype, layout, device, pin_memory); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::randn_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::randn_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::randn_out::redispatch(dispatchKeySet, size, out); + } + + // aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::randn_out::redispatch(dispatchKeySet, size, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::randn_generator_out::redispatch(dispatchKeySet, size, generator, out); + } + + // aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randn_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor randn_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::randn_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, at::TensorOptions options=at::kLong) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::TensorOptions options=at::kLong) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm::redispatch(dispatchKeySet, n, dtype, layout, device, pin_memory); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, dtype, layout, device, pin_memory); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional generator, at::TensorOptions options=at::kLong) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor randperm_symint(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::randperm_generator::redispatch(dispatchKeySet, n, generator, dtype, layout, device, pin_memory); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, at::Tensor & out) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, at::Tensor & out) { + return at::_ops::randperm_out::redispatch(dispatchKeySet, n, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, ::std::optional generator) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, ::std::optional generator, at::Tensor & out) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymInt n, ::std::optional generator) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randperm_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymInt n, ::std::optional generator, at::Tensor & out) { + return at::_ops::randperm_generator_out::redispatch(dispatchKeySet, n, generator, out); + } + + // aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step=1, at::TensorOptions options={}) { + return at::_ops::range_step::redispatch(dispatchKeySet, start, end, step, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::range_step::redispatch(dispatchKeySet, start, end, step, dtype, layout, device, pin_memory); + } + + // aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::TensorOptions options={}) { + return at::_ops::range::redispatch(dispatchKeySet, start, end, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor range(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::range::redispatch(dispatchKeySet, start, end, dtype, layout, device, pin_memory); + } + + // aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end) { + return at::_ops::range_out_::redispatch(dispatchKeySet, start, end, out); + } + + // aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, at::Tensor & out) { + return at::_ops::range_out_::redispatch(dispatchKeySet, start, end, out); + } + + // aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step) { + return at::_ops::range_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & range_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out) { + return at::_ops::range_out::redispatch(dispatchKeySet, start, end, step, out); + } + + // aten::ravel(Tensor(a) self) -> Tensor(a) + inline at::Tensor ravel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ravel::redispatch(dispatchKeySet, self); + } + + // aten::reciprocal(Tensor self) -> Tensor + inline at::Tensor reciprocal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::reciprocal::redispatch(dispatchKeySet, self); + } + + // aten::reciprocal_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & reciprocal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::reciprocal_::redispatch(dispatchKeySet, self); + } + + // aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reciprocal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reciprocal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::neg(Tensor self) -> Tensor + inline at::Tensor neg(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::neg::redispatch(dispatchKeySet, self); + } + + // aten::neg_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & neg_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::neg_::redispatch(dispatchKeySet, self); + } + + // aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & neg_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & neg_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::negative(Tensor self) -> Tensor + inline at::Tensor negative(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::negative::redispatch(dispatchKeySet, self); + } + + // aten::negative_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & negative_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::negative_::redispatch(dispatchKeySet, self); + } + + // aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & negative_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::negative_out::redispatch(dispatchKeySet, self, out); + } + + // aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & negative_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::negative_out::redispatch(dispatchKeySet, self, out); + } + + // aten::repeat(Tensor self, SymInt[] repeats) -> Tensor + inline at::Tensor repeat(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef repeats) { + return at::_ops::repeat::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats)); + } + + // aten::repeat(Tensor self, SymInt[] repeats) -> Tensor + inline at::Tensor repeat_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef repeats) { + return at::_ops::repeat::redispatch(dispatchKeySet, self, repeats); + } + + // aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor::redispatch(dispatchKeySet, repeats, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); + } + + // aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor::redispatch(dispatchKeySet, repeats, output_size); + } + + // aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_Tensor::redispatch(dispatchKeySet, self, repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); + } + + // aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_Tensor::redispatch(dispatchKeySet, self, repeats, dim, output_size); + } + + // aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_int::redispatch(dispatchKeySet, self, repeats, dim, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt); + } + + // aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor + inline at::Tensor repeat_interleave_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt repeats, ::std::optional dim=::std::nullopt, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_self_int::redispatch(dispatchKeySet, self, repeats, dim, output_size); + } + + // aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) + inline at::Tensor reshape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape) { + return at::_ops::reshape::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shape)); + } + + // aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a) + inline at::Tensor reshape_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shape) { + return at::_ops::reshape::redispatch(dispatchKeySet, self, shape); + } + + // aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _reshape_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_reshape_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _reshape_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::_reshape_copy::redispatch(dispatchKeySet, self, size); + } + + // aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + inline at::Tensor _reshape_alias(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::_reshape_alias::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a) + inline at::Tensor _reshape_alias_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::_reshape_alias::redispatch(dispatchKeySet, self, size, stride); + } + + // aten::_mkldnn_reshape(Tensor self, int[] shape) -> Tensor + inline at::Tensor _mkldnn_reshape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape) { + return at::_ops::_mkldnn_reshape::redispatch(dispatchKeySet, self, shape); + } + + // aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a) + inline at::Tensor reshape_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::reshape_as::redispatch(dispatchKeySet, self, other); + } + + // aten::round(Tensor self) -> Tensor + inline at::Tensor round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::round::redispatch(dispatchKeySet, self); + } + + // aten::round_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & round_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::round_::redispatch(dispatchKeySet, self); + } + + // aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::round.decimals(Tensor self, *, int decimals) -> Tensor + inline at::Tensor round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals) { + return at::_ops::round_decimals::redispatch(dispatchKeySet, self, decimals); + } + + // aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!) + inline at::Tensor & round_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t decimals) { + return at::_ops::round__decimals::redispatch(dispatchKeySet, self, decimals); + } + + // aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t decimals) { + return at::_ops::round_decimals_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals, at::Tensor & out) { + return at::_ops::round_decimals_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + inline at::Tensor rrelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu::redispatch(dispatchKeySet, self, lower, upper, training, generator); + } + + // aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & rrelu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_::redispatch(dispatchKeySet, self, lower, upper, training, generator); + } + + // aten::relu(Tensor self) -> Tensor + inline at::Tensor relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::relu::redispatch(dispatchKeySet, self); + } + + // aten::relu_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::relu_::redispatch(dispatchKeySet, self); + } + + // aten::relu6(Tensor self) -> Tensor + inline at::Tensor relu6(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::relu6::redispatch(dispatchKeySet, self); + } + + // aten::relu6_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & relu6_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::relu6_::redispatch(dispatchKeySet, self); + } + + // aten::prelu(Tensor self, Tensor weight) -> Tensor + inline at::Tensor prelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight) { + return at::_ops::prelu::redispatch(dispatchKeySet, self, weight); + } + + // aten::_prelu_kernel(Tensor self, Tensor weight) -> Tensor + inline at::Tensor _prelu_kernel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight) { + return at::_ops::_prelu_kernel::redispatch(dispatchKeySet, self, weight); + } + + // aten::_prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor) + inline ::std::tuple _prelu_kernel_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight) { + return at::_ops::_prelu_kernel_backward::redispatch(dispatchKeySet, grad_output, self, weight); + } + + // aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gelu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_out::redispatch(dispatchKeySet, self, approximate, out); + } + + // aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gelu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view approximate, at::Tensor & out) { + return at::_ops::gelu_out::redispatch(dispatchKeySet, self, approximate, out); + } + + // aten::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!) + inline at::Tensor & gelu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_::redispatch(dispatchKeySet, self, approximate); + } + + // aten::gelu(Tensor self, *, str approximate='none') -> Tensor + inline at::Tensor gelu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu::redispatch(dispatchKeySet, self, approximate); + } + + // aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & gelu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, approximate, grad_input); + } + + // aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & gelu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate, at::Tensor & grad_input) { + return at::_ops::gelu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, approximate, grad_input); + } + + // aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor + inline at::Tensor gelu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate="none") { + return at::_ops::gelu_backward::redispatch(dispatchKeySet, grad_output, self, approximate); + } + + // aten::infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor + inline at::Tensor infinitely_differentiable_gelu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self) { + return at::_ops::infinitely_differentiable_gelu_backward::redispatch(dispatchKeySet, grad, self); + } + + // aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardshrink_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::hardshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardshrink_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) { + return at::_ops::hardshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor + inline at::Tensor hardshrink(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::hardshrink::redispatch(dispatchKeySet, self, lambd); + } + + // aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardshrink_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::hardshrink_backward_grad_input::redispatch(dispatchKeySet, grad_out, self, lambd, grad_input); + } + + // aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardshrink_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input) { + return at::_ops::hardshrink_backward_grad_input::redispatch(dispatchKeySet, grad_out, self, lambd, grad_input); + } + + // aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor + inline at::Tensor hardshrink_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::hardshrink_backward::redispatch(dispatchKeySet, grad_out, self, lambd); + } + + // aten::rsqrt(Tensor self) -> Tensor + inline at::Tensor rsqrt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::rsqrt::redispatch(dispatchKeySet, self); + } + + // aten::rsqrt_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & rsqrt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::rsqrt_::redispatch(dispatchKeySet, self); + } + + // aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsqrt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsqrt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a) + inline at::Tensor select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, int64_t index) { + return at::_ops::select_Dimname::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + inline at::Tensor select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::select_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a) + inline at::Tensor select_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::select_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + inline at::Tensor select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) { + return at::_ops::select_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index); + } + + // aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor + inline at::Tensor select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + return at::_ops::select_backward::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index); + } + + // aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor _nested_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::_nested_select_backward::redispatch(dispatchKeySet, grad_output, self, dim, index); + } + + // aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor _nested_select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::_nested_select_backward::redispatch(dispatchKeySet, grad_output, self, dim, index); + } + + // aten::selu(Tensor self) -> Tensor + inline at::Tensor selu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::selu::redispatch(dispatchKeySet, self); + } + + // aten::selu_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & selu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::selu_::redispatch(dispatchKeySet, self); + } + + // aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor + inline at::Tensor celu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha=1.0) { + return at::_ops::celu::redispatch(dispatchKeySet, self, alpha); + } + + // aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!) + inline at::Tensor & celu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & alpha=1.0) { + return at::_ops::celu_::redispatch(dispatchKeySet, self, alpha); + } + + // aten::silu(Tensor self) -> Tensor + inline at::Tensor silu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::silu::redispatch(dispatchKeySet, self); + } + + // aten::silu_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & silu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::silu_::redispatch(dispatchKeySet, self); + } + + // aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & silu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::silu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & silu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::silu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & silu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::silu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & silu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::silu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor silu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::silu_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::mish(Tensor self) -> Tensor + inline at::Tensor mish(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::mish::redispatch(dispatchKeySet, self); + } + + // aten::mish_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & mish_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::mish_::redispatch(dispatchKeySet, self); + } + + // aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mish_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::mish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mish_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::mish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::mish_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor mish_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::mish_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::sigmoid(Tensor self) -> Tensor + inline at::Tensor sigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sigmoid::redispatch(dispatchKeySet, self); + } + + // aten::sigmoid_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sigmoid_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sigmoid_::redispatch(dispatchKeySet, self); + } + + // aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::logit(Tensor self, float? eps=None) -> Tensor + inline at::Tensor logit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit::redispatch(dispatchKeySet, self, eps); + } + + // aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!) + inline at::Tensor & logit_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_::redispatch(dispatchKeySet, self, eps); + } + + // aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & logit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps, at::Tensor & out) { + return at::_ops::logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::sin(Tensor self) -> Tensor + inline at::Tensor sin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sin::redispatch(dispatchKeySet, self); + } + + // aten::sin_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sin_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sin_::redispatch(dispatchKeySet, self); + } + + // aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinc(Tensor self) -> Tensor + inline at::Tensor sinc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sinc::redispatch(dispatchKeySet, self); + } + + // aten::sinc_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sinc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sinc_::redispatch(dispatchKeySet, self); + } + + // aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinh(Tensor self) -> Tensor + inline at::Tensor sinh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sinh::redispatch(dispatchKeySet, self); + } + + // aten::sinh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sinh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sinh_::redispatch(dispatchKeySet, self); + } + + // aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sinh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::detach(Tensor(a) self) -> Tensor(a) + inline at::Tensor detach(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::detach::redispatch(dispatchKeySet, self); + } + + // aten::detach_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & detach_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::detach_::redispatch(dispatchKeySet, self); + } + + // aten::size.int(Tensor self, int dim) -> int + inline int64_t __dispatch_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::size_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::size.Dimname(Tensor self, Dimname dim) -> int + inline int64_t size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::size_Dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::sym_size.int(Tensor self, int dim) -> SymInt + inline c10::SymInt __dispatch_sym_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::sym_size_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::sym_numel(Tensor self) -> SymInt + inline c10::SymInt __dispatch_sym_numel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sym_numel::redispatch(dispatchKeySet, self); + } + + // aten::sym_storage_offset(Tensor self) -> SymInt + inline c10::SymInt __dispatch_sym_storage_offset(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sym_storage_offset::redispatch(dispatchKeySet, self); + } + + // aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_Tensor::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_Tensor::redispatch(dispatchKeySet, self, dim, start, end, step); + } + + // aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + inline at::Tensor slice_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + return at::_ops::slice_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step); + } + + // aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor + inline at::Tensor slice_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) { + return at::_ops::slice_backward::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step); + } + + // aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice_inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_inverse::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a) + inline at::Tensor slice_inverse_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_inverse::redispatch(dispatchKeySet, self, src, dim, start, end, step); + } + + // aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_scatter::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_scatter::redispatch(dispatchKeySet, self, src, dim, start, end, step); + } + + // aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + inline at::Tensor select_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index) { + return at::_ops::select_scatter::redispatch(dispatchKeySet, self, src, dim, index); + } + + // aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor + inline at::Tensor select_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) { + return at::_ops::select_scatter::redispatch(dispatchKeySet, self, src, dim, index); + } + + // aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor + inline at::Tensor diagonal_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_scatter::redispatch(dispatchKeySet, self, src, offset, dim1, dim2); + } + + // aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_scatter_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter::redispatch(dispatchKeySet, self, src, size, stride, storage_offset); + } + + // aten::smm(Tensor self, Tensor mat2) -> Tensor + inline at::Tensor smm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2) { + return at::_ops::smm::redispatch(dispatchKeySet, self, mat2); + } + + // aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::softmax_int_out::redispatch(dispatchKeySet, self, dim, dtype, out); + } + + // aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor + inline at::Tensor _softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype); + } + + // aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + return at::_ops::_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, grad_input); + } + + // aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & grad_input) { + return at::_ops::_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, input_dtype, grad_input); + } + + // aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::split_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_size, int64_t dim=0) { + return at::_ops::split_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_size), dim); + } + + // aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[] + inline ::std::vector split_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim=0) { + return at::_ops::split_sizes::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split_with_sizes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim); + } + + // aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector unsafe_split_with_sizes_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes::redispatch(dispatchKeySet, self, split_sizes, dim); + } + + // aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + inline ::std::vector split_with_sizes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim); + } + + // aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[] + inline ::std::vector split_with_sizes_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes::redispatch(dispatchKeySet, self, split_sizes, dim); + } + + // aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + inline ::std::vector hsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) { + return at::_ops::hsplit_int::redispatch(dispatchKeySet, self, sections); + } + + // aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + inline ::std::vector hsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) { + return at::_ops::hsplit_array::redispatch(dispatchKeySet, self, indices); + } + + // aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + inline ::std::vector vsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) { + return at::_ops::vsplit_int::redispatch(dispatchKeySet, self, sections); + } + + // aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + inline ::std::vector vsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) { + return at::_ops::vsplit_array::redispatch(dispatchKeySet, self, indices); + } + + // aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[] + inline ::std::vector dsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sections) { + return at::_ops::dsplit_int::redispatch(dispatchKeySet, self, sections); + } + + // aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[] + inline ::std::vector dsplit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef indices) { + return at::_ops::dsplit_array::redispatch(dispatchKeySet, self, indices); + } + + // aten::squeeze(Tensor(a) self) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::squeeze::redispatch(dispatchKeySet, self); + } + + // aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::squeeze_dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::squeeze_dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a) + inline at::Tensor squeeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze_dims::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::squeeze_::redispatch(dispatchKeySet, self); + } + + // aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim) { + return at::_ops::squeeze__dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze__dims::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!) + inline at::Tensor & squeeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim) { + return at::_ops::squeeze__dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor sspaddmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sspaddmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sspaddmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sspaddmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sspaddmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sspaddmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor + inline at::Tensor _chunk_cat(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, int64_t num_chunks) { + return at::_ops::_chunk_cat::redispatch(dispatchKeySet, tensors, dim, num_chunks); + } + + // aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _chunk_cat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim, int64_t num_chunks) { + return at::_ops::_chunk_cat_out::redispatch(dispatchKeySet, tensors, dim, num_chunks, out); + } + + // aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _chunk_cat_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor & out) { + return at::_ops::_chunk_cat_out::redispatch(dispatchKeySet, tensors, dim, num_chunks, out); + } + + // aten::stack(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::stack::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::_stack(Tensor[] tensors, int dim=0) -> Tensor + inline at::Tensor _stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim=0) { + return at::_ops::_stack::redispatch(dispatchKeySet, tensors, dim); + } + + // aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors, int64_t dim=0) { + return at::_ops::_stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, int64_t dim, at::Tensor & out) { + return at::_ops::_stack_out::redispatch(dispatchKeySet, tensors, dim, out); + } + + // aten::hstack(Tensor[] tensors) -> Tensor + inline at::Tensor hstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::hstack::redispatch(dispatchKeySet, tensors); + } + + // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::hstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::hstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::vstack(Tensor[] tensors) -> Tensor + inline at::Tensor vstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::vstack::redispatch(dispatchKeySet, tensors); + } + + // aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::vstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & vstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::vstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::dstack(Tensor[] tensors) -> Tensor + inline at::Tensor dstack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::dstack::redispatch(dispatchKeySet, tensors); + } + + // aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dstack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::dstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dstack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::dstack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor + inline at::Tensor stft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) { + return at::_ops::stft::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, normalized, onesided, return_complex, align_to_window); + } + + // aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode="reflect", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor + inline at::Tensor stft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, c10::string_view pad_mode="reflect", bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional return_complex=::std::nullopt, ::std::optional align_to_window=::std::nullopt) { + return at::_ops::stft_center::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex, align_to_window); + } + + // aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor + inline at::Tensor istft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n_fft, ::std::optional hop_length=::std::nullopt, ::std::optional win_length=::std::nullopt, const ::std::optional & window={}, bool center=true, bool normalized=false, ::std::optional onesided=::std::nullopt, ::std::optional length=::std::nullopt, bool return_complex=false) { + return at::_ops::istft::redispatch(dispatchKeySet, self, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex); + } + + // aten::stride.int(Tensor self, int dim) -> int + inline int64_t __dispatch_stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::stride_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::stride.Dimname(Tensor self, Dimname dim) -> int + inline int64_t stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::stride_Dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::sym_stride.int(Tensor self, int dim) -> SymInt + inline c10::SymInt __dispatch_sym_stride(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::sym_stride_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum::redispatch(dispatchKeySet, self, dtype); + } + + // aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_dim_IntList::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_dim_DimnameList::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_IntList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::sum_IntList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_DimnameList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::sum_DimnameList_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor + inline at::Tensor _nested_sum_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim=false) { + return at::_ops::_nested_sum_backward::redispatch(dispatchKeySet, grad, self, dim, keepdim); + } + + // aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor nansum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nansum::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nansum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::nansum_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nansum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::nansum_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor sum_to_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::sum_to_size::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor sum_to_size_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::sum_to_size::redispatch(dispatchKeySet, self, size); + } + + // aten::sqrt(Tensor self) -> Tensor + inline at::Tensor sqrt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sqrt::redispatch(dispatchKeySet, self); + } + + // aten::sqrt_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sqrt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sqrt_::redispatch(dispatchKeySet, self); + } + + // aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sqrt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sqrt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::square(Tensor self) -> Tensor + inline at::Tensor square(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::square::redispatch(dispatchKeySet, self); + } + + // aten::square_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & square_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::square_::redispatch(dispatchKeySet, self); + } + + // aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & square_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::square_out::redispatch(dispatchKeySet, self, out); + } + + // aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & square_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::square_out::redispatch(dispatchKeySet, self, out); + } + + // aten::std(Tensor self, bool unbiased=True) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::std::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::std_mean::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_mean_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_mean_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_mean_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple std_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_mean_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::std_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::std_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::std_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::std_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor std(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & std_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::std_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod::redispatch(dispatchKeySet, self, dtype); + } + + // aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_dim_int::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_int_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::prod_int_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_dim_Dimname::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_Dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::prod_Dimname_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::t(Tensor(a) self) -> Tensor(a) + inline at::Tensor t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::t::redispatch(dispatchKeySet, self); + } + + // aten::t_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & t_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::t_::redispatch(dispatchKeySet, self); + } + + // aten::tan(Tensor self) -> Tensor + inline at::Tensor tan(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::tan::redispatch(dispatchKeySet, self); + } + + // aten::tan_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & tan_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::tan_::redispatch(dispatchKeySet, self); + } + + // aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tanh(Tensor self) -> Tensor + inline at::Tensor tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::tanh::redispatch(dispatchKeySet, self); + } + + // aten::tanh_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & tanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::tanh_::redispatch(dispatchKeySet, self); + } + + // aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor + inline at::Tensor tensordot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) { + return at::_ops::tensordot::redispatch(dispatchKeySet, self, other, dims_self, dims_other); + } + + // aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tensordot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) { + return at::_ops::tensordot_out::redispatch(dispatchKeySet, self, other, dims_self, dims_other, out); + } + + // aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tensordot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other, at::Tensor & out) { + return at::_ops::tensordot_out::redispatch(dispatchKeySet, self, other, dims_self, dims_other, out); + } + + // aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor + inline at::Tensor threshold(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + return at::_ops::threshold::redispatch(dispatchKeySet, self, threshold, value); + } + + // aten::threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!) + inline at::Tensor & threshold_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + return at::_ops::threshold_::redispatch(dispatchKeySet, self, threshold, value); + } + + // aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & threshold_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + return at::_ops::threshold_out::redispatch(dispatchKeySet, self, threshold, value, out); + } + + // aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & threshold_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value, at::Tensor & out) { + return at::_ops::threshold_out::redispatch(dispatchKeySet, self, threshold, value, out); + } + + // aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & threshold_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) { + return at::_ops::threshold_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, threshold, grad_input); + } + + // aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & threshold_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold, at::Tensor & grad_input) { + return at::_ops::threshold_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, threshold, grad_input); + } + + // aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor + inline at::Tensor threshold_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) { + return at::_ops::threshold_backward::redispatch(dispatchKeySet, grad_output, self, threshold); + } + + // aten::tile(Tensor self, SymInt[] dims) -> Tensor + inline at::Tensor tile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::tile::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(dims)); + } + + // aten::tile(Tensor self, SymInt[] dims) -> Tensor + inline at::Tensor tile_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef dims) { + return at::_ops::tile::redispatch(dispatchKeySet, self, dims); + } + + // aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + inline at::Tensor transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_int::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a) + inline at::Tensor transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim0, at::Dimname dim1) { + return at::_ops::transpose_Dimname::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::_mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor + inline at::Tensor _mkldnn_transpose(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::_mkldnn_transpose::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + inline at::Tensor & transpose_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + inline at::Tensor & _mkldnn_transpose_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::_mkldnn_transpose_::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::one_hot(Tensor self, int num_classes=-1) -> Tensor + inline at::Tensor one_hot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_classes=-1) { + return at::_ops::one_hot::redispatch(dispatchKeySet, self, num_classes); + } + + // aten::flip(Tensor self, int[] dims) -> Tensor + inline at::Tensor flip(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::flip::redispatch(dispatchKeySet, self, dims); + } + + // aten::fliplr(Tensor self) -> Tensor + inline at::Tensor fliplr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::fliplr::redispatch(dispatchKeySet, self); + } + + // aten::flipud(Tensor self) -> Tensor + inline at::Tensor flipud(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::flipud::redispatch(dispatchKeySet, self); + } + + // aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + inline at::Tensor roll(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims); + } + + // aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor + inline at::Tensor roll_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll::redispatch(dispatchKeySet, self, shifts, dims); + } + + // aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor + inline at::Tensor rot90(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k=1, at::IntArrayRef dims={0,1}) { + return at::_ops::rot90::redispatch(dispatchKeySet, self, k, dims); + } + + // aten::trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + inline at::Tensor trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) { + return at::_ops::trapezoid_x::redispatch(dispatchKeySet, y, x, dim); + } + + // aten::trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor + inline at::Tensor trapezoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Scalar & dx=1, int64_t dim=-1) { + return at::_ops::trapezoid_dx::redispatch(dispatchKeySet, y, dx, dim); + } + + // aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor + inline at::Tensor trapz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, const at::Tensor & x, int64_t dim=-1) { + return at::_ops::trapz_x::redispatch(dispatchKeySet, y, x, dim); + } + + // aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor + inline at::Tensor trapz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & y, double dx=1, int64_t dim=-1) { + return at::_ops::trapz_dx::redispatch(dispatchKeySet, y, dx, dim); + } + + // aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _transform_bias_rescale_qkv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) { + return at::_ops::_transform_bias_rescale_qkv::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads); + } + + // aten::_nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor + inline at::Tensor _nested_tensor_from_mask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true) { + return at::_ops::_nested_tensor_from_mask::redispatch(dispatchKeySet, t, mask, mask_check); + } + + // aten::_nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool + inline bool _nested_tensor_from_mask_left_aligned(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask) { + return at::_ops::_nested_tensor_from_mask_left_aligned::redispatch(dispatchKeySet, t, mask); + } + + // aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor + inline at::Tensor _nested_from_padded(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) { + return at::_ops::_nested_from_padded::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213); + } + + // aten::_nested_tensor_size(Tensor self) -> Tensor + inline at::Tensor _nested_tensor_size(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_tensor_size::redispatch(dispatchKeySet, self); + } + + // aten::_nested_tensor_strides(Tensor self) -> Tensor + inline at::Tensor _nested_tensor_strides(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_tensor_strides::redispatch(dispatchKeySet, self); + } + + // aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor + inline at::Tensor _nested_tensor_storage_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_tensor_storage_offsets::redispatch(dispatchKeySet, self); + } + + // aten::_nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor + inline at::Tensor _nested_from_padded_and_nested_example(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & nt_example) { + return at::_ops::_nested_from_padded_and_nested_example::redispatch(dispatchKeySet, padded, nt_example); + } + + // aten::_nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a) + inline at::Tensor _nested_view_from_buffer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + return at::_ops::_nested_view_from_buffer::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets); + } + + // aten::_nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor + inline at::Tensor _nested_view_from_buffer_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + return at::_ops::_nested_view_from_buffer_copy::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets); + } + + // aten::_nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a) + inline at::Tensor _nested_view_from_jagged(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths={}, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}) { + return at::_ops::_nested_view_from_jagged::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + + // aten::_nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor + inline at::Tensor _nested_view_from_jagged_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths={}, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}) { + return at::_ops::_nested_view_from_jagged_copy::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + + // aten::_nested_get_values(Tensor(a) self) -> Tensor(a) + inline at::Tensor _nested_get_values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_values::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_values_copy(Tensor self) -> Tensor + inline at::Tensor _nested_get_values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_values_copy::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_offsets(Tensor self) -> Tensor + inline at::Tensor _nested_get_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_offsets::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_lengths(Tensor self) -> Tensor + inline at::Tensor _nested_get_lengths(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_lengths::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_ragged_idx(Tensor self) -> int + inline int64_t _nested_get_ragged_idx(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_ragged_idx::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_min_seqlen(Tensor self) -> Tensor + inline at::Tensor _nested_get_min_seqlen(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_min_seqlen::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_max_seqlen(Tensor self) -> Tensor + inline at::Tensor _nested_get_max_seqlen(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nested_get_max_seqlen::redispatch(dispatchKeySet, self); + } + + // aten::_nested_get_jagged_dummy(Tensor any) -> Tensor + inline at::Tensor _nested_get_jagged_dummy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & any) { + return at::_ops::_nested_get_jagged_dummy::redispatch(dispatchKeySet, any); + } + + // aten::_nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor) + inline ::std::tuple _nested_compute_contiguous_strides_offsets(c10::DispatchKeySet dispatchKeySet, const at::Tensor & nested_size) { + return at::_ops::_nested_compute_contiguous_strides_offsets::redispatch(dispatchKeySet, nested_size); + } + + // aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor + inline at::Tensor _trilinear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim=1) { + return at::_ops::_trilinear::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + } + + // aten::triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor + inline at::Tensor triplet_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin=1.0, double p=2, double eps=1e-06, bool swap=false, int64_t reduction=at::Reduction::Mean) { + return at::_ops::triplet_margin_loss::redispatch(dispatchKeySet, anchor, positive, negative, margin, p, eps, swap, reduction); + } + + // aten::trunc(Tensor self) -> Tensor + inline at::Tensor trunc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::trunc::redispatch(dispatchKeySet, self); + } + + // aten::trunc_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & trunc_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::trunc_::redispatch(dispatchKeySet, self); + } + + // aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trunc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trunc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::fix(Tensor self) -> Tensor + inline at::Tensor fix(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::fix::redispatch(dispatchKeySet, self); + } + + // aten::fix_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & fix_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::fix_::redispatch(dispatchKeySet, self); + } + + // aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fix_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::fix_out::redispatch(dispatchKeySet, self, out); + } + + // aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fix_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::fix_out::redispatch(dispatchKeySet, self, out); + } + + // aten::type_as(Tensor self, Tensor other) -> Tensor + inline at::Tensor type_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::type_as::redispatch(dispatchKeySet, self, other); + } + + // aten::_has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool + inline bool _has_compatible_shallow_copy_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & from) { + return at::_ops::_has_compatible_shallow_copy_type::redispatch(dispatchKeySet, self, from); + } + + // aten::_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor) + inline ::std::tuple _unique(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted=true, bool return_inverse=false) { + return at::_ops::_unique::redispatch(dispatchKeySet, self, sorted, return_inverse); + } + + // aten::unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + inline ::std::tuple unique_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts); + } + + // aten::unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple unique_consecutive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool return_inverse=false, bool return_counts=false, ::std::optional dim=::std::nullopt) { + return at::_ops::unique_consecutive::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim); + } + + // aten::unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + inline ::std::tuple unique_dim_consecutive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim_consecutive::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts); + } + + // aten::_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _unique2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::_unique2::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts); + } + + // aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _unsafe_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_unsafe_view::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor _unsafe_view_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::_unsafe_view::redispatch(dispatchKeySet, self, size); + } + + // aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a) + inline at::Tensor unsqueeze(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze::redispatch(dispatchKeySet, self, dim); + } + + // aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!) + inline at::Tensor & unsqueeze_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze_::redispatch(dispatchKeySet, self, dim); + } + + // aten::vander(Tensor x, int? N=None, bool increasing=False) -> Tensor + inline at::Tensor vander(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, ::std::optional N=::std::nullopt, bool increasing=false) { + return at::_ops::vander::redispatch(dispatchKeySet, x, N, increasing); + } + + // aten::var(Tensor self, bool unbiased=True) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::var::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::var_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::var_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out) { + return at::_ops::var_names_out::redispatch(dispatchKeySet, self, dim, unbiased, keepdim, out); + } + + // aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor + inline at::Tensor var(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & var_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out) { + return at::_ops::var_correction_names_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out); + } + + // aten::var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool unbiased) { + return at::_ops::var_mean::redispatch(dispatchKeySet, self, unbiased); + } + + // aten::var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_mean_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_mean_correction::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim=false) { + return at::_ops::var_mean_names_dim::redispatch(dispatchKeySet, self, dim, unbiased, keepdim); + } + + // aten::var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor) + inline ::std::tuple var_mean(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_mean_correction_names::redispatch(dispatchKeySet, self, dim, correction, keepdim); + } + + // aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a) + inline at::Tensor view_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::view_as::redispatch(dispatchKeySet, self, other); + } + + // aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::where_self::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & where_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::where_self_out::redispatch(dispatchKeySet, condition, self, other, out); + } + + // aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & where_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::where_self_out::redispatch(dispatchKeySet, condition, self, other, out); + } + + // aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::where_ScalarSelf::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::where_ScalarOther::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor + inline at::Tensor where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other) { + return at::_ops::where_Scalar::redispatch(dispatchKeySet, condition, self, other); + } + + // aten::where(Tensor condition) -> Tensor[] + inline ::std::vector where(c10::DispatchKeySet dispatchKeySet, const at::Tensor & condition) { + return at::_ops::where::redispatch(dispatchKeySet, condition); + } + + // aten::norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor + inline at::Tensor norm_except_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, int64_t pow=2, int64_t dim=0) { + return at::_ops::norm_except_dim::redispatch(dispatchKeySet, v, pow, dim); + } + + // aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor + inline at::Tensor _weight_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) { + return at::_ops::_weight_norm::redispatch(dispatchKeySet, v, g, dim); + } + + // aten::_weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor) + inline ::std::tuple _weight_norm_interface(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) { + return at::_ops::_weight_norm_interface::redispatch(dispatchKeySet, v, g, dim); + } + + // aten::_weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + inline ::std::tuple _weight_norm_interface_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + return at::_ops::_weight_norm_interface_backward::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim); + } + + // aten::_weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor) + inline ::std::tuple _weight_norm_differentiable_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + return at::_ops::_weight_norm_differentiable_backward::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim); + } + + // aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::TensorOptions options={}) { + return at::_ops::zeros_names::redispatch(dispatchKeySet, size, names, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::zeros_names::redispatch(dispatchKeySet, size, names, dtype, layout, device, pin_memory); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _efficientzerotensor_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_efficientzerotensor::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::zeros::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::zeros::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::zeros::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor zeros_symint(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::zeros::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, size, out); + } + + // aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::zeros_out::redispatch(dispatchKeySet, size, out); + } + + // aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor zeros_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, ::std::optional memory_format=::std::nullopt) { + return at::_ops::zeros_like::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor zeros_like(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + return at::_ops::zeros_like::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, memory_format); + } + + // aten::_standard_gamma_grad(Tensor self, Tensor output) -> Tensor + inline at::Tensor _standard_gamma_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & output) { + return at::_ops::_standard_gamma_grad::redispatch(dispatchKeySet, self, output); + } + + // aten::_standard_gamma(Tensor self, Generator? generator=None) -> Tensor + inline at::Tensor _standard_gamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_standard_gamma::redispatch(dispatchKeySet, self, generator); + } + + // aten::_dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor + inline at::Tensor _dirichlet_grad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) { + return at::_ops::_dirichlet_grad::redispatch(dispatchKeySet, x, alpha, total); + } + + // aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor + inline at::Tensor _sample_dirichlet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_sample_dirichlet::redispatch(dispatchKeySet, self, generator); + } + + // aten::poisson(Tensor self, Generator? generator=None) -> Tensor + inline at::Tensor poisson(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::poisson::redispatch(dispatchKeySet, self, generator); + } + + // aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor + inline at::Tensor binomial(c10::DispatchKeySet dispatchKeySet, const at::Tensor & count, const at::Tensor & prob, ::std::optional generator=::std::nullopt) { + return at::_ops::binomial::redispatch(dispatchKeySet, count, prob, generator); + } + + // aten::native_norm(Tensor self, Scalar p=2) -> Tensor + inline at::Tensor native_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::native_norm::redispatch(dispatchKeySet, self, p); + } + + // aten::native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor + inline at::Tensor native_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + return at::_ops::native_norm_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype); + } + + // aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _batch_norm_with_update(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_with_update::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::_batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple _batch_norm_with_update_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_with_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out, save_mean, save_invstd, reserve); + } + + // aten::_batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple _batch_norm_with_update_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve) { + return at::_ops::_batch_norm_with_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out, save_mean, save_invstd, reserve); + } + + // aten::_batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _batch_norm_no_update(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_no_update::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor) + inline ::std::tuple batch_norm_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve) { + return at::_ops::batch_norm_backward::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_var, update, eps, output_mask, reserve); + } + + // aten::_sparse_sum(Tensor self) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_sparse_sum::redispatch(dispatchKeySet, self); + } + + // aten::_sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::_sparse_sum_dtype::redispatch(dispatchKeySet, self, dtype); + } + + // aten::_sparse_sum.dim(Tensor self, int[1] dim) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::_sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor + inline at::Tensor _sparse_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype) { + return at::_ops::_sparse_sum_dim_dtype::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor + inline at::Tensor _sparse_sum_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_backward::redispatch(dispatchKeySet, grad, self, dim); + } + + // aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_csr_sum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_sum_dim_dtype::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::_sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_csr_prod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_prod_dim_dtype::redispatch(dispatchKeySet, self, dim, keepdim, dtype); + } + + // aten::_sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _sparse_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + inline at::Tensor _sparse_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, self); + } + + // aten::_sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_log_softmax_int::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_log_softmax_Dimname::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor + inline at::Tensor _sparse_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_log_softmax::redispatch(dispatchKeySet, self, dim, half_to_float); + } + + // aten::_sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor + inline at::Tensor _sparse_log_softmax_backward_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_log_softmax_backward_data::redispatch(dispatchKeySet, grad_output, output, dim, self); + } + + // aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor + inline at::Tensor _spdiags(c10::DispatchKeySet dispatchKeySet, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout=::std::nullopt) { + return at::_ops::_spdiags::redispatch(dispatchKeySet, diagonals, offsets, shape, layout); + } + + // aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype) { + return at::_ops::norm_ScalarOpt_dtype::redispatch(dispatchKeySet, self, p, dtype); + } + + // aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::norm_Scalar::redispatch(dispatchKeySet, self, p); + } + + // aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype); + } + + // aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::norm_ScalarOpt_dim::redispatch(dispatchKeySet, self, p, dim, keepdim); + } + + // aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::norm_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::norm_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::norm_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_names_ScalarOpt_dim_dtype::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype); + } + + // aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim=false) { + return at::_ops::norm_names_ScalarOpt_dim::redispatch(dispatchKeySet, self, p, dim, keepdim); + } + + // aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) { + return at::_ops::norm_names_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::norm_names_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim=false) { + return at::_ops::norm_names_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::Tensor & out) { + return at::_ops::norm_names_out::redispatch(dispatchKeySet, self, p, dim, keepdim, out); + } + + // aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent) + inline ::std::tuple frexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::frexp_Tensor::redispatch(dispatchKeySet, self); + } + + // aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) + inline ::std::tuple frexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & mantissa, at::Tensor & exponent, const at::Tensor & self) { + return at::_ops::frexp_Tensor_out::redispatch(dispatchKeySet, self, mantissa, exponent); + } + + // aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent) + inline ::std::tuple frexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & mantissa, at::Tensor & exponent) { + return at::_ops::frexp_Tensor_out::redispatch(dispatchKeySet, self, mantissa, exponent); + } + + // aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor frobenius_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::frobenius_norm_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frobenius_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::frobenius_norm_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & frobenius_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::frobenius_norm_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::nuclear_norm(Tensor self, bool keepdim=False) -> Tensor + inline at::Tensor nuclear_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool keepdim=false) { + return at::_ops::nuclear_norm::redispatch(dispatchKeySet, self, keepdim); + } + + // aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool keepdim=false) { + return at::_ops::nuclear_norm_out::redispatch(dispatchKeySet, self, keepdim, out); + } + + // aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool keepdim, at::Tensor & out) { + return at::_ops::nuclear_norm_out::redispatch(dispatchKeySet, self, keepdim, out); + } + + // aten::nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor + inline at::Tensor nuclear_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::nuclear_norm_dim::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::nuclear_norm_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nuclear_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::nuclear_norm_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor clone(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::clone::redispatch(dispatchKeySet, self, memory_format); + } + + // aten::positive(Tensor(a) self) -> Tensor(a) + inline at::Tensor positive(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::positive::redispatch(dispatchKeySet, self); + } + + // aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!) + inline const at::Tensor & resize_as_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_as_::redispatch(dispatchKeySet, self, the_template, memory_format); + } + + // aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!) + inline const at::Tensor & resize_as_sparse_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template) { + return at::_ops::resize_as_sparse_::redispatch(dispatchKeySet, self, the_template); + } + + // aten::zero_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & zero_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::zero_::redispatch(dispatchKeySet, self); + } + + // aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::sub_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sub_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor sub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::sub_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & sub_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::sub__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor sub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::sub_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & sub_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::sub__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & subtract_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::subtract_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & subtract_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::subtract_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor subtract(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::subtract_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & subtract_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::subtract__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor subtract(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::subtract_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & subtract_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::subtract__Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor + inline at::Tensor rsub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & heaviside_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & values) { + return at::_ops::heaviside_out::redispatch(dispatchKeySet, self, values, out); + } + + // aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & heaviside_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & values, at::Tensor & out) { + return at::_ops::heaviside_out::redispatch(dispatchKeySet, self, values, out); + } + + // aten::heaviside(Tensor self, Tensor values) -> Tensor + inline at::Tensor heaviside(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & values) { + return at::_ops::heaviside::redispatch(dispatchKeySet, self, values); + } + + // aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!) + inline at::Tensor & heaviside_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & values) { + return at::_ops::heaviside_::redispatch(dispatchKeySet, self, values); + } + + // aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor + inline at::Tensor rsub(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Scalar::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor _sparse_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::_sparse_addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_sampled_addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sparse_sampled_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_sampled_addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sparse_sampled_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor sparse_sampled_addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::sparse_sampled_addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::_sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor) + inline ::std::tuple _sparse_mm_reduce_impl(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, c10::string_view reduce) { + return at::_ops::_sparse_mm_reduce_impl::redispatch(dispatchKeySet, self, other, reduce); + } + + // aten::_sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor) + inline ::std::tuple _sparse_mm_reduce_impl_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask) { + return at::_ops::_sparse_mm_reduce_impl_backward::redispatch(dispatchKeySet, self, grad_out, weight, reduce, arg_out, output_mask); + } + + // aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_dtype::redispatch(dispatchKeySet, self, mat1, mat2, out_dtype, beta, alpha); + } + + // aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_dtype_out::redispatch(dispatchKeySet, self, mat1, mat2, out_dtype, beta, alpha, out); + } + + // aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addmm_dtype_out::redispatch(dispatchKeySet, self, mat1, mat2, out_dtype, beta, alpha, out); + } + + // aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addmm_::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha); + } + + // aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _addmm_activation_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) { + return at::_ops::_addmm_activation_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu, out); + } + + // aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _addmm_activation_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu, at::Tensor & out) { + return at::_ops::_addmm_activation_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu, out); + } + + // aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor + inline at::Tensor _addmm_activation(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1, bool use_gelu=false) { + return at::_ops::_addmm_activation::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, use_gelu); + } + + // aten::_scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor + inline at::Tensor _scaled_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias={}, const ::std::optional & scale_result={}, ::std::optional out_dtype=::std::nullopt, bool use_fast_accum=false) { + return at::_ops::_scaled_mm::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum); + } + + // aten::_scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _scaled_mm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias={}, const ::std::optional & scale_result={}, ::std::optional out_dtype=::std::nullopt, bool use_fast_accum=false) { + return at::_ops::_scaled_mm_out::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); + } + + // aten::_scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _scaled_mm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum, at::Tensor & out) { + return at::_ops::_scaled_mm_out::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum, out); + } + + // aten::_scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor + inline at::Tensor _scaled_grouped_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & offs={}, const ::std::optional & bias={}, const ::std::optional & scale_result={}, ::std::optional out_dtype=::std::nullopt, bool use_fast_accum=false) { + return at::_ops::_scaled_grouped_mm::redispatch(dispatchKeySet, self, mat2, scale_a, scale_b, offs, bias, scale_result, out_dtype, use_fast_accum); + } + + // aten::_grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor + inline at::Tensor _grouped_mm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat2, const ::std::optional & offs={}, const ::std::optional & bias={}, ::std::optional out_dtype=::std::nullopt) { + return at::_ops::_grouped_mm::redispatch(dispatchKeySet, self, mat2, offs, bias, out_dtype); + } + + // aten::_sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_compressed_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t nnz, int64_t dense_dim, at::IntArrayRef size, at::IntArrayRef blocksize, at::ScalarType index_dtype, at::TensorOptions options) { + return at::_ops::_sparse_compressed_tensor_with_dims::redispatch(dispatchKeySet, nnz, dense_dim, size, blocksize, index_dtype, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_compressed_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t nnz, int64_t dense_dim, at::IntArrayRef size, at::IntArrayRef blocksize, at::ScalarType index_dtype, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_compressed_tensor_with_dims::redispatch(dispatchKeySet, nnz, dense_dim, size, blocksize, index_dtype, dtype, layout, device, pin_memory); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_csr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_csc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_bsr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsr_tensor_crow_col_value_size::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_bsc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsc_tensor_ccol_row_value_size::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_compressed_tensor_comp_plain_value::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_compressed_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_compressed_tensor_comp_plain_value::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_csr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_csc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_csc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_csc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_bsr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsr_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsr_tensor_crow_col_value::redispatch(dispatchKeySet, crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + + // aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::TensorOptions options) { + return at::_ops::sparse_bsc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_bsc_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_bsc_tensor_ccol_row_value::redispatch(dispatchKeySet, ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_compressed_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_compressed_tensor_unsafe::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_csr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_csr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_csc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_csc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_csc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_bsr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsr_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_bsr_tensor_unsafe::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}) { + return at::_ops::_sparse_bsc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _sparse_bsc_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_bsc_tensor_unsafe::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::sparse_coo_tensor_size::redispatch(dispatchKeySet, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::sparse_coo_tensor_size::redispatch(dispatchKeySet, size, dtype, layout, device, pin_memory); + } + + // aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::sparse_coo_tensor_indices::redispatch(dispatchKeySet, indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::sparse_coo_tensor_indices::redispatch(dispatchKeySet, indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::sparse_coo_tensor_indices_size::redispatch(dispatchKeySet, indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor sparse_coo_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::sparse_coo_tensor_indices_size::redispatch(dispatchKeySet, indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, c10::fromIntArrayRefSlow(size), c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, c10::fromIntArrayRefSlow(size), dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, at::TensorOptions options={}, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_unsafe_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_unsafe::redispatch(dispatchKeySet, indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> () + inline void _validate_sparse_coo_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional is_coalesced=::std::nullopt, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_coo_tensor_args::redispatch(dispatchKeySet, indices, values, size, is_coalesced, check_pinning); + } + + // aten::_validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> () + inline void _validate_sparse_compressed_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_compressed_tensor_args::redispatch(dispatchKeySet, compressed_indices, plain_indices, values, size, layout, check_pinning); + } + + // aten::_validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_csr_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_csr_tensor_args::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, check_pinning); + } + + // aten::_validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_csc_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_csc_tensor_args::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, check_pinning); + } + + // aten::_validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_bsr_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_bsr_tensor_args::redispatch(dispatchKeySet, crow_indices, col_indices, values, size, check_pinning); + } + + // aten::_validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> () + inline void _validate_sparse_bsc_tensor_args(c10::DispatchKeySet dispatchKeySet, const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning=::std::nullopt) { + return at::_ops::_validate_sparse_bsc_tensor_args::redispatch(dispatchKeySet, ccol_indices, row_indices, values, size, check_pinning); + } + + // aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::TensorOptions options) { + return at::_ops::_sparse_coo_tensor_with_dims::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::_sparse_coo_tensor_with_dims::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, dtype, layout, device, pin_memory); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors_symint(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, at::TensorOptions options, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), is_coalesced); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor + inline at::Tensor _sparse_coo_tensor_with_dims_and_tensors_symint(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + + // aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + inline const at::Tensor & sparse_resize_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!) + inline const at::Tensor & sparse_resize_and_clear_(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_and_clear_::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_mask(Tensor self, Tensor mask) -> Tensor + inline at::Tensor sparse_mask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::sparse_mask::redispatch(dispatchKeySet, self, mask); + } + + // aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor + inline at::Tensor _sparse_mask_projection(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches=false) { + return at::_ops::_sparse_mask_projection::redispatch(dispatchKeySet, self, mask, accumulate_matches); + } + + // aten::_to_cpu(Tensor[] tensors) -> Tensor[] + inline ::std::vector _to_cpu(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::_to_cpu::redispatch(dispatchKeySet, tensors); + } + + // aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor + inline at::Tensor to_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::to_dense::redispatch(dispatchKeySet, self, dtype, masked_grad); + } + + // aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor + inline at::Tensor _to_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::_to_dense::redispatch(dispatchKeySet, self, dtype, masked_grad); + } + + // aten::to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor + inline at::Tensor to_dense_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::to_dense_backward::redispatch(dispatchKeySet, grad, input, masked_grad); + } + + // aten::sparse_dim(Tensor self) -> int + inline int64_t sparse_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sparse_dim::redispatch(dispatchKeySet, self); + } + + // aten::_dimI(Tensor self) -> int + inline int64_t _dimI(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_dimI::redispatch(dispatchKeySet, self); + } + + // aten::dense_dim(Tensor self) -> int + inline int64_t dense_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::dense_dim::redispatch(dispatchKeySet, self); + } + + // aten::_dimV(Tensor self) -> int + inline int64_t _dimV(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_dimV::redispatch(dispatchKeySet, self); + } + + // aten::_nnz(Tensor self) -> int + inline int64_t _nnz(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_nnz::redispatch(dispatchKeySet, self); + } + + // aten::coalesce(Tensor(a) self) -> Tensor(a) + inline at::Tensor coalesce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::coalesce::redispatch(dispatchKeySet, self); + } + + // aten::_coalesce(Tensor self) -> Tensor + inline at::Tensor _coalesce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_coalesce::redispatch(dispatchKeySet, self); + } + + // aten::is_coalesced(Tensor self) -> bool + inline bool is_coalesced(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::is_coalesced::redispatch(dispatchKeySet, self); + } + + // aten::_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor _indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_indices::redispatch(dispatchKeySet, self); + } + + // aten::_values(Tensor(a) self) -> Tensor(a) + inline at::Tensor _values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_values::redispatch(dispatchKeySet, self); + } + + // aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!) + inline at::Tensor & _coalesced_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, bool coalesced) { + return at::_ops::_coalesced_::redispatch(dispatchKeySet, self, coalesced); + } + + // aten::indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::indices::redispatch(dispatchKeySet, self); + } + + // aten::values(Tensor(a) self) -> Tensor(a) + inline at::Tensor values(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::values::redispatch(dispatchKeySet, self); + } + + // aten::crow_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor crow_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::crow_indices::redispatch(dispatchKeySet, self); + } + + // aten::col_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor col_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::col_indices::redispatch(dispatchKeySet, self); + } + + // aten::ccol_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor ccol_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ccol_indices::redispatch(dispatchKeySet, self); + } + + // aten::row_indices(Tensor(a) self) -> Tensor(a) + inline at::Tensor row_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::row_indices::redispatch(dispatchKeySet, self); + } + + // aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hspmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mat1, const at::Tensor & mat2) { + return at::_ops::hspmm_out::redispatch(dispatchKeySet, mat1, mat2, out); + } + + // aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hspmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat2, at::Tensor & out) { + return at::_ops::hspmm_out::redispatch(dispatchKeySet, mat1, mat2, out); + } + + // aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor + inline at::Tensor hspmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mat1, const at::Tensor & mat2) { + return at::_ops::hspmm::redispatch(dispatchKeySet, mat1, mat2); + } + + // aten::copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!) + inline at::Tensor & copy_sparse_to_sparse_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_sparse_to_sparse_::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[] + inline ::std::vector unbind(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0) { + return at::_ops::unbind_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[] + inline ::std::vector unbind(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim) { + return at::_ops::unbind_Dimname::redispatch(dispatchKeySet, self, dim); + } + + // aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + inline at::Tensor to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim) { + return at::_ops::to_sparse_sparse_dim::redispatch(dispatchKeySet, self, sparse_dim); + } + + // aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor + inline at::Tensor _to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim) { + return at::_ops::_to_sparse_sparse_dim::redispatch(dispatchKeySet, self, sparse_dim); + } + + // aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim); + } + + // aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim); + } + + // aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_csr::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csr::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_csc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_csc::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_csc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csc::redispatch(dispatchKeySet, self, dense_dim); + } + + // aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_bsr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_bsr::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_bsr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsr::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor to_sparse_bsc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::to_sparse_bsc::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor + inline at::Tensor _to_sparse_bsc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsc::redispatch(dispatchKeySet, self, blocksize, dense_dim); + } + + // aten::_to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor) + inline ::std::tuple _to_sparse_semi_structured(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense) { + return at::_ops::_to_sparse_semi_structured::redispatch(dispatchKeySet, dense); + } + + // aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor + inline at::Tensor to_mkldnn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::to_mkldnn::redispatch(dispatchKeySet, self, dtype); + } + + // aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv2d_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt); + } + + // aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv2d_weight_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size); + } + + // aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv3d_weight(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt); + } + + // aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor + inline at::Tensor mkldnn_reorder_conv3d_weight_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size); + } + + // aten::to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor + inline at::Tensor to_mkldnn_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input) { + return at::_ops::to_mkldnn_backward::redispatch(dispatchKeySet, grad, input); + } + + // aten::quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor + inline at::Tensor quantize_per_tensor_dynamic(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool reduce_range) { + return at::_ops::quantize_per_tensor_dynamic::redispatch(dispatchKeySet, self, dtype, reduce_range); + } + + // aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor + inline at::Tensor quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor::redispatch(dispatchKeySet, self, scale, zero_point, dtype); + } + + // aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor + inline at::Tensor quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, dtype); + } + + // aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[] + inline ::std::vector quantize_per_tensor(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensors::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype); + } + + // aten::quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor + inline at::Tensor quantize_per_channel(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) { + return at::_ops::quantize_per_channel::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype); + } + + // aten::dequantize.self(Tensor self) -> Tensor + inline at::Tensor dequantize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::dequantize_self::redispatch(dispatchKeySet, self); + } + + // aten::dequantize.tensors(Tensor[] tensors) -> Tensor[] + inline ::std::vector dequantize(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::dequantize_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::q_scale(Tensor self) -> float + inline double q_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_scale::redispatch(dispatchKeySet, self); + } + + // aten::q_zero_point(Tensor self) -> int + inline int64_t q_zero_point(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_zero_point::redispatch(dispatchKeySet, self); + } + + // aten::q_per_channel_scales(Tensor self) -> Tensor + inline at::Tensor q_per_channel_scales(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_per_channel_scales::redispatch(dispatchKeySet, self); + } + + // aten::q_per_channel_zero_points(Tensor self) -> Tensor + inline at::Tensor q_per_channel_zero_points(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_per_channel_zero_points::redispatch(dispatchKeySet, self); + } + + // aten::q_per_channel_axis(Tensor self) -> int + inline int64_t q_per_channel_axis(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::q_per_channel_axis::redispatch(dispatchKeySet, self); + } + + // aten::int_repr(Tensor self) -> Tensor + inline at::Tensor int_repr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::int_repr::redispatch(dispatchKeySet, self); + } + + // aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor + inline at::Tensor _make_per_tensor_quantized_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point) { + return at::_ops::_make_per_tensor_quantized_tensor::redispatch(dispatchKeySet, self, scale, zero_point); + } + + // aten::_make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor + inline at::Tensor _make_per_channel_quantized_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) { + return at::_ops::_make_per_channel_quantized_tensor::redispatch(dispatchKeySet, self, scale, zero_point, axis); + } + + // aten::qscheme(Tensor self) -> QScheme + inline at::QScheme qscheme(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::qscheme::redispatch(dispatchKeySet, self); + } + + // aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor + inline at::Tensor fake_quantize_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max); + } + + // aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor + inline at::Tensor fake_quantize_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max); + } + + // aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + inline ::std::tuple fake_quantize_per_tensor_affine_cachemask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max); + } + + // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max); + } + + // aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + inline at::Tensor fake_quantize_per_tensor_affine_cachemask_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & mask) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_backward::redispatch(dispatchKeySet, grad, mask); + } + + // aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + inline at::Tensor _fake_quantize_learnable_per_tensor_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor); + } + + // aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _fake_quantize_learnable_per_tensor_affine_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_backward::redispatch(dispatchKeySet, grad, self, scale, zero_point, quant_min, quant_max, grad_factor); + } + + // aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor + inline at::Tensor fake_quantize_per_channel_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_channel_affine::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max); + } + + // aten::fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask) + inline ::std::tuple fake_quantize_per_channel_affine_cachemask(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_channel_affine_cachemask::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max); + } + + // aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor + inline at::Tensor fake_quantize_per_channel_affine_cachemask_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & mask) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::redispatch(dispatchKeySet, grad, mask); + } + + // aten::_fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor + inline at::Tensor _fake_quantize_learnable_per_channel_affine(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_channel_affine::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + + // aten::_fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _fake_quantize_learnable_per_channel_affine_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_backward::redispatch(dispatchKeySet, grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + + // aten::fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor + inline at::Tensor fused_moving_avg_obs_fake_quant(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::fused_moving_avg_obs_fake_quant::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + + // aten::_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask) + inline ::std::tuple _fused_moving_avg_obs_fq_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::_fused_moving_avg_obs_fq_helper::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + + // aten::_choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int) + inline ::std::tuple _choose_qparams_per_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool reduce_range=false) { + return at::_ops::_choose_qparams_per_tensor::redispatch(dispatchKeySet, self, reduce_range); + } + + // aten::_saturate_weight_to_fp16(Tensor weight) -> Tensor + inline at::Tensor _saturate_weight_to_fp16(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight) { + return at::_ops::_saturate_weight_to_fp16::redispatch(dispatchKeySet, weight); + } + + // aten::choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor) + inline ::std::tuple choose_qparams_optimized(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width) { + return at::_ops::choose_qparams_optimized::redispatch(dispatchKeySet, input, numel, n_bins, ratio, bit_width); + } + + // aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a) + inline at::Tensor _autocast_to_reduced_precision(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) { + return at::_ops::_autocast_to_reduced_precision::redispatch(dispatchKeySet, self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + } + + // aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a) + inline at::Tensor _autocast_to_full_precision(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool cuda_enabled, bool cpu_enabled) { + return at::_ops::_autocast_to_full_precision::redispatch(dispatchKeySet, self, cuda_enabled, cpu_enabled); + } + + // aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, bool non_blocking=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::_to_copy::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor _to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, ::std::optional memory_format) { + return at::_ops::_to_copy::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, non_blocking, memory_format); + } + + // aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorOptions options={}, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_dtype_layout::redispatch(dispatchKeySet, self, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt(), non_blocking, copy, c10::impl::check_tensor_options_and_extract_memory_format(options, memory_format)); + } + + // aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) { + return at::_ops::to_dtype_layout::redispatch(dispatchKeySet, self, dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + } + + // aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_device::redispatch(dispatchKeySet, self, device, dtype, non_blocking, copy, memory_format); + } + + // aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_dtype::redispatch(dispatchKeySet, self, dtype, non_blocking, copy, memory_format); + } + + // aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a) + inline at::Tensor to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, bool non_blocking=false, bool copy=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::to_other::redispatch(dispatchKeySet, self, other, non_blocking, copy, memory_format); + } + + // aten::meshgrid(Tensor[] tensors) -> Tensor[] + inline ::std::vector meshgrid(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::meshgrid::redispatch(dispatchKeySet, tensors); + } + + // aten::meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[] + inline ::std::vector meshgrid(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, c10::string_view indexing) { + return at::_ops::meshgrid_indexing::redispatch(dispatchKeySet, tensors, indexing); + } + + // aten::cartesian_prod(Tensor[] tensors) -> Tensor + inline at::Tensor cartesian_prod(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::cartesian_prod::redispatch(dispatchKeySet, tensors); + } + + // aten::combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor + inline at::Tensor combinations(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t r=2, bool with_replacement=false) { + return at::_ops::combinations::redispatch(dispatchKeySet, self, r, with_replacement); + } + + // aten::item(Tensor self) -> Scalar + inline at::Scalar item(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::item::redispatch(dispatchKeySet, self); + } + + // aten::result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & tensor, const at::Tensor & other) { + return at::_ops::result_type_Tensor::redispatch(dispatchKeySet, tensor, other); + } + + // aten::result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Tensor & tensor, const at::Scalar & other) { + return at::_ops::result_type_Scalar::redispatch(dispatchKeySet, tensor, other); + } + + // aten::result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Scalar & scalar, const at::Tensor & tensor) { + return at::_ops::result_type_Scalar_Tensor::redispatch(dispatchKeySet, scalar, tensor); + } + + // aten::result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType + inline at::ScalarType result_type(c10::DispatchKeySet dispatchKeySet, const at::Scalar & scalar1, const at::Scalar & scalar2) { + return at::_ops::result_type_Scalar_Scalar::redispatch(dispatchKeySet, scalar1, scalar2); + } + + // aten::can_cast(ScalarType from_, ScalarType to) -> bool + inline bool can_cast(c10::DispatchKeySet dispatchKeySet, at::ScalarType from_, at::ScalarType to) { + return at::_ops::can_cast::redispatch(dispatchKeySet, from_, to); + } + + // aten::promote_types(ScalarType type1, ScalarType type2) -> ScalarType + inline at::ScalarType promote_types(c10::DispatchKeySet dispatchKeySet, at::ScalarType type1, at::ScalarType type2) { + return at::_ops::promote_types::redispatch(dispatchKeySet, type1, type2); + } + + // aten::_local_scalar_dense(Tensor self) -> Scalar + inline at::Scalar _local_scalar_dense(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_local_scalar_dense::redispatch(dispatchKeySet, self); + } + + // aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _lstm_mps(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::_lstm_mps::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[]) + inline ::std::tuple,::std::vector> lstm_mps_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_mps_backward::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_lstm_cell::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias); + } + + // aten::_thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_lstm_cell_backward_impl(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + + // aten::_thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_lstm_cell_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_lstm_cell_backward::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + + // aten::_thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_differentiable_lstm_cell_backward(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const ::std::optional & input_bias, const ::std::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy) { + return at::_ops::_thnn_differentiable_lstm_cell_backward::redispatch(dispatchKeySet, grad_hy, grad_cy, input_gates, hidden_gates, input_bias, hidden_bias, cx, cy); + } + + // aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor) + inline ::std::tuple _thnn_fused_gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_gru_cell::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + + // aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_fused_gru_cell_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_gru_cell_backward::redispatch(dispatchKeySet, grad_hy, workspace, has_bias); + } + + // aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _thnn_differentiable_gru_cell_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + return at::_ops::_thnn_differentiable_gru_cell_backward::redispatch(dispatchKeySet, grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + + // aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor) + inline ::std::tuple lstm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor) + inline ::std::tuple lstm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::lstm_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple gru(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::gru_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + inline ::std::tuple gru(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::gru_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple rnn_tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::rnn_tanh_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + inline ::std::tuple rnn_tanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::rnn_tanh_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple rnn_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::rnn_relu_input::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + + // aten::rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor) + inline ::std::tuple rnn_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + return at::_ops::rnn_relu_data::redispatch(dispatchKeySet, data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + + // aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor) + inline ::std::tuple lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::lstm_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + inline at::Tensor gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::gru_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + inline at::Tensor rnn_tanh_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::rnn_tanh_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor + inline at::Tensor rnn_relu_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih={}, const ::std::optional & b_hh={}) { + return at::_ops::rnn_relu_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh); + } + + // aten::quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor) + inline ::std::tuple quantized_lstm_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_lstm_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + inline at::Tensor quantized_gru_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_gru_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + inline at::Tensor quantized_rnn_relu_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_rnn_relu_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor + inline at::Tensor quantized_rnn_tanh_cell(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + return at::_ops::quantized_rnn_tanh_cell::redispatch(dispatchKeySet, input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + + // aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor) + inline ::std::tuple _pack_padded_sequence(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & lengths, bool batch_first) { + return at::_ops::_pack_padded_sequence::redispatch(dispatchKeySet, input, lengths, batch_first); + } + + // aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + inline at::Tensor _pack_padded_sequence_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) { + return at::_ops::_pack_padded_sequence_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(input_size), batch_sizes, batch_first); + } + + // aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor + inline at::Tensor _pack_padded_sequence_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) { + return at::_ops::_pack_padded_sequence_backward::redispatch(dispatchKeySet, grad, input_size, batch_sizes, batch_first); + } + + // aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor) + inline ::std::tuple _pad_packed_sequence(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length) { + return at::_ops::_pad_packed_sequence::redispatch(dispatchKeySet, data, batch_sizes, batch_first, padding_value, total_length); + } + + // aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source) { + return at::_ops::set__source_Storage::redispatch(dispatchKeySet, self, source); + } + + // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set__source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set__source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride); + } + + // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set__source_Tensor_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!) + inline at::Tensor & set__symint(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set__source_Tensor_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride); + } + + // aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & source) { + return at::_ops::set__source_Tensor::redispatch(dispatchKeySet, self, source); + } + + // aten::set_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & set_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::set_::redispatch(dispatchKeySet, self); + } + + // aten::lift(Tensor self) -> Tensor + inline at::Tensor lift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lift::redispatch(dispatchKeySet, self); + } + + // aten::lift_fresh(Tensor(a) self) -> Tensor(a) + inline at::Tensor lift_fresh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lift_fresh::redispatch(dispatchKeySet, self); + } + + // aten::lift_fresh_copy(Tensor self) -> Tensor + inline at::Tensor lift_fresh_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lift_fresh_copy::redispatch(dispatchKeySet, self); + } + + // aten::is_set_to(Tensor self, Tensor tensor) -> bool + inline bool is_set_to(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor) { + return at::_ops::is_set_to::redispatch(dispatchKeySet, self, tensor); + } + + // aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!) + inline at::Tensor & masked_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + return at::_ops::masked_fill__Scalar::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor + inline at::Tensor masked_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + return at::_ops::masked_fill_Scalar::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!) + inline at::Tensor & masked_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + return at::_ops::masked_fill__Tensor::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor + inline at::Tensor masked_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + return at::_ops::masked_fill_Tensor::redispatch(dispatchKeySet, self, mask, value); + } + + // aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!) + inline at::Tensor & masked_scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + return at::_ops::masked_scatter_::redispatch(dispatchKeySet, self, mask, source); + } + + // aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor + inline at::Tensor masked_scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + return at::_ops::masked_scatter::redispatch(dispatchKeySet, self, mask, source); + } + + // aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + inline at::Tensor masked_scatter_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, at::IntArrayRef sizes) { + return at::_ops::masked_scatter_backward::redispatch(dispatchKeySet, grad_output, mask, c10::fromIntArrayRefSlow(sizes)); + } + + // aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor + inline at::Tensor masked_scatter_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes) { + return at::_ops::masked_scatter_backward::redispatch(dispatchKeySet, grad_output, mask, sizes); + } + + // aten::_masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor + inline at::Tensor _masked_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, ::std::optional dim=::std::nullopt, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_masked_softmax::redispatch(dispatchKeySet, self, mask, dim, mask_type); + } + + // aten::_masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor + inline at::Tensor _masked_softmax_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim=::std::nullopt) { + return at::_ops::_masked_softmax_backward::redispatch(dispatchKeySet, grad_output, output, mask, dim); + } + + // aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::view::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a) + inline at::Tensor view_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::view::redispatch(dispatchKeySet, self, size); + } + + // aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a) + inline at::Tensor view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::view_dtype::redispatch(dispatchKeySet, self, dtype); + } + + // aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!) + inline at::Tensor & put_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) { + return at::_ops::put_::redispatch(dispatchKeySet, self, index, source, accumulate); + } + + // aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor + inline at::Tensor put(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) { + return at::_ops::put::redispatch(dispatchKeySet, self, index, source, accumulate); + } + + // aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add_out::redispatch(dispatchKeySet, self, dim, index, source, alpha, out); + } + + // aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::index_add_out::redispatch(dispatchKeySet, self, dim, index, source, alpha, out); + } + + // aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & index_add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add_::redispatch(dispatchKeySet, self, dim, index, source, alpha); + } + + // aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + inline at::Tensor index_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add::redispatch(dispatchKeySet, self, dim, index, source, alpha); + } + + // aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor + inline at::Tensor index_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha=1) { + return at::_ops::index_add_dimname::redispatch(dispatchKeySet, self, dim, index, source, alpha); + } + + // aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) { + return at::_ops::index_reduce_out::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self, out); + } + + // aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self, at::Tensor & out) { + return at::_ops::index_reduce_out::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self, out); + } + + // aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!) + inline at::Tensor & index_reduce_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) { + return at::_ops::index_reduce_::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self); + } + + // aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor + inline at::Tensor index_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self=true) { + return at::_ops::index_reduce::redispatch(dispatchKeySet, self, dim, index, source, reduce, include_self); + } + + // aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill__int_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill_int_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill__int_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill_int_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill__Dimname_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!) + inline at::Tensor & index_fill_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill__Dimname_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill_Dimname_Scalar::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor + inline at::Tensor index_fill(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill_Dimname_Tensor::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_src::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter__src::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_src_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out) { + return at::_ops::scatter_src_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter_value::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter__value::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter_value_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out) { + return at::_ops::scatter_value_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + return at::_ops::scatter_reduce::redispatch(dispatchKeySet, self, dim, index, src, reduce); + } + + // aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + return at::_ops::scatter__reduce::redispatch(dispatchKeySet, self, dim, index, src, reduce); + } + + // aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + return at::_ops::scatter_reduce_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, out); + } + + // aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, at::Tensor & out) { + return at::_ops::scatter_reduce_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, out); + } + + // aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + return at::_ops::scatter_value_reduce::redispatch(dispatchKeySet, self, dim, index, value, reduce); + } + + // aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!) + inline at::Tensor & scatter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + return at::_ops::scatter__value_reduce::redispatch(dispatchKeySet, self, dim, index, value, reduce); + } + + // aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + return at::_ops::scatter_value_reduce_out::redispatch(dispatchKeySet, self, dim, index, value, reduce, out); + } + + // aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce, at::Tensor & out) { + return at::_ops::scatter_value_reduce_out::redispatch(dispatchKeySet, self, dim, index, value, reduce, out); + } + + // aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_dimname_src::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor + inline at::Tensor scatter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::scatter_dimname_value::redispatch(dispatchKeySet, self, dim, index, value); + } + + // aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!) + inline at::Tensor & scatter_add_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add_::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out) { + return at::_ops::scatter_add_out::redispatch(dispatchKeySet, self, dim, index, src, out); + } + + // aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor + inline at::Tensor scatter_add(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + return at::_ops::scatter_add_dimname::redispatch(dispatchKeySet, self, dim, index, src); + } + + // aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor + inline at::Tensor scatter_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) { + return at::_ops::scatter_reduce_two::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self); + } + + // aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!) + inline at::Tensor & scatter_reduce_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) { + return at::_ops::scatter_reduce__two::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self); + } + + // aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self=true) { + return at::_ops::scatter_reduce_two_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self, out); + } + + // aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scatter_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self, at::Tensor & out) { + return at::_ops::scatter_reduce_two_out::redispatch(dispatchKeySet, self, dim, index, src, reduce, include_self, out); + } + + // aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & eq_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::eq__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & eq_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::eq__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_and_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_and_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_and_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_and_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_and(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_and__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_and_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_and__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __and__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__and___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __and__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__and___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __iand__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__iand___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __iand__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__iand___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_or_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_or_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_or(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_or__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_or_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_or__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __or__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__or___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __or__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__or___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __ior__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__ior___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __ior__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__ior___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_xor_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_xor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_xor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_xor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_xor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_xor__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_xor_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_xor__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __xor__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__xor___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __xor__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__xor___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __ixor__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__ixor___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __ixor__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__ixor___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __lshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__lshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __lshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__lshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __ilshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__ilshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __ilshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__ilshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_left_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_left_shift_Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_left_shift__Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_left_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_left_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_left_shift(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor __rshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__rshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor __rshift__(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__rshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & __irshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::__irshift___Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & __irshift__(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::__irshift___Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_right_shift_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_right_shift_Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_right_shift__Tensor_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::bitwise_right_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::bitwise_right_shift_Tensor_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor bitwise_right_shift(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) + inline at::Tensor & tril_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t diagonal=0) { + return at::_ops::tril_::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!) + inline at::Tensor & triu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t diagonal=0) { + return at::_ops::triu_::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::digamma_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & digamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::digamma_::redispatch(dispatchKeySet, self); + } + + // aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!) + inline at::Tensor & lerp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + return at::_ops::lerp__Scalar::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!) + inline at::Tensor & lerp_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + return at::_ops::lerp__Tensor::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!) + inline at::Tensor & addbmm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm_::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addbmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addbmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::addbmm_out::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha, out); + } + + // aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor + inline at::Tensor addbmm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::addbmm::redispatch(dispatchKeySet, self, batch1, batch2, beta, alpha); + } + + // aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) { + return at::_ops::random__from::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t to, ::std::optional generator=::std::nullopt) { + return at::_ops::random__to::redispatch(dispatchKeySet, self, to, generator); + } + + // aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & random_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::random_::redispatch(dispatchKeySet, self, generator); + } + + // aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & uniform_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double from=0, double to=1, ::std::optional generator=::std::nullopt) { + return at::_ops::uniform_::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & cauchy_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double median=0, double sigma=1, ::std::optional generator=::std::nullopt) { + return at::_ops::cauchy_::redispatch(dispatchKeySet, self, median, sigma, generator); + } + + // aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & log_normal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double mean=1, double std=2, ::std::optional generator=::std::nullopt) { + return at::_ops::log_normal_::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & exponential_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double lambd=1, ::std::optional generator=::std::nullopt) { + return at::_ops::exponential_::redispatch(dispatchKeySet, self, lambd, generator); + } + + // aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & geometric_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::geometric_::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::diag_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) { + return at::_ops::diag_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::diag(Tensor self, int diagonal=0) -> Tensor + inline at::Tensor diag(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::diag::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cross_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, ::std::optional dim=::std::nullopt) { + return at::_ops::cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cross_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional dim, at::Tensor & out) { + return at::_ops::cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor + inline at::Tensor cross(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, ::std::optional dim=::std::nullopt) { + return at::_ops::cross::redispatch(dispatchKeySet, self, other, dim); + } + + // aten::triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) { + return at::_ops::triu_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::triu(Tensor self, int diagonal=0) -> Tensor + inline at::Tensor triu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::triu::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal, at::Tensor & out) { + return at::_ops::tril_out::redispatch(dispatchKeySet, self, diagonal, out); + } + + // aten::tril(Tensor self, int diagonal=0) -> Tensor + inline at::Tensor tril(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t diagonal=0) { + return at::_ops::tril::redispatch(dispatchKeySet, self, diagonal); + } + + // aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor tril_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset=0, at::TensorOptions options=at::kLong) { + return at::_ops::tril_indices::redispatch(dispatchKeySet, row, col, offset, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor tril_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::tril_indices::redispatch(dispatchKeySet, row, col, offset, dtype, layout, device, pin_memory); + } + + // aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor triu_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset=0, at::TensorOptions options=at::kLong) { + return at::_ops::triu_indices::redispatch(dispatchKeySet, row, col, offset, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor triu_indices(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::triu_indices::redispatch(dispatchKeySet, row, col, offset, dtype, layout, device, pin_memory); + } + + // aten::trace(Tensor self) -> Tensor + inline at::Tensor trace(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::trace::redispatch(dispatchKeySet, self); + } + + // aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor + inline at::Tensor trace_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef sizes) { + return at::_ops::trace_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(sizes)); + } + + // aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor + inline at::Tensor trace_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef sizes) { + return at::_ops::trace_backward::redispatch(dispatchKeySet, grad, sizes); + } + + // aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ne_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::ne_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor ne(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ne_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ne_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ne_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::ne_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ne.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor ne(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ne_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & ne_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::ne__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & ne_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::ne__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::not_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::not_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor not_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::not_equal_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::not_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & not_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::not_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor not_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::not_equal_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & not_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::not_equal__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & not_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::not_equal__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::eq_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::eq_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor eq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::eq_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::eq_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & eq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::eq_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::eq.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor eq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::eq_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ge_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::ge_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor ge(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::ge_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ge_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ge_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::ge_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::ge.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor ge(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::ge_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & ge_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::ge__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & ge_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::ge__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::greater_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor greater_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_equal_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::greater_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor greater_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_equal_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & greater_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_equal__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & greater_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_equal__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::le_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::le_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor le(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::le_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::le_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & le_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::le_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::le.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor le(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::le_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & le_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::le__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & le_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::le__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::less_equal_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor less_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_equal_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_equal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::less_equal_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor less_equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_equal_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & less_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_equal__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & less_equal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_equal__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::gt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::gt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor gt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::gt_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::gt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::gt.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor gt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::gt_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & gt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::gt__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & gt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::gt__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::greater_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor greater(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & greater_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::greater_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::greater.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor greater(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & greater_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::greater__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & greater_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::greater__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::lt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::lt_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor lt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::lt_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::lt_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::lt.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor lt(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::lt_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & lt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::lt__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & lt_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::lt__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::less_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor less(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::less_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & less_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::less_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::less.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor less(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::less_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & less_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::less__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & less_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::less__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & index) { + return at::_ops::take_out::redispatch(dispatchKeySet, self, index, out); + } + + // aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, at::Tensor & out) { + return at::_ops::take_out::redispatch(dispatchKeySet, self, index, out); + } + + // aten::take(Tensor self, Tensor index) -> Tensor + inline at::Tensor take(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index) { + return at::_ops::take::redispatch(dispatchKeySet, self, index); + } + + // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_along_dim_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, ::std::optional dim=::std::nullopt) { + return at::_ops::take_along_dim_out::redispatch(dispatchKeySet, self, indices, dim, out); + } + + // aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & take_along_dim_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, ::std::optional dim, at::Tensor & out) { + return at::_ops::take_along_dim_out::redispatch(dispatchKeySet, self, indices, dim, out); + } + + // aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor + inline at::Tensor take_along_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, ::std::optional dim=::std::nullopt) { + return at::_ops::take_along_dim::redispatch(dispatchKeySet, self, indices, dim); + } + + // aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, at::Tensor & out) { + return at::_ops::index_select_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select(Tensor self, int dim, Tensor index) -> Tensor + inline at::Tensor index_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, const at::Tensor & index) { + return at::_ops::index_select_dimname_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, at::Tensor & out) { + return at::_ops::index_select_dimname_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor + inline at::Tensor index_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index) { + return at::_ops::index_select_dimname::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor + inline at::Tensor index_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, at::IntArrayRef self_sizes, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select_backward::redispatch(dispatchKeySet, grad, c10::fromIntArrayRefSlow(self_sizes), dim, index); + } + + // aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor + inline at::Tensor index_select_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index) { + return at::_ops::index_select_backward::redispatch(dispatchKeySet, grad, self_sizes, dim, index); + } + + // aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_select_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::masked_select_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_select_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, at::Tensor & out) { + return at::_ops::masked_select_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::masked_select(Tensor self, Tensor mask) -> Tensor + inline at::Tensor masked_select(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::masked_select::redispatch(dispatchKeySet, self, mask); + } + + // aten::masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor + inline at::Tensor masked_select_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask) { + return at::_ops::masked_select_backward::redispatch(dispatchKeySet, grad, input, mask); + } + + // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::nonzero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::nonzero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nonzero(Tensor self) -> Tensor + inline at::Tensor nonzero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::nonzero::redispatch(dispatchKeySet, self); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t size, int64_t fill_value=-1) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, int64_t fill_value, at::Tensor & out) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt size, int64_t fill_value=-1) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nonzero_static_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt size, int64_t fill_value, at::Tensor & out) { + return at::_ops::nonzero_static_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor + inline at::Tensor nonzero_static(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, int64_t fill_value=-1) { + return at::_ops::nonzero_static::redispatch(dispatchKeySet, self, size, fill_value); + } + + // aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor + inline at::Tensor nonzero_static_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt size, int64_t fill_value=-1) { + return at::_ops::nonzero_static::redispatch(dispatchKeySet, self, size, fill_value); + } + + // aten::nonzero_numpy(Tensor self) -> Tensor[] + inline ::std::vector nonzero_numpy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::nonzero_numpy::redispatch(dispatchKeySet, self); + } + + // aten::argwhere(Tensor self) -> Tensor + inline at::Tensor argwhere(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::argwhere::redispatch(dispatchKeySet, self); + } + + // aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out) { + return at::_ops::gather_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor + inline at::Tensor gather(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather::redispatch(dispatchKeySet, self, dim, index, sparse_grad); + } + + // aten::gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor + inline at::Tensor gather_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) { + return at::_ops::gather_backward::redispatch(dispatchKeySet, grad, self, dim, index, sparse_grad); + } + + // aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather_dimname_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & gather_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out) { + return at::_ops::gather_dimname_out::redispatch(dispatchKeySet, self, dim, index, sparse_grad, out); + } + + // aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor + inline at::Tensor gather(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad=false) { + return at::_ops::gather_dimname::redispatch(dispatchKeySet, self, dim, index, sparse_grad); + } + + // aten::_gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor + inline at::Tensor _gather_sparse_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad) { + return at::_ops::_gather_sparse_backward::redispatch(dispatchKeySet, self, dim, index, grad); + } + + // aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcmul_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + inline at::Tensor addcmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + inline at::Tensor & addcmul_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcmul_::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & addcdiv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out) { + return at::_ops::addcdiv_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor + inline at::Tensor addcdiv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!) + inline at::Tensor & addcdiv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value=1) { + return at::_ops::addcdiv_::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor + inline at::Tensor cross_entropy_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100, double label_smoothing=0.0) { + return at::_ops::cross_entropy_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, label_smoothing); + } + + // aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor + inline at::Tensor cross_entropy_loss_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100, double label_smoothing=0.0) { + return at::_ops::cross_entropy_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, label_smoothing); + } + + // aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) + inline ::std::tuple triangular_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & X, at::Tensor & M, const at::Tensor & self, const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) { + return at::_ops::triangular_solve_X::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular, X, M); + } + + // aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient) + inline ::std::tuple triangular_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular, at::Tensor & X, at::Tensor & M) { + return at::_ops::triangular_solve_X::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular, X, M); + } + + // aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient) + inline ::std::tuple triangular_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper=true, bool transpose=false, bool unitriangular=false) { + return at::_ops::triangular_solve::redispatch(dispatchKeySet, self, A, upper, transpose, unitriangular); + } + + // aten::_linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> () + inline void _linalg_check_errors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & info, c10::string_view api_name, bool is_matrix) { + return at::_ops::_linalg_check_errors::redispatch(dispatchKeySet, info, api_name, is_matrix); + } + + // aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_triangular_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & B, bool upper, bool left=true, bool unitriangular=false) { + return at::_ops::linalg_solve_triangular_out::redispatch(dispatchKeySet, self, B, upper, left, unitriangular, out); + } + + // aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_triangular_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular, at::Tensor & out) { + return at::_ops::linalg_solve_triangular_out::redispatch(dispatchKeySet, self, B, upper, left, unitriangular, out); + } + + // aten::linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor + inline at::Tensor linalg_solve_triangular(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & B, bool upper, bool left=true, bool unitriangular=false) { + return at::_ops::linalg_solve_triangular::redispatch(dispatchKeySet, self, B, upper, left, unitriangular); + } + + // aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor + inline at::Tensor linalg_vander(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, ::std::optional N=::std::nullopt) { + return at::_ops::linalg_vander::redispatch(dispatchKeySet, x, N.has_value() ? ::std::make_optional(c10::SymInt(*N)) : ::std::nullopt); + } + + // aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor + inline at::Tensor linalg_vander_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, ::std::optional N=::std::nullopt) { + return at::_ops::linalg_vander::redispatch(dispatchKeySet, x, N); + } + + // aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + inline ::std::tuple svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & V, const at::Tensor & self, bool some=true, bool compute_uv=true) { + return at::_ops::svd_U::redispatch(dispatchKeySet, self, some, compute_uv, U, S, V); + } + + // aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) + inline ::std::tuple svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some, bool compute_uv, at::Tensor & U, at::Tensor & S, at::Tensor & V) { + return at::_ops::svd_U::redispatch(dispatchKeySet, self, some, compute_uv, U, S, V); + } + + // aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V) + inline ::std::tuple svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some=true, bool compute_uv=true) { + return at::_ops::svd::redispatch(dispatchKeySet, self, some, compute_uv); + } + + // aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a) + inline at::Tensor swapaxes(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t axis0, int64_t axis1) { + return at::_ops::swapaxes::redispatch(dispatchKeySet, self, axis0, axis1); + } + + // aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!) + inline at::Tensor & swapaxes_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t axis0, int64_t axis1) { + return at::_ops::swapaxes_::redispatch(dispatchKeySet, self, axis0, axis1); + } + + // aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a) + inline at::Tensor swapdims(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::swapdims::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!) + inline at::Tensor & swapdims_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::swapdims_::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) { + return at::_ops::cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::cholesky(Tensor self, bool upper=False) -> Tensor + inline at::Tensor cholesky(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky::redispatch(dispatchKeySet, self, upper); + } + + // aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2, bool upper=false) { + return at::_ops::cholesky_solve_out::redispatch(dispatchKeySet, self, input2, upper, out); + } + + // aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, bool upper, at::Tensor & out) { + return at::_ops::cholesky_solve_out::redispatch(dispatchKeySet, self, input2, upper, out); + } + + // aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor + inline at::Tensor cholesky_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, bool upper=false) { + return at::_ops::cholesky_solve::redispatch(dispatchKeySet, self, input2, upper); + } + + // aten::_cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor + inline at::Tensor _cholesky_solve_helper(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper) { + return at::_ops::_cholesky_solve_helper::redispatch(dispatchKeySet, self, A, upper); + } + + // aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor + inline at::Tensor cholesky_inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky_inverse::redispatch(dispatchKeySet, self, upper); + } + + // aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_inverse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) { + return at::_ops::cholesky_inverse_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cholesky_inverse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) { + return at::_ops::cholesky_inverse_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple qr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & Q, at::Tensor & R, const at::Tensor & self, bool some=true) { + return at::_ops::qr_Q::redispatch(dispatchKeySet, self, some, Q, R); + } + + // aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple qr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some, at::Tensor & Q, at::Tensor & R) { + return at::_ops::qr_Q::redispatch(dispatchKeySet, self, some, Q, R); + } + + // aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R) + inline ::std::tuple qr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool some=true) { + return at::_ops::qr::redispatch(dispatchKeySet, self, some); + } + + // aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) + inline ::std::tuple geqrf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & a, at::Tensor & tau, const at::Tensor & self) { + return at::_ops::geqrf_a::redispatch(dispatchKeySet, self, a, tau); + } + + // aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau) + inline ::std::tuple geqrf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & a, at::Tensor & tau) { + return at::_ops::geqrf_a::redispatch(dispatchKeySet, self, a, tau); + } + + // aten::geqrf(Tensor self) -> (Tensor a, Tensor tau) + inline ::std::tuple geqrf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::geqrf::redispatch(dispatchKeySet, self); + } + + // aten::orgqr(Tensor self, Tensor input2) -> Tensor + inline at::Tensor orgqr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2) { + return at::_ops::orgqr::redispatch(dispatchKeySet, self, input2); + } + + // aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & orgqr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2) { + return at::_ops::orgqr_out::redispatch(dispatchKeySet, self, input2, out); + } + + // aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & orgqr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, at::Tensor & out) { + return at::_ops::orgqr_out::redispatch(dispatchKeySet, self, input2, out); + } + + // aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ormqr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) { + return at::_ops::ormqr_out::redispatch(dispatchKeySet, self, input2, input3, left, transpose, out); + } + + // aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ormqr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose, at::Tensor & out) { + return at::_ops::ormqr_out::redispatch(dispatchKeySet, self, input2, input3, left, transpose, out); + } + + // aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor + inline at::Tensor ormqr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left=true, bool transpose=false) { + return at::_ops::ormqr::redispatch(dispatchKeySet, self, input2, input3, left, transpose); + } + + // aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info) + inline ::std::tuple _lu_with_info(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool pivot=true, bool check_errors=true) { + return at::_ops::_lu_with_info::redispatch(dispatchKeySet, self, pivot, check_errors); + } + + // aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lu_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) { + return at::_ops::lu_solve_out::redispatch(dispatchKeySet, self, LU_data, LU_pivots, out); + } + + // aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lu_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots, at::Tensor & out) { + return at::_ops::lu_solve_out::redispatch(dispatchKeySet, self, LU_data, LU_pivots, out); + } + + // aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor + inline at::Tensor lu_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) { + return at::_ops::lu_solve::redispatch(dispatchKeySet, self, LU_data, LU_pivots); + } + + // aten::lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U) + inline ::std::tuple lu_unpack(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data=true, bool unpack_pivots=true) { + return at::_ops::lu_unpack::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots); + } + + // aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple lu_unpack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data=true, bool unpack_pivots=true) { + return at::_ops::lu_unpack_out::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U); + } + + // aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple lu_unpack_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots, at::Tensor & P, at::Tensor & L, at::Tensor & U) { + return at::_ops::lu_unpack_out::redispatch(dispatchKeySet, LU_data, LU_pivots, unpack_data, unpack_pivots, P, L, U); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_samples, bool replacement, ::std::optional generator, at::Tensor & out) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multinomial_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator, at::Tensor & out) { + return at::_ops::multinomial_out::redispatch(dispatchKeySet, self, num_samples, replacement, generator, out); + } + + // aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + inline at::Tensor multinomial(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial::redispatch(dispatchKeySet, self, num_samples, replacement, generator); + } + + // aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor + inline at::Tensor multinomial_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt num_samples, bool replacement=false, ::std::optional generator=::std::nullopt) { + return at::_ops::multinomial::redispatch(dispatchKeySet, self, num_samples, replacement, generator); + } + + // aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lgamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lgamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lgamma_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & lgamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::lgamma_::redispatch(dispatchKeySet, self); + } + + // aten::lgamma(Tensor self) -> Tensor + inline at::Tensor lgamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::lgamma::redispatch(dispatchKeySet, self); + } + + // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & digamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & digamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::digamma(Tensor self) -> Tensor + inline at::Tensor digamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::digamma::redispatch(dispatchKeySet, self); + } + + // aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polygamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, const at::Tensor & self) { + return at::_ops::polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & polygamma_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self, at::Tensor & out) { + return at::_ops::polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::polygamma(int n, Tensor self) -> Tensor + inline at::Tensor polygamma(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self) { + return at::_ops::polygamma::redispatch(dispatchKeySet, n, self); + } + + // aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!) + inline at::Tensor & polygamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, int64_t n) { + return at::_ops::polygamma_::redispatch(dispatchKeySet, self, n); + } + + // aten::erfinv(Tensor self) -> Tensor + inline at::Tensor erfinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::erfinv::redispatch(dispatchKeySet, self); + } + + // aten::erfinv_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & erfinv_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::erfinv_::redispatch(dispatchKeySet, self); + } + + // aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & erfinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::i0(Tensor self) -> Tensor + inline at::Tensor i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::i0::redispatch(dispatchKeySet, self); + } + + // aten::i0_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & i0_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::i0_::redispatch(dispatchKeySet, self); + } + + // aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sign(Tensor self) -> Tensor + inline at::Tensor sign(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::sign::redispatch(dispatchKeySet, self); + } + + // aten::sign_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & sign_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::sign_::redispatch(dispatchKeySet, self); + } + + // aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sign_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sign_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::signbit(Tensor self) -> Tensor + inline at::Tensor signbit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::signbit::redispatch(dispatchKeySet, self); + } + + // aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & signbit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::signbit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & signbit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::signbit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor + inline at::Tensor dist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p=2) { + return at::_ops::dist::redispatch(dispatchKeySet, self, other, p); + } + + // aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & atan2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::atan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & atan2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2_::redispatch(dispatchKeySet, self, other); + } + + // aten::atan2(Tensor self, Tensor other) -> Tensor + inline at::Tensor atan2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::atan2::redispatch(dispatchKeySet, self, other); + } + + // aten::arctan2(Tensor self, Tensor other) -> Tensor + inline at::Tensor arctan2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2::redispatch(dispatchKeySet, self, other); + } + + // aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & arctan2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::arctan2_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & arctan2_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::arctan2_::redispatch(dispatchKeySet, self, other); + } + + // aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + return at::_ops::lerp_Scalar_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out) { + return at::_ops::lerp_Scalar_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + return at::_ops::lerp_Tensor_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lerp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out) { + return at::_ops::lerp_Tensor_out::redispatch(dispatchKeySet, self, end, weight, out); + } + + // aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor + inline at::Tensor lerp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + return at::_ops::lerp_Scalar::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor + inline at::Tensor lerp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + return at::_ops::lerp_Tensor::redispatch(dispatchKeySet, self, end, weight); + } + + // aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & histc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) { + return at::_ops::histc_out::redispatch(dispatchKeySet, self, bins, min, max, out); + } + + // aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & histc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max, at::Tensor & out) { + return at::_ops::histc_out::redispatch(dispatchKeySet, self, bins, min, max, out); + } + + // aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor + inline at::Tensor histc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins=100, const at::Scalar & min=0, const at::Scalar & max=0) { + return at::_ops::histc::redispatch(dispatchKeySet, self, bins, min, max); + } + + // aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & hist, at::Tensor & bin_edges, const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bins_tensor_out::redispatch(dispatchKeySet, self, bins, weight, density, hist, bin_edges); + } + + // aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges) { + return at::_ops::histogram_bins_tensor_out::redispatch(dispatchKeySet, self, bins, weight, density, hist, bin_edges); + } + + // aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) + inline ::std::tuple histogram(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bins_tensor::redispatch(dispatchKeySet, self, bins, weight, density); + } + + // aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & hist, at::Tensor & bin_edges, const at::Tensor & self, int64_t bins=100, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bin_ct_out::redispatch(dispatchKeySet, self, bins, range, weight, density, hist, bin_edges); + } + + // aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges) + inline ::std::tuple histogram_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges) { + return at::_ops::histogram_bin_ct_out::redispatch(dispatchKeySet, self, bins, range, weight, density, hist, bin_edges); + } + + // aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges) + inline ::std::tuple histogram(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins=100, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogram_bin_ct::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::_histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[] + inline ::std::vector _histogramdd_bin_edges(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_bin_edges::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::_histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor + inline at::Tensor _histogramdd_from_bin_cts(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_cts::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor + inline at::Tensor _histogramdd_from_bin_tensors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_tensors::redispatch(dispatchKeySet, self, bins, weight, density); + } + + // aten::histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogramdd::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogramdd_int_bins::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges) + inline ::std::tuple> histogramdd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::histogramdd_TensorList_bins::redispatch(dispatchKeySet, self, bins, range, weight, density); + } + + // aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::fmod_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::fmod_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor fmod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::fmod_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & fmod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::fmod__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmod_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::fmod_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor fmod(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmod_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & fmod_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmod__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hypot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::hypot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hypot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::hypot_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::hypot(Tensor self, Tensor other) -> Tensor + inline at::Tensor hypot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::hypot::redispatch(dispatchKeySet, self, other); + } + + // aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & hypot_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::hypot_::redispatch(dispatchKeySet, self, other); + } + + // aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igamma_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::igamma_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igamma(Tensor self, Tensor other) -> Tensor + inline at::Tensor igamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igamma::redispatch(dispatchKeySet, self, other); + } + + // aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & igamma_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::igamma_::redispatch(dispatchKeySet, self, other); + } + + // aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igammac_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igammac_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & igammac_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::igammac_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::igammac(Tensor self, Tensor other) -> Tensor + inline at::Tensor igammac(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::igammac::redispatch(dispatchKeySet, self, other); + } + + // aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & igammac_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::igammac_::redispatch(dispatchKeySet, self, other); + } + + // aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nextafter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::nextafter_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nextafter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::nextafter_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::nextafter(Tensor self, Tensor other) -> Tensor + inline at::Tensor nextafter(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::nextafter::redispatch(dispatchKeySet, self, other); + } + + // aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & nextafter_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::nextafter_::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::remainder_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::remainder_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::remainder_Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) + inline at::Tensor & remainder_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & other) { + return at::_ops::remainder__Scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::remainder_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::remainder_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor + inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::remainder_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) + inline at::Tensor & remainder_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & other) { + return at::_ops::remainder__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor + inline at::Tensor remainder(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::remainder_Scalar_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::min(Tensor self) -> Tensor + inline at::Tensor min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::min::redispatch(dispatchKeySet, self); + } + + // aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::min_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::min_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::fmin(Tensor self, Tensor other) -> Tensor + inline at::Tensor fmin(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmin::redispatch(dispatchKeySet, self, other); + } + + // aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmin_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmin_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmin_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::fmin_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max(Tensor self) -> Tensor + inline at::Tensor max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::max::redispatch(dispatchKeySet, self); + } + + // aten::fmax(Tensor self, Tensor other) -> Tensor + inline at::Tensor fmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmax::redispatch(dispatchKeySet, self, other); + } + + // aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::fmax_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::fmax_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::maximum(Tensor self, Tensor other) -> Tensor + inline at::Tensor maximum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::maximum::redispatch(dispatchKeySet, self, other); + } + + // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & maximum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::maximum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & maximum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::maximum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max.other(Tensor self, Tensor other) -> Tensor + inline at::Tensor max(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::max_other::redispatch(dispatchKeySet, self, other); + } + + // aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::max_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::max_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::max_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::max_unary_out::redispatch(dispatchKeySet, self, out); + } + + // aten::minimum(Tensor self, Tensor other) -> Tensor + inline at::Tensor minimum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::minimum::redispatch(dispatchKeySet, self, other); + } + + // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & minimum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::minimum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & minimum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::minimum_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::min_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & min_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::min_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::min.other(Tensor self, Tensor other) -> Tensor + inline at::Tensor min(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::min_other::redispatch(dispatchKeySet, self, other); + } + + // aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor quantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::quantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor quantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile_scalar::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::quantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::quantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor nanquantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::nanquantile_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor + inline at::Tensor nanquantile(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile_scalar::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation); + } + + // aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double q, ::std::optional dim=::std::nullopt, bool keepdim=false, c10::string_view interpolation="linear") { + return at::_ops::nanquantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanquantile_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out) { + return at::_ops::nanquantile_scalar_out::redispatch(dispatchKeySet, self, q, dim, keepdim, interpolation, out); + } + + // aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::sort_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, ::std::optional stable, int64_t dim=-1, bool descending=false) { + return at::_ops::sort_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::sort::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, int64_t dim=-1, bool descending=false) { + return at::_ops::sort_stable::redispatch(dispatchKeySet, self, stable, dim, descending); + } + + // aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_dimname_values::redispatch(dispatchKeySet, self, dim, descending, values, indices); + } + + // aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple sort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices) { + return at::_ops::sort_dimname_values_stable::redispatch(dispatchKeySet, self, stable, dim, descending, values, indices); + } + + // aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices) + inline ::std::tuple sort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending=false) { + return at::_ops::sort_dimname_stable::redispatch(dispatchKeySet, self, stable, dim, descending); + } + + // aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & msort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::msort_out::redispatch(dispatchKeySet, self, out); + } + + // aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & msort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::msort_out::redispatch(dispatchKeySet, self, out); + } + + // aten::msort(Tensor self) -> Tensor + inline at::Tensor msort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::msort::redispatch(dispatchKeySet, self); + } + + // aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor + inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor + inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort_stable::redispatch(dispatchKeySet, self, stable, dim, descending); + } + + // aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argsort_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool stable, int64_t dim=-1, bool descending=false) { + return at::_ops::argsort_stable_out::redispatch(dispatchKeySet, self, stable, dim, descending, out); + } + + // aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & argsort_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out) { + return at::_ops::argsort_stable_out::redispatch(dispatchKeySet, self, stable, dim, descending, out); + } + + // aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor + inline at::Tensor argsort(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Dimname dim, bool descending=false) { + return at::_ops::argsort_dimname::redispatch(dispatchKeySet, self, dim, descending); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & values, at::Tensor & indices, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + inline ::std::tuple topk_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices) { + return at::_ops::topk_values::redispatch(dispatchKeySet, self, k, dim, largest, sorted, values, indices); + } + + // aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + inline ::std::tuple topk(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::redispatch(dispatchKeySet, self, k, dim, largest, sorted); + } + + // aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices) + inline ::std::tuple topk_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt k, int64_t dim=-1, bool largest=true, bool sorted=true) { + return at::_ops::topk::redispatch(dispatchKeySet, self, k, dim, largest, sorted); + } + + // aten::all(Tensor self) -> Tensor + inline at::Tensor all(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::all::redispatch(dispatchKeySet, self); + } + + // aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::all_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & all_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::all_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::any(Tensor self) -> Tensor + inline at::Tensor any(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::any::redispatch(dispatchKeySet, self); + } + + // aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::any_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & any_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::any_all_out::redispatch(dispatchKeySet, self, out); + } + + // aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & renorm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + return at::_ops::renorm_out::redispatch(dispatchKeySet, self, p, dim, maxnorm, out); + } + + // aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & renorm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm, at::Tensor & out) { + return at::_ops::renorm_out::redispatch(dispatchKeySet, self, p, dim, maxnorm, out); + } + + // aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor + inline at::Tensor renorm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + return at::_ops::renorm::redispatch(dispatchKeySet, self, p, dim, maxnorm); + } + + // aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!) + inline at::Tensor & renorm_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + return at::_ops::renorm_::redispatch(dispatchKeySet, self, p, dim, maxnorm); + } + + // aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a) + inline at::Tensor unfold(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + return at::_ops::unfold::redispatch(dispatchKeySet, self, dimension, size, step); + } + + // aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + inline at::Tensor unfold_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step); + } + + // aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor + inline at::Tensor unfold_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step); + } + + // aten::equal(Tensor self, Tensor other) -> bool + inline bool equal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::equal::redispatch(dispatchKeySet, self, other); + } + + // aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::pow_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::pow_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::pow_Tensor_Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor + inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::pow_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::pow_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pow_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out) { + return at::_ops::pow_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + inline at::Tensor pow(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::pow_Tensor_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + inline at::Tensor & pow_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::pow__Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + inline at::Tensor & pow_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::pow__Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::float_power_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::float_power_Tensor_Tensor_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor + inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::float_power_Tensor_Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::float_power_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out) { + return at::_ops::float_power_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Scalar(Scalar self, Tensor exponent) -> Tensor + inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & exponent) { + return at::_ops::float_power_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::float_power_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & float_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out) { + return at::_ops::float_power_Tensor_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor + inline at::Tensor float_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::float_power_Tensor_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!) + inline at::Tensor & float_power_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & exponent) { + return at::_ops::float_power__Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!) + inline at::Tensor & float_power_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Tensor & exponent) { + return at::_ops::float_power__Tensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & normal_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double mean=0, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor + inline at::Tensor normal_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean=0, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_functional::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mean, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_float_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, double std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_Tensor_float_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_float::redispatch(dispatchKeySet, mean, std, generator); + } + + // aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, double mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_float_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_Tensor::redispatch(dispatchKeySet, mean, std, generator); + } + + // aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_Tensor_Tensor_out::redispatch(dispatchKeySet, mean, std, generator, out); + } + + // aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & mean, const at::Tensor & std, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_Tensor_Tensor::redispatch(dispatchKeySet, mean, std, generator); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, ::std::optional generator=::std::nullopt, at::TensorOptions options={}) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, dtype, layout, device, pin_memory); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal_symint(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator=::std::nullopt, at::TensorOptions options={}) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, size, generator, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor normal_symint(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::normal_float_float::redispatch(dispatchKeySet, mean, std, size, generator, dtype, layout, device, pin_memory); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, double std, at::IntArrayRef size, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, double mean, double std, at::IntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, c10::fromIntArrayRefSlow(size), generator, out); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, size, generator, out); + } + + // aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_symint_outf(c10::DispatchKeySet dispatchKeySet, double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_float_float_out::redispatch(dispatchKeySet, mean, std, size, generator, out); + } + + // aten::alias(Tensor(a) self) -> Tensor(a) + inline at::Tensor alias(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::alias::redispatch(dispatchKeySet, self); + } + + // aten::_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> () + inline void _amp_foreach_non_finite_check_and_unscale_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_::redispatch(dispatchKeySet, self, found_inf, inv_scale); + } + + // aten::_amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!) + inline at::Tensor & _amp_update_scale_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + return at::_ops::_amp_update_scale_::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + } + + // aten::_foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_add_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_add__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add__List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_add_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_add__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[] + inline ::std::vector _foreach_add(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> () + inline void _foreach_add_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add__Tensor::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_sub_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_sub__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[] + inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_sub_List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> () + inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_sub__List::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_sub(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_sub_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_sub_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_sub__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_mul_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_mul__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_mul_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_mul__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_mul_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_mul__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[] + inline ::std::vector _foreach_mul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_mul_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> () + inline void _foreach_mul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_mul__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_div_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_div__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_div_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_div__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_div_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_div__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[] + inline ::std::vector _foreach_div(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_div_Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> () + inline void _foreach_div_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_div__Tensor::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_max_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_max__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_max_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_max__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_clamp_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_max_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_clamp_max_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_max__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_min_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_min__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_min_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_min__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_clamp_min(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_min_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_clamp_min_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_min__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_maximum_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_maximum__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_maximum_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_maximum__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_maximum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_maximum_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_maximum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_maximum__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[] + inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_minimum_Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> () + inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_minimum__Scalar::redispatch(dispatchKeySet, self, scalar); + } + + // aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[] + inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_minimum_List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> () + inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_minimum__List::redispatch(dispatchKeySet, self, other); + } + + // aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_minimum(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_minimum_ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> () + inline void _foreach_minimum_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_minimum__ScalarList::redispatch(dispatchKeySet, self, scalars); + } + + // aten::_foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcdiv_Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcdiv_ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + inline ::std::vector _foreach_addcdiv(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcdiv_Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcdiv__Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () + inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcdiv__ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + inline void _foreach_addcdiv_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcdiv__Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[] + inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcmul_Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[] + inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcmul_ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[] + inline ::std::vector _foreach_addcmul(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcmul_Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> () + inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcmul__Scalar::redispatch(dispatchKeySet, self, tensor1, tensor2, value); + } + + // aten::_foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> () + inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcmul__ScalarList::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> () + inline void _foreach_addcmul_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcmul__Tensor::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars); + } + + // aten::_foreach_abs(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_abs(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_abs::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_abs_(Tensor(a!)[] self) -> () + inline void _foreach_abs_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_abs_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_acos(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_acos(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_acos::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_acos_(Tensor(a!)[] self) -> () + inline void _foreach_acos_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_acos_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_asin(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_asin(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_asin::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_asin_(Tensor(a!)[] self) -> () + inline void _foreach_asin_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_asin_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_atan(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_atan(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_atan::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_atan_(Tensor(a!)[] self) -> () + inline void _foreach_atan_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_atan_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_ceil(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_ceil(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_ceil::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_ceil_(Tensor(a!)[] self) -> () + inline void _foreach_ceil_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_ceil_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cos(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_cos(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cos::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cos_(Tensor(a!)[] self) -> () + inline void _foreach_cos_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cos_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cosh(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_cosh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cosh::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_cosh_(Tensor(a!)[] self) -> () + inline void _foreach_cosh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_cosh_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erf(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_erf(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erf::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erf_(Tensor(a!)[] self) -> () + inline void _foreach_erf_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erf_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erfc(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_erfc(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erfc::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_erfc_(Tensor(a!)[] self) -> () + inline void _foreach_erfc_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_erfc_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_exp(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_exp(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_exp::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_exp_(Tensor(a!)[] self) -> () + inline void _foreach_exp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_exp_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_expm1(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_expm1(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_expm1::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_expm1_(Tensor(a!)[] self) -> () + inline void _foreach_expm1_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_expm1_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_floor(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_floor(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_floor::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_floor_(Tensor(a!)[] self) -> () + inline void _foreach_floor_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_floor_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_frac(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_frac(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_frac::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_frac_(Tensor(a!)[] self) -> () + inline void _foreach_frac_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_frac_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[] + inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + return at::_ops::_foreach_lerp_List::redispatch(dispatchKeySet, self, tensors1, weights); + } + + // aten::_foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> () + inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + return at::_ops::_foreach_lerp__List::redispatch(dispatchKeySet, self, tensors1, weights); + } + + // aten::_foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[] + inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + return at::_ops::_foreach_lerp_Scalar::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> () + inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + return at::_ops::_foreach_lerp__Scalar::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[] + inline ::std::vector _foreach_lerp(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + return at::_ops::_foreach_lerp_ScalarList::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> () + inline void _foreach_lerp_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + return at::_ops::_foreach_lerp__ScalarList::redispatch(dispatchKeySet, self, tensors1, weight); + } + + // aten::_foreach_lgamma(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_lgamma(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_lgamma::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_lgamma_(Tensor(a!)[] self) -> () + inline void _foreach_lgamma_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_lgamma_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log_(Tensor(a!)[] self) -> () + inline void _foreach_log_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log10(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log10(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log10::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log10_(Tensor(a!)[] self) -> () + inline void _foreach_log10_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log10_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log1p(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log1p(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log1p::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log1p_(Tensor(a!)[] self) -> () + inline void _foreach_log1p_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log1p_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log2(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_log2(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log2::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_log2_(Tensor(a!)[] self) -> () + inline void _foreach_log2_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_log2_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_max(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_max(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_max::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_neg(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_neg(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_neg::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_neg_(Tensor(a!)[] self) -> () + inline void _foreach_neg_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_neg_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[] + inline ::std::vector _foreach_norm(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & ord=2, ::std::optional dtype=::std::nullopt) { + return at::_ops::_foreach_norm_Scalar::redispatch(dispatchKeySet, self, ord, dtype); + } + + // aten::_foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent) { + return at::_ops::_foreach_pow_List::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent) { + return at::_ops::_foreach_pow_Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent) { + return at::_ops::_foreach_pow_ScalarList::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[] + inline ::std::vector _foreach_pow(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, at::TensorList exponent) { + return at::_ops::_foreach_pow_ScalarAndTensor::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> () + inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent) { + return at::_ops::_foreach_pow__List::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> () + inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent) { + return at::_ops::_foreach_pow__Scalar::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> () + inline void _foreach_pow_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent) { + return at::_ops::_foreach_pow__ScalarList::redispatch(dispatchKeySet, self, exponent); + } + + // aten::_foreach_reciprocal(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_reciprocal(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_reciprocal::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_reciprocal_(Tensor(a!)[] self) -> () + inline void _foreach_reciprocal_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_reciprocal_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_round(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_round(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_round::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_round_(Tensor(a!)[] self) -> () + inline void _foreach_round_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_round_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_rsqrt(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_rsqrt(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_rsqrt::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_rsqrt_(Tensor(a!)[] self) -> () + inline void _foreach_rsqrt_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_rsqrt_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sigmoid(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sigmoid(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sigmoid::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sigmoid_(Tensor(a!)[] self) -> () + inline void _foreach_sigmoid_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sigmoid_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sign(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sign(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sign::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sign_(Tensor(a!)[] self) -> () + inline void _foreach_sign_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sign_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sin(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sin(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sin::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sin_(Tensor(a!)[] self) -> () + inline void _foreach_sin_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sin_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sinh(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sinh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sinh::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sinh_(Tensor(a!)[] self) -> () + inline void _foreach_sinh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sinh_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sqrt(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_sqrt(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sqrt::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_sqrt_(Tensor(a!)[] self) -> () + inline void _foreach_sqrt_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_sqrt_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tan(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_tan(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tan::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tan_(Tensor(a!)[] self) -> () + inline void _foreach_tan_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tan_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tanh(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_tanh(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tanh::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_tanh_(Tensor(a!)[] self) -> () + inline void _foreach_tanh_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_tanh_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_trunc(Tensor[] self) -> Tensor[] + inline ::std::vector _foreach_trunc(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_trunc::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_trunc_(Tensor(a!)[] self) -> () + inline void _foreach_trunc_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_trunc_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_zero_(Tensor(a!)[] self) -> () + inline void _foreach_zero_(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_zero_::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> () + inline void _foreach_copy_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking=false) { + return at::_ops::_foreach_copy_::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::_foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out + inline ::std::vector _foreach_copy(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking=false) { + return at::_ops::_foreach_copy::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor + inline at::Tensor bucketize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Tensor::redispatch(dispatchKeySet, self, boundaries, out_int32, right); + } + + // aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Tensor_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out) { + return at::_ops::bucketize_Tensor_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor + inline at::Tensor bucketize(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Scalar::redispatch(dispatchKeySet, self, boundaries, out_int32, right); + } + + // aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor + inline at::Tensor searchsorted(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Tensor::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter); + } + + // aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Tensor_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out) { + return at::_ops::searchsorted_Tensor_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor + inline at::Tensor searchsorted(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Scalar::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter); + } + + // aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32=false, bool right=false, ::std::optional side=::std::nullopt, const ::std::optional & sorter={}) { + return at::_ops::searchsorted_Scalar_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & searchsorted_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out) { + return at::_ops::searchsorted_Scalar_out::redispatch(dispatchKeySet, sorted_sequence, self, out_int32, right, side, sorter, out); + } + + // aten::_convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor + inline at::Tensor _convert_indices_from_coo_to_csr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, bool out_int32=false) { + return at::_ops::_convert_indices_from_coo_to_csr::redispatch(dispatchKeySet, self, size, out_int32); + } + + // aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_coo_to_csr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t size, bool out_int32=false) { + return at::_ops::_convert_indices_from_coo_to_csr_out::redispatch(dispatchKeySet, self, size, out_int32, out); + } + + // aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_coo_to_csr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t size, bool out_int32, at::Tensor & out) { + return at::_ops::_convert_indices_from_coo_to_csr_out::redispatch(dispatchKeySet, self, size, out_int32, out); + } + + // aten::_convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor + inline at::Tensor _convert_indices_from_csr_to_coo(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32=false, bool transpose=false) { + return at::_ops::_convert_indices_from_csr_to_coo::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose); + } + + // aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_csr_to_coo_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32=false, bool transpose=false) { + return at::_ops::_convert_indices_from_csr_to_coo_out::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose, out); + } + + // aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convert_indices_from_csr_to_coo_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose, at::Tensor & out) { + return at::_ops::_convert_indices_from_csr_to_coo_out::redispatch(dispatchKeySet, crow_indices, col_indices, out_int32, transpose, out); + } + + // aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mse_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::mse_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mse_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) { + return at::_ops::mse_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor mse_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::mse_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & mse_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::mse_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & mse_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::mse_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + inline at::Tensor mse_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::mse_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction); + } + + // aten::l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor l1_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::l1_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p=1, const at::Scalar & margin=1, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss_out::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction, out); + } + + // aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & out) { + return at::_ops::multi_margin_loss_out::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction, out); + } + + // aten::multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor multi_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p=1, const at::Scalar & margin=1, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss::redispatch(dispatchKeySet, self, target, p, margin, weight, reduction); + } + + // aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction, grad_input); + } + + // aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multi_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::multi_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction, grad_input); + } + + // aten::multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor + inline at::Tensor multi_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multi_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, p, margin, weight, reduction); + } + + // aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multilabel_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) { + return at::_ops::multilabel_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor multilabel_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::multilabel_margin_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple multilabel_margin_loss_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & is_target, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::multilabel_margin_loss_forward_output::redispatch(dispatchKeySet, self, target, reduction, output, is_target); + } + + // aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple multilabel_margin_loss_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & output, at::Tensor & is_target) { + return at::_ops::multilabel_margin_loss_forward_output::redispatch(dispatchKeySet, self, target, reduction, output, is_target); + } + + // aten::multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target) + inline ::std::tuple multilabel_margin_loss_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::multilabel_margin_loss_forward::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) { + return at::_ops::multilabel_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target, grad_input); + } + + // aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & multilabel_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target, at::Tensor & grad_input) { + return at::_ops::multilabel_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target, grad_input); + } + + // aten::multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor + inline at::Tensor multilabel_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) { + return at::_ops::multilabel_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, is_target); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & out) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out) { + return at::_ops::nll_loss_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss_nd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss_nd::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss_nd_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss_nd::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & out) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nll_loss2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out) { + return at::_ops::nll_loss2d_out::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, out); + } + + // aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, int64_t ignore_index=-100) { + return at::_ops::nll_loss2d::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor + inline at::Tensor nll_loss2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, int64_t reduction=at::Reduction::Mean, c10::SymInt ignore_index=-100) { + return at::_ops::nll_loss2d::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & total_weight, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple nll_loss2d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight) { + return at::_ops::nll_loss2d_forward_output::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index, output, total_weight); + } + + // aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss2d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index) { + return at::_ops::nll_loss2d_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight) + inline ::std::tuple nll_loss2d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + return at::_ops::nll_loss2d_forward::redispatch(dispatchKeySet, self, target, weight, reduction, ignore_index); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & nll_loss2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input) { + return at::_ops::nll_loss2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight, grad_input); + } + + // aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor + inline at::Tensor nll_loss2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + return at::_ops::nll_loss2d_backward::redispatch(dispatchKeySet, grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + + // aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double beta=1.0) { + return at::_ops::smooth_l1_loss_out::redispatch(dispatchKeySet, self, target, reduction, beta, out); + } + + // aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out) { + return at::_ops::smooth_l1_loss_out::redispatch(dispatchKeySet, self, target, reduction, beta, out); + } + + // aten::smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor + inline at::Tensor smooth_l1_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double beta=1.0) { + return at::_ops::smooth_l1_loss::redispatch(dispatchKeySet, self, target, reduction, beta); + } + + // aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + return at::_ops::smooth_l1_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta, grad_input); + } + + // aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & smooth_l1_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & grad_input) { + return at::_ops::smooth_l1_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta, grad_input); + } + + // aten::smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor + inline at::Tensor smooth_l1_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + return at::_ops::smooth_l1_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, beta); + } + + // aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & huber_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double delta=1.0) { + return at::_ops::huber_loss_out::redispatch(dispatchKeySet, self, target, reduction, delta, out); + } + + // aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & huber_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & out) { + return at::_ops::huber_loss_out::redispatch(dispatchKeySet, self, target, reduction, delta, out); + } + + // aten::huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor + inline at::Tensor huber_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean, double delta=1.0) { + return at::_ops::huber_loss::redispatch(dispatchKeySet, self, target, reduction, delta); + } + + // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & huber_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + return at::_ops::huber_loss_backward_out::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta, grad_input); + } + + // aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & huber_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input) { + return at::_ops::huber_loss_backward_out::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta, grad_input); + } + + // aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor + inline at::Tensor huber_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + return at::_ops::huber_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction, delta); + } + + // aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::soft_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out) { + return at::_ops::soft_margin_loss_out::redispatch(dispatchKeySet, self, target, reduction, out); + } + + // aten::soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor + inline at::Tensor soft_margin_loss(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, int64_t reduction=at::Reduction::Mean) { + return at::_ops::soft_margin_loss::redispatch(dispatchKeySet, self, target, reduction); + } + + // aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::soft_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & soft_margin_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input) { + return at::_ops::soft_margin_loss_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, target, reduction, grad_input); + } + + // aten::soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor + inline at::Tensor soft_margin_loss_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + return at::_ops::soft_margin_loss_backward::redispatch(dispatchKeySet, grad_output, self, target, reduction); + } + + // aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & elu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) { + return at::_ops::elu_out::redispatch(dispatchKeySet, self, alpha, scale, input_scale, out); + } + + // aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & elu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, at::Tensor & out) { + return at::_ops::elu_out::redispatch(dispatchKeySet, self, alpha, scale, input_scale, out); + } + + // aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor + inline at::Tensor elu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) { + return at::_ops::elu::redispatch(dispatchKeySet, self, alpha, scale, input_scale); + } + + // aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & elu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) { + return at::_ops::elu_backward_grad_input::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result, grad_input); + } + + // aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & elu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result, at::Tensor & grad_input) { + return at::_ops::elu_backward_grad_input::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result, grad_input); + } + + // aten::elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor + inline at::Tensor elu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) { + return at::_ops::elu_backward::redispatch(dispatchKeySet, grad_output, alpha, scale, input_scale, is_result, self_or_result); + } + + // aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!) + inline at::Tensor & elu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & alpha=1, const at::Scalar & scale=1, const at::Scalar & input_scale=1) { + return at::_ops::elu_::redispatch(dispatchKeySet, self, alpha, scale, input_scale); + } + + // aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=-1) { + return at::_ops::glu_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::glu_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::glu(Tensor self, int dim=-1) -> Tensor + inline at::Tensor glu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=-1) { + return at::_ops::glu::redispatch(dispatchKeySet, self, dim); + } + + // aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & glu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { + return at::_ops::glu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, dim, grad_input); + } + + // aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & glu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, at::Tensor & grad_input) { + return at::_ops::glu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, dim, grad_input); + } + + // aten::glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor + inline at::Tensor glu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { + return at::_ops::glu_backward::redispatch(dispatchKeySet, grad_output, self, dim); + } + + // aten::glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor + inline at::Tensor glu_jvp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_jvp::redispatch(dispatchKeySet, glu, x, dx, dim); + } + + // aten::glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor + inline at::Tensor glu_backward_jvp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_backward_jvp::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim); + } + + // aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardsigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::hardsigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardsigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::hardsigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardsigmoid(Tensor self) -> Tensor + inline at::Tensor hardsigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::hardsigmoid::redispatch(dispatchKeySet, self); + } + + // aten::hardsigmoid_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & hardsigmoid_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::hardsigmoid_::redispatch(dispatchKeySet, self); + } + + // aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardsigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardsigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardsigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::hardsigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor hardsigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardsigmoid_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardtanh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) { + return at::_ops::hardtanh_out::redispatch(dispatchKeySet, self, min_val, max_val, out); + } + + // aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardtanh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & out) { + return at::_ops::hardtanh_out::redispatch(dispatchKeySet, self, min_val, max_val, out); + } + + // aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor + inline at::Tensor hardtanh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) { + return at::_ops::hardtanh::redispatch(dispatchKeySet, self, min_val, max_val); + } + + // aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardtanh_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + return at::_ops::hardtanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, min_val, max_val, grad_input); + } + + // aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & hardtanh_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & grad_input) { + return at::_ops::hardtanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, min_val, max_val, grad_input); + } + + // aten::hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor + inline at::Tensor hardtanh_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + return at::_ops::hardtanh_backward::redispatch(dispatchKeySet, grad_output, self, min_val, max_val); + } + + // aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!) + inline at::Tensor & hardtanh_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & min_val=-1, const at::Scalar & max_val=1) { + return at::_ops::hardtanh_::redispatch(dispatchKeySet, self, min_val, max_val); + } + + // aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::hardswish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::hardswish_out::redispatch(dispatchKeySet, self, out); + } + + // aten::hardswish(Tensor self) -> Tensor + inline at::Tensor hardswish(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::hardswish::redispatch(dispatchKeySet, self); + } + + // aten::hardswish_(Tensor(a!) self) -> Tensor(a!) + inline at::Tensor & hardswish_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self) { + return at::_ops::hardswish_::redispatch(dispatchKeySet, self); + } + + // aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor hardswish_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardswish_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & leaky_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & negative_slope=0.01) { + return at::_ops::leaky_relu_out::redispatch(dispatchKeySet, self, negative_slope, out); + } + + // aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & leaky_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & negative_slope, at::Tensor & out) { + return at::_ops::leaky_relu_out::redispatch(dispatchKeySet, self, negative_slope, out); + } + + // aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor + inline at::Tensor leaky_relu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & negative_slope=0.01) { + return at::_ops::leaky_relu::redispatch(dispatchKeySet, self, negative_slope); + } + + // aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & leaky_relu_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) { + return at::_ops::leaky_relu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result, grad_input); + } + + // aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & leaky_relu_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result, at::Tensor & grad_input) { + return at::_ops::leaky_relu_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result, grad_input); + } + + // aten::leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor + inline at::Tensor leaky_relu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) { + return at::_ops::leaky_relu_backward::redispatch(dispatchKeySet, grad_output, self, negative_slope, self_is_result); + } + + // aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!) + inline at::Tensor & leaky_relu_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, const at::Scalar & negative_slope=0.01) { + return at::_ops::leaky_relu_::redispatch(dispatchKeySet, self, negative_slope); + } + + // aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::log_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_sigmoid_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::log_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::log_sigmoid(Tensor self) -> Tensor + inline at::Tensor log_sigmoid(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log_sigmoid::redispatch(dispatchKeySet, self); + } + + // aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple log_sigmoid_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & buffer, const at::Tensor & self) { + return at::_ops::log_sigmoid_forward_output::redispatch(dispatchKeySet, self, output, buffer); + } + + // aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple log_sigmoid_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & output, at::Tensor & buffer) { + return at::_ops::log_sigmoid_forward_output::redispatch(dispatchKeySet, self, output, buffer); + } + + // aten::log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer) + inline ::std::tuple log_sigmoid_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::log_sigmoid_forward::redispatch(dispatchKeySet, self); + } + + // aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & log_sigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) { + return at::_ops::log_sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, buffer, grad_input); + } + + // aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & log_sigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer, at::Tensor & grad_input) { + return at::_ops::log_sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, buffer, grad_input); + } + + // aten::log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor + inline at::Tensor log_sigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) { + return at::_ops::log_sigmoid_backward::redispatch(dispatchKeySet, grad_output, self, buffer); + } + + // aten::rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise_out::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator, out); + } + + // aten::rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator, at::Tensor & out) { + return at::_ops::rrelu_with_noise_out::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator, out); + } + + // aten::rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor + inline at::Tensor rrelu_with_noise(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator); + } + + // aten::rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor + inline at::Tensor rrelu_with_noise_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) { + return at::_ops::rrelu_with_noise_backward::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result); + } + + // aten::rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise_::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator); + } + + // aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softplus_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20) { + return at::_ops::softplus_out::redispatch(dispatchKeySet, self, beta, threshold, out); + } + + // aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softplus_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & out) { + return at::_ops::softplus_out::redispatch(dispatchKeySet, self, beta, threshold, out); + } + + // aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor + inline at::Tensor softplus(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & beta=1, const at::Scalar & threshold=20) { + return at::_ops::softplus::redispatch(dispatchKeySet, self, beta, threshold); + } + + // aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softplus_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + return at::_ops::softplus_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, beta, threshold, grad_input); + } + + // aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softplus_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & grad_input) { + return at::_ops::softplus_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, beta, threshold, grad_input); + } + + // aten::softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor + inline at::Tensor softplus_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + return at::_ops::softplus_backward::redispatch(dispatchKeySet, grad_output, self, beta, threshold); + } + + // aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softshrink_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::softshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & softshrink_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out) { + return at::_ops::softshrink_out::redispatch(dispatchKeySet, self, lambd, out); + } + + // aten::softshrink(Tensor self, Scalar lambd=0.5) -> Tensor + inline at::Tensor softshrink(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & lambd=0.5) { + return at::_ops::softshrink::redispatch(dispatchKeySet, self, lambd); + } + + // aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softshrink_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::softshrink_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, lambd, grad_input); + } + + // aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & softshrink_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input) { + return at::_ops::softshrink_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, lambd, grad_input); + } + + // aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor + inline at::Tensor softshrink_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) { + return at::_ops::softshrink_backward::redispatch(dispatchKeySet, grad_output, self, lambd); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor + inline at::Tensor mkldnn_adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::mkldnn_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::mkldnn_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::mkldnn_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor mkldnn_adaptive_avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor _adaptive_avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor adaptive_avg_pool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::adaptive_avg_pool3d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor + inline at::Tensor _adaptive_avg_pool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input) { + return at::_ops::adaptive_avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, grad_input); + } + + // aten::_adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor + inline at::Tensor _adaptive_avg_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool3d_backward::redispatch(dispatchKeySet, grad_output, self); + } + + // aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor) + inline ::std::tuple adaptive_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool2d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + inline at::Tensor adaptive_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, indices); + } + + // aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple adaptive_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_out::redispatch(dispatchKeySet, self, output_size, out, indices); + } + + // aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor) + inline ::std::tuple adaptive_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_max_pool3d::redispatch(dispatchKeySet, self, output_size); + } + + // aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & adaptive_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::adaptive_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, indices, grad_input); + } + + // aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor + inline at::Tensor adaptive_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + return at::_ops::adaptive_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, indices); + } + + // aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out) { + return at::_ops::avg_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + inline at::Tensor avg_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool2d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input) { + return at::_ops::avg_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + inline at::Tensor avg_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out) { + return at::_ops::avg_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, out); + } + + // aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor + inline at::Tensor avg_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true, ::std::optional divisor_override=::std::nullopt) { + return at::_ops::avg_pool3d::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input) { + return at::_ops::avg_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, grad_input); + } + + // aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor + inline at::Tensor avg_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + return at::_ops::avg_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + + // aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool2d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices) { + return at::_ops::fractional_max_pool2d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor) + inline ::std::tuple fractional_max_pool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool2d::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples); + } + + // aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::fractional_max_pool2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor + inline at::Tensor fractional_max_pool2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool2d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices); + } + + // aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool3d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fractional_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices) { + return at::_ops::fractional_max_pool3d_output::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples, output, indices); + } + + // aten::fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor) + inline ::std::tuple fractional_max_pool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + return at::_ops::fractional_max_pool3d::redispatch(dispatchKeySet, self, kernel_size, output_size, random_samples); + } + + // aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & fractional_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::fractional_max_pool3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices, grad_input); + } + + // aten::fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor + inline at::Tensor fractional_max_pool3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + return at::_ops::fractional_max_pool3d_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, output_size, indices); + } + + // aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool2d_with_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool2d_with_indices_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices) { + return at::_ops::max_pool2d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + inline ::std::tuple max_pool2d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool2d_with_indices_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool2d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool2d_with_indices_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::max_pool2d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor + inline at::Tensor max_pool2d_with_indices_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool2d_with_indices_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + + // aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool3d_with_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::Tensor & indices, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool3d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple max_pool3d_with_indices_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices) { + return at::_ops::max_pool3d_with_indices_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out, indices); + } + + // aten::max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor) + inline ::std::tuple max_pool3d_with_indices(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool3d_with_indices::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode); + } + + // aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool3d_with_indices_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool3d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & max_pool3d_with_indices_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input) { + return at::_ops::max_pool3d_with_indices_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices, grad_input); + } + + // aten::max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor + inline at::Tensor max_pool3d_with_indices_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + return at::_ops::max_pool3d_with_indices_backward::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, output_size, out); + } + + // aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::max_unpool2d_out::redispatch(dispatchKeySet, self, indices, output_size, out); + } + + // aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + inline at::Tensor max_unpool2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size) { + return at::_ops::max_unpool2d::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size)); + } + + // aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor + inline at::Tensor max_unpool2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) { + return at::_ops::max_unpool2d::redispatch(dispatchKeySet, self, indices, output_size); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding, out); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding, out); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, output_size, stride, padding, out); + } + + // aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_unpool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::max_unpool3d_out::redispatch(dispatchKeySet, self, indices, output_size, stride, padding, out); + } + + // aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + inline at::Tensor max_unpool3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d::redispatch(dispatchKeySet, self, indices, c10::fromIntArrayRefSlow(output_size), stride, padding); + } + + // aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor + inline at::Tensor max_unpool3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::max_unpool3d::redispatch(dispatchKeySet, self, indices, output_size, stride, padding); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d::redispatch(dispatchKeySet, self, padding); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor reflection_pad1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d::redispatch(dispatchKeySet, self, padding); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor reflection_pad2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::reflection_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d::redispatch(dispatchKeySet, self, padding); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & reflection_pad3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::reflection_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::reflection_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor reflection_pad3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::reflection_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad1d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d::redispatch(dispatchKeySet, self, padding); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor + inline at::Tensor replication_pad1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad1d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad2d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d::redispatch(dispatchKeySet, self, padding); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor + inline at::Tensor replication_pad2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad2d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), out); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & replication_pad3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::replication_pad3d_out::redispatch(dispatchKeySet, self, padding, out); + } + + // aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d::redispatch(dispatchKeySet, self, padding); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding), grad_input); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & replication_pad3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input) { + return at::_ops::replication_pad3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, padding, grad_input); + } + + // aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef padding) { + return at::_ops::replication_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, c10::fromIntArrayRefSlow(padding)); + } + + // aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor + inline at::Tensor replication_pad3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + return at::_ops::replication_pad3d_backward::redispatch(dispatchKeySet, grad_output, self, padding); + } + + // aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor + inline at::Tensor _pad_circular(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad) { + return at::_ops::_pad_circular::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad)); + } + + // aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor + inline at::Tensor _pad_circular_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad) { + return at::_ops::_pad_circular::redispatch(dispatchKeySet, self, pad); + } + + // aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor + inline at::Tensor _pad_enum(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, int64_t mode, ::std::optional value=::std::nullopt) { + return at::_ops::_pad_enum::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), mode, value); + } + + // aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor + inline at::Tensor _pad_enum_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, ::std::optional value=::std::nullopt) { + return at::_ops::_pad_enum::redispatch(dispatchKeySet, self, pad, mode, value); + } + + // aten::pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor + inline at::Tensor pad(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, c10::string_view mode="constant", ::std::optional value=::std::nullopt) { + return at::_ops::pad::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), mode, value); + } + + // aten::pad(Tensor self, SymInt[] pad, str mode="constant", float? value=None) -> Tensor + inline at::Tensor pad_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode="constant", ::std::optional value=::std::nullopt) { + return at::_ops::pad::redispatch(dispatchKeySet, self, pad, mode, value); + } + + // aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_linear1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_linear1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_linear1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_linear1d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bilinear2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bilinear2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bilinear2d_aa_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bilinear2d_aa_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_trilinear3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_trilinear3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_trilinear3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_trilinear3d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bicubic2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bicubic2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_bicubic2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bicubic2d_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bicubic2d_aa_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors); + } + + // aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::_upsample_bicubic2d_aa_vec::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors); + } + + // aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest1d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact1d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact1d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact2d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact2d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor upsample_nearest3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest3d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact3d_vec::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors); + } + + // aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::_upsample_nearest_exact3d_vec::redispatch(dispatchKeySet, input, output_size, scale_factors); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales, out); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales, out); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales, out); + } + + // aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_linear1d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales, out); + } + + // aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales); + } + + // aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d::redispatch(dispatchKeySet, self, output_size, align_corners, scales); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_linear1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_linear1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales, grad_input); + } + + // aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales); + } + + // aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor + inline at::Tensor upsample_linear1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_linear1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bilinear2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bilinear2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bilinear2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bilinear2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bilinear2d_aa_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bilinear2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bilinear2d_aa_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bilinear2d_aa_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_bicubic2d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_bicubic2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_bicubic2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_bicubic2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_bicubic2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_bicubic2d_aa_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w, out); + } + + // aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa::redispatch(dispatchKeySet, self, output_size, align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_bicubic2d_aa_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_bicubic2d_aa_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w, grad_input); + } + + // aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_h, scales_w); + } + + // aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_bicubic2d_aa_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_bicubic2d_aa_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_trilinear3d_out::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d::redispatch(dispatchKeySet, self, output_size, align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_trilinear3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_trilinear3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_trilinear3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_trilinear3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::upsample_nearest1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact1d_out::redispatch(dispatchKeySet, self, output_size, scales, out); + } + + // aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales); + } + + // aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d::redispatch(dispatchKeySet, self, output_size, scales); + } + + // aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales); + } + + // aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d::redispatch(dispatchKeySet, self, output_size, scales); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::upsample_nearest1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact1d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact1d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales, grad_input); + } + + // aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales); + } + + // aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor upsample_nearest1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::upsample_nearest1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales); + } + + // aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales); + } + + // aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor + inline at::Tensor _upsample_nearest_exact1d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales=::std::nullopt) { + return at::_ops::_upsample_nearest_exact1d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact2d_out::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w, out); + } + + // aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w); + } + + // aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d::redispatch(dispatchKeySet, self, output_size, scales_h, scales_w); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w); + } + + // aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_h, scales_w); + } + + // aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact2d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_h, scales_w); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::upsample_nearest3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out) { + return at::_ops::_upsample_nearest_exact3d_out::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w, out); + } + + // aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d::redispatch(dispatchKeySet, self, output_size, scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & upsample_nearest3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::upsample_nearest3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & _upsample_nearest_exact3d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input) { + return at::_ops::_upsample_nearest_exact3d_backward_grad_input::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w, grad_input); + } + + // aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w); + } + + // aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor upsample_nearest3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::upsample_nearest3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef output_size, at::IntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(output_size), c10::fromIntArrayRefSlow(input_size), scales_d, scales_h, scales_w); + } + + // aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor + inline at::Tensor _upsample_nearest_exact3d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d=::std::nullopt, ::std::optional scales_h=::std::nullopt, ::std::optional scales_w=::std::nullopt) { + return at::_ops::_upsample_nearest_exact3d_backward::redispatch(dispatchKeySet, grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + + // aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & sigmoid_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & sigmoid_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input) { + return at::_ops::sigmoid_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor + inline at::Tensor sigmoid_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::sigmoid_backward::redispatch(dispatchKeySet, grad_output, output); + } + + // aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & logit_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, eps, grad_input); + } + + // aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & logit_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps, at::Tensor & grad_input) { + return at::_ops::logit_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, eps, grad_input); + } + + // aten::logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor + inline at::Tensor logit_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::logit_backward::redispatch(dispatchKeySet, grad_output, self, eps); + } + + // aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & tanh_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::tanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!) + inline at::Tensor & tanh_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input) { + return at::_ops::tanh_backward_grad_input::redispatch(dispatchKeySet, grad_output, output, grad_input); + } + + // aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor + inline at::Tensor tanh_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output) { + return at::_ops::tanh_backward::redispatch(dispatchKeySet, grad_output, output); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_transpose3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_transpose3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation, out); + } + + // aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef output_padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_transpose3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_transpose3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef output_padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_transpose3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & thnn_conv2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::thnn_conv2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor + inline at::Tensor thnn_conv2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::thnn_conv2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor + inline at::Tensor thnn_conv2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::thnn_conv2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & output) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & _slow_conv2d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output) { + return at::_ops::_slow_conv2d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + inline at::Tensor _slow_conv2d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::_slow_conv2d_forward::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor + inline at::Tensor _slow_conv2d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::_slow_conv2d_forward::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias) { + return at::_ops::_slow_conv2d_backward_grad_input::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, grad_input, grad_weight, grad_bias); + } + + // aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple _slow_conv2d_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask); + } + + // aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias) + inline ::std::tuple _slow_conv2d_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conv_depthwise2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::_conv_depthwise2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + inline at::Tensor _conv_depthwise2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::_conv_depthwise2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor + inline at::Tensor _conv_depthwise2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::_conv_depthwise2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + inline at::Tensor conv_depthwise3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::conv_depthwise3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor + inline at::Tensor conv_depthwise3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::conv_depthwise3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), out); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out) { + return at::_ops::slow_conv3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, out); + } + + // aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor + inline at::Tensor slow_conv3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0) { + return at::_ops::slow_conv3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor + inline at::Tensor slow_conv3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0)) { + return at::_ops::slow_conv3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & output) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!) + inline at::Tensor & slow_conv3d_forward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output) { + return at::_ops::slow_conv3d_forward_output::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, output); + } + + // aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + inline at::Tensor slow_conv3d_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding) { + return at::_ops::slow_conv3d_forward::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding)); + } + + // aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor + inline at::Tensor slow_conv3d_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + return at::_ops::slow_conv3d_forward::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding); + } + + // aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated2d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated2d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated2d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated2d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated3d(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated3d::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation)); + } + + // aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor + inline at::Tensor slow_conv_dilated3d_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated3d::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride, out); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride, out); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride, out); + } + + // aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col2im_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::col2im_out::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride, out); + } + + // aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + inline at::Tensor col2im(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), kernel_size, dilation, padding, stride); + } + + // aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + inline at::Tensor col2im_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::col2im::redispatch(dispatchKeySet, self, output_size, kernel_size, dilation, padding, stride); + } + + // aten::column_stack(Tensor[] tensors) -> Tensor + inline at::Tensor column_stack(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::column_stack::redispatch(dispatchKeySet, tensors); + } + + // aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & column_stack_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::column_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & column_stack_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::column_stack_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & im2col_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::im2col_out::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride, out); + } + + // aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & im2col_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::im2col_out::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride, out); + } + + // aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor + inline at::Tensor im2col(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + return at::_ops::im2col::redispatch(dispatchKeySet, self, kernel_size, dilation, padding, stride); + } + + // aten::isfinite(Tensor self) -> Tensor + inline at::Tensor isfinite(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isfinite::redispatch(dispatchKeySet, self); + } + + // aten::isinf(Tensor self) -> Tensor + inline at::Tensor isinf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isinf::redispatch(dispatchKeySet, self); + } + + // aten::record_stream(Tensor(a!) self, Stream s) -> () + inline void record_stream(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, at::Stream s) { + return at::_ops::record_stream::redispatch(dispatchKeySet, self, s); + } + + // aten::isposinf(Tensor self) -> Tensor + inline at::Tensor isposinf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isposinf::redispatch(dispatchKeySet, self); + } + + // aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isposinf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isposinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isposinf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isposinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isneginf(Tensor self) -> Tensor + inline at::Tensor isneginf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::isneginf::redispatch(dispatchKeySet, self); + } + + // aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isneginf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isneginf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isneginf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isneginf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor + inline at::Tensor _add_batch_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t batch_dim, int64_t level) { + return at::_ops::_add_batch_dim::redispatch(dispatchKeySet, self, batch_dim, level); + } + + // aten::_remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor + inline at::Tensor _remove_batch_dim(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, int64_t batch_size, int64_t out_dim) { + return at::_ops::_remove_batch_dim::redispatch(dispatchKeySet, self, level, batch_size, out_dim); + } + + // aten::_remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor + inline at::Tensor _remove_batch_dim_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, c10::SymInt batch_size, int64_t out_dim) { + return at::_ops::_remove_batch_dim::redispatch(dispatchKeySet, self, level, batch_size, out_dim); + } + + // aten::special_entr(Tensor self) -> Tensor + inline at::Tensor special_entr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_entr::redispatch(dispatchKeySet, self); + } + + // aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_entr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_entr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_entr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_entr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtri(Tensor self) -> Tensor + inline at::Tensor special_ndtri(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_ndtri::redispatch(dispatchKeySet, self); + } + + // aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtri_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_ndtri_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtri_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_ndtri_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log_ndtr(Tensor self) -> Tensor + inline at::Tensor special_log_ndtr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_log_ndtr::redispatch(dispatchKeySet, self); + } + + // aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log_ndtr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_log_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log_ndtr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_log_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_expm1(Tensor self) -> Tensor + inline at::Tensor special_expm1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_expm1::redispatch(dispatchKeySet, self); + } + + // aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expm1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expm1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_exp2(Tensor self) -> Tensor + inline at::Tensor special_exp2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_exp2::redispatch(dispatchKeySet, self); + } + + // aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_exp2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_exp2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_exp2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_psi(Tensor self) -> Tensor + inline at::Tensor special_psi(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_psi::redispatch(dispatchKeySet, self); + } + + // aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_psi_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_psi_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_psi_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_psi_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_digamma(Tensor self) -> Tensor + inline at::Tensor special_digamma(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_digamma::redispatch(dispatchKeySet, self); + } + + // aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_digamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_digamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_digamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_gammaln(Tensor self) -> Tensor + inline at::Tensor special_gammaln(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_gammaln::redispatch(dispatchKeySet, self); + } + + // aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaln_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_gammaln_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaln_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_gammaln_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erf(Tensor self) -> Tensor + inline at::Tensor special_erf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erf::redispatch(dispatchKeySet, self); + } + + // aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfc(Tensor self) -> Tensor + inline at::Tensor special_erfc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erfc::redispatch(dispatchKeySet, self); + } + + // aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfcx(Tensor self) -> Tensor + inline at::Tensor special_erfcx(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erfcx::redispatch(dispatchKeySet, self); + } + + // aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfcx_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfcx_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfcx_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfcx_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfinv(Tensor self) -> Tensor + inline at::Tensor special_erfinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_erfinv::redispatch(dispatchKeySet, self); + } + + // aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_erfinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_erfinv_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtr(Tensor self) -> Tensor + inline at::Tensor special_ndtr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_ndtr::redispatch(dispatchKeySet, self); + } + + // aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_ndtr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_ndtr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_xlog1py(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlog1py::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor + inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlog1py_self_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor special_xlog1py(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlog1py_other_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlog1py_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlog1py_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlog1py_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlog1py_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlog1py_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlog1py_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::special_xlog1py_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlogy::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor + inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlogy_self_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor special_xlogy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlogy_other_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_xlogy_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlogy_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_xlogy_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_xlogy_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_xlogy_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_xlogy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::special_xlogy_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_zeta::redispatch(dispatchKeySet, self, other); + } + + // aten::special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor + inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_zeta_self_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor + inline at::Tensor special_zeta(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_zeta_other_scalar::redispatch(dispatchKeySet, self, other); + } + + // aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_zeta_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_zeta_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::special_zeta_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_zeta_self_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::special_zeta_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_zeta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::special_zeta_other_scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_i0(Tensor self) -> Tensor + inline at::Tensor special_i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i0::redispatch(dispatchKeySet, self); + } + + // aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i0e(Tensor self) -> Tensor + inline at::Tensor special_i0e(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i0e::redispatch(dispatchKeySet, self); + } + + // aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0e_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i0e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i0e_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i0e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1(Tensor self) -> Tensor + inline at::Tensor special_i1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i1::redispatch(dispatchKeySet, self); + } + + // aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1e(Tensor self) -> Tensor + inline at::Tensor special_i1e(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_i1e::redispatch(dispatchKeySet, self); + } + + // aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1e_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_i1e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_i1e_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_i1e_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_logit(Tensor self, float? eps=None) -> Tensor + inline at::Tensor special_logit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::special_logit::redispatch(dispatchKeySet, self, eps); + } + + // aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional eps=::std::nullopt) { + return at::_ops::special_logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional eps, at::Tensor & out) { + return at::_ops::special_logit_out::redispatch(dispatchKeySet, self, eps, out); + } + + // aten::special_polygamma(int n, Tensor self) -> Tensor + inline at::Tensor special_polygamma(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self) { + return at::_ops::special_polygamma::redispatch(dispatchKeySet, n, self); + } + + // aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_polygamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, const at::Tensor & self) { + return at::_ops::special_polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_polygamma_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_polygamma_out::redispatch(dispatchKeySet, n, self, out); + } + + // aten::special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor + inline at::Tensor special_logsumexp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::special_logsumexp::redispatch(dispatchKeySet, self, dim, keepdim); + } + + // aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logsumexp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false) { + return at::_ops::special_logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_logsumexp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out) { + return at::_ops::special_logsumexp_out::redispatch(dispatchKeySet, self, dim, keepdim, out); + } + + // aten::special_expit(Tensor self) -> Tensor + inline at::Tensor special_expit(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_expit::redispatch(dispatchKeySet, self); + } + + // aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expit_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_expit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_expit_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_expit_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_sinc(Tensor self) -> Tensor + inline at::Tensor special_sinc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_sinc::redispatch(dispatchKeySet, self); + } + + // aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_sinc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_sinc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_sinc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_round(Tensor self, *, int decimals=0) -> Tensor + inline at::Tensor special_round(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals=0) { + return at::_ops::special_round::redispatch(dispatchKeySet, self, decimals); + } + + // aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_round_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t decimals=0) { + return at::_ops::special_round_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_round_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t decimals, at::Tensor & out) { + return at::_ops::special_round_out::redispatch(dispatchKeySet, self, decimals, out); + } + + // aten::special_log1p(Tensor self) -> Tensor + inline at::Tensor special_log1p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_log1p::redispatch(dispatchKeySet, self); + } + + // aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log1p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_log1p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor special_log_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::special_log_softmax::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammainc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammainc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammainc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_gammainc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammainc(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_gammainc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammainc::redispatch(dispatchKeySet, self, other); + } + + // aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaincc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammaincc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_gammaincc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::special_gammaincc_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::special_gammaincc(Tensor self, Tensor other) -> Tensor + inline at::Tensor special_gammaincc(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::special_gammaincc::redispatch(dispatchKeySet, self, other); + } + + // aten::special_multigammaln(Tensor self, int p) -> Tensor + inline at::Tensor special_multigammaln(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p) { + return at::_ops::special_multigammaln::redispatch(dispatchKeySet, self, p); + } + + // aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_multigammaln_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t p) { + return at::_ops::special_multigammaln_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_multigammaln_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t p, at::Tensor & out) { + return at::_ops::special_multigammaln_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor special_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::special_softmax::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_fft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_fft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ifft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ifft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_rfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_rfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_irfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_irfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_hfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_hfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ihfft(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor + inline at::Tensor fft_ihfft_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft::redispatch(dispatchKeySet, self, n, dim, norm); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n.has_value() ? ::std::make_optional(c10::SymInt(*n)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional n=::std::nullopt, int64_t dim=-1, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft_out::redispatch(dispatchKeySet, self, n, dim, norm, out); + } + + // aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_fft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_fft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ifft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ifft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_rfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_rfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_irfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_irfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_hfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_hfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ihfft2(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor + inline at::Tensor fft_ihfft2_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::IntArrayRef dim={-2,-1}, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfft2_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfft2_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_fftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_fftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_fftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ifftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ifftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ifftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ifftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_rfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_rfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_rfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_irfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_irfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_irfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_irfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_hfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_hfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_hfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_hfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ihfftn(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm); + } + + // aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor + inline at::Tensor fft_ihfftn_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn::redispatch(dispatchKeySet, self, s, dim, norm); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*s)) : ::std::nullopt, dim, norm, out); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::OptionalSymIntArrayRef s=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, ::std::optional norm=::std::nullopt) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_ihfftn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out) { + return at::_ops::fft_ihfftn_out::redispatch(dispatchKeySet, self, s, dim, norm, out); + } + + // aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_fftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d=1.0, at::TensorOptions options={}) { + return at::_ops::fft_fftfreq::redispatch(dispatchKeySet, n, d, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_fftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::fft_fftfreq::redispatch(dispatchKeySet, n, d, dtype, layout, device, pin_memory); + } + + // aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftfreq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, double d=1.0) { + return at::_ops::fft_fftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_fftfreq_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, at::Tensor & out) { + return at::_ops::fft_fftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_rfftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d=1.0, at::TensorOptions options={}) { + return at::_ops::fft_rfftfreq::redispatch(dispatchKeySet, n, d, c10::optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt()); + } + + // aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor fft_rfftfreq(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + return at::_ops::fft_rfftfreq::redispatch(dispatchKeySet, n, d, dtype, layout, device, pin_memory); + } + + // aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftfreq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t n, double d=1.0) { + return at::_ops::fft_rfftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fft_rfftfreq_outf(c10::DispatchKeySet dispatchKeySet, int64_t n, double d, at::Tensor & out) { + return at::_ops::fft_rfftfreq_out::redispatch(dispatchKeySet, n, d, out); + } + + // aten::fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor + inline at::Tensor fft_fftshift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt) { + return at::_ops::fft_fftshift::redispatch(dispatchKeySet, self, dim); + } + + // aten::fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor + inline at::Tensor fft_ifftshift(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt) { + return at::_ops::fft_ifftshift::redispatch(dispatchKeySet, self, dim); + } + + // aten::linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info) + inline ::std::tuple linalg_cholesky_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false, bool check_errors=false) { + return at::_ops::linalg_cholesky_ex::redispatch(dispatchKeySet, self, upper, check_errors); + } + + // aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) + inline ::std::tuple linalg_cholesky_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & L, at::Tensor & info, const at::Tensor & self, bool upper=false, bool check_errors=false) { + return at::_ops::linalg_cholesky_ex_L::redispatch(dispatchKeySet, self, upper, check_errors, L, info); + } + + // aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info) + inline ::std::tuple linalg_cholesky_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, bool check_errors, at::Tensor & L, at::Tensor & info) { + return at::_ops::linalg_cholesky_ex_L::redispatch(dispatchKeySet, self, upper, check_errors, L, info); + } + + // aten::linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor + inline at::Tensor linalg_cholesky(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper=false) { + return at::_ops::linalg_cholesky::redispatch(dispatchKeySet, self, upper); + } + + // aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cholesky_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool upper=false) { + return at::_ops::linalg_cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cholesky_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool upper, at::Tensor & out) { + return at::_ops::linalg_cholesky_out::redispatch(dispatchKeySet, self, upper, out); + } + + // aten::linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor + inline at::Tensor linalg_cross(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t dim=-1) { + return at::_ops::linalg_cross::redispatch(dispatchKeySet, self, other, dim); + } + + // aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cross_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, int64_t dim=-1) { + return at::_ops::linalg_cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cross_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t dim, at::Tensor & out) { + return at::_ops::linalg_cross_out::redispatch(dispatchKeySet, self, other, dim, out); + } + + // aten::linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots) + inline ::std::tuple linalg_lu_factor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu_factor::redispatch(dispatchKeySet, A, pivot); + } + + // aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + inline ::std::tuple linalg_lu_factor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu_factor_out::redispatch(dispatchKeySet, A, pivot, LU, pivots); + } + + // aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots) + inline ::std::tuple linalg_lu_factor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, at::Tensor & LU, at::Tensor & pivots) { + return at::_ops::linalg_lu_factor_out::redispatch(dispatchKeySet, A, pivot, LU, pivots); + } + + // aten::linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info) + inline ::std::tuple linalg_lu_factor_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true, bool check_errors=false) { + return at::_ops::linalg_lu_factor_ex::redispatch(dispatchKeySet, A, pivot, check_errors); + } + + // aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_lu_factor_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info, const at::Tensor & A, bool pivot=true, bool check_errors=false) { + return at::_ops::linalg_lu_factor_ex_out::redispatch(dispatchKeySet, A, pivot, check_errors, LU, pivots, info); + } + + // aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_lu_factor_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, bool check_errors, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info) { + return at::_ops::linalg_lu_factor_ex_out::redispatch(dispatchKeySet, A, pivot, check_errors, LU, pivots, info); + } + + // aten::linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U) + inline ::std::tuple linalg_lu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu::redispatch(dispatchKeySet, A, pivot); + } + + // aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple linalg_lu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & P, at::Tensor & L, at::Tensor & U, const at::Tensor & A, bool pivot=true) { + return at::_ops::linalg_lu_out::redispatch(dispatchKeySet, A, pivot, P, L, U); + } + + // aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) + inline ::std::tuple linalg_lu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U) { + return at::_ops::linalg_lu_out::redispatch(dispatchKeySet, A, pivot, P, L, U); + } + + // aten::linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor + inline at::Tensor linalg_lu_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left=true, bool adjoint=false) { + return at::_ops::linalg_lu_solve::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint); + } + + // aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_lu_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left=true, bool adjoint=false) { + return at::_ops::linalg_lu_solve_out::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint, out); + } + + // aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_lu_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint, at::Tensor & out) { + return at::_ops::linalg_lu_solve_out::redispatch(dispatchKeySet, LU, pivots, B, left, adjoint, out); + } + + // aten::_linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots) + inline ::std::tuple _linalg_det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::_linalg_det::redispatch(dispatchKeySet, A); + } + + // aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + inline ::std::tuple _linalg_det_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A) { + return at::_ops::_linalg_det_result::redispatch(dispatchKeySet, A, result, LU, pivots); + } + + // aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) + inline ::std::tuple _linalg_det_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots) { + return at::_ops::_linalg_det_result::redispatch(dispatchKeySet, A, result, LU, pivots); + } + + // aten::linalg_det(Tensor A) -> Tensor + inline at::Tensor linalg_det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::linalg_det::redispatch(dispatchKeySet, A); + } + + // aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_det_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A) { + return at::_ops::linalg_det_out::redispatch(dispatchKeySet, A, out); + } + + // aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_det_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & out) { + return at::_ops::linalg_det_out::redispatch(dispatchKeySet, A, out); + } + + // aten::det(Tensor self) -> Tensor + inline at::Tensor det(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::det::redispatch(dispatchKeySet, self); + } + + // aten::linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info) + inline ::std::tuple linalg_ldl_factor_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian=false, bool check_errors=false) { + return at::_ops::linalg_ldl_factor_ex::redispatch(dispatchKeySet, self, hermitian, check_errors); + } + + // aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_ldl_factor_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info, const at::Tensor & self, bool hermitian=false, bool check_errors=false) { + return at::_ops::linalg_ldl_factor_ex_out::redispatch(dispatchKeySet, self, hermitian, check_errors, LD, pivots, info); + } + + // aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) + inline ::std::tuple linalg_ldl_factor_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian, bool check_errors, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info) { + return at::_ops::linalg_ldl_factor_ex_out::redispatch(dispatchKeySet, self, hermitian, check_errors, LD, pivots, info); + } + + // aten::linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots) + inline ::std::tuple linalg_ldl_factor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian=false) { + return at::_ops::linalg_ldl_factor::redispatch(dispatchKeySet, self, hermitian); + } + + // aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots) + inline ::std::tuple linalg_ldl_factor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & LD, at::Tensor & pivots, const at::Tensor & self, bool hermitian=false) { + return at::_ops::linalg_ldl_factor_out::redispatch(dispatchKeySet, self, hermitian, LD, pivots); + } + + // aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots) + inline ::std::tuple linalg_ldl_factor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool hermitian, at::Tensor & LD, at::Tensor & pivots) { + return at::_ops::linalg_ldl_factor_out::redispatch(dispatchKeySet, self, hermitian, LD, pivots); + } + + // aten::linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor + inline at::Tensor linalg_ldl_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian=false) { + return at::_ops::linalg_ldl_solve::redispatch(dispatchKeySet, LD, pivots, B, hermitian); + } + + // aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_ldl_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian=false) { + return at::_ops::linalg_ldl_solve_out::redispatch(dispatchKeySet, LD, pivots, B, hermitian, out); + } + + // aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_ldl_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_ldl_solve_out::redispatch(dispatchKeySet, LD, pivots, B, hermitian, out); + } + + // aten::linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values) + inline ::std::tuple linalg_lstsq(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & b, ::std::optional rcond=::std::nullopt, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_lstsq::redispatch(dispatchKeySet, self, b, rcond, driver); + } + + // aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) + inline ::std::tuple linalg_lstsq_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values, const at::Tensor & self, const at::Tensor & b, ::std::optional rcond=::std::nullopt, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_lstsq_out::redispatch(dispatchKeySet, self, b, rcond, driver, solution, residuals, rank, singular_values); + } + + // aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) + inline ::std::tuple linalg_lstsq_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values) { + return at::_ops::linalg_lstsq_out::redispatch(dispatchKeySet, self, b, rcond, driver, solution, residuals, rank, singular_values); + } + + // aten::linalg_matmul(Tensor self, Tensor other) -> Tensor + inline at::Tensor linalg_matmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::linalg_matmul::redispatch(dispatchKeySet, self, other); + } + + // aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::linalg_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::linalg_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor + inline at::Tensor linalg_vecdot(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & y, int64_t dim=-1) { + return at::_ops::linalg_vecdot::redispatch(dispatchKeySet, x, y, dim); + } + + // aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vecdot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & y, int64_t dim=-1) { + return at::_ops::linalg_vecdot_out::redispatch(dispatchKeySet, x, y, dim, out); + } + + // aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vecdot_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & y, int64_t dim, at::Tensor & out) { + return at::_ops::linalg_vecdot_out::redispatch(dispatchKeySet, x, y, dim, out); + } + + // aten::linalg_matrix_exp(Tensor self) -> Tensor + inline at::Tensor linalg_matrix_exp(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::linalg_matrix_exp::redispatch(dispatchKeySet, self); + } + + // aten::_linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots) + inline ::std::tuple _linalg_slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::_linalg_slogdet::redispatch(dispatchKeySet, A); + } + + // aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + inline ::std::tuple _linalg_slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots, const at::Tensor & A) { + return at::_ops::_linalg_slogdet_sign::redispatch(dispatchKeySet, A, sign, logabsdet, LU, pivots); + } + + // aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) + inline ::std::tuple _linalg_slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots) { + return at::_ops::_linalg_slogdet_sign::redispatch(dispatchKeySet, A, sign, logabsdet, LU, pivots); + } + + // aten::linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet) + inline ::std::tuple linalg_slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::linalg_slogdet::redispatch(dispatchKeySet, A); + } + + // aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple linalg_slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, const at::Tensor & A) { + return at::_ops::linalg_slogdet_out::redispatch(dispatchKeySet, A, sign, logabsdet); + } + + // aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple linalg_slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet) { + return at::_ops::linalg_slogdet_out::redispatch(dispatchKeySet, A, sign, logabsdet); + } + + // aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet) + inline ::std::tuple slogdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::slogdet::redispatch(dispatchKeySet, self); + } + + // aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple slogdet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & sign, at::Tensor & logabsdet, const at::Tensor & self) { + return at::_ops::slogdet_out::redispatch(dispatchKeySet, self, sign, logabsdet); + } + + // aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet) + inline ::std::tuple slogdet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & sign, at::Tensor & logabsdet) { + return at::_ops::slogdet_out::redispatch(dispatchKeySet, self, sign, logabsdet); + } + + // aten::logdet(Tensor self) -> Tensor + inline at::Tensor logdet(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::logdet::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors) + inline ::std::tuple linalg_eig(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::linalg_eig::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eig_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigenvalues, at::Tensor & eigenvectors, const at::Tensor & self) { + return at::_ops::linalg_eig_out::redispatch(dispatchKeySet, self, eigenvalues, eigenvectors); + } + + // aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eig_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & eigenvalues, at::Tensor & eigenvectors) { + return at::_ops::linalg_eig_out::redispatch(dispatchKeySet, self, eigenvalues, eigenvectors); + } + + // aten::_linalg_eigvals(Tensor self) -> Tensor + inline at::Tensor _linalg_eigvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_linalg_eigvals::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eigvals(Tensor self) -> Tensor + inline at::Tensor linalg_eigvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::linalg_eigvals::redispatch(dispatchKeySet, self); + } + + // aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvals_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::linalg_eigvals_out::redispatch(dispatchKeySet, self, out); + } + + // aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvals_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::linalg_eigvals_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_linalg_eigh(Tensor A, str UPLO="L", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors) + inline ::std::tuple _linalg_eigh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view UPLO="L", bool compute_v=true) { + return at::_ops::_linalg_eigh::redispatch(dispatchKeySet, A, UPLO, compute_v); + } + + // aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple _linalg_eigh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigenvalues, at::Tensor & eigenvectors, const at::Tensor & A, c10::string_view UPLO="L", bool compute_v=true) { + return at::_ops::_linalg_eigh_eigenvalues::redispatch(dispatchKeySet, A, UPLO, compute_v, eigenvalues, eigenvectors); + } + + // aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO="L", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple _linalg_eigh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view UPLO, bool compute_v, at::Tensor & eigenvalues, at::Tensor & eigenvectors) { + return at::_ops::_linalg_eigh_eigenvalues::redispatch(dispatchKeySet, A, UPLO, compute_v, eigenvalues, eigenvectors); + } + + // aten::linalg_eigh(Tensor self, str UPLO="L") -> (Tensor eigenvalues, Tensor eigenvectors) + inline ::std::tuple linalg_eigh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigh::redispatch(dispatchKeySet, self, UPLO); + } + + // aten::linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eigh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & eigvals, at::Tensor & eigvecs, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigh_eigvals::redispatch(dispatchKeySet, self, UPLO, eigvals, eigvecs); + } + + // aten::linalg_eigh.eigvals(Tensor self, str UPLO="L", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) + inline ::std::tuple linalg_eigh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs) { + return at::_ops::linalg_eigh_eigvals::redispatch(dispatchKeySet, self, UPLO, eigvals, eigvecs); + } + + // aten::linalg_eigvalsh(Tensor self, str UPLO="L") -> Tensor + inline at::Tensor linalg_eigvalsh(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigvalsh::redispatch(dispatchKeySet, self, UPLO); + } + + // aten::linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvalsh_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view UPLO="L") { + return at::_ops::linalg_eigvalsh_out::redispatch(dispatchKeySet, self, UPLO, out); + } + + // aten::linalg_eigvalsh.out(Tensor self, str UPLO="L", *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_eigvalsh_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view UPLO, at::Tensor & out) { + return at::_ops::linalg_eigvalsh_out::redispatch(dispatchKeySet, self, UPLO, out); + } + + // aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor + inline at::Tensor linalg_householder_product(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tau) { + return at::_ops::linalg_householder_product::redispatch(dispatchKeySet, input, tau); + } + + // aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_householder_product_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & tau) { + return at::_ops::linalg_householder_product_out::redispatch(dispatchKeySet, input, tau, out); + } + + // aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_householder_product_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tau, at::Tensor & out) { + return at::_ops::linalg_householder_product_out::redispatch(dispatchKeySet, input, tau, out); + } + + // aten::linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info) + inline ::std::tuple linalg_inv_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool check_errors=false) { + return at::_ops::linalg_inv_ex::redispatch(dispatchKeySet, A, check_errors); + } + + // aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) + inline ::std::tuple linalg_inv_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & inverse, at::Tensor & info, const at::Tensor & A, bool check_errors=false) { + return at::_ops::linalg_inv_ex_inverse::redispatch(dispatchKeySet, A, check_errors, inverse, info); + } + + // aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info) + inline ::std::tuple linalg_inv_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool check_errors, at::Tensor & inverse, at::Tensor & info) { + return at::_ops::linalg_inv_ex_inverse::redispatch(dispatchKeySet, A, check_errors, inverse, info); + } + + // aten::linalg_inv(Tensor A) -> Tensor + inline at::Tensor linalg_inv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A) { + return at::_ops::linalg_inv::redispatch(dispatchKeySet, A); + } + + // aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_inv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A) { + return at::_ops::linalg_inv_out::redispatch(dispatchKeySet, A, out); + } + + // aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_inv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, at::Tensor & out) { + return at::_ops::linalg_inv_out::redispatch(dispatchKeySet, A, out); + } + + // aten::inverse(Tensor self) -> Tensor + inline at::Tensor inverse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::inverse::redispatch(dispatchKeySet, self); + } + + // aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inverse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::inverse_out::redispatch(dispatchKeySet, self, out); + } + + // aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inverse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::inverse_out::redispatch(dispatchKeySet, self, out); + } + + // aten::inner(Tensor self, Tensor other) -> Tensor + inline at::Tensor inner(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::inner::redispatch(dispatchKeySet, self, other); + } + + // aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inner_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::inner_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & inner_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::inner_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::outer(Tensor self, Tensor vec2) -> Tensor + inline at::Tensor outer(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::outer::redispatch(dispatchKeySet, self, vec2); + } + + // aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & outer_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::outer_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & outer_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out) { + return at::_ops::outer_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::ger(Tensor self, Tensor vec2) -> Tensor + inline at::Tensor ger(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::ger::redispatch(dispatchKeySet, self, vec2); + } + + // aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ger_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & vec2) { + return at::_ops::ger_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ger_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out) { + return at::_ops::ger_out::redispatch(dispatchKeySet, self, vec2, out); + } + + // aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & ord=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm_ord_str::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & ord=::std::nullopt, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_norm_ord_str_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_norm_ord_str_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_vector_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_vector_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vector_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & ord=2, at::OptionalIntArrayRef dim=::std::nullopt, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_vector_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_vector_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_vector_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_matrix_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_matrix_norm_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor + inline at::Tensor linalg_matrix_norm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord="fro", at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm_str_ord::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype); + } + + // aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view ord="fro", at::IntArrayRef dim={-2,-1}, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::linalg_matrix_norm_str_ord_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::linalg_matrix_norm_str_ord_out::redispatch(dispatchKeySet, self, ord, dim, keepdim, dtype, out); + } + + // aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + inline ::std::tuple _linalg_svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, ::std::optional driver=::std::nullopt) { + return at::_ops::_linalg_svd::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver); + } + + // aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple _linalg_svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=false, bool compute_uv=true, ::std::optional driver=::std::nullopt) { + return at::_ops::_linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver, U, S, Vh); + } + + // aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple _linalg_svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh) { + return at::_ops::_linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, compute_uv, driver, U, S, Vh); + } + + // aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh) + inline ::std::tuple linalg_svd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices=true, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svd::redispatch(dispatchKeySet, A, full_matrices, driver); + } + + // aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple linalg_svd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & U, at::Tensor & S, at::Tensor & Vh, const at::Tensor & A, bool full_matrices=true, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, driver, U, S, Vh); + } + + // aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) + inline ::std::tuple linalg_svd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, bool full_matrices, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh) { + return at::_ops::linalg_svd_U::redispatch(dispatchKeySet, A, full_matrices, driver, U, S, Vh); + } + + // aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor + inline at::Tensor linalg_svdvals(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svdvals::redispatch(dispatchKeySet, A, driver); + } + + // aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_svdvals_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A, ::std::optional driver=::std::nullopt) { + return at::_ops::linalg_svdvals_out::redispatch(dispatchKeySet, A, driver, out); + } + + // aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_svdvals_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, ::std::optional driver, at::Tensor & out) { + return at::_ops::linalg_svdvals_out::redispatch(dispatchKeySet, A, driver, out); + } + + // aten::linalg_cond(Tensor self, Scalar? p=None) -> Tensor + inline at::Tensor linalg_cond(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p=::std::nullopt) { + return at::_ops::linalg_cond::redispatch(dispatchKeySet, self, p); + } + + // aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p=::std::nullopt) { + return at::_ops::linalg_cond_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::Tensor & out) { + return at::_ops::linalg_cond_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_cond.p_str(Tensor self, str p) -> Tensor + inline at::Tensor linalg_cond(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view p) { + return at::_ops::linalg_cond_p_str::redispatch(dispatchKeySet, self, p); + } + + // aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::string_view p) { + return at::_ops::linalg_cond_p_str_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_cond_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::string_view p, at::Tensor & out) { + return at::_ops::linalg_cond_p_str_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_tensor::redispatch(dispatchKeySet, self, atol, rtol, hermitian); + } + + // aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_tensor_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_atol_rtol_tensor_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_float::redispatch(dispatchKeySet, self, atol, rtol, hermitian); + } + + // aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_pinv_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond, bool hermitian=false) { + return at::_ops::linalg_pinv::redispatch(dispatchKeySet, self, rcond, hermitian); + } + + // aten::linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor + inline at::Tensor linalg_pinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & rcond, bool hermitian=false) { + return at::_ops::linalg_pinv_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian); + } + + // aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double rcond, bool hermitian=false) { + return at::_ops::linalg_pinv_out::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double rcond, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_out::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & rcond, bool hermitian=false) { + return at::_ops::linalg_pinv_out_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_pinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & rcond, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_pinv_out_rcond_tensor::redispatch(dispatchKeySet, self, rcond, hermitian, out); + } + + // aten::_linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info) + inline ::std::tuple _linalg_solve_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::_linalg_solve_ex::redispatch(dispatchKeySet, A, B, left, check_errors); + } + + // aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) + inline ::std::tuple _linalg_solve_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::_linalg_solve_ex_result::redispatch(dispatchKeySet, A, B, left, check_errors, result, LU, pivots, info); + } + + // aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) + inline ::std::tuple _linalg_solve_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info) { + return at::_ops::_linalg_solve_ex_result::redispatch(dispatchKeySet, A, B, left, check_errors, result, LU, pivots, info); + } + + // aten::linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info) + inline ::std::tuple linalg_solve_ex(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::linalg_solve_ex::redispatch(dispatchKeySet, A, B, left, check_errors); + } + + // aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + inline ::std::tuple linalg_solve_ex_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & result, at::Tensor & info, const at::Tensor & A, const at::Tensor & B, bool left=true, bool check_errors=false) { + return at::_ops::linalg_solve_ex_out::redispatch(dispatchKeySet, A, B, left, check_errors, result, info); + } + + // aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info) + inline ::std::tuple linalg_solve_ex_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & info) { + return at::_ops::linalg_solve_ex_out::redispatch(dispatchKeySet, A, B, left, check_errors, result, info); + } + + // aten::linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor + inline at::Tensor linalg_solve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true) { + return at::_ops::linalg_solve::redispatch(dispatchKeySet, A, B, left); + } + + // aten::_spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor + inline at::Tensor _spsolve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left=true) { + return at::_ops::_spsolve::redispatch(dispatchKeySet, A, B, left); + } + + // aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & A, const at::Tensor & B, bool left=true) { + return at::_ops::linalg_solve_out::redispatch(dispatchKeySet, A, B, left, out); + } + + // aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_solve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, const at::Tensor & B, bool left, at::Tensor & out) { + return at::_ops::linalg_solve_out::redispatch(dispatchKeySet, A, B, left, out); + } + + // aten::linalg_tensorinv(Tensor self, int ind=2) -> Tensor + inline at::Tensor linalg_tensorinv(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t ind=2) { + return at::_ops::linalg_tensorinv::redispatch(dispatchKeySet, self, ind); + } + + // aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorinv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t ind=2) { + return at::_ops::linalg_tensorinv_out::redispatch(dispatchKeySet, self, ind, out); + } + + // aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorinv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t ind, at::Tensor & out) { + return at::_ops::linalg_tensorinv_out::redispatch(dispatchKeySet, self, ind, out); + } + + // aten::linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor + inline at::Tensor linalg_tensorsolve(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims=::std::nullopt) { + return at::_ops::linalg_tensorsolve::redispatch(dispatchKeySet, self, other, dims); + } + + // aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorsolve_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims=::std::nullopt) { + return at::_ops::linalg_tensorsolve_out::redispatch(dispatchKeySet, self, other, dims, out); + } + + // aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_tensorsolve_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims, at::Tensor & out) { + return at::_ops::linalg_tensorsolve_out::redispatch(dispatchKeySet, self, other, dims, out); + } + + // aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R) + inline ::std::tuple linalg_qr(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view mode="reduced") { + return at::_ops::linalg_qr::redispatch(dispatchKeySet, A, mode); + } + + // aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple linalg_qr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & Q, at::Tensor & R, const at::Tensor & A, c10::string_view mode="reduced") { + return at::_ops::linalg_qr_out::redispatch(dispatchKeySet, A, mode, Q, R); + } + + // aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R) + inline ::std::tuple linalg_qr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & A, c10::string_view mode, at::Tensor & Q, at::Tensor & R) { + return at::_ops::linalg_qr_out::redispatch(dispatchKeySet, A, mode, Q, R); + } + + // aten::linalg_matrix_power(Tensor self, int n) -> Tensor + inline at::Tensor linalg_matrix_power(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n) { + return at::_ops::linalg_matrix_power::redispatch(dispatchKeySet, self, n); + } + + // aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_power_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t n) { + return at::_ops::linalg_matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_power_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t n, at::Tensor & out) { + return at::_ops::linalg_matrix_power_out::redispatch(dispatchKeySet, self, n, out); + } + + // aten::linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor::redispatch(dispatchKeySet, input, atol, rtol, hermitian); + } + + // aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const ::std::optional & atol={}, const ::std::optional & rtol={}, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor_out::redispatch(dispatchKeySet, input, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor_out::redispatch(dispatchKeySet, input, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_float::redispatch(dispatchKeySet, self, atol, rtol, hermitian); + } + + // aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_atol_rtol_float_out::redispatch(dispatchKeySet, self, atol, rtol, hermitian, out); + } + + // aten::linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank::redispatch(dispatchKeySet, self, tol, hermitian); + } + + // aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_out::redispatch(dispatchKeySet, self, tol, hermitian, out); + } + + // aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double tol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_out::redispatch(dispatchKeySet, self, tol, hermitian, out); + } + + // aten::linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor + inline at::Tensor linalg_matrix_rank(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian); + } + + // aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & tol, bool hermitian=false) { + return at::_ops::linalg_matrix_rank_out_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian, out); + } + + // aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_rank_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & tol, bool hermitian, at::Tensor & out) { + return at::_ops::linalg_matrix_rank_out_tol_tensor::redispatch(dispatchKeySet, input, tol, hermitian, out); + } + + // aten::linalg_multi_dot(Tensor[] tensors) -> Tensor + inline at::Tensor linalg_multi_dot(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::linalg_multi_dot::redispatch(dispatchKeySet, tensors); + } + + // aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_multi_dot_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::linalg_multi_dot_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_multi_dot_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::linalg_multi_dot_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor + inline at::Tensor nested_to_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=::std::nullopt) { + return at::_ops::nested_to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size); + } + + // aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor + inline at::Tensor _test_serialization_subcmul(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_test_serialization_subcmul::redispatch(dispatchKeySet, self, other, alpha); + } + + // aten::_test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor + inline at::Tensor _test_parallel_materialize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t num_parallel, bool skip_first=false) { + return at::_ops::_test_parallel_materialize::redispatch(dispatchKeySet, self, num_parallel, skip_first); + } + + // aten::_test_optional_intlist(Tensor values, int[]? addends) -> Tensor + inline at::Tensor _test_optional_intlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_intlist::redispatch(dispatchKeySet, values, addends); + } + + // aten::_test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor + inline at::Tensor _test_optional_filled_intlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_filled_intlist::redispatch(dispatchKeySet, values, addends); + } + + // aten::_test_optional_floatlist(Tensor values, float[]? addends) -> Tensor + inline at::Tensor _test_optional_floatlist(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, ::std::optional> addends) { + return at::_ops::_test_optional_floatlist::redispatch(dispatchKeySet, values, addends); + } + + // aten::_test_string_default(Tensor dummy, str a="\"'\\", str b='"\'\\') -> Tensor + inline at::Tensor _test_string_default(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, c10::string_view a="\"'\\", c10::string_view b="\"'\\") { + return at::_ops::_test_string_default::redispatch(dispatchKeySet, dummy, a, b); + } + + // aten::_test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor + inline at::Tensor _test_ambiguous_defaults(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, int64_t a=1, int64_t b=1) { + return at::_ops::_test_ambiguous_defaults_a::redispatch(dispatchKeySet, dummy, a, b); + } + + // aten::_test_ambiguous_defaults.b(Tensor dummy, int a=2, str b="2") -> Tensor + inline at::Tensor _test_ambiguous_defaults(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dummy, int64_t a, c10::string_view b) { + return at::_ops::_test_ambiguous_defaults_b::redispatch(dispatchKeySet, dummy, a, b); + } + + // aten::_test_warn_in_autograd(Tensor self) -> Tensor + inline at::Tensor _test_warn_in_autograd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_warn_in_autograd::redispatch(dispatchKeySet, self); + } + + // aten::_test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor + inline at::Tensor _test_autograd_multiple_dispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage::redispatch(dispatchKeySet, self); + } + + // aten::_test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor + inline at::Tensor _test_autograd_multiple_dispatch(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool b) { + return at::_ops::_test_autograd_multiple_dispatch_ntonly::redispatch(dispatchKeySet, self, b); + } + + // aten::_test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a) + inline at::Tensor _test_autograd_multiple_dispatch_view(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_view::redispatch(dispatchKeySet, self); + } + + // aten::_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor + inline at::Tensor _test_autograd_multiple_dispatch_view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy::redispatch(dispatchKeySet, self); + } + + // aten::segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor + inline at::Tensor segment_reduce(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & indices={}, const ::std::optional & offsets={}, int64_t axis=0, bool unsafe=false, const ::std::optional & initial=::std::nullopt) { + return at::_ops::segment_reduce::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial); + } + + // aten::_segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor + inline at::Tensor _segment_reduce_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & offsets={}, int64_t axis=0, const ::std::optional & initial=::std::nullopt) { + return at::_ops::_segment_reduce_backward::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial); + } + + // aten::pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side="right") -> Tensor + inline at::Tensor pad_sequence(c10::DispatchKeySet dispatchKeySet, at::TensorList sequences, bool batch_first=false, double padding_value=0.0, c10::string_view padding_side="right") { + return at::_ops::pad_sequence::redispatch(dispatchKeySet, sequences, batch_first, padding_value, padding_side); + } + + // aten::flatten_dense_tensors(Tensor[] tensors) -> Tensor + inline at::Tensor flatten_dense_tensors(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors) { + return at::_ops::flatten_dense_tensors::redispatch(dispatchKeySet, tensors); + } + + // aten::unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[] + inline ::std::vector unflatten_dense_tensors(c10::DispatchKeySet dispatchKeySet, const at::Tensor & flat, at::TensorList tensors) { + return at::_ops::unflatten_dense_tensors::redispatch(dispatchKeySet, flat, tensors); + } + + // aten::_nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor + inline at::Tensor _nested_tensor_from_tensor_list(c10::DispatchKeySet dispatchKeySet, at::TensorList list, ::std::optional dtype=::std::nullopt, ::std::optional layout=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional pin_memory=::std::nullopt) { + return at::_ops::_nested_tensor_from_tensor_list::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory); + } + + // aten::_fw_primal_copy(Tensor self, int level) -> Tensor + inline at::Tensor _fw_primal_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level) { + return at::_ops::_fw_primal_copy::redispatch(dispatchKeySet, self, level); + } + + // aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor + inline at::Tensor _make_dual_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + return at::_ops::_make_dual_copy::redispatch(dispatchKeySet, primal, tangent, level); + } + + // aten::view_as_real_copy(Tensor self) -> Tensor + inline at::Tensor view_as_real_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_real_copy::redispatch(dispatchKeySet, self); + } + + // aten::view_as_complex_copy(Tensor self) -> Tensor + inline at::Tensor view_as_complex_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::view_as_complex_copy::redispatch(dispatchKeySet, self); + } + + // aten::_conj_copy(Tensor self) -> Tensor + inline at::Tensor _conj_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_conj_copy::redispatch(dispatchKeySet, self); + } + + // aten::_neg_view_copy(Tensor self) -> Tensor + inline at::Tensor _neg_view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_neg_view_copy::redispatch(dispatchKeySet, self); + } + + // aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt); + } + + // aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor + inline at::Tensor as_strided_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy::redispatch(dispatchKeySet, self, size, stride, storage_offset); + } + + // aten::_sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor + inline at::Tensor _sparse_broadcast_to_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_sparse_broadcast_to_copy::redispatch(dispatchKeySet, self, size); + } + + // aten::diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor + inline at::Tensor diagonal_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_copy::redispatch(dispatchKeySet, self, offset, dim1, dim2); + } + + // aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor + inline at::Tensor expand_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit); + } + + // aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor + inline at::Tensor expand_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy::redispatch(dispatchKeySet, self, size, implicit); + } + + // aten::permute_copy(Tensor self, int[] dims) -> Tensor + inline at::Tensor permute_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::permute_copy::redispatch(dispatchKeySet, self, dims); + } + + // aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor + inline at::Tensor _reshape_alias_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::_reshape_alias_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor + inline at::Tensor _reshape_alias_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::_reshape_alias_copy::redispatch(dispatchKeySet, self, size, stride); + } + + // aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor select_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::select_copy_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor + inline at::Tensor select_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::select_copy_int::redispatch(dispatchKeySet, self, dim, index); + } + + // aten::detach_copy(Tensor self) -> Tensor + inline at::Tensor detach_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::detach_copy::redispatch(dispatchKeySet, self); + } + + // aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_copy_Tensor::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step); + } + + // aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor + inline at::Tensor slice_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_copy_Tensor::redispatch(dispatchKeySet, self, dim, start, end, step); + } + + // aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector split_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[] + inline ::std::vector split_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor::redispatch(dispatchKeySet, self, split_size, dim); + } + + // aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector split_with_sizes_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim); + } + + // aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[] + inline ::std::vector split_with_sizes_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy::redispatch(dispatchKeySet, self, split_sizes, dim); + } + + // aten::squeeze_copy(Tensor self) -> Tensor + inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::squeeze_copy::redispatch(dispatchKeySet, self); + } + + // aten::squeeze_copy.dim(Tensor self, int dim) -> Tensor + inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::squeeze_copy_dim::redispatch(dispatchKeySet, self, dim); + } + + // aten::squeeze_copy.dims(Tensor self, int[] dim) -> Tensor + inline at::Tensor squeeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze_copy_dims::redispatch(dispatchKeySet, self, dim); + } + + // aten::t_copy(Tensor self) -> Tensor + inline at::Tensor t_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::t_copy::redispatch(dispatchKeySet, self); + } + + // aten::transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor + inline at::Tensor transpose_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_copy_int::redispatch(dispatchKeySet, self, dim0, dim1); + } + + // aten::unsqueeze_copy(Tensor self, int dim) -> Tensor + inline at::Tensor unsqueeze_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze_copy::redispatch(dispatchKeySet, self, dim); + } + + // aten::_indices_copy(Tensor self) -> Tensor + inline at::Tensor _indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::_values_copy(Tensor self) -> Tensor + inline at::Tensor _values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::_values_copy::redispatch(dispatchKeySet, self); + } + + // aten::indices_copy(Tensor self) -> Tensor + inline at::Tensor indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::values_copy(Tensor self) -> Tensor + inline at::Tensor values_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::values_copy::redispatch(dispatchKeySet, self); + } + + // aten::crow_indices_copy(Tensor self) -> Tensor + inline at::Tensor crow_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::crow_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::col_indices_copy(Tensor self) -> Tensor + inline at::Tensor col_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::col_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::ccol_indices_copy(Tensor self) -> Tensor + inline at::Tensor ccol_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::ccol_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::row_indices_copy(Tensor self) -> Tensor + inline at::Tensor row_indices_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::row_indices_copy::redispatch(dispatchKeySet, self); + } + + // aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[] + inline ::std::vector unbind_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim=0) { + return at::_ops::unbind_copy_int::redispatch(dispatchKeySet, self, dim); + } + + // aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> () + inline void unbind_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t dim=0) { + return at::_ops::unbind_copy_int_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> () + inline void unbind_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::TensorList out) { + return at::_ops::unbind_copy_int_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + return at::_ops::split_copy_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void split_with_sizes_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::split_with_sizes_copy_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::view_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::view_copy::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size)); + } + + // aten::view_copy(Tensor self, SymInt[] size) -> Tensor + inline at::Tensor view_copy_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::view_copy::redispatch(dispatchKeySet, self, size); + } + + // aten::view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor + inline at::Tensor view_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::view_copy_dtype::redispatch(dispatchKeySet, self, dtype); + } + + // aten::unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor + inline at::Tensor unfold_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + return at::_ops::unfold_copy::redispatch(dispatchKeySet, self, dimension, size, step); + } + + // aten::alias_copy(Tensor self) -> Tensor + inline at::Tensor alias_copy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::alias_copy::redispatch(dispatchKeySet, self); + } + + // aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + inline at::Tensor to_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt); + } + + // aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor + inline at::Tensor to_padded_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor::redispatch(dispatchKeySet, self, padding, output_size); + } + + // aten::_jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + inline at::Tensor _jagged_to_padded_dense_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::TensorList offsets, at::IntArrayRef max_lengths, double padding_value=0.0) { + return at::_ops::_jagged_to_padded_dense_forward::redispatch(dispatchKeySet, values, offsets, c10::fromIntArrayRefSlow(max_lengths), padding_value); + } + + // aten::_jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor + inline at::Tensor _jagged_to_padded_dense_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::TensorList offsets, c10::SymIntArrayRef max_lengths, double padding_value=0.0) { + return at::_ops::_jagged_to_padded_dense_forward::redispatch(dispatchKeySet, values, offsets, max_lengths, padding_value); + } + + // aten::_padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + inline at::Tensor _padded_dense_to_jagged_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L=::std::nullopt) { + return at::_ops::_padded_dense_to_jagged_forward::redispatch(dispatchKeySet, dense, offsets, total_L.has_value() ? ::std::make_optional(c10::SymInt(*total_L)) : ::std::nullopt); + } + + // aten::_padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor + inline at::Tensor _padded_dense_to_jagged_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L=::std::nullopt) { + return at::_ops::_padded_dense_to_jagged_forward::redispatch(dispatchKeySet, dense, offsets, total_L); + } + + // aten::_nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + inline at::Tensor _nested_from_padded_tensor(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}, ::std::optional sum_S=::std::nullopt) { + return at::_ops::_nested_from_padded_tensor::redispatch(dispatchKeySet, padded, offsets, dummy, ragged_idx, min_seqlen, max_seqlen, sum_S.has_value() ? ::std::make_optional(c10::SymInt(*sum_S)) : ::std::nullopt); + } + + // aten::_nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor + inline at::Tensor _nested_from_padded_tensor_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}, ::std::optional sum_S=::std::nullopt) { + return at::_ops::_nested_from_padded_tensor::redispatch(dispatchKeySet, padded, offsets, dummy, ragged_idx, min_seqlen, max_seqlen, sum_S); + } + + // aten::_nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor + inline at::Tensor _nested_tensor_softmax_with_shape(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & query) { + return at::_ops::_nested_tensor_softmax_with_shape::redispatch(dispatchKeySet, self, query); + } + + // aten::_safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor + inline at::Tensor _safe_softmax(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional dtype=::std::nullopt) { + return at::_ops::_safe_softmax::redispatch(dispatchKeySet, self, dim, dtype); + } + + // aten::_transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor + inline at::Tensor _transformer_encoder_layer_fwd(c10::DispatchKeySet dispatchKeySet, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask={}, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_transformer_encoder_layer_fwd::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type); + } + + // aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor) + inline ::std::tuple _native_multi_head_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}, bool need_weights=true, bool average_attn_weights=true, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_native_multi_head_attention::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type); + } + + // aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor + inline at::Tensor scaled_dot_product_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, ::std::optional scale=::std::nullopt, bool enable_gqa=false) { + return at::_ops::scaled_dot_product_attention::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); + } + + // aten::_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int + inline int64_t _fused_sdp_choice(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, ::std::optional scale=::std::nullopt, bool enable_gqa=false) { + return at::_ops::_fused_sdp_choice::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); + } + + // aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_attention_math(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, const ::std::optional & dropout_mask={}, ::std::optional scale=::std::nullopt, bool enable_gqa=false) { + return at::_ops::_scaled_dot_product_attention_math::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale, enable_gqa); + } + + // aten::_scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_attention_math_for_mps(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask={}, double dropout_p=0.0, bool is_causal=false, const ::std::optional & dropout_mask={}, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_attention_math_for_mps::redispatch(dispatchKeySet, query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale); + } + + // aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + inline ::std::tuple _scaled_dot_product_flash_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp) + inline ::std::tuple _scaled_dot_product_flash_attention_for_cpu(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p=0.0, bool is_causal=false, const ::std::optional & attn_mask={}, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu::redispatch(dispatchKeySet, query, key, value, dropout_p, is_causal, attn_mask, scale); + } + + // aten::_scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _scaled_dot_product_fused_attention_overrideable(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias={}, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable::redispatch(dispatchKeySet, query, key, value, attn_bias, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + inline ::std::tuple _scaled_dot_product_flash_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + inline ::std::tuple _scaled_dot_product_flash_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value) + inline ::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const ::std::optional & attn_mask={}, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, attn_mask, scale); + } + + // aten::_scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias) + inline ::std::tuple _scaled_dot_product_fused_attention_overrideable_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable_backward::redispatch(dispatchKeySet, grad_out, query, key, value, attn_bias, grad_input_mask, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias) + inline ::std::tuple _scaled_dot_product_fused_attention_overrideable_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable_backward::redispatch(dispatchKeySet, grad_out, query, key, value, attn_bias, grad_input_mask, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + + // aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset) + inline ::std::tuple _scaled_dot_product_efficient_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_efficient_attention::redispatch(dispatchKeySet, query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, scale); + } + + // aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_efficient_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, attn_bias, out, logsumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale); + } + + // aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _scaled_dot_product_cudnn_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_cudnn_attention::redispatch(dispatchKeySet, query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_cudnn_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_cudnn_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + + // aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _scaled_dot_product_cudnn_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale=::std::nullopt) { + return at::_ops::_scaled_dot_product_cudnn_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + + // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + inline ::std::tuple _flash_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt, const ::std::optional & seqused_k={}, const ::std::optional & alibi_slopes={}) { + return at::_ops::_flash_attention_forward::redispatch(dispatchKeySet, query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left.has_value() ? ::std::make_optional(c10::SymInt(*window_size_left)) : ::std::nullopt, window_size_right.has_value() ? ::std::make_optional(c10::SymInt(*window_size_right)) : ::std::nullopt, seqused_k, alibi_slopes); + } + + // aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask) + inline ::std::tuple _flash_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt, const ::std::optional & seqused_k={}, const ::std::optional & alibi_slopes={}) { + return at::_ops::_flash_attention_forward::redispatch(dispatchKeySet, query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left, window_size_right, seqused_k, alibi_slopes); + } + + // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _flash_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, int64_t max_q, int64_t max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { + return at::_ops::_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left.has_value() ? ::std::make_optional(c10::SymInt(*window_size_left)) : ::std::nullopt, window_size_right.has_value() ? ::std::make_optional(c10::SymInt(*window_size_right)) : ::std::nullopt); + } + + // aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor) + inline ::std::tuple _flash_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale=::std::nullopt, ::std::optional window_size_left=::std::nullopt, ::std::optional window_size_right=::std::nullopt) { + return at::_ops::_flash_attention_backward::redispatch(dispatchKeySet, grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right); + } + + // aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + inline ::std::tuple _efficient_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp=false, ::std::optional scale=::std::nullopt, const ::std::optional & seqlen_k={}, ::std::optional window_size=::std::nullopt) { + return at::_ops::_efficient_attention_forward::redispatch(dispatchKeySet, query, key, value, bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q.has_value() ? ::std::make_optional(c10::SymInt(*max_seqlen_q)) : ::std::nullopt, max_seqlen_k.has_value() ? ::std::make_optional(c10::SymInt(*max_seqlen_k)) : ::std::nullopt, dropout_p, custom_mask_type, compute_log_sumexp, scale, seqlen_k, window_size); + } + + // aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k) + inline ::std::tuple _efficient_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp=false, ::std::optional scale=::std::nullopt, const ::std::optional & seqlen_k={}, ::std::optional window_size=::std::nullopt) { + return at::_ops::_efficient_attention_forward::redispatch(dispatchKeySet, query, key, value, bias, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, custom_mask_type, compute_log_sumexp, scale, seqlen_k, window_size); + } + + // aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _efficient_attention_backward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, int64_t max_seqlen_q, int64_t max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale=::std::nullopt, ::std::optional num_splits_key=::std::nullopt, ::std::optional window_size=::std::nullopt, bool shared_storage_dqdkdv=false) { + return at::_ops::_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + } + + // aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor) + inline ::std::tuple _efficient_attention_backward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale=::std::nullopt, ::std::optional num_splits_key=::std::nullopt, ::std::optional window_size=::std::nullopt, bool shared_storage_dqdkdv=false) { + return at::_ops::_efficient_attention_backward::redispatch(dispatchKeySet, grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + } + + // aten::_cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _cudnn_attention_forward(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, int64_t max_q, int64_t max_k, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_cudnn_attention_forward::redispatch(dispatchKeySet, query, key, value, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask) + inline ::std::tuple _cudnn_attention_forward_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, bool compute_log_sumexp, double dropout_p=0.0, bool is_causal=false, bool return_debug_mask=false, ::std::optional scale=::std::nullopt) { + return at::_ops::_cudnn_attention_forward::redispatch(dispatchKeySet, query, key, value, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + + // aten::_triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor + inline at::Tensor _triton_scaled_dot_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0) { + return at::_ops::_triton_scaled_dot_attention::redispatch(dispatchKeySet, q, k, v, dropout_p); + } + + // aten::_fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!) + inline at::Tensor & _fill_mem_eff_dropout_mask_(c10::DispatchKeySet dispatchKeySet, at::Tensor & self, double dropout_p, int64_t seed, int64_t offset) { + return at::_ops::_fill_mem_eff_dropout_mask_::redispatch(dispatchKeySet, self, dropout_p, seed, offset); + } + + // aten::_triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor + inline at::Tensor _triton_multi_head_attention(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}) { + return at::_ops::_triton_multi_head_attention::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask); + } + + // aten::special_airy_ai(Tensor x) -> Tensor + inline at::Tensor special_airy_ai(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_airy_ai::redispatch(dispatchKeySet, x); + } + + // aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_airy_ai_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_airy_ai_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_airy_ai_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_airy_ai_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_bessel_j0(Tensor self) -> Tensor + inline at::Tensor special_bessel_j0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_j0::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_j0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_j0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_j1(Tensor self) -> Tensor + inline at::Tensor special_bessel_j1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_j1::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_j1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_j1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_j1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y0(Tensor self) -> Tensor + inline at::Tensor special_bessel_y0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_y0::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_y0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_y0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y1(Tensor self) -> Tensor + inline at::Tensor special_bessel_y1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_bessel_y1::redispatch(dispatchKeySet, self); + } + + // aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_bessel_y1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_bessel_y1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_bessel_y1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_hermite_polynomial_h(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_h_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_h_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_h_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_h_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_h_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_h_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_h_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_hermite_polynomial_he(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_he_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_he_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_hermite_polynomial_he_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_he_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_hermite_polynomial_he_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_hermite_polynomial_he_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_hermite_polynomial_he_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l::redispatch(dispatchKeySet, x, n); + } + + // aten::special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_laguerre_polynomial_l(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_laguerre_polynomial_l_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_laguerre_polynomial_l_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_laguerre_polynomial_l_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_laguerre_polynomial_l_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_laguerre_polynomial_l_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_laguerre_polynomial_l_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_laguerre_polynomial_l_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p::redispatch(dispatchKeySet, x, n); + } + + // aten::special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_legendre_polynomial_p(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_legendre_polynomial_p_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_legendre_polynomial_p_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_legendre_polynomial_p_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_legendre_polynomial_p_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_legendre_polynomial_p_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_legendre_polynomial_p_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_legendre_polynomial_p_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_modified_bessel_i0(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_i0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i0::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_i0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_i1(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_i1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i1::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_i1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_i1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k0(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_k0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k0::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_k0_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k1(Tensor self) -> Tensor + inline at::Tensor special_modified_bessel_k1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k1::redispatch(dispatchKeySet, self); + } + + // aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::special_modified_bessel_k1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_modified_bessel_k1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::special_modified_bessel_k1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::special_scaled_modified_bessel_k0(Tensor x) -> Tensor + inline at::Tensor special_scaled_modified_bessel_k0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k0::redispatch(dispatchKeySet, x); + } + + // aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_scaled_modified_bessel_k0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_scaled_modified_bessel_k1(Tensor x) -> Tensor + inline at::Tensor special_scaled_modified_bessel_k1(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k1::redispatch(dispatchKeySet, x); + } + + // aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k1_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_scaled_modified_bessel_k1_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_scaled_modified_bessel_k1_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_scaled_modified_bessel_k1_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_t(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_t_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_t_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_u(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_u_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_u_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_v(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_v_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_v_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor + inline at::Tensor special_shifted_chebyshev_polynomial_w(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar::redispatch(dispatchKeySet, x, n); + } + + // aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_w_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & x, const at::Tensor & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & x, const at::Tensor & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Scalar & n) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_shifted_chebyshev_polynomial_w_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Scalar & n, at::Tensor & out) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar_out::redispatch(dispatchKeySet, x, n, out); + } + + // aten::special_spherical_bessel_j0(Tensor x) -> Tensor + inline at::Tensor special_spherical_bessel_j0(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x) { + return at::_ops::special_spherical_bessel_j0::redispatch(dispatchKeySet, x); + } + + // aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_spherical_bessel_j0_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x) { + return at::_ops::special_spherical_bessel_j0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & special_spherical_bessel_j0_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, at::Tensor & out) { + return at::_ops::special_spherical_bessel_j0_out::redispatch(dispatchKeySet, x, out); + } + + // aten::_foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor + inline at::Tensor _foobar(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool arg1=true, bool arg2=true, bool arg3=true) { + return at::_ops::_foobar::redispatch(dispatchKeySet, self, arg1, arg2, arg3); + } + + // aten::_fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adam_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adam_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam__tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adamw_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adamw_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw__tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_sgd_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_sgd_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd__tensor_lr::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adagrad_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + + // aten::_fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> () + inline void _fused_adagrad_(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad__tensor_lr::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + + // aten::_propagate_xla_data(Tensor input, Tensor output) -> () + inline void _propagate_xla_data(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & output) { + return at::_ops::_propagate_xla_data::redispatch(dispatchKeySet, input, output); + } + + // aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _new_zeros_with_same_feature_meta_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims=0) { + return at::_ops::_new_zeros_with_same_feature_meta_out::redispatch(dispatchKeySet, self, other, self_num_batch_dims, out); + } + + // aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _new_zeros_with_same_feature_meta_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims, at::Tensor & out) { + return at::_ops::_new_zeros_with_same_feature_meta_out::redispatch(dispatchKeySet, self, other, self_num_batch_dims, out); + } + + // aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _cudnn_ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + return at::_ops::_cudnn_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity, out0, out1); + } + + // aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _cudnn_ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_cudnn_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity, out0, out1); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, int64_t input_size, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_rnn_flatten_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out) { + return at::_ops::_cudnn_rnn_flatten_weight_out::redispatch(dispatchKeySet, weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional, out); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _cudnn_rnn_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_cudnn_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, c10::fromIntArrayRefSlow(batch_sizes), dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void _cudnn_rnn_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + return at::_ops::_cudnn_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_init_dropout_state_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, double dropout, bool train, int64_t dropout_seed) { + return at::_ops::_cudnn_init_dropout_state_out::redispatch(dispatchKeySet, dropout, train, dropout_seed, out); + } + + // aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cudnn_init_dropout_state_outf(c10::DispatchKeySet dispatchKeySet, double dropout, bool train, int64_t dropout_seed, at::Tensor & out) { + return at::_ops::_cudnn_init_dropout_state_out::redispatch(dispatchKeySet, dropout, train, dropout_seed, out); + } + + // aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fused_dropout_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::_fused_dropout_out::redispatch(dispatchKeySet, self, p, generator, out0, out1); + } + + // aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fused_dropout_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_fused_dropout_out::redispatch(dispatchKeySet, self, p, generator, out0, out1); + } + + // aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_scale_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, double scale) { + return at::_ops::_masked_scale_out::redispatch(dispatchKeySet, self, mask, scale, out); + } + + // aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_scale_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, double scale, at::Tensor & out) { + return at::_ops::_masked_scale_out::redispatch(dispatchKeySet, self, mask, scale, out); + } + + // aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple native_dropout_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double p, ::std::optional train) { + return at::_ops::native_dropout_out::redispatch(dispatchKeySet, input, p, train, out0, out1); + } + + // aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple native_dropout_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double p, ::std::optional train, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::native_dropout_out::redispatch(dispatchKeySet, input, p, train, out0, out1); + } + + // aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_dropout_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & mask, double scale) { + return at::_ops::native_dropout_backward_out::redispatch(dispatchKeySet, grad_output, mask, scale, out); + } + + // aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_dropout_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & mask, double scale, at::Tensor & out) { + return at::_ops::native_dropout_backward_out::redispatch(dispatchKeySet, grad_output, mask, scale, out); + } + + // aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_physical_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_physical_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_conj_physical_out::redispatch(dispatchKeySet, self, out); + } + + // aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, bool ceil_mode=false, bool count_include_pad=true) { + return at::_ops::avg_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, out); + } + + // aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & avg_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out) { + return at::_ops::avg_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, ceil_mode, count_include_pad, out); + } + + // aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::adaptive_avg_pool1d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & adaptive_avg_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::adaptive_avg_pool1d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::_add_relu_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::_add_relu_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::add_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & add_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::add_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, at::IntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, at::IntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, c10::fromIntArrayRefSlow(size), align_corners, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, size, align_corners, out); + } + + // aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & affine_grid_generator_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out) { + return at::_ops::affine_grid_generator_out::redispatch(dispatchKeySet, theta, size, align_corners, out); + } + + // aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_functorch_fallback_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_test_functorch_fallback_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_functorch_fallback_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::_test_functorch_fallback_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::bartlett_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::bartlett_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::bartlett_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bartlett_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::bartlett_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) { + return at::_ops::quantized_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point, out); + } + + // aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point, at::Tensor & out) { + return at::_ops::quantized_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, mean, var, eps, output_scale, output_zero_point, out); + } + + // aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_Tensor_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & p, ::std::optional generator, at::Tensor & out) { + return at::_ops::bernoulli_Tensor_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::bernoulli.Tensor(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor + inline at::Tensor bernoulli(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & p, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_Tensor::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p=0.5, ::std::optional generator=::std::nullopt) { + return at::_ops::bernoulli_float_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bernoulli_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out) { + return at::_ops::bernoulli_float_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_with_logits_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight={}, const ::std::optional & pos_weight={}, int64_t reduction=at::Reduction::Mean) { + return at::_ops::binary_cross_entropy_with_logits_out::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction, out); + } + + // aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binary_cross_entropy_with_logits_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction, at::Tensor & out) { + return at::_ops::binary_cross_entropy_with_logits_out::redispatch(dispatchKeySet, self, target, weight, pos_weight, reduction, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & weights={}, int64_t minlength=0) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights, int64_t minlength, at::Tensor & out) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & weights={}, c10::SymInt minlength=0) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bincount_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength, at::Tensor & out) { + return at::_ops::bincount_out::redispatch(dispatchKeySet, self, weights, minlength, out); + } + + // aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::blackman_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::blackman_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::blackman_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & blackman_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::blackman_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & block_diag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList tensors) { + return at::_ops::block_diag_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & block_diag_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::Tensor & out) { + return at::_ops::block_diag_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef pad, const at::Scalar & value, at::Tensor & out) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(pad), value, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value=0) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, pad, value, out); + } + + // aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & constant_pad_nd_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value, at::Tensor & out) { + return at::_ops::constant_pad_nd_out::redispatch(dispatchKeySet, self, pad, value, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, at::Tensor & out) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out) { + return at::_ops::convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : ::std::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalIntArrayRef bias_sizes, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*bias_sizes)) : ::std::nullopt, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, at::Tensor & out) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, out); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & convolution_overrideable_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out) { + return at::_ops::convolution_overrideable_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, out); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple convolution_backward_overrideable_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::convolution_backward_overrideable_out::redispatch(dispatchKeySet, grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask, out0, out1, out2); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), transposed, c10::fromIntArrayRefSlow(output_padding), groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out) { + return at::_ops::_convolution_out::redispatch(dispatchKeySet, input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32, out); + } + + // aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_tbc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad=0) { + return at::_ops::conv_tbc_out::redispatch(dispatchKeySet, self, weight, bias, pad, out); + } + + // aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_tbc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad, at::Tensor & out) { + return at::_ops::conv_tbc_out::redispatch(dispatchKeySet, self, weight, bias, pad, out); + } + + // aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out) { + return at::_ops::copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & dst, bool non_blocking=false) { + return at::_ops::_copy_from_out::redispatch(dispatchKeySet, self, dst, non_blocking, out); + } + + // aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out) { + return at::_ops::_copy_from_out::redispatch(dispatchKeySet, self, dst, non_blocking, out); + } + + // aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_and_resize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & dst) { + return at::_ops::_copy_from_and_resize_out::redispatch(dispatchKeySet, self, dst, out); + } + + // aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _copy_from_and_resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & dst, at::Tensor & out) { + return at::_ops::_copy_from_and_resize_out::redispatch(dispatchKeySet, self, dst, out); + } + + // aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::count_nonzero_dim_IntList_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::count_nonzero_dim_IntList_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dim=::std::nullopt) { + return at::_ops::count_nonzero_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & count_nonzero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dim, at::Tensor & out) { + return at::_ops::count_nonzero_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator_out::redispatch(dispatchKeySet, theta, N, C, H, W, out); + } + + // aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out) { + return at::_ops::cudnn_affine_grid_generator_out::redispatch(dispatchKeySet, theta, N, C, H, W, out); + } + + // aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) { + return at::_ops::cudnn_affine_grid_generator_backward_out::redispatch(dispatchKeySet, grad, N, C, H, W, out); + } + + // aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_affine_grid_generator_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out) { + return at::_ops::cudnn_affine_grid_generator_backward_out::redispatch(dispatchKeySet, grad, N, C, H, W, out); + } + + // aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple cudnn_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::cudnn_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2, out3); + } + + // aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple cudnn_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::cudnn_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2, out3); + } + + // aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple cudnn_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace) { + return at::_ops::cudnn_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace, out0, out1, out2); + } + + // aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple cudnn_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::cudnn_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace, out0, out1, out2); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out) { + return at::_ops::cudnn_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, out); + } + + // aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::_mps_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, padding, output_padding, stride, dilation, groups, out); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask, out0, out1); + } + + // aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mps_convolution_transpose_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::mps_convolution_transpose_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask, out0, out1); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_relu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_relu_out::redispatch(dispatchKeySet, self, weight, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_convolution_add_relu_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::cudnn_convolution_add_relu_out::redispatch(dispatchKeySet, self, weight, z, alpha, bias, stride, padding, dilation, groups, out); + } + + // aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_grid_sampler_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & grid) { + return at::_ops::cudnn_grid_sampler_out::redispatch(dispatchKeySet, self, grid, out); + } + + // aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cudnn_grid_sampler_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, at::Tensor & out) { + return at::_ops::cudnn_grid_sampler_out::redispatch(dispatchKeySet, self, grid, out); + } + + // aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple cudnn_grid_sampler_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) { + return at::_ops::cudnn_grid_sampler_backward_out::redispatch(dispatchKeySet, self, grid, grad_output, out0, out1); + } + + // aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple cudnn_grid_sampler_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::cudnn_grid_sampler_backward_out::redispatch(dispatchKeySet, self, grid, grad_output, out0, out1); + } + + // aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_ctc_loss_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank=0, bool zero_infinity=false) { + return at::_ops::_ctc_loss_Tensor_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _ctc_loss_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_ctc_loss_Tensor_out::redispatch(dispatchKeySet, log_probs, targets, input_lengths, target_lengths, blank, zero_infinity, out0, out1); + } + + // aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _ctc_loss_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity=false) { + return at::_ops::_ctc_loss_backward_out::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity, out); + } + + // aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _ctc_loss_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity, at::Tensor & out) { + return at::_ops::_ctc_loss_backward_out::redispatch(dispatchKeySet, grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity, out); + } + + // aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_embed_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=-2, int64_t dim2=-1) { + return at::_ops::diag_embed_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diag_embed_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diag_embed_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2, out); + } + + // aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, offset, dim1, dim2, out); + } + + // aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::div_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::div_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + return at::_ops::div_Scalar_mode_out::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & div_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode, at::Tensor & out) { + return at::_ops::div_Scalar_mode_out::redispatch(dispatchKeySet, self, other, rounding_mode, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx=-1, bool scale_grad_by_freq=false, bool sparse=false) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out) { + return at::_ops::embedding_out::redispatch(dispatchKeySet, weight, indices, padding_idx, scale_grad_by_freq, sparse, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq, at::Tensor & out) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_dense_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out) { + return at::_ops::embedding_dense_backward_out::redispatch(dispatchKeySet, grad_output, indices, num_weights, padding_idx, scale_grad_by_freq, out); + } + + // aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_renorm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + return at::_ops::embedding_renorm_out::redispatch(dispatchKeySet, self, indices, max_norm, norm_type, out); + } + + // aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & embedding_renorm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type, at::Tensor & out) { + return at::_ops::embedding_renorm_out::redispatch(dispatchKeySet, self, indices, max_norm, norm_type, out); + } + + // aten::embedding_renorm(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor + inline at::Tensor embedding_renorm(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + return at::_ops::embedding_renorm::redispatch(dispatchKeySet, self, indices, max_norm, norm_type); + } + + // aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_forward_only_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_forward_only_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_forward_only_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::_embedding_bag_forward_only_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq=false, int64_t mode=0, bool sparse=false, const ::std::optional & per_sample_weights={}, bool include_last_offset=false, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _embedding_bag_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::_embedding_bag_out::redispatch(dispatchKeySet, weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx, out0, out1, out2, out3); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, int64_t num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_dense_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out) { + return at::_ops::_embedding_bag_dense_backward_out::redispatch(dispatchKeySet, grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx, out); + } + + // aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_per_sample_weights_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx=-1) { + return at::_ops::_embedding_bag_per_sample_weights_backward_out::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx, out); + } + + // aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _embedding_bag_per_sample_weights_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx, at::Tensor & out) { + return at::_ops::_embedding_bag_per_sample_weights_backward_out::redispatch(dispatchKeySet, grad, weight, indices, offsets, offset2bag, mode, padding_idx, out); + } + + // aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_names_out::redispatch(dispatchKeySet, size, names, memory_format, out); + } + + // aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_names_out::redispatch(dispatchKeySet, size, names, memory_format, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, at::IntArrayRef physical_layout) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), physical_layout, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, at::IntArrayRef physical_layout) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, size, physical_layout, out); + } + + // aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_permuted_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out) { + return at::_ops::empty_permuted_out::redispatch(dispatchKeySet, size, physical_layout, out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::new_empty_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_empty_strided_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::new_empty_strided_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), fill_value, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_full_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out) { + return at::_ops::new_full_out::redispatch(dispatchKeySet, self, size, fill_value, out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_zeros_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::new_zeros_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & new_ones_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::new_ones_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scale, zero_point, memory_format, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, double scale, int64_t zero_point, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scale, zero_point, memory_format, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, double scale=1, int64_t zero_point=0, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, size, scale, zero_point, memory_format, out); + } + + // aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_affine_quantized_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, double scale, int64_t zero_point, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_affine_quantized_out::redispatch(dispatchKeySet, size, scale, zero_point, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), scales, zero_points, axis, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format=c10::MemoryFormat::Contiguous) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, size, scales, zero_points, axis, memory_format, out); + } + + // aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _empty_per_channel_affine_quantized_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_empty_per_channel_affine_quantized_out::redispatch(dispatchKeySet, size, scales, zero_points, axis, memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format, const at::Tensor & out) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, size, memory_format, out); + } + + // aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format, const at::Tensor & out) { + return at::_ops::resize_out::redispatch(dispatchKeySet, self, size, memory_format, out); + } + + // aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), memory_format); + } + + // aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor resize_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize::redispatch(dispatchKeySet, self, size, memory_format); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::Device device) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device, out); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device, const at::Tensor & out) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device, out); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_symint_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, size, device, out); + } + + // aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & _resize_output_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device, const at::Tensor & out) { + return at::_ops::_resize_output_out::redispatch(dispatchKeySet, self, size, device, out); + } + + // aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor + inline at::Tensor _resize_output(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Device device) { + return at::_ops::_resize_output::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), device); + } + + // aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor + inline at::Tensor _resize_output_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + return at::_ops::_resize_output::redispatch(dispatchKeySet, self, size, device); + } + + // aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_quantized_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_quantized_out::redispatch(dispatchKeySet, size, qtensor, memory_format, out); + } + + // aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_quantized_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_quantized_out::redispatch(dispatchKeySet, size, qtensor, memory_format, out); + } + + // aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::empty_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::empty_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, size, stride, out); + } + + // aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & empty_strided_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::empty_strided_out::redispatch(dispatchKeySet, size, stride, out); + } + + // aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & value) { + return at::_ops::fill_Scalar_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & value, at::Tensor & out) { + return at::_ops::fill_Scalar_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & value) { + return at::_ops::fill_Tensor_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & value, at::Tensor & out) { + return at::_ops::fill_Tensor_out::redispatch(dispatchKeySet, self, value, out); + } + + // aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::floor_divide_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & floor_divide_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::floor_divide_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names) { + return at::_ops::full_names_out::redispatch(dispatchKeySet, size, fill_value, names, out); + } + + // aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, at::Tensor & out) { + return at::_ops::full_names_out::redispatch(dispatchKeySet, size, fill_value, names, out); + } + + // aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & fill_value, ::std::optional memory_format=::std::nullopt) { + return at::_ops::full_like_out::redispatch(dispatchKeySet, self, fill_value, memory_format, out); + } + + // aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & full_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & fill_value, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::full_like_out::redispatch(dispatchKeySet, self, fill_value, memory_format, out); + } + + // aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & from_file_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::string_view filename, ::std::optional shared=::std::nullopt, ::std::optional size=0) { + return at::_ops::from_file_out::redispatch(dispatchKeySet, filename, shared, size, out); + } + + // aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & from_file_outf(c10::DispatchKeySet dispatchKeySet, c10::string_view filename, ::std::optional shared, ::std::optional size, at::Tensor & out) { + return at::_ops::from_file_out::redispatch(dispatchKeySet, filename, shared, size, out); + } + + // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_2d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) { + return at::_ops::grid_sampler_2d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_2d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::grid_sampler_2d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _grid_sampler_2d_cpu_fallback_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::_grid_sampler_2d_cpu_fallback_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _grid_sampler_2d_cpu_fallback_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) { + return at::_ops::_grid_sampler_2d_cpu_fallback_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + return at::_ops::grid_sampler_3d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & grid_sampler_3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out) { + return at::_ops::grid_sampler_3d_out::redispatch(dispatchKeySet, input, grid, interpolation_mode, padding_mode, align_corners, out); + } + + // aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + return at::_ops::grid_sampler_3d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple grid_sampler_3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::grid_sampler_3d_backward_out::redispatch(dispatchKeySet, grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask, out0, out1); + } + + // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::hann_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::hann_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::hann_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hann_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::hann_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::hamming_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::hamming_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::hamming_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::hamming_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double alpha) { + return at::_ops::hamming_window_periodic_alpha_out::redispatch(dispatchKeySet, window_length, periodic, alpha, out); + } + + // aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, at::Tensor & out) { + return at::_ops::hamming_window_periodic_alpha_out::redispatch(dispatchKeySet, window_length, periodic, alpha, out); + } + + // aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double alpha, double beta) { + return at::_ops::hamming_window_periodic_alpha_beta_out::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, out); + } + + // aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hamming_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double alpha, double beta, at::Tensor & out) { + return at::_ops::hamming_window_periodic_alpha_beta_out::redispatch(dispatchKeySet, window_length, periodic, alpha, beta, out); + } + + // aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length) { + return at::_ops::kaiser_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, at::Tensor & out) { + return at::_ops::kaiser_window_out::redispatch(dispatchKeySet, window_length, out); + } + + // aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic) { + return at::_ops::kaiser_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, at::Tensor & out) { + return at::_ops::kaiser_window_periodic_out::redispatch(dispatchKeySet, window_length, periodic, out); + } + + // aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t window_length, bool periodic, double beta) { + return at::_ops::kaiser_window_beta_out::redispatch(dispatchKeySet, window_length, periodic, beta, out); + } + + // aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & kaiser_window_outf(c10::DispatchKeySet dispatchKeySet, int64_t window_length, bool periodic, double beta, at::Tensor & out) { + return at::_ops::kaiser_window_beta_out::redispatch(dispatchKeySet, window_length, periodic, beta, out); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_out::redispatch(dispatchKeySet, input, weight, bias, N, C, HxW, group, eps, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_group_norm_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_group_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask, out0, out1, out2); + } + + // aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_put_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false) { + return at::_ops::index_put_out::redispatch(dispatchKeySet, self, indices, values, accumulate, out); + } + + // aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_put_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out) { + return at::_ops::index_put_out::redispatch(dispatchKeySet, self, indices, values, accumulate, out); + } + + // aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _index_put_impl_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) { + return at::_ops::_index_put_impl_out::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe, out); + } + + // aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _index_put_impl_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe, at::Tensor & out) { + return at::_ops::_index_put_impl_out::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe, out); + } + + // aten::_index_put_impl(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor + inline at::Tensor _index_put_impl(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate=false, bool unsafe=false) { + return at::_ops::_index_put_impl::redispatch(dispatchKeySet, self, indices, values, accumulate, unsafe); + } + + // aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isnan_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isnan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isnan_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isnan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::IntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, c10::fromIntArrayRefSlow(normalized_shape), weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_out::redispatch(dispatchKeySet, input, normalized_shape, weight, bias, eps, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, at::IntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, c10::fromIntArrayRefSlow(normalized_shape), mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_layer_norm_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_layer_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask, out0, out1, out2); + } + + // aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple linear_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple linear_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias={}) { + return at::_ops::mkldnn_linear_out::redispatch(dispatchKeySet, self, weight, bias, out); + } + + // aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out) { + return at::_ops::mkldnn_linear_out::redispatch(dispatchKeySet, self, weight, bias, out); + } + + // aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_backward_input_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) { + return at::_ops::mkldnn_linear_backward_input_out::redispatch(dispatchKeySet, input_size, grad_output, weight, out); + } + + // aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_linear_backward_input_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight, at::Tensor & out) { + return at::_ops::mkldnn_linear_backward_input_out::redispatch(dispatchKeySet, input_size, grad_output, weight, out); + } + + // aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mkldnn_linear_backward_weights_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) { + return at::_ops::mkldnn_linear_backward_weights_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined, out0, out1); + } + + // aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple mkldnn_linear_backward_weights_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::mkldnn_linear_backward_weights_out::redispatch(dispatchKeySet, grad_output, input, weight, bias_defined, out0, out1); + } + + // aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mkldnn_linear_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + return at::_ops::mkldnn_linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mkldnn_linear_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::mkldnn_linear_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, output_mask, out0, out1, out2); + } + + // aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple matmul_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) { + return at::_ops::matmul_backward_out::redispatch(dispatchKeySet, grad, self, other, mask, out0, out1); + } + + // aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple matmul_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::matmul_backward_out::redispatch(dispatchKeySet, grad, self, other, mask, out0, out1); + } + + // aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self) { + return at::_ops::_aminmax_out::redispatch(dispatchKeySet, self, out0, out1); + } + + // aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_aminmax_out::redispatch(dispatchKeySet, self, out0, out1); + } + + // aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, int64_t dim, bool keepdim=false) { + return at::_ops::_aminmax_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out0, out1); + } + + // aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _aminmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_aminmax_dim_out::redispatch(dispatchKeySet, self, dim, keepdim, out0, out1); + } + + // aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::mkldnn_max_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_max_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::mkldnn_max_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool1d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool1d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::quantized_max_pool1d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::quantized_max_pool2d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride={}, at::IntArrayRef padding=0, at::IntArrayRef dilation=1, bool ceil_mode=false) { + return at::_ops::quantized_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantized_max_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out) { + return at::_ops::quantized_max_pool3d_out::redispatch(dispatchKeySet, self, kernel_size, stride, padding, dilation, ceil_mode, out); + } + + // aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & median_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::median_out::redispatch(dispatchKeySet, self, out); + } + + // aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & median_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::median_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmedian_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::nanmedian_out::redispatch(dispatchKeySet, self, out); + } + + // aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & nanmedian_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::nanmedian_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mps_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::_mps_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1, out2); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, output_mask, out0, out1, out2); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask, out0, out1, out2); + } + + // aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple mps_convolution_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::mps_convolution_backward_out::redispatch(dispatchKeySet, self, grad_output, weight, padding, stride, dilation, groups, output_mask, out0, out1, out2); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::Tensor & out) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, out); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out) { + return at::_ops::mkldnn_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, out); + } + + // aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple mkldnn_rnn_layer_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { + return at::_ops::mkldnn_rnn_layer_out::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, out0, out1, out2, out3); + } + + // aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple mkldnn_rnn_layer_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::mkldnn_rnn_layer_out::redispatch(dispatchKeySet, input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train, out0, out1, out2, out3); + } + + // aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple mkldnn_rnn_layer_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) { + return at::_ops::mkldnn_rnn_layer_backward_out::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, out0, out1, out2, out3, out4, out5, out6); + } + + // aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!)) + inline ::std::tuple mkldnn_rnn_layer_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6) { + return at::_ops::mkldnn_rnn_layer_backward_out::redispatch(dispatchKeySet, input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace, out0, out1, out2, out3, out4, out5, out6); + } + + // aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + return at::_ops::miopen_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2); + } + + // aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::miopen_batch_norm_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon, out0, out1, out2); + } + + // aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon) { + return at::_ops::miopen_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, out0, out1, out2); + } + + // aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple miopen_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::miopen_batch_norm_backward_out::redispatch(dispatchKeySet, input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, out0, out1, out2); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef output_padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(output_padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_convolution_transpose_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_convolution_transpose_out::redispatch(dispatchKeySet, self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & miopen_depthwise_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out) { + return at::_ops::miopen_depthwise_convolution_out::redispatch(dispatchKeySet, self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic, out); + } + + // aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple miopen_rnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + return at::_ops::miopen_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple miopen_rnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::miopen_rnn_out::redispatch(dispatchKeySet, input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, out0, out1, out2, out3, out4); + } + + // aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void miopen_rnn_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + return at::_ops::miopen_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> () + inline void miopen_rnn_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + return at::_ops::miopen_rnn_backward_out::redispatch(dispatchKeySet, input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + + // aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sparse_matmul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::_sparse_sparse_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sparse_matmul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::_sparse_sparse_matmul_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::mul_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mul_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::mul_Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_native_batch_norm_legit_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out) + inline ::std::tuple _native_batch_norm_legit_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_functional::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, training, momentum, eps); + } + + // aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_no_training_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_native_batch_norm_legit_no_training_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2); + } + + // aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _native_batch_norm_legit_no_training_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_native_batch_norm_legit_no_training_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2); + } + + // aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, double eps) { + return at::_ops::batch_norm_stats_out::redispatch(dispatchKeySet, input, eps, out0, out1); + } + + // aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_stats_out::redispatch(dispatchKeySet, input, eps, out0, out1); + } + + // aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + return at::_ops::batch_norm_gather_stats_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1); + } + + // aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_gather_stats_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, count, out0, out1); + } + + // aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_with_counts_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + return at::_ops::batch_norm_gather_stats_with_counts_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1); + } + + // aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_gather_stats_with_counts_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_gather_stats_with_counts_out::redispatch(dispatchKeySet, input, mean, invstd, running_mean, running_var, momentum, eps, counts, out0, out1); + } + + // aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask) { + return at::_ops::native_batch_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, out0, out1, out2); + } + + // aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple native_batch_norm_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::native_batch_norm_backward_out::redispatch(dispatchKeySet, grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask, out0, out1, out2); + } + + // aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple batch_norm_backward_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + return at::_ops::batch_norm_backward_reduce_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3); + } + + // aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple batch_norm_backward_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::batch_norm_backward_reduce_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g, out0, out1, out2, out3); + } + + // aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_backward_elemt_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + return at::_ops::batch_norm_backward_elemt_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out); + } + + // aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & batch_norm_backward_elemt_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out) { + return at::_ops::batch_norm_backward_elemt_out::redispatch(dispatchKeySet, grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count, out); + } + + // aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_update_stats_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + return at::_ops::batch_norm_update_stats_out::redispatch(dispatchKeySet, input, running_mean, running_var, momentum, out0, out1); + } + + // aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple batch_norm_update_stats_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::batch_norm_update_stats_out::redispatch(dispatchKeySet, input, running_mean, running_var, momentum, out0, out1); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride=1) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride=c10::SymInt(1)) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, padding, stride, out); + } + + // aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nnpack_spatial_convolution_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::_nnpack_spatial_convolution_out::redispatch(dispatchKeySet, input, weight, bias, padding, stride, out); + } + + // aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::ones_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::ones_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::ones_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ones_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::ones_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _euclidean_dist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x1, const at::Tensor & x2) { + return at::_ops::_euclidean_dist_out::redispatch(dispatchKeySet, x1, x2, out); + } + + // aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _euclidean_dist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, at::Tensor & out) { + return at::_ops::_euclidean_dist_out::redispatch(dispatchKeySet, x1, x2, out); + } + + // aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + return at::_ops::_cdist_forward_out::redispatch(dispatchKeySet, x1, x2, p, compute_mode, out); + } + + // aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode, at::Tensor & out) { + return at::_ops::_cdist_forward_out::redispatch(dispatchKeySet, x1, x2, p, compute_mode, out); + } + + // aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) { + return at::_ops::_cdist_backward_out::redispatch(dispatchKeySet, grad, x1, x2, p, cdist, out); + } + + // aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cdist_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist, at::Tensor & out) { + return at::_ops::_cdist_backward_out::redispatch(dispatchKeySet, grad, x1, x2, p, cdist, out); + } + + // aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_forward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p=2) { + return at::_ops::_pdist_forward_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_forward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, at::Tensor & out) { + return at::_ops::_pdist_forward_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) { + return at::_ops::_pdist_backward_out::redispatch(dispatchKeySet, grad, self, p, pdist, out); + } + + // aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pdist_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist, at::Tensor & out) { + return at::_ops::_pdist_backward_out::redispatch(dispatchKeySet, grad, self, p, pdist, out); + } + + // aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_shuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t upscale_factor) { + return at::_ops::pixel_shuffle_out::redispatch(dispatchKeySet, self, upscale_factor, out); + } + + // aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_shuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t upscale_factor, at::Tensor & out) { + return at::_ops::pixel_shuffle_out::redispatch(dispatchKeySet, self, upscale_factor, out); + } + + // aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_unshuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t downscale_factor) { + return at::_ops::pixel_unshuffle_out::redispatch(dispatchKeySet, self, downscale_factor, out); + } + + // aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & pixel_unshuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t downscale_factor, at::Tensor & out) { + return at::_ops::pixel_unshuffle_out::redispatch(dispatchKeySet, self, downscale_factor, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t groups) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t groups, at::Tensor & out) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt groups) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & channel_shuffle_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt groups, at::Tensor & out) { + return at::_ops::channel_shuffle_out::redispatch(dispatchKeySet, self, groups, out); + } + + // aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pin_memory_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional device=::std::nullopt) { + return at::_ops::_pin_memory_out::redispatch(dispatchKeySet, self, device, out); + } + + // aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _pin_memory_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional device, at::Tensor & out) { + return at::_ops::_pin_memory_out::redispatch(dispatchKeySet, self, device, out); + } + + // aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scalar_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & s) { + return at::_ops::scalar_tensor_out::redispatch(dispatchKeySet, s, out); + } + + // aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & scalar_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & s, at::Tensor & out) { + return at::_ops::scalar_tensor_out::redispatch(dispatchKeySet, s, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional names) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::rand_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::rand_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rand_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::rand_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.Tensor_out(Tensor self, Tensor high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_Tensor_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.Tensor_out(Tensor self, Tensor high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_Tensor_out::redispatch(dispatchKeySet, self, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t low, int64_t high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t low, int64_t high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randint_like_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randint_like_low_dtype_out::redispatch(dispatchKeySet, self, low, high, memory_format, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), names, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional names) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), generator, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out) { + return at::_ops::randn_generator_with_names_out::redispatch(dispatchKeySet, size, generator, names, out); + } + + // aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::randn_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & randn_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::randn_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef repeats) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats), out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef repeats, at::Tensor & out) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(repeats), out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef repeats) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, repeats, out); + } + + // aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef repeats, at::Tensor & out) { + return at::_ops::repeat_out::redispatch(dispatchKeySet, self, repeats, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size, at::Tensor & out) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size.has_value() ? ::std::make_optional(c10::SymInt(*output_size)) : ::std::nullopt, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & repeats, ::std::optional output_size=::std::nullopt) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size, out); + } + + // aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & repeat_interleave_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & repeats, ::std::optional output_size, at::Tensor & out) { + return at::_ops::repeat_interleave_Tensor_out::redispatch(dispatchKeySet, repeats, output_size, out); + } + + // aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_reshape_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef shape) { + return at::_ops::_mkldnn_reshape_out::redispatch(dispatchKeySet, self, shape, out); + } + + // aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_reshape_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shape, at::Tensor & out) { + return at::_ops::_mkldnn_reshape_out::redispatch(dispatchKeySet, self, shape, out); + } + + // aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & relu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::relu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & relu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::relu_out::redispatch(dispatchKeySet, self, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t index, at::Tensor & out) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, index, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index, out); + } + + // aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index, at::Tensor & out) { + return at::_ops::select_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, index, out); + } + + // aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & celu_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & alpha=1.0) { + return at::_ops::celu_out::redispatch(dispatchKeySet, self, alpha, out); + } + + // aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & celu_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::celu_out::redispatch(dispatchKeySet, self, alpha, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, at::IntArrayRef input_sizes, int64_t dim, int64_t start, int64_t end, int64_t step, at::Tensor & out) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, c10::fromIntArrayRefSlow(input_sizes), dim, start, end, step, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step, out); + } + + // aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step, at::Tensor & out) { + return at::_ops::slice_backward_out::redispatch(dispatchKeySet, grad_output, input_sizes, dim, start, end, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, int64_t step, at::Tensor & out) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start, end, step, out); + } + + // aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out) { + return at::_ops::slice_scatter_out::redispatch(dispatchKeySet, self, src, dim, start, end, step, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, int64_t index, at::Tensor & out) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index, at::Tensor & out) { + return at::_ops::select_scatter_out::redispatch(dispatchKeySet, self, src, dim, index, out); + } + + // aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_scatter_out::redispatch(dispatchKeySet, self, src, offset, dim1, dim2, out); + } + + // aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_scatter_out::redispatch(dispatchKeySet, self, src, offset, dim1, dim2, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, size, stride, storage_offset, out); + } + + // aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_scatter_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_scatter_out::redispatch(dispatchKeySet, self, src, size, stride, storage_offset, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, int64_t split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t split_size, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymInt split_size, int64_t dim=0) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_Tensor_out::redispatch(dispatchKeySet, self, split_size, dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(split_sizes), dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_symint_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim=0) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> () + inline void unsafe_split_with_sizes_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + return at::_ops::unsafe_split_with_sizes_out::redispatch(dispatchKeySet, self, split_sizes, dim, out); + } + + // aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::sum_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::sum_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple std_mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::std_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple std_mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::std_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::prod_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::prod_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_transpose_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::_mkldnn_transpose_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _mkldnn_transpose_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out) { + return at::_ops::_mkldnn_transpose_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & flip_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::flip_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & flip_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::flip_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(shifts), dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims={}) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, shifts, dims, out); + } + + // aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & roll_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::roll_out::redispatch(dispatchKeySet, self, shifts, dims, out); + } + + // aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rot90_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t k=1, at::IntArrayRef dims={0,1}) { + return at::_ops::rot90_out::redispatch(dispatchKeySet, self, k, dims, out); + } + + // aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rot90_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t k, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::rot90_out::redispatch(dispatchKeySet, self, k, dims, out); + } + + // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _transform_bias_rescale_qkv_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) { + return at::_ops::_transform_bias_rescale_qkv_out::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads, out0, out1, out2); + } + + // aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _transform_bias_rescale_qkv_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_transform_bias_rescale_qkv_out::redispatch(dispatchKeySet, qkv, qkv_bias, num_heads, out0, out1, out2); + } + + // aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_mask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & t, const at::Tensor & mask, bool mask_check=true) { + return at::_ops::_nested_tensor_from_mask_out::redispatch(dispatchKeySet, t, mask, mask_check, out); + } + + // aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_mask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out) { + return at::_ops::_nested_tensor_from_mask_out::redispatch(dispatchKeySet, t, mask, mask_check, out); + } + + // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213=false) { + return at::_ops::_nested_from_padded_out::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213, out); + } + + // aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out) { + return at::_ops::_nested_from_padded_out::redispatch(dispatchKeySet, padded, cpu_nested_shape_example, fuse_transform_0213, out); + } + + // aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_size_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_tensor_size_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_size_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_tensor_size_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_strides_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_tensor_strides_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_strides_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_tensor_strides_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_storage_offsets_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_tensor_storage_offsets_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_storage_offsets_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_tensor_storage_offsets_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_and_nested_example_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & padded, const at::Tensor & nt_example) { + return at::_ops::_nested_from_padded_and_nested_example_out::redispatch(dispatchKeySet, padded, nt_example, out); + } + + // aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_from_padded_and_nested_example_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & padded, const at::Tensor & nt_example, at::Tensor & out) { + return at::_ops::_nested_from_padded_and_nested_example_out::redispatch(dispatchKeySet, padded, nt_example, out); + } + + // aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_buffer_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + return at::_ops::_nested_view_from_buffer_copy_out::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets, out); + } + + // aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_buffer_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets, at::Tensor & out) { + return at::_ops::_nested_view_from_buffer_copy_out::redispatch(dispatchKeySet, self, nested_size, nested_strides, offsets, out); + } + + // aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_jagged_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths={}, int64_t ragged_idx=1, const ::std::optional & min_seqlen={}, const ::std::optional & max_seqlen={}) { + return at::_ops::_nested_view_from_jagged_copy_out::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, out); + } + + // aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_view_from_jagged_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, at::Tensor & out) { + return at::_ops::_nested_view_from_jagged_copy_out::redispatch(dispatchKeySet, self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen, out); + } + + // aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_get_values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_nested_get_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_get_values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_nested_get_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _trilinear_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim=1) { + return at::_ops::_trilinear_out::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, out); + } + + // aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _trilinear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim, at::Tensor & out) { + return at::_ops::_trilinear_out::redispatch(dispatchKeySet, i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim, out); + } + + // aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _unique_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, bool sorted=true, bool return_inverse=false) { + return at::_ops::_unique_out::redispatch(dispatchKeySet, self, sorted, return_inverse, out0, out1); + } + + // aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _unique_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted, bool return_inverse, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_unique_out::redispatch(dispatchKeySet, self, sorted, return_inverse, out0, out1); + } + + // aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, int64_t dim, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim_out::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::unique_dim_out::redispatch(dispatchKeySet, self, dim, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_consecutive_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, bool return_inverse=false, bool return_counts=false, ::std::optional dim=::std::nullopt) { + return at::_ops::unique_consecutive_out::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim, out0, out1, out2); + } + + // aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_consecutive_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::unique_consecutive_out::redispatch(dispatchKeySet, self, return_inverse, return_counts, dim, out0, out1, out2); + } + + // aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_consecutive_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, int64_t dim, bool return_inverse=false, bool return_counts=false) { + return at::_ops::unique_dim_consecutive_out::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts, out0, out1, out2); + } + + // aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple unique_dim_consecutive_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::unique_dim_consecutive_out::redispatch(dispatchKeySet, self, dim, return_inverse, return_counts, out0, out1, out2); + } + + // aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _unique2_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & self, bool sorted=true, bool return_inverse=false, bool return_counts=false) { + return at::_ops::_unique2_out::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _unique2_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_unique2_out::redispatch(dispatchKeySet, self, sorted, return_inverse, return_counts, out0, out1, out2); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _unsafe_view_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::_unsafe_view_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple var_mean_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, at::OptionalIntArrayRef dim=::std::nullopt, const ::std::optional & correction=::std::nullopt, bool keepdim=false) { + return at::_ops::var_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple var_mean_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::var_mean_correction_out::redispatch(dispatchKeySet, self, dim, correction, keepdim, out0, out1); + } + + // aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & v, const at::Tensor & g, int64_t dim=0) { + return at::_ops::_weight_norm_interface_out::redispatch(dispatchKeySet, v, g, dim, out0, out1); + } + + // aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & v, const at::Tensor & g, int64_t dim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_weight_norm_interface_out::redispatch(dispatchKeySet, v, g, dim, out0, out1); + } + + // aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + return at::_ops::_weight_norm_interface_backward_out::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim, out0, out1); + } + + // aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _weight_norm_interface_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_weight_norm_interface_backward_out::redispatch(dispatchKeySet, grad_w, saved_v, saved_g, saved_norms, dim, out0, out1); + } + + // aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size, ::std::optional names) { + return at::_ops::zeros_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, ::std::optional names, at::Tensor & out) { + return at::_ops::zeros_names_out::redispatch(dispatchKeySet, size, names, out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, c10::fromIntArrayRefSlow(size), out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, c10::SymIntArrayRef size) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, size, out); + } + + // aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _efficientzerotensor_symint_outf(c10::DispatchKeySet dispatchKeySet, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::_efficientzerotensor_out::redispatch(dispatchKeySet, size, out); + } + + // aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_like_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::zeros_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zeros_like_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::zeros_like_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_grad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & output) { + return at::_ops::_standard_gamma_grad_out::redispatch(dispatchKeySet, self, output, out); + } + + // aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_grad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & output, at::Tensor & out) { + return at::_ops::_standard_gamma_grad_out::redispatch(dispatchKeySet, self, output, out); + } + + // aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_standard_gamma_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _standard_gamma_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::_standard_gamma_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _dirichlet_grad_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) { + return at::_ops::_dirichlet_grad_out::redispatch(dispatchKeySet, x, alpha, total, out); + } + + // aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _dirichlet_grad_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total, at::Tensor & out) { + return at::_ops::_dirichlet_grad_out::redispatch(dispatchKeySet, x, alpha, total, out); + } + + // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sample_dirichlet_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::_sample_dirichlet_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sample_dirichlet_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::_sample_dirichlet_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & poisson_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::poisson_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & poisson_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::poisson_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binomial_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & count, const at::Tensor & prob, ::std::optional generator=::std::nullopt) { + return at::_ops::binomial_out::redispatch(dispatchKeySet, count, prob, generator, out); + } + + // aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & binomial_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & count, const at::Tensor & prob, ::std::optional generator, at::Tensor & out) { + return at::_ops::binomial_out::redispatch(dispatchKeySet, count, prob, generator, out); + } + + // aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::native_norm_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, at::Tensor & out) { + return at::_ops::native_norm_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + return at::_ops::native_norm_ScalarOpt_dim_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & native_norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::native_norm_ScalarOpt_dim_dtype_out::redispatch(dispatchKeySet, self, p, dim, keepdim, dtype, out); + } + + // aten::_batch_norm_with_update_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out) + inline ::std::tuple _batch_norm_with_update_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_with_update_functional::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps); + } + + // aten::_batch_norm_no_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _batch_norm_no_update_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps) { + return at::_ops::_batch_norm_no_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2, out3); + } + + // aten::_batch_norm_no_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!)) + inline ::std::tuple _batch_norm_no_update_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3) { + return at::_ops::_batch_norm_no_update_out::redispatch(dispatchKeySet, input, weight, bias, running_mean, running_var, momentum, eps, out0, out1, out2, out3); + } + + // aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::_sparse_sum_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::_sparse_sum_backward_out::redispatch(dispatchKeySet, grad, self, dim, out); + } + + // aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_sum_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::_sparse_sum_backward_out::redispatch(dispatchKeySet, grad, self, dim, out); + } + + // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_sum_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_sum_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_sum_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::_sparse_csr_sum_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_prod_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim, bool keepdim=false, ::std::optional dtype=::std::nullopt) { + return at::_ops::_sparse_csr_prod_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_csr_prod_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out) { + return at::_ops::_sparse_csr_prod_dim_dtype_out::redispatch(dispatchKeySet, self, dim, keepdim, dtype, out); + } + + // aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_sparse_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_sparse_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, bool half_to_float) { + return at::_ops::_sparse_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out) { + return at::_ops::_sparse_log_softmax_out::redispatch(dispatchKeySet, self, dim, half_to_float, out); + } + + // aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_backward_data_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + return at::_ops::_sparse_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_log_softmax_backward_data_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_sparse_log_softmax_backward_data_out::redispatch(dispatchKeySet, grad_output, output, dim, self, out); + } + + // aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _spdiags_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout=::std::nullopt) { + return at::_ops::_spdiags_out::redispatch(dispatchKeySet, diagonals, offsets, shape, layout, out); + } + + // aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _spdiags_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout, at::Tensor & out) { + return at::_ops::_spdiags_out::redispatch(dispatchKeySet, diagonals, offsets, shape, layout, out); + } + + // aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype) { + return at::_ops::norm_ScalarOpt_dtype_out::redispatch(dispatchKeySet, self, p, dtype, out); + } + + // aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::norm_ScalarOpt_dtype_out::redispatch(dispatchKeySet, self, p, dtype, out); + } + + // aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & p=2) { + return at::_ops::norm_Scalar_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & norm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & p, at::Tensor & out) { + return at::_ops::norm_Scalar_out::redispatch(dispatchKeySet, self, p, out); + } + + // aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clone_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional memory_format=::std::nullopt) { + return at::_ops::clone_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & clone_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::clone_out::redispatch(dispatchKeySet, self, memory_format, out); + } + + // aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_as_out::redispatch(dispatchKeySet, self, the_template, memory_format, out); + } + + // aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format, const at::Tensor & out) { + return at::_ops::resize_as_out::redispatch(dispatchKeySet, self, the_template, memory_format, out); + } + + // aten::resize_as(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor + inline at::Tensor resize_as(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format=::std::nullopt) { + return at::_ops::resize_as::redispatch(dispatchKeySet, self, the_template, memory_format); + } + + // aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_sparse_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, const at::Tensor & the_template) { + return at::_ops::resize_as_sparse_out::redispatch(dispatchKeySet, self, the_template, out); + } + + // aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & resize_as_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template, const at::Tensor & out) { + return at::_ops::resize_as_sparse_out::redispatch(dispatchKeySet, self, the_template, out); + } + + // aten::resize_as_sparse(Tensor self, Tensor the_template) -> Tensor + inline at::Tensor resize_as_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & the_template) { + return at::_ops::resize_as_sparse::redispatch(dispatchKeySet, self, the_template); + } + + // aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zero_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & zero_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::zero(Tensor self) -> Tensor + inline at::Tensor zero(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::zero::redispatch(dispatchKeySet, self); + } + + // aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::sub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::sub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::rsub_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha=1) { + return at::_ops::rsub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rsub_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::rsub_Scalar_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_addmm_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta=1, const at::Scalar & alpha=1) { + return at::_ops::_sparse_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_addmm_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out) { + return at::_ops::_sparse_addmm_out::redispatch(dispatchKeySet, self, mat1, mat2, beta, alpha, out); + } + + // aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_coo_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::IntArrayRef size) { + return at::_ops::sparse_coo_tensor_size_out::redispatch(dispatchKeySet, size, out); + } + + // aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_coo_tensor_outf(c10::DispatchKeySet dispatchKeySet, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::sparse_coo_tensor_size_out::redispatch(dispatchKeySet, size, out); + } + + // aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size) { + return at::_ops::_sparse_coo_tensor_with_dims_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, out); + } + + // aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_sparse_coo_tensor_with_dims_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, is_coalesced, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced, at::Tensor & out) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, c10::fromIntArrayRefSlow(size), indices, values, is_coalesced, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced=::std::nullopt) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, is_coalesced, out); + } + + // aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_symint_outf(c10::DispatchKeySet dispatchKeySet, int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced, at::Tensor & out) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors_out::redispatch(dispatchKeySet, sparse_dim, dense_dim, size, indices, values, is_coalesced, out); + } + + // aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out) { + return at::_ops::sparse_resize_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor + inline at::Tensor sparse_resize(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_and_clear_out(c10::DispatchKeySet dispatchKeySet, const at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_and_clear_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!) + inline const at::Tensor & sparse_resize_and_clear_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out) { + return at::_ops::sparse_resize_and_clear_out::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim, out); + } + + // aten::sparse_resize_and_clear(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor + inline at::Tensor sparse_resize_and_clear(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + return at::_ops::sparse_resize_and_clear::redispatch(dispatchKeySet, self, size, sparse_dim, dense_dim); + } + + // aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_mask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask) { + return at::_ops::sparse_mask_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & sparse_mask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, at::Tensor & out) { + return at::_ops::sparse_mask_out::redispatch(dispatchKeySet, self, mask, out); + } + + // aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_mask_projection_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches=false) { + return at::_ops::_sparse_mask_projection_out::redispatch(dispatchKeySet, self, mask, accumulate_matches, out); + } + + // aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_mask_projection_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out) { + return at::_ops::_sparse_mask_projection_out::redispatch(dispatchKeySet, self, mask, accumulate_matches, out); + } + + // aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_dense_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt, ::std::optional masked_grad=::std::nullopt) { + return at::_ops::_to_dense_out::redispatch(dispatchKeySet, self, dtype, masked_grad, out); + } + + // aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_dense_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad, at::Tensor & out) { + return at::_ops::_to_dense_out::redispatch(dispatchKeySet, self, dtype, masked_grad, out); + } + + // aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_coalesce_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_coalesce_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesced_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool coalesced) { + return at::_ops::_coalesced_out::redispatch(dispatchKeySet, self, coalesced, out); + } + + // aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _coalesced_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool coalesced, at::Tensor & out) { + return at::_ops::_coalesced_out::redispatch(dispatchKeySet, self, coalesced, out); + } + + // aten::_coalesced(Tensor self, bool coalesced) -> Tensor + inline at::Tensor _coalesced(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool coalesced) { + return at::_ops::_coalesced::redispatch(dispatchKeySet, self, coalesced); + } + + // aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_sparse_to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_sparse_to_sparse_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & copy_sparse_to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out) { + return at::_ops::copy_sparse_to_sparse_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::copy_sparse_to_sparse(Tensor self, Tensor src, bool non_blocking=False) -> Tensor + inline at::Tensor copy_sparse_to_sparse(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & src, bool non_blocking=false) { + return at::_ops::copy_sparse_to_sparse::redispatch(dispatchKeySet, self, src, non_blocking); + } + + // aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t sparse_dim) { + return at::_ops::_to_sparse_sparse_dim_out::redispatch(dispatchKeySet, self, sparse_dim, out); + } + + // aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t sparse_dim, at::Tensor & out) { + return at::_ops::_to_sparse_sparse_dim_out::redispatch(dispatchKeySet, self, sparse_dim, out); + } + + // aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional layout=::std::nullopt, at::OptionalIntArrayRef blocksize=::std::nullopt, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_out::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim, out); + } + + // aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_out::redispatch(dispatchKeySet, self, layout, blocksize, dense_dim, out); + } + + // aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csr_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_csr_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_csc_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_csc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_csc_out::redispatch(dispatchKeySet, self, dense_dim, out); + } + + // aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsr_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_bsr_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsc_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim=::std::nullopt) { + return at::_ops::_to_sparse_bsc_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_sparse_bsc_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out) { + return at::_ops::_to_sparse_bsc_out::redispatch(dispatchKeySet, self, blocksize, dense_dim, out); + } + + // aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_mkldnn_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional dtype=::std::nullopt) { + return at::_ops::to_mkldnn_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_mkldnn_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional dtype, at::Tensor & out) { + return at::_ops::to_mkldnn_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::OptionalIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv2d_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv2d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef padding=0, at::IntArrayRef stride=1, at::IntArrayRef dilation=1, int64_t groups=1, at::OptionalIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef padding, at::IntArrayRef stride, at::IntArrayRef dilation, int64_t groups, at::OptionalIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(dilation), groups, input_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*input_size)) : ::std::nullopt, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef dilation=c10::SymInt(1), c10::SymInt groups=1, at::OptionalSymIntArrayRef input_size=::std::nullopt) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_reorder_conv3d_weight_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out) { + return at::_ops::mkldnn_reorder_conv3d_weight_out::redispatch(dispatchKeySet, self, padding, stride, dilation, groups, input_size, out); + } + + // aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_dynamic_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::ScalarType dtype, bool reduce_range) { + return at::_ops::quantize_per_tensor_dynamic_out::redispatch(dispatchKeySet, self, dtype, reduce_range, out); + } + + // aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_dynamic_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, bool reduce_range, at::Tensor & out) { + return at::_ops::quantize_per_tensor_dynamic_out::redispatch(dispatchKeySet, self, dtype, reduce_range, out); + } + + // aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::quantize_per_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::quantize_per_tensor_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, dtype, out); + } + + // aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> () + inline void quantize_per_tensor_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) { + return at::_ops::quantize_per_tensor_tensors_out::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype, out); + } + + // aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> () + inline void quantize_per_tensor_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out) { + return at::_ops::quantize_per_tensor_tensors_out::redispatch(dispatchKeySet, tensors, scales, zero_points, dtype, out); + } + + // aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_channel_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) { + return at::_ops::quantize_per_channel_out::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype, out); + } + + // aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & quantize_per_channel_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::quantize_per_channel_out::redispatch(dispatchKeySet, self, scales, zero_points, axis, dtype, out); + } + + // aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dequantize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::dequantize_self_out::redispatch(dispatchKeySet, self, out); + } + + // aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dequantize_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::dequantize_self_out::redispatch(dispatchKeySet, self, out); + } + + // aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> () + inline void dequantize_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList tensors) { + return at::_ops::dequantize_tensors_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> () + inline void dequantize_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList tensors, at::TensorList out) { + return at::_ops::dequantize_tensors_out::redispatch(dispatchKeySet, tensors, out); + } + + // aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_scales_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::q_per_channel_scales_out::redispatch(dispatchKeySet, self, out); + } + + // aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_scales_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::q_per_channel_scales_out::redispatch(dispatchKeySet, self, out); + } + + // aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_zero_points_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::q_per_channel_zero_points_out::redispatch(dispatchKeySet, self, out); + } + + // aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & q_per_channel_zero_points_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::q_per_channel_zero_points_out::redispatch(dispatchKeySet, self, out); + } + + // aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & int_repr_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::int_repr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & int_repr_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::int_repr_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_tensor_quantized_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double scale, int64_t zero_point) { + return at::_ops::_make_per_tensor_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, out); + } + + // aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_tensor_quantized_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, at::Tensor & out) { + return at::_ops::_make_per_tensor_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, out); + } + + // aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_channel_quantized_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) { + return at::_ops::_make_per_channel_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, out); + } + + // aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_per_channel_quantized_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, at::Tensor & out) { + return at::_ops::_make_per_channel_quantized_tensor_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, out); + } + + // aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_tensor_affine_cachemask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, out0, out1); + } + + // aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_tensor_affine_cachemask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out::redispatch(dispatchKeySet, self, scale, zero_point, fake_quant_enabled, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_tensor_affine_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor, out); + } + + // aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_tensor_affine_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, quant_min, quant_max, grad_factor, out); + } + + // aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_channel_affine_cachemask_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, out0, out1); + } + + // aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple fake_quantize_per_channel_affine_cachemask_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, out0, out1); + } + + // aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_channel_affine_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor=1.0) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor, out); + } + + // aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fake_quantize_learnable_per_channel_affine_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_out::redispatch(dispatchKeySet, self, scale, zero_point, axis, quant_min, quant_max, grad_factor, out); + } + + // aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) + inline ::std::tuple _fused_moving_avg_obs_fq_helper_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::_fused_moving_avg_obs_fq_helper_out::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant, out0, out1); + } + + // aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!)) + inline ::std::tuple _fused_moving_avg_obs_fq_helper_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_fused_moving_avg_obs_fq_helper_out::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant, out0, out1); + } + + // aten::_fused_moving_avg_obs_fq_helper_functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out) + inline ::std::tuple _fused_moving_avg_obs_fq_helper_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant=false, bool symmetric_quant=false) { + return at::_ops::_fused_moving_avg_obs_fq_helper_functional::redispatch(dispatchKeySet, self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + + // aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool non_blocking=false, ::std::optional memory_format=::std::nullopt) { + return at::_ops::_to_copy_out::redispatch(dispatchKeySet, self, non_blocking, memory_format, out); + } + + // aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _to_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool non_blocking, ::std::optional memory_format, at::Tensor & out) { + return at::_ops::_to_copy_out::redispatch(dispatchKeySet, self, non_blocking, memory_format, out); + } + + // aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _lstm_mps_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::_lstm_mps_out::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5); + } + + // aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!)) + inline ::std::tuple _lstm_mps_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5) { + return at::_ops::_lstm_mps_out::redispatch(dispatchKeySet, input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2, out3, out4, out5); + } + + // aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> () + inline void lstm_mps_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::TensorList out1, at::TensorList out2, const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + return at::_ops::lstm_mps_backward_out::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2); + } + + // aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> () + inline void lstm_mps_backward_outf(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2) { + return at::_ops::lstm_mps_backward_out::redispatch(dispatchKeySet, grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_lstm_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_thnn_fused_lstm_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, cx, input_bias, hidden_bias, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_backward_impl_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl_out::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias, out0, out1, out2); + } + + // aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _thnn_fused_lstm_cell_backward_impl_outf(c10::DispatchKeySet dispatchKeySet, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl_out::redispatch(dispatchKeySet, grad_hy, grad_cy, cx, cy, workspace, has_bias, out0, out1, out2); + } + + // aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _thnn_fused_gru_cell_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias={}, const ::std::optional & hidden_bias={}) { + return at::_ops::_thnn_fused_gru_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias, out0, out1); + } + + // aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _thnn_fused_gru_cell_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_thnn_fused_gru_cell_out::redispatch(dispatchKeySet, input_gates, hidden_gates, hx, input_bias, hidden_bias, out0, out1); + } + + // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _thnn_fused_gru_cell_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + return at::_ops::_thnn_fused_gru_cell_backward_out::redispatch(dispatchKeySet, grad_hy, workspace, has_bias, out0, out1, out2, out3, out4); + } + + // aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!)) + inline ::std::tuple _thnn_fused_gru_cell_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4) { + return at::_ops::_thnn_fused_gru_cell_backward_out::redispatch(dispatchKeySet, grad_hy, workspace, has_bias, out0, out1, out2, out3, out4); + } + + // aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _pack_padded_sequence_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & input, const at::Tensor & lengths, bool batch_first) { + return at::_ops::_pack_padded_sequence_out::redispatch(dispatchKeySet, input, lengths, batch_first, out0, out1); + } + + // aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _pack_padded_sequence_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, const at::Tensor & lengths, bool batch_first, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_pack_padded_sequence_out::redispatch(dispatchKeySet, input, lengths, batch_first, out0, out1); + } + + // aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source) { + return at::_ops::set_source_Storage_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, at::Tensor & out) { + return at::_ops::set_source_Storage_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Storage(Tensor self, Storage source) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source) { + return at::_ops::set_source_Storage::redispatch(dispatchKeySet, self, source); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, size, stride, out); + } + + // aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::set_source_Storage_storage_offset_out::redispatch(dispatchKeySet, self, source, storage_offset, size, stride, out); + } + + // aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, int64_t storage_offset, at::IntArrayRef size, at::IntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride)); + } + + // aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor + inline at::Tensor set_symint(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride={}) { + return at::_ops::set_source_Storage_storage_offset::redispatch(dispatchKeySet, self, source, storage_offset, size, stride); + } + + // aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & source) { + return at::_ops::set_source_Tensor_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & source, at::Tensor & out) { + return at::_ops::set_source_Tensor_out::redispatch(dispatchKeySet, self, source, out); + } + + // aten::set.source_Tensor(Tensor self, Tensor source) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & source) { + return at::_ops::set_source_Tensor::redispatch(dispatchKeySet, self, source); + } + + // aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::set_out::redispatch(dispatchKeySet, self, out); + } + + // aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & set_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::set_out::redispatch(dispatchKeySet, self, out); + } + + // aten::set(Tensor self) -> Tensor + inline at::Tensor set(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self) { + return at::_ops::set::redispatch(dispatchKeySet, self); + } + + // aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::lift_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::lift_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_fresh_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::lift_fresh_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & lift_fresh_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::lift_fresh_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + return at::_ops::masked_fill_Scalar_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value, at::Tensor & out) { + return at::_ops::masked_fill_Scalar_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + return at::_ops::masked_fill_Tensor_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value, at::Tensor & out) { + return at::_ops::masked_fill_Tensor_out::redispatch(dispatchKeySet, self, mask, value, out); + } + + // aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_scatter_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + return at::_ops::masked_scatter_out::redispatch(dispatchKeySet, self, mask, source, out); + } + + // aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & masked_scatter_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source, at::Tensor & out) { + return at::_ops::masked_scatter_out::redispatch(dispatchKeySet, self, mask, source, out); + } + + // aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & mask, ::std::optional dim=::std::nullopt, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_masked_softmax_out::redispatch(dispatchKeySet, self, mask, dim, mask_type, out); + } + + // aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type, at::Tensor & out) { + return at::_ops::_masked_softmax_out::redispatch(dispatchKeySet, self, mask, dim, mask_type, out); + } + + // aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim=::std::nullopt) { + return at::_ops::_masked_softmax_backward_out::redispatch(dispatchKeySet, grad_output, output, mask, dim, out); + } + + // aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _masked_softmax_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim, at::Tensor & out) { + return at::_ops::_masked_softmax_backward_out::redispatch(dispatchKeySet, grad_output, output, mask, dim, out); + } + + // aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & put_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate=false) { + return at::_ops::put_out::redispatch(dispatchKeySet, self, index, source, accumulate, out); + } + + // aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & put_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate, at::Tensor & out) { + return at::_ops::put_out::redispatch(dispatchKeySet, self, index, source, accumulate, out); + } + + // aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + return at::_ops::index_fill_int_Scalar_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out) { + return at::_ops::index_fill_int_Scalar_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + return at::_ops::index_fill_int_Tensor_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & index_fill_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value, at::Tensor & out) { + return at::_ops::index_fill_int_Tensor_out::redispatch(dispatchKeySet, self, dim, index, value, out); + } + + // aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_and_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_and_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_and_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_or_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_or_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_or_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_xor_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_xor_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_xor_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__lshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::__lshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__lshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __lshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::__lshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_left_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_left_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_left_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Scalar & other) { + return at::_ops::__rshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Scalar & other, at::Tensor & out) { + return at::_ops::__rshift___Scalar_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other) { + return at::_ops::__rshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & __rshift___outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::__rshift___Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::bitwise_right_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bitwise_right_shift_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::bitwise_right_shift_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_from_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator, at::Tensor & out) { + return at::_ops::random_from_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::random.from(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor + inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_from::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_to_out::redispatch(dispatchKeySet, self, to, generator, out); + } + + // aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t to, ::std::optional generator, at::Tensor & out) { + return at::_ops::random_to_out::redispatch(dispatchKeySet, self, to, generator, out); + } + + // aten::random.to(Tensor self, int to, *, Generator? generator=None) -> Tensor + inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t to, ::std::optional generator=::std::nullopt) { + return at::_ops::random_to::redispatch(dispatchKeySet, self, to, generator); + } + + // aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::random_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & random_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator, at::Tensor & out) { + return at::_ops::random_out::redispatch(dispatchKeySet, self, generator, out); + } + + // aten::random(Tensor self, *, Generator? generator=None) -> Tensor + inline at::Tensor random(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, ::std::optional generator=::std::nullopt) { + return at::_ops::random::redispatch(dispatchKeySet, self, generator); + } + + // aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & uniform_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double from=0, double to=1, ::std::optional generator=::std::nullopt) { + return at::_ops::uniform_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & uniform_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double from, double to, ::std::optional generator, at::Tensor & out) { + return at::_ops::uniform_out::redispatch(dispatchKeySet, self, from, to, generator, out); + } + + // aten::uniform(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor + inline at::Tensor uniform(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double from=0, double to=1, ::std::optional generator=::std::nullopt) { + return at::_ops::uniform::redispatch(dispatchKeySet, self, from, to, generator); + } + + // aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cauchy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double median=0, double sigma=1, ::std::optional generator=::std::nullopt) { + return at::_ops::cauchy_out::redispatch(dispatchKeySet, self, median, sigma, generator, out); + } + + // aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & cauchy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double median, double sigma, ::std::optional generator, at::Tensor & out) { + return at::_ops::cauchy_out::redispatch(dispatchKeySet, self, median, sigma, generator, out); + } + + // aten::cauchy(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor + inline at::Tensor cauchy(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double median=0, double sigma=1, ::std::optional generator=::std::nullopt) { + return at::_ops::cauchy::redispatch(dispatchKeySet, self, median, sigma, generator); + } + + // aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double mean=1, double std=2, ::std::optional generator=::std::nullopt) { + return at::_ops::log_normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & log_normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out) { + return at::_ops::log_normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::log_normal(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor + inline at::Tensor log_normal(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean=1, double std=2, ::std::optional generator=::std::nullopt) { + return at::_ops::log_normal::redispatch(dispatchKeySet, self, mean, std, generator); + } + + // aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exponential_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double lambd=1, ::std::optional generator=::std::nullopt) { + return at::_ops::exponential_out::redispatch(dispatchKeySet, self, lambd, generator, out); + } + + // aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & exponential_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double lambd, ::std::optional generator, at::Tensor & out) { + return at::_ops::exponential_out::redispatch(dispatchKeySet, self, lambd, generator, out); + } + + // aten::exponential(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor + inline at::Tensor exponential(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double lambd=1, ::std::optional generator=::std::nullopt) { + return at::_ops::exponential::redispatch(dispatchKeySet, self, lambd, generator); + } + + // aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & geometric_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::geometric_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & geometric_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out) { + return at::_ops::geometric_out::redispatch(dispatchKeySet, self, p, generator, out); + } + + // aten::geometric(Tensor self, float p, *, Generator? generator=None) -> Tensor + inline at::Tensor geometric(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double p, ::std::optional generator=::std::nullopt) { + return at::_ops::geometric::redispatch(dispatchKeySet, self, p, generator); + } + + // aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t row, int64_t col, int64_t offset=0) { + return at::_ops::tril_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & tril_indices_outf(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, at::Tensor & out) { + return at::_ops::tril_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_indices_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, int64_t row, int64_t col, int64_t offset=0) { + return at::_ops::triu_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & triu_indices_outf(c10::DispatchKeySet dispatchKeySet, int64_t row, int64_t col, int64_t offset, at::Tensor & out) { + return at::_ops::triu_indices_out::redispatch(dispatchKeySet, row, col, offset, out); + } + + // aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trace_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::trace_out::redispatch(dispatchKeySet, self, out); + } + + // aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & trace_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::trace_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cholesky_solve_helper_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & A, bool upper) { + return at::_ops::_cholesky_solve_helper_out::redispatch(dispatchKeySet, self, A, upper, out); + } + + // aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _cholesky_solve_helper_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & A, bool upper, at::Tensor & out) { + return at::_ops::_cholesky_solve_helper_out::redispatch(dispatchKeySet, self, A, upper, out); + } + + // aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p=2) { + return at::_ops::dist_out::redispatch(dispatchKeySet, self, other, p, out); + } + + // aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & dist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & other, const at::Scalar & p, at::Tensor & out) { + return at::_ops::dist_out::redispatch(dispatchKeySet, self, other, p, out); + } + + // aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> () + inline void _histogramdd_bin_edges_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_bin_edges_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> () + inline void _histogramdd_bin_edges_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::TensorList out) { + return at::_ops::_histogramdd_bin_edges_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_cts_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range=::std::nullopt, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_cts_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_cts_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & out) { + return at::_ops::_histogramdd_from_bin_cts_out::redispatch(dispatchKeySet, self, bins, range, weight, density, out); + } + + // aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_tensors_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::TensorList bins, const ::std::optional & weight={}, bool density=false) { + return at::_ops::_histogramdd_from_bin_tensors_out::redispatch(dispatchKeySet, self, bins, weight, density, out); + } + + // aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _histogramdd_from_bin_tensors_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density, at::Tensor & out) { + return at::_ops::_histogramdd_from_bin_tensors_out::redispatch(dispatchKeySet, self, bins, weight, density, out); + } + + // aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & other) { + return at::_ops::remainder_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & remainder_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & other, at::Tensor & out) { + return at::_ops::remainder_Scalar_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, at::IntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, c10::fromIntArrayRefSlow(input_sizes), dim, size, step, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step, out); + } + + // aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out) { + return at::_ops::unfold_backward_out::redispatch(dispatchKeySet, grad_in, input_sizes, dim, size, step, out); + } + + // aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double mean=0, double std=1, ::std::optional generator=::std::nullopt) { + return at::_ops::normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & normal_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out) { + return at::_ops::normal_out::redispatch(dispatchKeySet, self, mean, std, generator, out); + } + + // aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () + inline void _amp_foreach_non_finite_check_and_unscale_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::redispatch(dispatchKeySet, self, found_inf, inv_scale, out); + } + + // aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> () + inline void _amp_foreach_non_finite_check_and_unscale_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::redispatch(dispatchKeySet, self, found_inf, inv_scale, out); + } + + // aten::_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out) + inline ::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale::redispatch(dispatchKeySet, self, found_inf, inv_scale); + } + + // aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _amp_update_scale_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + return at::_ops::_amp_update_scale_out::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out); + } + + // aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _amp_update_scale_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor & out) { + return at::_ops::_amp_update_scale_out::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval, out); + } + + // aten::_amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out) + inline ::std::tuple _amp_update_scale(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + return at::_ops::_amp_update_scale::redispatch(dispatchKeySet, self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + } + + // aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_add_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_add_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + return at::_ops::_foreach_add_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_add_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_add_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_add_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_add_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out) { + return at::_ops::_foreach_add_Tensor_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_sub_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_sub_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other, const at::Scalar & alpha=1) { + return at::_ops::_foreach_sub_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> () + inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + return at::_ops::_foreach_sub_List_out::redispatch(dispatchKeySet, self, other, alpha, out); + } + + // aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_sub_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_sub_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_sub_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_mul_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_mul_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_mul_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_mul_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_mul_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_mul_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_mul_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_mul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, at::TensorList out) { + return at::_ops::_foreach_mul_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_div_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_div_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_div_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_div_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_div_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_div_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Tensor & other) { + return at::_ops::_foreach_div_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> () + inline void _foreach_div_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Tensor & other, at::TensorList out) { + return at::_ops::_foreach_div_Tensor_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_max_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_clamp_max_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_max_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_clamp_max_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_max_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_clamp_max_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_clamp_min_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_clamp_min_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_clamp_min_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_clamp_min_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_clamp_min_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_clamp_min_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_clamp_min_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_maximum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_maximum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_maximum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_maximum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_maximum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_maximum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_maximum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & scalar) { + return at::_ops::_foreach_minimum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + return at::_ops::_foreach_minimum_Scalar_out::redispatch(dispatchKeySet, self, scalar, out); + } + + // aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList other) { + return at::_ops::_foreach_minimum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList other, at::TensorList out) { + return at::_ops::_foreach_minimum_List_out::redispatch(dispatchKeySet, self, other, out); + } + + // aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef scalars) { + return at::_ops::_foreach_minimum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_minimum_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_minimum_ScalarList_out::redispatch(dispatchKeySet, self, scalars, out); + } + + // aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcdiv_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + return at::_ops::_foreach_addcdiv_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcdiv_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_addcdiv_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcdiv_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcdiv_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + return at::_ops::_foreach_addcdiv_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value=1) { + return at::_ops::_foreach_addcmul_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + return at::_ops::_foreach_addcmul_Scalar_out::redispatch(dispatchKeySet, self, tensor1, tensor2, value, out); + } + + // aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + return at::_ops::_foreach_addcmul_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + return at::_ops::_foreach_addcmul_ScalarList_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + return at::_ops::_foreach_addcmul_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> () + inline void _foreach_addcmul_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + return at::_ops::_foreach_addcmul_Tensor_out::redispatch(dispatchKeySet, self, tensor1, tensor2, scalars, out); + } + + // aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_abs_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_abs_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_abs_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_acos_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_acos_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_acos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_asin_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_asin_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_asin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_atan_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_atan_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_atan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_ceil_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_ceil_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_ceil_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cos_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cos_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_cos_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cosh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_cosh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_cosh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erf_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erf_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_erf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erfc_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_erfc_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_erfc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_exp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_exp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_expm1_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_expm1_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_expm1_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_floor_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_floor_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_floor_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_frac_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_frac_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_frac_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + return at::_ops::_foreach_lerp_List_out::redispatch(dispatchKeySet, self, tensors1, weights, out); + } + + // aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out) { + return at::_ops::_foreach_lerp_List_out::redispatch(dispatchKeySet, self, tensors1, weights, out); + } + + // aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + return at::_ops::_foreach_lerp_Scalar_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out) { + return at::_ops::_foreach_lerp_Scalar_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lerp.ScalarList_out(Tensor[] self, Tensor[] tensors1, Scalar[] weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + return at::_ops::_foreach_lerp_ScalarList_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lerp.ScalarList_out(Tensor[] self, Tensor[] tensors1, Scalar[] weight, *, Tensor(a!)[] out) -> () + inline void _foreach_lerp_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList tensors1, at::ArrayRef weight, at::TensorList out) { + return at::_ops::_foreach_lerp_ScalarList_out::redispatch(dispatchKeySet, self, tensors1, weight, out); + } + + // aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_lgamma_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_lgamma_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_lgamma_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log10_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log10_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log10_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log1p_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log1p_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log1p_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log2_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_log2_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_log2_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_max_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_max_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_max_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_max_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_neg_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_neg_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_neg_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, ScalarType? dtype=None, *, Tensor(a!)[] out) -> () + inline void _foreach_norm_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & ord=2, ::std::optional dtype=::std::nullopt) { + return at::_ops::_foreach_norm_Scalar_out::redispatch(dispatchKeySet, self, ord, dtype, out); + } + + // aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, ScalarType? dtype=None, *, Tensor(a!)[] out) -> () + inline void _foreach_norm_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & ord, ::std::optional dtype, at::TensorList out) { + return at::_ops::_foreach_norm_Scalar_out::redispatch(dispatchKeySet, self, ord, dtype, out); + } + + // aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList exponent) { + return at::_ops::_foreach_pow_List_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList exponent, at::TensorList out) { + return at::_ops::_foreach_pow_List_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, const at::Scalar & exponent) { + return at::_ops::_foreach_pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, const at::Scalar & exponent, at::TensorList out) { + return at::_ops::_foreach_pow_Scalar_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::ArrayRef exponent) { + return at::_ops::_foreach_pow_ScalarList_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> () + inline void _foreach_pow_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::ArrayRef exponent, at::TensorList out) { + return at::_ops::_foreach_pow_ScalarList_out::redispatch(dispatchKeySet, self, exponent, out); + } + + // aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_reciprocal_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_reciprocal_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_reciprocal_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_round_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_round_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_round_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_rsqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_rsqrt_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_rsqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_rsqrt_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_rsqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sigmoid_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sigmoid_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sigmoid_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sign_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sign_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sign_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sin_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sin_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sin_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sinh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sinh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sinh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sqrt_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_sqrt_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_sqrt_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tan_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tan_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_tan_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tanh_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_tanh_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_tanh_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_trunc_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_trunc_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_trunc_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_zero_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self) { + return at::_ops::_foreach_zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> () + inline void _foreach_zero_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList out) { + return at::_ops::_foreach_zero_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_foreach_zero(Tensor[] self) -> Tensor[] self_out + inline ::std::vector _foreach_zero(c10::DispatchKeySet dispatchKeySet, at::TensorList self) { + return at::_ops::_foreach_zero::redispatch(dispatchKeySet, self); + } + + // aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> () + inline void _foreach_copy_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList src, bool non_blocking=false) { + return at::_ops::_foreach_copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> () + inline void _foreach_copy_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out) { + return at::_ops::_foreach_copy_out::redispatch(dispatchKeySet, self, src, non_blocking, out); + } + + // aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32=false, bool right=false) { + return at::_ops::bucketize_Scalar_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & bucketize_outf(c10::DispatchKeySet dispatchKeySet, const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out) { + return at::_ops::bucketize_Scalar_out::redispatch(dispatchKeySet, self, boundaries, out_int32, right, out); + } + + // aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_jvp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_jvp_out::redispatch(dispatchKeySet, glu, x, dx, dim, out); + } + + // aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_jvp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim, at::Tensor & out) { + return at::_ops::glu_jvp_out::redispatch(dispatchKeySet, glu, x, dx, dim, out); + } + + // aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_backward_jvp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) { + return at::_ops::glu_backward_jvp_out::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim, out); + } + + // aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & glu_backward_jvp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim, at::Tensor & out) { + return at::_ops::glu_backward_jvp_out::redispatch(dispatchKeySet, grad_x, grad_glu, x, dgrad_glu, dx, dim, out); + } + + // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::hardswish_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & hardswish_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::hardswish_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out) + inline ::std::tuple rrelu_with_noise_functional(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower=0.125, const at::Scalar & upper=0.3333333333333333, bool training=false, ::std::optional generator=::std::nullopt) { + return at::_ops::rrelu_with_noise_functional::redispatch(dispatchKeySet, self, noise, lower, upper, training, generator); + } + + // aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) { + return at::_ops::rrelu_with_noise_backward_out::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result, out); + } + + // aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & rrelu_with_noise_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result, at::Tensor & out) { + return at::_ops::rrelu_with_noise_backward_out::redispatch(dispatchKeySet, grad_output, self, noise, lower, upper, training, self_is_result, out); + } + + // aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & mkldnn_adaptive_avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool2d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool2d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(output_size), out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef output_size) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool3d_out::redispatch(dispatchKeySet, self, output_size, out); + } + + // aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad_output, const at::Tensor & self) { + return at::_ops::_adaptive_avg_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _adaptive_avg_pool3d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_adaptive_avg_pool3d_backward_out::redispatch(dispatchKeySet, grad_output, self, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, align_corners, scale_factors, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors, out); + } + + // aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_bilinear2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_bilinear2d_vec_out::redispatch(dispatchKeySet, input, output_size, align_corners, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size, scale_factors, out); + } + + // aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & upsample_nearest2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out) { + return at::_ops::upsample_nearest2d_vec_out::redispatch(dispatchKeySet, input, output_size, scale_factors, out); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask, out0, out1, out2); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, c10::fromIntArrayRefSlow(kernel_size), c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), output_mask, out0, out1, out2); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask, out0, out1, out2); + } + + // aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!)) + inline ::std::tuple _slow_conv2d_backward_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2) { + return at::_ops::_slow_conv2d_backward_output_mask_out::redispatch(dispatchKeySet, grad_output, self, weight, kernel_size, stride, padding, output_mask, out0, out1, out2); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & conv_depthwise3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::conv_depthwise3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated2d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated2d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias={}, at::IntArrayRef stride=1, at::IntArrayRef padding=0, at::IntArrayRef dilation=1) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, at::IntArrayRef kernel_size, const ::std::optional & bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, c10::fromIntArrayRefSlow(kernel_size), bias, c10::fromIntArrayRefSlow(stride), c10::fromIntArrayRefSlow(padding), c10::fromIntArrayRefSlow(dilation), out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias={}, c10::SymIntArrayRef stride=c10::SymInt(1), c10::SymIntArrayRef padding=c10::SymInt(0), c10::SymIntArrayRef dilation=c10::SymInt(1)) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slow_conv_dilated3d_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out) { + return at::_ops::slow_conv_dilated3d_out::redispatch(dispatchKeySet, self, weight, kernel_size, bias, stride, padding, dilation, out); + } + + // aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isinf_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::isinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & isinf_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::isinf_out::redispatch(dispatchKeySet, self, out); + } + + // aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_exp_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::linalg_matrix_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & linalg_matrix_exp_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::linalg_matrix_exp_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_intlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_intlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out) { + return at::_ops::_test_optional_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_filled_intlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, at::OptionalIntArrayRef addends) { + return at::_ops::_test_optional_filled_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_filled_intlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out) { + return at::_ops::_test_optional_filled_intlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_floatlist_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & values, ::std::optional> addends) { + return at::_ops::_test_optional_floatlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_optional_floatlist_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & values, ::std::optional> addends, at::Tensor & out) { + return at::_ops::_test_optional_floatlist_out::redispatch(dispatchKeySet, values, addends, out); + } + + // aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_warn_in_autograd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_test_warn_in_autograd_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_warn_in_autograd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_test_warn_in_autograd_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _test_autograd_multiple_dispatch_view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & segment_reduce_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & indices={}, const ::std::optional & offsets={}, int64_t axis=0, bool unsafe=false, const ::std::optional & initial=::std::nullopt) { + return at::_ops::segment_reduce_out::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial, out); + } + + // aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & segment_reduce_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial, at::Tensor & out) { + return at::_ops::segment_reduce_out::redispatch(dispatchKeySet, data, reduce, lengths, indices, offsets, axis, unsafe, initial, out); + } + + // aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _segment_reduce_backward_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths={}, const ::std::optional & offsets={}, int64_t axis=0, const ::std::optional & initial=::std::nullopt) { + return at::_ops::_segment_reduce_backward_out::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial, out); + } + + // aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _segment_reduce_backward_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial, at::Tensor & out) { + return at::_ops::_segment_reduce_backward_out::redispatch(dispatchKeySet, grad, output, data, reduce, lengths, offsets, axis, initial, out); + } + + // aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_tensor_list_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, at::TensorList list, ::std::optional dtype=::std::nullopt, ::std::optional layout=::std::nullopt, ::std::optional device=::std::nullopt, ::std::optional pin_memory=::std::nullopt) { + return at::_ops::_nested_tensor_from_tensor_list_out::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory, out); + } + + // aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _nested_tensor_from_tensor_list_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, at::Tensor & out) { + return at::_ops::_nested_tensor_from_tensor_list_out::redispatch(dispatchKeySet, list, dtype, layout, device, pin_memory, out); + } + + // aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fw_primal_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t level) { + return at::_ops::_fw_primal_copy_out::redispatch(dispatchKeySet, self, level, out); + } + + // aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _fw_primal_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t level, at::Tensor & out) { + return at::_ops::_fw_primal_copy_out::redispatch(dispatchKeySet, self, level, out); + } + + // aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_dual_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + return at::_ops::_make_dual_copy_out::redispatch(dispatchKeySet, primal, tangent, level, out); + } + + // aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _make_dual_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & primal, const at::Tensor & tangent, int64_t level, at::Tensor & out) { + return at::_ops::_make_dual_copy_out::redispatch(dispatchKeySet, primal, tangent, level, out); + } + + // aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_real_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::view_as_real_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_real_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::view_as_real_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_complex_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::view_as_complex_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_as_complex_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::view_as_complex_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_conj_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _conj_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_conj_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _neg_view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_neg_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _neg_view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_neg_view_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), storage_offset.has_value() ? ::std::make_optional(c10::SymInt(*storage_offset)) : ::std::nullopt, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset=::std::nullopt) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, size, stride, storage_offset, out); + } + + // aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & as_strided_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out) { + return at::_ops::as_strided_copy_out::redispatch(dispatchKeySet, self, size, stride, storage_offset, out); + } + + // aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_broadcast_to_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::_sparse_broadcast_to_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _sparse_broadcast_to_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::_sparse_broadcast_to_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t offset=0, int64_t dim1=0, int64_t dim2=1) { + return at::_ops::diagonal_copy_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & diagonal_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out) { + return at::_ops::diagonal_copy_out::redispatch(dispatchKeySet, self, offset, dim1, dim2, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, bool implicit, at::Tensor & out) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), implicit, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit=false) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, size, implicit, out); + } + + // aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & expand_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, bool implicit, at::Tensor & out) { + return at::_ops::expand_copy_out::redispatch(dispatchKeySet, self, size, implicit, out); + } + + // aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & permute_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dims) { + return at::_ops::permute_copy_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & permute_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out) { + return at::_ops::permute_copy_out::redispatch(dispatchKeySet, self, dims, out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::IntArrayRef stride, at::Tensor & out) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), c10::fromIntArrayRefSlow(stride), out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _reshape_alias_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out) { + return at::_ops::_reshape_alias_copy_out::redispatch(dispatchKeySet, self, size, stride, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, int64_t index) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, int64_t index, at::Tensor & out) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim, c10::SymInt index) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & select_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out) { + return at::_ops::select_copy_int_out::redispatch(dispatchKeySet, self, dim, index, out); + } + + // aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & detach_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::detach_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & detach_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::detach_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, int64_t step=1) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, int64_t step, at::Tensor & out) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start.has_value() ? ::std::make_optional(c10::SymInt(*start)) : ::std::nullopt, end.has_value() ? ::std::make_optional(c10::SymInt(*end)) : ::std::nullopt, step, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim=0, ::std::optional start=::std::nullopt, ::std::optional end=::std::nullopt, c10::SymInt step=1) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start, end, step, out); + } + + // aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & slice_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out) { + return at::_ops::slice_copy_Tensor_out::redispatch(dispatchKeySet, self, dim, start, end, step, out); + } + + // aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::squeeze_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::squeeze_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::squeeze_copy_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::squeeze_copy_dim_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef dim) { + return at::_ops::squeeze_copy_dims_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & squeeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out) { + return at::_ops::squeeze_copy_dims_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & t_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::t_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & t_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::t_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & transpose_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim0, int64_t dim1) { + return at::_ops::transpose_copy_int_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & transpose_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out) { + return at::_ops::transpose_copy_int_out::redispatch(dispatchKeySet, self, dim0, dim1, out); + } + + // aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unsqueeze_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dim) { + return at::_ops::unsqueeze_copy_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unsqueeze_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dim, at::Tensor & out) { + return at::_ops::unsqueeze_copy_out::redispatch(dispatchKeySet, self, dim, out); + } + + // aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::_values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & values_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & values_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::values_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & crow_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::crow_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & crow_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::crow_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::col_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & col_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::col_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ccol_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::ccol_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & ccol_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::ccol_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_indices_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::row_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & row_indices_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::row_indices_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::IntArrayRef size) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::IntArrayRef size, at::Tensor & out) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, c10::fromIntArrayRefSlow(size), out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, c10::SymIntArrayRef size) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out) { + return at::_ops::view_copy_out::redispatch(dispatchKeySet, self, size, out); + } + + // aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, at::ScalarType dtype) { + return at::_ops::view_copy_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & view_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::ScalarType dtype, at::Tensor & out) { + return at::_ops::view_copy_dtype_out::redispatch(dispatchKeySet, self, dtype, out); + } + + // aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + return at::_ops::unfold_copy_out::redispatch(dispatchKeySet, self, dimension, size, step, out); + } + + // aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & unfold_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, int64_t dimension, int64_t size, int64_t step, at::Tensor & out) { + return at::_ops::unfold_copy_out::redispatch(dispatchKeySet, self, dimension, size, step, out); + } + + // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & alias_copy_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self) { + return at::_ops::alias_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & alias_copy_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, at::Tensor & out) { + return at::_ops::alias_copy_out::redispatch(dispatchKeySet, self, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size, at::Tensor & out) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size.has_value() ? ::std::make_optional(c10::fromIntArrayRefSlow(*output_size)) : ::std::nullopt, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_symint_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size=::std::nullopt) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size, out); + } + + // aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & to_padded_tensor_symint_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size, at::Tensor & out) { + return at::_ops::to_padded_tensor_out::redispatch(dispatchKeySet, self, padding, output_size, out); + } + + // aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _transformer_encoder_layer_fwd_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask={}, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_transformer_encoder_layer_fwd_out::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type, out); + } + + // aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _transformer_encoder_layer_fwd_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type, at::Tensor & out) { + return at::_ops::_transformer_encoder_layer_fwd_out::redispatch(dispatchKeySet, src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type, out); + } + + // aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _native_multi_head_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out0, at::Tensor & out1, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}, bool need_weights=true, bool average_attn_weights=true, ::std::optional mask_type=::std::nullopt) { + return at::_ops::_native_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type, out0, out1); + } + + // aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!)) + inline ::std::tuple _native_multi_head_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type, at::Tensor & out0, at::Tensor & out1) { + return at::_ops::_native_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type, out0, out1); + } + + // aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_scaled_dot_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p=0.0) { + return at::_ops::_triton_scaled_dot_attention_out::redispatch(dispatchKeySet, q, k, v, dropout_p, out); + } + + // aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_scaled_dot_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out) { + return at::_ops::_triton_scaled_dot_attention_out::redispatch(dispatchKeySet, q, k, v, dropout_p, out); + } + + // aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_multi_head_attention_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask={}) { + return at::_ops::_triton_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, out); + } + + // aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _triton_multi_head_attention_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, at::Tensor & out) { + return at::_ops::_triton_multi_head_attention_out::redispatch(dispatchKeySet, query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, out); + } + + // aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _foobar_out(c10::DispatchKeySet dispatchKeySet, at::Tensor & out, const at::Tensor & self, bool arg1=true, bool arg2=true, bool arg3=true) { + return at::_ops::_foobar_out::redispatch(dispatchKeySet, self, arg1, arg2, arg3, out); + } + + // aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!) + inline at::Tensor & _foobar_outf(c10::DispatchKeySet dispatchKeySet, const at::Tensor & self, bool arg1, bool arg2, bool arg3, at::Tensor & out) { + return at::_ops::_foobar_out::redispatch(dispatchKeySet, self, arg1, arg2, arg3, out); + } + + // aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adam_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adam_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adam_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adam.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adam_tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adamw_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adamw_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adamw_tensor_lr_out::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adamw.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adamw_tensor_lr::redispatch(dispatchKeySet, self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + + // aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_sgd_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_tensor_lr_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_sgd_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_sgd_tensor_lr_out::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + + // aten::_fused_sgd.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_sgd_tensor_lr::redispatch(dispatchKeySet, self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + + // aten::_fused_adagrad.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adagrad_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out, Tensor[] state_steps_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector> _fused_adagrad(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + + // aten::_fused_adagrad.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_out(c10::DispatchKeySet dispatchKeySet, at::TensorList out, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_tensor_lr_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> () + inline void _fused_adagrad_outf(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + return at::_ops::_fused_adagrad_tensor_lr_out::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + + // aten::_fused_adagrad.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out) + inline ::std::tuple<::std::vector,::std::vector,::std::vector> _fused_adagrad(c10::DispatchKeySet dispatchKeySet, at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale={}, const ::std::optional & found_inf={}) { + return at::_ops::_fused_adagrad_tensor_lr::redispatch(dispatchKeySet, self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } +} // namespace redispatch + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/RegistrationDeclarations.h b/phivenv/Lib/site-packages/torch/include/ATen/RegistrationDeclarations.h new file mode 100644 index 0000000000000000000000000000000000000000..6c910b40974c827c2434d610f76ae8dc07b35566 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/RegistrationDeclarations.h @@ -0,0 +1,3167 @@ +// This file contains all native_functions that can be registered to +// and the schema string that they should be registered with + +at::Tensor _cast_Byte(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Byte(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Char(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Char(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Double(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Double(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Float(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Float(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Int(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Int(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Long(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Long(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Short(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Short(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _cast_Half(const at::Tensor & self, bool non_blocking); // {"schema": "aten::_cast_Half(Tensor self, bool non_blocking=False) -> Tensor", "dispatch": "False", "default": "True"} +void _backward(const at::Tensor & self, at::TensorList inputs, const ::std::optional & gradient, ::std::optional retain_graph, bool create_graph); // {"schema": "aten::_backward(Tensor self, Tensor[] inputs, Tensor? gradient=None, bool? retain_graph=None, bool create_graph=False) -> ()", "dispatch": "False", "default": "True"} +void set_data(at::Tensor & self, const at::Tensor & new_data); // {"schema": "aten::set_data(Tensor(a!) self, Tensor new_data) -> ()", "dispatch": "False", "default": "True"} +at::Tensor data(const at::Tensor & self); // {"schema": "aten::data(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +bool is_leaf(const at::Tensor & self); // {"schema": "aten::is_leaf(Tensor self) -> bool", "dispatch": "False", "default": "True"} +int64_t output_nr(const at::Tensor & self); // {"schema": "aten::output_nr(Tensor self) -> int", "dispatch": "False", "default": "True"} +int64_t _version(const at::Tensor & self); // {"schema": "aten::_version(Tensor self) -> int", "dispatch": "False", "default": "True"} +at::Tensor & requires_grad_(at::Tensor & self, bool requires_grad); // {"schema": "aten::requires_grad_(Tensor(a!) self, bool requires_grad=True) -> Tensor(a!)", "dispatch": "False", "default": "True"} +void retain_grad(at::Tensor & self); // {"schema": "aten::retain_grad(Tensor(a!) self) -> ()", "dispatch": "False", "default": "True"} +bool retains_grad(const at::Tensor & self); // {"schema": "aten::retains_grad(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor _fw_primal(const at::Tensor & self, int64_t level); // {"schema": "aten::_fw_primal(Tensor(a) self, int level) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor _make_dual(const at::Tensor & primal, const at::Tensor & tangent, int64_t level); // {"schema": "aten::_make_dual(Tensor(a) primal, Tensor tangent, int level) -> Tensor(a)", "dispatch": "True", "default": "True"} +::std::tuple _unpack_dual(const at::Tensor & dual, int64_t level); // {"schema": "aten::_unpack_dual(Tensor(a) dual, int level) -> (Tensor(a) primal, Tensor tangent)", "dispatch": "False", "default": "True"} +at::Tensor _new_zeros_with_same_feature_meta(const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims); // {"schema": "aten::_new_zeros_with_same_feature_meta(Tensor self, Tensor other, *, int self_num_batch_dims=0) -> Tensor", "dispatch": "True", "default": "True"} +bool _has_same_storage_numel(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::_has_same_storage_numel(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "True"} +at::Tensor & rename_(at::Tensor & self, ::std::optional names); // {"schema": "aten::rename_(Tensor(a!) self, Dimname[]? names) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor rename(const at::Tensor & self, ::std::optional names); // {"schema": "aten::rename(Tensor(a) self, Dimname[]? names) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor align_to(const at::Tensor & self, at::DimnameList names); // {"schema": "aten::align_to(Tensor(a) self, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor align_to(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx); // {"schema": "aten::align_to.ellipsis_idx(Tensor(a) self, Dimname[] order, int ellipsis_idx) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor align_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::align_as(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector align_tensors(at::TensorList tensors); // {"schema": "aten::align_tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +void _assert_async(const at::Tensor & self); // {"schema": "aten::_assert_async(Tensor self) -> ()", "dispatch": "True", "default": "False"} +void _assert_async(const at::Tensor & self, c10::string_view assert_msg); // {"schema": "aten::_assert_async.msg(Tensor self, str assert_msg) -> ()", "dispatch": "True", "default": "False"} +void _assert_scalar(const at::Scalar & self, c10::string_view assert_msg); // {"schema": "aten::_assert_scalar(Scalar self, str assert_msg) -> ()", "dispatch": "True", "default": "True"} +at::Tensor _functional_assert_scalar(const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token); // {"schema": "aten::_functional_assert_scalar(Scalar self, str assert_msg, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _functional_assert_async(const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token); // {"schema": "aten::_functional_assert_async.msg(Tensor self, str assert_msg, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "False"} +void _assert_tensor_metadata(const at::Tensor & a, at::OptionalSymIntArrayRef size, at::OptionalSymIntArrayRef stride, ::std::optional dtype, ::std::optional device, ::std::optional layout); // {"schema": "aten::_assert_tensor_metadata(Tensor a, SymInt[]? size=None, SymInt[]? stride=None, ScalarType? dtype=None, *, Device? device=None, Layout? layout=None) -> ()", "dispatch": "True", "default": "True"} +void _print(c10::string_view s); // {"schema": "aten::_print(str s) -> ()", "dispatch": "True", "default": "True"} +void sym_constrain_range(const at::Scalar & size, ::std::optional min, ::std::optional max); // {"schema": "aten::sym_constrain_range(Scalar size, *, int? min=None, int? max=None) -> ()", "dispatch": "True", "default": "True"} +void sym_constrain_range_for_size(const at::Scalar & size, ::std::optional min, ::std::optional max); // {"schema": "aten::sym_constrain_range_for_size(Scalar size, *, int? min=None, int? max=None) -> ()", "dispatch": "True", "default": "True"} +at::Tensor _functional_sym_constrain_range(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token); // {"schema": "aten::_functional_sym_constrain_range(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _functional_sym_constrain_range_for_size(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token); // {"schema": "aten::_functional_sym_constrain_range_for_size(Scalar size, int? min, int? max, Tensor dep_token) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _make_dep_token(::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::_make_dep_token(*, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor refine_names(const at::Tensor & self, at::DimnameList names); // {"schema": "aten::refine_names(Tensor(a) self, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"} +bool _use_cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank); // {"schema": "aten::_use_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank) -> bool", "dispatch": "True", "default": "False"} +bool _use_cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank); // {"schema": "aten::_use_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> bool", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity); // {"schema": "aten::_cudnn_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity); // {"schema": "aten::_cudnn_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank, bool deterministic, bool zero_infinity) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +bool _use_cudnn_rnn_flatten_weight(); // {"schema": "aten::_use_cudnn_rnn_flatten_weight() -> bool", "dispatch": "False", "default": "True"} +at::Tensor _cudnn_rnn_flatten_weight(at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional); // {"schema": "aten::_cudnn_rnn_flatten_weight(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_rnn(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state); // {"schema": "aten::_cudnn_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple> _cudnn_rnn_backward(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask); // {"schema": "aten::_cudnn_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])", "dispatch": "True", "default": "False"} +at::Tensor _cudnn_init_dropout_state(double dropout, bool train, int64_t dropout_seed, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_cudnn_init_dropout_state(float dropout, bool train, int dropout_seed, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "False"} +int64_t _debug_has_internal_overlap(const at::Tensor & self); // {"schema": "aten::_debug_has_internal_overlap(Tensor self) -> int", "dispatch": "False", "default": "True"} +::std::tuple _fused_dropout(const at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::_fused_dropout(Tensor self, float p, Generator? generator=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _masked_scale(const at::Tensor & self, const at::Tensor & mask, double scale); // {"schema": "aten::_masked_scale(Tensor self, Tensor mask, float scale) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple native_dropout(const at::Tensor & input, double p, ::std::optional train); // {"schema": "aten::native_dropout(Tensor input, float p, bool? train) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor native_dropout_backward(const at::Tensor & grad_output, const at::Tensor & mask, double scale); // {"schema": "aten::native_dropout_backward(Tensor grad_output, Tensor mask, float scale) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _sobol_engine_draw(const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, ::std::optional dtype); // {"schema": "aten::_sobol_engine_draw(Tensor quasi, int n, Tensor sobolstate, int dimension, int num_generated, ScalarType? dtype) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor & _sobol_engine_ff_(at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated); // {"schema": "aten::_sobol_engine_ff_(Tensor(a!) self, int n, Tensor sobolstate, int dimension, int num_generated) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & _sobol_engine_scramble_(at::Tensor & self, const at::Tensor & ltm, int64_t dimension); // {"schema": "aten::_sobol_engine_scramble_(Tensor(a!) self, Tensor ltm, int dimension) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & _sobol_engine_initialize_state_(at::Tensor & self, int64_t dimension); // {"schema": "aten::_sobol_engine_initialize_state_(Tensor(a!) self, int dimension) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor _reshape_from_tensor(const at::Tensor & self, const at::Tensor & shape); // {"schema": "aten::_reshape_from_tensor(Tensor self, Tensor shape) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _shape_as_tensor(const at::Tensor & self); // {"schema": "aten::_shape_as_tensor(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor feature_dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & feature_dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::feature_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor alpha_dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & alpha_dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor feature_alpha_dropout(const at::Tensor & input, double p, bool train); // {"schema": "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & feature_alpha_dropout_(at::Tensor & self, double p, bool train); // {"schema": "aten::feature_alpha_dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor abs(const at::Tensor & self); // {"schema": "aten::abs(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & abs_(at::Tensor & self); // {"schema": "aten::abs_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & abs_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::abs.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor absolute(const at::Tensor & self); // {"schema": "aten::absolute(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & absolute_(at::Tensor & self); // {"schema": "aten::absolute_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & absolute_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::absolute.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor angle(const at::Tensor & self); // {"schema": "aten::angle(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & angle_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::angle.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor view_as_real(const at::Tensor & self); // {"schema": "aten::view_as_real(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor view_as_complex(const at::Tensor & self); // {"schema": "aten::view_as_complex(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor sgn(const at::Tensor & self); // {"schema": "aten::sgn(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sgn_(at::Tensor & self); // {"schema": "aten::sgn_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sgn_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sgn.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor chalf(const at::Tensor & self, ::std::optional memory_format); // {"schema": "aten::chalf(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor real(const at::Tensor & self); // {"schema": "aten::real(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor imag(const at::Tensor & self); // {"schema": "aten::imag(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _conj(const at::Tensor & self); // {"schema": "aten::_conj(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor conj(const at::Tensor & self); // {"schema": "aten::conj(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _conj_physical(const at::Tensor & self); // {"schema": "aten::_conj_physical(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor conj_physical(const at::Tensor & self); // {"schema": "aten::conj_physical(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & conj_physical_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & conj_physical_(at::Tensor & self); // {"schema": "aten::conj_physical_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resolve_conj(const at::Tensor & self); // {"schema": "aten::resolve_conj(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor resolve_neg(const at::Tensor & self); // {"schema": "aten::resolve_neg(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _neg_view(const at::Tensor & self); // {"schema": "aten::_neg_view(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor acos(const at::Tensor & self); // {"schema": "aten::acos(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & acos_(at::Tensor & self); // {"schema": "aten::acos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & acos_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::acos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arccos(const at::Tensor & self); // {"schema": "aten::arccos(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arccos_(at::Tensor & self); // {"schema": "aten::arccos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arccos_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arccos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor avg_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad); // {"schema": "aten::avg_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor adaptive_avg_pool1d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool1d(Tensor self, int[1] output_size) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple adaptive_max_pool1d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool1d(Tensor self, int[1] output_size) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor add(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & add_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::add_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & add_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::add.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _add_relu(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _add_relu_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _add_relu_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::_add_relu.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _add_relu(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _add_relu_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::_add_relu_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor add(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::add.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & add_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::add_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor addmv(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addmv_(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmv_(Tensor(a!) self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addmv_out(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addmv.out(Tensor self, Tensor mat, Tensor vec, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addr(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addr_(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addr_(Tensor(a!) self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addr_out(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addr.out(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor affine_grid_generator(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners); // {"schema": "aten::affine_grid_generator(Tensor theta, SymInt[] size, bool align_corners) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor affine_grid_generator_backward(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners); // {"schema": "aten::affine_grid_generator_backward(Tensor grad, SymInt[] size, bool align_corners) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _is_all_true(const at::Tensor & self); // {"schema": "aten::_is_all_true(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _is_any_true(const at::Tensor & self); // {"schema": "aten::_is_any_true(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_check_tensor(const at::Tensor & self); // {"schema": "aten::_test_check_tensor(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_functorch_fallback(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::_test_functorch_fallback(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor all(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::all.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor all(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::all.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & all_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); // {"schema": "aten::all.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & all_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::all.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor all(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::all.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & all_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); // {"schema": "aten::all.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +bool allclose(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan); // {"schema": "aten::allclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> bool", "dispatch": "True", "default": "True"} +at::Tensor any(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::any.dim(Tensor self, int dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor any(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::any.dims(Tensor self, int[]? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & any_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out); // {"schema": "aten::any.out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & any_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::any.dims_out(Tensor self, int[]? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor any(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::any.dimname(Tensor self, Dimname dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & any_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & out); // {"schema": "aten::any.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor arange(const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::arange(Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor arange(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::arange.start(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor arange(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::arange.start_step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & arange_out(const at::Scalar & end, at::Tensor & out); // {"schema": "aten::arange.out(Scalar end, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & arange_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); // {"schema": "aten::arange.start_out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _dim_arange(const at::Tensor & like, int64_t dim); // {"schema": "aten::_dim_arange(Tensor like, int dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor argmax(const at::Tensor & self, ::std::optional dim, bool keepdim); // {"schema": "aten::argmax(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & argmax_out(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); // {"schema": "aten::argmax.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor argmin(const at::Tensor & self, ::std::optional dim, bool keepdim); // {"schema": "aten::argmin(Tensor self, int? dim=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & argmin_out(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & out); // {"schema": "aten::argmin.out(Tensor self, int? dim=None, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor acosh(const at::Tensor & self); // {"schema": "aten::acosh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & acosh_(at::Tensor & self); // {"schema": "aten::acosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & acosh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::acosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arccosh(const at::Tensor & self); // {"schema": "aten::arccosh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arccosh_(at::Tensor & self); // {"schema": "aten::arccosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arccosh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arccosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor asinh(const at::Tensor & self); // {"schema": "aten::asinh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & asinh_(at::Tensor & self); // {"schema": "aten::asinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & asinh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::asinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arcsinh(const at::Tensor & self); // {"schema": "aten::arcsinh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arcsinh_(at::Tensor & self); // {"schema": "aten::arcsinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arcsinh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arcsinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor atanh(const at::Tensor & self); // {"schema": "aten::atanh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & atanh_(at::Tensor & self); // {"schema": "aten::atanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & atanh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::atanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arctanh(const at::Tensor & self); // {"schema": "aten::arctanh(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arctanh_(at::Tensor & self); // {"schema": "aten::arctanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arctanh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arctanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor as_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided(Tensor(a) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a)", "dispatch": "True", "default": "False"} +const at::Tensor & as_strided_(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided_(Tensor(a!) self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor asin(const at::Tensor & self); // {"schema": "aten::asin(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & asin_(at::Tensor & self); // {"schema": "aten::asin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & asin_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::asin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arcsin(const at::Tensor & self); // {"schema": "aten::arcsin(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arcsin_(at::Tensor & self); // {"schema": "aten::arcsin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arcsin_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arcsin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor atan(const at::Tensor & self); // {"schema": "aten::atan(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & atan_(at::Tensor & self); // {"schema": "aten::atan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & atan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::atan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor arctan(const at::Tensor & self); // {"schema": "aten::arctan(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arctan_(at::Tensor & self); // {"schema": "aten::arctan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arctan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::arctan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor atleast_1d(const at::Tensor & self); // {"schema": "aten::atleast_1d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector atleast_1d(at::TensorList tensors); // {"schema": "aten::atleast_1d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor atleast_2d(const at::Tensor & self); // {"schema": "aten::atleast_2d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector atleast_2d(at::TensorList tensors); // {"schema": "aten::atleast_2d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor atleast_3d(const at::Tensor & self); // {"schema": "aten::atleast_3d(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector atleast_3d(at::TensorList tensors); // {"schema": "aten::atleast_3d.Sequence(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & baddbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::baddbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & baddbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::baddbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor baddbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::baddbmm.dtype(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & baddbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::baddbmm.dtype_out(Tensor self, Tensor batch1, Tensor batch2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bartlett_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::bartlett_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bartlett_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::bartlett_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantized_batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point); // {"schema": "aten::quantized_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_impl_index(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::_batch_norm_impl_index(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> (Tensor, Tensor, Tensor, Tensor, int)", "dispatch": "False", "default": "True"} +::std::tuple _batch_norm_impl_index_backward(int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace); // {"schema": "aten::_batch_norm_impl_index_backward(int impl_index, Tensor input, Tensor grad_output, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var_transform, bool train, float eps, bool[3] output_mask, Tensor reservedSpace) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor bernoulli(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::bernoulli(Tensor self, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bernoulli_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::bernoulli.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bernoulli_(at::Tensor & self, const at::Tensor & p, ::std::optional generator); // {"schema": "aten::bernoulli_.Tensor(Tensor(a!) self, Tensor p, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bernoulli_(at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::bernoulli_.float(Tensor(a!) self, float p=0.5, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bernoulli(const at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::bernoulli.p(Tensor self, float p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bilinear(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const ::std::optional & bias); // {"schema": "aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor binary_cross_entropy(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & binary_cross_entropy_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & out); // {"schema": "aten::binary_cross_entropy.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor binary_cross_entropy_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & binary_cross_entropy_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::binary_cross_entropy_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor binary_cross_entropy_with_logits(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction); // {"schema": "aten::binary_cross_entropy_with_logits(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bincount(const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength); // {"schema": "aten::bincount(Tensor self, Tensor? weights=None, SymInt minlength=0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor bitwise_not(const at::Tensor & self); // {"schema": "aten::bitwise_not(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_not_(at::Tensor & self); // {"schema": "aten::bitwise_not_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_not_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::bitwise_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & copysign_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::copysign.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor copysign(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::copysign.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copysign_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::copysign_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor copysign(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::copysign.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copysign_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::copysign_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & copysign_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::copysign.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _lazy_clone(const at::Tensor & self); // {"schema": "aten::_lazy_clone(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logical_not(const at::Tensor & self); // {"schema": "aten::logical_not(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_not_(at::Tensor & self); // {"schema": "aten::logical_not_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_not_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::logical_not.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logical_xor(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_xor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_xor_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_xor_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_xor_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logical_xor.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logical_and(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_and(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_and_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_and_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_and_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logical_and.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logical_or(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_or(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logical_or_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logical_or_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logical_or_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logical_or.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor blackman_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::blackman_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor blackman_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::blackman_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bmm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::bmm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bmm_out(const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::bmm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bmm(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype); // {"schema": "aten::bmm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & bmm_out(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out); // {"schema": "aten::bmm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector broadcast_tensors(at::TensorList tensors); // {"schema": "aten::broadcast_tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor broadcast_to(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::broadcast_to(Tensor(a) self, SymInt[] size) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _sparse_broadcast_to(const at::Tensor & self, at::IntArrayRef size); // {"schema": "aten::_sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor cat(const at::ITensorListRef & tensors, int64_t dim); // {"schema": "aten::cat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cat_out(const at::ITensorListRef & tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::cat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cat(at::TensorList tensors, at::Dimname dim); // {"schema": "aten::cat.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cat_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out); // {"schema": "aten::cat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concat(at::TensorList tensors, int64_t dim); // {"schema": "aten::concat(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concat_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::concat.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concat(at::TensorList tensors, at::Dimname dim); // {"schema": "aten::concat.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concat_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out); // {"schema": "aten::concat.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concatenate(at::TensorList tensors, int64_t dim); // {"schema": "aten::concatenate(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concatenate_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::concatenate.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor concatenate(at::TensorList tensors, at::Dimname dim); // {"schema": "aten::concatenate.names(Tensor[] tensors, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & concatenate_out(at::TensorList tensors, at::Dimname dim, at::Tensor & out); // {"schema": "aten::concatenate.names_out(Tensor[] tensors, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor block_diag(at::TensorList tensors); // {"schema": "aten::block_diag(Tensor[] tensors) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ceil(const at::Tensor & self); // {"schema": "aten::ceil(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ceil_(at::Tensor & self); // {"schema": "aten::ceil_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ceil_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::ceil.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor chain_matmul(at::TensorList matrices); // {"schema": "aten::chain_matmul(Tensor[] matrices) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & chain_matmul_out(at::TensorList matrices, at::Tensor & out); // {"schema": "aten::chain_matmul.out(Tensor[] matrices, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::vector unsafe_chunk(const at::Tensor & self, int64_t chunks, int64_t dim); // {"schema": "aten::unsafe_chunk(Tensor self, int chunks, int dim=0) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector chunk(const at::Tensor & self, int64_t chunks, int64_t dim); // {"schema": "aten::chunk(Tensor(a -> *) self, int chunks, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector tensor_split(const at::Tensor & self, c10::SymInt sections, int64_t dim); // {"schema": "aten::tensor_split.sections(Tensor(a -> *) self, SymInt sections, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector tensor_split(const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim); // {"schema": "aten::tensor_split.indices(Tensor(a -> *) self, SymInt[] indices, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector tensor_split(const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim); // {"schema": "aten::tensor_split.tensor_indices_or_sections(Tensor(a -> *) self, Tensor tensor_indices_or_sections, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +at::Tensor clamp(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor clamp(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & clamp_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clamp_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clamp.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & clamp_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clamp.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor clamp_max(const at::Tensor & self, const at::Scalar & max); // {"schema": "aten::clamp_max(Tensor self, Scalar max) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor clamp_max(const at::Tensor & self, const at::Tensor & max); // {"schema": "aten::clamp_max.Tensor(Tensor self, Tensor max) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & clamp_max_(at::Tensor & self, const at::Scalar & max); // {"schema": "aten::clamp_max_(Tensor(a!) self, Scalar max) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_max_(at::Tensor & self, const at::Tensor & max); // {"schema": "aten::clamp_max_.Tensor(Tensor(a!) self, Tensor max) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_max_out(const at::Tensor & self, const at::Scalar & max, at::Tensor & out); // {"schema": "aten::clamp_max.out(Tensor self, Scalar max, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & clamp_max_out(const at::Tensor & self, const at::Tensor & max, at::Tensor & out); // {"schema": "aten::clamp_max.Tensor_out(Tensor self, Tensor max, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor clamp_min(const at::Tensor & self, const at::Scalar & min); // {"schema": "aten::clamp_min(Tensor self, Scalar min) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor clamp_min(const at::Tensor & self, const at::Tensor & min); // {"schema": "aten::clamp_min.Tensor(Tensor self, Tensor min) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & clamp_min_(at::Tensor & self, const at::Scalar & min); // {"schema": "aten::clamp_min_(Tensor(a!) self, Scalar min) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_min_(at::Tensor & self, const at::Tensor & min); // {"schema": "aten::clamp_min_.Tensor(Tensor(a!) self, Tensor min) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clamp_min_out(const at::Tensor & self, const at::Scalar & min, at::Tensor & out); // {"schema": "aten::clamp_min.out(Tensor self, Scalar min, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & clamp_min_out(const at::Tensor & self, const at::Tensor & min, at::Tensor & out); // {"schema": "aten::clamp_min.Tensor_out(Tensor self, Tensor min, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor clip(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip(Tensor self, Scalar? min=None, Scalar? max=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor clip(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip.Tensor(Tensor self, Tensor? min=None, Tensor? max=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & clip_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip_(Tensor(a!) self, Scalar? min=None, Scalar? max=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & clip_(at::Tensor & self, const ::std::optional & min, const ::std::optional & max); // {"schema": "aten::clip_.Tensor(Tensor(a!) self, Tensor? min=None, Tensor? max=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & clip_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clip.out(Tensor self, Scalar? min=None, Scalar? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & clip_out(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max, at::Tensor & out); // {"schema": "aten::clip.Tensor_out(Tensor self, Tensor? min=None, Tensor? max=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +bool cudnn_is_acceptable(const at::Tensor & self); // {"schema": "aten::cudnn_is_acceptable(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor complex(const at::Tensor & real, const at::Tensor & imag); // {"schema": "aten::complex(Tensor real, Tensor imag) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & complex_out(const at::Tensor & real, const at::Tensor & imag, at::Tensor & out); // {"schema": "aten::complex.out(Tensor real, Tensor imag, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor polar(const at::Tensor & abs, const at::Tensor & angle); // {"schema": "aten::polar(Tensor abs, Tensor angle) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & polar_out(const at::Tensor & abs, const at::Tensor & angle, at::Tensor & out); // {"schema": "aten::polar.out(Tensor abs, Tensor angle, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor constant_pad_nd(const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value); // {"schema": "aten::constant_pad_nd(Tensor self, SymInt[] pad, Scalar value=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor contiguous(const at::Tensor & self, at::MemoryFormat memory_format); // {"schema": "aten::contiguous(Tensor(a) self, *, MemoryFormat memory_format=contiguous_format) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups); // {"schema": "aten::convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::convolution_backward(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor convolution_overrideable(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups); // {"schema": "aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward_overrideable(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", "dispatch": "True", "default": "True"} +at::Tensor _convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32); // {"schema": "aten::_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled); // {"schema": "aten::_convolution.deprecated(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, int[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _convolution_mode(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_convolution_mode(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, str padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _convolution_double_backward(const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::_convolution_double_backward(Tensor? ggI, Tensor? ggW, Tensor? ggb, Tensor gO, Tensor weight, Tensor self, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor conv1d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv2d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv3d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv3d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv1d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv1d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, str padding=\"valid\", SymInt[1] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv2d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv2d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, str padding=\"valid\", SymInt[2] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv3d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::conv3d.padding(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, str padding=\"valid\", SymInt[3] dilation=1, SymInt groups=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv_tbc(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad); // {"schema": "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple conv_tbc_backward(const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad); // {"schema": "aten::conv_tbc_backward(Tensor self, Tensor input, Tensor weight, Tensor bias, int pad) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor conv_transpose1d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor? bias=None, SymInt[1] stride=1, SymInt[1] padding=0, SymInt[1] output_padding=0, SymInt groups=1, SymInt[1] dilation=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv_transpose2d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose2d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt groups=1, SymInt[2] dilation=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor conv_transpose3d(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_transpose3d.input(Tensor input, Tensor weight, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt groups=1, SymInt[3] dilation=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor copy(const at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy(Tensor self, Tensor src, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copy_(at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _copy_from(const at::Tensor & self, const at::Tensor & dst, bool non_blocking); // {"schema": "aten::_copy_from(Tensor self, Tensor dst, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _copy_from_and_resize(const at::Tensor & self, const at::Tensor & dst); // {"schema": "aten::_copy_from_and_resize(Tensor self, Tensor dst) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cos(const at::Tensor & self); // {"schema": "aten::cos(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cos_(at::Tensor & self); // {"schema": "aten::cos_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cos_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::cos.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cosh(const at::Tensor & self); // {"schema": "aten::cosh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cosh_(at::Tensor & self); // {"schema": "aten::cosh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cosh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::cosh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cosine_embedding_loss(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction); // {"schema": "aten::cosine_embedding_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor count_nonzero(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::count_nonzero.dim_IntList(Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor count_nonzero(const at::Tensor & self, ::std::optional dim); // {"schema": "aten::count_nonzero(Tensor self, int? dim=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor cov(const at::Tensor & self, int64_t correction, const ::std::optional & fweights, const ::std::optional & aweights); // {"schema": "aten::cov(Tensor self, *, int correction=1, Tensor? fweights=None, Tensor? aweights=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor corrcoef(const at::Tensor & self); // {"schema": "aten::corrcoef(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cudnn_affine_grid_generator(const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W); // {"schema": "aten::cudnn_affine_grid_generator(Tensor theta, int N, int C, int H, int W) -> Tensor grid", "dispatch": "True", "default": "False"} +at::Tensor cudnn_affine_grid_generator_backward(const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W); // {"schema": "aten::cudnn_affine_grid_generator_backward(Tensor grad, int N, int C, int H, int W) -> Tensor grad_theta", "dispatch": "True", "default": "False"} +::std::tuple cudnn_batch_norm(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon); // {"schema": "aten::cudnn_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple cudnn_batch_norm_backward(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace); // {"schema": "aten::cudnn_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32); // {"schema": "aten::cudnn_convolution(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & cudnn_convolution_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out); // {"schema": "aten::cudnn_convolution.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution_transpose(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32); // {"schema": "aten::cudnn_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _mps_convolution_transpose(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_mps_convolution_transpose(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mps_convolution_transpose_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::mps_convolution_transpose_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution_relu(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::cudnn_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cudnn_convolution_add_relu(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::cudnn_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cudnn_grid_sampler(const at::Tensor & self, const at::Tensor & grid); // {"schema": "aten::cudnn_grid_sampler(Tensor self, Tensor grid) -> Tensor output", "dispatch": "True", "default": "False"} +::std::tuple cudnn_grid_sampler_backward(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output); // {"schema": "aten::cudnn_grid_sampler_backward(Tensor self, Tensor grid, Tensor grad_output) -> (Tensor grad_self, Tensor grad_grid)", "dispatch": "True", "default": "False"} +::std::tuple cummax(const at::Tensor & self, int64_t dim); // {"schema": "aten::cummax(Tensor self, int dim) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple cummax_out(const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple cummax(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple cummax_out(const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +void _cummax_helper(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim); // {"schema": "aten::_cummax_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()", "dispatch": "True", "default": "False"} +::std::tuple cummin(const at::Tensor & self, int64_t dim); // {"schema": "aten::cummin(Tensor self, int dim) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple cummin_out(const at::Tensor & self, int64_t dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummin.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple cummin(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::cummin.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple cummin_out(const at::Tensor & self, at::Dimname dim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::cummin.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +void _cummin_helper(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim); // {"schema": "aten::_cummin_helper(Tensor self, Tensor(a!) values, Tensor(b!) indices, int dim) -> ()", "dispatch": "True", "default": "False"} +at::Tensor cummaxmin_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim); // {"schema": "aten::cummaxmin_backward(Tensor grad, Tensor input, Tensor indices, int dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cumprod(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumprod(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cumprod_(at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumprod_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cumprod_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumprod.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cumprod(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumprod.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cumprod_(at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumprod_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & cumprod_out(const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor cumprod_backward(const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output); // {"schema": "aten::cumprod_backward(Tensor grad, Tensor input, int dim, Tensor output) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cumsum(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cumsum_(at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::cumsum_(Tensor(a!) self, int dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cumsum_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumsum.out(Tensor self, int dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cumsum(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumsum.dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cumsum_(at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::cumsum_.dimname(Tensor(a!) self, Dimname dim, *, ScalarType? dtype=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & cumsum_out(const at::Tensor & self, at::Dimname dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::cumsum.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor cumulative_trapezoid(const at::Tensor & y, const at::Tensor & x, int64_t dim); // {"schema": "aten::cumulative_trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cumulative_trapezoid(const at::Tensor & y, const at::Scalar & dx, int64_t dim); // {"schema": "aten::cumulative_trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, int64_t reduction, bool zero_infinity); // {"schema": "aten::ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, int64_t reduction, bool zero_infinity); // {"schema": "aten::ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _ctc_loss(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss.Tensor(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _ctc_loss_backward(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss_backward(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _ctc_loss_backward(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity); // {"schema": "aten::_ctc_loss_backward.Tensor(Tensor grad, Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor diag_embed(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diag_embed(Tensor self, int offset=0, int dim1=-2, int dim2=-1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor diagflat(const at::Tensor & self, int64_t offset); // {"schema": "aten::diagflat(Tensor self, int offset=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor diagonal(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal(Tensor(a) self, int offset=0, int dim1=0, int dim2=1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor linalg_diagonal(const at::Tensor & A, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::linalg_diagonal(Tensor(a) A, *, int offset=0, int dim1=-2, int dim2=-1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor diagonal(const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset); // {"schema": "aten::diagonal.Dimname(Tensor(a) self, *, Dimname outdim, Dimname dim1, Dimname dim2, int offset=0) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor diagonal_backward(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_backward(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fill_diagonal_(at::Tensor & self, const at::Scalar & fill_value, bool wrap); // {"schema": "aten::fill_diagonal_(Tensor(a!) self, Scalar fill_value, bool wrap=False) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor diff(const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append); // {"schema": "aten::diff(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & diff_out(const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append, at::Tensor & out); // {"schema": "aten::diff.out(Tensor self, int n=1, int dim=-1, Tensor? prepend=None, Tensor? append=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, const ::std::optional & spacing, ::std::optional dim, int64_t edge_order); // {"schema": "aten::gradient.scalarint(Tensor self, *, Scalar? spacing=None, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.scalararray(Tensor self, *, Scalar spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.array(Tensor self, *, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::ArrayRef spacing, ::std::optional dim, int64_t edge_order); // {"schema": "aten::gradient.scalarrayint(Tensor self, *, Scalar[] spacing, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.scalarrayarray(Tensor self, *, Scalar[] spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::TensorList spacing, ::std::optional dim, int64_t edge_order); // {"schema": "aten::gradient.tensorarrayint(Tensor self, *, Tensor[] spacing, int? dim=None, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector gradient(const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order); // {"schema": "aten::gradient.tensorarray(Tensor self, *, Tensor[] spacing, int[] dim, int edge_order=1) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor div(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::div.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::div.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor div(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::div.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::div_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out); // {"schema": "aten::div.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor div(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::div.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor div(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::div.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & div_(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::div_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::divide.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::divide.Tensor_mode(Tensor self, Tensor other, *, str? rounding_mode) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode); // {"schema": "aten::divide_.Tensor_mode(Tensor(a!) self, Tensor other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & divide_out(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode, at::Tensor & out); // {"schema": "aten::divide.out_mode(Tensor self, Tensor other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor divide(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::divide.Scalar_mode(Tensor self, Scalar other, *, str? rounding_mode) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & divide_(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode); // {"schema": "aten::divide_.Scalar_mode(Tensor(a!) self, Scalar other, *, str? rounding_mode) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor true_divide(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::true_divide.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & true_divide_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::true_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & true_divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::true_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor true_divide(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::true_divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & true_divide_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::true_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor dot(const at::Tensor & self, const at::Tensor & tensor); // {"schema": "aten::dot(Tensor self, Tensor tensor) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & dot_out(const at::Tensor & self, const at::Tensor & tensor, at::Tensor & out); // {"schema": "aten::dot.out(Tensor self, Tensor tensor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor vdot(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::vdot(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & vdot_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::vdot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor einsum(c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path); // {"schema": "aten::einsum(str equation, Tensor[] tensors, *, int[]? path=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor embedding(const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); // {"schema": "aten::embedding(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor embedding_backward(const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse); // {"schema": "aten::embedding_backward(Tensor grad, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor embedding_dense_backward(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq); // {"schema": "aten::embedding_dense_backward(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & embedding_renorm_(at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type); // {"schema": "aten::embedding_renorm_(Tensor(a!) self, Tensor indices, float max_norm, float norm_type) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor embedding_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq); // {"schema": "aten::embedding_sparse_backward(Tensor grad, Tensor indices, int num_weights, int padding_idx, bool scale_grad_by_freq) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _embedding_bag_forward_only(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx); // {"schema": "aten::_embedding_bag_forward_only(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _rowwise_prune(const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype); // {"schema": "aten::_rowwise_prune(Tensor weight, Tensor mask, ScalarType compressed_indices_dtype) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor row_stack(at::TensorList tensors); // {"schema": "aten::row_stack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & row_stack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::row_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple embedding_bag(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset); // {"schema": "aten::embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple embedding_bag(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, ::std::optional padding_idx); // {"schema": "aten::embedding_bag.padding_idx(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, bool include_last_offset, int? padding_idx) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _embedding_bag(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx); // {"schema": "aten::_embedding_bag(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _embedding_bag_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, bool sparse, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _embedding_bag_sparse_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_sparse_backward(Tensor grad, Tensor indices, Tensor offsets, Tensor offset2bag, Tensor bag_size, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _embedding_bag_dense_backward(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx); // {"schema": "aten::_embedding_bag_dense_backward(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _embedding_bag_per_sample_weights_backward(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx); // {"schema": "aten::_embedding_bag_per_sample_weights_backward(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor empty(at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor empty(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty.memory_format(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor empty_permuted(c10::SymIntArrayRef size, at::IntArrayRef physical_layout, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::empty_permuted(SymInt[] size, int[] physical_layout, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_empty(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_empty(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_empty_strided(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_empty_strided(Tensor self, SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_full(const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_full(Tensor self, SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_zeros(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_zeros(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor new_ones(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::new_ones(Tensor self, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _empty_affine_quantized(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, double scale, int64_t zero_point, ::std::optional memory_format); // {"schema": "aten::_empty_affine_quantized(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _empty_per_channel_affine_quantized(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::_empty_per_channel_affine_quantized(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=contiguous_format) -> Tensor", "dispatch": "True", "default": "False"} +const at::Tensor & resize_(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format); // {"schema": "aten::resize_(Tensor(a!) self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +const at::Tensor & _resize_output_(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device); // {"schema": "aten::_resize_output_(Tensor(a!) self, SymInt[] size, Device device) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor empty_quantized(at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty_quantized(int[] size, Tensor qtensor, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & empty_out(c10::SymIntArrayRef size, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty.out(SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor empty_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::empty_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor empty_strided(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::empty_strided(SymInt[] size, SymInt[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor erf(const at::Tensor & self); // {"schema": "aten::erf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & erf_(at::Tensor & self); // {"schema": "aten::erf_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & erf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor erfc(const at::Tensor & self); // {"schema": "aten::erfc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & erfc_(at::Tensor & self); // {"schema": "aten::erfc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & erfc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor exp(const at::Tensor & self); // {"schema": "aten::exp(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & exp_(at::Tensor & self); // {"schema": "aten::exp_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & exp_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor exp2(const at::Tensor & self); // {"schema": "aten::exp2(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & exp2_(at::Tensor & self); // {"schema": "aten::exp2_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & exp2_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor expm1(const at::Tensor & self); // {"schema": "aten::expm1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & expm1_(at::Tensor & self); // {"schema": "aten::expm1_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & expm1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor expand(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit); // {"schema": "aten::expand(Tensor(a) self, SymInt[] size, *, bool implicit=False) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor expand_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor eye(c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::eye(SymInt n, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor eye(c10::SymInt n, c10::SymInt m, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::eye.m(SymInt n, SymInt m, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & eye_out(c10::SymInt n, at::Tensor & out); // {"schema": "aten::eye.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & eye_out(c10::SymInt n, c10::SymInt m, at::Tensor & out); // {"schema": "aten::eye.m_out(SymInt n, SymInt m, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim); // {"schema": "aten::flatten.using_ints(Tensor(a) self, int start_dim=0, int end_dim=-1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor flatten(const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim); // {"schema": "aten::flatten.named_out_dim(Tensor(a) self, int start_dim, int end_dim, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor flatten(const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim); // {"schema": "aten::flatten.using_names(Tensor(a) self, Dimname start_dim, Dimname end_dim, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor flatten(const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim); // {"schema": "aten::flatten.DimnameList(Tensor(a) self, Dimname[] dims, Dimname out_dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor unflatten(const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes); // {"schema": "aten::unflatten.int(Tensor(a) self, int dim, SymInt[] sizes) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor unflatten(const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names); // {"schema": "aten::unflatten.Dimname(Tensor(a) self, Dimname dim, SymInt[] sizes, Dimname[] names) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor fill(const at::Tensor & self, const at::Scalar & value); // {"schema": "aten::fill.Scalar(Tensor self, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor fill(const at::Tensor & self, const at::Tensor & value); // {"schema": "aten::fill.Tensor(Tensor self, Tensor value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fill_(at::Tensor & self, const at::Scalar & value); // {"schema": "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & fill_(at::Tensor & self, const at::Tensor & value); // {"schema": "aten::fill_.Tensor(Tensor(a!) self, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor floor(const at::Tensor & self); // {"schema": "aten::floor(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & floor_(at::Tensor & self); // {"schema": "aten::floor_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & floor_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::floor.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor floor_divide(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::floor_divide(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & floor_divide_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::floor_divide_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & floor_divide_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::floor_divide.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor floor_divide(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::floor_divide.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & floor_divide_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::floor_divide_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor frac(const at::Tensor & self); // {"schema": "aten::frac(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & frac_(at::Tensor & self); // {"schema": "aten::frac_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & frac_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::frac.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor full(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::full.names(int[] size, Scalar fill_value, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor full(c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::full(SymInt[] size, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & full_out(c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out); // {"schema": "aten::full.out(SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor full_like(const at::Tensor & self, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::full_like(Tensor self, Scalar fill_value, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor from_file(c10::string_view filename, ::std::optional shared, ::std::optional size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::from_file(str filename, bool? shared=None, int? size=0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & gcd_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::gcd.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gcd(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gcd(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gcd_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gcd_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lcm_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::lcm.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lcm(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lcm(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lcm_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lcm_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor grid_sampler(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor grid_sampler_2d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler_2d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple grid_sampler_2d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); // {"schema": "aten::grid_sampler_2d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _grid_sampler_2d_cpu_fallback(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::_grid_sampler_2d_cpu_fallback(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple _grid_sampler_2d_cpu_fallback_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::_grid_sampler_2d_cpu_fallback_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor grid_sampler_3d(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners); // {"schema": "aten::grid_sampler_3d(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple grid_sampler_3d_backward(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask); // {"schema": "aten::grid_sampler_3d_backward(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor hann_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hann_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hann_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hann_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, bool periodic, double alpha, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window.periodic_alpha(int window_length, bool periodic, float alpha, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hamming_window(int64_t window_length, bool periodic, double alpha, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::hamming_window.periodic_alpha_beta(int window_length, bool periodic, float alpha, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor kaiser_window(int64_t window_length, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::kaiser_window(int window_length, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor kaiser_window(int64_t window_length, bool periodic, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::kaiser_window.periodic(int window_length, bool periodic, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor kaiser_window(int64_t window_length, bool periodic, double beta, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::kaiser_window.beta(int window_length, bool periodic, float beta, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor hinge_embedding_loss(const at::Tensor & self, const at::Tensor & target, double margin, int64_t reduction); // {"schema": "aten::hinge_embedding_loss(Tensor self, Tensor target, float margin=1.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor group_norm(const at::Tensor & input, int64_t num_groups, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enabled); // {"schema": "aten::group_norm(Tensor input, int num_groups, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enabled=True) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple native_group_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps); // {"schema": "aten::native_group_norm(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple native_group_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask); // {"schema": "aten::native_group_norm_backward(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _fft_r2c(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided); // {"schema": "aten::_fft_r2c(Tensor self, int[] dim, int normalization, bool onesided) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fft_r2c_out(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided, at::Tensor & out); // {"schema": "aten::_fft_r2c.out(Tensor self, int[] dim, int normalization, bool onesided, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _fft_c2r(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size); // {"schema": "aten::_fft_c2r(Tensor self, int[] dim, int normalization, SymInt last_dim_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fft_c2r_out(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size, at::Tensor & out); // {"schema": "aten::_fft_c2r.out(Tensor self, int[] dim, int normalization, SymInt last_dim_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _fft_c2c(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward); // {"schema": "aten::_fft_c2c(Tensor self, SymInt[] dim, int normalization, bool forward) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fft_c2c_out(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward, at::Tensor & out); // {"schema": "aten::_fft_c2c.out(Tensor self, SymInt[] dim, int normalization, bool forward, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +void _validate_compressed_sparse_indices(bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz); // {"schema": "aten::_validate_compressed_sparse_indices(bool is_crow, Tensor compressed_idx, Tensor plain_idx, int cdim, int dim, int nnz) -> ()", "dispatch": "True", "default": "False"} +int64_t _cufft_get_plan_cache_size(at::DeviceIndex device_index); // {"schema": "aten::_cufft_get_plan_cache_size(DeviceIndex device_index) -> int", "dispatch": "False", "default": "True"} +int64_t _cufft_get_plan_cache_max_size(at::DeviceIndex device_index); // {"schema": "aten::_cufft_get_plan_cache_max_size(DeviceIndex device_index) -> int", "dispatch": "False", "default": "True"} +void _cufft_set_plan_cache_max_size(at::DeviceIndex device_index, int64_t max_size); // {"schema": "aten::_cufft_set_plan_cache_max_size(DeviceIndex device_index, int max_size) -> ()", "dispatch": "False", "default": "True"} +void _cufft_clear_plan_cache(at::DeviceIndex device_index); // {"schema": "aten::_cufft_clear_plan_cache(DeviceIndex device_index) -> ()", "dispatch": "False", "default": "True"} +at::Tensor index(const at::Tensor & self, const c10::List<::std::optional> & indices); // {"schema": "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_out(const at::Tensor & self, const c10::List<::std::optional> & indices, at::Tensor & out); // {"schema": "aten::index.Tensor_out(Tensor self, Tensor?[] indices, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _unsafe_index(const at::Tensor & self, const c10::List<::std::optional> & indices); // {"schema": "aten::_unsafe_index.Tensor(Tensor self, Tensor?[] indices) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _unsafe_masked_index(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill); // {"schema": "aten::_unsafe_masked_index(Tensor self, Tensor mask, Tensor?[] indices, Scalar fill) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _unsafe_masked_index_put_accumulate(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Tensor & values); // {"schema": "aten::_unsafe_masked_index_put_accumulate(Tensor self, Tensor mask, Tensor?[] indices, Tensor values) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_copy_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, at::Tensor & out); // {"schema": "aten::index_copy.out(Tensor self, int dim, Tensor index, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & index_copy_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy_(Tensor(a!) self, int dim, Tensor index, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_copy(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy(Tensor self, int dim, Tensor index, Tensor source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_copy_(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy_.dimname(Tensor(a!) self, Dimname dim, Tensor index, Tensor source) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor index_copy(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source); // {"schema": "aten::index_copy.dimname(Tensor self, Dimname dim, Tensor index, Tensor source) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & index_put_(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate); // {"schema": "aten::index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_put(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate); // {"schema": "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _unsafe_index_put(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate); // {"schema": "aten::_unsafe_index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _index_put_impl_(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe); // {"schema": "aten::_index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor instance_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled); // {"schema": "aten::instance_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool use_input_stats, float momentum, float eps, bool cudnn_enabled) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor isclose(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan); // {"schema": "aten::isclose(Tensor self, Tensor other, float rtol=1e-05, float atol=1e-08, bool equal_nan=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & isin_out(const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out); // {"schema": "aten::isin.Tensor_Tensor_out(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isin(const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert); // {"schema": "aten::isin.Tensor_Tensor(Tensor elements, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isin_out(const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert, at::Tensor & out); // {"schema": "aten::isin.Tensor_Scalar_out(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isin(const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert); // {"schema": "aten::isin.Tensor_Scalar(Tensor elements, Scalar test_element, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isin_out(const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert, at::Tensor & out); // {"schema": "aten::isin.Scalar_Tensor_out(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isin(const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert); // {"schema": "aten::isin.Scalar_Tensor(Scalar element, Tensor test_elements, *, bool assume_unique=False, bool invert=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor isnan(const at::Tensor & self); // {"schema": "aten::isnan(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +bool is_distributed(const at::Tensor & self); // {"schema": "aten::is_distributed(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_floating_point(const at::Tensor & self); // {"schema": "aten::is_floating_point(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_complex(const at::Tensor & self); // {"schema": "aten::is_complex(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_conj(const at::Tensor & self); // {"schema": "aten::is_conj(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool _is_zerotensor(const at::Tensor & self); // {"schema": "aten::_is_zerotensor(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_neg(const at::Tensor & self); // {"schema": "aten::is_neg(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor isreal(const at::Tensor & self); // {"schema": "aten::isreal(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +bool is_nonzero(const at::Tensor & self); // {"schema": "aten::is_nonzero(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_same_size(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::is_same_size(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "True"} +bool is_signed(const at::Tensor & self); // {"schema": "aten::is_signed(Tensor self) -> bool", "dispatch": "False", "default": "True"} +bool is_inference(const at::Tensor & self); // {"schema": "aten::is_inference(Tensor self) -> bool", "dispatch": "False", "default": "True"} +at::Tensor kl_div(const at::Tensor & self, const at::Tensor & target, int64_t reduction, bool log_target); // {"schema": "aten::kl_div(Tensor self, Tensor target, int reduction=Mean, *, bool log_target=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor kron(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::kron(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & kron_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::kron.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple kthvalue(const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim); // {"schema": "aten::kthvalue(Tensor self, SymInt k, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple kthvalue_out(const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::kthvalue.values(Tensor self, SymInt k, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple kthvalue(const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim); // {"schema": "aten::kthvalue.dimname(Tensor self, SymInt k, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple kthvalue_out(const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::kthvalue.dimname_out(Tensor self, SymInt k, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor layer_norm(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enable); // {"schema": "aten::layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple native_layer_norm(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps); // {"schema": "aten::native_layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple native_layer_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask); // {"schema": "aten::native_layer_norm_backward(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor rms_norm(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps); // {"schema": "aten::rms_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _fused_rms_norm(const at::Tensor & input, int64_t normalized_shape_ndim, const at::Tensor & weight, double eps); // {"schema": "aten::_fused_rms_norm(Tensor input, int normalized_shape_ndim, Tensor weight, float eps) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor nan_to_num(const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf); // {"schema": "aten::nan_to_num(Tensor self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & nan_to_num_(at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf); // {"schema": "aten::nan_to_num_(Tensor(a!) self, float? nan=None, float? posinf=None, float? neginf=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & nan_to_num_out(const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf, at::Tensor & out); // {"schema": "aten::nan_to_num.out(Tensor self, float? nan=None, float? posinf=None, float? neginf=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor linear(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias); // {"schema": "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple linear_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask); // {"schema": "aten::linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & linear_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out); // {"schema": "aten::linear.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor mkldnn_linear(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias); // {"schema": "aten::mkldnn_linear(Tensor self, Tensor weight, Tensor? bias=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_linear_backward_input(at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight); // {"schema": "aten::mkldnn_linear_backward_input(int[] input_size, Tensor grad_output, Tensor weight) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mkldnn_linear_backward_weights(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined); // {"schema": "aten::mkldnn_linear_backward_weights(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple mkldnn_linear_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask); // {"schema": "aten::mkldnn_linear_backward(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _cslt_compress(const at::Tensor & input); // {"schema": "aten::_cslt_compress(Tensor input) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _cslt_sparse_mm(const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias, const ::std::optional & alpha, ::std::optional out_dtype, bool transpose_result, int64_t alg_id, int64_t split_k, int64_t split_k_mode); // {"schema": "aten::_cslt_sparse_mm(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False, int alg_id=0, int split_k=1, int split_k_mode=-1) -> Tensor", "dispatch": "True", "default": "False"} +int64_t _cslt_sparse_mm_search(const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias, const ::std::optional & alpha, ::std::optional out_dtype, bool transpose_result); // {"schema": "aten::_cslt_sparse_mm_search(Tensor compressed_A, Tensor dense_B, Tensor? bias=None, Tensor? alpha=None, ScalarType? out_dtype=None, bool transpose_result=False) -> int", "dispatch": "True", "default": "False"} +::std::tuple _sparse_semi_structured_tile(const at::Tensor & input, c10::string_view algorithm, bool use_cutlass); // {"schema": "aten::_sparse_semi_structured_tile(Tensor input, str algorithm=\"\", bool use_cutlass=True) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _sparse_semi_structured_apply(const at::Tensor & input, const at::Tensor & thread_masks); // {"schema": "aten::_sparse_semi_structured_apply(Tensor input, Tensor thread_masks) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_apply_dense(const at::Tensor & input, const at::Tensor & thread_masks); // {"schema": "aten::_sparse_semi_structured_apply_dense(Tensor input, Tensor thread_masks) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_linear(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const ::std::optional & bias, ::std::optional activation, ::std::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_linear(Tensor input, Tensor weight, Tensor meta, *, Tensor? bias=None, str? activation=None, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_mm(const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_mm(Tensor mat1, Tensor mat1_meta, Tensor mat2, *, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_semi_structured_addmm(const at::Tensor & input, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, const at::Scalar & alpha, const at::Scalar & beta, ::std::optional out_dtype); // {"schema": "aten::_sparse_semi_structured_addmm(Tensor input, Tensor mat1, Tensor mat1_meta, Tensor mat2, *, Scalar alpha=1, Scalar beta=1, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _mixed_dtypes_linear(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const ::std::optional & bias, ::std::optional activation); // {"schema": "aten::_mixed_dtypes_linear(Tensor input, Tensor weight, Tensor scale, *, Tensor? bias=None, str? activation=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor fbgemm_linear_int8_weight_fp32_activation(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_int8_weight_fp32_activation(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_int8_weight(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_int8_weight(Tensor input, Tensor weight, Tensor packed, Tensor col_offsets, Scalar weight_scale, Scalar weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple fbgemm_linear_quantize_weight(const at::Tensor & input); // {"schema": "aten::fbgemm_linear_quantize_weight(Tensor input) -> (Tensor, Tensor, float, int)", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_pack_gemm_matrix_fp16(const at::Tensor & input); // {"schema": "aten::fbgemm_pack_gemm_matrix_fp16(Tensor input) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _wrapped_linear_prepack(const at::Tensor & weight, const at::Tensor & weight_scale, const at::Tensor & weight_zero_point, const at::Tensor & bias); // {"schema": "aten::_wrapped_linear_prepack(Tensor weight, Tensor weight_scale, Tensor weight_zero_point, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _wrapped_quantized_linear_prepacked(const at::Tensor & input, const at::Tensor & input_scale, const at::Tensor & input_zero_point, const at::Tensor & packed_weight, const at::Tensor & output_scale, const at::Tensor & output_zero_point, int64_t out_channel); // {"schema": "aten::_wrapped_quantized_linear_prepacked(Tensor input, Tensor input_scale, Tensor input_zero_point, Tensor packed_weight, Tensor output_scale, Tensor output_zero_point, int out_channel) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_fp16_weight_fp32_activation(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_fp16_weight_fp32_activation(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_linear_fp16_weight(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias); // {"schema": "aten::fbgemm_linear_fp16_weight(Tensor input, Tensor packed_weight, Tensor bias) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_pack_quantized_matrix(const at::Tensor & input); // {"schema": "aten::fbgemm_pack_quantized_matrix(Tensor input) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fbgemm_pack_quantized_matrix(const at::Tensor & input, int64_t K, int64_t N); // {"schema": "aten::fbgemm_pack_quantized_matrix.KN(Tensor input, int K, int N) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor ldexp(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ldexp.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & ldexp_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ldexp_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & ldexp_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::ldexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace(Scalar start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor linspace(const at::Tensor & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace.Tensor_Tensor(Tensor start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor linspace(const at::Tensor & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace.Tensor_Scalar(Tensor start, Scalar end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor linspace(const at::Scalar & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::linspace.Scalar_Tensor(Scalar start, Tensor end, int steps, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linspace_out(const at::Scalar & start, const at::Scalar & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.out(Scalar start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & linspace_out(const at::Tensor & start, const at::Tensor & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & linspace_out(const at::Tensor & start, const at::Scalar & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & linspace_out(const at::Scalar & start, const at::Tensor & end, int64_t steps, at::Tensor & out); // {"schema": "aten::linspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log(const at::Tensor & self); // {"schema": "aten::log(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log_(at::Tensor & self); // {"schema": "aten::log_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log10(const at::Tensor & self); // {"schema": "aten::log10(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log10_(at::Tensor & self); // {"schema": "aten::log10_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log10_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log10.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log1p(const at::Tensor & self); // {"schema": "aten::log1p(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log1p_(at::Tensor & self); // {"schema": "aten::log1p_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log1p_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log2(const at::Tensor & self); // {"schema": "aten::log2(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log2_(at::Tensor & self); // {"schema": "aten::log2_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log2_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & logaddexp_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logaddexp.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logaddexp(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logaddexp(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logaddexp2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::logaddexp2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logaddexp2(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::logaddexp2(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor xlogy(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::xlogy.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor xlogy(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::xlogy.Scalar_Self(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor xlogy(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::xlogy.Scalar_Other(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::xlogy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::xlogy_.Scalar_Other(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::xlogy.OutTensor(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & xlogy_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::xlogy.OutScalar_Self(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & xlogy_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::xlogy.OutScalar_Other(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace(Scalar start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace.Tensor_Tensor(Tensor start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace.Tensor_Scalar(Tensor start, Scalar end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor logspace(const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::logspace.Scalar_Tensor(Scalar start, Tensor end, int steps, float base=10.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logspace_out(const at::Scalar & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.out(Scalar start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & logspace_out(const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.Tensor_Tensor_out(Tensor start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logspace_out(const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.Tensor_Scalar_out(Tensor start, Scalar end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & logspace_out(const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, at::Tensor & out); // {"schema": "aten::logspace.Scalar_Tensor_out(Scalar start, Tensor end, int steps, float base=10.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & log_softmax_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::log_softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log_softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _log_softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _log_softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _log_softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype); // {"schema": "aten::_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _log_softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & out); // {"schema": "aten::_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _logcumsumexp(const at::Tensor & self, int64_t dim); // {"schema": "aten::_logcumsumexp(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _logcumsumexp_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::_logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logcumsumexp(const at::Tensor & self, int64_t dim); // {"schema": "aten::logcumsumexp(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logcumsumexp_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::logcumsumexp.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor logcumsumexp(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::logcumsumexp.dimname(Tensor self, Dimname dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & logcumsumexp_out(const at::Tensor & self, at::Dimname dim, at::Tensor & out); // {"schema": "aten::logcumsumexp.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logsumexp_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor logsumexp(const at::Tensor & self, at::DimnameList dim, bool keepdim); // {"schema": "aten::logsumexp.names(Tensor self, Dimname[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & logsumexp_out(const at::Tensor & self, at::DimnameList dim, bool keepdim, at::Tensor & out); // {"schema": "aten::logsumexp.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor margin_ranking_loss(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction); // {"schema": "aten::margin_ranking_loss(Tensor input1, Tensor input2, Tensor target, float margin=0.0, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor matmul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple matmul_backward(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask); // {"schema": "aten::matmul_backward(Tensor grad, Tensor self, Tensor other, bool[2] mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & matmul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor matrix_power(const at::Tensor & self, int64_t n); // {"schema": "aten::matrix_power(Tensor self, int n) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & matrix_power_out(const at::Tensor & self, int64_t n, at::Tensor & out); // {"schema": "aten::matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor matrix_exp(const at::Tensor & self); // {"schema": "aten::matrix_exp(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor matrix_exp_backward(const at::Tensor & self, const at::Tensor & grad); // {"schema": "aten::matrix_exp_backward(Tensor self, Tensor grad) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _aminmax(const at::Tensor & self); // {"schema": "aten::_aminmax(Tensor self) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _aminmax(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::_aminmax.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple aminmax(const at::Tensor & self, ::std::optional dim, bool keepdim); // {"schema": "aten::aminmax(Tensor self, *, int? dim=None, bool keepdim=False) -> (Tensor min, Tensor max)", "dispatch": "True", "default": "True"} +::std::tuple aminmax_out(const at::Tensor & self, ::std::optional dim, bool keepdim, at::Tensor & min, at::Tensor & max); // {"schema": "aten::aminmax.out(Tensor self, *, int? dim=None, bool keepdim=False, Tensor(a!) min, Tensor(b!) max) -> (Tensor(a!) min, Tensor(b!) max)", "dispatch": "True", "default": "False"} +at::Tensor _compute_linear_combination(const at::Tensor & input, const at::Tensor & coefficients); // {"schema": "aten::_compute_linear_combination(Tensor input, Tensor coefficients) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _compute_linear_combination_out(const at::Tensor & input, const at::Tensor & coefficients, at::Tensor & out); // {"schema": "aten::_compute_linear_combination.out(Tensor input, Tensor coefficients, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple max(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple max_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & max, at::Tensor & max_values); // {"schema": "aten::max.dim_max(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple max(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::max.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple max_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & max, at::Tensor & max_values); // {"schema": "aten::max.names_dim_max(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) max, Tensor(b!) max_values) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor value_selecting_reduction_backward(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim); // {"schema": "aten::value_selecting_reduction_backward(Tensor grad, int dim, Tensor indices, SymInt[] sizes, bool keepdim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor amax(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::amax(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & amax_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::amax.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple max_pool1d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool1d_with_indices(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor max_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool2d_backward(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::mkldnn_max_pool3d_backward(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantized_max_pool1d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool1d(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantized_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantized_max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::quantized_max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor mean(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::mean(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mean_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::mean.dtype_out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor mean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::mean.dim(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::mean.out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mean(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::mean.names_dim(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & mean_out(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::mean.names_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nanmean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::nanmean(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nanmean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::nanmean.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor median(const at::Tensor & self); // {"schema": "aten::median(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple median(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::median.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple median_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::median.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple median(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::median.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple median_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::median.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor nanmedian(const at::Tensor & self); // {"schema": "aten::nanmedian(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple nanmedian(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::nanmedian.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple nanmedian_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::nanmedian.dim_values(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple nanmedian(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::nanmedian.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple nanmedian_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::nanmedian.names_dim_values(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +::std::tuple min(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::min.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple min_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices); // {"schema": "aten::min.dim_min(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple min(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::min.names_dim(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple min_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & min, at::Tensor & min_indices); // {"schema": "aten::min.names_dim_min(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) min, Tensor(b!) min_indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor amin(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::amin(Tensor self, int[1] dim=[], bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & amin_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::amin.out(Tensor self, int[1] dim=[], bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _mps_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::_mps_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mps_convolution_backward(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask); // {"schema": "aten::mps_convolution_backward(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::mkldnn_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_rnn_layer(const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train); // {"schema": "aten::mkldnn_rnn_layer(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple mkldnn_rnn_layer_backward(const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace); // {"schema": "aten::mkldnn_rnn_layer_backward(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple miopen_batch_norm(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon); // {"schema": "aten::miopen_batch_norm(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple miopen_batch_norm_backward(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon); // {"schema": "aten::miopen_batch_norm_backward(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution_transpose(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_convolution_transpose(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_depthwise_convolution(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic); // {"schema": "aten::miopen_depthwise_convolution(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution_relu(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::miopen_convolution_relu(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor miopen_convolution_add_relu(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups); // {"schema": "aten::miopen_convolution_add_relu(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple miopen_rnn(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state); // {"schema": "aten::miopen_rnn(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple> miopen_rnn_backward(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask); // {"schema": "aten::miopen_rnn_backward(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask) -> (Tensor, Tensor, Tensor, Tensor[])", "dispatch": "True", "default": "False"} +at::Tensor mm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::mm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mm_out(const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mm(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype); // {"schema": "aten::mm.dtype(Tensor self, Tensor mat2, ScalarType out_dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & mm_out(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype, at::Tensor & out); // {"schema": "aten::mm.dtype_out(Tensor self, Tensor mat2, ScalarType out_dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _int_mm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::_int_mm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _int_mm_out(const at::Tensor & self, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::_int_mm.out(Tensor self, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _convert_weight_to_int4pack(const at::Tensor & self, int64_t innerKTiles); // {"schema": "aten::_convert_weight_to_int4pack(Tensor self, int innerKTiles) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int4pack_mm(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros); // {"schema": "aten::_weight_int4pack_mm(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int4pack_mm_with_scales_and_zeros(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScale, const at::Tensor & qZeros); // {"schema": "aten::_weight_int4pack_mm_with_scales_and_zeros(Tensor self, Tensor mat2, int qGroupSize, Tensor qScale, Tensor qZeros) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _convert_weight_to_int4pack_for_cpu(const at::Tensor & self, int64_t innerKTiles); // {"schema": "aten::_convert_weight_to_int4pack_for_cpu(Tensor self, int innerKTiles) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int4pack_mm_for_cpu(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros); // {"schema": "aten::_weight_int4pack_mm_for_cpu(Tensor self, Tensor mat2, int qGroupSize, Tensor qScaleAndZeros) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _dyn_quant_pack_4bit_weight(const at::Tensor & weights, const at::Tensor & scales_zeros, const ::std::optional & bias, int64_t block_size, int64_t in_features, int64_t out_features); // {"schema": "aten::_dyn_quant_pack_4bit_weight(Tensor weights, Tensor scales_zeros, Tensor? bias, int block_size, int in_features, int out_features) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _dyn_quant_matmul_4bit(const at::Tensor & inp, const at::Tensor & packed_weights, int64_t block_size, int64_t in_features, int64_t out_features); // {"schema": "aten::_dyn_quant_matmul_4bit(Tensor inp, Tensor packed_weights, int block_size, int in_features, int out_features) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _weight_int8pack_mm(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales); // {"schema": "aten::_weight_int8pack_mm(Tensor self, Tensor mat2, Tensor scales) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_mm(const at::Tensor & sparse, const at::Tensor & dense); // {"schema": "aten::_sparse_mm(Tensor sparse, Tensor dense) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_mm(const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce); // {"schema": "aten::_sparse_mm.reduce(Tensor sparse, Tensor dense, str reduce) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sparse_matmul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::_sparse_sparse_matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple mode(const at::Tensor & self, int64_t dim, bool keepdim); // {"schema": "aten::mode(Tensor self, int dim=-1, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "False"} +::std::tuple mode_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::mode.values(Tensor self, int dim=-1, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple mode(const at::Tensor & self, at::Dimname dim, bool keepdim); // {"schema": "aten::mode.dimname(Tensor self, Dimname dim, bool keepdim=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple mode_out(const at::Tensor & self, at::Dimname dim, bool keepdim, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::mode.dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +at::Tensor mul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mul_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::mul_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mul(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::mul.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mul_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::mul_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor multiply(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::multiply.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & multiply_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::multiply_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & multiply_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::multiply.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor multiply(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::multiply.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & multiply_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::multiply_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor mv(const at::Tensor & self, const at::Tensor & vec); // {"schema": "aten::mv(Tensor self, Tensor vec) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mv_out(const at::Tensor & self, const at::Tensor & vec, at::Tensor & out); // {"schema": "aten::mv.out(Tensor self, Tensor vec, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mvlgamma_out(const at::Tensor & self, int64_t p, at::Tensor & out); // {"schema": "aten::mvlgamma.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mvlgamma(const at::Tensor & self, int64_t p); // {"schema": "aten::mvlgamma(Tensor self, int p) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mvlgamma_(at::Tensor & self, int64_t p); // {"schema": "aten::mvlgamma_(Tensor(a!) self, int p) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor narrow_copy(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length); // {"schema": "aten::narrow_copy(Tensor self, int dim, SymInt start, SymInt length) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & narrow_copy_out(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length, at::Tensor & out); // {"schema": "aten::narrow_copy.out(Tensor self, int dim, SymInt start, SymInt length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor narrow(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length); // {"schema": "aten::narrow(Tensor(a) self, int dim, SymInt start, SymInt length) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor narrow(const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length); // {"schema": "aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, SymInt length) -> Tensor(a)", "dispatch": "False", "default": "True"} +::std::tuple native_batch_norm(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps); // {"schema": "aten::native_batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple native_batch_norm_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd); // {"schema": "aten::native_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit_no_training(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit_no_training(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple _native_batch_norm_legit_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd); // {"schema": "aten::_native_batch_norm_legit.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, bool training, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd) -> (Tensor(d!), Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit.no_stats(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _native_batch_norm_legit_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd); // {"schema": "aten::_native_batch_norm_legit.no_stats_out(Tensor input, Tensor? weight, Tensor? bias, bool training, float momentum, float eps, *, Tensor(a!) out, Tensor(b!) save_mean, Tensor(c!) save_invstd) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_stats(const at::Tensor & input, double eps); // {"schema": "aten::batch_norm_stats(Tensor input, float eps) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor batch_norm_elemt(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps); // {"schema": "aten::batch_norm_elemt(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & batch_norm_elemt_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps, at::Tensor & out); // {"schema": "aten::batch_norm_elemt.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor invstd, float eps, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_gather_stats(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count); // {"schema": "aten::batch_norm_gather_stats(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_gather_stats_with_counts(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts); // {"schema": "aten::batch_norm_gather_stats_with_counts(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple native_batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask); // {"schema": "aten::native_batch_norm_backward(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_backward_reduce(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g); // {"schema": "aten::batch_norm_backward_reduce(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor batch_norm_backward_elemt(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count); // {"schema": "aten::batch_norm_backward_elemt(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple batch_norm_update_stats(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum); // {"schema": "aten::batch_norm_update_stats(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +bool is_vulkan_available(); // {"schema": "aten::is_vulkan_available() -> bool", "dispatch": "False", "default": "True"} +bool _nnpack_available(); // {"schema": "aten::_nnpack_available() -> bool", "dispatch": "False", "default": "True"} +at::Tensor _nnpack_spatial_convolution(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride); // {"schema": "aten::_nnpack_spatial_convolution(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ones(at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::ones.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ones(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::ones(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ones_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::ones.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor ones_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::ones_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor pairwise_distance(const at::Tensor & x1, const at::Tensor & x2, double p, double eps, bool keepdim); // {"schema": "aten::pairwise_distance(Tensor x1, Tensor x2, float p=2, float eps=1e-06, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor cdist(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode); // {"schema": "aten::cdist(Tensor x1, Tensor x2, float p=2, int? compute_mode=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _euclidean_dist(const at::Tensor & x1, const at::Tensor & x2); // {"schema": "aten::_euclidean_dist(Tensor x1, Tensor x2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _cdist_forward(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode); // {"schema": "aten::_cdist_forward(Tensor x1, Tensor x2, float p, int? compute_mode) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _cdist_backward(const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist); // {"schema": "aten::_cdist_backward(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor pdist(const at::Tensor & self, double p); // {"schema": "aten::pdist(Tensor self, float p=2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _pdist_forward(const at::Tensor & self, double p); // {"schema": "aten::_pdist_forward(Tensor self, float p=2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _pdist_backward(const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist); // {"schema": "aten::_pdist_backward(Tensor grad, Tensor self, float p, Tensor pdist) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cosine_similarity(const at::Tensor & x1, const at::Tensor & x2, int64_t dim, double eps); // {"schema": "aten::cosine_similarity(Tensor x1, Tensor x2, int dim=1, float eps=1e-08) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor permute(const at::Tensor & self, at::IntArrayRef dims); // {"schema": "aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor movedim(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination); // {"schema": "aten::movedim.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor movedim(const at::Tensor & self, int64_t source, int64_t destination); // {"schema": "aten::movedim.int(Tensor(a) self, int source, int destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor moveaxis(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination); // {"schema": "aten::moveaxis.intlist(Tensor(a) self, int[] source, int[] destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor moveaxis(const at::Tensor & self, int64_t source, int64_t destination); // {"schema": "aten::moveaxis.int(Tensor(a) self, int source, int destination) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor numpy_T(const at::Tensor & self); // {"schema": "aten::numpy_T(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor matrix_H(const at::Tensor & self); // {"schema": "aten::matrix_H(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor mT(const at::Tensor & self); // {"schema": "aten::mT(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor mH(const at::Tensor & self); // {"schema": "aten::mH(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor adjoint(const at::Tensor & self); // {"schema": "aten::adjoint(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor pixel_shuffle(const at::Tensor & self, int64_t upscale_factor); // {"schema": "aten::pixel_shuffle(Tensor self, int upscale_factor) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor pixel_unshuffle(const at::Tensor & self, int64_t downscale_factor); // {"schema": "aten::pixel_unshuffle(Tensor self, int downscale_factor) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor channel_shuffle(const at::Tensor & self, c10::SymInt groups); // {"schema": "aten::channel_shuffle(Tensor self, SymInt groups) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor native_channel_shuffle(const at::Tensor & self, c10::SymInt groups); // {"schema": "aten::native_channel_shuffle(Tensor self, SymInt groups) -> Tensor", "dispatch": "True", "default": "True"} +bool is_pinned(const at::Tensor & self, ::std::optional device); // {"schema": "aten::is_pinned(Tensor self, Device? device=None) -> bool", "dispatch": "True", "default": "True"} +at::Tensor pin_memory(const at::Tensor & self, ::std::optional device); // {"schema": "aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _pin_memory(const at::Tensor & self, ::std::optional device); // {"schema": "aten::_pin_memory(Tensor self, Device? device=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor pinverse(const at::Tensor & self, double rcond); // {"schema": "aten::pinverse(Tensor self, float rcond=1e-15) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor poisson_nll_loss(const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction); // {"schema": "aten::poisson_nll_loss(Tensor input, Tensor target, bool log_input, bool full, float eps, int reduction) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor rad2deg(const at::Tensor & self); // {"schema": "aten::rad2deg(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rad2deg_(at::Tensor & self); // {"schema": "aten::rad2deg_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rad2deg_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::rad2deg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor deg2rad(const at::Tensor & self); // {"schema": "aten::deg2rad(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & deg2rad_(at::Tensor & self); // {"schema": "aten::deg2rad_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & deg2rad_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::deg2rad.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor scalar_tensor(const at::Scalar & s, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::scalar_tensor(Scalar s, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rand(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::rand.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::rand.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::rand.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor rand_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::rand_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint(SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint.generator(SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint.low(SymInt low, SymInt high, SymInt[] size, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randint.low_generator(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::randint.out(SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randint.generator_out(SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::randint.low_out(SymInt low, SymInt high, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_out(c10::SymInt low, c10::SymInt high, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randint.low_generator_out(SymInt low, SymInt high, SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like(Tensor self, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, const at::Tensor & high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.Tensor(Tensor self, Tensor high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randint_like(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randint_like.low_dtype(Tensor self, SymInt low, SymInt high, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn.generator(SymInt[] size, *, Generator? generator, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn.names(SymInt[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randn(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randn.generator_with_names(SymInt[] size, *, Generator? generator, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::randn.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randn.generator_out(SymInt[] size, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor randn_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::randn_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randperm(c10::SymInt n, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randperm(SymInt n, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor randperm(c10::SymInt n, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::randperm.generator(SymInt n, *, Generator? generator, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & randperm_out(c10::SymInt n, at::Tensor & out); // {"schema": "aten::randperm.out(SymInt n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randperm_out(c10::SymInt n, ::std::optional generator, at::Tensor & out); // {"schema": "aten::randperm.generator_out(SymInt n, *, Generator? generator, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor range(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::range.step(Scalar start, Scalar end, Scalar step=1, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor range(const at::Scalar & start, const at::Scalar & end, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::range(Scalar start, Scalar end, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & range_out(const at::Scalar & start, const at::Scalar & end, at::Tensor & out); // {"schema": "aten::range.out_(Scalar start, Scalar end, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & range_out(const at::Scalar & start, const at::Scalar & end, const at::Scalar & step, at::Tensor & out); // {"schema": "aten::range.out(Scalar start, Scalar end, Scalar step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ravel(const at::Tensor & self); // {"schema": "aten::ravel(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor reciprocal(const at::Tensor & self); // {"schema": "aten::reciprocal(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reciprocal_(at::Tensor & self); // {"schema": "aten::reciprocal_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & reciprocal_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::reciprocal.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor neg(const at::Tensor & self); // {"schema": "aten::neg(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & neg_(at::Tensor & self); // {"schema": "aten::neg_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & neg_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::neg.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor negative(const at::Tensor & self); // {"schema": "aten::negative(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & negative_(at::Tensor & self); // {"schema": "aten::negative_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & negative_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::negative.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor repeat(const at::Tensor & self, c10::SymIntArrayRef repeats); // {"schema": "aten::repeat(Tensor self, SymInt[] repeats) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor repeat_interleave(const at::Tensor & repeats, ::std::optional output_size); // {"schema": "aten::repeat_interleave.Tensor(Tensor repeats, *, SymInt? output_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor repeat_interleave(const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size); // {"schema": "aten::repeat_interleave.self_Tensor(Tensor self, Tensor repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor repeat_interleave(const at::Tensor & self, c10::SymInt repeats, ::std::optional dim, ::std::optional output_size); // {"schema": "aten::repeat_interleave.self_int(Tensor self, SymInt repeats, int? dim=None, *, SymInt? output_size=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor reshape(const at::Tensor & self, c10::SymIntArrayRef shape); // {"schema": "aten::reshape(Tensor(a) self, SymInt[] shape) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _reshape_copy(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::_reshape_copy(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _reshape_alias(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::_reshape_alias(Tensor(a) self, SymInt[] size, SymInt[] stride) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _mkldnn_reshape(const at::Tensor & self, at::IntArrayRef shape); // {"schema": "aten::_mkldnn_reshape(Tensor self, int[] shape) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor reshape_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::reshape_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor round(const at::Tensor & self); // {"schema": "aten::round(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & round_(at::Tensor & self); // {"schema": "aten::round_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & round_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::round.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor round(const at::Tensor & self, int64_t decimals); // {"schema": "aten::round.decimals(Tensor self, *, int decimals) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & round_(at::Tensor & self, int64_t decimals); // {"schema": "aten::round_.decimals(Tensor(a!) self, *, int decimals) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & round_out(const at::Tensor & self, int64_t decimals, at::Tensor & out); // {"schema": "aten::round.decimals_out(Tensor self, *, int decimals, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor rrelu(const at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu(Tensor self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & rrelu_(at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_(Tensor(a!) self, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor relu(const at::Tensor & self); // {"schema": "aten::relu(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & relu_(at::Tensor & self); // {"schema": "aten::relu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor relu6(const at::Tensor & self); // {"schema": "aten::relu6(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & relu6_(at::Tensor & self); // {"schema": "aten::relu6_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor prelu(const at::Tensor & self, const at::Tensor & weight); // {"schema": "aten::prelu(Tensor self, Tensor weight) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _prelu_kernel(const at::Tensor & self, const at::Tensor & weight); // {"schema": "aten::_prelu_kernel(Tensor self, Tensor weight) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _prelu_kernel_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight); // {"schema": "aten::_prelu_kernel_backward(Tensor grad_output, Tensor self, Tensor weight) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & gelu_out(const at::Tensor & self, c10::string_view approximate, at::Tensor & out); // {"schema": "aten::gelu.out(Tensor self, *, str approximate='none', Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & gelu_(at::Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu_(Tensor(a!) self, *, str approximate='none') -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor gelu(const at::Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu(Tensor self, *, str approximate='none') -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gelu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate, at::Tensor & grad_input); // {"schema": "aten::gelu_backward.grad_input(Tensor grad_output, Tensor self, *, str approximate='none', Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gelu_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate); // {"schema": "aten::gelu_backward(Tensor grad_output, Tensor self, *, str approximate='none') -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor infinitely_differentiable_gelu_backward(const at::Tensor & grad, const at::Tensor & self); // {"schema": "aten::infinitely_differentiable_gelu_backward(Tensor grad, Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & hardshrink_out(const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out); // {"schema": "aten::hardshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardshrink(const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::hardshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hardshrink_backward_out(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input); // {"schema": "aten::hardshrink_backward.grad_input(Tensor grad_out, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardshrink_backward(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::hardshrink_backward(Tensor grad_out, Tensor self, Scalar lambd) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor rsqrt(const at::Tensor & self); // {"schema": "aten::rsqrt(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rsqrt_(at::Tensor & self); // {"schema": "aten::rsqrt_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rsqrt_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::rsqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor select(const at::Tensor & self, at::Dimname dim, int64_t index); // {"schema": "aten::select.Dimname(Tensor(a) self, Dimname dim, int index) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor select(const at::Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::select.int(Tensor(a) self, int dim, SymInt index) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor select_backward(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index); // {"schema": "aten::select_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_select_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::_nested_select_backward(Tensor grad_output, Tensor self, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor selu(const at::Tensor & self); // {"schema": "aten::selu(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & selu_(at::Tensor & self); // {"schema": "aten::selu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor celu(const at::Tensor & self, const at::Scalar & alpha); // {"schema": "aten::celu(Tensor self, Scalar alpha=1.0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & celu_(at::Tensor & self, const at::Scalar & alpha); // {"schema": "aten::celu_(Tensor(a!) self, Scalar alpha=1.0) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor silu(const at::Tensor & self); // {"schema": "aten::silu(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & silu_(at::Tensor & self); // {"schema": "aten::silu_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & silu_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::silu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & silu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); // {"schema": "aten::silu_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor silu_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::silu_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor mish(const at::Tensor & self); // {"schema": "aten::mish(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mish_(at::Tensor & self); // {"schema": "aten::mish_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mish_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::mish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mish_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::mish_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sigmoid(const at::Tensor & self); // {"schema": "aten::sigmoid(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sigmoid_(at::Tensor & self); // {"schema": "aten::sigmoid_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sigmoid_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logit(const at::Tensor & self, ::std::optional eps); // {"schema": "aten::logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & logit_(at::Tensor & self, ::std::optional eps); // {"schema": "aten::logit_(Tensor(a!) self, float? eps=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & logit_out(const at::Tensor & self, ::std::optional eps, at::Tensor & out); // {"schema": "aten::logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sin(const at::Tensor & self); // {"schema": "aten::sin(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sin_(at::Tensor & self); // {"schema": "aten::sin_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sin_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sin.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sinc(const at::Tensor & self); // {"schema": "aten::sinc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sinc_(at::Tensor & self); // {"schema": "aten::sinc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sinc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sinh(const at::Tensor & self); // {"schema": "aten::sinh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sinh_(at::Tensor & self); // {"schema": "aten::sinh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sinh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sinh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor detach(const at::Tensor & self); // {"schema": "aten::detach(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & detach_(at::Tensor & self); // {"schema": "aten::detach_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +int64_t size(const at::Tensor & self, int64_t dim); // {"schema": "aten::size.int(Tensor self, int dim) -> int", "dispatch": "False", "default": "True"} +int64_t size(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::size.Dimname(Tensor self, Dimname dim) -> int", "dispatch": "False", "default": "True"} +c10::SymInt sym_size(const at::Tensor & self, int64_t dim); // {"schema": "aten::sym_size.int(Tensor self, int dim) -> SymInt", "dispatch": "False", "default": "True"} +c10::SymInt sym_numel(const at::Tensor & self); // {"schema": "aten::sym_numel(Tensor self) -> SymInt", "dispatch": "False", "default": "True"} +c10::SymInt sym_storage_offset(const at::Tensor & self); // {"schema": "aten::sym_storage_offset(Tensor self) -> SymInt", "dispatch": "False", "default": "True"} +at::Tensor slice(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice.Tensor(Tensor(a) self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor slice_backward(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step); // {"schema": "aten::slice_backward(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor slice_inverse(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice_inverse(Tensor(a) self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor slice_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice_scatter(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor select_scatter(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index); // {"schema": "aten::select_scatter(Tensor self, Tensor src, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor diagonal_scatter(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_scatter(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor as_strided_scatter(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided_scatter(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor smm(const at::Tensor & self, const at::Tensor & mat2); // {"schema": "aten::smm(Tensor self, Tensor mat2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & softmax_out(const at::Tensor & self, int64_t dim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::softmax.int_out(Tensor self, int dim, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype); // {"schema": "aten::_softmax_backward_data(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype, at::Tensor & grad_input); // {"schema": "aten::_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, ScalarType input_dtype, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector unsafe_split(const at::Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::unsafe_split.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector split(const at::Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::split.Tensor(Tensor(a -> *) self, SymInt split_size, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector split(const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim); // {"schema": "aten::split.sizes(Tensor(a -> *) self, SymInt[] split_size, int dim=0) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector unsafe_split_with_sizes(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::unsafe_split_with_sizes(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector split_with_sizes(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::split_with_sizes(Tensor(a -> *) self, SymInt[] split_sizes, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector hsplit(const at::Tensor & self, int64_t sections); // {"schema": "aten::hsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector hsplit(const at::Tensor & self, at::IntArrayRef indices); // {"schema": "aten::hsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector vsplit(const at::Tensor & self, int64_t sections); // {"schema": "aten::vsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector vsplit(const at::Tensor & self, at::IntArrayRef indices); // {"schema": "aten::vsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector dsplit(const at::Tensor & self, int64_t sections); // {"schema": "aten::dsplit.int(Tensor(a -> *) self, int sections) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +::std::vector dsplit(const at::Tensor & self, at::IntArrayRef indices); // {"schema": "aten::dsplit.array(Tensor(a -> *) self, int[] indices) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +at::Tensor squeeze(const at::Tensor & self); // {"schema": "aten::squeeze(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor squeeze(const at::Tensor & self, int64_t dim); // {"schema": "aten::squeeze.dim(Tensor(a) self, int dim) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor squeeze(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::squeeze.dimname(Tensor(a) self, Dimname dim) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor squeeze(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::squeeze.dims(Tensor(a) self, int[] dim) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self); // {"schema": "aten::squeeze_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self, int64_t dim); // {"schema": "aten::squeeze_.dim(Tensor(a!) self, int dim) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::squeeze_.dims(Tensor(a!) self, int[] dim) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_(at::Tensor & self, at::Dimname dim); // {"schema": "aten::squeeze_.dimname(Tensor(a!) self, Dimname dim) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor sspaddmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::sspaddmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & sspaddmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sspaddmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _chunk_cat(at::TensorList tensors, int64_t dim, int64_t num_chunks); // {"schema": "aten::_chunk_cat(Tensor[] tensors, int dim, int num_chunks) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _chunk_cat_out(at::TensorList tensors, int64_t dim, int64_t num_chunks, at::Tensor & out); // {"schema": "aten::_chunk_cat.out(Tensor[] tensors, int dim, int num_chunks, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor stack(at::TensorList tensors, int64_t dim); // {"schema": "aten::stack(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & stack_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _stack(at::TensorList tensors, int64_t dim); // {"schema": "aten::_stack(Tensor[] tensors, int dim=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _stack_out(at::TensorList tensors, int64_t dim, at::Tensor & out); // {"schema": "aten::_stack.out(Tensor[] tensors, int dim=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor hstack(at::TensorList tensors); // {"schema": "aten::hstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & hstack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::hstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor vstack(at::TensorList tensors); // {"schema": "aten::vstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & vstack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::vstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor dstack(at::TensorList tensors); // {"schema": "aten::dstack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & dstack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::dstack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor stft(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window); // {"schema": "aten::stft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor stft(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, c10::string_view pad_mode, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window); // {"schema": "aten::stft.center(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, str pad_mode=\"reflect\", bool normalized=False, bool? onesided=None, bool? return_complex=None, bool? align_to_window=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor istft(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, bool normalized, ::std::optional onesided, ::std::optional length, bool return_complex); // {"schema": "aten::istft(Tensor self, int n_fft, int? hop_length=None, int? win_length=None, Tensor? window=None, bool center=True, bool normalized=False, bool? onesided=None, int? length=None, bool return_complex=False) -> Tensor", "dispatch": "False", "default": "True"} +int64_t stride(const at::Tensor & self, int64_t dim); // {"schema": "aten::stride.int(Tensor self, int dim) -> int", "dispatch": "False", "default": "True"} +int64_t stride(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::stride.Dimname(Tensor self, Dimname dim) -> int", "dispatch": "False", "default": "True"} +c10::SymInt sym_stride(const at::Tensor & self, int64_t dim); // {"schema": "aten::sym_stride.int(Tensor self, int dim) -> SymInt", "dispatch": "False", "default": "True"} +at::Tensor sum(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::sum(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sum(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::sum.dim_IntList(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sum(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::sum.dim_DimnameList(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & sum_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::sum.IntList_out(Tensor self, int[1]? dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & sum_out(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::sum.DimnameList_out(Tensor self, Dimname[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor _nested_sum_backward(const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim); // {"schema": "aten::_nested_sum_backward(Tensor grad, Tensor self, int[1]? dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor nansum(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::nansum(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & nansum_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::nansum.out(Tensor self, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sum_to_size(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::sum_to_size(Tensor self, SymInt[] size) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sqrt(const at::Tensor & self); // {"schema": "aten::sqrt(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sqrt_(at::Tensor & self); // {"schema": "aten::sqrt_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sqrt_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sqrt.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor square(const at::Tensor & self); // {"schema": "aten::square(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & square_(at::Tensor & self); // {"schema": "aten::square_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & square_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::square.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, bool unbiased); // {"schema": "aten::std(Tensor self, bool unbiased=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::std.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple std_mean(const at::Tensor & self, bool unbiased); // {"schema": "aten::std_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple std_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::std_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple std_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple std_mean(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::std_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple std_mean(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::std.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::std.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor std(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::std.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::std.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor std(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::std.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & std_out(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::std.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor prod(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::prod(Tensor self, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor prod(const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::prod.dim_int(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & prod_out(const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::prod.int_out(Tensor self, int dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor prod(const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::prod.dim_Dimname(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & prod_out(const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::prod.Dimname_out(Tensor self, Dimname dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor t(const at::Tensor & self); // {"schema": "aten::t(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & t_(at::Tensor & self); // {"schema": "aten::t_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor tan(const at::Tensor & self); // {"schema": "aten::tan(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tan_(at::Tensor & self); // {"schema": "aten::tan_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & tan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::tan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tanh(const at::Tensor & self); // {"schema": "aten::tanh(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tanh_(at::Tensor & self); // {"schema": "aten::tanh_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & tanh_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::tanh.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tensordot(const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other); // {"schema": "aten::tensordot(Tensor self, Tensor other, int[] dims_self, int[] dims_other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & tensordot_out(const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other, at::Tensor & out); // {"schema": "aten::tensordot.out(Tensor self, Tensor other, int[] dims_self, int[] dims_other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor threshold(const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value); // {"schema": "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & threshold_(at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value); // {"schema": "aten::threshold_(Tensor(a!) self, Scalar threshold, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & threshold_out(const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::threshold.out(Tensor self, Scalar threshold, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & threshold_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold, at::Tensor & grad_input); // {"schema": "aten::threshold_backward.grad_input(Tensor grad_output, Tensor self, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor threshold_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold); // {"schema": "aten::threshold_backward(Tensor grad_output, Tensor self, Scalar threshold) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor tile(const at::Tensor & self, c10::SymIntArrayRef dims); // {"schema": "aten::tile(Tensor self, SymInt[] dims) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor transpose(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose.int(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor transpose(const at::Tensor & self, at::Dimname dim0, at::Dimname dim1); // {"schema": "aten::transpose.Dimname(Tensor(a) self, Dimname dim0, Dimname dim1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _mkldnn_transpose(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::_mkldnn_transpose(Tensor self, int dim0, int dim1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & transpose_(at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mkldnn_transpose_(at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::_mkldnn_transpose_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor one_hot(const at::Tensor & self, int64_t num_classes); // {"schema": "aten::one_hot(Tensor self, int num_classes=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor flip(const at::Tensor & self, at::IntArrayRef dims); // {"schema": "aten::flip(Tensor self, int[] dims) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor fliplr(const at::Tensor & self); // {"schema": "aten::fliplr(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor flipud(const at::Tensor & self); // {"schema": "aten::flipud(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor roll(const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims); // {"schema": "aten::roll(Tensor self, SymInt[1] shifts, int[1] dims=[]) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor rot90(const at::Tensor & self, int64_t k, at::IntArrayRef dims); // {"schema": "aten::rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor trapezoid(const at::Tensor & y, const at::Tensor & x, int64_t dim); // {"schema": "aten::trapezoid.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trapezoid(const at::Tensor & y, const at::Scalar & dx, int64_t dim); // {"schema": "aten::trapezoid.dx(Tensor y, *, Scalar dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trapz(const at::Tensor & y, const at::Tensor & x, int64_t dim); // {"schema": "aten::trapz.x(Tensor y, Tensor x, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trapz(const at::Tensor & y, double dx, int64_t dim); // {"schema": "aten::trapz.dx(Tensor y, *, float dx=1, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _transform_bias_rescale_qkv(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads); // {"schema": "aten::_transform_bias_rescale_qkv(Tensor qkv, Tensor qkv_bias, int num_heads) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_from_mask(const at::Tensor & t, const at::Tensor & mask, bool mask_check); // {"schema": "aten::_nested_tensor_from_mask(Tensor t, Tensor mask, bool mask_check=True) -> Tensor", "dispatch": "True", "default": "False"} +bool _nested_tensor_from_mask_left_aligned(const at::Tensor & t, const at::Tensor & mask); // {"schema": "aten::_nested_tensor_from_mask_left_aligned(Tensor t, Tensor mask) -> bool", "dispatch": "True", "default": "False"} +at::Tensor _nested_from_padded(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213); // {"schema": "aten::_nested_from_padded(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_size(const at::Tensor & self); // {"schema": "aten::_nested_tensor_size(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_strides(const at::Tensor & self); // {"schema": "aten::_nested_tensor_strides(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_storage_offsets(const at::Tensor & self); // {"schema": "aten::_nested_tensor_storage_offsets(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_from_padded_and_nested_example(const at::Tensor & padded, const at::Tensor & nt_example); // {"schema": "aten::_nested_from_padded_and_nested_example(Tensor padded, Tensor nt_example) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_view_from_buffer(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets); // {"schema": "aten::_nested_view_from_buffer(Tensor(a) self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _nested_view_from_buffer_copy(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets); // {"schema": "aten::_nested_view_from_buffer_copy(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_view_from_jagged(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen); // {"schema": "aten::_nested_view_from_jagged(Tensor(a) self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _nested_view_from_jagged_copy(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen); // {"schema": "aten::_nested_view_from_jagged_copy(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_get_values(const at::Tensor & self); // {"schema": "aten::_nested_get_values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_values_copy(const at::Tensor & self); // {"schema": "aten::_nested_get_values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _nested_get_offsets(const at::Tensor & self); // {"schema": "aten::_nested_get_offsets(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_lengths(const at::Tensor & self); // {"schema": "aten::_nested_get_lengths(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +int64_t _nested_get_ragged_idx(const at::Tensor & self); // {"schema": "aten::_nested_get_ragged_idx(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_min_seqlen(const at::Tensor & self); // {"schema": "aten::_nested_get_min_seqlen(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_max_seqlen(const at::Tensor & self); // {"schema": "aten::_nested_get_max_seqlen(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_get_jagged_dummy(const at::Tensor & any); // {"schema": "aten::_nested_get_jagged_dummy(Tensor any) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _nested_compute_contiguous_strides_offsets(const at::Tensor & nested_size); // {"schema": "aten::_nested_compute_contiguous_strides_offsets(Tensor nested_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _trilinear(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim); // {"schema": "aten::_trilinear(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor triplet_margin_loss(const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin, double p, double eps, bool swap, int64_t reduction); // {"schema": "aten::triplet_margin_loss(Tensor anchor, Tensor positive, Tensor negative, float margin=1.0, float p=2, float eps=1e-06, bool swap=False, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor trunc(const at::Tensor & self); // {"schema": "aten::trunc(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & trunc_(at::Tensor & self); // {"schema": "aten::trunc_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & trunc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::trunc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fix(const at::Tensor & self); // {"schema": "aten::fix(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fix_(at::Tensor & self); // {"schema": "aten::fix_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & fix_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::fix.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor type_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::type_as(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +bool _has_compatible_shallow_copy_type(const at::Tensor & self, const at::Tensor & from); // {"schema": "aten::_has_compatible_shallow_copy_type(Tensor self, Tensor from) -> bool", "dispatch": "False", "default": "True"} +::std::tuple _unique(const at::Tensor & self, bool sorted, bool return_inverse); // {"schema": "aten::_unique(Tensor self, bool sorted=True, bool return_inverse=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple unique_dim(const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts); // {"schema": "aten::unique_dim(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple unique_consecutive(const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim); // {"schema": "aten::unique_consecutive(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple unique_dim_consecutive(const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts); // {"schema": "aten::unique_dim_consecutive(Tensor self, int dim, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _unique2(const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts); // {"schema": "aten::_unique2(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _unsafe_view(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::_unsafe_view(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor unsqueeze(const at::Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze(Tensor(a) self, int dim) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & unsqueeze_(at::Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze_(Tensor(a!) self, int dim) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor vander(const at::Tensor & x, ::std::optional N, bool increasing); // {"schema": "aten::vander(Tensor x, int? N=None, bool increasing=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, bool unbiased); // {"schema": "aten::var(Tensor self, bool unbiased=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::var.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & var_out(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::var.out(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & var_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::var.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor var(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::var.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & var_out(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim, at::Tensor & out); // {"schema": "aten::var.names_out(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor var(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & var_out(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim, at::Tensor & out); // {"schema": "aten::var.correction_names_out(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, bool unbiased); // {"schema": "aten::var_mean(Tensor self, bool unbiased=True) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim); // {"schema": "aten::var_mean.dim(Tensor self, int[1]? dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var_mean.correction(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple var_mean(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim); // {"schema": "aten::var_mean.names_dim(Tensor self, Dimname[1] dim, bool unbiased=True, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple var_mean(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim); // {"schema": "aten::var_mean.correction_names(Tensor self, Dimname[1] dim, *, Scalar? correction=None, bool keepdim=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor view_as(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::view_as(Tensor(a) self, Tensor other) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor where(const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::where.self(Tensor condition, Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & where_out(const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::where.self_out(Tensor condition, Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor where(const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::where.ScalarSelf(Tensor condition, Scalar self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor where(const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::where.ScalarOther(Tensor condition, Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor where(const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other); // {"schema": "aten::where.Scalar(Tensor condition, Scalar self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector where(const at::Tensor & condition); // {"schema": "aten::where(Tensor condition) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor norm_except_dim(const at::Tensor & v, int64_t pow, int64_t dim); // {"schema": "aten::norm_except_dim(Tensor v, int pow=2, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _weight_norm(const at::Tensor & v, const at::Tensor & g, int64_t dim); // {"schema": "aten::_weight_norm(Tensor v, Tensor g, int dim=0) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _weight_norm_interface(const at::Tensor & v, const at::Tensor & g, int64_t dim); // {"schema": "aten::_weight_norm_interface(Tensor v, Tensor g, int dim=0) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _weight_norm_interface_backward(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim); // {"schema": "aten::_weight_norm_interface_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _weight_norm_differentiable_backward(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim); // {"schema": "aten::_weight_norm_differentiable_backward(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor zeros(at::IntArrayRef size, ::std::optional names, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::zeros.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _efficientzerotensor(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_efficientzerotensor(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor zeros(c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::zeros(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & zeros_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::zeros.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor zeros_like(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format); // {"schema": "aten::zeros_like(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _standard_gamma_grad(const at::Tensor & self, const at::Tensor & output); // {"schema": "aten::_standard_gamma_grad(Tensor self, Tensor output) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _standard_gamma(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::_standard_gamma(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _dirichlet_grad(const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total); // {"schema": "aten::_dirichlet_grad(Tensor x, Tensor alpha, Tensor total) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sample_dirichlet(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::_sample_dirichlet(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor poisson(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::poisson(Tensor self, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor binomial(const at::Tensor & count, const at::Tensor & prob, ::std::optional generator); // {"schema": "aten::binomial(Tensor count, Tensor prob, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor native_norm(const at::Tensor & self, const at::Scalar & p); // {"schema": "aten::native_norm(Tensor self, Scalar p=2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor native_norm(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::native_norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_with_update(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps); // {"schema": "aten::_batch_norm_with_update(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_with_update_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, at::Tensor & running_mean, at::Tensor & running_var, double momentum, double eps, at::Tensor & out, at::Tensor & save_mean, at::Tensor & save_invstd, at::Tensor & reserve); // {"schema": "aten::_batch_norm_with_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor(a!) running_mean, Tensor(b!) running_var, float momentum, float eps, *, Tensor(d!) out, Tensor(e!) save_mean, Tensor(f!) save_invstd, Tensor(g!) reserve) -> (Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))", "dispatch": "True", "default": "False"} +::std::tuple _batch_norm_no_update(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps); // {"schema": "aten::_batch_norm_no_update(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_backward(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve); // {"schema": "aten::batch_norm_backward(Tensor grad_out, Tensor input, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, bool update, float eps, bool[3] output_mask, Tensor reserve) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor _sparse_sum(const at::Tensor & self); // {"schema": "aten::_sparse_sum(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sum(const at::Tensor & self, at::ScalarType dtype); // {"schema": "aten::_sparse_sum.dtype(Tensor self, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sum(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::_sparse_sum.dim(Tensor self, int[1] dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_sum(const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype); // {"schema": "aten::_sparse_sum.dim_dtype(Tensor self, int[1] dim, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_sum_backward(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::_sparse_sum_backward(Tensor grad, Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_csr_sum(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::_sparse_csr_sum.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_csr_prod(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::_sparse_csr_prod.dim_dtype(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::_sparse_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::_sparse_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_sparse_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self); // {"schema": "aten::_sparse_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_log_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::_sparse_log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_log_softmax(const at::Tensor & self, at::Dimname dim, ::std::optional dtype); // {"schema": "aten::_sparse_log_softmax.Dimname(Tensor self, Dimname dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_log_softmax(const at::Tensor & self, int64_t dim, bool half_to_float); // {"schema": "aten::_sparse_log_softmax(Tensor self, int dim, bool half_to_float) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_log_softmax_backward_data(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self); // {"schema": "aten::_sparse_log_softmax_backward_data(Tensor grad_output, Tensor output, int dim, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _spdiags(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout); // {"schema": "aten::_spdiags(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype); // {"schema": "aten::norm.ScalarOpt_dtype(Tensor self, Scalar? p, *, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor norm(const at::Tensor & self, const at::Scalar & p); // {"schema": "aten::norm.Scalar(Tensor self, Scalar p=2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype); // {"schema": "aten::norm.ScalarOpt_dim_dtype(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::norm.ScalarOpt_dim(Tensor self, Scalar? p, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::norm.dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::norm.out(Tensor self, Scalar? p, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype); // {"schema": "aten::norm.names_ScalarOpt_dim_dtype(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor norm(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim); // {"schema": "aten::norm.names_ScalarOpt_dim(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::norm.names_dtype_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::Tensor & out); // {"schema": "aten::norm.names_out(Tensor self, Scalar? p, Dimname[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple frexp(const at::Tensor & self); // {"schema": "aten::frexp.Tensor(Tensor self) -> (Tensor mantissa, Tensor exponent)", "dispatch": "True", "default": "True"} +::std::tuple frexp_out(const at::Tensor & self, at::Tensor & mantissa, at::Tensor & exponent); // {"schema": "aten::frexp.Tensor_out(Tensor self, *, Tensor(a!) mantissa, Tensor(b!) exponent) -> (Tensor(a!) mantissa, Tensor(b!) exponent)", "dispatch": "True", "default": "False"} +at::Tensor frobenius_norm(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::frobenius_norm.dim(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & frobenius_norm_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::frobenius_norm.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nuclear_norm(const at::Tensor & self, bool keepdim); // {"schema": "aten::nuclear_norm(Tensor self, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nuclear_norm_out(const at::Tensor & self, bool keepdim, at::Tensor & out); // {"schema": "aten::nuclear_norm.out(Tensor self, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nuclear_norm(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::nuclear_norm.dim(Tensor self, int[2] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nuclear_norm_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::nuclear_norm.dim_out(Tensor self, int[2] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor clone(const at::Tensor & self, ::std::optional memory_format); // {"schema": "aten::clone(Tensor self, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor positive(const at::Tensor & self); // {"schema": "aten::positive(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +const at::Tensor & resize_as_(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format); // {"schema": "aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & resize_as_sparse_(const at::Tensor & self, const at::Tensor & the_template); // {"schema": "aten::resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & zero_(at::Tensor & self); // {"schema": "aten::zero_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & sub_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sub.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::sub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sub_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::sub_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor sub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::sub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sub_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::sub_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & subtract_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor subtract(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & subtract_(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor subtract(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & subtract_(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor rsub(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & heaviside_out(const at::Tensor & self, const at::Tensor & values, at::Tensor & out); // {"schema": "aten::heaviside.out(Tensor self, Tensor values, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor heaviside(const at::Tensor & self, const at::Tensor & values); // {"schema": "aten::heaviside(Tensor self, Tensor values) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & heaviside_(at::Tensor & self, const at::Tensor & values); // {"schema": "aten::heaviside_(Tensor(a!) self, Tensor values) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor rsub(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha); // {"schema": "aten::rsub.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::_sparse_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sparse_sampled_addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sparse_sampled_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sparse_sampled_addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::sparse_sampled_addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _sparse_mm_reduce_impl(const at::Tensor & self, const at::Tensor & other, c10::string_view reduce); // {"schema": "aten::_sparse_mm_reduce_impl(Tensor self, Tensor other, str reduce) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _sparse_mm_reduce_impl_backward(const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask); // {"schema": "aten::_sparse_mm_reduce_impl_backward(Tensor self, Tensor grad_out, Tensor weight, str reduce, Tensor arg_out, bool[2] output_mask) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor addmm(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmm.dtype(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addmm.dtype_out(Tensor self, Tensor mat1, Tensor mat2, ScalarType out_dtype, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & addmm_(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addmm_(Tensor(a!) self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _addmm_activation_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu, at::Tensor & out); // {"schema": "aten::_addmm_activation.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _addmm_activation(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu); // {"schema": "aten::_addmm_activation(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, bool use_gelu=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _scaled_mm(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum); // {"schema": "aten::_scaled_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _scaled_mm_out(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum, at::Tensor & out); // {"schema": "aten::_scaled_mm.out(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _scaled_grouped_mm(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & offs, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum); // {"schema": "aten::_scaled_grouped_mm(Tensor self, Tensor mat2, Tensor scale_a, Tensor scale_b, Tensor? offs=None, Tensor? bias=None, Tensor? scale_result=None, ScalarType? out_dtype=None, bool use_fast_accum=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _grouped_mm(const at::Tensor & self, const at::Tensor & mat2, const ::std::optional & offs, const ::std::optional & bias, ::std::optional out_dtype); // {"schema": "aten::_grouped_mm(Tensor self, Tensor mat2, Tensor? offs=None, Tensor? bias=None, ScalarType? out_dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_compressed_tensor_with_dims(int64_t nnz, int64_t dense_dim, at::IntArrayRef size, at::IntArrayRef blocksize, at::ScalarType index_dtype, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_compressed_tensor_with_dims(int nnz, int dense_dim, int[] size, int[] blocksize, ScalarType index_dtype, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_compressed_tensor(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_compressed_tensor.comp_plain_value_size(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_csr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_csc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsr_tensor.crow_col_value_size(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsc_tensor.ccol_row_value_size(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_compressed_tensor(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_compressed_tensor.comp_plain_value(Tensor compressed_indices, Tensor plain_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_csr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_csc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_csc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsr_tensor(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_bsc_tensor(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_bsc_tensor.ccol_row_value(Tensor ccol_indices, Tensor row_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_compressed_tensor_unsafe(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_compressed_tensor_unsafe(Tensor compressed_indices, Tensor plain_indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_csr_tensor_unsafe(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_csc_tensor_unsafe(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_csc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_bsr_tensor_unsafe(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_bsr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_bsc_tensor_unsafe(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_bsc_tensor_unsafe(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_coo_tensor(at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor sparse_coo_tensor(const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor sparse_coo_tensor(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::sparse_coo_tensor.indices_size(Tensor indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _sparse_coo_tensor_unsafe(const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::_sparse_coo_tensor_unsafe(Tensor indices, Tensor values, SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool? is_coalesced=None) -> Tensor", "dispatch": "False", "default": "True"} +void _validate_sparse_coo_tensor_args(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional is_coalesced, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size, bool? is_coalesced=None, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_compressed_tensor_args(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_compressed_tensor_args(Tensor compressed_indices, Tensor plain_indices, Tensor values, int[] size, Layout layout, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_csr_tensor_args(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_csc_tensor_args(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_csc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_bsr_tensor_args(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_bsr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +void _validate_sparse_bsc_tensor_args(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning); // {"schema": "aten::_validate_sparse_bsc_tensor_args(Tensor ccol_indices, Tensor row_indices, Tensor values, int[] size, bool? check_pinning=None) -> ()", "dispatch": "False", "default": "True"} +at::Tensor _sparse_coo_tensor_with_dims(int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_coo_tensor_with_dims_and_tensors(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced); // {"schema": "aten::_sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False, bool? is_coalesced=None) -> Tensor", "dispatch": "True", "default": "False"} +const at::Tensor & sparse_resize_(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)", "dispatch": "True", "default": "False"} +const at::Tensor & sparse_resize_and_clear_(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sparse_mask(const at::Tensor & self, const at::Tensor & mask); // {"schema": "aten::sparse_mask(Tensor self, Tensor mask) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _sparse_mask_projection(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches); // {"schema": "aten::_sparse_mask_projection(Tensor self, Tensor mask, bool accumulate_matches=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector _to_cpu(at::TensorList tensors); // {"schema": "aten::_to_cpu(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor to_dense(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad); // {"schema": "aten::to_dense(Tensor self, ScalarType? dtype=None, *, bool? masked_grad=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_dense(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad); // {"schema": "aten::_to_dense(Tensor self, ScalarType? dtype=None, bool? masked_grad=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_dense_backward(const at::Tensor & grad, const at::Tensor & input, ::std::optional masked_grad); // {"schema": "aten::to_dense_backward(Tensor grad, Tensor input, bool? masked_grad=None) -> Tensor", "dispatch": "False", "default": "True"} +int64_t sparse_dim(const at::Tensor & self); // {"schema": "aten::sparse_dim(Tensor self) -> int", "dispatch": "True", "default": "True"} +int64_t _dimI(const at::Tensor & self); // {"schema": "aten::_dimI(Tensor self) -> int", "dispatch": "True", "default": "False"} +int64_t dense_dim(const at::Tensor & self); // {"schema": "aten::dense_dim(Tensor self) -> int", "dispatch": "True", "default": "True"} +int64_t _dimV(const at::Tensor & self); // {"schema": "aten::_dimV(Tensor self) -> int", "dispatch": "True", "default": "False"} +int64_t _nnz(const at::Tensor & self); // {"schema": "aten::_nnz(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor coalesce(const at::Tensor & self); // {"schema": "aten::coalesce(Tensor(a) self) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _coalesce(const at::Tensor & self); // {"schema": "aten::_coalesce(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +bool is_coalesced(const at::Tensor & self); // {"schema": "aten::is_coalesced(Tensor self) -> bool", "dispatch": "True", "default": "True"} +at::Tensor _indices(const at::Tensor & self); // {"schema": "aten::_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor _values(const at::Tensor & self); // {"schema": "aten::_values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor & _coalesced_(at::Tensor & self, bool coalesced); // {"schema": "aten::_coalesced_(Tensor(a!) self, bool coalesced) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor indices(const at::Tensor & self); // {"schema": "aten::indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor values(const at::Tensor & self); // {"schema": "aten::values(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor crow_indices(const at::Tensor & self); // {"schema": "aten::crow_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor col_indices(const at::Tensor & self); // {"schema": "aten::col_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor ccol_indices(const at::Tensor & self); // {"schema": "aten::ccol_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor row_indices(const at::Tensor & self); // {"schema": "aten::row_indices(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & hspmm_out(const at::Tensor & mat1, const at::Tensor & mat2, at::Tensor & out); // {"schema": "aten::hspmm.out(Tensor mat1, Tensor mat2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hspmm(const at::Tensor & mat1, const at::Tensor & mat2); // {"schema": "aten::hspmm(Tensor mat1, Tensor mat2) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & copy_sparse_to_sparse_(at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy_sparse_to_sparse_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector unbind(const at::Tensor & self, int64_t dim); // {"schema": "aten::unbind.int(Tensor(a -> *) self, int dim=0) -> Tensor(a)[]", "dispatch": "True", "default": "True"} +::std::vector unbind(const at::Tensor & self, at::Dimname dim); // {"schema": "aten::unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]", "dispatch": "False", "default": "True"} +at::Tensor to_sparse(const at::Tensor & self, int64_t sparse_dim); // {"schema": "aten::to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse(const at::Tensor & self, int64_t sparse_dim); // {"schema": "aten::_to_sparse.sparse_dim(Tensor self, int sparse_dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::_to_sparse(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_csr(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_csr(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_csr(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_csc(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_csc(const at::Tensor & self, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_csc(Tensor self, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_bsr(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_bsr(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_bsr(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_sparse_bsc(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _to_sparse_bsc(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim); // {"schema": "aten::_to_sparse_bsc(Tensor self, int[2] blocksize, int? dense_dim=None) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _to_sparse_semi_structured(const at::Tensor & dense); // {"schema": "aten::_to_sparse_semi_structured(Tensor dense) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor to_mkldnn(const at::Tensor & self, ::std::optional dtype); // {"schema": "aten::to_mkldnn(Tensor self, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_reorder_conv2d_weight(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size); // {"schema": "aten::mkldnn_reorder_conv2d_weight(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_reorder_conv3d_weight(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size); // {"schema": "aten::mkldnn_reorder_conv3d_weight(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor to_mkldnn_backward(const at::Tensor & grad, const at::Tensor & input); // {"schema": "aten::to_mkldnn_backward(Tensor grad, Tensor input) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantize_per_tensor_dynamic(const at::Tensor & self, at::ScalarType dtype, bool reduce_range); // {"schema": "aten::quantize_per_tensor_dynamic(Tensor self, ScalarType dtype, bool reduce_range) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantize_per_tensor(const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype); // {"schema": "aten::quantize_per_tensor(Tensor self, float scale, int zero_point, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor quantize_per_tensor(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype); // {"schema": "aten::quantize_per_tensor.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector quantize_per_tensor(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype); // {"schema": "aten::quantize_per_tensor.tensors(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype) -> Tensor[]", "dispatch": "True", "default": "False"} +at::Tensor quantize_per_channel(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype); // {"schema": "aten::quantize_per_channel(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor dequantize(const at::Tensor & self); // {"schema": "aten::dequantize.self(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector dequantize(at::TensorList tensors); // {"schema": "aten::dequantize.tensors(Tensor[] tensors) -> Tensor[]", "dispatch": "True", "default": "False"} +double q_scale(const at::Tensor & self); // {"schema": "aten::q_scale(Tensor self) -> float", "dispatch": "True", "default": "False"} +int64_t q_zero_point(const at::Tensor & self); // {"schema": "aten::q_zero_point(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor q_per_channel_scales(const at::Tensor & self); // {"schema": "aten::q_per_channel_scales(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor q_per_channel_zero_points(const at::Tensor & self); // {"schema": "aten::q_per_channel_zero_points(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +int64_t q_per_channel_axis(const at::Tensor & self); // {"schema": "aten::q_per_channel_axis(Tensor self) -> int", "dispatch": "True", "default": "False"} +at::Tensor int_repr(const at::Tensor & self); // {"schema": "aten::int_repr(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _make_per_tensor_quantized_tensor(const at::Tensor & self, double scale, int64_t zero_point); // {"schema": "aten::_make_per_tensor_quantized_tensor(Tensor self, float scale, int zero_point) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _make_per_channel_quantized_tensor(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis); // {"schema": "aten::_make_per_channel_quantized_tensor(Tensor self, Tensor scale, Tensor zero_point, int axis) -> Tensor", "dispatch": "True", "default": "False"} +at::QScheme qscheme(const at::Tensor & self); // {"schema": "aten::qscheme(Tensor self) -> QScheme", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_tensor_affine(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fake_quantize_per_tensor_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine.tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple fake_quantize_per_tensor_affine_cachemask(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask(Tensor self, float scale, int zero_point, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max); // {"schema": "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_tensor_affine_cachemask_backward(const at::Tensor & grad, const at::Tensor & mask); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _fake_quantize_learnable_per_tensor_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _fake_quantize_learnable_per_tensor_affine_backward(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_channel_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple fake_quantize_per_channel_affine_cachemask(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +at::Tensor fake_quantize_per_channel_affine_cachemask_backward(const at::Tensor & grad, const at::Tensor & mask); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask_backward(Tensor grad, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _fake_quantize_learnable_per_channel_affine(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _fake_quantize_learnable_per_channel_affine_backward(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine_backward(Tensor grad, Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor fused_moving_avg_obs_fake_quant(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::fused_moving_avg_obs_fake_quant(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _fused_moving_avg_obs_fq_helper(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::_fused_moving_avg_obs_fq_helper(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask)", "dispatch": "True", "default": "False"} +::std::tuple _choose_qparams_per_tensor(const at::Tensor & self, bool reduce_range); // {"schema": "aten::_choose_qparams_per_tensor(Tensor self, bool reduce_range=False) -> (float, int)", "dispatch": "False", "default": "True"} +at::Tensor _saturate_weight_to_fp16(const at::Tensor & weight); // {"schema": "aten::_saturate_weight_to_fp16(Tensor weight) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple choose_qparams_optimized(const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width); // {"schema": "aten::choose_qparams_optimized(Tensor input, int numel, int n_bins, float ratio, int bit_width) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor _autocast_to_reduced_precision(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype); // {"schema": "aten::_autocast_to_reduced_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled, ScalarType cuda_dtype, ScalarType cpu_dtype) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _autocast_to_full_precision(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled); // {"schema": "aten::_autocast_to_full_precision(Tensor(a) self, bool cuda_enabled, bool cpu_enabled) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor _to_copy(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, ::std::optional memory_format); // {"schema": "aten::_to_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor to(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.dtype_layout(Tensor(a) self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor to(const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.device(Tensor(a) self, Device device, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor to(const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor to(const at::Tensor & self, const at::Tensor & other, bool non_blocking, bool copy, ::std::optional memory_format); // {"schema": "aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)", "dispatch": "False", "default": "True"} +::std::vector meshgrid(at::TensorList tensors); // {"schema": "aten::meshgrid(Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +::std::vector meshgrid(at::TensorList tensors, c10::string_view indexing); // {"schema": "aten::meshgrid.indexing(Tensor[] tensors, *, str indexing) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor cartesian_prod(at::TensorList tensors); // {"schema": "aten::cartesian_prod(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor combinations(const at::Tensor & self, int64_t r, bool with_replacement); // {"schema": "aten::combinations(Tensor self, int r=2, bool with_replacement=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Scalar item(const at::Tensor & self); // {"schema": "aten::item(Tensor self) -> Scalar", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Tensor & tensor, const at::Tensor & other); // {"schema": "aten::result_type.Tensor(Tensor tensor, Tensor other) -> ScalarType", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Tensor & tensor, const at::Scalar & other); // {"schema": "aten::result_type.Scalar(Tensor tensor, Scalar other) -> ScalarType", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Scalar & scalar, const at::Tensor & tensor); // {"schema": "aten::result_type.Scalar_Tensor(Scalar scalar, Tensor tensor) -> ScalarType", "dispatch": "False", "default": "True"} +at::ScalarType result_type(const at::Scalar & scalar1, const at::Scalar & scalar2); // {"schema": "aten::result_type.Scalar_Scalar(Scalar scalar1, Scalar scalar2) -> ScalarType", "dispatch": "False", "default": "True"} +bool can_cast(at::ScalarType from_, at::ScalarType to); // {"schema": "aten::can_cast(ScalarType from_, ScalarType to) -> bool", "dispatch": "False", "default": "True"} +at::ScalarType promote_types(at::ScalarType type1, at::ScalarType type2); // {"schema": "aten::promote_types(ScalarType type1, ScalarType type2) -> ScalarType", "dispatch": "False", "default": "True"} +at::Scalar _local_scalar_dense(const at::Tensor & self); // {"schema": "aten::_local_scalar_dense(Tensor self) -> Scalar", "dispatch": "True", "default": "False"} +::std::tuple _lstm_mps(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::_lstm_mps(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple,::std::vector> lstm_mps_backward(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::lstm_mps_backward(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor[], Tensor[])", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_lstm_cell(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias); // {"schema": "aten::_thnn_fused_lstm_cell(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_lstm_cell_backward_impl(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_lstm_cell_backward_impl(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_lstm_cell_backward(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _thnn_differentiable_lstm_cell_backward(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const ::std::optional & input_bias, const ::std::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy); // {"schema": "aten::_thnn_differentiable_lstm_cell_backward(Tensor? grad_hy, Tensor? grad_cy, Tensor input_gates, Tensor hidden_gates, Tensor? input_bias, Tensor? hidden_bias, Tensor cx, Tensor cy) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _thnn_fused_gru_cell(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias); // {"schema": "aten::_thnn_fused_gru_cell(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_fused_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias); // {"schema": "aten::_thnn_fused_gru_cell_backward(Tensor grad_hy, Tensor workspace, bool has_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _thnn_differentiable_gru_cell_backward(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias); // {"schema": "aten::_thnn_differentiable_gru_cell_backward(Tensor grad_hy, Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias, Tensor? hidden_bias) -> (Tensor, Tensor, Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple lstm(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::lstm.input(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple lstm(const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::lstm.data(Tensor data, Tensor batch_sizes, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple gru(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::gru.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple gru(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::gru.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_tanh(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::rnn_tanh.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_tanh(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::rnn_tanh.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_relu(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first); // {"schema": "aten::rnn_relu.input(Tensor input, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple rnn_relu(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional); // {"schema": "aten::rnn_relu.data(Tensor data, Tensor batch_sizes, Tensor hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple lstm_cell(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor gru_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor rnn_tanh_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor rnn_relu_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh); // {"schema": "aten::rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor? b_ih=None, Tensor? b_hh=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple quantized_lstm_cell(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_lstm_cell(Tensor input, Tensor[] hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor quantized_gru_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_gru_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantized_rnn_relu_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_rnn_relu_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantized_rnn_tanh_cell(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh); // {"schema": "aten::quantized_rnn_tanh_cell(Tensor input, Tensor hx, Tensor w_ih, Tensor w_hh, Tensor b_ih, Tensor b_hh, Tensor packed_ih, Tensor packed_hh, Tensor col_offsets_ih, Tensor col_offsets_hh, Scalar scale_ih, Scalar scale_hh, Scalar zero_point_ih, Scalar zero_point_hh) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _pack_padded_sequence(const at::Tensor & input, const at::Tensor & lengths, bool batch_first); // {"schema": "aten::_pack_padded_sequence(Tensor input, Tensor lengths, bool batch_first) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor _pack_padded_sequence_backward(const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first); // {"schema": "aten::_pack_padded_sequence_backward(Tensor grad, SymInt[] input_size, Tensor batch_sizes, bool batch_first) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple _pad_packed_sequence(const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length); // {"schema": "aten::_pad_packed_sequence(Tensor data, Tensor batch_sizes, bool batch_first, Scalar padding_value, int total_length) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +at::Tensor & set_(at::Tensor & self, at::Storage source); // {"schema": "aten::set_.source_Storage(Tensor(a!) self, Storage source) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & set_(at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set_.source_Storage_storage_offset(Tensor(a!) self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & set_(at::Tensor & self, const at::Tensor & source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set_.source_Tensor_storage_offset(Tensor(a!) self, Tensor source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & set_(at::Tensor & self, const at::Tensor & source); // {"schema": "aten::set_.source_Tensor(Tensor(a!) self, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & set_(at::Tensor & self); // {"schema": "aten::set_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lift(const at::Tensor & self); // {"schema": "aten::lift(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor lift_fresh(const at::Tensor & self); // {"schema": "aten::lift_fresh(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor lift_fresh_copy(const at::Tensor & self); // {"schema": "aten::lift_fresh_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +bool is_set_to(const at::Tensor & self, const at::Tensor & tensor); // {"schema": "aten::is_set_to(Tensor self, Tensor tensor) -> bool", "dispatch": "True", "default": "False"} +at::Tensor & masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); // {"schema": "aten::masked_fill_.Scalar(Tensor(a!) self, Tensor mask, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value); // {"schema": "aten::masked_fill.Scalar(Tensor self, Tensor mask, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & masked_fill_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); // {"schema": "aten::masked_fill_.Tensor(Tensor(a!) self, Tensor mask, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_fill(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value); // {"schema": "aten::masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & masked_scatter_(at::Tensor & self, const at::Tensor & mask, const at::Tensor & source); // {"schema": "aten::masked_scatter_(Tensor(a!) self, Tensor mask, Tensor source) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_scatter(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source); // {"schema": "aten::masked_scatter(Tensor self, Tensor mask, Tensor source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor masked_scatter_backward(const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes); // {"schema": "aten::masked_scatter_backward(Tensor grad_output, Tensor mask, SymInt[] sizes) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _masked_softmax(const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type); // {"schema": "aten::_masked_softmax(Tensor self, Tensor mask, int? dim=None, int? mask_type=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _masked_softmax_backward(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim); // {"schema": "aten::_masked_softmax_backward(Tensor grad_output, Tensor output, Tensor mask, int? dim=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor view(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::view(Tensor(a) self, SymInt[] size) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor view(const at::Tensor & self, at::ScalarType dtype); // {"schema": "aten::view.dtype(Tensor(a) self, ScalarType dtype) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor & put_(at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate); // {"schema": "aten::put_(Tensor(a!) self, Tensor index, Tensor source, bool accumulate=False) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor put(const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate); // {"schema": "aten::put(Tensor self, Tensor index, Tensor source, bool accumulate=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_add_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::index_add.out(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & index_add_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha); // {"schema": "aten::index_add_(Tensor(a!) self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_add(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha); // {"schema": "aten::index_add(Tensor self, int dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor index_add(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha); // {"schema": "aten::index_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor source, *, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & index_reduce_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self, at::Tensor & out); // {"schema": "aten::index_reduce.out(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & index_reduce_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self); // {"schema": "aten::index_reduce_(Tensor(a!) self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor index_reduce(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self); // {"schema": "aten::index_reduce(Tensor self, int dim, Tensor index, Tensor source, str reduce, *, bool include_self=True) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill_.int_Scalar(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor index_fill(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill.int_Scalar(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill_.int_Tensor(Tensor(a!) self, int dim, Tensor index, Tensor value) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor index_fill(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill.int_Tensor(Tensor self, int dim, Tensor index, Tensor value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill_.Dimname_Scalar(Tensor(a!) self, Dimname dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & index_fill_(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill_.Dimname_Tensor(Tensor(a!) self, Dimname dim, Tensor index, Tensor value) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor index_fill(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::index_fill.Dimname_Scalar(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor index_fill(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value); // {"schema": "aten::index_fill.Dimname_Tensor(Tensor self, Dimname dim, Tensor index, Tensor value) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter.src(Tensor self, int dim, Tensor index, Tensor src) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_.src(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out); // {"schema": "aten::scatter.src_out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::scatter.value(Tensor self, int dim, Tensor index, Scalar value) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::scatter_.value(Tensor(a!) self, int dim, Tensor index, Scalar value) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::scatter.value_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce); // {"schema": "aten::scatter.reduce(Tensor self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce); // {"schema": "aten::scatter_.reduce(Tensor(a!) self, int dim, Tensor index, Tensor src, *, str reduce) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, at::Tensor & out); // {"schema": "aten::scatter.reduce_out(Tensor self, int dim, Tensor index, Tensor src, *, str reduce, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce); // {"schema": "aten::scatter.value_reduce(Tensor self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce); // {"schema": "aten::scatter_.value_reduce(Tensor(a!) self, int dim, Tensor index, Scalar value, *, str reduce) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce, at::Tensor & out); // {"schema": "aten::scatter.value_reduce_out(Tensor self, int dim, Tensor index, Scalar value, *, str reduce, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter.dimname_src(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value); // {"schema": "aten::scatter.dimname_value(Tensor self, Dimname dim, Tensor index, Scalar value) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter_add(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_add(Tensor self, int dim, Tensor index, Tensor src) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_add_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_add_(Tensor(a!) self, int dim, Tensor index, Tensor src) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_add_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, at::Tensor & out); // {"schema": "aten::scatter_add.out(Tensor self, int dim, Tensor index, Tensor src, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor scatter_add(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src); // {"schema": "aten::scatter_add.dimname(Tensor self, Dimname dim, Tensor index, Tensor src) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor scatter_reduce(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self); // {"schema": "aten::scatter_reduce.two(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & scatter_reduce_(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self); // {"schema": "aten::scatter_reduce_.two(Tensor(a!) self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scatter_reduce_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self, at::Tensor & out); // {"schema": "aten::scatter_reduce.two_out(Tensor self, int dim, Tensor index, Tensor src, str reduce, *, bool include_self=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & eq_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::eq_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & eq_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::eq_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_and.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bitwise_and_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_and.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_and(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_and.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_and(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_and.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_and(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_and.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_and_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_and_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor __and__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__and__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor __and__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__and__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & __iand__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__iand__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & __iand__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__iand__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & bitwise_or_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_or.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bitwise_or_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_or.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_or(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_or.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_or(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_or.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_or(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_or.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_or_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_or_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_or_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_or_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor __or__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__or__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor __or__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__or__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & __ior__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__ior__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & __ior__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__ior__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & bitwise_xor_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_xor.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & bitwise_xor_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_xor.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_xor(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_xor.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_xor(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_xor.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor bitwise_xor(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_xor.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_xor_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_xor_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_xor_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_xor_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor __xor__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__xor__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor __xor__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__xor__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & __ixor__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__ixor__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & __ixor__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__ixor__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor __lshift__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__lshift__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor __lshift__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__lshift__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & __ilshift__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__ilshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & __ilshift__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__ilshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_left_shift(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_left_shift.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_left_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_left_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_left_shift(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_left_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_left_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_left_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_left_shift(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_left_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor __rshift__(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__rshift__.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor __rshift__(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__rshift__.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & __irshift__(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::__irshift__.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & __irshift__(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::__irshift__.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_right_shift(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_right_shift.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::bitwise_right_shift_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_right_shift.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bitwise_right_shift(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_right_shift.Tensor_Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::bitwise_right_shift_.Tensor_Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::bitwise_right_shift.Tensor_Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bitwise_right_shift(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::bitwise_right_shift.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tril_(at::Tensor & self, int64_t diagonal); // {"schema": "aten::tril_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & triu_(at::Tensor & self, int64_t diagonal); // {"schema": "aten::triu_(Tensor(a!) self, int diagonal=0) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & digamma_(at::Tensor & self); // {"schema": "aten::digamma_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); // {"schema": "aten::lerp_.Scalar(Tensor(a!) self, Tensor end, Scalar weight) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lerp_(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); // {"schema": "aten::lerp_.Tensor(Tensor(a!) self, Tensor end, Tensor weight) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addbmm_(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addbmm_(Tensor(a!) self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & addbmm_out(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::addbmm.out(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addbmm(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha); // {"schema": "aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & random_(at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator); // {"schema": "aten::random_.from(Tensor(a!) self, int from, int? to, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & random_(at::Tensor & self, int64_t to, ::std::optional generator); // {"schema": "aten::random_.to(Tensor(a!) self, int to, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & random_(at::Tensor & self, ::std::optional generator); // {"schema": "aten::random_(Tensor(a!) self, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & uniform_(at::Tensor & self, double from, double to, ::std::optional generator); // {"schema": "aten::uniform_(Tensor(a!) self, float from=0, float to=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & cauchy_(at::Tensor & self, double median, double sigma, ::std::optional generator); // {"schema": "aten::cauchy_(Tensor(a!) self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & log_normal_(at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::log_normal_(Tensor(a!) self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & exponential_(at::Tensor & self, double lambd, ::std::optional generator); // {"schema": "aten::exponential_(Tensor(a!) self, float lambd=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & geometric_(at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::geometric_(Tensor(a!) self, float p, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & diag_out(const at::Tensor & self, int64_t diagonal, at::Tensor & out); // {"schema": "aten::diag.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor diag(const at::Tensor & self, int64_t diagonal); // {"schema": "aten::diag(Tensor self, int diagonal=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & cross_out(const at::Tensor & self, const at::Tensor & other, ::std::optional dim, at::Tensor & out); // {"schema": "aten::cross.out(Tensor self, Tensor other, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor cross(const at::Tensor & self, const at::Tensor & other, ::std::optional dim); // {"schema": "aten::cross(Tensor self, Tensor other, int? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & triu_out(const at::Tensor & self, int64_t diagonal, at::Tensor & out); // {"schema": "aten::triu.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor triu(const at::Tensor & self, int64_t diagonal); // {"schema": "aten::triu(Tensor self, int diagonal=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tril_out(const at::Tensor & self, int64_t diagonal, at::Tensor & out); // {"schema": "aten::tril.out(Tensor self, int diagonal=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tril(const at::Tensor & self, int64_t diagonal); // {"schema": "aten::tril(Tensor self, int diagonal=0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor tril_indices(int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor triu_indices(int64_t row, int64_t col, int64_t offset, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor trace(const at::Tensor & self); // {"schema": "aten::trace(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor trace_backward(const at::Tensor & grad, c10::SymIntArrayRef sizes); // {"schema": "aten::trace_backward(Tensor grad, SymInt[] sizes) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & ne_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::ne.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ne(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ne.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ne_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::ne.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ne(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ne.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ne_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ne_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ne_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ne_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & not_equal_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::not_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor not_equal(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::not_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & not_equal_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::not_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor not_equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::not_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & not_equal_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::not_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & not_equal_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::not_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & eq_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::eq.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor eq(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::eq.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & eq_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::eq.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor eq(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::eq.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ge_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::ge.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ge(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ge.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ge_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::ge.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ge(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ge.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & ge_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::ge_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ge_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::ge_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & greater_equal_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::greater_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater_equal(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_equal_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::greater_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater_equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_equal_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & greater_equal_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & le_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::le.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor le(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::le.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & le_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::le.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor le(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::le.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & le_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::le_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & le_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::le_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & less_equal_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::less_equal.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less_equal(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less_equal.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_equal_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::less_equal.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less_equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less_equal.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_equal_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less_equal_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & less_equal_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less_equal_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & gt_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::gt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gt(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::gt.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gt_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::gt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gt(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gt.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & gt_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::gt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & gt_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::gt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & greater_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::greater.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::greater.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor greater(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & greater_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::greater_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & greater_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::greater_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & lt_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::lt.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lt(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::lt.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lt_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::lt.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lt(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lt.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lt_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::lt_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lt_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::lt_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & less_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::less.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::less.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor less(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & less_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::less_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & less_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::less_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & take_out(const at::Tensor & self, const at::Tensor & index, at::Tensor & out); // {"schema": "aten::take.out(Tensor self, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor take(const at::Tensor & self, const at::Tensor & index); // {"schema": "aten::take(Tensor self, Tensor index) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & take_along_dim_out(const at::Tensor & self, const at::Tensor & indices, ::std::optional dim, at::Tensor & out); // {"schema": "aten::take_along_dim.out(Tensor self, Tensor indices, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor take_along_dim(const at::Tensor & self, const at::Tensor & indices, ::std::optional dim); // {"schema": "aten::take_along_dim(Tensor self, Tensor indices, int? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & index_select_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, at::Tensor & out); // {"schema": "aten::index_select.out(Tensor self, int dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor index_select(const at::Tensor & self, int64_t dim, const at::Tensor & index); // {"schema": "aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & index_select_out(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, at::Tensor & out); // {"schema": "aten::index_select.dimname_out(Tensor self, Dimname dim, Tensor index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor index_select(const at::Tensor & self, at::Dimname dim, const at::Tensor & index); // {"schema": "aten::index_select.dimname(Tensor self, Dimname dim, Tensor index) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor index_select_backward(const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index); // {"schema": "aten::index_select_backward(Tensor grad, SymInt[] self_sizes, int dim, Tensor index) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & masked_select_out(const at::Tensor & self, const at::Tensor & mask, at::Tensor & out); // {"schema": "aten::masked_select.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor masked_select(const at::Tensor & self, const at::Tensor & mask); // {"schema": "aten::masked_select(Tensor self, Tensor mask) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor masked_select_backward(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask); // {"schema": "aten::masked_select_backward(Tensor grad, Tensor input, Tensor mask) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nonzero_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::nonzero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nonzero(const at::Tensor & self); // {"schema": "aten::nonzero(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & nonzero_static_out(const at::Tensor & self, c10::SymInt size, int64_t fill_value, at::Tensor & out); // {"schema": "aten::nonzero_static.out(Tensor self, *, SymInt size, int fill_value=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nonzero_static(const at::Tensor & self, c10::SymInt size, int64_t fill_value); // {"schema": "aten::nonzero_static(Tensor self, *, SymInt size, int fill_value=-1) -> Tensor", "dispatch": "True", "default": "False"} +::std::vector nonzero_numpy(const at::Tensor & self); // {"schema": "aten::nonzero_numpy(Tensor self) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor argwhere(const at::Tensor & self); // {"schema": "aten::argwhere(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & gather_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out); // {"schema": "aten::gather.out(Tensor self, int dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor gather(const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad); // {"schema": "aten::gather(Tensor self, int dim, Tensor index, *, bool sparse_grad=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor gather_backward(const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad); // {"schema": "aten::gather_backward(Tensor grad, Tensor self, int dim, Tensor index, bool sparse_grad) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & gather_out(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad, at::Tensor & out); // {"schema": "aten::gather.dimname_out(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor gather(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad); // {"schema": "aten::gather.dimname(Tensor self, Dimname dim, Tensor index, *, bool sparse_grad=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _gather_sparse_backward(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad); // {"schema": "aten::_gather_sparse_backward(Tensor self, int dim, Tensor index, Tensor grad) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & addcmul_out(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::addcmul.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addcmul(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addcmul_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcmul_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & addcdiv_out(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::addcdiv.out(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor addcdiv(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & addcdiv_(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value); // {"schema": "aten::addcdiv_(Tensor(a!) self, Tensor tensor1, Tensor tensor2, *, Scalar value=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor cross_entropy_loss(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, double label_smoothing); // {"schema": "aten::cross_entropy_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, float label_smoothing=0.0) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple triangular_solve_out(const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular, at::Tensor & X, at::Tensor & M); // {"schema": "aten::triangular_solve.X(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False, *, Tensor(a!) X, Tensor(b!) M) -> (Tensor(a!) solution, Tensor(b!) cloned_coefficient)", "dispatch": "True", "default": "False"} +::std::tuple triangular_solve(const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular); // {"schema": "aten::triangular_solve(Tensor self, Tensor A, bool upper=True, bool transpose=False, bool unitriangular=False) -> (Tensor solution, Tensor cloned_coefficient)", "dispatch": "True", "default": "True"} +void _linalg_check_errors(const at::Tensor & info, c10::string_view api_name, bool is_matrix); // {"schema": "aten::_linalg_check_errors(Tensor info, str api_name, *, bool is_matrix) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & linalg_solve_triangular_out(const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular, at::Tensor & out); // {"schema": "aten::linalg_solve_triangular.out(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor linalg_solve_triangular(const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular); // {"schema": "aten::linalg_solve_triangular(Tensor self, Tensor B, *, bool upper, bool left=True, bool unitriangular=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor linalg_vander(const at::Tensor & x, ::std::optional N); // {"schema": "aten::linalg_vander(Tensor x, *, SymInt? N=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple svd_out(const at::Tensor & self, bool some, bool compute_uv, at::Tensor & U, at::Tensor & S, at::Tensor & V); // {"schema": "aten::svd.U(Tensor self, bool some=True, bool compute_uv=True, *, Tensor(a!) U, Tensor(b!) S, Tensor(c!) V) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) V)", "dispatch": "False", "default": "True"} +::std::tuple svd(const at::Tensor & self, bool some, bool compute_uv); // {"schema": "aten::svd(Tensor self, bool some=True, bool compute_uv=True) -> (Tensor U, Tensor S, Tensor V)", "dispatch": "False", "default": "True"} +at::Tensor swapaxes(const at::Tensor & self, int64_t axis0, int64_t axis1); // {"schema": "aten::swapaxes(Tensor(a) self, int axis0, int axis1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor & swapaxes_(at::Tensor & self, int64_t axis0, int64_t axis1); // {"schema": "aten::swapaxes_(Tensor(a!) self, int axis0, int axis1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor swapdims(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::swapdims(Tensor(a) self, int dim0, int dim1) -> Tensor(a)", "dispatch": "False", "default": "True"} +at::Tensor & swapdims_(at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::swapdims_(Tensor(a!) self, int dim0, int dim1) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & cholesky_out(const at::Tensor & self, bool upper, at::Tensor & out); // {"schema": "aten::cholesky.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor cholesky(const at::Tensor & self, bool upper); // {"schema": "aten::cholesky(Tensor self, bool upper=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & cholesky_solve_out(const at::Tensor & self, const at::Tensor & input2, bool upper, at::Tensor & out); // {"schema": "aten::cholesky_solve.out(Tensor self, Tensor input2, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor cholesky_solve(const at::Tensor & self, const at::Tensor & input2, bool upper); // {"schema": "aten::cholesky_solve(Tensor self, Tensor input2, bool upper=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _cholesky_solve_helper(const at::Tensor & self, const at::Tensor & A, bool upper); // {"schema": "aten::_cholesky_solve_helper(Tensor self, Tensor A, bool upper) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor cholesky_inverse(const at::Tensor & self, bool upper); // {"schema": "aten::cholesky_inverse(Tensor self, bool upper=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & cholesky_inverse_out(const at::Tensor & self, bool upper, at::Tensor & out); // {"schema": "aten::cholesky_inverse.out(Tensor self, bool upper=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple qr_out(const at::Tensor & self, bool some, at::Tensor & Q, at::Tensor & R); // {"schema": "aten::qr.Q(Tensor self, bool some=True, *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)", "dispatch": "False", "default": "True"} +::std::tuple qr(const at::Tensor & self, bool some); // {"schema": "aten::qr(Tensor self, bool some=True) -> (Tensor Q, Tensor R)", "dispatch": "False", "default": "True"} +::std::tuple geqrf_out(const at::Tensor & self, at::Tensor & a, at::Tensor & tau); // {"schema": "aten::geqrf.a(Tensor self, *, Tensor(a!) a, Tensor(b!) tau) -> (Tensor(a!) a, Tensor(b!) tau)", "dispatch": "True", "default": "False"} +::std::tuple geqrf(const at::Tensor & self); // {"schema": "aten::geqrf(Tensor self) -> (Tensor a, Tensor tau)", "dispatch": "True", "default": "False"} +at::Tensor orgqr(const at::Tensor & self, const at::Tensor & input2); // {"schema": "aten::orgqr(Tensor self, Tensor input2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & orgqr_out(const at::Tensor & self, const at::Tensor & input2, at::Tensor & out); // {"schema": "aten::orgqr.out(Tensor self, Tensor input2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & ormqr_out(const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose, at::Tensor & out); // {"schema": "aten::ormqr.out(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor ormqr(const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose); // {"schema": "aten::ormqr(Tensor self, Tensor input2, Tensor input3, bool left=True, bool transpose=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _lu_with_info(const at::Tensor & self, bool pivot, bool check_errors); // {"schema": "aten::_lu_with_info(Tensor self, bool pivot=True, bool check_errors=True) -> (Tensor LU, Tensor pivots, Tensor info)", "dispatch": "False", "default": "True"} +at::Tensor & lu_solve_out(const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots, at::Tensor & out); // {"schema": "aten::lu_solve.out(Tensor self, Tensor LU_data, Tensor LU_pivots, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor lu_solve(const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots); // {"schema": "aten::lu_solve(Tensor self, Tensor LU_data, Tensor LU_pivots) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple lu_unpack(const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots); // {"schema": "aten::lu_unpack(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True) -> (Tensor P, Tensor L, Tensor U)", "dispatch": "True", "default": "True"} +::std::tuple lu_unpack_out(const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots, at::Tensor & P, at::Tensor & L, at::Tensor & U); // {"schema": "aten::lu_unpack.out(Tensor LU_data, Tensor LU_pivots, bool unpack_data=True, bool unpack_pivots=True, *, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)", "dispatch": "True", "default": "False"} +at::Tensor & multinomial_out(const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator, at::Tensor & out); // {"schema": "aten::multinomial.out(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multinomial(const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator); // {"schema": "aten::multinomial(Tensor self, SymInt num_samples, bool replacement=False, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & lgamma_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::lgamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & lgamma_(at::Tensor & self); // {"schema": "aten::lgamma_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor lgamma(const at::Tensor & self); // {"schema": "aten::lgamma(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & digamma_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor digamma(const at::Tensor & self); // {"schema": "aten::digamma(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & polygamma_out(int64_t n, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor polygamma(int64_t n, const at::Tensor & self); // {"schema": "aten::polygamma(int n, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & polygamma_(at::Tensor & self, int64_t n); // {"schema": "aten::polygamma_(Tensor(a!) self, int n) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor erfinv(const at::Tensor & self); // {"schema": "aten::erfinv(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & erfinv_(at::Tensor & self); // {"schema": "aten::erfinv_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & erfinv_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor i0(const at::Tensor & self); // {"schema": "aten::i0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & i0_(at::Tensor & self); // {"schema": "aten::i0_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & i0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sign(const at::Tensor & self); // {"schema": "aten::sign(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sign_(at::Tensor & self); // {"schema": "aten::sign_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sign_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::sign.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor signbit(const at::Tensor & self); // {"schema": "aten::signbit(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & signbit_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::signbit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor dist(const at::Tensor & self, const at::Tensor & other, const at::Scalar & p); // {"schema": "aten::dist(Tensor self, Tensor other, Scalar p=2) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & atan2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::atan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & atan2_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::atan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor atan2(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::atan2(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor arctan2(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::arctan2(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & arctan2_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::arctan2.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & arctan2_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::arctan2_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight, at::Tensor & out); // {"schema": "aten::lerp.Scalar_out(Tensor self, Tensor end, Scalar weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & lerp_out(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight, at::Tensor & out); // {"schema": "aten::lerp.Tensor_out(Tensor self, Tensor end, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight); // {"schema": "aten::lerp.Scalar(Tensor self, Tensor end, Scalar weight) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor lerp(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight); // {"schema": "aten::lerp.Tensor(Tensor self, Tensor end, Tensor weight) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & histc_out(const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max, at::Tensor & out); // {"schema": "aten::histc.out(Tensor self, int bins=100, Scalar min=0, Scalar max=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor histc(const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max); // {"schema": "aten::histc(Tensor self, int bins=100, Scalar min=0, Scalar max=0) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple histogram_out(const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges); // {"schema": "aten::histogram.bins_tensor_out(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)", "dispatch": "True", "default": "False"} +::std::tuple histogram(const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density); // {"schema": "aten::histogram.bins_tensor(Tensor self, Tensor bins, *, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)", "dispatch": "True", "default": "False"} +::std::tuple histogram_out(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & hist, at::Tensor & bin_edges); // {"schema": "aten::histogram.bin_ct_out(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) hist, Tensor(b!) bin_edges) -> (Tensor(a!) hist, Tensor(b!) bin_edges)", "dispatch": "True", "default": "False"} +::std::tuple histogram(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogram.bin_ct(Tensor self, int bins=100, *, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor bin_edges)", "dispatch": "True", "default": "False"} +::std::vector _histogramdd_bin_edges(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::_histogramdd_bin_edges(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor[]", "dispatch": "True", "default": "False"} +at::Tensor _histogramdd_from_bin_cts(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::_histogramdd_from_bin_cts(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _histogramdd_from_bin_tensors(const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density); // {"schema": "aten::_histogramdd_from_bin_tensors(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple> histogramdd(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogramdd(Tensor self, int[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"} +::std::tuple> histogramdd(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogramdd.int_bins(Tensor self, int bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"} +::std::tuple> histogramdd(const at::Tensor & self, at::TensorList bins, ::std::optional> range, const ::std::optional & weight, bool density); // {"schema": "aten::histogramdd.TensorList_bins(Tensor self, Tensor[] bins, float[]? range=None, Tensor? weight=None, bool density=False) -> (Tensor hist, Tensor[] bin_edges)", "dispatch": "False", "default": "True"} +at::Tensor & fmod_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::fmod.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor fmod(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::fmod.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmod_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::fmod_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & fmod_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::fmod.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fmod(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmod.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmod_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmod_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hypot_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::hypot.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hypot(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::hypot(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hypot_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::hypot_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & igamma_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::igamma.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor igamma(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igamma(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & igamma_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igamma_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & igammac_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::igammac.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor igammac(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igammac(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & igammac_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::igammac_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & nextafter_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::nextafter.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nextafter(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::nextafter(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & nextafter_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::nextafter_(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & remainder_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::remainder.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor remainder(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::remainder.Scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & remainder_(at::Tensor & self, const at::Scalar & other); // {"schema": "aten::remainder_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & remainder_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::remainder.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor remainder(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::remainder.Tensor(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & remainder_(at::Tensor & self, const at::Tensor & other); // {"schema": "aten::remainder_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor remainder(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::remainder.Scalar_Tensor(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor min(const at::Tensor & self); // {"schema": "aten::min(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & min_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::min.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fmin(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmin(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmin_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::fmin.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max(const at::Tensor & self); // {"schema": "aten::max(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor fmax(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::fmax(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fmax_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::fmax.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor maximum(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::maximum(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & maximum_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::maximum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::max.other(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & max_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::max.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & max_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::max.unary_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor minimum(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::minimum(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & minimum_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::minimum.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & min_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::min.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor min(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::min.other(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor quantile(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::quantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & quantile_out(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::quantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor quantile(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::quantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & quantile_out(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::quantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nanquantile(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::nanquantile(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nanquantile_out(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::nanquantile.out(Tensor self, Tensor q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nanquantile(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation); // {"schema": "aten::nanquantile.scalar(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & nanquantile_out(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation, at::Tensor & out); // {"schema": "aten::nanquantile.scalar_out(Tensor self, float q, int? dim=None, bool keepdim=False, *, str interpolation='linear', Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.values(Tensor self, int dim=-1, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.values_stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple sort(const at::Tensor & self, int64_t dim, bool descending); // {"schema": "aten::sort(Tensor self, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple sort(const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending); // {"schema": "aten::sort.stable(Tensor self, *, bool? stable, int dim=-1, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.dimname_values(Tensor self, Dimname dim, bool descending=False, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +::std::tuple sort_out(const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::sort.dimname_values_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "False", "default": "True"} +::std::tuple sort(const at::Tensor & self, at::Dimname dim, bool descending); // {"schema": "aten::sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +::std::tuple sort(const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending); // {"schema": "aten::sort.dimname_stable(Tensor self, *, bool? stable, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)", "dispatch": "False", "default": "True"} +at::Tensor & msort_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor msort(const at::Tensor & self); // {"schema": "aten::msort(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor argsort(const at::Tensor & self, int64_t dim, bool descending); // {"schema": "aten::argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor argsort(const at::Tensor & self, bool stable, int64_t dim, bool descending); // {"schema": "aten::argsort.stable(Tensor self, *, bool stable, int dim=-1, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & argsort_out(const at::Tensor & self, bool stable, int64_t dim, bool descending, at::Tensor & out); // {"schema": "aten::argsort.stable_out(Tensor self, *, bool stable, int dim=-1, bool descending=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor argsort(const at::Tensor & self, at::Dimname dim, bool descending); // {"schema": "aten::argsort.dimname(Tensor self, Dimname dim, bool descending=False) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple topk_out(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted, at::Tensor & values, at::Tensor & indices); // {"schema": "aten::topk.values(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices)", "dispatch": "True", "default": "False"} +::std::tuple topk(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted); // {"schema": "aten::topk(Tensor self, SymInt k, int dim=-1, bool largest=True, bool sorted=True) -> (Tensor values, Tensor indices)", "dispatch": "True", "default": "True"} +at::Tensor all(const at::Tensor & self); // {"schema": "aten::all(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & all_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::all.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor any(const at::Tensor & self); // {"schema": "aten::any(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & any_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::any.all_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & renorm_out(const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm, at::Tensor & out); // {"schema": "aten::renorm.out(Tensor self, Scalar p, int dim, Scalar maxnorm, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor renorm(const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm); // {"schema": "aten::renorm(Tensor self, Scalar p, int dim, Scalar maxnorm) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & renorm_(at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm); // {"schema": "aten::renorm_(Tensor(a!) self, Scalar p, int dim, Scalar maxnorm) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor unfold(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step); // {"schema": "aten::unfold(Tensor(a) self, int dimension, int size, int step) -> Tensor(a)", "dispatch": "True", "default": "False"} +at::Tensor unfold_backward(const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step); // {"schema": "aten::unfold_backward(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step) -> Tensor", "dispatch": "True", "default": "False"} +bool equal(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::equal(Tensor self, Tensor other) -> bool", "dispatch": "True", "default": "False"} +at::Tensor & pow_out(const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::pow.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor pow(const at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & pow_out(const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::pow.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor pow(const at::Scalar & self, const at::Tensor & exponent); // {"schema": "aten::pow.Scalar(Scalar self, Tensor exponent) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & pow_out(const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out); // {"schema": "aten::pow.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor pow(const at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & pow_(at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::pow_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & pow_(at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::pow_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & float_power_out(const at::Tensor & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::float_power.Tensor_Tensor_out(Tensor self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor float_power(const at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::float_power.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & float_power_out(const at::Scalar & self, const at::Tensor & exponent, at::Tensor & out); // {"schema": "aten::float_power.Scalar_out(Scalar self, Tensor exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor float_power(const at::Scalar & self, const at::Tensor & exponent); // {"schema": "aten::float_power.Scalar(Scalar self, Tensor exponent) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & float_power_out(const at::Tensor & self, const at::Scalar & exponent, at::Tensor & out); // {"schema": "aten::float_power.Tensor_Scalar_out(Tensor self, Scalar exponent, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor float_power(const at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::float_power.Tensor_Scalar(Tensor self, Scalar exponent) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & float_power_(at::Tensor & self, const at::Scalar & exponent); // {"schema": "aten::float_power_.Scalar(Tensor(a!) self, Scalar exponent) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & float_power_(at::Tensor & self, const at::Tensor & exponent); // {"schema": "aten::float_power_.Tensor(Tensor(a!) self, Tensor exponent) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & normal_(at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal_functional(const at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::normal_functional(Tensor self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & normal_out(const at::Tensor & mean, double std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.Tensor_float_out(Tensor mean, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal(const at::Tensor & mean, double std, ::std::optional generator); // {"schema": "aten::normal.Tensor_float(Tensor mean, float std=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & normal_out(double mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.float_Tensor_out(float mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal(double mean, const at::Tensor & std, ::std::optional generator); // {"schema": "aten::normal.float_Tensor(float mean, Tensor std, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & normal_out(const at::Tensor & mean, const at::Tensor & std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.Tensor_Tensor_out(Tensor mean, Tensor std, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor normal(const at::Tensor & mean, const at::Tensor & std, ::std::optional generator); // {"schema": "aten::normal.Tensor_Tensor(Tensor mean, Tensor std, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor normal(double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::normal.float_float(float mean, float std, SymInt[] size, *, Generator? generator=None, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & normal_out(double mean, double std, c10::SymIntArrayRef size, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.float_float_out(float mean, float std, SymInt[] size, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor alias(const at::Tensor & self); // {"schema": "aten::alias(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +void _amp_foreach_non_finite_check_and_unscale_(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale_(Tensor(a!)[] self, Tensor(b!) found_inf, Tensor inv_scale) -> ()", "dispatch": "True", "default": "False"} +at::Tensor & _amp_update_scale_(at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); // {"schema": "aten::_amp_update_scale_(Tensor(a!) self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::vector _foreach_add(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_add.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_add(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_add(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_add.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_add_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_add(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add.Tensor(Tensor[] self, Tensor other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_add_(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_foreach_add_.Tensor(Tensor(a!)[] self, Tensor other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sub(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_sub.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sub_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_sub_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sub(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_sub.List(Tensor[] self, Tensor[] other, *, Scalar alpha=1) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sub_(at::TensorList self, at::TensorList other, const at::Scalar & alpha); // {"schema": "aten::_foreach_sub_.List(Tensor(a!)[] self, Tensor[] other, *, Scalar alpha=1) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sub(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_sub.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sub_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_sub_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_mul.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_mul_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_mul.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_mul_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_mul.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_mul_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_mul(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_mul.Tensor(Tensor[] self, Tensor other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_mul_(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_mul_.Tensor(Tensor(a!)[] self, Tensor other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_div.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_div_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_div.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_div_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_div.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_div_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_div(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_div.Tensor(Tensor[] self, Tensor other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_div_(at::TensorList self, const at::Tensor & other); // {"schema": "aten::_foreach_div_.Tensor(Tensor(a!)[] self, Tensor other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_max(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_max.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_max_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_max(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_max.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_max_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_max(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_max.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_max_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_min(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_min.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_clamp_min_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_min(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_min.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_clamp_min_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_clamp_min(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_min.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_clamp_min_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_maximum(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_maximum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_maximum_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_maximum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_maximum(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_maximum.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_maximum_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_maximum_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_maximum(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_maximum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_maximum_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_maximum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_minimum(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_minimum.Scalar(Tensor[] self, Scalar scalar) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_minimum_(at::TensorList self, const at::Scalar & scalar); // {"schema": "aten::_foreach_minimum_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_minimum(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_minimum.List(Tensor[] self, Tensor[] other) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_minimum_(at::TensorList self, at::TensorList other); // {"schema": "aten::_foreach_minimum_.List(Tensor(a!)[] self, Tensor[] other) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_minimum(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_minimum.ScalarList(Tensor[] self, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_minimum_(at::TensorList self, at::ArrayRef scalars); // {"schema": "aten::_foreach_minimum_.ScalarList(Tensor(a!)[] self, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcdiv(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcdiv.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcdiv(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcdiv.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcdiv(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcdiv.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcdiv_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcdiv_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcdiv_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcmul(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcmul.Scalar(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcmul(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcmul.ScalarList(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_addcmul(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcmul.Tensor(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_addcmul_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value); // {"schema": "aten::_foreach_addcmul_.Scalar(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars); // {"schema": "aten::_foreach_addcmul_.ScalarList(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars); // {"schema": "aten::_foreach_addcmul_.Tensor(Tensor(a!)[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_abs(at::TensorList self); // {"schema": "aten::_foreach_abs(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_abs_(at::TensorList self); // {"schema": "aten::_foreach_abs_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_acos(at::TensorList self); // {"schema": "aten::_foreach_acos(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_acos_(at::TensorList self); // {"schema": "aten::_foreach_acos_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_asin(at::TensorList self); // {"schema": "aten::_foreach_asin(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_asin_(at::TensorList self); // {"schema": "aten::_foreach_asin_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_atan(at::TensorList self); // {"schema": "aten::_foreach_atan(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_atan_(at::TensorList self); // {"schema": "aten::_foreach_atan_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_ceil(at::TensorList self); // {"schema": "aten::_foreach_ceil(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_ceil_(at::TensorList self); // {"schema": "aten::_foreach_ceil_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_cos(at::TensorList self); // {"schema": "aten::_foreach_cos(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_cos_(at::TensorList self); // {"schema": "aten::_foreach_cos_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_cosh(at::TensorList self); // {"schema": "aten::_foreach_cosh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_cosh_(at::TensorList self); // {"schema": "aten::_foreach_cosh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_erf(at::TensorList self); // {"schema": "aten::_foreach_erf(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_erf_(at::TensorList self); // {"schema": "aten::_foreach_erf_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_erfc(at::TensorList self); // {"schema": "aten::_foreach_erfc(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_erfc_(at::TensorList self); // {"schema": "aten::_foreach_erfc_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_exp(at::TensorList self); // {"schema": "aten::_foreach_exp(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_exp_(at::TensorList self); // {"schema": "aten::_foreach_exp_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_expm1(at::TensorList self); // {"schema": "aten::_foreach_expm1(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_expm1_(at::TensorList self); // {"schema": "aten::_foreach_expm1_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_floor(at::TensorList self); // {"schema": "aten::_foreach_floor(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_floor_(at::TensorList self); // {"schema": "aten::_foreach_floor_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_frac(at::TensorList self); // {"schema": "aten::_foreach_frac(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_frac_(at::TensorList self); // {"schema": "aten::_foreach_frac_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lerp(at::TensorList self, at::TensorList tensors1, at::TensorList weights); // {"schema": "aten::_foreach_lerp.List(Tensor[] self, Tensor[] tensors1, Tensor[] weights) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lerp_(at::TensorList self, at::TensorList tensors1, at::TensorList weights); // {"schema": "aten::_foreach_lerp_.List(Tensor(a!)[] self, Tensor[] tensors1, Tensor[] weights) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lerp(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight); // {"schema": "aten::_foreach_lerp.Scalar(Tensor[] self, Tensor[] tensors1, Scalar weight) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lerp_(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight); // {"schema": "aten::_foreach_lerp_.Scalar(Tensor(a!)[] self, Tensor[] tensors1, Scalar weight) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lerp(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight); // {"schema": "aten::_foreach_lerp.ScalarList(Tensor[] self, Tensor[] tensors1, Scalar[] weight) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lerp_(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight); // {"schema": "aten::_foreach_lerp_.ScalarList(Tensor(a!)[] self, Tensor[] tensors1, Scalar[] weight) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_lgamma(at::TensorList self); // {"schema": "aten::_foreach_lgamma(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_lgamma_(at::TensorList self); // {"schema": "aten::_foreach_lgamma_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log(at::TensorList self); // {"schema": "aten::_foreach_log(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log_(at::TensorList self); // {"schema": "aten::_foreach_log_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log10(at::TensorList self); // {"schema": "aten::_foreach_log10(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log10_(at::TensorList self); // {"schema": "aten::_foreach_log10_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log1p(at::TensorList self); // {"schema": "aten::_foreach_log1p(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log1p_(at::TensorList self); // {"schema": "aten::_foreach_log1p_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_log2(at::TensorList self); // {"schema": "aten::_foreach_log2(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_log2_(at::TensorList self); // {"schema": "aten::_foreach_log2_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_max(at::TensorList self); // {"schema": "aten::_foreach_max(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_neg(at::TensorList self); // {"schema": "aten::_foreach_neg(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_neg_(at::TensorList self); // {"schema": "aten::_foreach_neg_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_norm(at::TensorList self, const at::Scalar & ord, ::std::optional dtype); // {"schema": "aten::_foreach_norm.Scalar(Tensor[] self, Scalar ord=2, ScalarType? dtype=None) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(at::TensorList self, at::TensorList exponent); // {"schema": "aten::_foreach_pow.List(Tensor[] self, Tensor[] exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(at::TensorList self, const at::Scalar & exponent); // {"schema": "aten::_foreach_pow.Scalar(Tensor[] self, Scalar exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(at::TensorList self, at::ArrayRef exponent); // {"schema": "aten::_foreach_pow.ScalarList(Tensor[] self, Scalar[] exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector _foreach_pow(const at::Scalar & self, at::TensorList exponent); // {"schema": "aten::_foreach_pow.ScalarAndTensor(Scalar self, Tensor[] exponent) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_pow_(at::TensorList self, at::TensorList exponent); // {"schema": "aten::_foreach_pow_.List(Tensor(a!)[] self, Tensor[] exponent) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_(at::TensorList self, const at::Scalar & exponent); // {"schema": "aten::_foreach_pow_.Scalar(Tensor(a!)[] self, Scalar exponent) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_(at::TensorList self, at::ArrayRef exponent); // {"schema": "aten::_foreach_pow_.ScalarList(Tensor(a!)[] self, Scalar[] exponent) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_reciprocal(at::TensorList self); // {"schema": "aten::_foreach_reciprocal(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_reciprocal_(at::TensorList self); // {"schema": "aten::_foreach_reciprocal_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_round(at::TensorList self); // {"schema": "aten::_foreach_round(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_round_(at::TensorList self); // {"schema": "aten::_foreach_round_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_rsqrt(at::TensorList self); // {"schema": "aten::_foreach_rsqrt(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_rsqrt_(at::TensorList self); // {"schema": "aten::_foreach_rsqrt_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sigmoid(at::TensorList self); // {"schema": "aten::_foreach_sigmoid(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sigmoid_(at::TensorList self); // {"schema": "aten::_foreach_sigmoid_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sign(at::TensorList self); // {"schema": "aten::_foreach_sign(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sign_(at::TensorList self); // {"schema": "aten::_foreach_sign_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sin(at::TensorList self); // {"schema": "aten::_foreach_sin(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sin_(at::TensorList self); // {"schema": "aten::_foreach_sin_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sinh(at::TensorList self); // {"schema": "aten::_foreach_sinh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sinh_(at::TensorList self); // {"schema": "aten::_foreach_sinh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_sqrt(at::TensorList self); // {"schema": "aten::_foreach_sqrt(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_sqrt_(at::TensorList self); // {"schema": "aten::_foreach_sqrt_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_tan(at::TensorList self); // {"schema": "aten::_foreach_tan(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_tan_(at::TensorList self); // {"schema": "aten::_foreach_tan_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_tanh(at::TensorList self); // {"schema": "aten::_foreach_tanh(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_tanh_(at::TensorList self); // {"schema": "aten::_foreach_tanh_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_trunc(at::TensorList self); // {"schema": "aten::_foreach_trunc(Tensor[] self) -> Tensor[]", "dispatch": "True", "default": "True"} +void _foreach_trunc_(at::TensorList self); // {"schema": "aten::_foreach_trunc_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +void _foreach_zero_(at::TensorList self); // {"schema": "aten::_foreach_zero_(Tensor(a!)[] self) -> ()", "dispatch": "True", "default": "True"} +void _foreach_copy_(at::TensorList self, at::TensorList src, bool non_blocking); // {"schema": "aten::_foreach_copy_(Tensor(a!)[] self, Tensor[] src, bool non_blocking=False) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_copy(at::TensorList self, at::TensorList src, bool non_blocking); // {"schema": "aten::_foreach_copy(Tensor[] self, Tensor[] src, bool non_blocking=False) -> Tensor[] self_out", "dispatch": "True", "default": "True"} +at::Tensor bucketize(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right); // {"schema": "aten::bucketize.Tensor(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & bucketize_out(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out); // {"schema": "aten::bucketize.Tensor_out(Tensor self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor bucketize(const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right); // {"schema": "aten::bucketize.Scalar(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor searchsorted(const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter); // {"schema": "aten::searchsorted.Tensor(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & searchsorted_out(const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out); // {"schema": "aten::searchsorted.Tensor_out(Tensor sorted_sequence, Tensor self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor searchsorted(const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter); // {"schema": "aten::searchsorted.Scalar(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & searchsorted_out(const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter, at::Tensor & out); // {"schema": "aten::searchsorted.Scalar_out(Tensor sorted_sequence, Scalar self, *, bool out_int32=False, bool right=False, str? side=None, Tensor? sorter=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _convert_indices_from_coo_to_csr(const at::Tensor & self, int64_t size, bool out_int32); // {"schema": "aten::_convert_indices_from_coo_to_csr(Tensor self, int size, *, bool out_int32=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _convert_indices_from_coo_to_csr_out(const at::Tensor & self, int64_t size, bool out_int32, at::Tensor & out); // {"schema": "aten::_convert_indices_from_coo_to_csr.out(Tensor self, int size, *, bool out_int32=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _convert_indices_from_csr_to_coo(const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose); // {"schema": "aten::_convert_indices_from_csr_to_coo(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _convert_indices_from_csr_to_coo_out(const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose, at::Tensor & out); // {"schema": "aten::_convert_indices_from_csr_to_coo.out(Tensor crow_indices, Tensor col_indices, *, bool out_int32=False, bool transpose=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & mse_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out); // {"schema": "aten::mse_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mse_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::mse_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & mse_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::mse_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mse_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::mse_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor l1_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::l1_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & multi_margin_loss_out(const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & out); // {"schema": "aten::multi_margin_loss.out(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multi_margin_loss(const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::multi_margin_loss(Tensor self, Tensor target, Scalar p=1, Scalar margin=1, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & multi_margin_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::multi_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multi_margin_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction); // {"schema": "aten::multi_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, Scalar p, Scalar margin, Tensor? weight=None, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & multilabel_margin_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out); // {"schema": "aten::multilabel_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor multilabel_margin_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::multilabel_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple multilabel_margin_loss_forward_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & output, at::Tensor & is_target); // {"schema": "aten::multilabel_margin_loss_forward.output(Tensor self, Tensor target, int reduction, *, Tensor(a!) output, Tensor(b!) is_target) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple multilabel_margin_loss_forward(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::multilabel_margin_loss_forward(Tensor self, Tensor target, int reduction) -> (Tensor output, Tensor is_target)", "dispatch": "True", "default": "False"} +at::Tensor & multilabel_margin_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target, at::Tensor & grad_input); // {"schema": "aten::multilabel_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor multilabel_margin_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target); // {"schema": "aten::multilabel_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, Tensor is_target) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & nll_loss_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out); // {"schema": "aten::nll_loss.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nll_loss_nd(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss_nd(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor nll_loss(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple nll_loss_forward_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight); // {"schema": "aten::nll_loss_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple nll_loss_forward(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)", "dispatch": "True", "default": "True"} +at::Tensor & nll_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input); // {"schema": "aten::nll_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nll_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight); // {"schema": "aten::nll_loss_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & nll_loss2d_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & out); // {"schema": "aten::nll_loss2d.out(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nll_loss2d(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss2d(Tensor self, Tensor target, Tensor? weight=None, int reduction=Mean, SymInt ignore_index=-100) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple nll_loss2d_forward_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, at::Tensor & output, at::Tensor & total_weight); // {"schema": "aten::nll_loss2d_forward.output(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, *, Tensor(a!) output, Tensor(b!) total_weight) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index); // {"schema": "aten::nll_loss2d_forward(Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index) -> (Tensor output, Tensor total_weight)", "dispatch": "True", "default": "False"} +at::Tensor & nll_loss2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight, at::Tensor & grad_input); // {"schema": "aten::nll_loss2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight); // {"schema": "aten::nll_loss2d_backward(Tensor grad_output, Tensor self, Tensor target, Tensor? weight, int reduction, SymInt ignore_index, Tensor total_weight) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & smooth_l1_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & out); // {"schema": "aten::smooth_l1_loss.out(Tensor self, Tensor target, int reduction=Mean, float beta=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor smooth_l1_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); // {"schema": "aten::smooth_l1_loss(Tensor self, Tensor target, int reduction=Mean, float beta=1.0) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & smooth_l1_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta, at::Tensor & grad_input); // {"schema": "aten::smooth_l1_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor smooth_l1_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta); // {"schema": "aten::smooth_l1_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float beta) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & huber_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & out); // {"schema": "aten::huber_loss.out(Tensor self, Tensor target, int reduction=Mean, float delta=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor huber_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta); // {"schema": "aten::huber_loss(Tensor self, Tensor target, int reduction=Mean, float delta=1.0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & huber_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta, at::Tensor & grad_input); // {"schema": "aten::huber_loss_backward.out(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor huber_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta); // {"schema": "aten::huber_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction, float delta) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & soft_margin_loss_out(const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & out); // {"schema": "aten::soft_margin_loss.out(Tensor self, Tensor target, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor soft_margin_loss(const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::soft_margin_loss(Tensor self, Tensor target, int reduction=Mean) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & soft_margin_loss_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, at::Tensor & grad_input); // {"schema": "aten::soft_margin_loss_backward.grad_input(Tensor grad_output, Tensor self, Tensor target, int reduction, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor soft_margin_loss_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction); // {"schema": "aten::soft_margin_loss_backward(Tensor grad_output, Tensor self, Tensor target, int reduction) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & elu_out(const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, at::Tensor & out); // {"schema": "aten::elu.out(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor elu(const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale); // {"schema": "aten::elu(Tensor self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & elu_backward_out(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result, at::Tensor & grad_input); // {"schema": "aten::elu_backward.grad_input(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor elu_backward(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result); // {"schema": "aten::elu_backward(Tensor grad_output, Scalar alpha, Scalar scale, Scalar input_scale, bool is_result, Tensor self_or_result) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & elu_(at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale); // {"schema": "aten::elu_(Tensor(a!) self, Scalar alpha=1, Scalar scale=1, Scalar input_scale=1) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & glu_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::glu.out(Tensor self, int dim=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor glu(const at::Tensor & self, int64_t dim); // {"schema": "aten::glu(Tensor self, int dim=-1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & glu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, at::Tensor & grad_input); // {"schema": "aten::glu_backward.grad_input(Tensor grad_output, Tensor self, int dim, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor glu_backward(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim); // {"schema": "aten::glu_backward(Tensor grad_output, Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor glu_jvp(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim); // {"schema": "aten::glu_jvp(Tensor glu, Tensor x, Tensor dx, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor glu_backward_jvp(const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim); // {"schema": "aten::glu_backward_jvp(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardsigmoid_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::hardsigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardsigmoid(const at::Tensor & self); // {"schema": "aten::hardsigmoid(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hardsigmoid_(at::Tensor & self); // {"schema": "aten::hardsigmoid_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hardsigmoid_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); // {"schema": "aten::hardsigmoid_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardsigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::hardsigmoid_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & hardtanh_out(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & out); // {"schema": "aten::hardtanh.out(Tensor self, Scalar min_val=-1, Scalar max_val=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardtanh(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); // {"schema": "aten::hardtanh(Tensor self, Scalar min_val=-1, Scalar max_val=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardtanh_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val, at::Tensor & grad_input); // {"schema": "aten::hardtanh_backward.grad_input(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardtanh_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); // {"schema": "aten::hardtanh_backward(Tensor grad_output, Tensor self, Scalar min_val, Scalar max_val) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardtanh_(at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val); // {"schema": "aten::hardtanh_(Tensor(a!) self, Scalar min_val=-1, Scalar max_val=1) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & hardswish_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardswish(const at::Tensor & self); // {"schema": "aten::hardswish(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & hardswish_(at::Tensor & self); // {"schema": "aten::hardswish_(Tensor(a!) self) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::hardswish_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & leaky_relu_out(const at::Tensor & self, const at::Scalar & negative_slope, at::Tensor & out); // {"schema": "aten::leaky_relu.out(Tensor self, Scalar negative_slope=0.01, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor leaky_relu(const at::Tensor & self, const at::Scalar & negative_slope); // {"schema": "aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & leaky_relu_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result, at::Tensor & grad_input); // {"schema": "aten::leaky_relu_backward.grad_input(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor leaky_relu_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result); // {"schema": "aten::leaky_relu_backward(Tensor grad_output, Tensor self, Scalar negative_slope, bool self_is_result) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & leaky_relu_(at::Tensor & self, const at::Scalar & negative_slope); // {"schema": "aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & log_sigmoid_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::log_sigmoid.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor log_sigmoid(const at::Tensor & self); // {"schema": "aten::log_sigmoid(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple log_sigmoid_forward_out(const at::Tensor & self, at::Tensor & output, at::Tensor & buffer); // {"schema": "aten::log_sigmoid_forward.output(Tensor self, *, Tensor(a!) output, Tensor(b!) buffer) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple log_sigmoid_forward(const at::Tensor & self); // {"schema": "aten::log_sigmoid_forward(Tensor self) -> (Tensor output, Tensor buffer)", "dispatch": "True", "default": "False"} +at::Tensor & log_sigmoid_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer, at::Tensor & grad_input); // {"schema": "aten::log_sigmoid_backward.grad_input(Tensor grad_output, Tensor self, Tensor buffer, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor log_sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer); // {"schema": "aten::log_sigmoid_backward(Tensor grad_output, Tensor self, Tensor buffer) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & rrelu_with_noise_out(const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator, at::Tensor & out); // {"schema": "aten::rrelu_with_noise.out(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor rrelu_with_noise(const at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_with_noise(Tensor self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor rrelu_with_noise_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result); // {"schema": "aten::rrelu_with_noise_backward(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & rrelu_with_noise_(at::Tensor & self, at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_with_noise_(Tensor(a!) self, Tensor(b!) noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & softplus_out(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & out); // {"schema": "aten::softplus.out(Tensor self, Scalar beta=1, Scalar threshold=20, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softplus(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold); // {"schema": "aten::softplus(Tensor self, Scalar beta=1, Scalar threshold=20) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & softplus_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold, at::Tensor & grad_input); // {"schema": "aten::softplus_backward.grad_input(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softplus_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold); // {"schema": "aten::softplus_backward(Tensor grad_output, Tensor self, Scalar beta, Scalar threshold) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & softshrink_out(const at::Tensor & self, const at::Scalar & lambd, at::Tensor & out); // {"schema": "aten::softshrink.out(Tensor self, Scalar lambd=0.5, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softshrink(const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::softshrink(Tensor self, Scalar lambd=0.5) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & softshrink_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd, at::Tensor & grad_input); // {"schema": "aten::softshrink_backward.grad_input(Tensor grad_output, Tensor self, Scalar lambd, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor softshrink_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd); // {"schema": "aten::softshrink_backward(Tensor grad_output, Tensor self, Scalar lambd) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_avg_pool2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_avg_pool2d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor mkldnn_adaptive_avg_pool2d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::mkldnn_adaptive_avg_pool2d(Tensor self, int[2] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & mkldnn_adaptive_avg_pool2d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); // {"schema": "aten::mkldnn_adaptive_avg_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor mkldnn_adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::mkldnn_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _adaptive_avg_pool2d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::_adaptive_avg_pool2d(Tensor self, SymInt[2] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _adaptive_avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::_adaptive_avg_pool2d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & adaptive_avg_pool3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_avg_pool3d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _adaptive_avg_pool3d(const at::Tensor & self, c10::SymIntArrayRef output_size); // {"schema": "aten::_adaptive_avg_pool3d(Tensor self, SymInt[3] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & adaptive_avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & grad_input); // {"schema": "aten::adaptive_avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _adaptive_avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self); // {"schema": "aten::_adaptive_avg_pool3d_backward(Tensor grad_output, Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple adaptive_max_pool2d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::adaptive_max_pool2d.out(Tensor self, int[2] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple adaptive_max_pool2d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool2d(Tensor self, int[2] output_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::adaptive_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); // {"schema": "aten::adaptive_max_pool2d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple adaptive_max_pool3d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::adaptive_max_pool3d.out(Tensor self, int[3] output_size, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple adaptive_max_pool3d(const at::Tensor & self, at::IntArrayRef output_size); // {"schema": "aten::adaptive_max_pool3d(Tensor self, int[3] output_size) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_max_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::adaptive_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor adaptive_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices); // {"schema": "aten::adaptive_max_pool3d_backward(Tensor grad_output, Tensor self, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); // {"schema": "aten::avg_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); // {"schema": "aten::avg_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & out); // {"schema": "aten::avg_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool3d(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, bool ceil_mode=False, bool count_include_pad=True, int? divisor_override=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override, at::Tensor & grad_input); // {"schema": "aten::avg_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor avg_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override); // {"schema": "aten::avg_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, bool ceil_mode, bool count_include_pad, int? divisor_override) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple fractional_max_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices); // {"schema": "aten::fractional_max_pool2d.output(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple fractional_max_pool2d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples); // {"schema": "aten::fractional_max_pool2d(Tensor self, int[2] kernel_size, int[2] output_size, Tensor random_samples) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & fractional_max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::fractional_max_pool2d_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fractional_max_pool2d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices); // {"schema": "aten::fractional_max_pool2d_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] output_size, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple fractional_max_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples, at::Tensor & output, at::Tensor & indices); // {"schema": "aten::fractional_max_pool3d.output(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples, *, Tensor(a!) output, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple fractional_max_pool3d(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples); // {"schema": "aten::fractional_max_pool3d(Tensor self, int[3] kernel_size, int[3] output_size, Tensor random_samples) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & fractional_max_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::fractional_max_pool3d_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor fractional_max_pool3d_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices); // {"schema": "aten::fractional_max_pool3d_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] output_size, Tensor indices) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple max_pool2d_with_indices_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::max_pool2d_with_indices.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple max_pool2d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "True"} +at::Tensor & max_pool2d_with_indices_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::max_pool2d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_pool2d_with_indices_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices); // {"schema": "aten::max_pool2d_with_indices_backward(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride, int[2] padding, int[2] dilation, bool ceil_mode, Tensor indices) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple max_pool3d_with_indices_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out, at::Tensor & indices); // {"schema": "aten::max_pool3d_with_indices.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out, Tensor(b!) indices) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "False"} +::std::tuple max_pool3d_with_indices(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode); // {"schema": "aten::max_pool3d_with_indices(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor & max_pool3d_with_indices_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices, at::Tensor & grad_input); // {"schema": "aten::max_pool3d_with_indices_backward.grad_input(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_pool3d_with_indices_backward(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices); // {"schema": "aten::max_pool3d_with_indices_backward(Tensor grad_output, Tensor self, int[3] kernel_size, int[3] stride, int[3] padding, int[3] dilation, bool ceil_mode, Tensor indices) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & max_unpool2d_out(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::max_unpool2d.out(Tensor self, Tensor indices, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_unpool2d(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size); // {"schema": "aten::max_unpool2d(Tensor self, Tensor indices, SymInt[2] output_size) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & max_unpool3d_out(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding, at::Tensor & out); // {"schema": "aten::max_unpool3d.out(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor max_unpool3d(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding); // {"schema": "aten::max_unpool3d(Tensor self, Tensor indices, SymInt[3] output_size, int[3] stride, int[3] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & reflection_pad1d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::reflection_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad1d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad1d(Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reflection_pad1d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::reflection_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad1d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reflection_pad2d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::reflection_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad2d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad2d(Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & reflection_pad2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::reflection_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad2d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & reflection_pad3d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::reflection_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad3d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad3d(Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & reflection_pad3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::reflection_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor reflection_pad3d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::reflection_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad1d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::replication_pad1d.out(Tensor self, SymInt[2] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad1d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad1d(Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad1d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::replication_pad1d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[2] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad1d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad1d_backward(Tensor grad_output, Tensor self, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad2d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::replication_pad2d.out(Tensor self, SymInt[4] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad2d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad2d(Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::replication_pad2d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[4] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad2d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad2d_backward(Tensor grad_output, Tensor self, SymInt[4] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & replication_pad3d_out(const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::replication_pad3d.out(Tensor self, SymInt[6] padding, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad3d(const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad3d(Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & replication_pad3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding, at::Tensor & grad_input); // {"schema": "aten::replication_pad3d_backward.grad_input(Tensor grad_output, Tensor self, SymInt[6] padding, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor replication_pad3d_backward(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding); // {"schema": "aten::replication_pad3d_backward(Tensor grad_output, Tensor self, SymInt[6] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _pad_circular(const at::Tensor & self, c10::SymIntArrayRef pad); // {"schema": "aten::_pad_circular(Tensor self, SymInt[] pad) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _pad_enum(const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, ::std::optional value); // {"schema": "aten::_pad_enum(Tensor self, SymInt[] pad, int mode, float? value=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor pad(const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode, ::std::optional value); // {"schema": "aten::pad(Tensor self, SymInt[] pad, str mode=\"constant\", float? value=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_linear1d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_linear1d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_bilinear2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_bilinear2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_bilinear2d_aa(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::_upsample_bilinear2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_trilinear3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_trilinear3d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_bicubic2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::upsample_bicubic2d.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_bicubic2d_aa(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors); // {"schema": "aten::_upsample_bicubic2d_aa.vec(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_nearest1d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::upsample_nearest1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_nearest_exact1d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact1d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_nearest2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::upsample_nearest2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_nearest_exact2d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact2d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor upsample_nearest3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::upsample_nearest3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _upsample_nearest_exact3d(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors); // {"schema": "aten::_upsample_nearest_exact3d.vec(Tensor input, SymInt[]? output_size, float[]? scale_factors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & upsample_linear1d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales, at::Tensor & out); // {"schema": "aten::upsample_linear1d.out(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_linear1d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales); // {"schema": "aten::upsample_linear1d(Tensor self, SymInt[1] output_size, bool align_corners, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_linear1d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales, at::Tensor & grad_input); // {"schema": "aten::upsample_linear1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_linear1d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales); // {"schema": "aten::upsample_linear1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, bool align_corners, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bilinear2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_bilinear2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bilinear2d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bilinear2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bilinear2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_bilinear2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bilinear2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bilinear2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bilinear2d_aa_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_bilinear2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bilinear2d_aa(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bilinear2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bilinear2d_aa_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_bilinear2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bilinear2d_aa_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bilinear2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bicubic2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_bicubic2d.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bicubic2d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bicubic2d(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bicubic2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_bicubic2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_bicubic2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_bicubic2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bicubic2d_aa_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_bicubic2d_aa.out(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bicubic2d_aa(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bicubic2d_aa(Tensor self, SymInt[2] output_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _upsample_bicubic2d_aa_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_bicubic2d_aa_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _upsample_bicubic2d_aa_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_bicubic2d_aa_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, bool align_corners, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_trilinear3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_trilinear3d.out(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_trilinear3d(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_trilinear3d(Tensor self, SymInt[3] output_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_trilinear3d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_trilinear3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_trilinear3d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_trilinear3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, bool align_corners, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest1d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out); // {"schema": "aten::upsample_nearest1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact1d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales, at::Tensor & out); // {"schema": "aten::_upsample_nearest_exact1d.out(Tensor self, SymInt[1] output_size, float? scales=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest1d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales); // {"schema": "aten::upsample_nearest1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact1d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales); // {"schema": "aten::_upsample_nearest_exact1d(Tensor self, SymInt[1] output_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest1d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input); // {"schema": "aten::upsample_nearest1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact1d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales, at::Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact1d_backward.grad_input(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest1d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales); // {"schema": "aten::upsample_nearest1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact1d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales); // {"schema": "aten::_upsample_nearest_exact1d_backward(Tensor grad_output, SymInt[1] output_size, SymInt[3] input_size, float? scales=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_nearest2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_nearest_exact2d.out(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest2d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact2d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact2d(Tensor self, SymInt[2] output_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_nearest2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact2d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact2d_backward.grad_input(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact2d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact2d_backward(Tensor grad_output, SymInt[2] output_size, SymInt[4] input_size, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::upsample_nearest3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & out); // {"schema": "aten::_upsample_nearest_exact3d.out(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest3d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact3d(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact3d(Tensor self, SymInt[3] output_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest3d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::upsample_nearest3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & _upsample_nearest_exact3d_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w, at::Tensor & grad_input); // {"schema": "aten::_upsample_nearest_exact3d_backward.grad_input(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor upsample_nearest3d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::upsample_nearest3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _upsample_nearest_exact3d_backward(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w); // {"schema": "aten::_upsample_nearest_exact3d_backward(Tensor grad_output, SymInt[3] output_size, SymInt[5] input_size, float? scales_d=None, float? scales_h=None, float? scales_w=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sigmoid_backward_out(const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input); // {"schema": "aten::sigmoid_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor sigmoid_backward(const at::Tensor & grad_output, const at::Tensor & output); // {"schema": "aten::sigmoid_backward(Tensor grad_output, Tensor output) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & logit_backward_out(const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps, at::Tensor & grad_input); // {"schema": "aten::logit_backward.grad_input(Tensor grad_output, Tensor self, float? eps=None, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor logit_backward(const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps); // {"schema": "aten::logit_backward(Tensor grad_output, Tensor self, float? eps=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tanh_backward_out(const at::Tensor & grad_output, const at::Tensor & output, at::Tensor & grad_input); // {"schema": "aten::tanh_backward.grad_input(Tensor grad_output, Tensor output, *, Tensor(a!) grad_input) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor tanh_backward(const at::Tensor & grad_output, const at::Tensor & output); // {"schema": "aten::tanh_backward(Tensor grad_output, Tensor output) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_transpose2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_transpose2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_transpose2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_transpose2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] output_padding=0, SymInt[2] dilation=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_transpose3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_transpose3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_transpose3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_transpose3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] output_padding=0, SymInt[3] dilation=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & thnn_conv2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::thnn_conv2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor thnn_conv2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::thnn_conv2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & _slow_conv2d_forward_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output); // {"schema": "aten::_slow_conv2d_forward.output(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) output) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _slow_conv2d_forward(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::_slow_conv2d_forward(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _slow_conv2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & grad_input, at::Tensor & grad_weight, at::Tensor & grad_bias); // {"schema": "aten::_slow_conv2d_backward.grad_input(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, *, Tensor(a!) grad_input, Tensor(b!) grad_weight, Tensor(c!) grad_bias) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "False"} +::std::tuple _slow_conv2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask); // {"schema": "aten::_slow_conv2d_backward.output_mask(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)", "dispatch": "True", "default": "False"} +at::Tensor & _conv_depthwise2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::_conv_depthwise2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _conv_depthwise2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::_conv_depthwise2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias, SymInt[2] stride, SymInt[2] padding, SymInt[2] dilation) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor conv_depthwise3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::conv_depthwise3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & slow_conv3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & out); // {"schema": "aten::slow_conv3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor slow_conv3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::slow_conv3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & slow_conv3d_forward_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, at::Tensor & output); // {"schema": "aten::slow_conv3d_forward.output(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, *, Tensor(a!) output) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor slow_conv3d_forward(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding); // {"schema": "aten::slow_conv3d_forward(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_dilated2d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_dilated2d(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor slow_conv_dilated3d(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation); // {"schema": "aten::slow_conv_dilated3d(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & col2im_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out); // {"schema": "aten::col2im.out(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor col2im(const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride); // {"schema": "aten::col2im(Tensor self, SymInt[2] output_size, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor column_stack(at::TensorList tensors); // {"schema": "aten::column_stack(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & column_stack_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::column_stack.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & im2col_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride, at::Tensor & out); // {"schema": "aten::im2col.out(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor im2col(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride); // {"schema": "aten::im2col(Tensor self, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor isfinite(const at::Tensor & self); // {"schema": "aten::isfinite(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor isinf(const at::Tensor & self); // {"schema": "aten::isinf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +void record_stream(at::Tensor & self, at::Stream s); // {"schema": "aten::record_stream(Tensor(a!) self, Stream s) -> ()", "dispatch": "True", "default": "False"} +at::Tensor isposinf(const at::Tensor & self); // {"schema": "aten::isposinf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isposinf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isposinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor isneginf(const at::Tensor & self); // {"schema": "aten::isneginf(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isneginf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isneginf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _add_batch_dim(const at::Tensor & self, int64_t batch_dim, int64_t level); // {"schema": "aten::_add_batch_dim(Tensor self, int batch_dim, int level) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _remove_batch_dim(const at::Tensor & self, int64_t level, c10::SymInt batch_size, int64_t out_dim); // {"schema": "aten::_remove_batch_dim(Tensor self, int level, SymInt batch_size, int out_dim) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_entr(const at::Tensor & self); // {"schema": "aten::special_entr(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_entr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_entr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_ndtri(const at::Tensor & self); // {"schema": "aten::special_ndtri(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_ndtri_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_ndtri.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_log_ndtr(const at::Tensor & self); // {"schema": "aten::special_log_ndtr(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_log_ndtr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_log_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_expm1(const at::Tensor & self); // {"schema": "aten::special_expm1(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_expm1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_expm1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_exp2(const at::Tensor & self); // {"schema": "aten::special_exp2(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_exp2_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_exp2.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_psi(const at::Tensor & self); // {"schema": "aten::special_psi(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_psi_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_psi.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_digamma(const at::Tensor & self); // {"schema": "aten::special_digamma(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_digamma_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_digamma.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_gammaln(const at::Tensor & self); // {"schema": "aten::special_gammaln(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_gammaln_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_gammaln.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_erf(const at::Tensor & self); // {"schema": "aten::special_erf(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_erf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_erfc(const at::Tensor & self); // {"schema": "aten::special_erfc(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_erfc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erfc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_erfcx(const at::Tensor & self); // {"schema": "aten::special_erfcx(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_erfcx_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erfcx.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_erfinv(const at::Tensor & self); // {"schema": "aten::special_erfinv(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_erfinv_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_erfinv.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_ndtr(const at::Tensor & self); // {"schema": "aten::special_ndtr(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_ndtr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_ndtr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_xlog1py(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_xlog1py(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_xlog1py(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::special_xlog1py.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_xlog1py(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::special_xlog1py.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_xlog1py_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlog1py.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_xlog1py_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlog1py.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_xlog1py_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::special_xlog1py.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_xlogy(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_xlogy(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_xlogy(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::special_xlogy.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_xlogy(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::special_xlogy.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_xlogy_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlogy.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & special_xlogy_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_xlogy.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & special_xlogy_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::special_xlogy.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_zeta(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_zeta(Tensor self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_zeta(const at::Scalar & self, const at::Tensor & other); // {"schema": "aten::special_zeta.self_scalar(Scalar self, Tensor other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_zeta(const at::Tensor & self, const at::Scalar & other); // {"schema": "aten::special_zeta.other_scalar(Tensor self, Scalar other) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_zeta_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_zeta.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_zeta_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_zeta.self_scalar_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_zeta_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::special_zeta.other_scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_i0(const at::Tensor & self); // {"schema": "aten::special_i0(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_i0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_i0e(const at::Tensor & self); // {"schema": "aten::special_i0e(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_i0e_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i0e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_i1(const at::Tensor & self); // {"schema": "aten::special_i1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_i1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_i1e(const at::Tensor & self); // {"schema": "aten::special_i1e(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_i1e_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_i1e.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_logit(const at::Tensor & self, ::std::optional eps); // {"schema": "aten::special_logit(Tensor self, float? eps=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_logit_out(const at::Tensor & self, ::std::optional eps, at::Tensor & out); // {"schema": "aten::special_logit.out(Tensor self, float? eps=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_polygamma(int64_t n, const at::Tensor & self); // {"schema": "aten::special_polygamma(int n, Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_polygamma_out(int64_t n, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_polygamma.out(int n, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_logsumexp(const at::Tensor & self, at::IntArrayRef dim, bool keepdim); // {"schema": "aten::special_logsumexp(Tensor self, int[1] dim, bool keepdim=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_logsumexp_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, at::Tensor & out); // {"schema": "aten::special_logsumexp.out(Tensor self, int[1] dim, bool keepdim=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_expit(const at::Tensor & self); // {"schema": "aten::special_expit(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_expit_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_expit.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_sinc(const at::Tensor & self); // {"schema": "aten::special_sinc(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_sinc_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_sinc.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_round(const at::Tensor & self, int64_t decimals); // {"schema": "aten::special_round(Tensor self, *, int decimals=0) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_round_out(const at::Tensor & self, int64_t decimals, at::Tensor & out); // {"schema": "aten::special_round.out(Tensor self, *, int decimals=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_log1p(const at::Tensor & self); // {"schema": "aten::special_log1p(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_log1p_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_log1p.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_log_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::special_log_softmax(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_gammainc_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_gammainc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_gammainc(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_gammainc(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_gammaincc_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::special_gammaincc.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_gammaincc(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::special_gammaincc(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor special_multigammaln(const at::Tensor & self, int64_t p); // {"schema": "aten::special_multigammaln(Tensor self, int p) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & special_multigammaln_out(const at::Tensor & self, int64_t p, at::Tensor & out); // {"schema": "aten::special_multigammaln.out(Tensor self, int p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor special_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::special_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fft_fft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_fft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_fft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_fft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ifft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_ifft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ifft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ifft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_rfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_rfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_rfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_rfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_irfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_irfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_irfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_irfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_hfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_hfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_hfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_hfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ihfft(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm); // {"schema": "aten::fft_ihfft(Tensor self, SymInt? n=None, int dim=-1, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ihfft_out(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ihfft.out(Tensor self, SymInt? n=None, int dim=-1, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_fft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_fft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_fft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_fft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ifft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ifft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ifft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ifft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_rfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_rfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_rfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_rfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_irfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_irfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_irfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_irfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_hfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_hfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_hfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_hfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ihfft2(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ihfft2(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ihfft2_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ihfft2.out(Tensor self, SymInt[1]? s=None, int[1] dim=[-2,-1], str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_fftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_fftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_fftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_fftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ifftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ifftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ifftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ifftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_rfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_rfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_rfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_rfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_irfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_irfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_irfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_irfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_hfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_hfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_hfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_hfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_ihfftn(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm); // {"schema": "aten::fft_ihfftn(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & fft_ihfftn_out(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm, at::Tensor & out); // {"schema": "aten::fft_ihfftn.out(Tensor self, SymInt[1]? s=None, int[1]? dim=None, str? norm=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor fft_fftfreq(int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::fft_fftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fft_fftfreq_out(int64_t n, double d, at::Tensor & out); // {"schema": "aten::fft_fftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor fft_rfftfreq(int64_t n, double d, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::fft_rfftfreq(int n, float d=1.0, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & fft_rfftfreq_out(int64_t n, double d, at::Tensor & out); // {"schema": "aten::fft_rfftfreq.out(int n, float d=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor fft_fftshift(const at::Tensor & self, at::OptionalIntArrayRef dim); // {"schema": "aten::fft_fftshift(Tensor self, int[1]? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor fft_ifftshift(const at::Tensor & self, at::OptionalIntArrayRef dim); // {"schema": "aten::fft_ifftshift(Tensor self, int[1]? dim=None) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple linalg_cholesky_ex(const at::Tensor & self, bool upper, bool check_errors); // {"schema": "aten::linalg_cholesky_ex(Tensor self, *, bool upper=False, bool check_errors=False) -> (Tensor L, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_cholesky_ex_out(const at::Tensor & self, bool upper, bool check_errors, at::Tensor & L, at::Tensor & info); // {"schema": "aten::linalg_cholesky_ex.L(Tensor self, *, bool upper=False, bool check_errors=False, Tensor(a!) L, Tensor(b!) info) -> (Tensor(a!) L, Tensor(b!) info)", "dispatch": "True", "default": "False"} +at::Tensor linalg_cholesky(const at::Tensor & self, bool upper); // {"schema": "aten::linalg_cholesky(Tensor self, *, bool upper=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_cholesky_out(const at::Tensor & self, bool upper, at::Tensor & out); // {"schema": "aten::linalg_cholesky.out(Tensor self, *, bool upper=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_cross(const at::Tensor & self, const at::Tensor & other, int64_t dim); // {"schema": "aten::linalg_cross(Tensor self, Tensor other, *, int dim=-1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_cross_out(const at::Tensor & self, const at::Tensor & other, int64_t dim, at::Tensor & out); // {"schema": "aten::linalg_cross.out(Tensor self, Tensor other, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple linalg_lu_factor(const at::Tensor & A, bool pivot); // {"schema": "aten::linalg_lu_factor(Tensor A, *, bool pivot=True) -> (Tensor LU, Tensor pivots)", "dispatch": "False", "default": "True"} +::std::tuple linalg_lu_factor_out(const at::Tensor & A, bool pivot, at::Tensor & LU, at::Tensor & pivots); // {"schema": "aten::linalg_lu_factor.out(Tensor A, *, bool pivot=True, Tensor(a!) LU, Tensor(b!) pivots) -> (Tensor(a!) LU, Tensor(b!) pivots)", "dispatch": "False", "default": "True"} +::std::tuple linalg_lu_factor_ex(const at::Tensor & A, bool pivot, bool check_errors); // {"schema": "aten::linalg_lu_factor_ex(Tensor A, *, bool pivot=True, bool check_errors=False) -> (Tensor LU, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_lu_factor_ex_out(const at::Tensor & A, bool pivot, bool check_errors, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info); // {"schema": "aten::linalg_lu_factor_ex.out(Tensor A, *, bool pivot=True, bool check_errors=False, Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LU, Tensor(b!) pivots, Tensor(c!) info)", "dispatch": "True", "default": "False"} +::std::tuple linalg_lu(const at::Tensor & A, bool pivot); // {"schema": "aten::linalg_lu(Tensor A, *, bool pivot=True) -> (Tensor P, Tensor L, Tensor U)", "dispatch": "True", "default": "True"} +::std::tuple linalg_lu_out(const at::Tensor & A, bool pivot, at::Tensor & P, at::Tensor & L, at::Tensor & U); // {"schema": "aten::linalg_lu.out(Tensor A, *, bool pivot=True, Tensor(a!) P, Tensor(b!) L, Tensor(c!) U) -> (Tensor(a!) P, Tensor(b!) L, Tensor(c!) U)", "dispatch": "True", "default": "False"} +at::Tensor linalg_lu_solve(const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint); // {"schema": "aten::linalg_lu_solve(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_lu_solve_out(const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint, at::Tensor & out); // {"schema": "aten::linalg_lu_solve.out(Tensor LU, Tensor pivots, Tensor B, *, bool left=True, bool adjoint=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple _linalg_det(const at::Tensor & A); // {"schema": "aten::_linalg_det(Tensor A) -> (Tensor result, Tensor LU, Tensor pivots)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_det_out(const at::Tensor & A, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots); // {"schema": "aten::_linalg_det.result(Tensor A, *, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots)", "dispatch": "True", "default": "False"} +at::Tensor linalg_det(const at::Tensor & A); // {"schema": "aten::linalg_det(Tensor A) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_det_out(const at::Tensor & A, at::Tensor & out); // {"schema": "aten::linalg_det.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor det(const at::Tensor & self); // {"schema": "aten::det(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple linalg_ldl_factor_ex(const at::Tensor & self, bool hermitian, bool check_errors); // {"schema": "aten::linalg_ldl_factor_ex(Tensor self, *, bool hermitian=False, bool check_errors=False) -> (Tensor LD, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_ldl_factor_ex_out(const at::Tensor & self, bool hermitian, bool check_errors, at::Tensor & LD, at::Tensor & pivots, at::Tensor & info); // {"schema": "aten::linalg_ldl_factor_ex.out(Tensor self, *, bool hermitian=False, bool check_errors=False, Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info) -> (Tensor(a!) LD, Tensor(b!) pivots, Tensor(c!) info)", "dispatch": "True", "default": "False"} +::std::tuple linalg_ldl_factor(const at::Tensor & self, bool hermitian); // {"schema": "aten::linalg_ldl_factor(Tensor self, *, bool hermitian=False) -> (Tensor LD, Tensor pivots)", "dispatch": "False", "default": "True"} +::std::tuple linalg_ldl_factor_out(const at::Tensor & self, bool hermitian, at::Tensor & LD, at::Tensor & pivots); // {"schema": "aten::linalg_ldl_factor.out(Tensor self, *, bool hermitian=False, Tensor(a!) LD, Tensor(b!) pivots) -> (Tensor(a!) LD, Tensor(b!) pivots)", "dispatch": "False", "default": "True"} +at::Tensor linalg_ldl_solve(const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian); // {"schema": "aten::linalg_ldl_solve(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_ldl_solve_out(const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_ldl_solve.out(Tensor LD, Tensor pivots, Tensor B, *, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple linalg_lstsq(const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver); // {"schema": "aten::linalg_lstsq(Tensor self, Tensor b, float? rcond=None, *, str? driver=None) -> (Tensor solution, Tensor residuals, Tensor rank, Tensor singular_values)", "dispatch": "True", "default": "True"} +::std::tuple linalg_lstsq_out(const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver, at::Tensor & solution, at::Tensor & residuals, at::Tensor & rank, at::Tensor & singular_values); // {"schema": "aten::linalg_lstsq.out(Tensor self, Tensor b, float? rcond=None, *, str? driver=None, Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values) -> (Tensor(a!) solution, Tensor(b!) residuals, Tensor(c!) rank, Tensor(d!) singular_values)", "dispatch": "True", "default": "False"} +at::Tensor linalg_matmul(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::linalg_matmul(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matmul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::linalg_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_vecdot(const at::Tensor & x, const at::Tensor & y, int64_t dim); // {"schema": "aten::linalg_vecdot(Tensor x, Tensor y, *, int dim=-1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_vecdot_out(const at::Tensor & x, const at::Tensor & y, int64_t dim, at::Tensor & out); // {"schema": "aten::linalg_vecdot.out(Tensor x, Tensor y, *, int dim=-1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_exp(const at::Tensor & self); // {"schema": "aten::linalg_matrix_exp(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _linalg_slogdet(const at::Tensor & A); // {"schema": "aten::_linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet, Tensor LU, Tensor pivots)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_slogdet_out(const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet, at::Tensor & LU, at::Tensor & pivots); // {"schema": "aten::_linalg_slogdet.sign(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots) -> (Tensor(a!) sign, Tensor(b!) logabsdet, Tensor(c!) LU, Tensor(d!) pivots)", "dispatch": "True", "default": "False"} +::std::tuple linalg_slogdet(const at::Tensor & A); // {"schema": "aten::linalg_slogdet(Tensor A) -> (Tensor sign, Tensor logabsdet)", "dispatch": "False", "default": "True"} +::std::tuple linalg_slogdet_out(const at::Tensor & A, at::Tensor & sign, at::Tensor & logabsdet); // {"schema": "aten::linalg_slogdet.out(Tensor A, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)", "dispatch": "False", "default": "True"} +::std::tuple slogdet(const at::Tensor & self); // {"schema": "aten::slogdet(Tensor self) -> (Tensor sign, Tensor logabsdet)", "dispatch": "False", "default": "True"} +::std::tuple slogdet_out(const at::Tensor & self, at::Tensor & sign, at::Tensor & logabsdet); // {"schema": "aten::slogdet.out(Tensor self, *, Tensor(a!) sign, Tensor(b!) logabsdet) -> (Tensor(a!) sign, Tensor(b!) logabsdet)", "dispatch": "False", "default": "True"} +at::Tensor logdet(const at::Tensor & self); // {"schema": "aten::logdet(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +::std::tuple linalg_eig(const at::Tensor & self); // {"schema": "aten::linalg_eig(Tensor self) -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "True", "default": "False"} +::std::tuple linalg_eig_out(const at::Tensor & self, at::Tensor & eigenvalues, at::Tensor & eigenvectors); // {"schema": "aten::linalg_eig.out(Tensor self, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "True", "default": "False"} +at::Tensor _linalg_eigvals(const at::Tensor & self); // {"schema": "aten::_linalg_eigvals(Tensor self) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor linalg_eigvals(const at::Tensor & self); // {"schema": "aten::linalg_eigvals(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_eigvals_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::linalg_eigvals.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple _linalg_eigh(const at::Tensor & A, c10::string_view UPLO, bool compute_v); // {"schema": "aten::_linalg_eigh(Tensor A, str UPLO=\"L\", bool compute_v=True) -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_eigh_out(const at::Tensor & A, c10::string_view UPLO, bool compute_v, at::Tensor & eigenvalues, at::Tensor & eigenvectors); // {"schema": "aten::_linalg_eigh.eigenvalues(Tensor A, str UPLO=\"L\", bool compute_v=True, *, Tensor(a!) eigenvalues, Tensor(b!) eigenvectors) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "True", "default": "False"} +::std::tuple linalg_eigh(const at::Tensor & self, c10::string_view UPLO); // {"schema": "aten::linalg_eigh(Tensor self, str UPLO=\"L\") -> (Tensor eigenvalues, Tensor eigenvectors)", "dispatch": "False", "default": "True"} +::std::tuple linalg_eigh_out(const at::Tensor & self, c10::string_view UPLO, at::Tensor & eigvals, at::Tensor & eigvecs); // {"schema": "aten::linalg_eigh.eigvals(Tensor self, str UPLO=\"L\", *, Tensor(a!) eigvals, Tensor(b!) eigvecs) -> (Tensor(a!) eigenvalues, Tensor(b!) eigenvectors)", "dispatch": "False", "default": "True"} +at::Tensor linalg_eigvalsh(const at::Tensor & self, c10::string_view UPLO); // {"schema": "aten::linalg_eigvalsh(Tensor self, str UPLO=\"L\") -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_eigvalsh_out(const at::Tensor & self, c10::string_view UPLO, at::Tensor & out); // {"schema": "aten::linalg_eigvalsh.out(Tensor self, str UPLO=\"L\", *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_householder_product(const at::Tensor & input, const at::Tensor & tau); // {"schema": "aten::linalg_householder_product(Tensor input, Tensor tau) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & linalg_householder_product_out(const at::Tensor & input, const at::Tensor & tau, at::Tensor & out); // {"schema": "aten::linalg_householder_product.out(Tensor input, Tensor tau, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +::std::tuple linalg_inv_ex(const at::Tensor & A, bool check_errors); // {"schema": "aten::linalg_inv_ex(Tensor A, *, bool check_errors=False) -> (Tensor inverse, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple linalg_inv_ex_out(const at::Tensor & A, bool check_errors, at::Tensor & inverse, at::Tensor & info); // {"schema": "aten::linalg_inv_ex.inverse(Tensor A, *, bool check_errors=False, Tensor(a!) inverse, Tensor(b!) info) -> (Tensor(a!) inverse, Tensor(b!) info)", "dispatch": "True", "default": "False"} +at::Tensor linalg_inv(const at::Tensor & A); // {"schema": "aten::linalg_inv(Tensor A) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_inv_out(const at::Tensor & A, at::Tensor & out); // {"schema": "aten::linalg_inv.out(Tensor A, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor inverse(const at::Tensor & self); // {"schema": "aten::inverse(Tensor self) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & inverse_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::inverse.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor inner(const at::Tensor & self, const at::Tensor & other); // {"schema": "aten::inner(Tensor self, Tensor other) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & inner_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::inner.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor outer(const at::Tensor & self, const at::Tensor & vec2); // {"schema": "aten::outer(Tensor self, Tensor vec2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & outer_out(const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out); // {"schema": "aten::outer.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor ger(const at::Tensor & self, const at::Tensor & vec2); // {"schema": "aten::ger(Tensor self, Tensor vec2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & ger_out(const at::Tensor & self, const at::Tensor & vec2, at::Tensor & out); // {"schema": "aten::ger.out(Tensor self, Tensor vec2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_norm(const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_norm(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor linalg_norm(const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_norm.ord_str(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_norm_out(const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_norm.out(Tensor self, Scalar? ord=None, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & linalg_norm_out(const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_norm.ord_str_out(Tensor self, str ord, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_vector_norm(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_vector_norm(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_vector_norm_out(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_vector_norm.out(Tensor self, Scalar ord=2, int[1]? dim=None, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor linalg_matrix_norm(const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_matrix_norm(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_norm_out(const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_matrix_norm.out(Tensor self, Scalar ord, int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_norm(const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype); // {"schema": "aten::linalg_matrix_norm.str_ord(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_norm_out(const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::linalg_matrix_norm.str_ord_out(Tensor self, str ord='fro', int[] dim=[-2,-1], bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple _linalg_svd(const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver); // {"schema": "aten::_linalg_svd(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_svd_out(const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh); // {"schema": "aten::_linalg_svd.U(Tensor A, bool full_matrices=False, bool compute_uv=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)", "dispatch": "True", "default": "False"} +::std::tuple linalg_svd(const at::Tensor & A, bool full_matrices, ::std::optional driver); // {"schema": "aten::linalg_svd(Tensor A, bool full_matrices=True, *, str? driver=None) -> (Tensor U, Tensor S, Tensor Vh)", "dispatch": "False", "default": "True"} +::std::tuple linalg_svd_out(const at::Tensor & A, bool full_matrices, ::std::optional driver, at::Tensor & U, at::Tensor & S, at::Tensor & Vh); // {"schema": "aten::linalg_svd.U(Tensor A, bool full_matrices=True, *, str? driver=None, Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh) -> (Tensor(a!) U, Tensor(b!) S, Tensor(c!) Vh)", "dispatch": "False", "default": "True"} +at::Tensor linalg_svdvals(const at::Tensor & A, ::std::optional driver); // {"schema": "aten::linalg_svdvals(Tensor A, *, str? driver=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_svdvals_out(const at::Tensor & A, ::std::optional driver, at::Tensor & out); // {"schema": "aten::linalg_svdvals.out(Tensor A, *, str? driver=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_cond(const at::Tensor & self, const ::std::optional & p); // {"schema": "aten::linalg_cond(Tensor self, Scalar? p=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_cond_out(const at::Tensor & self, const ::std::optional & p, at::Tensor & out); // {"schema": "aten::linalg_cond.out(Tensor self, Scalar? p=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_cond(const at::Tensor & self, c10::string_view p); // {"schema": "aten::linalg_cond.p_str(Tensor self, str p) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_cond_out(const at::Tensor & self, c10::string_view p, at::Tensor & out); // {"schema": "aten::linalg_cond.p_str_out(Tensor self, str p, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian); // {"schema": "aten::linalg_pinv.atol_rtol_tensor(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.atol_rtol_tensor_out(Tensor self, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian); // {"schema": "aten::linalg_pinv.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, double rcond, bool hermitian); // {"schema": "aten::linalg_pinv(Tensor self, float rcond, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor linalg_pinv(const at::Tensor & self, const at::Tensor & rcond, bool hermitian); // {"schema": "aten::linalg_pinv.rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, double rcond, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.out(Tensor self, float rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor & linalg_pinv_out(const at::Tensor & self, const at::Tensor & rcond, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_pinv.out_rcond_tensor(Tensor self, Tensor rcond, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple _linalg_solve_ex(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors); // {"schema": "aten::_linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor LU, Tensor pivots, Tensor info)", "dispatch": "True", "default": "True"} +::std::tuple _linalg_solve_ex_out(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & LU, at::Tensor & pivots, at::Tensor & info); // {"schema": "aten::_linalg_solve_ex.result(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info) -> (Tensor(a!) result, Tensor(b!) LU, Tensor(c!) pivots, Tensor(d!) info)", "dispatch": "True", "default": "False"} +::std::tuple linalg_solve_ex(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors); // {"schema": "aten::linalg_solve_ex(Tensor A, Tensor B, *, bool left=True, bool check_errors=False) -> (Tensor result, Tensor info)", "dispatch": "False", "default": "True"} +::std::tuple linalg_solve_ex_out(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors, at::Tensor & result, at::Tensor & info); // {"schema": "aten::linalg_solve_ex.out(Tensor A, Tensor B, *, bool left=True, bool check_errors=False, Tensor(a!) result, Tensor(b!) info) -> (Tensor(a!) result, Tensor(b!) info)", "dispatch": "False", "default": "True"} +at::Tensor linalg_solve(const at::Tensor & A, const at::Tensor & B, bool left); // {"schema": "aten::linalg_solve(Tensor A, Tensor B, *, bool left=True) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _spsolve(const at::Tensor & A, const at::Tensor & B, bool left); // {"schema": "aten::_spsolve(Tensor A, Tensor B, *, bool left=True) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & linalg_solve_out(const at::Tensor & A, const at::Tensor & B, bool left, at::Tensor & out); // {"schema": "aten::linalg_solve.out(Tensor A, Tensor B, *, bool left=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_tensorinv(const at::Tensor & self, int64_t ind); // {"schema": "aten::linalg_tensorinv(Tensor self, int ind=2) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_tensorinv_out(const at::Tensor & self, int64_t ind, at::Tensor & out); // {"schema": "aten::linalg_tensorinv.out(Tensor self, int ind=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_tensorsolve(const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims); // {"schema": "aten::linalg_tensorsolve(Tensor self, Tensor other, int[]? dims=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_tensorsolve_out(const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims, at::Tensor & out); // {"schema": "aten::linalg_tensorsolve.out(Tensor self, Tensor other, int[]? dims=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +::std::tuple linalg_qr(const at::Tensor & A, c10::string_view mode); // {"schema": "aten::linalg_qr(Tensor A, str mode='reduced') -> (Tensor Q, Tensor R)", "dispatch": "True", "default": "True"} +::std::tuple linalg_qr_out(const at::Tensor & A, c10::string_view mode, at::Tensor & Q, at::Tensor & R); // {"schema": "aten::linalg_qr.out(Tensor A, str mode='reduced', *, Tensor(a!) Q, Tensor(b!) R) -> (Tensor(a!) Q, Tensor(b!) R)", "dispatch": "True", "default": "False"} +at::Tensor linalg_matrix_power(const at::Tensor & self, int64_t n); // {"schema": "aten::linalg_matrix_power(Tensor self, int n) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_power_out(const at::Tensor & self, int64_t n, at::Tensor & out); // {"schema": "aten::linalg_matrix_power.out(Tensor self, int n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.atol_rtol_tensor(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.atol_rtol_tensor_out(Tensor input, *, Tensor? atol=None, Tensor? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.atol_rtol_float(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.atol_rtol_float_out(Tensor self, *, float? atol=None, float? rtol=None, bool hermitian=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & self, double tol, bool hermitian); // {"schema": "aten::linalg_matrix_rank(Tensor self, float tol, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & self, double tol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.out(Tensor self, float tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_matrix_rank(const at::Tensor & input, const at::Tensor & tol, bool hermitian); // {"schema": "aten::linalg_matrix_rank.tol_tensor(Tensor input, Tensor tol, bool hermitian=False) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_matrix_rank_out(const at::Tensor & input, const at::Tensor & tol, bool hermitian, at::Tensor & out); // {"schema": "aten::linalg_matrix_rank.out_tol_tensor(Tensor input, Tensor tol, bool hermitian=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor linalg_multi_dot(at::TensorList tensors); // {"schema": "aten::linalg_multi_dot(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor & linalg_multi_dot_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::linalg_multi_dot.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "False", "default": "True"} +at::Tensor nested_to_padded_tensor(const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size); // {"schema": "aten::nested_to_padded_tensor(Tensor self, float padding, int[]? output_size=None) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_serialization_subcmul(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha); // {"schema": "aten::_test_serialization_subcmul(Tensor self, Tensor other, Scalar alpha=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_parallel_materialize(const at::Tensor & self, int64_t num_parallel, bool skip_first); // {"schema": "aten::_test_parallel_materialize(Tensor self, int num_parallel, bool skip_first=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_optional_intlist(const at::Tensor & values, at::OptionalIntArrayRef addends); // {"schema": "aten::_test_optional_intlist(Tensor values, int[]? addends) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _test_optional_filled_intlist(const at::Tensor & values, at::OptionalIntArrayRef addends); // {"schema": "aten::_test_optional_filled_intlist(Tensor values, int[2]? addends) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _test_optional_floatlist(const at::Tensor & values, ::std::optional> addends); // {"schema": "aten::_test_optional_floatlist(Tensor values, float[]? addends) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _test_string_default(const at::Tensor & dummy, c10::string_view a, c10::string_view b); // {"schema": "aten::_test_string_default(Tensor dummy, str a=\"\\\"'\\\\\", str b='\"\\'\\\\') -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_ambiguous_defaults(const at::Tensor & dummy, int64_t a, int64_t b); // {"schema": "aten::_test_ambiguous_defaults.a(Tensor dummy, int a=1, int b=1) -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_ambiguous_defaults(const at::Tensor & dummy, int64_t a, c10::string_view b); // {"schema": "aten::_test_ambiguous_defaults.b(Tensor dummy, int a=2, str b=\"2\") -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor _test_warn_in_autograd(const at::Tensor & self); // {"schema": "aten::_test_warn_in_autograd(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch(const at::Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch.fullcoverage(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch(const at::Tensor & self, bool b); // {"schema": "aten::_test_autograd_multiple_dispatch.ntonly(Tensor self, bool b) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch_view(const at::Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch_view(Tensor(a) self) -> Tensor(a)", "dispatch": "True", "default": "True"} +at::Tensor _test_autograd_multiple_dispatch_view_copy(const at::Tensor & self); // {"schema": "aten::_test_autograd_multiple_dispatch_view_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor segment_reduce(const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial); // {"schema": "aten::segment_reduce(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _segment_reduce_backward(const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial); // {"schema": "aten::_segment_reduce_backward(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor pad_sequence(at::TensorList sequences, bool batch_first, double padding_value, c10::string_view padding_side); // {"schema": "aten::pad_sequence(Tensor[] sequences, bool batch_first=False, float padding_value=0.0, str padding_side=\"right\") -> Tensor", "dispatch": "False", "default": "True"} +at::Tensor flatten_dense_tensors(at::TensorList tensors); // {"schema": "aten::flatten_dense_tensors(Tensor[] tensors) -> Tensor", "dispatch": "False", "default": "True"} +::std::vector unflatten_dense_tensors(const at::Tensor & flat, at::TensorList tensors); // {"schema": "aten::unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]", "dispatch": "False", "default": "True"} +at::Tensor _nested_tensor_from_tensor_list(at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory); // {"schema": "aten::_nested_tensor_from_tensor_list(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _fw_primal_copy(const at::Tensor & self, int64_t level); // {"schema": "aten::_fw_primal_copy(Tensor self, int level) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _make_dual_copy(const at::Tensor & primal, const at::Tensor & tangent, int64_t level); // {"schema": "aten::_make_dual_copy(Tensor primal, Tensor tangent, int level) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor view_as_real_copy(const at::Tensor & self); // {"schema": "aten::view_as_real_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor view_as_complex_copy(const at::Tensor & self); // {"schema": "aten::view_as_complex_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _conj_copy(const at::Tensor & self); // {"schema": "aten::_conj_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _neg_view_copy(const at::Tensor & self); // {"schema": "aten::_neg_view_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor as_strided_copy(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset); // {"schema": "aten::as_strided_copy(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _sparse_broadcast_to_copy(const at::Tensor & self, at::IntArrayRef size); // {"schema": "aten::_sparse_broadcast_to_copy(Tensor self, int[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor diagonal_copy(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2); // {"schema": "aten::diagonal_copy(Tensor self, int offset=0, int dim1=0, int dim2=1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor expand_copy(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit); // {"schema": "aten::expand_copy(Tensor self, SymInt[] size, *, bool implicit=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor permute_copy(const at::Tensor & self, at::IntArrayRef dims); // {"schema": "aten::permute_copy(Tensor self, int[] dims) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _reshape_alias_copy(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::_reshape_alias_copy(Tensor self, SymInt[] size, SymInt[] stride) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor select_copy(const at::Tensor & self, int64_t dim, c10::SymInt index); // {"schema": "aten::select_copy.int(Tensor self, int dim, SymInt index) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor detach_copy(const at::Tensor & self); // {"schema": "aten::detach_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor slice_copy(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step); // {"schema": "aten::slice_copy.Tensor(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1) -> Tensor", "dispatch": "True", "default": "True"} +::std::vector split_copy(const at::Tensor & self, c10::SymInt split_size, int64_t dim); // {"schema": "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +::std::vector split_with_sizes_copy(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim); // {"schema": "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +at::Tensor squeeze_copy(const at::Tensor & self); // {"schema": "aten::squeeze_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor squeeze_copy(const at::Tensor & self, int64_t dim); // {"schema": "aten::squeeze_copy.dim(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor squeeze_copy(const at::Tensor & self, at::IntArrayRef dim); // {"schema": "aten::squeeze_copy.dims(Tensor self, int[] dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor t_copy(const at::Tensor & self); // {"schema": "aten::t_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor transpose_copy(const at::Tensor & self, int64_t dim0, int64_t dim1); // {"schema": "aten::transpose_copy.int(Tensor self, int dim0, int dim1) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor unsqueeze_copy(const at::Tensor & self, int64_t dim); // {"schema": "aten::unsqueeze_copy(Tensor self, int dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _indices_copy(const at::Tensor & self); // {"schema": "aten::_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _values_copy(const at::Tensor & self); // {"schema": "aten::_values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor indices_copy(const at::Tensor & self); // {"schema": "aten::indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor values_copy(const at::Tensor & self); // {"schema": "aten::values_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor crow_indices_copy(const at::Tensor & self); // {"schema": "aten::crow_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor col_indices_copy(const at::Tensor & self); // {"schema": "aten::col_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor ccol_indices_copy(const at::Tensor & self); // {"schema": "aten::ccol_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor row_indices_copy(const at::Tensor & self); // {"schema": "aten::row_indices_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +::std::vector unbind_copy(const at::Tensor & self, int64_t dim); // {"schema": "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]", "dispatch": "True", "default": "True"} +void unbind_copy_out(const at::Tensor & self, int64_t dim, at::TensorList out); // {"schema": "aten::unbind_copy.int_out(Tensor self, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void split_copy_out(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out); // {"schema": "aten::split_copy.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void split_with_sizes_copy_out(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out); // {"schema": "aten::split_with_sizes_copy.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor view_copy(const at::Tensor & self, c10::SymIntArrayRef size); // {"schema": "aten::view_copy(Tensor self, SymInt[] size) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor view_copy(const at::Tensor & self, at::ScalarType dtype); // {"schema": "aten::view_copy.dtype(Tensor self, ScalarType dtype) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor unfold_copy(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step); // {"schema": "aten::unfold_copy(Tensor self, int dimension, int size, int step) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor alias_copy(const at::Tensor & self); // {"schema": "aten::alias_copy(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor to_padded_tensor(const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size); // {"schema": "aten::to_padded_tensor(Tensor self, float padding, SymInt[]? output_size=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _jagged_to_padded_dense_forward(const at::Tensor & values, at::TensorList offsets, c10::SymIntArrayRef max_lengths, double padding_value); // {"schema": "aten::_jagged_to_padded_dense_forward(Tensor values, Tensor[] offsets, SymInt[] max_lengths, float padding_value=0.0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _padded_dense_to_jagged_forward(const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L); // {"schema": "aten::_padded_dense_to_jagged_forward(Tensor dense, Tensor[] offsets, SymInt? total_L=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_from_padded_tensor(const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, ::std::optional sum_S); // {"schema": "aten::_nested_from_padded_tensor(Tensor padded, Tensor offsets, Tensor dummy, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, SymInt? sum_S=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _nested_tensor_softmax_with_shape(const at::Tensor & self, const at::Tensor & query); // {"schema": "aten::_nested_tensor_softmax_with_shape(Tensor self, Tensor query) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor _safe_softmax(const at::Tensor & self, int64_t dim, ::std::optional dtype); // {"schema": "aten::_safe_softmax(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor _transformer_encoder_layer_fwd(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type); // {"schema": "aten::_transformer_encoder_layer_fwd(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None) -> Tensor", "dispatch": "True", "default": "False"} +::std::tuple _native_multi_head_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type); // {"schema": "aten::_native_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +at::Tensor scaled_dot_product_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa); // {"schema": "aten::scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> Tensor", "dispatch": "False", "default": "True"} +int64_t _fused_sdp_choice(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa); // {"schema": "aten::_fused_sdp_choice(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, *, float? scale=None, bool enable_gqa=False) -> int", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_attention_math(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale, bool enable_gqa); // {"schema": "aten::_scaled_dot_product_attention_math(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None, bool enable_gqa=False) -> (Tensor, Tensor)", "dispatch": "False", "default": "True"} +::std::tuple _scaled_dot_product_attention_math_for_mps(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_attention_math_for_mps(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None, float dropout_p=0.0, bool is_causal=False, Tensor? dropout_mask=None, *, float? scale=None) -> (Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_flash_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_flash_attention_for_cpu(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_for_cpu(Tensor query, Tensor key, Tensor value, float dropout_p=0.0, bool is_causal=False, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor output, Tensor logsumexp)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_fused_attention_overrideable(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_fused_attention_overrideable(Tensor query, Tensor key, Tensor value, Tensor? attn_bias=None, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "True"} +::std::tuple _scaled_dot_product_flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_flash_attention_for_cpu_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, float dropout_p, bool is_causal, *, Tensor? attn_mask=None, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_fused_attention_overrideable_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_fused_attention_overrideable_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor attn_bias, bool[4] grad_input_mask, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor philox_seed, Tensor philox_offset, *, float? scale=None) -> (Tensor grad_query, Tensor grad_key, Tensor grad_value, Tensor grad_attn_bias)", "dispatch": "True", "default": "True"} +::std::tuple _scaled_dot_product_efficient_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_efficient_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, *, float? scale=None) -> (Tensor output, Tensor log_sumexp, Tensor philox_seed, Tensor philox_offset)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_efficient_attention_backward(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor attn_bias, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, float dropout_p, bool[4] grad_input_mask, bool is_causal=False, *, float? scale=None) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_cudnn_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_cudnn_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _scaled_dot_product_cudnn_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale); // {"schema": "aten::_scaled_dot_product_cudnn_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, Tensor attn_bias, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, *, float? scale=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _flash_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right, const ::std::optional & seqused_k, const ::std::optional & alibi_slopes); // {"schema": "aten::_flash_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, bool return_debug_mask, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None, Tensor? seqused_k=None, Tensor? alibi_slopes=None) -> (Tensor output, Tensor softmax_logsumexp, Tensor rng_state, Tensor unused, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +::std::tuple _flash_attention_backward(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right); // {"schema": "aten::_flash_attention_backward(Tensor grad_out, Tensor query, Tensor key, Tensor value, Tensor out, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, float dropout_p, bool is_causal, Tensor rng_state, Tensor unused, *, float? scale=None, SymInt? window_size_left=None, SymInt? window_size_right=None) -> (Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _efficient_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, ::std::optional max_seqlen_q, ::std::optional max_seqlen_k, double dropout_p, int64_t custom_mask_type, bool compute_log_sumexp, ::std::optional scale, const ::std::optional & seqlen_k, ::std::optional window_size); // {"schema": "aten::_efficient_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt? max_seqlen_q, SymInt? max_seqlen_k, float dropout_p, int custom_mask_type, bool compute_log_sumexp=False, *, float? scale=None, Tensor? seqlen_k=None, int? window_size=None) -> (Tensor output, Tensor logsumexp, Tensor philox_seed, Tensor philox_offset, SymInt max_seqlen_batch_q, SymInt max_seqlen_batch_k)", "dispatch": "True", "default": "False"} +::std::tuple _efficient_attention_backward(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale, ::std::optional num_splits_key, ::std::optional window_size, bool shared_storage_dqdkdv); // {"schema": "aten::_efficient_attention_backward(Tensor grad_out_, Tensor query, Tensor key, Tensor value, Tensor? bias, Tensor out, Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, SymInt max_seqlen_q, SymInt max_seqlen_k, Tensor logsumexp, float dropout_p, Tensor philox_seed, Tensor philox_offset, int custom_mask_type, bool bias_requires_grad, *, float? scale=None, int? num_splits_key=None, int? window_size=None, bool shared_storage_dqdkdv=False) -> (Tensor, Tensor, Tensor, Tensor)", "dispatch": "True", "default": "False"} +::std::tuple _cudnn_attention_forward(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, bool compute_log_sumexp, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale); // {"schema": "aten::_cudnn_attention_forward(Tensor query, Tensor key, Tensor value, Tensor? attn_bias, Tensor? cum_seq_q, Tensor? cum_seq_k, SymInt max_q, SymInt max_k, bool compute_log_sumexp, float dropout_p=0.0, bool is_causal=False, bool return_debug_mask=False, *, float? scale=None) -> (Tensor output, Tensor logsumexp, Tensor cum_seq_q, Tensor cum_seq_k, SymInt max_q, SymInt max_k, Tensor philox_seed, Tensor philox_offset, Tensor debug_attn_mask)", "dispatch": "True", "default": "False"} +at::Tensor _triton_scaled_dot_attention(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p); // {"schema": "aten::_triton_scaled_dot_attention(Tensor q, Tensor k, Tensor v, float dropout_p=0.0) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor & _fill_mem_eff_dropout_mask_(at::Tensor & self, double dropout_p, int64_t seed, int64_t offset); // {"schema": "aten::_fill_mem_eff_dropout_mask_(Tensor(a!) self, float dropout_p, int seed, int offset) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _triton_multi_head_attention(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask); // {"schema": "aten::_triton_multi_head_attention(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None) -> Tensor", "dispatch": "True", "default": "False"} +at::Tensor special_airy_ai(const at::Tensor & x); // {"schema": "aten::special_airy_ai(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_airy_ai_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_airy_ai.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_j0(const at::Tensor & self); // {"schema": "aten::special_bessel_j0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_j0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_j0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_j1(const at::Tensor & self); // {"schema": "aten::special_bessel_j1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_j1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_j1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_y0(const at::Tensor & self); // {"schema": "aten::special_bessel_y0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_y0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_y0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_bessel_y1(const at::Tensor & self); // {"schema": "aten::special_bessel_y1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_bessel_y1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_bessel_y1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_chebyshev_polynomial_t(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_t(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_t(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_t_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_u(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_u(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_u(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_u_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_v(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_v(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_v(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_v_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_w(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_w(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_chebyshev_polynomial_w(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_chebyshev_polynomial_w_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_h(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_h(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_h(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_h.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_h(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_hermite_polynomial_h.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_h_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_hermite_polynomial_h_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_h_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_h.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_he(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_he(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_he(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_hermite_polynomial_he.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_hermite_polynomial_he(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_hermite_polynomial_he.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_he_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_hermite_polynomial_he_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_hermite_polynomial_he_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_hermite_polynomial_he.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_laguerre_polynomial_l(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_laguerre_polynomial_l(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_laguerre_polynomial_l(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_laguerre_polynomial_l.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_laguerre_polynomial_l(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_laguerre_polynomial_l.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_laguerre_polynomial_l_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_laguerre_polynomial_l_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_laguerre_polynomial_l_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_laguerre_polynomial_l.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_legendre_polynomial_p(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_legendre_polynomial_p(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_legendre_polynomial_p(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_legendre_polynomial_p.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_legendre_polynomial_p(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_legendre_polynomial_p.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_legendre_polynomial_p_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_legendre_polynomial_p_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_legendre_polynomial_p_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_legendre_polynomial_p.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_modified_bessel_i0(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_i0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_i0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_i0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_modified_bessel_i1(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_i1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_i1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_i1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_modified_bessel_k0(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_k0(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_k0_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_k0.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_modified_bessel_k1(const at::Tensor & self); // {"schema": "aten::special_modified_bessel_k1(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_modified_bessel_k1_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::special_modified_bessel_k1.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_scaled_modified_bessel_k0(const at::Tensor & x); // {"schema": "aten::special_scaled_modified_bessel_k0(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_scaled_modified_bessel_k0_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_scaled_modified_bessel_k0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_scaled_modified_bessel_k1(const at::Tensor & x); // {"schema": "aten::special_scaled_modified_bessel_k1(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_scaled_modified_bessel_k1_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_scaled_modified_bessel_k1.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor special_shifted_chebyshev_polynomial_t(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_t(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_t(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_t_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_t_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_t.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_u(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_u(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_u(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_u_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_u_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_u.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_v(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_v(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_v(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_v_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_v_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_v.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_w(const at::Tensor & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w(Tensor x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_w(const at::Scalar & x, const at::Tensor & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.x_scalar(Scalar x, Tensor n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor special_shifted_chebyshev_polynomial_w(const at::Tensor & x, const at::Scalar & n); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.n_scalar(Tensor x, Scalar n) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.out(Tensor x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor & special_shifted_chebyshev_polynomial_w_out(const at::Scalar & x, const at::Tensor & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.x_scalar_out(Scalar x, Tensor n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & special_shifted_chebyshev_polynomial_w_out(const at::Tensor & x, const at::Scalar & n, at::Tensor & out); // {"schema": "aten::special_shifted_chebyshev_polynomial_w.n_scalar_out(Tensor x, Scalar n, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor special_spherical_bessel_j0(const at::Tensor & x); // {"schema": "aten::special_spherical_bessel_j0(Tensor x) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & special_spherical_bessel_j0_out(const at::Tensor & x, at::Tensor & out); // {"schema": "aten::special_spherical_bessel_j0.out(Tensor x, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "False"} +at::Tensor _foobar(const at::Tensor & self, bool arg1, bool arg2, bool arg3); // {"schema": "aten::_foobar(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True) -> Tensor", "dispatch": "True", "default": "False"} +void _fused_adam_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adam_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adamw_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adamw_(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_sgd_(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_sgd_(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adagrad_(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad_(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _fused_adagrad_(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad_.tensor_lr(Tensor(a!)[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> ()", "dispatch": "True", "default": "False"} +void _propagate_xla_data(const at::Tensor & input, const at::Tensor & output); // {"schema": "aten::_propagate_xla_data(Tensor input, Tensor output) -> ()", "dispatch": "False", "default": "True"} +at::Tensor & _new_zeros_with_same_feature_meta_out(const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims, at::Tensor & out); // {"schema": "aten::_new_zeros_with_same_feature_meta.out(Tensor self, Tensor other, *, int self_num_batch_dims=0, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _cudnn_ctc_loss_out(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_cudnn_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank, bool deterministic, bool zero_infinity, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _cudnn_rnn_flatten_weight_out(at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional, at::Tensor & out); // {"schema": "aten::_cudnn_rnn_flatten_weight.out(Tensor[] weight_arr, int weight_stride0, SymInt input_size, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, bool bidirectional, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _cudnn_rnn_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4); // {"schema": "aten::_cudnn_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor? weight_buf, Tensor hx, Tensor? cx, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"} +void _cudnn_rnn_backward_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3); // {"schema": "aten::_cudnn_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, SymInt hidden_size, SymInt proj_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, SymInt[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & _cudnn_init_dropout_state_out(double dropout, bool train, int64_t dropout_seed, at::Tensor & out); // {"schema": "aten::_cudnn_init_dropout_state.out(float dropout, bool train, int dropout_seed, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _fused_dropout_out(const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_fused_dropout.out(Tensor self, float p, Generator? generator=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _masked_scale_out(const at::Tensor & self, const at::Tensor & mask, double scale, at::Tensor & out); // {"schema": "aten::_masked_scale.out(Tensor self, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple native_dropout_out(const at::Tensor & input, double p, ::std::optional train, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::native_dropout.out(Tensor input, float p, bool? train, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & native_dropout_backward_out(const at::Tensor & grad_output, const at::Tensor & mask, double scale, at::Tensor & out); // {"schema": "aten::native_dropout_backward.out(Tensor grad_output, Tensor mask, float scale, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _conj_physical_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_conj_physical.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & avg_pool1d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, at::Tensor & out); // {"schema": "aten::avg_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, bool ceil_mode=False, bool count_include_pad=True, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & adaptive_avg_pool1d_out(const at::Tensor & self, at::IntArrayRef output_size, at::Tensor & out); // {"schema": "aten::adaptive_avg_pool1d.out(Tensor self, int[1] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _add_relu_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::_add_relu.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & add_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::add.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & affine_grid_generator_out(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners, at::Tensor & out); // {"schema": "aten::affine_grid_generator.out(Tensor theta, SymInt[] size, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_functorch_fallback_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::_test_functorch_fallback.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bartlett_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::bartlett_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bartlett_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::bartlett_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_batch_norm_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point, at::Tensor & out); // {"schema": "aten::quantized_batch_norm.out(Tensor input, Tensor? weight, Tensor? bias, Tensor mean, Tensor var, float eps, float output_scale, int output_zero_point, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bernoulli_out(const at::Tensor & self, const at::Tensor & p, ::std::optional generator, at::Tensor & out); // {"schema": "aten::bernoulli.Tensor_out(Tensor self, Tensor p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor bernoulli(const at::Tensor & self, const at::Tensor & p, ::std::optional generator); // {"schema": "aten::bernoulli.Tensor(Tensor self, Tensor p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & bernoulli_out(const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out); // {"schema": "aten::bernoulli.float_out(Tensor self, float p=0.5, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & binary_cross_entropy_with_logits_out(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction, at::Tensor & out); // {"schema": "aten::binary_cross_entropy_with_logits.out(Tensor self, Tensor target, Tensor? weight=None, Tensor? pos_weight=None, int reduction=Mean, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bincount_out(const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength, at::Tensor & out); // {"schema": "aten::bincount.out(Tensor self, Tensor? weights=None, SymInt minlength=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & blackman_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::blackman_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & blackman_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::blackman_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & block_diag_out(at::TensorList tensors, at::Tensor & out); // {"schema": "aten::block_diag.out(Tensor[] tensors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & constant_pad_nd_out(const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::constant_pad_nd.out(Tensor self, SymInt[] pad, Scalar value=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & convolution_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::convolution_backward.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[]? bias_sizes, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & convolution_overrideable_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::convolution_overrideable.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple convolution_backward_overrideable_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::convolution_backward_overrideable.out(Tensor grad_output, Tensor input, Tensor weight, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & _convolution_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, at::Tensor & out); // {"schema": "aten::_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, bool transposed, SymInt[] output_padding, SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & conv_tbc_out(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad, at::Tensor & out); // {"schema": "aten::conv_tbc.out(Tensor self, Tensor weight, Tensor bias, int pad=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & copy_out(const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out); // {"schema": "aten::copy.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _copy_from_out(const at::Tensor & self, const at::Tensor & dst, bool non_blocking, at::Tensor & out); // {"schema": "aten::_copy_from.out(Tensor self, Tensor dst, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _copy_from_and_resize_out(const at::Tensor & self, const at::Tensor & dst, at::Tensor & out); // {"schema": "aten::_copy_from_and_resize.out(Tensor self, Tensor dst, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & count_nonzero_out(const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::count_nonzero.dim_IntList_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & count_nonzero_out(const at::Tensor & self, ::std::optional dim, at::Tensor & out); // {"schema": "aten::count_nonzero.out(Tensor self, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_affine_grid_generator_out(const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out); // {"schema": "aten::cudnn_affine_grid_generator.out(Tensor theta, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_affine_grid_generator_backward_out(const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W, at::Tensor & out); // {"schema": "aten::cudnn_affine_grid_generator_backward.out(Tensor grad, int N, int C, int H, int W, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple cudnn_batch_norm_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::cudnn_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +::std::tuple cudnn_batch_norm_backward_out(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::cudnn_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, Tensor reserveSpace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_convolution_transpose_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, at::Tensor & out); // {"schema": "aten::cudnn_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, bool allow_tf32, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mps_convolution_transpose_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::_mps_convolution_transpose.out(Tensor self, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mps_convolution_transpose_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::mps_convolution_transpose_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_convolution_relu_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::cudnn_convolution_relu.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_convolution_add_relu_out(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::cudnn_convolution_add_relu.out(Tensor self, Tensor weight, Tensor z, Scalar? alpha, Tensor? bias, SymInt[] stride, SymInt[] padding, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & cudnn_grid_sampler_out(const at::Tensor & self, const at::Tensor & grid, at::Tensor & out); // {"schema": "aten::cudnn_grid_sampler.out(Tensor self, Tensor grid, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple cudnn_grid_sampler_backward_out(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::cudnn_grid_sampler_backward.out(Tensor self, Tensor grid, Tensor grad_output, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _ctc_loss_out(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_ctc_loss.out(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _ctc_loss_out(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_ctc_loss.Tensor_out(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank=0, bool zero_infinity=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _ctc_loss_backward_out(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity, at::Tensor & out); // {"schema": "aten::_ctc_loss_backward.out(Tensor grad, Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, Tensor neg_log_likelihood, Tensor log_alpha, int blank, bool zero_infinity=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diag_embed_out(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diag_embed.out(Tensor self, int offset=0, int dim1=-2, int dim2=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diagonal_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diagonal_backward.out(Tensor grad_output, SymInt[] input_sizes, int offset, int dim1, int dim2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::div.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & div_out(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode, at::Tensor & out); // {"schema": "aten::div.Scalar_mode_out(Tensor self, Scalar other, *, str? rounding_mode, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & embedding_out(const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse, at::Tensor & out); // {"schema": "aten::embedding.out(Tensor weight, Tensor indices, SymInt padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & embedding_dense_backward_out(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, at::Tensor & out); // {"schema": "aten::embedding_dense_backward.out(Tensor grad_output, Tensor indices, SymInt num_weights, SymInt padding_idx, bool scale_grad_by_freq, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & embedding_renorm_out(const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type, at::Tensor & out); // {"schema": "aten::embedding_renorm.out(Tensor self, Tensor indices, float max_norm, float norm_type, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor embedding_renorm(const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type); // {"schema": "aten::embedding_renorm(Tensor self, Tensor indices, float max_norm, float norm_type) -> Tensor", "dispatch": "True", "default": "True"} +::std::tuple _embedding_bag_forward_only_out(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::_embedding_bag_forward_only.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +::std::tuple _embedding_bag_out(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::_embedding_bag.out(Tensor weight, Tensor indices, Tensor offsets, bool scale_grad_by_freq=False, int mode=0, bool sparse=False, Tensor? per_sample_weights=None, bool include_last_offset=False, int padding_idx=-1, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +at::Tensor & _embedding_bag_dense_backward_out(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx, at::Tensor & out); // {"schema": "aten::_embedding_bag_dense_backward.out(Tensor grad, Tensor indices, Tensor offset2bag, Tensor bag_size, Tensor maximum_indices, SymInt num_weights, bool scale_grad_by_freq, int mode, Tensor? per_sample_weights, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _embedding_bag_per_sample_weights_backward_out(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx, at::Tensor & out); // {"schema": "aten::_embedding_bag_per_sample_weights_backward.out(Tensor grad, Tensor weight, Tensor indices, Tensor offsets, Tensor offset2bag, int mode, int padding_idx=-1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_out(at::IntArrayRef size, ::std::optional names, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty.names_out(int[] size, *, Dimname[]? names, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_permuted_out(c10::SymIntArrayRef size, at::IntArrayRef physical_layout, at::Tensor & out); // {"schema": "aten::empty_permuted.out(SymInt[] size, int[] physical_layout, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_empty_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::new_empty.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_empty_strided_out(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::new_empty_strided.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_full_out(const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, at::Tensor & out); // {"schema": "aten::new_full.out(Tensor self, SymInt[] size, Scalar fill_value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_zeros_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::new_zeros.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & new_ones_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::new_ones.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _empty_affine_quantized_out(c10::SymIntArrayRef size, double scale, int64_t zero_point, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::_empty_affine_quantized.out(SymInt[] size, *, float scale=1, int zero_point=0, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _empty_per_channel_affine_quantized_out(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::_empty_per_channel_affine_quantized.out(SymInt[] size, *, Tensor scales, Tensor zero_points, int axis, MemoryFormat? memory_format=contiguous_format, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & resize_out(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format, const at::Tensor & out); // {"schema": "aten::resize.out(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resize(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format); // {"schema": "aten::resize(Tensor self, SymInt[] size, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +const at::Tensor & _resize_output_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device, const at::Tensor & out); // {"schema": "aten::_resize_output.out(Tensor self, SymInt[] size, Device device, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _resize_output(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device); // {"schema": "aten::_resize_output(Tensor self, SymInt[] size, Device device) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & empty_quantized_out(at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty_quantized.out(int[] size, Tensor qtensor, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::empty_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & empty_strided_out(c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::empty_strided.out(SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & fill_out(const at::Tensor & self, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::fill.Scalar_out(Tensor self, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & fill_out(const at::Tensor & self, const at::Tensor & value, at::Tensor & out); // {"schema": "aten::fill.Tensor_out(Tensor self, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & floor_divide_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::floor_divide.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & full_out(at::IntArrayRef size, const at::Scalar & fill_value, ::std::optional names, at::Tensor & out); // {"schema": "aten::full.names_out(int[] size, Scalar fill_value, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & full_like_out(const at::Tensor & self, const at::Scalar & fill_value, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::full_like.out(Tensor self, Scalar fill_value, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & from_file_out(c10::string_view filename, ::std::optional shared, ::std::optional size, at::Tensor & out); // {"schema": "aten::from_file.out(str filename, bool? shared=None, int? size=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & grid_sampler_2d_out(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out); // {"schema": "aten::grid_sampler_2d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple grid_sampler_2d_backward_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::grid_sampler_2d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _grid_sampler_2d_cpu_fallback_out(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out); // {"schema": "aten::_grid_sampler_2d_cpu_fallback.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & grid_sampler_3d_out(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, at::Tensor & out); // {"schema": "aten::grid_sampler_3d.out(Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple grid_sampler_3d_backward_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::grid_sampler_3d_backward.out(Tensor grad_output, Tensor input, Tensor grid, int interpolation_mode, int padding_mode, bool align_corners, bool[2] output_mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & hann_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::hann_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hann_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::hann_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::hamming_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::hamming_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, bool periodic, double alpha, at::Tensor & out); // {"schema": "aten::hamming_window.periodic_alpha_out(int window_length, bool periodic, float alpha, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hamming_window_out(int64_t window_length, bool periodic, double alpha, double beta, at::Tensor & out); // {"schema": "aten::hamming_window.periodic_alpha_beta_out(int window_length, bool periodic, float alpha, float beta, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & kaiser_window_out(int64_t window_length, at::Tensor & out); // {"schema": "aten::kaiser_window.out(int window_length, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & kaiser_window_out(int64_t window_length, bool periodic, at::Tensor & out); // {"schema": "aten::kaiser_window.periodic_out(int window_length, bool periodic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & kaiser_window_out(int64_t window_length, bool periodic, double beta, at::Tensor & out); // {"schema": "aten::kaiser_window.beta_out(int window_length, bool periodic, float beta, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple native_group_norm_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_group_norm.out(Tensor input, Tensor? weight, Tensor? bias, SymInt N, SymInt C, SymInt HxW, int group, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple native_group_norm_backward_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_group_norm_backward.out(Tensor grad_out, Tensor input, Tensor mean, Tensor rstd, Tensor? weight, SymInt N, SymInt C, SymInt HxW, int group, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & index_put_out(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, at::Tensor & out); // {"schema": "aten::index_put.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _index_put_impl_out(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe, at::Tensor & out); // {"schema": "aten::_index_put_impl.out(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _index_put_impl(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe); // {"schema": "aten::_index_put_impl(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & isnan_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isnan.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple native_layer_norm_out(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_layer_norm.out(Tensor input, SymInt[] normalized_shape, Tensor? weight, Tensor? bias, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple native_layer_norm_backward_out(const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_layer_norm_backward.out(Tensor grad_out, Tensor input, SymInt[] normalized_shape, Tensor mean, Tensor rstd, Tensor? weight, Tensor? bias, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple linear_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_linear_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, at::Tensor & out); // {"schema": "aten::mkldnn_linear.out(Tensor self, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_linear_backward_input_out(at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight, at::Tensor & out); // {"schema": "aten::mkldnn_linear_backward_input.out(int[] input_size, Tensor grad_output, Tensor weight, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_linear_backward_weights_out(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::mkldnn_linear_backward_weights.out(Tensor grad_output, Tensor input, Tensor weight, bool bias_defined, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_linear_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::mkldnn_linear_backward.out(Tensor self, Tensor grad_output, Tensor weight, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple matmul_backward_out(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::matmul_backward.out(Tensor grad, Tensor self, Tensor other, bool[2] mask, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _aminmax_out(const at::Tensor & self, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_aminmax.out(Tensor self, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _aminmax_out(const at::Tensor & self, int64_t dim, bool keepdim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_aminmax.dim_out(Tensor self, int dim, bool keepdim=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::max_pool2d_backward.out(Tensor grad_output, Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool2d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_max_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::mkldnn_max_pool3d_backward.out(Tensor grad_output, Tensor output, Tensor input, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_max_pool1d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::quantized_max_pool1d.out(Tensor self, int[1] kernel_size, int[1] stride=[], int[1] padding=0, int[1] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_max_pool2d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::quantized_max_pool2d.out(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantized_max_pool3d_out(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, at::Tensor & out); // {"schema": "aten::quantized_max_pool3d.out(Tensor self, int[3] kernel_size, int[3] stride=[], int[3] padding=0, int[3] dilation=1, bool ceil_mode=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & median_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::median.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & nanmedian_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::nanmedian.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mps_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::_mps_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mps_convolution_backward_out(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::mps_convolution_backward.out(Tensor self, Tensor grad_output, Tensor weight, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::mkldnn_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_rnn_layer_out(const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::mkldnn_rnn_layer.out(Tensor input, Tensor weight0, Tensor weight1, Tensor weight2, Tensor weight3, Tensor hx_, Tensor cx_, bool reverse, int[] batch_sizes, int mode, int hidden_size, int num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +::std::tuple mkldnn_rnn_layer_backward_out(const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5, at::Tensor & out6); // {"schema": "aten::mkldnn_rnn_layer_backward.out(Tensor input, Tensor weight1, Tensor weight2, Tensor weight3, Tensor weight4, Tensor hx_, Tensor cx_tmp, Tensor output, Tensor hy_, Tensor cy_, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, bool reverse, int mode, int hidden_size, int num_layers, bool has_biases, bool train, bool bidirectional, int[] batch_sizes, bool batch_first, Tensor workspace, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5, Tensor(g!) out6) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!), Tensor(g!))", "dispatch": "True", "default": "True"} +::std::tuple miopen_batch_norm_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::miopen_batch_norm.out(Tensor input, Tensor weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float exponential_average_factor, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple miopen_batch_norm_backward_out(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::miopen_batch_norm_backward.out(Tensor input, Tensor grad_output, Tensor weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_var, float epsilon, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & miopen_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out); // {"schema": "aten::miopen_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & miopen_convolution_transpose_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out); // {"schema": "aten::miopen_convolution_transpose.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] output_padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & miopen_depthwise_convolution_out(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, at::Tensor & out); // {"schema": "aten::miopen_depthwise_convolution.out(Tensor self, Tensor weight, Tensor? bias, SymInt[] padding, SymInt[] stride, SymInt[] dilation, SymInt groups, bool benchmark, bool deterministic, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple miopen_rnn_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4); // {"schema": "aten::miopen_rnn.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor hx, Tensor? cx, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"} +void miopen_rnn_backward_out(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3); // {"schema": "aten::miopen_rnn_backward.out(Tensor input, Tensor[] weight, int weight_stride0, Tensor weight_buf, Tensor hx, Tensor? cx, Tensor output, Tensor? grad_output, Tensor? grad_hy, Tensor? grad_cy, int mode, int hidden_size, int num_layers, bool batch_first, float dropout, bool train, bool bidirectional, int[] batch_sizes, Tensor? dropout_state, Tensor reserve, bool[4] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!)[] out3) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_sparse_matmul_out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::_sparse_sparse_matmul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mul_out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _native_batch_norm_legit_functional(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps); // {"schema": "aten::_native_batch_norm_legit_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out)", "dispatch": "True", "default": "True"} +::std::tuple _native_batch_norm_legit_no_training_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_native_batch_norm_legit_no_training.out(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_stats_out(const at::Tensor & input, double eps, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_stats.out(Tensor input, float eps, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_gather_stats_out(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_gather_stats.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, int count, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_gather_stats_with_counts_out(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_gather_stats_with_counts.out(Tensor input, Tensor mean, Tensor invstd, Tensor? running_mean, Tensor? running_var, float momentum, float eps, Tensor counts, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple native_batch_norm_backward_out(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::native_batch_norm_backward.out(Tensor grad_out, Tensor input, Tensor? weight, Tensor? running_mean, Tensor? running_var, Tensor? save_mean, Tensor? save_invstd, bool train, float eps, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_backward_reduce_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::batch_norm_backward_reduce.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, bool input_g, bool weight_g, bool bias_g, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +at::Tensor & batch_norm_backward_elemt_out(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count, at::Tensor & out); // {"schema": "aten::batch_norm_backward_elemt.out(Tensor grad_out, Tensor input, Tensor mean, Tensor invstd, Tensor? weight, Tensor sum_dy, Tensor sum_dy_xmu, Tensor count, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple batch_norm_update_stats_out(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::batch_norm_update_stats.out(Tensor input, Tensor? running_mean, Tensor? running_var, float momentum, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _nnpack_spatial_convolution_out(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::_nnpack_spatial_convolution.out(Tensor input, Tensor weight, Tensor? bias, SymInt[2] padding, SymInt[2] stride=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ones_out(at::IntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::ones.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ones_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::ones_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _euclidean_dist_out(const at::Tensor & x1, const at::Tensor & x2, at::Tensor & out); // {"schema": "aten::_euclidean_dist.out(Tensor x1, Tensor x2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _cdist_forward_out(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode, at::Tensor & out); // {"schema": "aten::_cdist_forward.out(Tensor x1, Tensor x2, float p, int? compute_mode, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _cdist_backward_out(const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist, at::Tensor & out); // {"schema": "aten::_cdist_backward.out(Tensor grad, Tensor x1, Tensor x2, float p, Tensor cdist, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _pdist_forward_out(const at::Tensor & self, double p, at::Tensor & out); // {"schema": "aten::_pdist_forward.out(Tensor self, float p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _pdist_backward_out(const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist, at::Tensor & out); // {"schema": "aten::_pdist_backward.out(Tensor grad, Tensor self, float p, Tensor pdist, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & pixel_shuffle_out(const at::Tensor & self, int64_t upscale_factor, at::Tensor & out); // {"schema": "aten::pixel_shuffle.out(Tensor self, int upscale_factor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & pixel_unshuffle_out(const at::Tensor & self, int64_t downscale_factor, at::Tensor & out); // {"schema": "aten::pixel_unshuffle.out(Tensor self, int downscale_factor, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & channel_shuffle_out(const at::Tensor & self, c10::SymInt groups, at::Tensor & out); // {"schema": "aten::channel_shuffle.out(Tensor self, SymInt groups, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _pin_memory_out(const at::Tensor & self, ::std::optional device, at::Tensor & out); // {"schema": "aten::_pin_memory.out(Tensor self, Device? device=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & scalar_tensor_out(const at::Scalar & s, at::Tensor & out); // {"schema": "aten::scalar_tensor.out(Scalar s, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::rand.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_out(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out); // {"schema": "aten::rand.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rand_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::rand_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, c10::SymInt high, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.out(Tensor self, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, const at::Tensor & high, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.Tensor_out(Tensor self, Tensor high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randint_like_out(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randint_like.low_dtype_out(Tensor self, SymInt low, SymInt high, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::randn.names_out(SymInt[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_out(c10::SymIntArrayRef size, ::std::optional generator, ::std::optional names, at::Tensor & out); // {"schema": "aten::randn.generator_with_names_out(SymInt[] size, *, Generator? generator, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & randn_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::randn_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & repeat_out(const at::Tensor & self, c10::SymIntArrayRef repeats, at::Tensor & out); // {"schema": "aten::repeat.out(Tensor self, SymInt[] repeats, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & repeat_interleave_out(const at::Tensor & repeats, ::std::optional output_size, at::Tensor & out); // {"schema": "aten::repeat_interleave.Tensor_out(Tensor repeats, *, SymInt? output_size=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mkldnn_reshape_out(const at::Tensor & self, at::IntArrayRef shape, at::Tensor & out); // {"schema": "aten::_mkldnn_reshape.out(Tensor self, int[] shape, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & relu_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::relu.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & select_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index, at::Tensor & out); // {"schema": "aten::select_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & celu_out(const at::Tensor & self, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::celu.out(Tensor self, Scalar alpha=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slice_backward_out(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step, at::Tensor & out); // {"schema": "aten::slice_backward.out(Tensor grad_output, SymInt[] input_sizes, int dim, SymInt start, SymInt end, SymInt step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slice_scatter_out(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out); // {"schema": "aten::slice_scatter.out(Tensor self, Tensor src, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & select_scatter_out(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index, at::Tensor & out); // {"schema": "aten::select_scatter.out(Tensor self, Tensor src, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diagonal_scatter_out(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diagonal_scatter.out(Tensor self, Tensor src, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & as_strided_scatter_out(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); // {"schema": "aten::as_strided_scatter.out(Tensor self, Tensor src, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void unsafe_split_out(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out); // {"schema": "aten::unsafe_split.Tensor_out(Tensor self, SymInt split_size, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void unsafe_split_with_sizes_out(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out); // {"schema": "aten::unsafe_split_with_sizes.out(Tensor self, SymInt[] split_sizes, int dim=0, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & sum_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::sum.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple std_mean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::std_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & prod_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::prod.out(Tensor self, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _mkldnn_transpose_out(const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out); // {"schema": "aten::_mkldnn_transpose.out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & flip_out(const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::flip.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & roll_out(const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::roll.out(Tensor self, SymInt[1] shifts, int[1] dims=[], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rot90_out(const at::Tensor & self, int64_t k, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::rot90.out(Tensor self, int k=1, int[] dims=[0,1], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _transform_bias_rescale_qkv_out(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_transform_bias_rescale_qkv.out(Tensor qkv, Tensor qkv_bias, int num_heads, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_from_mask_out(const at::Tensor & t, const at::Tensor & mask, bool mask_check, at::Tensor & out); // {"schema": "aten::_nested_tensor_from_mask.out(Tensor t, Tensor mask, bool mask_check=True, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_from_padded_out(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213, at::Tensor & out); // {"schema": "aten::_nested_from_padded.out(Tensor padded, Tensor cpu_nested_shape_example, bool fuse_transform_0213=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_size_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_tensor_size.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_strides_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_tensor_strides.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_storage_offsets_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_tensor_storage_offsets.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_from_padded_and_nested_example_out(const at::Tensor & padded, const at::Tensor & nt_example, at::Tensor & out); // {"schema": "aten::_nested_from_padded_and_nested_example.out(Tensor padded, Tensor nt_example, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_view_from_buffer_copy_out(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets, at::Tensor & out); // {"schema": "aten::_nested_view_from_buffer_copy.out(Tensor self, Tensor nested_size, Tensor nested_strides, Tensor offsets, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_view_from_jagged_copy_out(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, at::Tensor & out); // {"schema": "aten::_nested_view_from_jagged_copy.out(Tensor self, Tensor offsets, Tensor dummy, Tensor? lengths=None, int ragged_idx=1, Tensor? min_seqlen=None, Tensor? max_seqlen=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_get_values_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_nested_get_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _trilinear_out(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim, at::Tensor & out); // {"schema": "aten::_trilinear.out(Tensor i1, Tensor i2, Tensor i3, int[] expand1, int[] expand2, int[] expand3, int[] sumdim, int unroll_dim=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _unique_out(const at::Tensor & self, bool sorted, bool return_inverse, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_unique.out(Tensor self, bool sorted=True, bool return_inverse=False, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple unique_dim_out(const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::unique_dim.out(Tensor self, int dim, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple unique_consecutive_out(const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::unique_consecutive.out(Tensor self, bool return_inverse=False, bool return_counts=False, int? dim=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple unique_dim_consecutive_out(const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::unique_dim_consecutive.out(Tensor self, int dim, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple _unique2_out(const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_unique2.out(Tensor self, bool sorted=True, bool return_inverse=False, bool return_counts=False, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & _unsafe_view_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::_unsafe_view.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple var_mean_out(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::var_mean.correction_out(Tensor self, int[1]? dim=None, *, Scalar? correction=None, bool keepdim=False, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _weight_norm_interface_out(const at::Tensor & v, const at::Tensor & g, int64_t dim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_weight_norm_interface.out(Tensor v, Tensor g, int dim=0, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _weight_norm_interface_backward_out(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_weight_norm_interface_backward.out(Tensor grad_w, Tensor saved_v, Tensor saved_g, Tensor saved_norms, int dim, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & zeros_out(at::IntArrayRef size, ::std::optional names, at::Tensor & out); // {"schema": "aten::zeros.names_out(int[] size, *, Dimname[]? names, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _efficientzerotensor_out(c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::_efficientzerotensor.out(SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & zeros_like_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::zeros_like.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _standard_gamma_grad_out(const at::Tensor & self, const at::Tensor & output, at::Tensor & out); // {"schema": "aten::_standard_gamma_grad.out(Tensor self, Tensor output, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _standard_gamma_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::_standard_gamma.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _dirichlet_grad_out(const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total, at::Tensor & out); // {"schema": "aten::_dirichlet_grad.out(Tensor x, Tensor alpha, Tensor total, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sample_dirichlet_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::_sample_dirichlet.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & poisson_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::poisson.out(Tensor self, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & binomial_out(const at::Tensor & count, const at::Tensor & prob, ::std::optional generator, at::Tensor & out); // {"schema": "aten::binomial.out(Tensor count, Tensor prob, Generator? generator=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & native_norm_out(const at::Tensor & self, const at::Scalar & p, at::Tensor & out); // {"schema": "aten::native_norm.out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & native_norm_out(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::native_norm.ScalarOpt_dim_dtype_out(Tensor self, Scalar? p, int[1] dim, bool keepdim, ScalarType? dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _batch_norm_with_update_functional(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps); // {"schema": "aten::_batch_norm_with_update_functional(Tensor input, Tensor? weight, Tensor? bias, Tensor running_mean, Tensor running_var, float momentum, float eps) -> (Tensor, Tensor, Tensor, Tensor, Tensor running_mean_out, Tensor running_var_out)", "dispatch": "True", "default": "True"} +::std::tuple _batch_norm_no_update_out(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3); // {"schema": "aten::_batch_norm_no_update.out(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, float momentum, float eps, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!))", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_sum_out(const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::_sparse_sum.dim_out(Tensor self, int[1] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_sum_backward_out(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::_sparse_sum_backward.out(Tensor grad, Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_csr_sum_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::_sparse_csr_sum.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_csr_prod_out(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::_sparse_csr_prod.dim_dtype_out(Tensor self, int[1] dim, bool keepdim=False, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_sparse_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_sparse_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_log_softmax_out(const at::Tensor & self, int64_t dim, bool half_to_float, at::Tensor & out); // {"schema": "aten::_sparse_log_softmax.out(Tensor self, int dim, bool half_to_float, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_log_softmax_backward_data_out(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_sparse_log_softmax_backward_data.out(Tensor grad_output, Tensor output, int dim, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _spdiags_out(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout, at::Tensor & out); // {"schema": "aten::_spdiags.out(Tensor diagonals, Tensor offsets, int[] shape, Layout? layout=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::norm.ScalarOpt_dtype_out(Tensor self, Scalar? p, *, ScalarType dtype, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & norm_out(const at::Tensor & self, const at::Scalar & p, at::Tensor & out); // {"schema": "aten::norm.Scalar_out(Tensor self, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & clone_out(const at::Tensor & self, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::clone.out(Tensor self, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & resize_as_out(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format, const at::Tensor & out); // {"schema": "aten::resize_as.out(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resize_as(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format); // {"schema": "aten::resize_as(Tensor self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor", "dispatch": "True", "default": "True"} +const at::Tensor & resize_as_sparse_out(const at::Tensor & self, const at::Tensor & the_template, const at::Tensor & out); // {"schema": "aten::resize_as_sparse.out(Tensor self, Tensor the_template, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor resize_as_sparse(const at::Tensor & self, const at::Tensor & the_template); // {"schema": "aten::resize_as_sparse(Tensor self, Tensor the_template) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & zero_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::zero.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor zero(const at::Tensor & self); // {"schema": "aten::zero(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sub_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::sub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rsub_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::rsub.Tensor_out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & rsub_out(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::rsub.Scalar_out(Tensor self, Scalar other, Scalar alpha=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_addmm_out(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, at::Tensor & out); // {"schema": "aten::_sparse_addmm.out(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & sparse_coo_tensor_out(at::IntArrayRef size, at::Tensor & out); // {"schema": "aten::sparse_coo_tensor.size_out(int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_coo_tensor_with_dims_out(int64_t sparse_dim, int64_t dense_dim, at::IntArrayRef size, at::Tensor & out); // {"schema": "aten::_sparse_coo_tensor_with_dims.out(int sparse_dim, int dense_dim, int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_coo_tensor_with_dims_and_tensors_out(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional is_coalesced, at::Tensor & out); // {"schema": "aten::_sparse_coo_tensor_with_dims_and_tensors.out(int sparse_dim, int dense_dim, SymInt[] size, Tensor indices, Tensor values, *, bool? is_coalesced=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +const at::Tensor & sparse_resize_out(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out); // {"schema": "aten::sparse_resize.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor sparse_resize(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor", "dispatch": "True", "default": "True"} +const at::Tensor & sparse_resize_and_clear_out(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim, const at::Tensor & out); // {"schema": "aten::sparse_resize_and_clear.out(Tensor self, int[] size, int sparse_dim, int dense_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor sparse_resize_and_clear(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim); // {"schema": "aten::sparse_resize_and_clear(Tensor self, int[] size, int sparse_dim, int dense_dim) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & sparse_mask_out(const at::Tensor & self, const at::Tensor & mask, at::Tensor & out); // {"schema": "aten::sparse_mask.out(Tensor self, Tensor mask, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_mask_projection_out(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches, at::Tensor & out); // {"schema": "aten::_sparse_mask_projection.out(Tensor self, Tensor mask, bool accumulate_matches=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_dense_out(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad, at::Tensor & out); // {"schema": "aten::_to_dense.out(Tensor self, ScalarType? dtype=None, bool? masked_grad=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _coalesce_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_coalesce.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _coalesced_out(const at::Tensor & self, bool coalesced, at::Tensor & out); // {"schema": "aten::_coalesced.out(Tensor self, bool coalesced, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor _coalesced(const at::Tensor & self, bool coalesced); // {"schema": "aten::_coalesced(Tensor self, bool coalesced) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & copy_sparse_to_sparse_out(const at::Tensor & self, const at::Tensor & src, bool non_blocking, at::Tensor & out); // {"schema": "aten::copy_sparse_to_sparse.out(Tensor self, Tensor src, bool non_blocking=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor copy_sparse_to_sparse(const at::Tensor & self, const at::Tensor & src, bool non_blocking); // {"schema": "aten::copy_sparse_to_sparse(Tensor self, Tensor src, bool non_blocking=False) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_out(const at::Tensor & self, int64_t sparse_dim, at::Tensor & out); // {"schema": "aten::_to_sparse.sparse_dim_out(Tensor self, int sparse_dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_out(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse.out(Tensor self, *, Layout? layout=None, int[2]? blocksize=None, int? dense_dim=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_csr_out(const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_csr.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_csc_out(const at::Tensor & self, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_csc.out(Tensor self, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_bsr_out(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_bsr.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _to_sparse_bsc_out(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim, at::Tensor & out); // {"schema": "aten::_to_sparse_bsc.out(Tensor self, int[2] blocksize, int? dense_dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & to_mkldnn_out(const at::Tensor & self, ::std::optional dtype, at::Tensor & out); // {"schema": "aten::to_mkldnn.out(Tensor self, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_reorder_conv2d_weight_out(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out); // {"schema": "aten::mkldnn_reorder_conv2d_weight.out(Tensor self, SymInt[2] padding=0, SymInt[2] stride=1, SymInt[2] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_reorder_conv3d_weight_out(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size, at::Tensor & out); // {"schema": "aten::mkldnn_reorder_conv3d_weight.out(Tensor self, SymInt[3] padding=0, SymInt[3] stride=1, SymInt[3] dilation=1, SymInt groups=1, SymInt[]? input_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_tensor_dynamic_out(const at::Tensor & self, at::ScalarType dtype, bool reduce_range, at::Tensor & out); // {"schema": "aten::quantize_per_tensor_dynamic.out(Tensor self, ScalarType dtype, bool reduce_range, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_tensor_out(const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::quantize_per_tensor.out(Tensor self, float scale, int zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_tensor_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::quantize_per_tensor.tensor_qparams_out(Tensor self, Tensor scale, Tensor zero_point, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void quantize_per_tensor_out(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out); // {"schema": "aten::quantize_per_tensor.tensors_out(Tensor[] tensors, Tensor scales, Tensor zero_points, ScalarType dtype, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & quantize_per_channel_out(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::quantize_per_channel.out(Tensor self, Tensor scales, Tensor zero_points, int axis, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & dequantize_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::dequantize.self_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void dequantize_out(at::TensorList tensors, at::TensorList out); // {"schema": "aten::dequantize.tensors_out(Tensor[] tensors, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & q_per_channel_scales_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::q_per_channel_scales.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & q_per_channel_zero_points_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::q_per_channel_zero_points.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & int_repr_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::int_repr.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _make_per_tensor_quantized_tensor_out(const at::Tensor & self, double scale, int64_t zero_point, at::Tensor & out); // {"schema": "aten::_make_per_tensor_quantized_tensor.out(Tensor self, float scale, int zero_point, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _make_per_channel_quantized_tensor_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, at::Tensor & out); // {"schema": "aten::_make_per_channel_quantized_tensor.out(Tensor self, Tensor scale, Tensor zero_point, int axis, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple fake_quantize_per_tensor_affine_cachemask_out(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::fake_quantize_per_tensor_affine_cachemask.out(Tensor self, float scale, int zero_point, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams.out(Tensor self, Tensor scale, Tensor zero_point, Tensor fake_quant_enabled, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _fake_quantize_learnable_per_tensor_affine_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out); // {"schema": "aten::_fake_quantize_learnable_per_tensor_affine.out(Tensor self, Tensor scale, Tensor zero_point, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple fake_quantize_per_channel_affine_cachemask_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::fake_quantize_per_channel_affine_cachemask.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _fake_quantize_learnable_per_channel_affine_out(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor, at::Tensor & out); // {"schema": "aten::_fake_quantize_learnable_per_channel_affine.out(Tensor self, Tensor scale, Tensor zero_point, int axis, int quant_min, int quant_max, float grad_factor=1.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _fused_moving_avg_obs_fq_helper_out(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, at::Tensor & running_min, at::Tensor & running_max, at::Tensor & scale, at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_fused_moving_avg_obs_fq_helper.out(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor(a!) running_min, Tensor(b!) running_max, Tensor(c!) scale, Tensor(d!) zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False, *, Tensor(e!) out0, Tensor(f!) out1) -> (Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "True"} +::std::tuple _fused_moving_avg_obs_fq_helper_functional(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant); // {"schema": "aten::_fused_moving_avg_obs_fq_helper_functional(Tensor self, Tensor observer_on, Tensor fake_quant_on, Tensor running_min, Tensor running_max, Tensor scale, Tensor zero_point, float averaging_const, int quant_min, int quant_max, int ch_axis, bool per_row_fake_quant=False, bool symmetric_quant=False) -> (Tensor output, Tensor mask, Tensor running_min_out, Tensor running_max_out, Tensor scale_out, Tensor zero_point_out)", "dispatch": "True", "default": "True"} +at::Tensor & _to_copy_out(const at::Tensor & self, bool non_blocking, ::std::optional memory_format, at::Tensor & out); // {"schema": "aten::_to_copy.out(Tensor self, *, bool non_blocking=False, MemoryFormat? memory_format=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _lstm_mps_out(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4, at::Tensor & out5); // {"schema": "aten::_lstm_mps.out(Tensor input, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4, Tensor(f!) out5) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!), Tensor(f!))", "dispatch": "True", "default": "True"} +void lstm_mps_backward_out(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2); // {"schema": "aten::lstm_mps_backward.out(Tensor? grad_y, Tensor? grad_hy, Tensor? grad_cy, Tensor z_state, Tensor cell_state_fwd, Tensor input, Tensor layersOutputs, Tensor[] hx, Tensor[] params, bool has_biases, int num_layers, float dropout, bool train, bool bidirectional, bool batch_first, *, Tensor(a!) out0, Tensor(b!)[] out1, Tensor(c!)[] out2) -> ()", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_lstm_cell_out(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_thnn_fused_lstm_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor cx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_lstm_cell_backward_impl_out(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_thnn_fused_lstm_cell_backward_impl.out(Tensor? grad_hy, Tensor? grad_cy, Tensor cx, Tensor cy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_gru_cell_out(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_thnn_fused_gru_cell.out(Tensor input_gates, Tensor hidden_gates, Tensor hx, Tensor? input_bias=None, Tensor? hidden_bias=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +::std::tuple _thnn_fused_gru_cell_backward_out(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::Tensor & out3, at::Tensor & out4); // {"schema": "aten::_thnn_fused_gru_cell_backward.out(Tensor grad_hy, Tensor workspace, bool has_bias, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2, Tensor(d!) out3, Tensor(e!) out4) -> (Tensor(a!), Tensor(b!), Tensor(c!), Tensor(d!), Tensor(e!))", "dispatch": "True", "default": "True"} +::std::tuple _pack_padded_sequence_out(const at::Tensor & input, const at::Tensor & lengths, bool batch_first, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_pack_padded_sequence.out(Tensor input, Tensor lengths, bool batch_first, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, at::Storage source, at::Tensor & out); // {"schema": "aten::set.source_Storage_out(Tensor self, Storage source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self, at::Storage source); // {"schema": "aten::set.source_Storage(Tensor self, Storage source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::set.source_Storage_storage_offset_out(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[], *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride); // {"schema": "aten::set.source_Storage_storage_offset(Tensor self, Storage source, SymInt storage_offset, SymInt[] size, SymInt[] stride=[]) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, const at::Tensor & source, at::Tensor & out); // {"schema": "aten::set.source_Tensor_out(Tensor self, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self, const at::Tensor & source); // {"schema": "aten::set.source_Tensor(Tensor self, Tensor source) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & set_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::set.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor set(const at::Tensor & self); // {"schema": "aten::set(Tensor self) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & lift_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::lift.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & lift_fresh_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::lift_fresh_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & masked_fill_out(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::masked_fill.Scalar_out(Tensor self, Tensor mask, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & masked_fill_out(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value, at::Tensor & out); // {"schema": "aten::masked_fill.Tensor_out(Tensor self, Tensor mask, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & masked_scatter_out(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source, at::Tensor & out); // {"schema": "aten::masked_scatter.out(Tensor self, Tensor mask, Tensor source, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _masked_softmax_out(const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type, at::Tensor & out); // {"schema": "aten::_masked_softmax.out(Tensor self, Tensor mask, int? dim=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _masked_softmax_backward_out(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim, at::Tensor & out); // {"schema": "aten::_masked_softmax_backward.out(Tensor grad_output, Tensor output, Tensor mask, int? dim=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & put_out(const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate, at::Tensor & out); // {"schema": "aten::put.out(Tensor self, Tensor index, Tensor source, bool accumulate=False, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, at::Tensor & out); // {"schema": "aten::index_fill.int_Scalar_out(Tensor self, int dim, Tensor index, Scalar value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & index_fill_out(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value, at::Tensor & out); // {"schema": "aten::index_fill.int_Tensor_out(Tensor self, int dim, Tensor index, Tensor value, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_and_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_and.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_or_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_or.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_xor_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_xor.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __lshift___out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::__lshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __lshift___out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::__lshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_left_shift_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_left_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __rshift___out(const at::Tensor & self, const at::Scalar & other, at::Tensor & out); // {"schema": "aten::__rshift__.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & __rshift___out(const at::Tensor & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::__rshift__.Tensor_out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & bitwise_right_shift_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::bitwise_right_shift.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & random_out(const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator, at::Tensor & out); // {"schema": "aten::random.from_out(Tensor self, int from, int? to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor random(const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator); // {"schema": "aten::random.from(Tensor self, int from, int? to, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & random_out(const at::Tensor & self, int64_t to, ::std::optional generator, at::Tensor & out); // {"schema": "aten::random.to_out(Tensor self, int to, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor random(const at::Tensor & self, int64_t to, ::std::optional generator); // {"schema": "aten::random.to(Tensor self, int to, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & random_out(const at::Tensor & self, ::std::optional generator, at::Tensor & out); // {"schema": "aten::random.out(Tensor self, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor random(const at::Tensor & self, ::std::optional generator); // {"schema": "aten::random(Tensor self, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & uniform_out(const at::Tensor & self, double from, double to, ::std::optional generator, at::Tensor & out); // {"schema": "aten::uniform.out(Tensor self, float from=0, float to=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor uniform(const at::Tensor & self, double from, double to, ::std::optional generator); // {"schema": "aten::uniform(Tensor self, float from=0, float to=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & cauchy_out(const at::Tensor & self, double median, double sigma, ::std::optional generator, at::Tensor & out); // {"schema": "aten::cauchy.out(Tensor self, float median=0, float sigma=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor cauchy(const at::Tensor & self, double median, double sigma, ::std::optional generator); // {"schema": "aten::cauchy(Tensor self, float median=0, float sigma=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & log_normal_out(const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::log_normal.out(Tensor self, float mean=1, float std=2, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor log_normal(const at::Tensor & self, double mean, double std, ::std::optional generator); // {"schema": "aten::log_normal(Tensor self, float mean=1, float std=2, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & exponential_out(const at::Tensor & self, double lambd, ::std::optional generator, at::Tensor & out); // {"schema": "aten::exponential.out(Tensor self, float lambd=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor exponential(const at::Tensor & self, double lambd, ::std::optional generator); // {"schema": "aten::exponential(Tensor self, float lambd=1, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & geometric_out(const at::Tensor & self, double p, ::std::optional generator, at::Tensor & out); // {"schema": "aten::geometric.out(Tensor self, float p, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor geometric(const at::Tensor & self, double p, ::std::optional generator); // {"schema": "aten::geometric(Tensor self, float p, *, Generator? generator=None) -> Tensor", "dispatch": "True", "default": "True"} +at::Tensor & tril_indices_out(int64_t row, int64_t col, int64_t offset, at::Tensor & out); // {"schema": "aten::tril_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & triu_indices_out(int64_t row, int64_t col, int64_t offset, at::Tensor & out); // {"schema": "aten::triu_indices.out(int row, int col, int offset=0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & trace_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::trace.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _cholesky_solve_helper_out(const at::Tensor & self, const at::Tensor & A, bool upper, at::Tensor & out); // {"schema": "aten::_cholesky_solve_helper.out(Tensor self, Tensor A, bool upper, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & dist_out(const at::Tensor & self, const at::Tensor & other, const at::Scalar & p, at::Tensor & out); // {"schema": "aten::dist.out(Tensor self, Tensor other, Scalar p=2, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void _histogramdd_bin_edges_out(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::TensorList out); // {"schema": "aten::_histogramdd_bin_edges.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & _histogramdd_from_bin_cts_out(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::Tensor & out); // {"schema": "aten::_histogramdd_from_bin_cts.out(Tensor self, int[] bins, *, float[]? range=None, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _histogramdd_from_bin_tensors_out(const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density, at::Tensor & out); // {"schema": "aten::_histogramdd_from_bin_tensors.out(Tensor self, Tensor[] bins, *, Tensor? weight=None, bool density=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & remainder_out(const at::Scalar & self, const at::Tensor & other, at::Tensor & out); // {"schema": "aten::remainder.Scalar_Tensor_out(Scalar self, Tensor other, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & unfold_backward_out(const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step, at::Tensor & out); // {"schema": "aten::unfold_backward.out(Tensor grad_in, SymInt[] input_sizes, int dim, int size, int step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & normal_out(const at::Tensor & self, double mean, double std, ::std::optional generator, at::Tensor & out); // {"schema": "aten::normal.out(Tensor self, float mean=0, float std=1, *, Generator? generator=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void _amp_foreach_non_finite_check_and_unscale_out(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale.out(Tensor[] self, Tensor(b!) found_inf, Tensor inv_scale, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale(at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale); // {"schema": "aten::_amp_foreach_non_finite_check_and_unscale(Tensor[] self, Tensor found_inf, Tensor inv_scale) -> (Tensor[] self_out, Tensor found_inf_out)", "dispatch": "True", "default": "True"} +at::Tensor & _amp_update_scale_out(const at::Tensor & self, at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval, at::Tensor & out); // {"schema": "aten::_amp_update_scale.out(Tensor self, Tensor(b!) growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _amp_update_scale(const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval); // {"schema": "aten::_amp_update_scale(Tensor self, Tensor growth_tracker, Tensor found_inf, float scale_growth_factor, float scale_backoff_factor, int growth_interval) -> (Tensor, Tensor growth_tracker_out)", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_add.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out); // {"schema": "aten::_foreach_add.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_add.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_add_out(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out); // {"schema": "aten::_foreach_add.Tensor_out(Tensor[] self, Tensor other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sub_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_sub.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sub_out(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out); // {"schema": "aten::_foreach_sub.List_out(Tensor[] self, Tensor[] other, *, Scalar alpha=1, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sub_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_sub.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_mul.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_mul.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_mul.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_mul_out(at::TensorList self, const at::Tensor & other, at::TensorList out); // {"schema": "aten::_foreach_mul.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_div.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_div.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_div.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_div_out(at::TensorList self, const at::Tensor & other, at::TensorList out); // {"schema": "aten::_foreach_div.Tensor_out(Tensor[] self, Tensor other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_clamp_max.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_clamp_max.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_max_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_clamp_max.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_clamp_min.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_clamp_min.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_clamp_min_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_clamp_min.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_maximum_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_maximum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_maximum_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_maximum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_maximum_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_maximum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_minimum_out(at::TensorList self, const at::Scalar & scalar, at::TensorList out); // {"schema": "aten::_foreach_minimum.Scalar_out(Tensor[] self, Scalar scalar, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_minimum_out(at::TensorList self, at::TensorList other, at::TensorList out); // {"schema": "aten::_foreach_minimum.List_out(Tensor[] self, Tensor[] other, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_minimum_out(at::TensorList self, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_minimum.ScalarList_out(Tensor[] self, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out); // {"schema": "aten::_foreach_addcdiv.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_addcdiv.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcdiv_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out); // {"schema": "aten::_foreach_addcdiv.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out); // {"schema": "aten::_foreach_addcmul.Scalar_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar value=1, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out); // {"schema": "aten::_foreach_addcmul.ScalarList_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Scalar[] scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_addcmul_out(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out); // {"schema": "aten::_foreach_addcmul.Tensor_out(Tensor[] self, Tensor[] tensor1, Tensor[] tensor2, Tensor scalars, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_abs_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_abs.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_acos_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_acos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_asin_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_asin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_atan_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_atan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_ceil_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_ceil.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_cos_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_cos.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_cosh_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_cosh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_erf_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_erf.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_erfc_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_erfc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_exp_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_exp.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_expm1_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_expm1.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_floor_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_floor.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_frac_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_frac.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lerp_out(at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out); // {"schema": "aten::_foreach_lerp.List_out(Tensor[] self, Tensor[] tensors1, Tensor[] weights, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lerp_out(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out); // {"schema": "aten::_foreach_lerp.Scalar_out(Tensor[] self, Tensor[] tensors1, Scalar weight, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lerp_out(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight, at::TensorList out); // {"schema": "aten::_foreach_lerp.ScalarList_out(Tensor[] self, Tensor[] tensors1, Scalar[] weight, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_lgamma_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_lgamma.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log10_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log10.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log1p_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log1p.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_log2_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_log2.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_max_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_max.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_neg_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_neg.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_norm_out(at::TensorList self, const at::Scalar & ord, ::std::optional dtype, at::TensorList out); // {"schema": "aten::_foreach_norm.Scalar_out(Tensor[] self, Scalar ord=2, ScalarType? dtype=None, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_out(at::TensorList self, at::TensorList exponent, at::TensorList out); // {"schema": "aten::_foreach_pow.List_out(Tensor[] self, Tensor[] exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_out(at::TensorList self, const at::Scalar & exponent, at::TensorList out); // {"schema": "aten::_foreach_pow.Scalar_out(Tensor[] self, Scalar exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_pow_out(at::TensorList self, at::ArrayRef exponent, at::TensorList out); // {"schema": "aten::_foreach_pow.ScalarList_out(Tensor[] self, Scalar[] exponent, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_reciprocal_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_reciprocal.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_round_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_round.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_rsqrt_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_rsqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sigmoid_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sigmoid.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sign_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sign.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sin_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sin.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sinh_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sinh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_sqrt_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_sqrt.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_tan_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_tan.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_tanh_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_tanh.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_trunc_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_trunc.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +void _foreach_zero_out(at::TensorList self, at::TensorList out); // {"schema": "aten::_foreach_zero.out(Tensor[] self, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::vector _foreach_zero(at::TensorList self); // {"schema": "aten::_foreach_zero(Tensor[] self) -> Tensor[] self_out", "dispatch": "True", "default": "True"} +void _foreach_copy_out(at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out); // {"schema": "aten::_foreach_copy.out(Tensor[] self, Tensor[] src, bool non_blocking=False, *, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +at::Tensor & bucketize_out(const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right, at::Tensor & out); // {"schema": "aten::bucketize.Scalar_out(Scalar self, Tensor boundaries, *, bool out_int32=False, bool right=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & glu_jvp_out(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim, at::Tensor & out); // {"schema": "aten::glu_jvp.out(Tensor glu, Tensor x, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & glu_backward_jvp_out(const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim, at::Tensor & out); // {"schema": "aten::glu_backward_jvp.out(Tensor grad_x, Tensor grad_glu, Tensor x, Tensor dgrad_glu, Tensor dx, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & hardswish_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::hardswish_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple rrelu_with_noise_functional(const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator); // {"schema": "aten::rrelu_with_noise_functional(Tensor self, Tensor noise, Scalar lower=0.125, Scalar upper=0.3333333333333333, bool training=False, Generator? generator=None) -> (Tensor, Tensor noise_out)", "dispatch": "True", "default": "True"} +at::Tensor & rrelu_with_noise_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result, at::Tensor & out); // {"schema": "aten::rrelu_with_noise_backward.out(Tensor grad_output, Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, bool self_is_result, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & mkldnn_adaptive_avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::mkldnn_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool2d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool2d.out(Tensor self, SymInt[2] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool2d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool3d_out(const at::Tensor & self, c10::SymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool3d.out(Tensor self, SymInt[3] output_size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _adaptive_avg_pool3d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_adaptive_avg_pool3d_backward.out(Tensor grad_output, Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & upsample_bilinear2d_out(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors, at::Tensor & out); // {"schema": "aten::upsample_bilinear2d.vec_out(Tensor input, SymInt[]? output_size, bool align_corners, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & upsample_nearest2d_out(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors, at::Tensor & out); // {"schema": "aten::upsample_nearest2d.vec_out(Tensor input, SymInt[]? output_size, float[]? scale_factors, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _slow_conv2d_backward_out(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2); // {"schema": "aten::_slow_conv2d_backward.output_mask_out(Tensor grad_output, Tensor self, Tensor weight, SymInt[2] kernel_size, SymInt[2] stride, SymInt[2] padding, bool[3] output_mask, *, Tensor(a!) out0, Tensor(b!) out1, Tensor(c!) out2) -> (Tensor(a!), Tensor(b!), Tensor(c!))", "dispatch": "True", "default": "True"} +at::Tensor & conv_depthwise3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::conv_depthwise3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias, SymInt[3] stride, SymInt[3] padding, SymInt[3] dilation, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_dilated2d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_dilated2d.out(Tensor self, Tensor weight, SymInt[2] kernel_size, Tensor? bias=None, SymInt[2] stride=1, SymInt[2] padding=0, SymInt[2] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slow_conv_dilated3d_out(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, at::Tensor & out); // {"schema": "aten::slow_conv_dilated3d.out(Tensor self, Tensor weight, SymInt[3] kernel_size, Tensor? bias=None, SymInt[3] stride=1, SymInt[3] padding=0, SymInt[3] dilation=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & isinf_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::isinf.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & linalg_matrix_exp_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::linalg_matrix_exp.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_optional_intlist_out(const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out); // {"schema": "aten::_test_optional_intlist.out(Tensor values, int[]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_optional_filled_intlist_out(const at::Tensor & values, at::OptionalIntArrayRef addends, at::Tensor & out); // {"schema": "aten::_test_optional_filled_intlist.out(Tensor values, int[2]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_optional_floatlist_out(const at::Tensor & values, ::std::optional> addends, at::Tensor & out); // {"schema": "aten::_test_optional_floatlist.out(Tensor values, float[]? addends, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_warn_in_autograd_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_test_warn_in_autograd.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_autograd_multiple_dispatch_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_test_autograd_multiple_dispatch.fullcoverage_out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _test_autograd_multiple_dispatch_view_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_test_autograd_multiple_dispatch_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & segment_reduce_out(const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial, at::Tensor & out); // {"schema": "aten::segment_reduce.out(Tensor data, str reduce, *, Tensor? lengths=None, Tensor? indices=None, Tensor? offsets=None, int axis=0, bool unsafe=False, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _segment_reduce_backward_out(const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial, at::Tensor & out); // {"schema": "aten::_segment_reduce_backward.out(Tensor grad, Tensor output, Tensor data, str reduce, *, Tensor? lengths=None, Tensor? offsets=None, int axis=0, Scalar? initial=None, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _nested_tensor_from_tensor_list_out(at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, at::Tensor & out); // {"schema": "aten::_nested_tensor_from_tensor_list.out(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _fw_primal_copy_out(const at::Tensor & self, int64_t level, at::Tensor & out); // {"schema": "aten::_fw_primal_copy.out(Tensor self, int level, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _make_dual_copy_out(const at::Tensor & primal, const at::Tensor & tangent, int64_t level, at::Tensor & out); // {"schema": "aten::_make_dual_copy.out(Tensor primal, Tensor tangent, int level, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_as_real_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::view_as_real_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_as_complex_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::view_as_complex_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _conj_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_conj_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _neg_view_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_neg_view_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & as_strided_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset, at::Tensor & out); // {"schema": "aten::as_strided_copy.out(Tensor self, SymInt[] size, SymInt[] stride, SymInt? storage_offset=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _sparse_broadcast_to_copy_out(const at::Tensor & self, at::IntArrayRef size, at::Tensor & out); // {"schema": "aten::_sparse_broadcast_to_copy.out(Tensor self, int[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & diagonal_copy_out(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2, at::Tensor & out); // {"schema": "aten::diagonal_copy.out(Tensor self, int offset=0, int dim1=0, int dim2=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & expand_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit, at::Tensor & out); // {"schema": "aten::expand_copy.out(Tensor self, SymInt[] size, *, bool implicit=False, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & permute_copy_out(const at::Tensor & self, at::IntArrayRef dims, at::Tensor & out); // {"schema": "aten::permute_copy.out(Tensor self, int[] dims, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _reshape_alias_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, at::Tensor & out); // {"schema": "aten::_reshape_alias_copy.out(Tensor self, SymInt[] size, SymInt[] stride, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & select_copy_out(const at::Tensor & self, int64_t dim, c10::SymInt index, at::Tensor & out); // {"schema": "aten::select_copy.int_out(Tensor self, int dim, SymInt index, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & detach_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::detach_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & slice_copy_out(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step, at::Tensor & out); // {"schema": "aten::slice_copy.Tensor_out(Tensor self, int dim=0, SymInt? start=None, SymInt? end=None, SymInt step=1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::squeeze_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_copy_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::squeeze_copy.dim_out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & squeeze_copy_out(const at::Tensor & self, at::IntArrayRef dim, at::Tensor & out); // {"schema": "aten::squeeze_copy.dims_out(Tensor self, int[] dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & t_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::t_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & transpose_copy_out(const at::Tensor & self, int64_t dim0, int64_t dim1, at::Tensor & out); // {"schema": "aten::transpose_copy.int_out(Tensor self, int dim0, int dim1, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & unsqueeze_copy_out(const at::Tensor & self, int64_t dim, at::Tensor & out); // {"schema": "aten::unsqueeze_copy.out(Tensor self, int dim, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _values_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::_values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & values_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::values_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & crow_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::crow_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & col_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::col_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & ccol_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::ccol_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & row_indices_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::row_indices_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_copy_out(const at::Tensor & self, c10::SymIntArrayRef size, at::Tensor & out); // {"schema": "aten::view_copy.out(Tensor self, SymInt[] size, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & view_copy_out(const at::Tensor & self, at::ScalarType dtype, at::Tensor & out); // {"schema": "aten::view_copy.dtype_out(Tensor self, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & unfold_copy_out(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step, at::Tensor & out); // {"schema": "aten::unfold_copy.out(Tensor self, int dimension, int size, int step, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & alias_copy_out(const at::Tensor & self, at::Tensor & out); // {"schema": "aten::alias_copy.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & to_padded_tensor_out(const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size, at::Tensor & out); // {"schema": "aten::to_padded_tensor.out(Tensor self, float padding, SymInt[]? output_size=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _transformer_encoder_layer_fwd_out(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type, at::Tensor & out); // {"schema": "aten::_transformer_encoder_layer_fwd.out(Tensor src, int embed_dim, int num_heads, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, bool use_gelu, bool norm_first, float eps, Tensor norm_weight_1, Tensor norm_bias_1, Tensor norm_weight_2, Tensor norm_bias_2, Tensor ffn_weight_1, Tensor ffn_bias_1, Tensor ffn_weight_2, Tensor ffn_bias_2, Tensor? mask=None, int? mask_type=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +::std::tuple _native_multi_head_attention_out(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type, at::Tensor & out0, at::Tensor & out1); // {"schema": "aten::_native_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, bool need_weights=True, bool average_attn_weights=True, int? mask_type=None, *, Tensor(a!) out0, Tensor(b!) out1) -> (Tensor(a!), Tensor(b!))", "dispatch": "True", "default": "True"} +at::Tensor & _triton_scaled_dot_attention_out(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p, at::Tensor & out); // {"schema": "aten::_triton_scaled_dot_attention.out(Tensor q, Tensor k, Tensor v, float dropout_p=0.0, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _triton_multi_head_attention_out(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, at::Tensor & out); // {"schema": "aten::_triton_multi_head_attention.out(Tensor query, Tensor key, Tensor value, int embed_dim, int num_head, Tensor qkv_weight, Tensor qkv_bias, Tensor proj_weight, Tensor proj_bias, Tensor? mask=None, *, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +at::Tensor & _foobar_out(const at::Tensor & self, bool arg1, bool arg2, bool arg3, at::Tensor & out); // {"schema": "aten::_foobar.out(Tensor self, bool arg1=True, bool arg2=True, *, bool arg3=True, Tensor(a!) out) -> Tensor(a!)", "dispatch": "True", "default": "True"} +void _fused_adam_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adam.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_adam_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adam.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adam.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_adamw_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adamw.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, float lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_adamw_out(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adamw.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] exp_avgs, Tensor(d!)[] exp_avg_sqs, Tensor(e!)[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adamw.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] exp_avgs, Tensor[] exp_avg_sqs, Tensor[] max_exp_avg_sqs, Tensor[] state_steps, *, Tensor lr, float beta1, float beta2, float weight_decay, float eps, bool amsgrad, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] exp_avgs_out, Tensor[] exp_avg_sqs_out, Tensor[] max_exp_avg_sqs_out)", "dispatch": "True", "default": "True"} +void _fused_sgd_out(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_sgd.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, float lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)", "dispatch": "True", "default": "True"} +void _fused_sgd_out(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_sgd.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_sgd.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] momentum_buffer_list, *, float weight_decay, float momentum, Tensor lr, float dampening, bool nesterov, bool maximize, bool is_first_step, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] momentum_buffer_list_out)", "dispatch": "True", "default": "True"} +void _fused_adagrad_out(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adagrad.out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor(d!)[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector> _fused_adagrad(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, float lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out, Tensor[] state_steps_out)", "dispatch": "True", "default": "True"} +void _fused_adagrad_out(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out); // {"schema": "aten::_fused_adagrad.tensor_lr_out(Tensor[] self, Tensor(b!)[] grads, Tensor(c!)[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None, Tensor(a!)[] out) -> ()", "dispatch": "True", "default": "True"} +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_adagrad(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf); // {"schema": "aten::_fused_adagrad.tensor_lr(Tensor[] self, Tensor[] grads, Tensor[] state_sums, Tensor[] state_steps, *, Tensor lr, float lr_decay, float weight_decay, float eps, bool maximize, Tensor? grad_scale=None, Tensor? found_inf=None) -> (Tensor[] self_out, Tensor[] grads_out, Tensor[] state_sums_out)", "dispatch": "True", "default": "True"} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SDPBackend.h b/phivenv/Lib/site-packages/torch/include/ATen/SDPBackend.h new file mode 100644 index 0000000000000000000000000000000000000000..543806f996460e5f83bfffaabdbe9fe1e2aebcac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SDPBackend.h @@ -0,0 +1,16 @@ +#pragma once +#include + +namespace at { + +constexpr int32_t num_sdp_backends = 5; +enum class SDPBackend { + error = -1, + math = 0, + flash_attention = 1, + efficient_attention = 2, + cudnn_attention = 3, + overrideable = 4 +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SavedTensorHooks.h b/phivenv/Lib/site-packages/torch/include/ATen/SavedTensorHooks.h new file mode 100644 index 0000000000000000000000000000000000000000..7cb6da7b1fc53c50f34602eca4ed008e3021fc88 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SavedTensorHooks.h @@ -0,0 +1,68 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include + +namespace at { + +namespace impl { + +struct TORCH_API SavedTensorDefaultHooksTLS { + // PyObject is defined in c10/util/python_stub.h + std::stack> stack; + + // See NOTE: [Disabling SavedTensorDefaultHooks] for context + // NOTE: [disabled_error_message invariant] + // disabled_error_message is nullopt IFF Saved Tensor hooks is enabled + // We did this for efficiency (so we didn't have to keep a separate bool + // around) + std::optional disabled_error_message; + + // See NOTE: [Deferring tensor pack/unpack hooks until runtime] + bool is_tracing = false; +}; + +} // namespace impl + +struct TORCH_API SavedTensorDefaultHooks { + static void push_hooks( + c10::SafePyObject pack_hook, + c10::SafePyObject unpack_hook); + static std::pair pop_hooks(); + static std::optional> + get_hooks(bool ignore_is_tracing = false); + static void lazy_initialize(); + + static const impl::SavedTensorDefaultHooksTLS& get_tls_state(); + static void set_tls_state(const impl::SavedTensorDefaultHooksTLS& tls); + + // NOTE: [Disabling SavedTensorDefaultHooks] + // A developer of a PyTorch feature may choose to disable SavedTensorDefault + // hooks, especially if their feature does not work with it. If they are + // disabled, then the following will raise an error: + // - Attempting to push_hooks + // - calling disable(message) with a non-zero stack (hooks) size + static void disable( + const std::string& error_message, + const bool fail_if_non_empty = true); + static void enable(); + static bool is_enabled(); + static const std::optional& get_disabled_error_message(); + + // NOTE: [Deferring tensor pack/unpack hooks until runtime] + // To preserve eager semantics of pack/unpack hooks firing only once per saved + // variable, Dynamo/AOTAutograd need to defer hook firing until runtime. Using + // disable() would loud error at trace time, and pushing a no-op hook would + // fail when the traced code is wrapped in a disable_saved_tensors_hooks ctx. + // To do so, we disable these hooks during tracing. See + // https://github.com/pytorch/pytorch/issues/113263. + static bool set_tracing(bool is_tracing); +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Scalar.h b/phivenv/Lib/site-packages/torch/include/ATen/Scalar.h new file mode 100644 index 0000000000000000000000000000000000000000..6dec39dd3c32cef073fec4891ab16a71c58e8077 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Scalar.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ScalarOps.h b/phivenv/Lib/site-packages/torch/include/ATen/ScalarOps.h new file mode 100644 index 0000000000000000000000000000000000000000..7c1b1306673b7564e5bb441ca7d8554c04e4dcd1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ScalarOps.h @@ -0,0 +1,53 @@ +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::detail { +// When filling a number to 1-element CPU tensor, we want to skip +// everything but manipulate data ptr directly. +// Ideally this fast pass should be implemented in TensorIterator, +// but we also want to skip compute_types which in not avoidable +// in TensorIterator for now. +Tensor& scalar_fill(Tensor& self, const Scalar& value); +TORCH_API Tensor scalar_tensor_static( + const Scalar& s, + std::optional dtype_opt, + std::optional device_opt); +} // namespace at::detail + +// This is in the c10 namespace because we use ADL to find the functions in it. +namespace c10 { + +// FIXME: this should be (and was) Scalar::toTensor, but there is currently no +// way to implement this without going through Derived Types (which are not part +// of core). +inline at::Tensor scalar_to_tensor( + const Scalar& s, + const Device device = at::kCPU) { + // This is the fast track we have for CPU scalar tensors. + if (device == at::kCPU) { + return at::detail::scalar_tensor_static(s, s.type(), at::kCPU); + } + return at::scalar_tensor(s, at::device(device).dtype(s.type())); +} + +} // namespace c10 + +namespace at::native { + +inline Tensor wrapped_scalar_tensor( + const Scalar& scalar, + const Device device = at::kCPU) { + auto tensor = scalar_to_tensor(scalar, device); + tensor.unsafeGetTensorImpl()->set_wrapped_number(true); + return tensor; +} + +} // namespace at::native diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ScalarType.h b/phivenv/Lib/site-packages/torch/include/ATen/ScalarType.h new file mode 100644 index 0000000000000000000000000000000000000000..022ca42e17b1bf75d00c0a9df76b59ac3a600cdd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ScalarType.h @@ -0,0 +1,4 @@ +#pragma once +#include // for BC reasons +#include +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SequenceNumber.h b/phivenv/Lib/site-packages/torch/include/ATen/SequenceNumber.h new file mode 100644 index 0000000000000000000000000000000000000000..fc9ef214608d0e9858f414924ff8c1a56a38f7e9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SequenceNumber.h @@ -0,0 +1,13 @@ +#pragma once + +#include +#include + +// A simple thread local enumeration, used to link forward and backward pass +// ops and is used by autograd and observers framework +namespace at::sequence_number { + +TORCH_API uint64_t peek(); +TORCH_API uint64_t get_and_increment(); + +} // namespace at::sequence_number diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SmallVector.h b/phivenv/Lib/site-packages/torch/include/ATen/SmallVector.h new file mode 100644 index 0000000000000000000000000000000000000000..aa8a0558f671095399ca3926d72adb68d460429b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SmallVector.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SparseCsrTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/SparseCsrTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..5e257ca0ec6f3d3d78e74a63b9b9623316f80117 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SparseCsrTensorImpl.h @@ -0,0 +1,206 @@ +#pragma once + +#include +#include +#include +#include +namespace at { + +// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for +// denoting the data: `crow_indices_`, `col_indices_` and `values_`. +// The `crow_indices_` tensor is a integer tensor of shape `(size(0) + 1)` +// that represents the compressed row indices of the CSR tensor. The +// `col_indices_` tensor is an integer tensor of shape `(nnz())` +// that explicitly stores the column indices of each value of the sparse +// tensor. The `values_` tensor can be of any pytorch-supported data type +// and has shape `(nnz())`. +// +// Since the main advantage of the CSR format over the COO format is speed of +// computation, care must be taken to facilitate smooth interfacing of +// these data structures with optimized libraries such as MKL and MAGMA. +// Since the MKL interface for pytorch currently uses indexing with int32 +// type, it is important to make sure that the `crow_indices` and `col_indices` +// are of type int32 when calling MKL routines such as SPMM or SPMV. +// +// If not calling MKL, it should be alright to use 64 bit integer tensors +// for indexing. +struct TORCH_API SparseCsrTensorImpl : public TensorImpl { + Tensor crow_indices_; + Tensor col_indices_; + Tensor values_; + Layout layout_; + + public: + explicit SparseCsrTensorImpl( + at::DispatchKeySet, + at::Device device, + Layout layout, + const caffe2::TypeMeta); + + void resize_(int64_t nnz, IntArrayRef size); + void resize_and_clear_( + int64_t sparse_dim, + int64_t dense_dim, + IntArrayRef size); + void resize_as_sparse_compressed_tensor_(const Tensor& src); + void set_member_tensors( + const Tensor& crow_indices, + const Tensor& col_indices, + const Tensor& values, + c10::SymIntArrayRef size); + void set_member_tensors( + const Tensor& crow_indices, + const Tensor& col_indices, + const Tensor& values, + IntArrayRef size); + const Tensor& compressed_indices() const { + return crow_indices_; + } + const Tensor& plain_indices() const { + return col_indices_; + } + const Tensor& values() const { + return values_; + } + int64_t nnz() { + return col_indices_.size(-1); + } + + inline int64_t batch_dim() const noexcept { + return crow_indices_.dim() - 1; + } + + inline int64_t sparse_dim() const noexcept { + return 2; + } + + inline int64_t dense_dim() const noexcept { + return values_.dim() - batch_dim() - block_dim() - 1; + } + + private: + inline int64_t block_dim() const noexcept { + return (layout_ == kSparseBsr || layout_ == kSparseBsc ? 2 : 0); + } + + protected: + IntArrayRef strides_custom() const override; + SymIntArrayRef sym_strides_custom() const override; + bool is_contiguous_custom(MemoryFormat) const override; + + public: + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + Layout layout_impl() const override { + return layout_; + } + void set_layout(Layout layout) { + switch (layout) { + case kSparseCsr: + case kSparseCsc: + case kSparseBsr: + case kSparseBsc: + layout_ = layout; + break; + default: + TORCH_CHECK(false, "unsupported layout ", layout); + } + } + + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); + c10::impl::PyInterpreter&& interpreter = nullptr; + if (mode_stack_len > 0 && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + const auto& cur_torch_dispatch_mode_state = + c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); + interpreter = cur_torch_dispatch_mode_state->pyinterpreter(); + } else if ( + key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + interpreter = pyobj_slot_.load_pyobj_interpreter(); + } else { + // otherwise just copy the SparseTensorImpl and not the PyObject. + auto impl = c10::make_intrusive( + key_set(), device(), layout_impl(), dtype()); + copy_tensor_metadata( + /*src_sparse_impl=*/this, + /*dest_sparse_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + auto r = interpreter->detach(this); + r->set_version_counter(std::forward(version_counter)); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + version_counter, allow_tensor_metadata_change); + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + std::move(version_counter), allow_tensor_metadata_change); + } + + private: + explicit SparseCsrTensorImpl( + at::DispatchKeySet key_set, + const caffe2::TypeMeta data_type, + at::Tensor crow_indices, + at::Tensor col_indices, + at::Tensor values, + at::Layout layout); + + const char* tensorimpl_type_name() const override; + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const SparseCsrTensorImpl* src_sparse_impl, + SparseCsrTensorImpl* dest_sparse_impl, + c10::VariableVersion version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_sparse_impl, + dest_sparse_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // Sparse-specific fields + dest_sparse_impl->crow_indices_ = src_sparse_impl->compressed_indices(); + dest_sparse_impl->col_indices_ = src_sparse_impl->plain_indices(); + dest_sparse_impl->values_ = src_sparse_impl->values(); + dest_sparse_impl->layout_ = src_sparse_impl->layout_impl(); + } +}; +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SparseCsrTensorUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/SparseCsrTensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..2e518d3f2aba57ab09dc8973c32cb34442494f1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SparseCsrTensorUtils.h @@ -0,0 +1,454 @@ +#pragma once + +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#include +#else +#include +#include +#endif + +#define AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS(LAYOUT, NAME, ...) \ + [&] { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseCsc: \ + case kSparseBsr: \ + case kSparseBsc: \ + return __VA_ARGS__(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, ROW_DIM_ACTION, COLUMN_DIM_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseBsr: \ + return (ROW_DIM_ACTION)(); \ + case kSparseCsc: \ + case kSparseBsc: \ + return (COLUMN_DIM_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_PLAIN_SPARSE_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, NO_BLOCK_ACTION, BLOCK_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseCsc: \ + return (NO_BLOCK_ACTION)(); \ + case kSparseBsr: \ + case kSparseBsc: \ + return (BLOCK_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_ROW_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, ROW_DIM_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseBsr: \ + return (ROW_DIM_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse row compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_COL_COMPRESSED_LAYOUTS( \ + LAYOUT, NAME, COL_DIM_ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsc: \ + case kSparseBsc: \ + return (COL_DIM_ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse column compressed tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_COMPRESSED_NONBLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseCsr: \ + case kSparseCsc: \ + return (ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed (non-block) tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_COMPRESSED_BLOCK_LAYOUTS(LAYOUT, NAME, ACTION) \ + [&]() { \ + const auto& the_layout = LAYOUT; \ + switch (the_layout) { \ + case kSparseBsr: \ + case kSparseBsc: \ + return (ACTION)(); \ + default: \ + TORCH_CHECK( \ + false, \ + NAME, \ + " expected sparse compressed block tensor layout but got ", \ + the_layout); \ + } \ + }() + +#define AT_DISPATCH_SPARSE_VALUE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH( \ + TYPE, \ + NAME, \ + AT_DISPATCH_CASE_ALL_TYPES_AND_COMPLEX_AND4( \ + kComplexHalf, kHalf, kBool, kBFloat16, __VA_ARGS__)) + +namespace at::sparse_csr { + +// Implements RAII object to manage checking sparse tensor invariants: +class CheckSparseTensorInvariants { + bool old_state; + + public: + CheckSparseTensorInvariants(bool state) + : old_state(at::globalContext().checkSparseTensorInvariants()) { + at::globalContext().setCheckSparseTensorInvariants(state); + } + CheckSparseTensorInvariants(CheckSparseTensorInvariants&& other) = delete; + CheckSparseTensorInvariants(const CheckSparseTensorInvariants&) = delete; + CheckSparseTensorInvariants& operator=(const CheckSparseTensorInvariants&) = + delete; + CheckSparseTensorInvariants& operator=(CheckSparseTensorInvariants&&) = + delete; + + ~CheckSparseTensorInvariants() { + at::globalContext().setCheckSparseTensorInvariants(old_state); + } +}; + +using SparseCsrTensor = Tensor; + +inline bool is_sparse_compressed(const Layout& layout) { + switch (layout) { + case kSparseCsr: + case kSparseCsc: + case kSparseBsr: + case kSparseBsc: + return true; + default:; + } + return false; +} + +inline bool is_sparse_compressed(const Tensor& self) { + return is_sparse_compressed(self.layout()); +} + +inline SparseCsrTensorImpl* get_sparse_csr_impl(const SparseCsrTensor& self) { + AT_DISPATCH_ALL_SPARSE_COMPRESSED_LAYOUTS( + self.layout(), "get_sparse_csr_impl", [&] {}); + return static_cast(self.unsafeGetTensorImpl()); +} + +inline std::string layoutToString( + Layout layout, + bool upper = false, + bool lower = false) { + switch (layout) { + case kSparseCsr: + return (upper ? "CSR" : (lower ? "csr" : "Csr")); + case kSparseCsc: + return (upper ? "CSC" : (lower ? "csc" : "Csc")); + case kSparseBsr: + return (upper ? "BSR" : (lower ? "bsr" : "Bsr")); + case kSparseBsc: + return (upper ? "BSC" : (lower ? "bsc" : "Bsc")); + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline bool isCompressedRow(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, "isCompressedRow", [&] { return true; }, [&] { return false; }); +} + +inline bool isCompressedColumn(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, + "isCompressedColumn", + [&] { return false; }, + [&] { return true; }); +} + +inline std::string compressedIndicesName(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, + "compressedIndicesName", + [&] { return "crow_indices"; }, + [&] { return "ccol_indices"; }); +} + +inline std::string plainIndicesName(Layout layout) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + layout, + "plainIndicesName", + [&] { return "col_indices"; }, + [&] { return "row_indices"; }); +} + +inline std::string compressedDimName(Layout layout) { + switch (layout) { + case kSparseCsr: + return "row"; + case kSparseCsc: + return "column"; + case kSparseBsr: + return "row block"; + case kSparseBsc: + return "column block"; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline std::string plainDimName(Layout layout) { + switch (layout) { + case kSparseCsr: + return "column"; + case kSparseCsc: + return "row"; + case kSparseBsr: + return "column block"; + case kSparseBsc: + return "row block"; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return ""; + } +} + +inline size_t rowDimension(Layout layout, IntArrayRef size) { + return size.size() - (isCompressedRow(layout) ? 2 : 1); +} + +inline size_t columnDimension(Layout layout, IntArrayRef size) { + return size.size() - (isCompressedColumn(layout) ? 2 : 1); +} + +inline size_t compressedDimension( + Layout layout, + IntArrayRef size, + size_t dense_ndim = 0) { + return size.size() - dense_ndim - (isCompressedRow(layout) ? 2 : 1); +} + +inline size_t plainDimension( + Layout layout, + IntArrayRef size, + size_t dense_ndim = 0) { + return size.size() - dense_ndim - (isCompressedRow(layout) ? 1 : 2); +} + +inline int64_t numBatchDimensions(Tensor const& self) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + self.layout(), + "numBatchDimensions", + [&self] { return self.crow_indices().dim() - 1; }, + [&self] { return self.ccol_indices().dim() - 1; }); +} + +inline std::pair getCompressedPlainIndices(Tensor const& self) { + return AT_DISPATCH_ROW_SPARSE_COMPRESSED_LAYOUTS( + self.layout(), + "getCompressedPlainIndices", + [&self] { + return std::make_pair(self.crow_indices(), self.col_indices()); + }, + [&self] { + return std::make_pair(self.ccol_indices(), self.row_indices()); + }); +} + +inline ScalarType getIndexDtype(Tensor const& self) { + switch (self.layout()) { + case kSparseCsr: + case kSparseBsr: + return self.crow_indices().scalar_type(); + case kSparseCsc: + case kSparseBsc: + return self.ccol_indices().scalar_type(); + case kSparse: + return self._indices().scalar_type(); + default: + return ScalarType::Long; + } +} + +inline Layout flip_compressed_layout(Layout layout) { + switch (layout) { + case kSparseCsr: + return kSparseCsc; + case kSparseCsc: + return kSparseCsr; + case kSparseBsr: + return kSparseBsc; + case kSparseBsc: + return kSparseBsr; + default: + TORCH_CHECK(false, "Not a sparse compressed layout:", layout); + return kSparseCsr; + } +} + +inline DimVector getBlockSize(Tensor const& self) { + int64_t n_batch = numBatchDimensions(self); + return at::DimVector(self.values().sizes().slice(n_batch + 1, 2)); +} + +inline at::OptionalArray getSymIntBlockSize(Tensor const& self) { + if (self.layout() == at::kSparseBsr || self.layout() == at::kSparseBsc) { + int64_t n_batch = numBatchDimensions(self); + return self.values().sym_sizes().slice(n_batch + 1, 2).vec(); + } else { + return {}; + } +} + +template +inline bool only_sparse_compressed_binary_op_trivial_cases( + const Tensor& self, + const Tensor& other, + const Scalar& alpha, + Tensor& out, + const binary_op_t& binary_op, + const binary_op_out_t& binary_op_out) { + // Only sparse compressed! Just like the name says :) + TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(self)); + TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(other)); + TORCH_INTERNAL_ASSERT(at::sparse_csr::is_sparse_compressed(out)); + + // Bypass BLAS if there are matches in (self, other, out) + if (self.is_same(out) && self.is_same(other)) { + binary_op_out(self.values(), other.values(), alpha); + return true; + } + if (self.is_same(other)) { + auto [compressed_indices, plain_indices] = + at::sparse_csr::getCompressedPlainIndices(self); + static_cast(out.unsafeGetTensorImpl()) + ->set_member_tensors( + compressed_indices, + plain_indices, + binary_op(self.values(), other.values(), alpha), + self.sizes()); + return true; + } + return false; +} + +inline bool only_sparse_compressed_add_trivial_cases( + const Tensor& self, + const Tensor& other, + const Scalar& alpha, + Tensor& out) { + return only_sparse_compressed_binary_op_trivial_cases( + self, + other, + alpha, + out, + [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) { + return v1.add(v2, alpha); + }, + [](const Tensor& v1, const Tensor& v2, const Scalar& alpha) { + return v1.add_(v2, alpha); + }); +} + +inline Tensor to_type(const Tensor& input, ScalarType dtype) { + auto [compressed_indices, plain_indices] = + at::sparse_csr::getCompressedPlainIndices(input); + return at::_sparse_compressed_tensor_unsafe( + compressed_indices, + plain_indices, + std::move(input.values()).to(dtype), + input.sizes(), + dtype, + input.layout(), + input.device(), + input.options().pinned_memory_opt()); +} + +template +inline std::tuple create_acc_buffer( + TensorOptions option, + ScalarType type, + int64_t nnz = -1) { + Tensor new_values, new_values_acc; + constexpr bool need_acc = !std::is_same_v; + bool is_integral = at::isIntegralType(type, /*includeBool=*/true); + if constexpr (need_acc) { + auto acc_dtype = CppTypeToScalarType::value; + new_values_acc = at::empty({}, option.dtype(acc_dtype)); + new_values = is_integral ? new_values_acc : at::empty({}, option); + } else { + new_values = new_values_acc = at::empty({}, option); + } + if (nnz != -1) { + return std::make_tuple( + new_values.resize_(nnz), new_values_acc.resize_(nnz)); + } else { + return std::make_tuple(new_values, new_values_acc); + } +} + +inline void copy_from_acc_buffer(Tensor& new_values, Tensor& new_values_acc) { + if (!new_values_acc.is_same(new_values)) { + new_values.copy_(new_values_acc); + } +} + +} // namespace at::sparse_csr diff --git a/phivenv/Lib/site-packages/torch/include/ATen/SparseTensorImpl.h b/phivenv/Lib/site-packages/torch/include/ATen/SparseTensorImpl.h new file mode 100644 index 0000000000000000000000000000000000000000..9401784dbd10537ef39199db461cebea534a8ca5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/SparseTensorImpl.h @@ -0,0 +1,421 @@ +#pragma once + +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#include +#endif + +namespace at { +struct TORCH_API SparseTensorImpl : public TensorImpl { + // Stored in COO format, indices + values. + + // INVARIANTS: + // sparse_dim: range [0, len(shape)]; sparse_dim + dense_dim = len(shape) + // dense_dim : range [0, len(shape)]; sparse_dim + dense_dim = len(shape) + // _indices.shape: dimensionality: 2, shape: (sparse_dim, nnz) + // _values.shape: dimensionality: 1 + dense_dim. shape: (nnz, + // shape[sparse_dim:]) + + int64_t sparse_dim_ = 0; // number of sparse dimensions + int64_t dense_dim_ = 0; // number of dense dimensions + + Tensor indices_; // always a LongTensor + Tensor values_; + + // A sparse tensor is 'coalesced' if every index occurs at most once in + // the indices tensor, and the indices are in sorted order. (This means + // that it is very easy to convert a coalesced tensor to CSR format: you + // need only compute CSR format indices.) + // + // Most math operations can only be performed on coalesced sparse tensors, + // because many algorithms proceed by merging two sorted lists (of indices). + bool coalesced_ = false; + + // compute_numel with integer multiplication overflow check, see gh-57542 + void refresh_numel() { + TensorImpl::safe_refresh_numel(); + } + + public: + // Public for now... + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta); + + void release_resources() override; + + int64_t nnz() const { + return values_.size(0); + } + + c10::SymInt sym_nnz() const { + return values_.sym_size(0); + } + int64_t sparse_dim() const { + return sparse_dim_; + } + int64_t dense_dim() const { + return dense_dim_; + } + bool coalesced() const { + return coalesced_; + } + Tensor indices() const { + return indices_; + } + Tensor values() const { + return values_; + } + + void set_size(int64_t dim, int64_t new_size) override; + void set_stride(int64_t dim, int64_t new_stride) override; + void set_storage_offset(int64_t storage_offset) override; + +#ifdef DEBUG + bool has_storage() const override; +#endif + + // WARNING: This function does NOT preserve invariants of sparse_dim/dense_dim + // with respect to indices and values + void raw_resize_(int64_t sparse_dim, int64_t dense_dim, IntArrayRef size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "raw_resize_ ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "raw_resize_ called on tensor with symbolic shape") + set_sizes_and_strides(size, std::vector(size.size())); + sparse_dim_ = sparse_dim; + dense_dim_ = dense_dim; + refresh_numel(); + } + + // NOTE: This function preserves invariants of sparse_dim/dense_dim with + // respect to indices and values. + // + // NOTE: This function supports the following cases: + // 1. When we keep the number of dense dimensions unchanged, and NOT shrinking + // the size of any of the dense dimensions. + // 2. When we keep the number of sparse dimensions unchanged, and NOT + // shrinking the size of any of the sparse dimensions. + // 3. When the sparse tensor has zero nnz, in which case we are free to change + // the shapes of both its sparse and dense dimensions. + // + // This function DOESN'T support (and will throw an error) the following + // cases: + // 1. When we attempt to change the number of sparse dimensions on a non-empty + // sparse tensor (such an operation will invalidate the indices stored). + // 2. When we attempt to change the number of dense dimensions on a non-empty + // sparse tensor (such an operation will behave differently from an equivalent + // dense tensor's resize method, and for API consistency we don't support it). + // 3. When we attempt to shrink the size of any of the dense dimensions on a + // non-empty sparse tensor (such an operation will behave differently from an + // equivalent dense tensor's resize method, and for API consistency we don't + // support it). + // 4. When we attempt to shrink the size of any of the sparse dimensions on a + // non-empty sparse tensor (this could make some of the stored indices + // out-of-bound and thus unsafe). + template + void _resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "resize_ ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "resize_ called on tensor with symbolic shape") + TORCH_CHECK( + sparse_dim + dense_dim == static_cast(size.size()), + "number of dimensions must be sparse_dim (", + sparse_dim, + ") + dense_dim (", + dense_dim, + "), but got ", + size.size()); + if (nnz() > 0) { + [[maybe_unused]] auto constexpr alt_options_msg = + "You could try the following options:\n\ +1. If you need an empty sparse tensor of this size, call `x = torch.sparse_coo_tensor(size)`.\n\ +2. If you need to resize this tensor, you have the following options:\n\ + 1. For both sparse and dense dimensions, keep the number of them constant and the size of them non-shrinking, and then try the same call again.\n\ + 2. Or, create a new sparse tensor with the correct indices and values from this sparse tensor."; + + TORCH_CHECK( + sparse_dim == sparse_dim_, + "changing the number of sparse dimensions (from ", + sparse_dim_, + " to ", + sparse_dim, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + + TORCH_CHECK( + dense_dim == dense_dim_, + "changing the number of dense dimensions (from ", + dense_dim_, + " to ", + dense_dim, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + + bool shrinking_sparse_dims = false; + bool shrinking_dense_dim = false; + auto sparse_size_original = generic_sizes().slice(0, sparse_dim); + auto sparse_size_new = size.slice(0, sparse_dim); + for (const auto i : c10::irange(sparse_dim)) { + if (sparse_size_new[i] < sparse_size_original[i]) { + shrinking_sparse_dims = true; + break; + } + } + auto dense_size_original = generic_sizes().slice(sparse_dim); + auto dense_size_new = size.slice(sparse_dim); + for (const auto i : c10::irange(dense_dim)) { + if (dense_size_new[i] < dense_size_original[i]) { + shrinking_dense_dim = true; + break; + } + } + + TORCH_CHECK( + !shrinking_sparse_dims, + "shrinking the size of sparse dimensions (from ", + sparse_size_original, + " to ", + sparse_size_new, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + + TORCH_CHECK( + !shrinking_dense_dim, + "shrinking the size of dense dimensions (from ", + dense_size_original, + " to ", + dense_size_new, + ") on a non-empty sparse tensor is not supported.\n", + alt_options_msg); + } + + auto sizes_and_strides = generic_sizes(); + const bool size_equals_sizes = std::equal( + size.begin(), + size.end(), + sizes_and_strides.begin(), + sizes_and_strides.end()); + if ((!size_equals_sizes) || (sparse_dim != sparse_dim_) || + (dense_dim != dense_dim_)) { + auto nnz = at::symint::sizes(values())[0]; + std::vector values_size = {nnz}; + auto dense_size = size.slice(sparse_dim); + values_size.insert( + values_size.end(), dense_size.begin(), dense_size.end()); + at::symint::resize_(values_, values_size); + at::symint::resize_(indices_, {T(sparse_dim), nnz}); + } + + if (!size_equals_sizes) { + set_sizes_and_strides(size, std::vector(size.size())); + } + sparse_dim_ = sparse_dim; + dense_dim_ = dense_dim; + refresh_numel(); + } + + void resize_(int64_t sparse_dim, int64_t dense_dim, ArrayRef size) { + return _resize_(sparse_dim, dense_dim, size); + } + + void resize_( + int64_t sparse_dim, + int64_t dense_dim, + ArrayRef size) { + return _resize_(sparse_dim, dense_dim, size); + } + + // NOTE: this function will resize the sparse tensor and also set `indices` + // and `values` to empty. + void resize_and_clear_( + int64_t sparse_dim, + int64_t dense_dim, + IntArrayRef size) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "resize_and_clear_ ", + err_msg_tensor_metadata_change_not_allowed); + TORCH_CHECK( + !has_symbolic_sizes_strides_, + "resize_and_clear_ called on tensor with symbolic shape") + TORCH_CHECK( + sparse_dim + dense_dim == static_cast(size.size()), + "number of dimensions must be sparse_dim (", + sparse_dim, + ") + dense_dim (", + dense_dim, + "), but got ", + size.size()); + + set_sizes_and_strides(size, std::vector(size.size())); + sparse_dim_ = sparse_dim; + dense_dim_ = dense_dim; + + auto empty_indices = at::empty({sparse_dim, 0}, indices().options()); + std::vector values_size = {0}; + auto dense_size = sizes().slice(sparse_dim); + values_size.insert(values_size.end(), dense_size.begin(), dense_size.end()); + auto empty_values = at::empty(values_size, values().options()); + set_indices_and_values_unsafe(empty_indices, empty_values); + refresh_numel(); + } + + void set_coalesced(bool coalesced) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_coalesced ", + err_msg_tensor_metadata_change_not_allowed); + coalesced_ = coalesced; + } + + // NOTE: this function is only used internally and not exposed to Python + // frontend + void set_nnz_and_narrow(int64_t new_nnz) { + TORCH_CHECK( + allow_tensor_metadata_change(), + "set_nnz_and_narrow ", + err_msg_tensor_metadata_change_not_allowed); + AT_ASSERT(new_nnz <= nnz()); + indices_ = indices_.narrow(1, 0, new_nnz); + values_ = values_.narrow(0, 0, new_nnz); + if (new_nnz < 2) { + coalesced_ = true; + } + } + + // Takes indices and values and directly puts them into the sparse tensor, no + // copy. NOTE: this function is unsafe because it doesn't check whether any + // indices are out of boundaries of `sizes`, so it should ONLY be used where + // we know that the indices are guaranteed to be within bounds. This used to + // be called THSTensor_(_move) NB: This used to be able to avoid a refcount + // bump, but I was too lazy to make it happen + void set_indices_and_values_unsafe( + const Tensor& indices, + const Tensor& values); + + template + c10::intrusive_ptr shallow_copy_and_detach_core( + VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const { + const auto mode_stack_len = c10::impl::TorchDispatchModeTLS::stack_len(); + c10::impl::PyInterpreter&& interpreter = nullptr; + if (mode_stack_len > 0 && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + const auto& cur_torch_dispatch_mode_state = + c10::impl::TorchDispatchModeTLS::get_stack_at(mode_stack_len - 1); + interpreter = cur_torch_dispatch_mode_state->pyinterpreter(); + } else if ( + key_set_.has(DispatchKey::Python) && + !c10::impl::tls_is_dispatch_key_excluded(DispatchKey::Python)) { + interpreter = pyobj_slot_.load_pyobj_interpreter(); + } else { + // otherwise just copy the SparseTensorImpl and not the PyObject. + auto impl = c10::make_intrusive(key_set(), dtype()); + copy_tensor_metadata( + /*src_sparse_impl=*/this, + /*dest_sparse_impl=*/impl.get(), + /*version_counter=*/version_counter, + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change); + impl->refresh_numel(); + return impl; + } + auto r = interpreter->detach(this); + r->set_version_counter(std::forward(version_counter)); + r->set_allow_tensor_metadata_change(allow_tensor_metadata_change); + return r; + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + const c10::VariableVersion& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + version_counter, allow_tensor_metadata_change); + } + + /** + * Return a TensorImpl that is a shallow-copy of this TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, + * see NOTE [ TensorImpl Shallow-Copying ]. + */ + c10::intrusive_ptr shallow_copy_and_detach( + c10::VariableVersion&& version_counter, + bool allow_tensor_metadata_change) const override { + return shallow_copy_and_detach_core( + std::move(version_counter), allow_tensor_metadata_change); + } + + /** + * Shallow-copies data from another TensorImpl into this TensorImpl. + * + * For why this function doesn't check this TensorImpl's + * `allow_tensor_metadata_change_`, see NOTE [ TensorImpl Shallow-Copying ]. + */ + void shallow_copy_from(const c10::intrusive_ptr& impl) override { + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); + auto sparse_impl = static_cast(impl.get()); + copy_tensor_metadata( + /*src_sparse_impl=*/sparse_impl, + /*dest_sparse_impl=*/this, + /*version_counter=*/version_counter(), + /*allow_tensor_metadata_change=*/allow_tensor_metadata_change()); + refresh_numel(); + } + + private: + explicit SparseTensorImpl( + at::DispatchKeySet, + const caffe2::TypeMeta, + at::Tensor indices, + at::Tensor values); + + /** + * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / + * storage_offset) from one TensorImpl to another TensorImpl. + * + * For usage of `version_counter` and `allow_tensor_metadata_change`, see NOTE + * [ TensorImpl Shallow-Copying ]. + */ + static void copy_tensor_metadata( + const SparseTensorImpl* src_sparse_impl, + SparseTensorImpl* dest_sparse_impl, + c10::VariableVersion version_counter, + bool allow_tensor_metadata_change) { + TensorImpl::copy_tensor_metadata( + src_sparse_impl, + dest_sparse_impl, + std::move(version_counter), + allow_tensor_metadata_change); + + // Sparse-specific fields + dest_sparse_impl->sparse_dim_ = src_sparse_impl->sparse_dim(); + dest_sparse_impl->dense_dim_ = src_sparse_impl->dense_dim(); + dest_sparse_impl->indices_ = src_sparse_impl->indices(); + dest_sparse_impl->values_ = src_sparse_impl->values(); + dest_sparse_impl->coalesced_ = src_sparse_impl->coalesced(); + } + + const char* tensorimpl_type_name() const override; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Storage.h b/phivenv/Lib/site-packages/torch/include/ATen/Storage.h new file mode 100644 index 0000000000000000000000000000000000000000..458e195f3067cdcb0249a4c074f95e8244c94b1f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Storage.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/StorageUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/StorageUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..71b168db8ef472bae4b71b8634f0440ed846aab6 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/StorageUtils.h @@ -0,0 +1,49 @@ +#pragma once + +#include +#include +#include + +namespace at { + +class TensorBase; + +// Here we define a series of utils to create/manipulate ATen backed +// c10 storage implementations. + +/** + * Create a new shared memory storage impl managed by file descriptor + * + * @param size size in bytes + */ +C10_EXPORT c10::intrusive_ptr new_shm_fd_storage(size_t size); + +/** + * Copy src to dst + * Caller must guarantee the validness of the storage objects + * during the entire copy process, esp. when it's async. + * + * This can probably live in c10 namespace later if needed, + * but for now keep it in at to keep implementation simple. + * + * @param dst dst tensor + * @param src src tensor + * @param non_blocking (default false) whether this operation blocks caller + */ +C10_EXPORT void storage_copy( + c10::Storage& dst, + const c10::Storage& src, + bool non_blocking = false); + +/** + * In place change the storage to shm based. + * + * This is only applicable to CPU tensors not already shared. + * Otherwise, it's a no op to mirror the THP tensor behavior: + * https://pytorch.org/docs/stable/generated/torch.Tensor.share_memory_.html + * + * @param t a tensor + */ +C10_EXPORT void share_memory_(TensorBase& t); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Tensor.h b/phivenv/Lib/site-packages/torch/include/ATen/Tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..c07eba1b83cb467d861f6d220b9b834c63e3cadd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Tensor.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorAccessor.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorAccessor.h new file mode 100644 index 0000000000000000000000000000000000000000..9a8b7f6fedfc56702f9b3509cb1af6b9380c87e1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorAccessor.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorGeometry.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorGeometry.h new file mode 100644 index 0000000000000000000000000000000000000000..8966b1be1195eb9e329b70c89c3f3656b5f07be5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorGeometry.h @@ -0,0 +1,154 @@ +#pragma once + +#include +#include + +namespace at { + +// Return if the tensor geometry represented by `sizes` and `strides` is +// contiguous Although we cache is_contiguous in tensor now, this is till useful +// because it allows checking if a particular geometry is contiguous without +// explicitly constructing a tensor, e.g., when you want to choose a kernel +// strategy based on whether a subgeometry is contiguous. +TORCH_API bool geometry_is_contiguous(IntArrayRef sizes, IntArrayRef strides); + +struct TORCH_API TensorGeometry { + TensorGeometry() = default; + + explicit TensorGeometry(c10::SymIntArrayRef sizes) + : sizes_(sizes.vec()), + strides_(sizes.size()), + has_symbolic_sizes_strides_( + !c10::asIntArrayRefSlowOpt(sizes).has_value()) { + int64_t dim = static_cast(sizes.size()); + c10::SymInt expected_stride = 1; + for (int64_t i = dim - 1; i >= 0; i--) { + strides_[i] = expected_stride; + expected_stride *= sizes_[i]; + } + numel_ = expected_stride; + } + + explicit TensorGeometry(const TensorBase& t) + : sizes_(t.sym_sizes().vec()), + strides_(t.sym_strides().vec()), + storage_offset_(t.sym_storage_offset()), + numel_(t.sym_numel()), + has_symbolic_sizes_strides_( + t.unsafeGetTensorImpl()->has_symbolic_sizes_strides()) {} + + explicit TensorGeometry( + std::vector sizes, + std::vector strides, + at::SymInt storage_offset) + : sizes_(std::move(sizes)), + strides_(std::move(strides)), + storage_offset_(std::move(storage_offset)) { + recompute(); + } + + // true if the tensor is contiguous + bool is_contiguous() const; + + int64_t dim() const { + return static_cast(sizes_.size()); + } + + int64_t size(int64_t dim) const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + dim = c10::maybe_wrap_dim(dim, this->dim()); + return sizes_.at(static_cast(dim)).as_int_unchecked(); + } + c10::IntArrayRef sizes() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return c10::asIntArrayRefUnchecked(sizes_); + } + int64_t stride(int64_t dim) const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + dim = c10::maybe_wrap_dim(dim, this->dim()); + return strides_.at(static_cast(dim)).as_int_unchecked(); + } + c10::IntArrayRef strides() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return c10::asIntArrayRefUnchecked(strides_); + } + int64_t storage_offset() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return storage_offset_.as_int_unchecked(); + } + int64_t numel() const { + TORCH_INTERNAL_ASSERT(!has_symbolic_sizes_strides_); + return numel_.as_int_unchecked(); + } + + c10::SymInt sym_size(int64_t dim) const { + dim = c10::maybe_wrap_dim(dim, this->dim()); + return sizes_.at(static_cast(dim)); + } + c10::SymIntArrayRef sym_sizes() const { + return sizes_; + } + c10::SymInt sym_stride(int64_t dim) const { + dim = c10::maybe_wrap_dim(dim, this->dim()); + return strides_.at(static_cast(dim)); + } + c10::SymIntArrayRef sym_strides() const { + return strides_; + } + c10::SymInt sym_storage_offset() const { + return storage_offset_; + } + c10::SymInt sym_numel() const { + return numel_; + } + + TensorGeometry transpose(int64_t dim0, int64_t dim1) { + TensorGeometry r = *this; // copy + TORCH_CHECK( + dim0 < dim(), + "transpose: dim0=", + dim0, + " out of range (dim=", + dim(), + ")") + TORCH_CHECK( + dim1 < dim(), + "transpose: dim1=", + dim1, + " out of range (dim=", + dim(), + ")") + std::swap(r.sizes_[dim0], r.sizes_[dim1]); + std::swap(r.strides_[dim0], r.strides_[dim1]); + return r; + } + + std::vector& mutable_sizes() { + return sizes_; + } + std::vector& mutable_strides() { + return strides_; + } + c10::SymInt& mutable_storage_offset() { + return storage_offset_; + } + void recompute() { + // recalculate numel after a change + c10::SymInt numel = 1; + for (const auto& i : sizes_) { + numel = numel * i; + } + numel_ = std::move(numel); + has_symbolic_sizes_strides_ = + !c10::asIntArrayRefSlowOpt(sizes_).has_value(); + } + + private: + std::vector sizes_; + std::vector strides_; + c10::SymInt storage_offset_; + c10::SymInt numel_; + bool has_symbolic_sizes_strides_{false}; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorIndexing.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorIndexing.h new file mode 100644 index 0000000000000000000000000000000000000000..1a2880eda93381eee42fbc89b197812a31f6d516 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorIndexing.h @@ -0,0 +1,742 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#include +#else +#include +#include +#include +#include +#endif + +#include + +#include + +namespace at::indexing { + +constexpr int64_t INDEX_MIN = c10::SymInt::min_representable_int(); +constexpr int64_t INDEX_MAX = -(INDEX_MIN + 1); + +enum class TensorIndexType { None, Ellipsis, SymInt, Boolean, Slice, Tensor }; + +constexpr std::nullopt_t None = std::nullopt; + +struct TORCH_API EllipsisIndexType final { + EllipsisIndexType() = default; +}; +TORCH_API extern const EllipsisIndexType Ellipsis; + +struct TORCH_API Slice final { + public: + Slice( + std::optional start_index = std::nullopt, + std::optional stop_index = std::nullopt, + std::optional step_index = std::nullopt) { + if (!step_index.has_value()) { + step_ = c10::SymInt(1); + } else { + step_ = std::move(step_index).value(); + } + + TORCH_CHECK_VALUE( + step_.sym_ne(0).expect_true(__FILE__, __LINE__), + "slice step cannot be zero"); + + if (!start_index.has_value()) { + start_ = c10::SymInt(step_ < 0 ? INDEX_MAX : 0); + } else { + start_ = std::move(start_index).value(); + } + + if (!stop_index.has_value()) { + stop_ = c10::SymInt(step_ < 0 ? INDEX_MIN : INDEX_MAX); + } else { + stop_ = std::move(stop_index).value(); + } + } + + inline c10::SymInt start() const { + return start_; + } + + inline c10::SymInt stop() const { + return stop_; + } + + inline c10::SymInt step() const { + return step_; + } + + private: + c10::SymInt start_; + c10::SymInt stop_; + c10::SymInt step_; +}; + +TORCH_API std::ostream& operator<<(std::ostream& stream, const Slice& slice); + +// `at::indexing::TensorIndex` is used for converting C++ tensor indices such as +// `{None, "...", Ellipsis, 0, true, Slice(1, None, 2), torch::tensor({1, 2})}` +// into its equivalent `std::vector`, so that further tensor +// indexing operations can be performed using the supplied indices. +// +// There is one-to-one correspondence between Python and C++ tensor index types: +// Python | C++ +// ----------------------------------------------------- +// `None` | `at::indexing::None` +// `Ellipsis` | `at::indexing::Ellipsis` +// `...` | `"..."` +// `123` | `123` +// `True` / `False` | `true` / `false` +// `:` | `Slice()` / `Slice(None, None)` +// `::` | `Slice()` / `Slice(None, None, None)` +// `1:` | `Slice(1, None)` +// `1::` | `Slice(1, None, None)` +// `:3` | `Slice(None, 3)` +// `:3:` | `Slice(None, 3, None)` +// `::2` | `Slice(None, None, 2)` +// `1:3` | `Slice(1, 3)` +// `1::2` | `Slice(1, None, 2)` +// `:3:2` | `Slice(None, 3, 2)` +// `1:3:2` | `Slice(1, 3, 2)` +// `torch.tensor([1, 2])`) | `torch::tensor({1, 2})` +struct TORCH_API TensorIndex final { + // Case 1: `at::indexing::None` + TensorIndex(std::nullopt_t) : type_(TensorIndexType::None) {} + + // Case 2: "..." / `at::indexing::Ellipsis` + TensorIndex(at::indexing::EllipsisIndexType) + : type_(TensorIndexType::Ellipsis) {} + TensorIndex(const char* str) : TensorIndex(at::indexing::Ellipsis) { + TORCH_CHECK_VALUE( + strcmp(str, "...") == 0, + "Expected \"...\" to represent an ellipsis index, but got \"", + str, + "\""); + } + + // Case 3: (Sym) Integer value + TensorIndex(SymInt integer) + : integer_(std::move(integer)), type_(TensorIndexType::SymInt) {} + TensorIndex(int64_t integer) : TensorIndex(SymInt(integer)) {} + TensorIndex(int integer) : TensorIndex(SymInt(integer)) {} + + // Case 4: Boolean value + template >> + TensorIndex(T boolean) : boolean_(boolean), type_(TensorIndexType::Boolean) {} + + // Case 5: Slice represented in `at::indexing::Slice` form + TensorIndex(Slice slice) + : slice_(std::move(slice)), type_(TensorIndexType::Slice) {} + + // Case 6: Tensor value + TensorIndex(Tensor tensor) + : tensor_(std::move(tensor)), type_(TensorIndexType::Tensor) {} + + inline bool is_none() const { + return type_ == TensorIndexType::None; + } + + inline bool is_ellipsis() const { + return type_ == TensorIndexType::Ellipsis; + } + + inline bool is_integer() const { + return type_ == TensorIndexType::SymInt; + } + + inline SymInt integer() const { + return integer_; + } + + inline bool is_boolean() const { + return type_ == TensorIndexType::Boolean; + } + + inline bool boolean() const { + return boolean_; + } + + inline bool is_slice() const { + return type_ == TensorIndexType::Slice; + } + + inline const Slice& slice() const { + return slice_; + } + + inline bool is_tensor() const { + return type_ == TensorIndexType::Tensor; + } + + inline const Tensor& tensor() const { + return tensor_; + } + + private: + SymInt integer_ = 0; + bool boolean_ = false; + Slice slice_; + Tensor tensor_; + TensorIndexType type_; +}; + +TORCH_API std::ostream& operator<<( + std::ostream& stream, + const TensorIndex& tensor_index); +TORCH_API std::ostream& operator<<( + std::ostream& stream, + const std::vector& tensor_indices); + +namespace impl { +inline Tensor applySlice( + const Tensor& self, + int64_t dim, + c10::SymInt start, + c10::SymInt stop, + c10::SymInt step, + bool disable_slice_optimization, + const at::Device& self_device, + const std::optional& self_sizes) { + // TODO: implement negative step + TORCH_CHECK_VALUE( + step.sym_gt(0).expect_true(__FILE__, __LINE__), + "step must be greater than zero"); + + // See NOTE [nested tensor size for indexing] + if (self_sizes.has_value()) { + // Skip this optimization if we are tracing, as the trace may be polymorphic + // over the shape of the `self` tensor, and we still want to record + // the slice. + SymInt length = (self_device == at::kCPU || self_device == at::kCUDA) + ? (*self_sizes)[dim] + : self.sym_size(dim); + if (!disable_slice_optimization && + TORCH_STATICALLY_KNOWN_TRUE(start.sym_eq(0)) && + TORCH_STATICALLY_KNOWN_TRUE(length.sym_eq(stop)) && step == 1) { + return self; + } + } + return self.slice_symint( + dim, std::move(start), std::move(stop), std::move(step)); +} + +inline Tensor applySelect( + const Tensor& self, + int64_t dim, + SymInt index, + int64_t real_dim, + const at::Device& /*self_device*/, + const std::optional& self_sizes) { + // See NOTE [nested tensor size for indexing] + if (self_sizes.has_value()) { + auto maybe_index = index.maybe_as_int(); + if (maybe_index.has_value()) { + TORCH_CHECK_INDEX( + !(maybe_index.value() == 0 && dim == 0 && self_sizes->empty()), + "invalid index of a 0-dim tensor. ", + "Use `tensor.item()` in Python or `tensor.item()` in C++ to convert a 0-dim tensor to a number"); + } + + auto size = (*self_sizes)[dim]; + // Note: `size >= -index` is not equivalent to `size > -1 - index` if index + // is INT64_MIN For std::numeric_limits::min() result of unary + // minus is undefined by the standard but in practice is equal to self. On + // the other hand, indexing wraping is valid for all negative int64_t + // values, as x[INT64_MIN] is the same as x[INT64_MAX] + TORCH_CHECK_INDEX( + size.sym_gt(-1 - index) + .sym_and(size.sym_gt(index)) + .expect_true(__FILE__, __LINE__), + "index ", + index, + " is out of bounds for dimension ", + real_dim, + " with size ", + size); + } + + // if the index is negative, do not normalize it because that would fix the + // index on the current tensor size in the tracer. aten::select also works on + // negative indices + return self.select_symint(dim, std::move(index)); +} + +inline Tensor boolToIndexingTensorCPUOrCUDA(const Tensor& self, bool value) { + // booleans add a dimension of size 1. true indexes this dimension as if 0:, + // false as empty. + if (value) { + return at::empty({1}, self.options().dtype(kLong)).fill_(0.); + } else { + return at::empty({0}, self.options().dtype(kLong)); + } +} + +inline Tensor boolToIndexingTensorNonNativeDeviceType( + const Tensor& self, + bool value) { + // booleans add a dimension of size 1. true indexes this dimension as if 0:, + // false as empty. + if (value) { + return at::zeros({1}, self.options().dtype(kLong)); + } else { + return at::empty({0}, self.options().dtype(kLong)); + } +} + +inline Tensor boolToIndexingTensor( + const Tensor& self, + bool value, + const at::Device& self_device) { + if (self_device == at::kCPU || self_device == at::kCUDA) { + return boolToIndexingTensorCPUOrCUDA(self, value); + } else { + return boolToIndexingTensorNonNativeDeviceType(self, value); + } +} + +inline Tensor scalarToTensorNonNativeDeviceType( + const Scalar& v, + const TensorOptions& options) { + return at::scalar_tensor(v, options); +} + +inline void recordTensorIndex( + const Tensor& tensor, + std::vector& outIndices, + int64_t* dim_ptr) { + // TODO: check scalarType + outIndices.resize(*dim_ptr + 1); + outIndices[*dim_ptr] = tensor; + (*dim_ptr)++; +} + +inline c10::List<::std::optional> typeConvertIndices( + const Tensor& /*self*/, + std::vector&& indices) { + c10::List<::std::optional> converted_inds; + converted_inds.reserve(indices.size()); + for (auto&& i : std::move(indices)) { + converted_inds.push_back(std::move(i)); + } + return converted_inds; +} + +// NOTE: Why do we mirror instead of replace the `count_specified_dimensions` +// function in torch/csrc/autograd/python_variable_indexing.cpp? It's because +// `count_specified_dimensions` is on the hot path of Python tensor multi-dim +// indexing (i.e. it's called by `applySlicing` which is called by +// `THPVariable_getitem` / `THPVariable_setitem` when handling indexing of more +// than one dimension). If we were to merge the Python/C++ +// `count_specified_dimensions` function, on the Python side we would have to +// construct a `std::vector` container to be consumed by the C++ +// `count_specified_dimensions` function, which adds 100s of nanoseconds +// overhead and is undesirable. +inline int64_t count_specified_dimensions( + const ArrayRef& indices) { + // Count the number of indexed dimensions (everything but ellipsis and None) + int64_t count = 0; + for (auto& obj : indices) { + if (obj.is_tensor()) { + auto& tensor = obj.tensor(); + if (tensor.scalar_type() == kByte || tensor.scalar_type() == kBool) { + count += tensor.dim(); + } else { + count++; + } + } else if (!obj.is_none() && !obj.is_ellipsis() && !obj.is_boolean()) { + count++; + } + } + return count; +} +} // namespace impl + +// NOTE: Many functions below are only for consumption from Python indexing +// implementation, they include: +// +// - `Tensor scalarToTensor(...)` +// - `IntArrayRef slicePrefix1sSize(...)` +// - `void copy_to(...)` +// - `Tensor handleDimInMultiDimIndexing(...)` +// - `Tensor dispatch_index(...)` +// - `Tensor dispatch_index_put_(...)` +// - `Tensor get_item(...)` +// - `void set_item(...)` +// +// The rest of the functions are in `at::indexing::impl` namespace, signifying +// that they shouldn't be used from Python indexing implementation. +inline Tensor scalarToTensor( + const Scalar& v, + const TensorOptions& options, + const at::Device& self_device) { + if (self_device == at::kCPU && !v.isSymbolic()) { + return at::detail::scalar_tensor_static( + v, + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + options.dtype_opt()->toScalarType(), + self_device); + } else { + return impl::scalarToTensorNonNativeDeviceType(v, options); + } +} + +// To match numpy semantics: +// As a special case for backwards compatibility, +// strip away unit dimensions from the left of 'src' +inline SymIntArrayRef slicePrefix1sSize(const SymIntArrayRef& sizes) { + size_t first_non1_src = sizes.size(); + for (const auto i : c10::irange(sizes.size())) { + // Unbacked SymInt has different behavior, but this is sound because + // failing to slice will only ever cause an error, not divergent + // behavior + if (!sizes[i].has_hint() || sizes[i] != 1) { + first_non1_src = i; + break; + } + } + + return sizes.slice(first_non1_src); +} + +inline void copy_to(const Tensor& dst, const Tensor& src) { + if (dst.sym_sizes().equals(src.sym_sizes())) { + // A shortcut to avoid generating hard-coded constant sizes during tracing. + // This is not a perfect solution: when src & dst have different shapes, + // constants will still appear. Users can workaround that case by + // dst[index..] = src.reshape(..) + dst.copy_(src); + return; + } else if (src.dim() == 0 && src.device().type() == at::kCPU) { + dst.fill_(src); + return; + } + auto src_view = src.view_symint(slicePrefix1sSize(src.sym_sizes())); + c10::MaybeOwned b_src = expand_inplace(dst, src_view, "setitem"); + dst.copy_(*b_src); +} + +// See NOTE [ Setting `disable_slice_optimization` when calling C++ tensor +// indexing functions from Python ] +inline Tensor handleDimInMultiDimIndexing( + const Tensor& prev_dim_result, + const Tensor& original_tensor, + const TensorIndex& index, + int64_t* dim_ptr, + int64_t* specified_dims_ptr, + int64_t real_dim, + std::vector& outIndices, + bool disable_slice_optimization, + const at::Device& original_tensor_device, + const std::optional& prev_dim_result_sizes) { + if (index.is_integer()) { + return impl::applySelect( + prev_dim_result, + *dim_ptr, + index.integer(), + real_dim, + original_tensor_device, + prev_dim_result_sizes); + } else if (index.is_slice()) { + Tensor result = impl::applySlice( + prev_dim_result, + *dim_ptr, + index.slice().start(), + index.slice().stop(), + index.slice().step(), + /*disable_slice_optimization=*/disable_slice_optimization, + original_tensor_device, + prev_dim_result_sizes); + (*dim_ptr)++; + return result; + } else if (index.is_ellipsis()) { + (*dim_ptr) += original_tensor.dim() - (*specified_dims_ptr); + return prev_dim_result; + } else if (index.is_none()) { + Tensor result = prev_dim_result.unsqueeze(*dim_ptr); + (*dim_ptr)++; + return result; + } else if (index.is_boolean()) { + Tensor result = prev_dim_result.unsqueeze(*dim_ptr); + impl::recordTensorIndex( + impl::boolToIndexingTensor( + result, index.boolean(), original_tensor_device), + outIndices, + dim_ptr); + return result; + } else if (index.is_tensor()) { + Tensor result = prev_dim_result; + const Tensor& tensor = index.tensor(); + auto scalar_type = tensor.scalar_type(); + if (tensor.dim() == 0 && + at::isIntegralType(scalar_type, /*includeBool=*/true)) { + if (scalar_type != at::kByte && scalar_type != at::kBool) { + result = impl::applySelect( + result, + *dim_ptr, + tensor.item(), + real_dim, + original_tensor_device, + prev_dim_result_sizes); + } else { + result = result.unsqueeze(*dim_ptr); + if (scalar_type == at::kBool) { + impl::recordTensorIndex( + impl::boolToIndexingTensor( + result, tensor.item() != 0, original_tensor_device), + outIndices, + dim_ptr); + } else { + impl::recordTensorIndex( + impl::boolToIndexingTensor( + result, tensor.item() != 0, original_tensor_device), + outIndices, + dim_ptr); + } + } + } else { + impl::recordTensorIndex(tensor, outIndices, dim_ptr); + } + return result; + } else { + TORCH_INTERNAL_ASSERT(false, "Invalid TensorIndex type"); + } +} + +namespace impl { +// This mirrors `applySlicing` in +// torch/csrc/autograd/python_variable_indexing.cpp +inline Tensor applySlicing( + const Tensor& self, + const ArrayRef& indices, + std::vector& outIndices, + bool disable_slice_optimization, + const at::Device& self_device, + const std::optional& self_sizes) { + int64_t dim = 0; + int64_t specified_dims = impl::count_specified_dimensions(indices); + + // See NOTE [nested tensor size for indexing] + if (self_sizes.has_value()) { + TORCH_CHECK_INDEX( + specified_dims <= (int64_t)self_sizes->size(), + "too many indices for tensor of dimension ", + (int)self_sizes->size()); + } + + Tensor result = self; + for (const auto i : c10::irange(indices.size())) { + auto& obj = indices[i]; + // See NOTE [nested tensor size for indexing] + std::optional result_sizes = result.is_nested() + ? std::optional(std::nullopt) + : std::optional(result.sym_sizes()); + result = handleDimInMultiDimIndexing( + /*prev_dim_result=*/result, + /*original_tensor=*/self, + /*index=*/obj, + /*dim_ptr=*/&dim, + /*specified_dims_ptr=*/&specified_dims, + /*real_dim=*/static_cast(i), + /*outIndices=*/outIndices, + /*disable_slice_optimization=*/disable_slice_optimization, + /*original_tensor_device=*/self_device, + /*prev_dim_result_sizes=*/result_sizes); + } + return result; +} +} // namespace impl + +inline Tensor dispatch_index( + const Tensor& self, + std::vector&& indices) { + return self.index(impl::typeConvertIndices(self, std::move(indices))); +} + +inline Tensor dispatch_index_put_( + Tensor& self, + std::vector&& indices, + const Tensor& value) { + return self.index_put_( + impl::typeConvertIndices(self, std::move(indices)), value); +} + +// NOTE [ Setting `disable_slice_optimization` when calling C++ tensor indexing +// functions from Python ] +// +// Question: When should we set `disable_slice_optimization` to `true` when +// calling C++ tensor indexing functions from Python indexing code? +// +// Answer: What "slice optimization" means: when we have a slicing expression +// like `x[0:5, 0]`, where the sliced tensor was of size 5 in dimension 0, we +// would skip dispatching the actual slice call as an optimization. However, +// here are the cases where we DON'T want this optimization: +// +// 1. When we are doing 1-D slicing (e.g. `tensor[:]`). +// Reason: we always return a shallow copy for expressions such as +// `tensor[:]` / `tensor[...]` / `tensor[:, :]`. (Note that for `tensor[:, +// :]`, we return an alias of `tensor` by doing the following: +// ``` +// Tensor sliced = impl::applySlicing(self, indices, tensorIndices, +// disable_slice_optimization, self_device, self_sizes); if +// (tensorIndices.empty()) { +// if (sliced.is_same(self)) { +// // ensure we return a shallow copy for things like x[...] +// sliced = at::alias(sliced); +// } +// return sliced; +// } +// ```) +// 2. When we are doing JIT tracing. +// Reason: JIT tracing needs the `self.slice(...)` call to properly trace the +// slice operation. + +// This mirrors `THPVariable_getitem` in +// torch/csrc/autograd/python_variable_indexing.cpp See NOTE [ Setting +// `disable_slice_optimization` when calling C++ tensor indexing functions from +// Python ] +inline Tensor get_item( + const Tensor& self, + const ArrayRef& indices, + bool disable_slice_optimization = false) { + at::Device self_device = self.device(); + // NOTE [nested tensor size for indexing] + // nested tensor does not have a size (yet) so for now we represent its size + // as null may need to be changed after we reach a better solution for nested + // tensor size + std::optional self_sizes = self.is_nested() + ? std::optional(std::nullopt) + : std::optional(self.sym_sizes()); + + // handle simple types: integers, slices, none, ellipsis, bool + if (indices.size() == 1) { + const TensorIndex& index = indices[0]; + if (index.is_integer()) { + return impl::applySelect( + self, 0, index.integer(), 0, self_device, self_sizes); + } else if (index.is_slice()) { + return impl::applySlice( + self, + 0, + index.slice().start(), + index.slice().stop(), + index.slice().step(), + /*disable_slice_optimization=*/true, + self_device, + self_sizes); + } else if (index.is_none()) { + return self.unsqueeze(0); + } else if (index.is_ellipsis()) { + return at::alias(self); + } else if (index.is_boolean()) { + Tensor result = self.unsqueeze(0); + return dispatch_index( + result, + std::vector{impl::boolToIndexingTensor( + result, index.boolean(), self_device)}); + } + } + + std::vector tensorIndices; + Tensor sliced = impl::applySlicing( + self, + indices, + tensorIndices, + disable_slice_optimization, + self_device, + self_sizes); + if (tensorIndices.empty()) { + if (sliced.is_same(self)) { + // ensure we return a shallow copy for things like x[...] + sliced = at::alias(sliced); + } + return sliced; + } + + // indexing by tensors ("advanced" indexing) + return dispatch_index(sliced, std::move(tensorIndices)); +} + +// This mirrors `THPVariable_setitem` in +// torch/csrc/autograd/python_variable_indexing.cpp for "the assigned value is a +// Tensor" case See NOTE [ Setting `disable_slice_optimization` when calling C++ +// tensor indexing functions from Python ] +inline void set_item( + const Tensor& self, + const ArrayRef& indices, + const Tensor& value, + bool disable_slice_optimization = false) { + at::Device self_device = self.device(); + SymIntArrayRef self_sizes = self.sym_sizes(); + + // handle simple types: integers, slices, ellipsis, bool + if (indices.size() == 1) { + const TensorIndex& index = indices[0]; + if (index.is_boolean() && !index.boolean()) { + // do nothing for false (technically we should check the size, but we + // don't have real 0-sized shapes. + return; + } else if (index.is_ellipsis()) { + copy_to(self, value); + return; + } else if (index.is_none() || (index.is_boolean() && index.boolean())) { + copy_to(self.unsqueeze(0), value); + return; + } else if (index.is_integer()) { + copy_to( + impl::applySelect( + self, 0, index.integer(), 0, self_device, self_sizes), + value); + return; + } else if (index.is_slice()) { + copy_to( + impl::applySlice( + self, + 0, + index.slice().start(), + index.slice().stop(), + index.slice().step(), + /*disable_slice_optimization=*/disable_slice_optimization, + self_device, + self_sizes), + value); + return; + } + } + + std::vector tensorIndices; + Tensor sliced = impl::applySlicing( + self, + indices, + tensorIndices, + disable_slice_optimization, + self_device, + self_sizes); + if (tensorIndices.empty()) { + copy_to(sliced, value); + return; + } + + SymIntArrayRef valueSizes = value.sym_sizes(); + SymIntArrayRef slicedValueSizes = slicePrefix1sSize(valueSizes); + Tensor valuesSliced; + if (!valueSizes.equals(slicedValueSizes)) { + valuesSliced = value.view_symint(slicedValueSizes); + } else { + valuesSliced = value; + } + dispatch_index_put_(sliced, std::move(tensorIndices), valuesSliced); + return; +} + +} // namespace at::indexing diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorIterator.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorIterator.h new file mode 100644 index 0000000000000000000000000000000000000000..4e93e93356b1e5639465da44e89893a104411789 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorIterator.h @@ -0,0 +1,1034 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace at { +class Tensor; +class OptionalTensorRef; +using NameVector = SmallVector; +} // namespace at + +// TensorIterator is a helper class for element-wise operations, such as +// arithmetic, comparisons, and trigonometric functions. It handles +// broadcasting and type conversions of operands. +// +// This is inspired by NumPy's Array Iterator API (NpyIter). +// +// The files Loops.h and Loops.cuh provide functions to build kernels that +// use TensorIterator. +// +// Example: +// +// auto iter = TensorIteratorConfig() +// .add_output(output) +// .add_input(input) +// .build() +// +// [MyKernel.cpp / MyKernel.cu] +// cpu_kernel(iter, [](float a, float b) { +// return a + b; +// }); +// +// gpu_kernel(iter, []GPU_LAMBDA(float a, float b) -> float { +// return a + b; +// }); +// +// Note [Order of Construction] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// When setting up the tensor iterator configuration, the output Tensors +// have to be added first via +// TensorIteratorConfig::add_owned_output(at::Tensor). After adding all outputs, +// the inputs can be added via +// TensorIteratorConfig::add_owned_input(at::Tensor). +// Adding another output after inputs have been added will rise an exception. +// +// Note [Common Dtype Computation] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Some operations have a natural notion of a "common dtype" or +// "computation dtype" where all inputs are cast to one dtype, the +// operation is performed, and then the results are cast to all outputs. +// +// TensorIterator infers a common dtype if all inputs have the same dtype, +// and it computes one using type promotion rules on its inputs if +// promote_inputs_to_common_dtype_ is true. Attempting to query +// a common dtype otherwise will throw an exception. +// +// Note that the outputs are not considered when computing a common dtype. + +namespace at { + +namespace internal { +// This parameter is heuristically chosen to determine the minimum number of +// work that warrants parallelism. For example, when summing an array, it is +// deemed inefficient to parallelise over arrays shorter than 32768. Further, +// no parallel algorithm (such as parallel_reduce) should split work into +// smaller than GRAIN_SIZE chunks. +constexpr int64_t GRAIN_SIZE = 32768; + +// Storage for a non-owning Tensor, without needing to include Tensor.h +class TORCH_API OpaqueOptionalTensorRef { + alignas(alignof(TensorBase)) std::array data_{}; + + public: + OpaqueOptionalTensorRef(); + OpaqueOptionalTensorRef(const OpaqueOptionalTensorRef&) = default; + OpaqueOptionalTensorRef& operator=(const OpaqueOptionalTensorRef&) = default; + OpaqueOptionalTensorRef(OpaqueOptionalTensorRef&&) noexcept = default; + OpaqueOptionalTensorRef& operator=(OpaqueOptionalTensorRef&&) noexcept = + default; + ~OpaqueOptionalTensorRef(); + + OptionalTensorRef* get() { + return reinterpret_cast(data_.data()); + } + const OptionalTensorRef* get() const { + return reinterpret_cast(data_.data()); + } + + OptionalTensorRef& operator*() { + return *get(); + } + const OptionalTensorRef& operator*() const { + return *get(); + } + OptionalTensorRef* operator->() { + return get(); + } + const OptionalTensorRef* operator->() const { + return get(); + } + + const Tensor& getTensor() const; +}; +} // namespace internal + +struct TORCH_API OperandInfo { + using StrideVector = SmallVector; + OperandInfo() = default; + C10_ALWAYS_INLINE explicit OperandInfo(c10::MaybeOwned&& t) { + if (t->defined()) { + device = t->device(); + target_dtype = t->scalar_type(); + current_dtype = target_dtype; + } + tensor(std::move(t)); + validate(); + } + + C10_ALWAYS_INLINE OperandInfo(const OperandInfo&) = default; + C10_ALWAYS_INLINE OperandInfo& operator=(const OperandInfo&) = default; + C10_ALWAYS_INLINE OperandInfo(OperandInfo&&) noexcept = default; + C10_ALWAYS_INLINE OperandInfo& operator=(OperandInfo&&) noexcept = default; + C10_ALWAYS_INLINE ~OperandInfo() = default; + + /// The data pointer. This may be different from tensor->data_ptr() if the + /// iterator is split. + void* data = nullptr; + + /// Stride after broadcasting. The stride is in bytes, not number of elements. + StrideVector stride_bytes; + + /// The desired device and type for the operand. For inputs, this specifies + /// that the input should be converted to this type if necessary. For outputs, + /// this specifies which type to allocate. target_dtype and device are + /// initialized with the dtype and device of the tensor but during type + /// promotion target_dtype value can become different from tensor's dtype + /// also, during type promotion target_dtype and device can be set for an + /// undefined tensor so that tensor can be properly constructed later. + std::optional device = std::nullopt; + ScalarType target_dtype = ScalarType::Undefined; + // Caches dtype of the tensor, because scalar_type is an expensive operation + // If dtype of the tensor is changed (e.g. as a result of type promotion or in + // allocate_outputs), this + // value should be changed too. + ScalarType current_dtype = ScalarType::Undefined; + + bool is_device_defined() const { + return device.has_value(); + } + bool is_type_defined() const { + return target_dtype != ScalarType::Undefined; + } + TensorOptions options() const { + return TensorOptions(target_dtype).device(device); + } + + bool is_output = false; + + // will_resize is only for output tensor. + // 1) Functional call(like torch.add(self, other)): output tensor is + // undefined, and pytorch creates a new tensor by using common shape + // and computed stride in TensorIterator; + // 2) Inplace call(like torch.add_(self, other)): output tensor is same + // with input tensor, and can't to modify tensor's size and stride; + // 3) Op call with output(like torch.add(self, other, out = output)): + // output tensor is defined, but tensor shape maybe different with common + // shape. If tensor shape is not same with common shape, this output + // tensor will be resized by using common shape and computed stride in + // TensorIterator. Otherwise can't modify tensor's size and stride. + bool will_resize = false; + + bool is_read_write = false; + + bool is_const = false; + + void validate() { + TORCH_CHECK( + !tensor_base_->defined() || tensor_base_->layout() == kStrided, + "unsupported tensor layout: ", + tensor_base_->layout()); + } + + /// The tensor operand. Note that the strides, data pointer, and + /// other attributes may differ due to dimension reordering and + /// coalescing. + const Tensor& tensor() const { + return tensor_storage_.getTensor(); + } + const TensorBase& tensor_base() const { + return *tensor_base_; + } + void tensor(c10::MaybeOwned&& tensor); + + // Save the original tensor operand in cases when an output is modified + // (e.g. if dtype is changed) + const Tensor& original_tensor() const { + return original_tensor_storage_.getTensor(); + } + const TensorBase& original_tensor_base() const { + return *original_tensor_base_; + } + + // Set tensor to a new value, and store the old tensor value in + // original_tensor Should only ever be called once for the lifetime of an + // operand + void exchange_tensor(c10::MaybeOwned&& new_tensor); + + // Move original_tensor back into tensor, exchange_tensor must have been + // called before + void restore_original_tensor(); + + private: + c10::MaybeOwned tensor_base_; + c10::MaybeOwned original_tensor_base_ = + c10::MaybeOwned::owned(std::in_place); + + // We store TensorBase visibly in the header to allow inline access. + // However, we sometimes need a genuine `const Tensor &` for the + // TensorIterator API. So, we also store a non-owning `Tensor` + // object in these `_storage_` variables. + internal::OpaqueOptionalTensorRef tensor_storage_; + internal::OpaqueOptionalTensorRef original_tensor_storage_; +}; + +struct SplitUntil32Bit; + +enum class FastSetupType : uint8_t { + NONE, + CONTIGUOUS, + CHANNELS_LAST, + NON_OVERLAPPING_DENSE +}; + +class TensorIteratorConfig; +struct TensorIterator; + +struct TORCH_API TensorIteratorBase : public impl::MetaBase { + using DimMask = std::bitset<64>; + using PtrVector = SmallVector; + using StrideVector = SmallVector; + + void build(TensorIteratorConfig&); + + // The inner-loop function operates on the fastest moving dimension. It + // implements element-wise operations in terms of 1-d strided tensors. + // + // Arguments: + // data: data pointers for each operand (length `ntensors`) + // strides: stride for each operand (length `ntensors`) + // size: size of inner loop + // + // The `size` often matches shape[0], but may be smaller due to + // parallelization of the inner loop. + using loop2d_t = c10::function_ref< + void(char** data, const int64_t* strides, int64_t size0, int64_t size1)>; + + using loop_subiter_t = c10::function_ref; + + void foreach_reduced_elt(loop_subiter_t loop, bool parallelize = true); + + int ndim() const { + return static_cast(shape_.size()); + } + IntArrayRef shape() const { + return shape_; + } + int64_t numel() const; + int ntensors() const { + return static_cast(operands_.size()); + } + int noutputs() const { + return num_outputs_; + } + int ninputs() const { + return ntensors() - noutputs(); + } + IntArrayRef view_offsets() const { + return view_offsets_; + } + + /// number of elements in the output operand. this is the same as numel() for + /// operations that are not reductions. + int64_t num_output_elements() const; + + /// number of reduced dimensions in a reduction operation + int num_reduce_dims() const; + + /// 1-dimensional iteration and no buffering or type conversion + bool is_trivial_1d() const; + /// Reducible to 1-dimensional and all operands are contiguous + bool is_contiguous() const; + bool is_dim_reduced(int dim) const; + + /// Accessors for each operand + IntArrayRef strides(int64_t arg) const { + return operands_[arg].stride_bytes; + } + void* data_ptr(int64_t arg) const; + ScalarType dtype(int64_t arg = 0) const { + return operands_[arg].current_dtype; + } + ScalarType common_dtype() const { + TORCH_INTERNAL_ASSERT( + common_dtype_ != ScalarType::Undefined, + "Queried for invalid common dtype!"); + return common_dtype_; + } + ScalarType input_dtype(int64_t arg = 0) const { + return operands_[num_outputs_ + arg].current_dtype; + } + Device device(int64_t arg = 0) const { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return operands_[arg].device.value(); + } + c10::DeviceType device_type(int64_t arg = 0) const { + return device(arg).type(); + } + int64_t element_size(int64_t arg) const { + return static_cast(elementSize(dtype(arg))); + } + bool is_scalar(int64_t arg) const; + bool is_cpu_scalar(int64_t arg) const; + + const TensorBase& tensor_base(int64_t arg) const { + return operands_[arg].tensor_base(); + } + const Tensor& tensor(int64_t arg) const { + return operands_[arg].tensor(); + } + + const TensorBase& output_base(int64_t arg = 0) const { + AT_ASSERT(arg < num_outputs_); + return tensor_base(arg); + } + + const Tensor& output(int64_t arg = 0) const { + AT_ASSERT(arg < num_outputs_); + return tensor(arg); + } + + const TensorBase& input_base(int64_t arg = 0) const { + AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); + return tensor_base(num_outputs_ + arg); + } + const Tensor& input(int64_t arg = 0) const { + AT_ASSERT(arg >= 0 && arg < ntensors() - num_outputs_); + return tensor(num_outputs_ + arg); + } + + // Copies from temporary outputs back to the original outputs + // NOTE: only used on CPU + void cast_outputs(); + + /// Removes an operand from this iterator + void remove_operand(int64_t arg); + /// Shrinks an iterated dimension + void narrow(int dim, int64_t start, int64_t size); + /// Narrows every dim after and including `start_dim` to size one. + void select_all_keeping_dim(int start_dim, IntArrayRef starts); + /// Replaces the data pointer for the operand at index `arg`. + /// The new pointer should have the same sizes, strides and dtype as the + /// original + void unsafe_replace_operand(int64_t arg, void* data); + + /// Splits this TensorIterator into two iterators. Together they iterate over + /// the entire operation. Used by `with_32bit_indexing()`. + std::unique_ptr split(int dim); + + /// Returns the dimension with the largest extent: (size[dim]-1) * stride[dim] + int get_dim_to_split() const; + + template + T scalar_value(int64_t arg) { + auto& op = operands_[arg]; + return c10::fetch_and_cast(op.tensor_base().scalar_type(), op.data); + } + + /// Return scalar value from original_tensor_base if it is defined. When + /// common_dtype is Half, casting scalar input to common_dtype might overflow. + /// If the scalar is aleady given in the type of Half, then return scalar + /// value from tensor_base. + template + T original_scalar_value(int64_t arg) { + auto& original_tensor_base = operands_[arg].original_tensor_base(); + if (original_tensor_base.defined()) { + TORCH_INTERNAL_ASSERT( + original_tensor_base.scalar_type() != common_dtype()); + return c10::fetch_and_cast( + original_tensor_base.scalar_type(), + original_tensor_base.const_data_ptr()); + } else { + return scalar_value(arg); + } + } + + private: + template + auto loop_2d_from_1d(const loop1d_t& loop) { + return + [loop, ntensor = ntensors()]( + char** base, const int64_t* strides, int64_t size0, int64_t size1) { + PtrVector data(base, base + ntensor); + const int64_t* outer_strides = &strides[ntensor]; + for (const auto i : c10::irange(size1)) { + if (i > 0) { + for (const auto arg : c10::irange(ntensor)) { + data[arg] += outer_strides[arg]; + } + } + loop(data.data(), strides, size0); + } + }; + } + + public: + template < + typename loop1d_t, + std::enable_if_t< + std::is_convertible_v< + loop1d_t, + c10::function_ref< + void(char**, const int64_t* strides, int64_t size)>>, + int> = 0> + void for_each(loop1d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE) { + for_each(loop_2d_from_1d(loop), grain_size); + } + + void for_each(loop2d_t loop, int64_t grain_size = at::internal::GRAIN_SIZE); + + void parallel_reduce(loop2d_t loop); + + template < + typename loop1d_t, + std::enable_if_t< + std::is_convertible_v< + loop1d_t, + c10::function_ref< + void(char**, const int64_t* strides, int64_t size)>>, + int> = 0> + void serial_for_each(loop1d_t loop, Range range) { + serial_for_each(loop_2d_from_1d(loop), range); + } + + void serial_for_each(loop2d_t loop, Range range) const; + + /// Create a strides array for a Tensor with shape of this iterator. The + /// parameter `element_size` specifies the size of Tensor's data type in + /// bytes (e.g. `4` for `float`) + StrideVector compatible_stride(int64_t element_size) const; + + /// Inverts the re-ordering done by reorder_dimensions. This can only be + /// called *before* coalesce_dimensions() is called. + DimVector invert_perm(IntArrayRef input) const; + + /// Reapply same re-ordering as it is done by reorder_dimensions. This can + /// only be called *before* coalesce_dimensions() is called. + DimVector apply_perm_and_mul(IntArrayRef input, int mul) const; + + /// Helper functions for CPU iteration + StrideVector get_dim_strides(int dim) const; + StrideVector get_strides() const; + StrideVector get_inner_strides() const { + return get_dim_strides(0); + } + PtrVector get_base_ptrs() const; + + // Helper functions for advanced stride manipulations (e.g. torch.flip) + void _unsafe_set_arg_strides(const int64_t arg, IntArrayRef strides) { + operands_[arg].stride_bytes = strides; + } + void _unsafe_set_arg_data(const int64_t arg, void* data) { + operands_[arg].data = data; + } + + // Helper functions for custom device, custom device can get OperandInfo and + // NameVector in their side. + const OperandInfo& operand(int arg = 0) const { + return operands_[arg]; + } + OperandInfo& operand(int arg = 0) { + return operands_[arg]; + } + NameVector& get_dim_names() { + return names_; + } + const NameVector& get_dim_names() const { + return names_; + } + + /// true if the stride computation can use 32-bit arithmetic. Used by GPU + /// kernels + bool can_use_32bit_indexing() const; + + /// An "iteratable" object that recursively splits this iterator into + /// sub-iterators that can use 32-bit indexing. + SplitUntil32Bit with_32bit_indexing() const; + + /// If the kernel should accumulate into the output. Only relevant for CUDA + /// reductions. + bool should_accumulate() const { + return accumulate_; + } + + /// Whether this iterator produces the actual output, + /// as opposed to something that will be accumulated further. Only relevant + /// for CUDA reductions. + bool is_final_output() const { + return final_output_; + } + + bool has_contiguous_first_dim() const { + if (ndim() == 0) { + return true; + } + + int num_tensors = ntensors(); + for (const auto i : c10::irange(num_tensors)) { + if (strides(i)[0] != element_size(i)) { + return false; + } + } + return true; + } + + void set_output_raw_strided( + int64_t output_idx, + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options, + DimnameList names) override; + +#define TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, maybestatic) \ + maybestatic void methodname( \ + TensorBase&& out, const TensorBase& a, const TensorBase& b) = delete; \ + maybestatic void methodname( \ + const TensorBase& out, TensorBase&& a, const TensorBase& b) = delete; \ + maybestatic void methodname( \ + const TensorBase& out, const TensorBase& a, TensorBase&& b) = delete; \ + maybestatic void methodname( \ + TensorBase&& out, TensorBase&& a, const TensorBase& b) = delete; \ + maybestatic void methodname( \ + TensorBase&& out, const TensorBase& a, TensorBase&& b) = delete; \ + maybestatic void methodname( \ + const TensorBase& out, TensorBase&& a, TensorBase&& b) = delete; \ + maybestatic void methodname( \ + TensorBase&& out, TensorBase&& a, TensorBase&& b) = delete; + +#define TORCH_DISALLOW_TEMPORARIES(methodname) \ + TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, ) + + void build_binary_float_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_borrowing_binary_float_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_float_op) + void build_binary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_borrowing_binary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_binary_op) + void build_unary_float_op(const TensorBase& out, const TensorBase& a); + void build_borrowing_unary_float_op( + const TensorBase& out, + const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_float_op) + void build_unary_op(const TensorBase& out, const TensorBase& a); + // Odd special case needed for pow. Has to borrow the output because + // it's a structured kernel, but the argument is potentially a copy. + void build_output_borrowing_argument_owning_unary_op( + const TensorBase& out, + const TensorBase& a); + void build_borrowing_unary_op(const TensorBase& out, const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_op) + void build_borrowing_unary_force_boolean_op( + const TensorBase& out, + const TensorBase& a); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_unary_force_boolean_op) + void build_comparison_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_borrowing_comparison_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(build_borrowing_comparison_op) + // Another special case: we need to own the second argument for comparison + // ops. + void build_borrowing_except_last_argument_comparison_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + void build_ternary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b, + const TensorBase& c); + +#undef TORCH_DISALLOW_TEMPORARIES + protected: + // Mutable reference as it moves tensors out of TensorIteratorConfig + void populate_operands(TensorIteratorConfig&); + void mark_outputs(); + void mark_resize_outputs(const TensorIteratorConfig&); + void compute_mem_overlaps(const TensorIteratorConfig&); + void compute_shape(const TensorIteratorConfig&); + void compute_strides(const TensorIteratorConfig&); + void reorder_dimensions(); + void permute_dimensions(IntArrayRef perm); + void compute_types(const TensorIteratorConfig&); + ScalarType compute_common_dtype(); + void allocate_or_resize_outputs(); + bool fast_set_up(const TensorIteratorConfig&); + FastSetupType compute_fast_setup_type(const TensorIteratorConfig&); + void compute_names(const TensorIteratorConfig&); + void propagate_names_to_outputs(); + void coalesce_dimensions(); + + protected: + /// Records the "computation" shape of the output tensor. The computation + /// shape is different from the regular shape in a few ways: + /// + /// - The shape may be permuted (via permute_dimensions) so that we + /// process the dimensions in the most computationally efficient order + /// (rather than the logical order given to us by the users.) + /// - The shape may have adjacent dimensions collapsed (via + /// coalesce_dimensions) so that we minimize the number of + /// dimensions we have to explicitly iterate over. For example, + /// a pointwise operation on a contiguous tensor "computationally" + /// consists of only a single dimension. + /// + /// In other words, the computation shape is the output shape as it + /// actually matters for implementing the kernel, but not necessarily the + /// output shape that the user will see in the end. + /// + /// The lifecycle of mutations to shape_ in TensorIterator: + /// - declare_static_shape() sets an initial shape explicitly + /// provided by user, otherwise + /// - compute_shape() computes the true (non-computational) shape + /// specified by the user. + /// - reorder_dimensions() reorders dimensions to improve coalescing. + /// - coalesce_dimensions() then coalesces adjacent dimensions when + /// possible. + /// + /// The shape may also be further modified if we create sub-TensorIterators, + /// e.g., via narrow or select_all_keeping_dim. + DimVector shape_; + + /// Temporarily records the permutation computed by reorder_dimensions. + /// This permutation maps the computation output dimension (dim) to + /// the original true output dimension (perm_[dim]). It is used by + /// invert_perm to undo the permutation. After coalesce_dimensions is + /// called, the permutation is no longer valid (as, in general, there + /// is no permutation that will make computation dimensions to + /// output dimensions); methods that manipulate perm_ are obligated + /// to test that !has_coalesced_dimensions + DimVector perm_; + + /// Has coalesce_dimensions() (or any moral equivalent, e.g., fast_build()) + /// been called? This is SOLELY used to check validity of perm_. + bool has_coalesced_dimensions_ = false; + + /// Whether iteration must be fixed. This disables dimension permuting and + /// also changes how for_each divides work among threads. + bool enforce_linear_iteration_ = false; + + /// The index offsets into the original tensors for each dimension. + /// This is only non-zero when you narrow() a TensorIterator (e.g., + /// when you make sub-TensorIterators). + DimVector view_offsets_; + + /// The computed names of the output tensor. Computed by compute_names() + NameVector names_; + + /// The operands of the TensorIterator: both the inputs and outputs. The + /// outputs MUST come first in the operands_ list. There is always an + /// operand for each output of the TensorIterator, even if TensorIterator + /// will ultimately be responsible for allocating the output; in those + /// cases, tensor is simply undefined (and will be populated later + /// during build()). + /// + /// This list is initially populated prior to build(), but build() mutates + /// OperandInfo to populate more information. + SmallVector operands_; + + /// Number of outputs in operands_ (the length of the outputs prefix + /// in operands_). + int num_outputs_ = 0; + + /// Whether or not all operands have the same shape and are 1d+. Having all + /// the same shape affects whether or not the iterator is eligible for fast + /// setup. + bool all_ops_same_shape_ = false; + /// Whether or not all operands are 0d, this affects type promotion + bool all_ops_are_scalars_ = false; + + /// The "computation" dtype of TensorIterator, specifying what the dtype + /// we will do the internal computation in TensorIterator. Typically, + /// this matches the dtype of the output tensors, but not always! + ScalarType common_dtype_ = ScalarType::Undefined; + + /// This is currently defined as kCPU, or the device of the first non-CPU + /// tensor argument. See TensorIteratorBase::compute_types for details. + Device common_device_ = kCPU; + + /// Set by split(), see should_accumulate() and is_final_output() + bool accumulate_ = false; + bool final_output_ = true; + + // From TensorIteratorConfig + bool is_reduction_ = false; + + /// Set by populate_operands(), says if we're handling meta tensors + bool is_meta_ = false; +}; + +struct TORCH_API TensorIterator final : public TensorIteratorBase { + TensorIterator() : TensorIteratorBase() {} + // Slicing is OK, TensorIterator guaranteed NOT to have any fields + TensorIterator(const TensorIteratorBase& iter) : TensorIteratorBase(iter) {} + +#define TORCH_DISALLOW_TEMPORARIES(methodname) \ + TORCH_DISALLOW_TEMPORARIES_IMPL(methodname, static) + + static TensorIterator binary_float_op( + TensorBase& out, + const TensorBase& a, + const TensorBase& b); + static TensorIterator binary_op( + TensorBase& out, + const TensorBase& a, + const TensorBase& b); + static TensorIterator borrowing_binary_op( + const TensorBase& out, + const TensorBase& a, + const TensorBase& b); + TORCH_DISALLOW_TEMPORARIES(borrowing_binary_op) + static TensorIterator comparison_op( + TensorBase& out, + const TensorBase& a, + const TensorBase& b); + static TensorIterator unary_op(TensorBase& out, const TensorBase& a); + static TensorIterator unary_float_op(TensorBase& out, const TensorBase& a); + static TensorIterator nullary_op(TensorBase& out); + static TensorIterator borrowing_nullary_op(const TensorBase& out); + static TensorIterator borrowing_nullary_op(TensorBase&& out) = delete; + static TensorIterator reduce_op(TensorBase& out, const TensorBase& a); + static TensorIterator reduce_op( + TensorBase& out1, + TensorBase& out2, + const TensorBase& a); +#undef TORCH_DISALLOW_TEMPORARIES +#undef TORCH_DISALLOW_TEMPORARIES_IMPL + + const Tensor& maybe_get_output(int64_t output_idx) override; + void set_output_raw_strided( + int64_t output_idx, + IntArrayRef sizes, + IntArrayRef strides, + TensorOptions options, + DimnameList names) override; +}; + +class TORCH_API TensorIteratorConfig final { + public: + friend struct TensorIteratorBase; + friend struct TensorIterator; + + TensorIteratorConfig() = default; + + C10_DISABLE_COPY_AND_ASSIGN(TensorIteratorConfig); + TensorIteratorConfig(TensorIteratorConfig&&) = default; + TensorIteratorConfig& operator=(TensorIteratorConfig&&) = default; + ~TensorIteratorConfig() = default; + + /// Construction + // Stores input/output Tensors without incrementing the reference count. + // Important: the outputs have to be added before the inputs. + TensorIteratorConfig& add_output(const TensorBase& output) { + return add_borrowed_output(output); + } + TensorIteratorConfig& add_input(const TensorBase& input) { + return add_borrowed_input(input); + } + TensorIteratorConfig& add_const_input(const TensorBase& input) { + return add_borrowed_const_input(input); + } + + // Borrowing from temporaries is unlikely to go well. + TensorIteratorConfig& add_output(TensorBase&& output) = delete; + TensorIteratorConfig& add_input(TensorBase&& input) = delete; + TensorIteratorConfig& add_const_input(TensorBase&& input) = delete; + + // Stores input/output Tensors while incrementing the reference count. + // Note that add_{in,out}put are nearly always what you + // want, and the exception (adding an unnamed temporary) won't + // compile. + TensorIteratorConfig& add_owned_output(const TensorBase& output); + TensorIteratorConfig& add_owned_input(const TensorBase& input); + TensorIteratorConfig& add_owned_const_input(const TensorBase& input); + + // Advanced API: stores input/output Tensors without incrementing + // the reference count. The caller must ensure that these Tensors + // live at least as long as this TensorIteratorConfig and any + // TensorIteratorBase built from this TensorIteratorConfig. + // Important: the outputs have to be added before the inputs. + TensorIteratorConfig& add_borrowed_output(const TensorBase& output); + TensorIteratorConfig& add_borrowed_input(const TensorBase& input); + TensorIteratorConfig& add_borrowed_const_input(const TensorBase& input); + + // Borrowing from temporaries is unlikely to go well. + TensorIteratorConfig& add_borrowed_output(TensorBase&& output) = delete; + TensorIteratorConfig& add_borrowed_input(TensorBase&& input) = delete; + TensorIteratorConfig& add_borrowed_const_input(TensorBase&& input) = delete; + + // Sets the check_mem_overlap_ flag, which is true by default. + // If true, inputs are checked for partial overlap with the outputs and + // outputs are checked for internal overlap (e.g. broadcasted views). An error + // is raised if unacceptable overlap is detected. + // If you're migrating an existing operator to using TensorIterator, please + // consider if the previous implementation checked memory overlap. If it did + // not, and if the operator is idempotent (for example, Tensor.fill_(0)), then + // checking memory overlap is BC-breaking. Please don't check memory overlap + // in that case. + TensorIteratorConfig& set_check_mem_overlap(bool check_mem_overlap) { + check_mem_overlap_ = check_mem_overlap; + return *this; + } + + // Sets the check_all_same_dtype_ flag, which is true by default + // If true, checks that all inputs and defined outputs have the same dtype + // Setting either of promote_inputs_to_common_dtype_ + // or cast_common_dtype_to_outputs_ to true will set + // check_all_same_dtype_ to false. + TensorIteratorConfig& check_all_same_dtype(const bool _check_all_same_dtype) { + check_all_same_dtype_ = _check_all_same_dtype; + return *this; + } + + // Sets the check_all_same_device_ flag, which is true by default + // If true, all operands must be on the same device, with the possible + // exception of CPU scalars, which can be passed to some CUDA kernels + // as kernel arguments. + TensorIteratorConfig& check_all_same_device( + const bool _check_all_same_device) { + check_all_same_device_ = _check_all_same_device; + return *this; + } + + // Sets the enforce_safe_casting_to_output_ flag, which is false by default + // If true, the iterator's "common dtype" must be computable + // (see the [Common Dtype Computation] note) and + // canCast(common dtype, output dtype) must be true for all outputs. + TensorIteratorConfig& enforce_safe_casting_to_output( + const bool _enforce_safe_casting_to_output) { + enforce_safe_casting_to_output_ = _enforce_safe_casting_to_output; + return *this; + } + + // Sets the enforce_linear_iteration_ flag, which is false by default. + // If true, iteration goes in the same order as a C-contiguous tensor + // is layed out in memory. i.e. last dimension iterates fastest. + // + // This iteration order can be less efficient and may even prevent + // vectorization. So only use if the correctness of your kernel depends on it. + TensorIteratorConfig& enforce_linear_iteration( + const bool _enforce_linear_iteration = true) { + enforce_linear_iteration_ = _enforce_linear_iteration; + return *this; + } + + // Sets the promote_inputs_to_common_dtype_ flag, which is false by default + // If true, the iterator's "common dtype" is always computed (see the + // [Common Dtype Computation] note) and, on the CPU, temporary copies of + // the inputs in the common dtype are passed as the actual inputs to + // the operation. + // Setting this flag to true sets check_all_same_dtype_ to false. + TensorIteratorConfig& promote_inputs_to_common_dtype( + const bool _promote_inputs_to_common_dtype) { + promote_inputs_to_common_dtype_ = _promote_inputs_to_common_dtype; + if (_promote_inputs_to_common_dtype) { + check_all_same_dtype_ = false; + } + return *this; + } + + // Sets the promote_integer_inputs_to_float_ flag, which is false by default + // NOTE: If set to true, the promote_inputs_to_common_dtype_ must also be + // true. If true, if the iterator's "common dtype" is an integral type + // (including bool) + // then it is changed to the default float scalar type. + TensorIteratorConfig& promote_integer_inputs_to_float( + const bool _promote_integer_inputs_to_float) { + promote_integer_inputs_to_float_ = _promote_integer_inputs_to_float; + TORCH_INTERNAL_ASSERT( + !promote_integer_inputs_to_float_ || promote_inputs_to_common_dtype_); + return *this; + } + + TensorIteratorConfig& is_reduction(const bool _is_reduction) { + is_reduction_ = _is_reduction; + return *this; + } + + TensorIteratorConfig& allow_cpu_scalars(const bool _allow_cpu_scalars) { + allow_cpu_scalars_ = _allow_cpu_scalars; + return *this; + } + + // Sets the cast_common_dtype_to_outputs_ flag, which is false by default + // If true, the iterator's "common dtype" must be computatable + // (see the [Common Dtype Computation] note) and, on the CPU, temporary + // copies of the outputs are passed as the actual output to the operation. + // These temporaries are then copied to the original outputs after + // the operation is performed (see cast_outputs()). + // Setting this flag to true sets check_all_same_dtype_ to false. + TensorIteratorConfig& cast_common_dtype_to_outputs( + const bool _cast_common_dtype_to_outputs) { + cast_common_dtype_to_outputs_ = _cast_common_dtype_to_outputs; + if (_cast_common_dtype_to_outputs) { + check_all_same_dtype_ = false; + } + return *this; + } + + TensorIteratorConfig& resize_outputs(bool resize_outputs) { + resize_outputs_ = resize_outputs; + return *this; + } + + // Bypass output dtype/device computation and fix the dtype/device as + // specified here. + TensorIteratorConfig& declare_static_dtype_and_device( + ScalarType dtype, + Device device); + TensorIteratorConfig& declare_static_dtype(ScalarType dtype); + TensorIteratorConfig& declare_static_device(Device device); + TensorIteratorConfig& declare_static_shape(IntArrayRef shape); + TensorIteratorConfig& declare_static_shape( + IntArrayRef shape, + IntArrayRef squash_dims); + + // It would be better if this was && qualified, but this would be at the cost + // of a lot of boilerplate above + TensorIterator build() { + TensorIterator iter; + iter.build(*this); + return iter; + } + + private: + bool is_tensor_const(size_t idx); + + SmallVector, 4> tensors_; + int num_outputs_ = 0; + int num_inputs_ = 0; + + std::optional static_shape_ = std::nullopt; + std::optional static_dtype_ = std::nullopt; + std::optional static_device_ = std::nullopt; + bool check_mem_overlap_ = true; + bool allow_cpu_scalars_ = false; + bool is_reduction_ = false; + bool resize_outputs_ = true; + bool check_all_same_dtype_ = true; + bool check_all_same_device_ = true; + bool enforce_safe_casting_to_output_ = false; + bool enforce_linear_iteration_ = false; + bool promote_inputs_to_common_dtype_ = false; + bool promote_integer_inputs_to_float_ = false; + bool cast_common_dtype_to_outputs_ = false; + + SmallVector const_tensor_indices_; +}; + +/// A container-like struct that acts as if it contains splits of a +/// TensorIterator that can use 32-bit indexing. Taken together the splits cover +/// the original TensorIterator. +struct TORCH_API SplitUntil32Bit { + // NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) + struct TORCH_API iterator { + iterator() = default; + iterator(const TensorIteratorBase& iter); + iterator(iterator&&) = default; + iterator& operator=(iterator&&) = default; + ~iterator() = default; + + // Guaranteed to be a TensorIterator proper! + TensorIterator& operator*() const; + iterator& operator++(); + bool operator==(const iterator& other) const { + // two iterators are equal if they are the same object or they're both + // empty + return this == &other || (vec.empty() && other.vec.empty()); + } + // needed for C++11 range-based for loop + bool operator!=(const iterator& other) const { + return !(*this == other); + } + + /// stack of TensorIterators to be split + std::vector> vec; + }; + + SplitUntil32Bit(const TensorIteratorBase& iter) : iter(iter) {} + + iterator begin() const; + iterator end() const; + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const TensorIteratorBase& iter; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorIteratorInternal.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorIteratorInternal.h new file mode 100644 index 0000000000000000000000000000000000000000..ed8c9674a5b5530f23f5448bf4b00bc10af3467d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorIteratorInternal.h @@ -0,0 +1,72 @@ +#pragma once +#include +#include +#include + +namespace at { + +struct DimCounter { + DimCounter(IntArrayRef shape, Range range); + + void increment(const std::array& step); + bool is_done() const; + std::array max_2d_step() const; + + IntArrayRef shape; + Range range; + c10::SmallBuffer values; + int64_t offset; +}; + +namespace internal { + +inline void get_data_ptrs( + char** ptrs, + ArrayRef base, + IntArrayRef strides, + IntArrayRef counter) { + const auto ntensors = base.size(); + const auto ndim = counter.size(); + std::copy(base.begin(), base.end(), ptrs); + for (const auto dim : c10::irange(ndim)) { + int64_t value = counter[dim]; + for (const auto arg : c10::irange(ntensors)) { + ptrs[arg] += value * strides[dim * ntensors + arg]; + } + } +} + +inline void serial_for_each( + IntArrayRef shape, + IntArrayRef strides, + char** base_ptrs, + size_t ntensors, + typename TensorIteratorBase::loop2d_t loop, + Range range) { + const auto ndim = shape.size(); + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + strides.size() == ntensors * std::max(size_t{2}, ndim)); + + if (ndim <= 1) { + if (range.begin == 0) { + loop(base_ptrs, strides.data(), range.size(), 1); + } else { + c10::SmallBuffer ptrs(ntensors); + get_data_ptrs(ptrs.data(), {base_ptrs, ntensors}, strides, {range.begin}); + loop(ptrs.data(), strides.data(), range.size(), 1); + } + } else { + c10::SmallBuffer ptrs(ntensors); + auto counter = DimCounter(shape, range); + while (!counter.is_done()) { + get_data_ptrs( + ptrs.data(), {base_ptrs, ntensors}, strides, counter.values); + auto step = counter.max_2d_step(); + loop(ptrs.data(), strides.data(), step[0], step[1]); + counter.increment(step); + } + } +} + +} // namespace internal +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorMeta.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorMeta.h new file mode 100644 index 0000000000000000000000000000000000000000..8c3ba35e4922b0bf026f26bc8da4036fb6515a14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorMeta.h @@ -0,0 +1,137 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { + +class Tensor; + +namespace impl { + +// Use this to define the prototype for a meta function. There are two +// versions; one that takes one argument (just the operator name), or FUNC2 +// variant that takes two arguments (operator name and overload name). +// +// Example usage: +// +// TORCH_META_FUNC2(add, Tensor) ( +// const Tensor& self, const Tensor& other +// ) { +// ... compute sizes and options ... +// set_output(sizes, options); +// } +// +#define TORCH_META_FUNC(name) void structured_##name::meta +#define TORCH_META_FUNC2(name, overload) \ + void structured_##name##_##overload::meta + +// These are versions of TORCH_META_FUNC(2) that include a precompute_out struct +// as a return value. They should be used when the kernel in question has +// precomputed values declared in native_functions.yaml and the corresponding +// implementation should return an instance of the aforementioned struct. +#define TORCH_PRECOMPUTE_META_FUNC(name) \ + structured_##name::meta_return_ty structured_##name::meta +#define TORCH_PRECOMPUTE_META_FUNC2(name, overload) \ + structured_##name##_##overload::meta_return_ty \ + structured_##name##_##overload::meta + +// Use this to create a precompute struct in a meta function. +#define TORCH_PRECOMPUTE_STRUCT(name) structured_##name::precompute_out<> +#define TORCH_PRECOMPUTE_STRUCT2(name, overload) \ + structured_##name##_##overload::precompute_out<> + +// Use this to define the prototype for an implementation. This takes only +// one argument, which is the name of the dispatch key entry you're +// implementing. +// +// Example usage: +// +// TORCH_IMPL_FUNC(add_cpu) ( +// Tensor& result, const Tensor& self, const Tensor& other +// ) { +// ... do the actual implementation ... +// } +// +#define TORCH_IMPL_FUNC(name) void structured_##name::impl + +// Base class for all structured kernel classes. The set_output virtual +// method is varied depending whether or not the operator is +// functional/out/inplace, and could also be specialized for CPU/CUDA/etc +// (although presently it isn't). +// +// A notable subclass of this interface is TensorIteratorBase. +struct TORCH_API MetaBase { + MetaBase() = default; + MetaBase(const MetaBase&) = default; + MetaBase& operator=(const MetaBase&) = default; + MetaBase(MetaBase&&) noexcept = default; + MetaBase& operator=(MetaBase&&) noexcept = default; + virtual const Tensor& maybe_get_output(int64_t output_idx) = 0; + + // Note: [set_output_*] + // See: https://github.com/pytorch/pytorch/issues/69813 + // Whenever defining the output properties in the META function of a + // structured kernel (what was usually done with `set_output`), use one of + // these 3 variants, instead. In order to decide which variant to use, check + // the following decision tree: + // + // - Can the kernel you are going to implement support output tensors + // with arbitrary strides? + // | + // -- YES: `set_output_raw_strided` + // | + // -- NO: Should the output tensor strides be contiguous? + // | + // -- YES: `set_output_contiguous` + // | + // -- NO: `set_output_strided` + // + // Use this function whenever the kernel requires specific strides for the + // output. If `strides` does not match the given output strides, proxy outputs + // will be created and passed to the IMPL function. + virtual void set_output_strided( + int64_t output_idx [[maybe_unused]], + IntArrayRef sizes [[maybe_unused]], + IntArrayRef strides [[maybe_unused]], + TensorOptions options [[maybe_unused]], + DimnameList names [[maybe_unused]] = {}) { + TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented."); + } + + // Use this function whenever the kernel knows how to handle arbitrary strided + // outputs. This function has the same behavior as the old `set_output`: it + // will only re-stride if the given output was resized. + virtual void set_output_raw_strided( + int64_t output_idx [[maybe_unused]], + IntArrayRef sizes [[maybe_unused]], + IntArrayRef strides_hint [[maybe_unused]], + TensorOptions options [[maybe_unused]], + DimnameList names [[maybe_unused]] = {}) { + TORCH_INTERNAL_ASSERT(false, "set_output_strided not implemented."); + } + + // Use this function if the kernel requires contiguous strides. + // Alias for `set_output_strided`, but with contiguous strides. + void set_output_contiguous( + int64_t output_idx, + IntArrayRef sizes, + TensorOptions options, + DimnameList names = {}) { + auto strides = c10::contiguous_strides(sizes); + set_output_strided(output_idx, sizes, strides, options, names); + } + + // Returns a reference to an undefined tensor if there is no presupplied + // output + const Tensor& maybe_get_output() { + return maybe_get_output(0); + } + virtual ~MetaBase() = default; +}; + +} // namespace impl + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorNames.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorNames.h new file mode 100644 index 0000000000000000000000000000000000000000..adbbd1b16a57c14e4267adb85c5932c54e603234 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorNames.h @@ -0,0 +1,75 @@ +#pragma once + +#include + +namespace at::namedinference { + +// TensorName and TensorNames are wrappers around Dimname and DimnameList +// that contain helper functions to make writing name inference rules easier. +// +// A TensorName represents a Dimname associated with some DimnameList (from a +// Tensor). This encapsulates all the information that is needed to check if +// names *match* and to *unify* names. +// +// Definition: Two names in two tensors *match* if they are equal, or if at +// least one of them is a wildcard that can be *refined* to the other name. +// +// Definition: unify(name, other) fails if the names do not match. Otherwise, +// it returns the most refined of name and other. +// +// Here is an example of checking if two names match. +// tensor: Tensor[A, None] +// other: Tensor[A] +// +// Let's say we wish to check if tensor.names[-1] matches other.names[-1]. +// None (in tensor) cannot match A (in other) because if the None were refined +// to A, `tensor` would have duplicate names [A, A]. Therefore we need to check +// tensor.names [A, None] for the existence of A. +struct TORCH_API TensorName { + explicit TensorName(ArrayRef origin, int origin_idx) + : origin_(origin), + name_(origin[maybe_wrap_dim( + origin_idx, + static_cast(origin.size()))]), + origin_idx_(origin_idx) {} + + // op_name is only used for error reporting. + const TensorName& unify(const TensorName& other, const char* op_name) const; + Dimname toDimname() const; + + private: + ArrayRef origin_; + Dimname name_; + int origin_idx_; // A named tensor can have at most 64 dims. + + TORCH_API friend std::ostream& operator<<( + std::ostream& out, + const TensorName& tensorname); +}; + +using TensorNameVec = SmallVector; + +struct TORCH_API TensorNames { + explicit TensorNames(ArrayRef names); + + // Create TensorNames from names[start:end]. Each individual TensorName stores + // `names`, NOT names[start:end], because the original tensor's names are + // `names`. + explicit TensorNames(ArrayRef names, int64_t start, int64_t end); + + // op_name is only used for error reporting. + TensorNames& unifyFromRightInplace( + const TensorNames& other, + const char* op_name = "unify"); + void checkUnique(const char* op_name) const; + + void append(TensorName name); + std::vector toDimnameVec() const; + + private: + explicit TensorNames(TensorNameVec&& names) : names_(std::move(names)) {} + + TensorNameVec names_; +}; + +} // namespace at::namedinference diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorOperators.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorOperators.h new file mode 100644 index 0000000000000000000000000000000000000000..096f7777dc538364b47ad0f9a4c2954efc11efbc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorOperators.h @@ -0,0 +1,51 @@ +#pragma once + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at { + +#define AT_FORALL_BINARY_OPS(_) \ + _(+, x.add(y), y.add(x)) \ + _(*, x.mul(y), y.mul(x)) \ + _(-, \ + x.sub(y), \ + ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).sub_(y)) \ + _(/, \ + x.div(y), \ + ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).div_(y)) \ + _(%, \ + x.remainder(y), \ + ::at::empty_like(y, at::MemoryFormat::Preserve).fill_(x).remainder_(y)) \ + _(&, x.bitwise_and(y), y.bitwise_and(x)) \ + _(|, x.bitwise_or(y), y.bitwise_or(x)) \ + _(^, x.bitwise_xor(y), y.bitwise_xor(x)) \ + _(<, x.lt(y), y.gt(x)) \ + _(<=, x.le(y), y.ge(x)) \ + _(>, x.gt(y), y.lt(x)) \ + _(>=, x.ge(y), y.le(x)) \ + _(==, x.eq(y), y.eq(x)) \ + _(!=, x.ne(y), y.ne(x)) + +#define DEFINE_OPERATOR(op, body, reverse_scalar_body) \ + inline Tensor operator op(const Tensor& x, const Tensor& y) { \ + return body; \ + } \ + inline Tensor operator op(const Tensor& x, const Scalar& y) { \ + return body; \ + } \ + inline Tensor operator op(const Scalar& x, const Tensor& y) { \ + return reverse_scalar_body; \ + } + +AT_FORALL_BINARY_OPS(DEFINE_OPERATOR) +#undef DEFINE_OPERATOR +#undef AT_FORALL_BINARY_OPS + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorOptions.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorOptions.h new file mode 100644 index 0000000000000000000000000000000000000000..0ff746e88800b39930641108b39a92661a5d8257 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorOptions.h @@ -0,0 +1,2 @@ +#pragma once +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..7cf9dfc65497f09969f8a61b3ee4fb227db7456d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorSubclassLikeUtils.h @@ -0,0 +1,88 @@ +#pragma once +#include +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at { + +// Note [Tensor-subclass-like Tensors] +// Tensor-subclass-like is defined as: +// - a Tensor subclass (via __torch_dispatch__ in Python or extending +// TensorImpl in C++) +// - anything else that shares the same perils as Tensor subclasses. +// For example, many Tensor subclasses do not have storage and meta Tensors +// do not have storage either, so meta Tensors belong here. +// +// We should ensure that PyTorch internals supports Tensor-subclass-like +// objects. In particular, Tensor-subclass-like objects struggle with two +// classes of operations that are problematic for Tensor subclasses: +// 1. Because some Tensor subclasses do not have storage, .item() or +// .data_ptr() calls are not good. +// 2. Certain in-place operations can eliminate the typing of the Tensor +// subclass. For example: +// >>> torch.zeros(input.sizes(), grad.options()).diag().copy_(input) +// If input is a Tensor subclass, then the above ends up either erroring out +// or returning a regular non-Tensor-subclass Tensor! + +constexpr auto kFunctorchWrappedTensors = DispatchKeySet( + {DispatchKey::FuncTorchGradWrapper, + DispatchKey::FuncTorchBatched, + DispatchKey::Functionalize}); + +constexpr auto kTensorSubclassLike = + kFunctorchWrappedTensors | + DispatchKeySet( + {// WARNING: DO NOT put combined backend component + functionality keys + // here, you will incorrectly always match on the functionality key + // no matter the backend component + DispatchKey::Batched, + DispatchKey::Sparse, + DispatchKey::SparseCsr, + DispatchKey::Python}) | + DispatchKeySet(BackendComponent::MetaBit); + +inline bool isTensorSubclassLike(const Tensor& tensor) { + if (c10::impl::dispatch_mode_enabled()) + return true; + auto key_set = tensor.unsafeGetTensorImpl()->key_set(); + return !(key_set & kTensorSubclassLike).empty(); +} + +inline bool areAnyTensorSubclassLike(TensorList tensors) { + if (c10::impl::dispatch_mode_enabled()) + return true; + return std::any_of(tensors.begin(), tensors.end(), isTensorSubclassLike); +} + +inline bool areAnyOptionalTensorSubclassLike( + const c10::List>& tensors) { + if (c10::impl::dispatch_mode_enabled()) + return true; + return std::any_of( + tensors.begin(), + tensors.end(), + [](const std::optional& opt_tensor) { + return ( + opt_tensor.has_value() && isTensorSubclassLike(opt_tensor.value())); + }); +} + +// Helper function to deal testing truthfulness of a scalar tensor +// in a Composite Compliant manner. +// NOTE: This function expects a scalar tensor of boolean dtype. +// Eg. +// Non-Composite Compliant Pattern : (t == 0).all().item() +// Composite Compliant Patter : is_salar_tensor_true((t == 0).all()) +inline bool is_scalar_tensor_true(const Tensor& t) { + TORCH_INTERNAL_ASSERT(t.dim() == 0) + TORCH_INTERNAL_ASSERT(t.scalar_type() == kBool) + return at::equal(t, t.new_ones({}, t.options())); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TensorUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/TensorUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..6030f8138047d4b9121a0e4d3944edfc4daf63be --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TensorUtils.h @@ -0,0 +1,190 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +// These functions are NOT in Utils.h, because this file has a dep on Tensor.h + +#define TORCH_CHECK_TENSOR_ALL(cond, ...) \ + TORCH_CHECK((cond)._is_all_true().item(), __VA_ARGS__); + +namespace at { + +// The following are utility functions for checking that arguments +// make sense. These are particularly useful for native functions, +// which do NO argument checking by default. + +struct TORCH_API TensorArg { + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const Tensor& tensor; + const char* name; + int pos; // 1-indexed + TensorArg(const Tensor& tensor, const char* name, int pos) + : tensor(tensor), name(name), pos(pos) {} + // Try to mitigate any possibility of dangling reference to temporaries. + // NOLINTNEXTLINE(cppcoreguidelines-rvalue-reference-param-not-moved) + TensorArg(Tensor&& tensor, const char* name, int pos) = delete; + const Tensor* operator->() const { + return &tensor; + } + const Tensor& operator*() const { + return tensor; + } +}; + +struct TORCH_API TensorGeometryArg { + TensorGeometry tensor; + const char* name; + int pos; // 1-indexed + /* implicit */ TensorGeometryArg(TensorArg arg) + : tensor(TensorGeometry{arg.tensor}), name(arg.name), pos(arg.pos) {} + TensorGeometryArg(TensorGeometry tensor, const char* name, int pos) + : tensor(std::move(tensor)), name(name), pos(pos) {} + const TensorGeometry* operator->() const { + return &tensor; + } + const TensorGeometry& operator*() const { + return tensor; + } +}; + +// A string describing which function did checks on its input +// arguments. +// TODO: Consider generalizing this into a call stack. +using CheckedFrom = const char*; + +// The undefined convention: singular operators assume their arguments +// are defined, but functions which take multiple tensors will +// implicitly filter out undefined tensors (to make it easier to perform +// tests which should apply if the tensor is defined, and should not +// otherwise.) +// +// NB: This means that the n-ary operators take lists of TensorArg, +// not TensorGeometryArg, because the Tensor to TensorGeometry +// conversion will blow up if you have undefined tensors. + +TORCH_API std::ostream& operator<<( + std::ostream& out, + const TensorGeometryArg& t); +TORCH_API void checkDim( + CheckedFrom c, + const Tensor& tensor, + const char* name, + int pos, // 1-indexed + int64_t dim); +TORCH_API void checkDim(CheckedFrom c, const TensorGeometryArg& t, int64_t dim); +// NB: this is an inclusive-exclusive range +TORCH_API void checkDimRange( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t dim_start, + int64_t dim_end); +TORCH_API void checkSameDim( + CheckedFrom c, + const TensorGeometryArg& t1, + const TensorGeometryArg& t2); +TORCH_API void checkContiguous(CheckedFrom c, const TensorGeometryArg& t); +TORCH_API void checkAllContiguous(CheckedFrom c, at::ArrayRef ts); +TORCH_API void checkSize( + CheckedFrom c, + const TensorGeometryArg& t, + IntArrayRef sizes); +TORCH_API void checkSize_symint( + CheckedFrom c, + const TensorGeometryArg& t, + c10::SymIntArrayRef sizes); +TORCH_API void checkSize( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t dim, + int64_t size); +TORCH_API void checkSize_symint( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t dim, + const c10::SymInt& size); +TORCH_API void checkNumel( + CheckedFrom c, + const TensorGeometryArg& t, + int64_t numel); +TORCH_API void checkSameNumel( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameNumel(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkScalarType(CheckedFrom c, const TensorArg& t, ScalarType s); +TORCH_API void checkScalarTypes( + CheckedFrom c, + const TensorArg& t, + at::ArrayRef l); +TORCH_API void checkSameGPU( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameGPU(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkSameType( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameType(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkSameSize( + CheckedFrom c, + const TensorArg& t1, + const TensorArg& t2); +TORCH_API void checkAllSameSize(CheckedFrom c, ArrayRef tensors); +TORCH_API void checkDefined(CheckedFrom c, const TensorArg& t); +TORCH_API void checkAllDefined(CheckedFrom c, at::ArrayRef t); + +// FixMe: does TensorArg slow things down? +TORCH_API void checkBackend( + CheckedFrom c, + at::ArrayRef t, + at::Backend backend); + +TORCH_API void checkDeviceType( + CheckedFrom c, + at::ArrayRef tensors, + at::DeviceType device_type); + +TORCH_API void checkLayout(CheckedFrom c, const Tensor& t, Layout layout); + +TORCH_API void checkLayout( + CheckedFrom c, + at::ArrayRef tensors, + at::Layout layout); + +// Methods for getting data_ptr if tensor is defined +TORCH_API void* maybe_data_ptr(const Tensor& tensor); +TORCH_API void* maybe_data_ptr(const TensorArg& tensor); + +TORCH_API void check_dim_size( + const Tensor& tensor, + int64_t dim, + int64_t dim_size, + int64_t size); + +namespace detail { +TORCH_API std::vector defaultStrides(IntArrayRef sizes); + +TORCH_API std::optional> computeStride( + IntArrayRef oldshape, + IntArrayRef oldstride, + IntArrayRef newshape); + +TORCH_API std::optional computeStride( + c10::SymIntArrayRef oldshape, + c10::SymIntArrayRef oldstride, + c10::SymIntArrayRef newshape); + +TORCH_API std::optional computeStride( + IntArrayRef oldshape, + IntArrayRef oldstride, + const DimVector& newshape); + +} // namespace detail +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h b/phivenv/Lib/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h new file mode 100644 index 0000000000000000000000000000000000000000..efa50483c44c8ce7df838d1a86ece780e0d5d3a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ThreadLocalPythonObjects.h @@ -0,0 +1,21 @@ +#pragma once + +#include +#include +#include + +namespace at::impl { + +struct TORCH_API ThreadLocalPythonObjects { + static void set(const std::string& key, std::shared_ptr value); + static const std::shared_ptr& get(const std::string& key); + static bool contains(const std::string& key); + + static const ThreadLocalPythonObjects& get_state(); + static void set_state(ThreadLocalPythonObjects state); + + private: + std::unordered_map> obj_dict_; +}; + +} // namespace at::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ThreadLocalState.h b/phivenv/Lib/site-packages/torch/include/ATen/ThreadLocalState.h new file mode 100644 index 0000000000000000000000000000000000000000..84b903ffe708c68859a6b263023a62b4e8236426 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ThreadLocalState.h @@ -0,0 +1,124 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace at { + +// Thread local state contains values that are preserved across +// thread boundaries (e.g. at::launch/JIT fork, autograd). +// Note at::parallel_for doesn't preserve TLS across thread boundaries. +class TORCH_API ThreadLocalState { + public: + // Saves the thread local variables' values and + // returns them as a ThreadLocalState + ThreadLocalState(); + + // set_grad_mode - force the value of the grad mode TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_grad_mode(bool enabled); + + // set_multithreading_enabled - force the value of the multithreadinmaximum + // threads TLS in + // the current state object. This is used for example in the + // autograd engine. + void set_multithreading_enabled(bool enabled); + + // Sets thread local variables in the current thread, + // according to the thread boundary specified + static void setThreadLocalState(const ThreadLocalState& state); + + private: + c10::impl::LocalDispatchKeySet dispatch_key_; + + // ThreadLocalDebugInfo does not change after being created + // with DebugInfoGuard + std::shared_ptr debug_info_; + + // RecordFunction TLS + RecordFunctionTLS rf_tls_; + + // TLS for out-of-tree functorch + // See NOTE [functorch TLS in pytorch/pytorch] for why this needs to be a + // pointer (spoiler alert: it's due to the indirection) + // This needs to be a shared_ptr instead of a unique_ptr because + // ThreadLocalState is copy-able and does indeed get copied. Maybe we can + // consider adding an explicit copy constructor for ThreadLocalState in the + // future but I didn't want to add one just for this. + std::shared_ptr functorch_tls_; + + // TLS for AutogradModes + AutogradState autograd_tls_; + + // TLS for enable_torch_dispatch_mode + c10::impl::TorchDispatchModeTLS torch_dispatch_mode_state_; + + // TLS for enable_python_dispatcher + c10::impl::PyInterpreter* python_dispatcher_state_; + + // TLS for __torch_function__ (mode and disable_torch_function) + at::impl::PythonTorchFunctionTLS python_torch_function_state_; + + // TLS for saved tensors default hooks + at::impl::SavedTensorDefaultHooksTLS saved_tensors_default_hooks_state_; + + bool functionalization_reapply_views_state_; + + // TLS for arbitrary python objects that is registered via hooks + at::impl::ThreadLocalPythonObjects saved_objects_; + +#if !defined(CAFFE2_IS_XPLAT_BUILD) && !defined(C10_MOBILE) && \ + !defined(BUILD_LITE_INTERPRETER) + // TLS for autocast dtypes + std::array + autocast_dtypes_{}; +#endif + + friend class ThreadLocalStateGuard; +}; + +// Guard to set and reset the thread local state +class TORCH_API ThreadLocalStateGuard { + public: + explicit ThreadLocalStateGuard(const ThreadLocalState& state) + : prev_state_(ThreadLocalState()) { + // set the given state across the thread boundary + ThreadLocalState::setThreadLocalState(state); + } + ThreadLocalStateGuard(ThreadLocalStateGuard&& other) = delete; + ThreadLocalStateGuard(const ThreadLocalStateGuard&) = delete; + ThreadLocalStateGuard& operator=(const ThreadLocalStateGuard&) = delete; + ThreadLocalStateGuard& operator=(ThreadLocalStateGuard&&) = delete; + + ~ThreadLocalStateGuard() { + // restore previously set variables + ThreadLocalState::setThreadLocalState(prev_state_); + } + + private: + // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members) + const ThreadLocalState prev_state_; +}; + +template +auto wrapPropagateTLSState(T callback) { + return [tls_state = ThreadLocalState(), + callback = std::move(callback)](auto&&... args) { + ThreadLocalStateGuard g(tls_state); + // Propagate value returned by callback(). + return callback(std::forward(args)...); + }; +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TracerMode.h b/phivenv/Lib/site-packages/torch/include/ATen/TracerMode.h new file mode 100644 index 0000000000000000000000000000000000000000..2ed0fc0048a0ada85e3539bfd23cf69dd9df83e5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TracerMode.h @@ -0,0 +1,132 @@ +#pragma once + +#include +#include +#include + +// NOTE [Tracing Mode Switches] +// +// Historically, tracing function was controlled by two switches: +// +// - `AutoDispatchBelowADInplaceOrView` guard +// +// Tracing function used to be script-generated inside `VariableType_*.cpp` +// kernels, sharing the same `Autograd` dispatch key with autograd function. +// Therefore, before tracing function was moved out of VariableType, +// `AutoDispatchBelowADInplaceOrView` guard can also disable tracing as a +// side effect of disabling `Autograd` dispatching. +// +// - `setTracingState()` API in `torch/csrc/jit/frontend/tracer.h` +// +// It stores tracing data in a `TracingState` object in TLS. If the +// `TracingState` object in TLS is `null`, then tracing is paused. +// +// The `TracingState` object is created in `tracer::trace()` - the main +// entrance of tracing function. It's temporarily set to `null` inside +// generated VariableType (now TraceType) to bypass tracing for intermediate +// ops (ops being called by other ops). After the intermediate op call +// finishes it's set back to the original `TracingState` object. +// +// The `TracingState` obect in TLS can also be read/written via its Python +// binding in `python_tracer.cpp`, and `get/setTracingState()` C++ APIs, +// which are also exposed as `TORCH_API`. +// +// Two new switches were introduced since tracing function was moved out of +// VariableType: +// +// - `tracer::impl::set_dispatch_enabled()` API +// +// Unlike the special `Autograd` dispatch key which is included in dispatch +// key set by default, `Tracer` dispatch key is off by default. The +// dispatching switch can be toggled via this new API. +// +// - `tracer::impl::NoTracerDispatchMode` guard +// +// It's used to cover the old semantics of `AutoDispatchBelowADInplaceOrView` +// after tracing was moved out of VariableType. +// +// Before tracing function was moved out of VariableType, tracing was enabled +// when the following conditions are satisfied: +// +// 1) `TracingState` object in TLS != null; +// - Either inside the execution scope of `tracer::trace()`, or +// - Eagerly called `setTracingState()` with non-null object. +// 2) Not inside `AutoDispatchBelowADInplaceOrView` scope; +// +// After: +// +// 1) `TracingState` object in TLS != null; +// 2) Has called `tracer::impl::set_dispatch_enabled(true)`; +// 3) Not inside `tracer::impl::NonDispatchGuard` scope; +// +// [TODOs] +// +// - `setTracingState()` v.s. `tracer::impl::set_dispatch_enabled()` +// +// Currently `set_dispatch_enabled()` is set/unset inside `setTracingState()` +// to keep the semantics exactly the same as before - it's confusing to keep +// both switches, though. We should consider simplifying/limiting the exposed +// `setTracingState()` Python/C++ APIs (and other APIs calling it) so that +// these two can be unified. +// +// - `AutoDispatchBelowADInplaceOrView` v.s. +// `tracer::impl::NoTracerDispatchMode` +// +// We don't need to always set both guards together to keep semantics +// unchanged. For the follow use cases of `AutoDispatchBelowADInplaceOrView` +// we don't need set the new tracer guard: +// +// * Script-generated VariableType kernels. The guard is not necessary as +// tracing is already disabled explicitly by `setTracingState(null)` in +// generated TraceType kernels - we could keep it as is or use the new guard +// instead. +// +// * Custom ops. Will be handled by fallback kernel for `Tracer`. +// +// * Functions that are not likely to be called in tracing context (no python +// binding / not an operator), e.g.: all mobile forward() wrappers, test +// binaries, and etc. +// +// * Where new threads are spawned, e.g.: ATen/native/ConvolutionMM2d.cpp. +// It's not necessary as tracing is off by default. +// +// For the rest of cases we might need have both: +// +// * Functions that might be reachable from eager mode python (especially +// factory methods), e.g.: +// `internal_new_from_data()` in `torch/csrc/utils/tensor_new.cpp`. +// Without the new guard it will add `aten::empty` to the traced graph. +// +// * Some manually maintained functions, e.g.: +// `torch/csrc/autograd/VariableTypeManual.cpp`. +// Set the new guard if it's not obvious whether `setTracingState(null)` +// has been called before it reaches the `AutoDispatchBelowADInplaceOrView` +// guard. +// +// We might need tweak the usage of the new guard to optimize/fix things. +// It should only affect the correctness of tracing function, because the +// guard is essentially no-op when the master `setTracingState()` switch is +// off. + +// TODO: move this from `at::` to `jit::torch::` after +// `aten/src/ATen/cpp_custom_type_hack.h` is removed. + +namespace at::tracer::impl { + +inline bool is_dispatch_enabled() { + return c10::impl::tls_is_dispatch_key_included(at::DispatchKey::Tracer) && + !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer); +} + +inline void set_dispatch_enabled(bool enabled) { + TORCH_INTERNAL_ASSERT( + !c10::impl::tls_is_dispatch_key_excluded(at::DispatchKey::Tracer), + "Cannot enable tracing within the scope of NoTracerDispatchMode!"); + c10::impl::tls_set_dispatch_key_included(at::DispatchKey::Tracer, enabled); +} + +struct NoTracerDispatchMode { + c10::impl::ExcludeDispatchKeyGuard guard_{at::DispatchKey::Tracer}; +}; + +} // namespace at::tracer::impl diff --git a/phivenv/Lib/site-packages/torch/include/ATen/TypeDefault.h b/phivenv/Lib/site-packages/torch/include/ATen/TypeDefault.h new file mode 100644 index 0000000000000000000000000000000000000000..f0d178cd3eeefa71bdf4ca43fb269b0c4bf3d4c5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/TypeDefault.h @@ -0,0 +1,26 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +struct Storage; +} + +namespace at { + +class Tensor; +using TensorList = ArrayRef; + +class Context; +struct Generator; + +struct Quantizer; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Utils.h b/phivenv/Lib/site-packages/torch/include/ATen/Utils.h new file mode 100644 index 0000000000000000000000000000000000000000..f61ad74b296e867a6deea9d380d25530ba6b6c68 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Utils.h @@ -0,0 +1,138 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#define AT_DISALLOW_COPY_AND_ASSIGN(TypeName) \ + TypeName(const TypeName&) = delete; \ + void operator=(const TypeName&) = delete + +namespace at { + +TORCH_API int _crash_if_asan(int); + +// Converts a TensorList (i.e. ArrayRef to vector of TensorImpl*) +// NB: This is ONLY used by legacy TH bindings, and ONLY used by cat. +// Once cat is ported entirely to ATen this can be deleted! +inline std::vector checked_dense_tensor_list_unwrap( + ArrayRef tensors, + const char* name, + int pos, + c10::DeviceType device_type, + ScalarType scalar_type) { + std::vector unwrapped; + unwrapped.reserve(tensors.size()); + for (const auto i : c10::irange(tensors.size())) { + const auto& expr = tensors[i]; + if (expr.layout() != Layout::Strided) { + TORCH_CHECK( + false, + "Expected dense tensor but got ", + expr.layout(), + " for sequence element ", + i, + " in sequence argument at position #", + pos, + " '", + name, + "'"); + } + if (expr.device().type() != device_type) { + TORCH_CHECK( + false, + "Expected object of device type ", + device_type, + " but got device type ", + expr.device().type(), + " for sequence element ", + i, + " in sequence argument at position #", + pos, + " '", + name, + "'"); + } + if (expr.scalar_type() != scalar_type) { + TORCH_CHECK( + false, + "Expected object of scalar type ", + scalar_type, + " but got scalar type ", + expr.scalar_type(), + " for sequence element ", + i, + " in sequence argument at position #", + pos, + " '", + name, + "'"); + } + unwrapped.emplace_back(expr.unsafeGetTensorImpl()); + } + return unwrapped; +} + +template +std::array check_intlist( + ArrayRef list, + const char* name, + int pos) { + if (list.empty()) { + // TODO: is this necessary? We used to treat nullptr-vs-not in IntList + // differently with strides as a way of faking optional. + list = {}; + } + auto res = std::array(); + if (list.size() == 1 && N > 1) { + res.fill(list[0]); + return res; + } + if (list.size() != N) { + TORCH_CHECK( + false, + "Expected a list of ", + N, + " ints but got ", + list.size(), + " for argument #", + pos, + " '", + name, + "'"); + } + std::copy_n(list.begin(), N, res.begin()); + return res; +} + +using at::detail::check_size_nonnegative; + +namespace detail { + +template +TORCH_API Tensor tensor_cpu(ArrayRef values, const TensorOptions& options); + +template +TORCH_API Tensor +tensor_backend(ArrayRef values, const TensorOptions& options); + +template +TORCH_API Tensor +tensor_complex_cpu(ArrayRef values, const TensorOptions& options); + +template +TORCH_API Tensor +tensor_complex_backend(ArrayRef values, const TensorOptions& options); +} // namespace detail + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/Version.h b/phivenv/Lib/site-packages/torch/include/ATen/Version.h new file mode 100644 index 0000000000000000000000000000000000000000..e73a907744b5129c2bca29f441b37bb36d4ee82e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/Version.h @@ -0,0 +1,18 @@ +#include + +namespace at { + +/// Returns a detailed string describing the configuration PyTorch. +TORCH_API std::string show_config(); + +TORCH_API std::string get_mkl_version(); + +TORCH_API std::string get_mkldnn_version(); + +TORCH_API std::string get_openmp_version(); + +TORCH_API std::string get_cxx_flags(); + +TORCH_API std::string get_cpu_capability(); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h b/phivenv/Lib/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h new file mode 100644 index 0000000000000000000000000000000000000000..3be5df5372d460bb5b8832fd37deb762cf620ec2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/VmapGeneratedPlumbing.h @@ -0,0 +1,28093 @@ + +#pragma once +#include +#include + +namespace at { namespace functorch { + +template +at::Tensor _cast_Byte_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Byte::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Char_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Char::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Double_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Double::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Float_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Float::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Int_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Int::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Long_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Long::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Short_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Short::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cast_Half_generated_plumbing(const at::Tensor & self, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_cast_Half::call(self, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _backward_generated_plumbing(const at::Tensor & self, at::TensorList inputs, const ::std::optional & gradient, ::std::optional retain_graph, bool create_graph) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(inputs, cur_level) && !isBatchedAtLevel(gradient, cur_level)) { + return at::_ops::_backward::call(self, inputs, gradient, retain_graph, create_graph); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional gradient_value; + std::optional gradient_bdim; + if (gradient) { + std::tie(gradient_value, gradient_bdim) = unwrapTensorAtLevel(gradient.value(), cur_level); + } + batch_rule(self_value, self_bdim, inputs, gradient_value, gradient_bdim, retain_graph, create_graph); +} +template +void set_data_generated_plumbing(at::Tensor & self, const at::Tensor & new_data) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(new_data, cur_level)) { + return at::_ops::set_data::call(self, new_data); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [new_data_value, new_data_bdim] = unwrapTensorAtLevel(new_data, cur_level); + batch_rule(self_value, self_bdim, new_data_value, new_data_bdim); +} +template +at::Tensor data_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::data::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & requires_grad__generated_plumbing(at::Tensor & self, bool requires_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::requires_grad_::call(self, requires_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, requires_grad); + return self; +} +template +void retain_grad_generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::retain_grad::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); +} +template +at::Tensor _fw_primal_generated_plumbing(const at::Tensor & self, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fw_primal::call(self, level); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_dual_generated_plumbing(const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(primal, cur_level) && !isBatchedAtLevel(tangent, cur_level)) { + return at::_ops::_make_dual::call(primal, tangent, level); + } + auto [primal_value, primal_bdim] = unwrapTensorAtLevel(primal, cur_level); + auto [tangent_value, tangent_bdim] = unwrapTensorAtLevel(tangent, cur_level); + auto results = batch_rule(primal_value, primal_bdim, tangent_value, tangent_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _unpack_dual_generated_plumbing(const at::Tensor & dual, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dual, cur_level)) { + return at::_ops::_unpack_dual::call(dual, level); + } + auto [dual_value, dual_bdim] = unwrapTensorAtLevel(dual, cur_level); + auto results = batch_rule(dual_value, dual_bdim, level); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _new_zeros_with_same_feature_meta_generated_plumbing(const at::Tensor & self, const at::Tensor & other, int64_t self_num_batch_dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_new_zeros_with_same_feature_meta::call(self, other, self_num_batch_dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, self_num_batch_dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rename_generated_plumbing(const at::Tensor & self, ::std::optional names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rename::call(self, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor align_to_generated_plumbing(const at::Tensor & self, at::DimnameList names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::align_to::call(self, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor align_to_ellipsis_idx_generated_plumbing(const at::Tensor & self, at::DimnameList order, int64_t ellipsis_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::align_to_ellipsis_idx::call(self, order, ellipsis_idx); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, order, ellipsis_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor align_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::align_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector align_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::align_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _assert_async_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_assert_async::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); +} +template +void _assert_async_msg_generated_plumbing(const at::Tensor & self, c10::string_view assert_msg) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_assert_async_msg::call(self, assert_msg); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, assert_msg); +} +template +at::Tensor _functional_assert_scalar_generated_plumbing(const at::Scalar & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_assert_scalar::call(self, assert_msg, dep_token); + } + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(self, assert_msg, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _functional_assert_async_msg_generated_plumbing(const at::Tensor & self, c10::string_view assert_msg, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_assert_async_msg::call(self, assert_msg, dep_token); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(self_value, self_bdim, assert_msg, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _assert_tensor_metadata_generated_plumbing(const at::Tensor & a, at::OptionalSymIntArrayRef size, at::OptionalSymIntArrayRef stride, ::std::optional dtype, ::std::optional device, ::std::optional layout) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(a, cur_level)) { + return at::_ops::_assert_tensor_metadata::call(a, size, stride, dtype, device, layout); + } + auto [a_value, a_bdim] = unwrapTensorAtLevel(a, cur_level); + batch_rule(a_value, a_bdim, size, stride, dtype, device, layout); +} +template +at::Tensor _functional_sym_constrain_range_generated_plumbing(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_sym_constrain_range::call(size, min, max, dep_token); + } + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(size, min, max, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _functional_sym_constrain_range_for_size_generated_plumbing(const at::Scalar & size, ::std::optional min, ::std::optional max, const at::Tensor & dep_token) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dep_token, cur_level)) { + return at::_ops::_functional_sym_constrain_range_for_size::call(size, min, max, dep_token); + } + auto [dep_token_value, dep_token_bdim] = unwrapTensorAtLevel(dep_token, cur_level); + auto results = batch_rule(size, min, max, dep_token_value, dep_token_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor refine_names_generated_plumbing(const at::Tensor & self, at::DimnameList names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::refine_names::call(self, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _cudnn_ctc_loss_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level)) { + return at::_ops::_cudnn_ctc_loss::call(log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, blank, deterministic, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _cudnn_ctc_loss_Tensor_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool deterministic, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level)) { + return at::_ops::_cudnn_ctc_loss_Tensor::call(log_probs, targets, input_lengths, target_lengths, blank, deterministic, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, blank, deterministic, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _cudnn_rnn_flatten_weight_generated_plumbing(at::TensorList weight_arr, int64_t weight_stride0, c10::SymInt input_size, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight_arr, cur_level)) { + return at::_ops::_cudnn_rnn_flatten_weight::call(weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + } + + auto results = batch_rule(weight_arr, weight_stride0, input_size, mode, hidden_size, proj_size, num_layers, batch_first, bidirectional); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _cudnn_rnn_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const ::std::optional & weight_buf, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(dropout_state, cur_level)) { + return at::_ops::_cudnn_rnn::call(input, weight, weight_stride0, weight_buf, hx, cx, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional weight_buf_value; + std::optional weight_buf_bdim; + if (weight_buf) { + std::tie(weight_buf_value, weight_buf_bdim) = unwrapTensorAtLevel(weight_buf.value(), cur_level); + } + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple> _cudnn_rnn_backward_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level)) { + return at::_ops::_cudnn_rnn_backward::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _fused_dropout_generated_plumbing(const at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fused_dropout::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, generator); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _masked_scale_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, double scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_masked_scale::call(self, mask, scale); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, scale); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_dropout_generated_plumbing(const at::Tensor & input, double p, ::std::optional train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::native_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor native_dropout_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & mask, double scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::native_dropout_backward::call(grad_output, mask, scale); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, mask_value, mask_bdim, scale); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _sobol_engine_draw_generated_plumbing(const at::Tensor & quasi, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(quasi, cur_level) && !isBatchedAtLevel(sobolstate, cur_level)) { + return at::_ops::_sobol_engine_draw::call(quasi, n, sobolstate, dimension, num_generated, dtype); + } + auto [quasi_value, quasi_bdim] = unwrapTensorAtLevel(quasi, cur_level); + auto [sobolstate_value, sobolstate_bdim] = unwrapTensorAtLevel(sobolstate, cur_level); + auto results = batch_rule(quasi_value, quasi_bdim, n, sobolstate_value, sobolstate_bdim, dimension, num_generated, dtype); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor & _sobol_engine_ff__generated_plumbing(at::Tensor & self, int64_t n, const at::Tensor & sobolstate, int64_t dimension, int64_t num_generated) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(sobolstate, cur_level)) { + return at::_ops::_sobol_engine_ff_::call(self, n, sobolstate, dimension, num_generated); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [sobolstate_value, sobolstate_bdim] = unwrapTensorAtLevel(sobolstate, cur_level); + batch_rule(self_value, self_bdim, n, sobolstate_value, sobolstate_bdim, dimension, num_generated); + return self; +} +template +at::Tensor & _sobol_engine_scramble__generated_plumbing(at::Tensor & self, const at::Tensor & ltm, int64_t dimension) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(ltm, cur_level)) { + return at::_ops::_sobol_engine_scramble_::call(self, ltm, dimension); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [ltm_value, ltm_bdim] = unwrapTensorAtLevel(ltm, cur_level); + batch_rule(self_value, self_bdim, ltm_value, ltm_bdim, dimension); + return self; +} +template +at::Tensor & _sobol_engine_initialize_state__generated_plumbing(at::Tensor & self, int64_t dimension) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sobol_engine_initialize_state_::call(self, dimension); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dimension); + return self; +} +template +at::Tensor _reshape_from_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & shape) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(shape, cur_level)) { + return at::_ops::_reshape_from_tensor::call(self, shape); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [shape_value, shape_bdim] = unwrapTensorAtLevel(shape, cur_level); + auto results = batch_rule(self_value, self_bdim, shape_value, shape_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _shape_as_tensor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_shape_as_tensor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor feature_dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::feature_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & feature_dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::feature_dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor alpha_dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::alpha_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & alpha_dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::alpha_dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor feature_alpha_dropout_generated_plumbing(const at::Tensor & input, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::feature_alpha_dropout::call(input, p, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, p, train); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & feature_alpha_dropout__generated_plumbing(at::Tensor & self, double p, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::feature_alpha_dropout_::call(self, p, train); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, train); + return self; +} +template +at::Tensor abs_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::abs::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & abs__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::abs_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor absolute_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::absolute::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & absolute__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::absolute_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor angle_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::angle::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_real_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_real::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_complex_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_complex::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sgn_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sgn::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sgn__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sgn_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor chalf_generated_plumbing(const at::Tensor & self, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::chalf::call(self, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor real_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::real::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor imag_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::imag::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _conj_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_conj::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conj_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::conj::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _conj_physical_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_conj_physical::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conj_physical_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::conj_physical::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & conj_physical__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::conj_physical_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor resolve_conj_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::resolve_conj::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor resolve_neg_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::resolve_neg::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _neg_view_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_neg_view::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor acos_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acos::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & acos__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acos_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arccos_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccos::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arccos__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccos_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor avg_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool1d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adaptive_avg_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_avg_pool1d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple adaptive_max_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_max_pool1d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor add_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::add_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & add__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::add__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor _add_relu_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_add_relu_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _add_relu__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_add_relu__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor _add_relu_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_add_relu_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _add_relu__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_add_relu__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor add_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::add_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & add__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::add__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor addmv_generated_plumbing(const at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat, cur_level) && !isBatchedAtLevel(vec, cur_level)) { + return at::_ops::addmv::call(self, mat, vec, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat_value, mat_bdim] = unwrapTensorAtLevel(mat, cur_level); + auto [vec_value, vec_bdim] = unwrapTensorAtLevel(vec, cur_level); + auto results = batch_rule(self_value, self_bdim, mat_value, mat_bdim, vec_value, vec_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addmv__generated_plumbing(at::Tensor & self, const at::Tensor & mat, const at::Tensor & vec, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat, cur_level) && !isBatchedAtLevel(vec, cur_level)) { + return at::_ops::addmv_::call(self, mat, vec, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat_value, mat_bdim] = unwrapTensorAtLevel(mat, cur_level); + auto [vec_value, vec_bdim] = unwrapTensorAtLevel(vec, cur_level); + batch_rule(self_value, self_bdim, mat_value, mat_bdim, vec_value, vec_bdim, beta, alpha); + return self; +} +template +at::Tensor addr_generated_plumbing(const at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec1, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::addr::call(self, vec1, vec2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec1_value, vec1_bdim] = unwrapTensorAtLevel(vec1, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + auto results = batch_rule(self_value, self_bdim, vec1_value, vec1_bdim, vec2_value, vec2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addr__generated_plumbing(at::Tensor & self, const at::Tensor & vec1, const at::Tensor & vec2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec1, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::addr_::call(self, vec1, vec2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec1_value, vec1_bdim] = unwrapTensorAtLevel(vec1, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + batch_rule(self_value, self_bdim, vec1_value, vec1_bdim, vec2_value, vec2_bdim, beta, alpha); + return self; +} +template +at::Tensor affine_grid_generator_generated_plumbing(const at::Tensor & theta, c10::SymIntArrayRef size, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(theta, cur_level)) { + return at::_ops::affine_grid_generator::call(theta, size, align_corners); + } + auto [theta_value, theta_bdim] = unwrapTensorAtLevel(theta, cur_level); + auto results = batch_rule(theta_value, theta_bdim, size, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor affine_grid_generator_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef size, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level)) { + return at::_ops::affine_grid_generator_backward::call(grad, size, align_corners); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(grad_value, grad_bdim, size, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _is_all_true_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_is_all_true::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _is_any_true_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_is_any_true::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_check_tensor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_check_tensor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_functorch_fallback_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_test_functorch_fallback::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor all_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor all_dims_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all_dims::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor all_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all_dimname::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_dims_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any_dims::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any_dimname::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dim_arange_generated_plumbing(const at::Tensor & like, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(like, cur_level)) { + return at::_ops::_dim_arange::call(like, dim); + } + auto [like_value, like_bdim] = unwrapTensorAtLevel(like, cur_level); + auto results = batch_rule(like_value, like_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argmax_generated_plumbing(const at::Tensor & self, ::std::optional dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argmax::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argmin_generated_plumbing(const at::Tensor & self, ::std::optional dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argmin::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor acosh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acosh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & acosh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::acosh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arccosh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccosh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arccosh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arccosh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor asinh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asinh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & asinh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asinh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arcsinh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsinh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arcsinh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsinh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor atanh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atanh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & atanh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atanh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arctanh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctanh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arctanh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctanh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor as_strided_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::as_strided::call(self, size, stride, storage_offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride, storage_offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor asin_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asin::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & asin__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::asin_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arcsin_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsin::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arcsin__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arcsin_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor atan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & atan__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atan_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor arctan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arctan__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::arctan_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor atleast_1d_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atleast_1d::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector atleast_1d_Sequence_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::atleast_1d_Sequence::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor atleast_2d_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atleast_2d::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector atleast_2d_Sequence_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::atleast_2d_Sequence::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor atleast_3d_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::atleast_3d::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector atleast_3d_Sequence_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::atleast_3d_Sequence::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor baddbmm_generated_plumbing(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::baddbmm::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + auto results = batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & baddbmm__generated_plumbing(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::baddbmm_::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return self; +} +template +at::Tensor baddbmm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::baddbmm_dtype::call(self, batch1, batch2, out_dtype, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + auto results = batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, out_dtype, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor batch_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::batch_norm::call(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, momentum, eps, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_batch_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & var, double eps, double output_scale, int64_t output_zero_point) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(var, cur_level)) { + return at::_ops::quantized_batch_norm::call(input, weight, bias, mean, var, eps, output_scale, output_zero_point); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [var_value, var_bdim] = unwrapTensorAtLevel(var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, mean_value, mean_bdim, var_value, var_bdim, eps, output_scale, output_zero_point); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _batch_norm_impl_index_backward_generated_plumbing(int64_t impl_index, const at::Tensor & input, const at::Tensor & grad_output, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var_transform, bool train, double eps, ::std::array output_mask, const at::Tensor & reservedSpace) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var_transform, cur_level) && !isBatchedAtLevel(reservedSpace, cur_level)) { + return at::_ops::_batch_norm_impl_index_backward::call(impl_index, input, grad_output, weight, running_mean, running_var, save_mean, save_var_transform, train, eps, output_mask, reservedSpace); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [reservedSpace_value, reservedSpace_bdim] = unwrapTensorAtLevel(reservedSpace, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_transform_value; + std::optional save_var_transform_bdim; + if (save_var_transform) { + std::tie(save_var_transform_value, save_var_transform_bdim) = unwrapTensorAtLevel(save_var_transform.value(), cur_level); + } + auto results = batch_rule(impl_index, input_value, input_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_transform_value, save_var_transform_bdim, train, eps, output_mask, reservedSpace_value, reservedSpace_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor bernoulli_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bernoulli::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bernoulli__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(p, cur_level)) { + return at::_ops::bernoulli__Tensor::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [p_value, p_bdim] = unwrapTensorAtLevel(p, cur_level); + batch_rule(self_value, self_bdim, p_value, p_bdim, generator); + return self; +} +template +at::Tensor & bernoulli__float_generated_plumbing(at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bernoulli__float::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, generator); + return self; +} +template +at::Tensor bernoulli_p_generated_plumbing(const at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bernoulli_p::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bilinear_generated_plumbing(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input1, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::bilinear::call(input1, input2, weight, bias); + } + auto [input1_value, input1_bdim] = unwrapTensorAtLevel(input1, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input1_value, input1_bdim, input2_value, input2_bdim, weight_value, weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binary_cross_entropy_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::binary_cross_entropy::call(self, target, weight, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binary_cross_entropy_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::binary_cross_entropy_backward::call(grad_output, self, target, weight, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binary_cross_entropy_with_logits_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, const ::std::optional & pos_weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(pos_weight, cur_level)) { + return at::_ops::binary_cross_entropy_with_logits::call(self, target, weight, pos_weight, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional pos_weight_value; + std::optional pos_weight_bdim; + if (pos_weight) { + std::tie(pos_weight_value, pos_weight_bdim) = unwrapTensorAtLevel(pos_weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, pos_weight_value, pos_weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bincount_generated_plumbing(const at::Tensor & self, const ::std::optional & weights, c10::SymInt minlength) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weights, cur_level)) { + return at::_ops::bincount::call(self, weights, minlength); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weights_value; + std::optional weights_bdim; + if (weights) { + std::tie(weights_value, weights_bdim) = unwrapTensorAtLevel(weights.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weights_value, weights_bdim, minlength); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_not_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_not::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_not__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_not_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor copysign_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::copysign_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copysign__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::copysign__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor copysign_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::copysign_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copysign__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::copysign__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor _lazy_clone_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_lazy_clone::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logical_not_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logical_not::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_not__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logical_not_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor logical_xor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_xor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_xor__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_xor_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor logical_and_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_and::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_and__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_and_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor logical_or_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_or::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logical_or__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logical_or_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::bmm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bmm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::bmm_dtype::call(self, mat2, out_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector broadcast_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::broadcast_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor broadcast_to_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::broadcast_to::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_broadcast_to_generated_plumbing(const at::Tensor & self, at::IntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_broadcast_to::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cat_generated_plumbing(const at::ITensorListRef & tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::cat::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cat_names_generated_plumbing(at::TensorList tensors, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::cat_names::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concat_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concat::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concat_names_generated_plumbing(at::TensorList tensors, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concat_names::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concatenate_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concatenate::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor concatenate_names_generated_plumbing(at::TensorList tensors, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::concatenate_names::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor block_diag_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::block_diag::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ceil_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ceil::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ceil__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ceil_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor chain_matmul_generated_plumbing(at::TensorList matrices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(matrices, cur_level)) { + return at::_ops::chain_matmul::call(matrices); + } + + auto results = batch_rule(matrices); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unsafe_chunk_generated_plumbing(const at::Tensor & self, int64_t chunks, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsafe_chunk::call(self, chunks, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, chunks, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector chunk_generated_plumbing(const at::Tensor & self, int64_t chunks, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::chunk::call(self, chunks, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, chunks, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector tensor_split_sections_generated_plumbing(const at::Tensor & self, c10::SymInt sections, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tensor_split_sections::call(self, sections, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector tensor_split_indices_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tensor_split_indices::call(self, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector tensor_split_tensor_indices_or_sections_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor_indices_or_sections, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor_indices_or_sections, cur_level)) { + return at::_ops::tensor_split_tensor_indices_or_sections::call(self, tensor_indices_or_sections, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor_indices_or_sections_value, tensor_indices_or_sections_bdim] = unwrapTensorAtLevel(tensor_indices_or_sections, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor_indices_or_sections_value, tensor_indices_or_sections_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_Tensor_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp_Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clamp__generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min, max); + return self; +} +template +at::Tensor & clamp__Tensor_generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp__Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return self; +} +template +at::Tensor clamp_max_generated_plumbing(const at::Tensor & self, const at::Scalar & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_max::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_max_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp_max_Tensor::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [max_value, max_bdim] = unwrapTensorAtLevel(max, cur_level); + auto results = batch_rule(self_value, self_bdim, max_value, max_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clamp_max__generated_plumbing(at::Tensor & self, const at::Scalar & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_max_::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, max); + return self; +} +template +at::Tensor & clamp_max__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clamp_max__Tensor::call(self, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [max_value, max_bdim] = unwrapTensorAtLevel(max, cur_level); + batch_rule(self_value, self_bdim, max_value, max_bdim); + return self; +} +template +at::Tensor clamp_min_generated_plumbing(const at::Tensor & self, const at::Scalar & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_min::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clamp_min_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level)) { + return at::_ops::clamp_min_Tensor::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [min_value, min_bdim] = unwrapTensorAtLevel(min, cur_level); + auto results = batch_rule(self_value, self_bdim, min_value, min_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clamp_min__generated_plumbing(at::Tensor & self, const at::Scalar & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clamp_min_::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min); + return self; +} +template +at::Tensor & clamp_min__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & min) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level)) { + return at::_ops::clamp_min__Tensor::call(self, min); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [min_value, min_bdim] = unwrapTensorAtLevel(min, cur_level); + batch_rule(self_value, self_bdim, min_value, min_bdim); + return self; +} +template +at::Tensor clip_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clip::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clip_Tensor_generated_plumbing(const at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clip_Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & clip__generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clip_::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min, max); + return self; +} +template +at::Tensor & clip__Tensor_generated_plumbing(at::Tensor & self, const ::std::optional & min, const ::std::optional & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(min, cur_level) && !isBatchedAtLevel(max, cur_level)) { + return at::_ops::clip__Tensor::call(self, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional min_value; + std::optional min_bdim; + if (min) { + std::tie(min_value, min_bdim) = unwrapTensorAtLevel(min.value(), cur_level); + } + std::optional max_value; + std::optional max_bdim; + if (max) { + std::tie(max_value, max_bdim) = unwrapTensorAtLevel(max.value(), cur_level); + } + batch_rule(self_value, self_bdim, min_value, min_bdim, max_value, max_bdim); + return self; +} +template +at::Tensor complex_generated_plumbing(const at::Tensor & real, const at::Tensor & imag) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(real, cur_level) && !isBatchedAtLevel(imag, cur_level)) { + return at::_ops::complex::call(real, imag); + } + auto [real_value, real_bdim] = unwrapTensorAtLevel(real, cur_level); + auto [imag_value, imag_bdim] = unwrapTensorAtLevel(imag, cur_level); + auto results = batch_rule(real_value, real_bdim, imag_value, imag_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor polar_generated_plumbing(const at::Tensor & abs, const at::Tensor & angle) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(abs, cur_level) && !isBatchedAtLevel(angle, cur_level)) { + return at::_ops::polar::call(abs, angle); + } + auto [abs_value, abs_bdim] = unwrapTensorAtLevel(abs, cur_level); + auto [angle_value, angle_bdim] = unwrapTensorAtLevel(angle, cur_level); + auto results = batch_rule(abs_value, abs_bdim, angle_value, angle_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor constant_pad_nd_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::constant_pad_nd::call(self, pad, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor contiguous_generated_plumbing(const at::Tensor & self, at::MemoryFormat memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::contiguous::call(self, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor convolution_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::convolution::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple convolution_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, at::OptionalSymIntArrayRef bias_sizes, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::convolution_backward::call(grad_output, input, weight, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, weight_value, weight_bdim, bias_sizes, stride, padding, dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor convolution_overrideable_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::convolution_overrideable::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple convolution_backward_overrideable_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::convolution_backward_overrideable::call(grad_output, input, weight, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, weight_value, weight_bdim, stride, padding, dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _convolution_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled, bool allow_tf32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_convolution::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled, allow_tf32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convolution_deprecated_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, c10::SymInt groups, bool benchmark, bool deterministic, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_convolution_deprecated::call(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, transposed, output_padding, groups, benchmark, deterministic, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convolution_mode_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_convolution_mode::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _convolution_double_backward_generated_plumbing(const ::std::optional & ggI, const ::std::optional & ggW, const ::std::optional & ggb, const at::Tensor & gO, const at::Tensor & weight, const at::Tensor & self, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, bool transposed, c10::SymIntArrayRef output_padding, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ggI, cur_level) && !isBatchedAtLevel(ggW, cur_level) && !isBatchedAtLevel(ggb, cur_level) && !isBatchedAtLevel(gO, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convolution_double_backward::call(ggI, ggW, ggb, gO, weight, self, stride, padding, dilation, transposed, output_padding, groups, output_mask); + } + auto [gO_value, gO_bdim] = unwrapTensorAtLevel(gO, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional ggI_value; + std::optional ggI_bdim; + if (ggI) { + std::tie(ggI_value, ggI_bdim) = unwrapTensorAtLevel(ggI.value(), cur_level); + } + std::optional ggW_value; + std::optional ggW_bdim; + if (ggW) { + std::tie(ggW_value, ggW_bdim) = unwrapTensorAtLevel(ggW.value(), cur_level); + } + std::optional ggb_value; + std::optional ggb_bdim; + if (ggb) { + std::tie(ggb_value, ggb_bdim) = unwrapTensorAtLevel(ggb.value(), cur_level); + } + auto results = batch_rule(ggI_value, ggI_bdim, ggW_value, ggW_bdim, ggb_value, ggb_bdim, gO_value, gO_bdim, weight_value, weight_bdim, self_value, self_bdim, stride, padding, dilation, transposed, output_padding, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor conv1d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv1d::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv2d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv2d::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv3d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv3d::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv1d_padding_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv1d_padding::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv2d_padding_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv2d_padding::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv3d_padding_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::string_view padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv3d_padding::call(input, weight, bias, stride, padding, dilation, groups); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_tbc_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_tbc::call(self, weight, bias, pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, pad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple conv_tbc_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & input, const at::Tensor & weight, const at::Tensor & bias, int64_t pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_tbc_backward::call(self, input, weight, bias, pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(self_value, self_bdim, input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, pad); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor conv_transpose1d_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_transpose1d::call(input, weight, bias, stride, padding, output_padding, groups, dilation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, output_padding, groups, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_transpose2d_input_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_transpose2d_input::call(input, weight, bias, stride, padding, output_padding, groups, dilation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, output_padding, groups, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_transpose3d_input_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymInt groups, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_transpose3d_input::call(input, weight, bias, stride, padding, output_padding, groups, dilation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, output_padding, groups, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor copy_generated_plumbing(const at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copy__generated_plumbing(at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy_::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return self; +} +template +at::Tensor _copy_from_generated_plumbing(const at::Tensor & self, const at::Tensor & dst, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(dst, cur_level)) { + return at::_ops::_copy_from::call(self, dst, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [dst_value, dst_bdim] = unwrapTensorAtLevel(dst, cur_level); + auto results = batch_rule(self_value, self_bdim, dst_value, dst_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _copy_from_and_resize_generated_plumbing(const at::Tensor & self, const at::Tensor & dst) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(dst, cur_level)) { + return at::_ops::_copy_from_and_resize::call(self, dst); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [dst_value, dst_bdim] = unwrapTensorAtLevel(dst, cur_level); + auto results = batch_rule(self_value, self_bdim, dst_value, dst_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cos_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cos::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cos__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cos_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor cosh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cosh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cosh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cosh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor cosine_embedding_loss_generated_plumbing(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input1, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::cosine_embedding_loss::call(input1, input2, target, margin, reduction); + } + auto [input1_value, input1_bdim] = unwrapTensorAtLevel(input1, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(input1_value, input1_bdim, input2_value, input2_bdim, target_value, target_bdim, margin, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor count_nonzero_dim_IntList_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::count_nonzero_dim_IntList::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor count_nonzero_generated_plumbing(const at::Tensor & self, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::count_nonzero::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cov_generated_plumbing(const at::Tensor & self, int64_t correction, const ::std::optional & fweights, const ::std::optional & aweights) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(fweights, cur_level) && !isBatchedAtLevel(aweights, cur_level)) { + return at::_ops::cov::call(self, correction, fweights, aweights); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional fweights_value; + std::optional fweights_bdim; + if (fweights) { + std::tie(fweights_value, fweights_bdim) = unwrapTensorAtLevel(fweights.value(), cur_level); + } + std::optional aweights_value; + std::optional aweights_bdim; + if (aweights) { + std::tie(aweights_value, aweights_bdim) = unwrapTensorAtLevel(aweights.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, correction, fweights_value, fweights_bdim, aweights_value, aweights_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor corrcoef_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::corrcoef::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_affine_grid_generator_generated_plumbing(const at::Tensor & theta, int64_t N, int64_t C, int64_t H, int64_t W) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(theta, cur_level)) { + return at::_ops::cudnn_affine_grid_generator::call(theta, N, C, H, W); + } + auto [theta_value, theta_bdim] = unwrapTensorAtLevel(theta, cur_level); + auto results = batch_rule(theta_value, theta_bdim, N, C, H, W); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_affine_grid_generator_backward_generated_plumbing(const at::Tensor & grad, int64_t N, int64_t C, int64_t H, int64_t W) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level)) { + return at::_ops::cudnn_affine_grid_generator_backward::call(grad, N, C, H, W); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(grad_value, grad_bdim, N, C, H, W); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple cudnn_batch_norm_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::cudnn_batch_norm::call(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, exponential_average_factor, epsilon); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple cudnn_batch_norm_backward_generated_plumbing(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon, const at::Tensor & reserveSpace) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var, cur_level) && !isBatchedAtLevel(reserveSpace, cur_level)) { + return at::_ops::cudnn_batch_norm_backward::call(input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon, reserveSpace); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [reserveSpace_value, reserveSpace_bdim] = unwrapTensorAtLevel(reserveSpace, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_value; + std::optional save_var_bdim; + if (save_var) { + std::tie(save_var_value, save_var_bdim) = unwrapTensorAtLevel(save_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_value, save_var_bdim, epsilon, reserveSpace_value, reserveSpace_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor cudnn_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::cudnn_convolution::call(self, weight, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_convolution_transpose_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic, bool allow_tf32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::cudnn_convolution_transpose::call(self, weight, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, padding, output_padding, stride, dilation, groups, benchmark, deterministic, allow_tf32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mps_convolution_transpose_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_mps_convolution_transpose::call(self, weight, padding, output_padding, stride, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, padding, output_padding, stride, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mps_convolution_transpose_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mps_convolution_transpose_backward::call(self, grad_output, weight, padding, output_padding, stride, dilation, groups, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, padding, output_padding, stride, dilation, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor cudnn_convolution_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::cudnn_convolution_relu::call(self, weight, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_convolution_add_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(z, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::cudnn_convolution_add_relu::call(self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [z_value, z_bdim] = unwrapTensorAtLevel(z, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, z_value, z_bdim, alpha, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cudnn_grid_sampler_generated_plumbing(const at::Tensor & self, const at::Tensor & grid) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::cudnn_grid_sampler::call(self, grid); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(self_value, self_bdim, grid_value, grid_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple cudnn_grid_sampler_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grid, const at::Tensor & grad_output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grid, cur_level) && !isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::cudnn_grid_sampler_backward::call(self, grid, grad_output); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(self_value, self_bdim, grid_value, grid_bdim, grad_output_value, grad_output_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple cummax_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummax::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple cummax_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummax_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _cummax_helper_generated_plumbing(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_cummax_helper::call(self, values, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + batch_rule(self_value, self_bdim, values_value, values_bdim, indices_value, indices_bdim, dim); +} +template +::std::tuple cummin_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummin::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple cummin_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cummin_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _cummin_helper_generated_plumbing(const at::Tensor & self, at::Tensor & values, at::Tensor & indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_cummin_helper::call(self, values, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + batch_rule(self_value, self_bdim, values_value, values_bdim, indices_value, indices_bdim, dim); +} +template +at::Tensor cummaxmin_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & indices, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::cummaxmin_backward::call(grad, input, indices, dim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, indices_value, indices_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cumprod_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumprod__generated_plumbing(at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod_::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumprod_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod_dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumprod__dimname_generated_plumbing(at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumprod__dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumprod_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, int64_t dim, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::cumprod_backward::call(grad, input, dim, output); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, dim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cumsum_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumsum__generated_plumbing(at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum_::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumsum_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum_dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & cumsum__dimname_generated_plumbing(at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cumsum__dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, dtype); + return self; +} +template +at::Tensor cumulative_trapezoid_x_generated_plumbing(const at::Tensor & y, const at::Tensor & x, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level) && !isBatchedAtLevel(x, cur_level)) { + return at::_ops::cumulative_trapezoid_x::call(y, x, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(y_value, y_bdim, x_value, x_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cumulative_trapezoid_dx_generated_plumbing(const at::Tensor & y, const at::Scalar & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level)) { + return at::_ops::cumulative_trapezoid_dx::call(y, dx, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(y_value, y_bdim, dx, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ctc_loss_IntList_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, int64_t reduction, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level)) { + return at::_ops::ctc_loss_IntList::call(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, blank, reduction, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ctc_loss_Tensor_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, int64_t reduction, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level)) { + return at::_ops::ctc_loss_Tensor::call(log_probs, targets, input_lengths, target_lengths, blank, reduction, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, blank, reduction, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _ctc_loss_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level)) { + return at::_ops::_ctc_loss::call(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, blank, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _ctc_loss_Tensor_generated_plumbing(const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level)) { + return at::_ops::_ctc_loss_Tensor::call(log_probs, targets, input_lengths, target_lengths, blank, zero_infinity); + } + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto results = batch_rule(log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, blank, zero_infinity); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _ctc_loss_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, at::IntArrayRef input_lengths, at::IntArrayRef target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(neg_log_likelihood, cur_level) && !isBatchedAtLevel(log_alpha, cur_level)) { + return at::_ops::_ctc_loss_backward::call(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [neg_log_likelihood_value, neg_log_likelihood_bdim] = unwrapTensorAtLevel(neg_log_likelihood, cur_level); + auto [log_alpha_value, log_alpha_bdim] = unwrapTensorAtLevel(log_alpha, cur_level); + auto results = batch_rule(grad_value, grad_bdim, log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths, target_lengths, neg_log_likelihood_value, neg_log_likelihood_bdim, log_alpha_value, log_alpha_bdim, blank, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _ctc_loss_backward_Tensor_generated_plumbing(const at::Tensor & grad, const at::Tensor & log_probs, const at::Tensor & targets, const at::Tensor & input_lengths, const at::Tensor & target_lengths, const at::Tensor & neg_log_likelihood, const at::Tensor & log_alpha, int64_t blank, bool zero_infinity) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(log_probs, cur_level) && !isBatchedAtLevel(targets, cur_level) && !isBatchedAtLevel(input_lengths, cur_level) && !isBatchedAtLevel(target_lengths, cur_level) && !isBatchedAtLevel(neg_log_likelihood, cur_level) && !isBatchedAtLevel(log_alpha, cur_level)) { + return at::_ops::_ctc_loss_backward_Tensor::call(grad, log_probs, targets, input_lengths, target_lengths, neg_log_likelihood, log_alpha, blank, zero_infinity); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [log_probs_value, log_probs_bdim] = unwrapTensorAtLevel(log_probs, cur_level); + auto [targets_value, targets_bdim] = unwrapTensorAtLevel(targets, cur_level); + auto [input_lengths_value, input_lengths_bdim] = unwrapTensorAtLevel(input_lengths, cur_level); + auto [target_lengths_value, target_lengths_bdim] = unwrapTensorAtLevel(target_lengths, cur_level); + auto [neg_log_likelihood_value, neg_log_likelihood_bdim] = unwrapTensorAtLevel(neg_log_likelihood, cur_level); + auto [log_alpha_value, log_alpha_bdim] = unwrapTensorAtLevel(log_alpha, cur_level); + auto results = batch_rule(grad_value, grad_bdim, log_probs_value, log_probs_bdim, targets_value, targets_bdim, input_lengths_value, input_lengths_bdim, target_lengths_value, target_lengths_bdim, neg_log_likelihood_value, neg_log_likelihood_bdim, log_alpha_value, log_alpha_bdim, blank, zero_infinity); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diag_embed_generated_plumbing(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diag_embed::call(self, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagflat_generated_plumbing(const at::Tensor & self, int64_t offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagflat::call(self, offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_generated_plumbing(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagonal::call(self, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_diagonal_generated_plumbing(const at::Tensor & A, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_diagonal::call(A, offset, dim1, dim2); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname outdim, at::Dimname dim1, at::Dimname dim2, int64_t offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagonal_Dimname::call(self, outdim, dim1, dim2, offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, outdim, dim1, dim2, offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::diagonal_backward::call(grad_output, input_sizes, offset, dim1, dim2); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_sizes, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fill_diagonal__generated_plumbing(at::Tensor & self, const at::Scalar & fill_value, bool wrap) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fill_diagonal_::call(self, fill_value, wrap); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, fill_value, wrap); + return self; +} +template +at::Tensor diff_generated_plumbing(const at::Tensor & self, int64_t n, int64_t dim, const ::std::optional & prepend, const ::std::optional & append) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(prepend, cur_level) && !isBatchedAtLevel(append, cur_level)) { + return at::_ops::diff::call(self, n, dim, prepend, append); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional prepend_value; + std::optional prepend_bdim; + if (prepend) { + std::tie(prepend_value, prepend_bdim) = unwrapTensorAtLevel(prepend.value(), cur_level); + } + std::optional append_value; + std::optional append_bdim; + if (append) { + std::tie(append_value, append_bdim) = unwrapTensorAtLevel(append.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n, dim, prepend_value, prepend_bdim, append_value, append_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalarint_generated_plumbing(const at::Tensor & self, const ::std::optional & spacing, ::std::optional dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalarint::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalararray_generated_plumbing(const at::Tensor & self, const at::Scalar & spacing, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalararray::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_array::call(self, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalarrayint_generated_plumbing(const at::Tensor & self, at::ArrayRef spacing, ::std::optional dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalarrayint::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_scalarrayarray_generated_plumbing(const at::Tensor & self, at::ArrayRef spacing, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gradient_scalarrayarray::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_tensorarrayint_generated_plumbing(const at::Tensor & self, at::TensorList spacing, ::std::optional dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(spacing, cur_level)) { + return at::_ops::gradient_tensorarrayint::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector gradient_tensorarray_generated_plumbing(const at::Tensor & self, at::TensorList spacing, at::IntArrayRef dim, int64_t edge_order) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(spacing, cur_level)) { + return at::_ops::gradient_tensorarray::call(self, spacing, dim, edge_order); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, spacing, dim, edge_order); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor div_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor div_Tensor_mode_generated_plumbing(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div_Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Tensor_mode_generated_plumbing(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::div__Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return self; +} +template +at::Tensor div_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor div_Scalar_mode_generated_plumbing(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div_Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & div__Scalar_mode_generated_plumbing(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::div__Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, rounding_mode); + return self; +} +template +at::Tensor divide_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor divide_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor divide_Tensor_mode_generated_plumbing(const at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide_Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Tensor_mode_generated_plumbing(at::Tensor & self, const at::Tensor & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::divide__Tensor_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, rounding_mode); + return self; +} +template +at::Tensor divide_Scalar_mode_generated_plumbing(const at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide_Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, rounding_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & divide__Scalar_mode_generated_plumbing(at::Tensor & self, const at::Scalar & other, ::std::optional rounding_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::divide__Scalar_mode::call(self, other, rounding_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, rounding_mode); + return self; +} +template +at::Tensor true_divide_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::true_divide_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & true_divide__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::true_divide__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor true_divide_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::true_divide_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & true_divide__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::true_divide__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor dot_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor, cur_level)) { + return at::_ops::dot::call(self, tensor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor_value, tensor_bdim] = unwrapTensorAtLevel(tensor, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor_value, tensor_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor vdot_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::vdot::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor einsum_generated_plumbing(c10::string_view equation, at::TensorList tensors, at::OptionalIntArrayRef path) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::einsum::call(equation, tensors, path); + } + + auto results = batch_rule(equation, tensors, path); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding::call(weight, indices, padding_idx, scale_grad_by_freq, sparse); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, padding_idx, scale_grad_by_freq, sparse); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq, bool sparse) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_backward::call(grad, indices, num_weights, padding_idx, scale_grad_by_freq, sparse); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, num_weights, padding_idx, scale_grad_by_freq, sparse); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_dense_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & indices, c10::SymInt num_weights, c10::SymInt padding_idx, bool scale_grad_by_freq) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_dense_backward::call(grad_output, indices, num_weights, padding_idx, scale_grad_by_freq); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, indices_value, indices_bdim, num_weights, padding_idx, scale_grad_by_freq); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & embedding_renorm__generated_plumbing(at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_renorm_::call(self, indices, max_norm, norm_type); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + batch_rule(self_value, self_bdim, indices_value, indices_bdim, max_norm, norm_type); + return self; +} +template +at::Tensor embedding_sparse_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, int64_t num_weights, int64_t padding_idx, bool scale_grad_by_freq) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_sparse_backward::call(grad, indices, num_weights, padding_idx, scale_grad_by_freq); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, num_weights, padding_idx, scale_grad_by_freq); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _embedding_bag_forward_only_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_forward_only::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset, padding_idx); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _rowwise_prune_generated_plumbing(const at::Tensor & weight, const at::Tensor & mask, at::ScalarType compressed_indices_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_rowwise_prune::call(weight, mask, compressed_indices_dtype); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(weight_value, weight_bdim, mask_value, mask_bdim, compressed_indices_dtype); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor row_stack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::row_stack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple embedding_bag_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::embedding_bag::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple embedding_bag_padding_idx_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, ::std::optional padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::embedding_bag_padding_idx::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset, padding_idx); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _embedding_bag_generated_plumbing(const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, bool include_last_offset, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag::call(weight, indices, offsets, scale_grad_by_freq, mode, sparse, per_sample_weights, include_last_offset, padding_idx); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, include_last_offset, padding_idx); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor _embedding_bag_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, bool sparse, const ::std::optional & per_sample_weights, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(offset2bag, cur_level) && !isBatchedAtLevel(bag_size, cur_level) && !isBatchedAtLevel(maximum_indices, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_backward::call(grad, indices, offsets, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto [bag_size_value, bag_size_bdim] = unwrapTensorAtLevel(bag_size, cur_level); + auto [maximum_indices_value, maximum_indices_bdim] = unwrapTensorAtLevel(maximum_indices, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, offset2bag_value, offset2bag_bdim, bag_size_value, bag_size_bdim, maximum_indices_value, maximum_indices_bdim, num_weights, scale_grad_by_freq, mode, sparse, per_sample_weights_value, per_sample_weights_bdim, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _embedding_bag_sparse_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, const at::Tensor & bag_size, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(offset2bag, cur_level) && !isBatchedAtLevel(bag_size, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_sparse_backward::call(grad, indices, offsets, offset2bag, bag_size, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto [bag_size_value, bag_size_bdim] = unwrapTensorAtLevel(bag_size, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, offset2bag_value, offset2bag_bdim, bag_size_value, bag_size_bdim, num_weights, scale_grad_by_freq, mode, per_sample_weights_value, per_sample_weights_bdim, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _embedding_bag_dense_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & indices, const at::Tensor & offset2bag, const at::Tensor & bag_size, const at::Tensor & maximum_indices, c10::SymInt num_weights, bool scale_grad_by_freq, int64_t mode, const ::std::optional & per_sample_weights, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offset2bag, cur_level) && !isBatchedAtLevel(bag_size, cur_level) && !isBatchedAtLevel(maximum_indices, cur_level) && !isBatchedAtLevel(per_sample_weights, cur_level)) { + return at::_ops::_embedding_bag_dense_backward::call(grad, indices, offset2bag, bag_size, maximum_indices, num_weights, scale_grad_by_freq, mode, per_sample_weights, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto [bag_size_value, bag_size_bdim] = unwrapTensorAtLevel(bag_size, cur_level); + auto [maximum_indices_value, maximum_indices_bdim] = unwrapTensorAtLevel(maximum_indices, cur_level); + std::optional per_sample_weights_value; + std::optional per_sample_weights_bdim; + if (per_sample_weights) { + std::tie(per_sample_weights_value, per_sample_weights_bdim) = unwrapTensorAtLevel(per_sample_weights.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, indices_value, indices_bdim, offset2bag_value, offset2bag_bdim, bag_size_value, bag_size_bdim, maximum_indices_value, maximum_indices_bdim, num_weights, scale_grad_by_freq, mode, per_sample_weights_value, per_sample_weights_bdim, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _embedding_bag_per_sample_weights_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & weight, const at::Tensor & indices, const at::Tensor & offsets, const at::Tensor & offset2bag, int64_t mode, int64_t padding_idx) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(offset2bag, cur_level)) { + return at::_ops::_embedding_bag_per_sample_weights_backward::call(grad, weight, indices, offsets, offset2bag, mode, padding_idx); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [offset2bag_value, offset2bag_bdim] = unwrapTensorAtLevel(offset2bag, cur_level); + auto results = batch_rule(grad_value, grad_bdim, weight_value, weight_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, offset2bag_value, offset2bag_bdim, mode, padding_idx); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_empty_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_empty::call(self, size, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_empty_strided_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_empty_strided::call(self, size, stride, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_full_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_full::call(self, size, fill_value, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, fill_value, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_zeros_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_zeros::call(self, size, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor new_ones_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::new_ones::call(self, size, dtype, layout, device, pin_memory); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _empty_per_channel_affine_quantized_generated_plumbing(c10::SymIntArrayRef size, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level)) { + return at::_ops::_empty_per_channel_affine_quantized::call(size, scales, zero_points, axis, dtype, layout, device, pin_memory, memory_format); + } + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + auto results = batch_rule(size, scales_value, scales_bdim, zero_points_value, zero_points_bdim, axis, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +const at::Tensor & _resize_output__generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_resize_output_::call(self, size, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, size, device); + return self; +} +template +at::Tensor empty_quantized_generated_plumbing(at::IntArrayRef size, const at::Tensor & qtensor, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(qtensor, cur_level)) { + return at::_ops::empty_quantized::call(size, qtensor, dtype, layout, device, pin_memory, memory_format); + } + auto [qtensor_value, qtensor_bdim] = unwrapTensorAtLevel(qtensor, cur_level); + auto results = batch_rule(size, qtensor_value, qtensor_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor empty_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::empty_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor erf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & erf__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erf_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor erfc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & erfc__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfc_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor exp_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & exp__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor exp2_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp2::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & exp2__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exp2_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor expm1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expm1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & expm1__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expm1_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor expand_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expand::call(self, size, implicit); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, implicit); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor expand_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::expand_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_using_ints_generated_plumbing(const at::Tensor & self, int64_t start_dim, int64_t end_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_using_ints::call(self, start_dim, end_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, start_dim, end_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_named_out_dim_generated_plumbing(const at::Tensor & self, int64_t start_dim, int64_t end_dim, at::Dimname out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_named_out_dim::call(self, start_dim, end_dim, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, start_dim, end_dim, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_using_names_generated_plumbing(const at::Tensor & self, at::Dimname start_dim, at::Dimname end_dim, at::Dimname out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_using_names::call(self, start_dim, end_dim, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, start_dim, end_dim, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_DimnameList_generated_plumbing(const at::Tensor & self, at::DimnameList dims, at::Dimname out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flatten_DimnameList::call(self, dims, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unflatten_int_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymIntArrayRef sizes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unflatten_int::call(self, dim, sizes); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, sizes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unflatten_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, c10::SymIntArrayRef sizes, at::DimnameList names) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unflatten_Dimname::call(self, dim, sizes, names); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, sizes, names); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fill_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fill_Scalar::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fill_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::fill_Tensor::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fill__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fill__Scalar::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, value); + return self; +} +template +at::Tensor & fill__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::fill__Tensor::call(self, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor floor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & floor__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor floor_divide_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::floor_divide::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & floor_divide__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::floor_divide__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor floor_divide_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor_divide_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & floor_divide__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::floor_divide__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor frac_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frac::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & frac__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frac_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor full_like_generated_plumbing(const at::Tensor & self, const at::Scalar & fill_value, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::full_like::call(self, fill_value, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, fill_value, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gcd_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gcd::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & gcd__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gcd_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor lcm_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lcm::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & lcm__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lcm_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor grid_sampler_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor grid_sampler_2d_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_2d::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple grid_sampler_2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_2d_backward::call(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _grid_sampler_2d_cpu_fallback_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::_grid_sampler_2d_cpu_fallback::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _grid_sampler_2d_cpu_fallback_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::_grid_sampler_2d_cpu_fallback_backward::call(grad_output, input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor grid_sampler_3d_generated_plumbing(const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_3d::call(input, grid, interpolation_mode, padding_mode, align_corners); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple grid_sampler_3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & grid, int64_t interpolation_mode, int64_t padding_mode, bool align_corners, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grid, cur_level)) { + return at::_ops::grid_sampler_3d_backward::call(grad_output, input, grid, interpolation_mode, padding_mode, align_corners, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grid_value, grid_bdim] = unwrapTensorAtLevel(grid, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, grid_value, grid_bdim, interpolation_mode, padding_mode, align_corners, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor hinge_embedding_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, double margin, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::hinge_embedding_loss::call(self, target, margin, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, margin, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor group_norm_generated_plumbing(const at::Tensor & input, int64_t num_groups, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::group_norm::call(input, num_groups, weight, bias, eps, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, num_groups, weight_value, weight_bdim, bias_value, bias_bdim, eps, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_group_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::native_group_norm::call(input, weight, bias, N, C, HxW, group, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, N, C, HxW, group, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple native_group_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, c10::SymInt N, c10::SymInt C, c10::SymInt HxW, int64_t group, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(rstd, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::native_group_norm_backward::call(grad_out, input, mean, rstd, weight, N, C, HxW, group, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, mean_value, mean_bdim, rstd_value, rstd_bdim, weight_value, weight_bdim, N, C, HxW, group, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _fft_r2c_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, bool onesided) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fft_r2c::call(self, dim, normalization, onesided); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, normalization, onesided); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fft_c2r_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, int64_t normalization, c10::SymInt last_dim_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fft_c2r::call(self, dim, normalization, last_dim_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, normalization, last_dim_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fft_c2c_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef dim, int64_t normalization, bool forward) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fft_c2c::call(self, dim, normalization, forward); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, normalization, forward); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _validate_compressed_sparse_indices_generated_plumbing(bool is_crow, const at::Tensor & compressed_idx, const at::Tensor & plain_idx, int64_t cdim, int64_t dim, int64_t nnz) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_idx, cur_level) && !isBatchedAtLevel(plain_idx, cur_level)) { + return at::_ops::_validate_compressed_sparse_indices::call(is_crow, compressed_idx, plain_idx, cdim, dim, nnz); + } + auto [compressed_idx_value, compressed_idx_bdim] = unwrapTensorAtLevel(compressed_idx, cur_level); + auto [plain_idx_value, plain_idx_bdim] = unwrapTensorAtLevel(plain_idx, cur_level); + batch_rule(is_crow, compressed_idx_value, compressed_idx_bdim, plain_idx_value, plain_idx_bdim, cdim, dim, nnz); +} +template +at::Tensor index_Tensor_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::index_Tensor::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_index_Tensor_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_unsafe_index_Tensor::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_masked_index_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Scalar & fill) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::_unsafe_masked_index::call(self, mask, indices, fill); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, indices, fill); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_masked_index_put_accumulate_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const c10::List<::std::optional> & indices, const at::Tensor & values) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_unsafe_masked_index_put_accumulate::call(self, mask, indices, values); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, indices, values_value, values_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_copy__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy_::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return self; +} +template +at::Tensor index_copy_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_copy__dimname_generated_plumbing(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy__dimname::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return self; +} +template +at::Tensor index_copy_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_copy_dimname::call(self, dim, index, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_put__generated_plumbing(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::index_put_::call(self, indices, values, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate); + return self; +} +template +at::Tensor index_put_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::index_put::call(self, indices, values, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _unsafe_index_put_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_unsafe_index_put::call(self, indices, values, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _index_put_impl__generated_plumbing(at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_index_put_impl_::call(self, indices, values, accumulate, unsafe); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate, unsafe); + return self; +} +template +at::Tensor instance_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool use_input_stats, double momentum, double eps, bool cudnn_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::instance_norm::call(input, weight, bias, running_mean, running_var, use_input_stats, momentum, eps, cudnn_enabled); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, use_input_stats, momentum, eps, cudnn_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isclose_generated_plumbing(const at::Tensor & self, const at::Tensor & other, double rtol, double atol, bool equal_nan) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::isclose::call(self, other, rtol, atol, equal_nan); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, rtol, atol, equal_nan); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isin_Tensor_Tensor_generated_plumbing(const at::Tensor & elements, const at::Tensor & test_elements, bool assume_unique, bool invert) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(elements, cur_level) && !isBatchedAtLevel(test_elements, cur_level)) { + return at::_ops::isin_Tensor_Tensor::call(elements, test_elements, assume_unique, invert); + } + auto [elements_value, elements_bdim] = unwrapTensorAtLevel(elements, cur_level); + auto [test_elements_value, test_elements_bdim] = unwrapTensorAtLevel(test_elements, cur_level); + auto results = batch_rule(elements_value, elements_bdim, test_elements_value, test_elements_bdim, assume_unique, invert); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isin_Tensor_Scalar_generated_plumbing(const at::Tensor & elements, const at::Scalar & test_element, bool assume_unique, bool invert) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(elements, cur_level)) { + return at::_ops::isin_Tensor_Scalar::call(elements, test_element, assume_unique, invert); + } + auto [elements_value, elements_bdim] = unwrapTensorAtLevel(elements, cur_level); + auto results = batch_rule(elements_value, elements_bdim, test_element, assume_unique, invert); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isin_Scalar_Tensor_generated_plumbing(const at::Scalar & element, const at::Tensor & test_elements, bool assume_unique, bool invert) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(test_elements, cur_level)) { + return at::_ops::isin_Scalar_Tensor::call(element, test_elements, assume_unique, invert); + } + auto [test_elements_value, test_elements_bdim] = unwrapTensorAtLevel(test_elements, cur_level); + auto results = batch_rule(element, test_elements_value, test_elements_bdim, assume_unique, invert); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isnan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isnan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isreal_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isreal::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor kl_div_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction, bool log_target) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::kl_div::call(self, target, reduction, log_target); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction, log_target); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor kron_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::kron::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple kthvalue_generated_plumbing(const at::Tensor & self, c10::SymInt k, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::kthvalue::call(self, k, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple kthvalue_dimname_generated_plumbing(const at::Tensor & self, c10::SymInt k, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::kthvalue_dimname::call(self, k, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor layer_norm_generated_plumbing(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps, bool cudnn_enable) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::layer_norm::call(input, normalized_shape, weight, bias, eps, cudnn_enable); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, bias_value, bias_bdim, eps, cudnn_enable); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_layer_norm_generated_plumbing(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, const ::std::optional & bias, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::native_layer_norm::call(input, normalized_shape, weight, bias, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, bias_value, bias_bdim, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple native_layer_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const at::Tensor & mean, const at::Tensor & rstd, const ::std::optional & weight, const ::std::optional & bias, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(rstd, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::native_layer_norm_backward::call(grad_out, input, normalized_shape, mean, rstd, weight, bias, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [rstd_value, rstd_bdim] = unwrapTensorAtLevel(rstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, normalized_shape, mean_value, mean_bdim, rstd_value, rstd_bdim, weight_value, weight_bdim, bias_value, bias_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor rms_norm_generated_plumbing(const at::Tensor & input, c10::SymIntArrayRef normalized_shape, const ::std::optional & weight, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::rms_norm::call(input, normalized_shape, weight, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, normalized_shape, weight_value, weight_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fused_rms_norm_generated_plumbing(const at::Tensor & input, int64_t normalized_shape_ndim, const at::Tensor & weight, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_fused_rms_norm::call(input, normalized_shape_ndim, weight, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(input_value, input_bdim, normalized_shape_ndim, weight_value, weight_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nan_to_num_generated_plumbing(const at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nan_to_num::call(self, nan, posinf, neginf); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, nan, posinf, neginf); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & nan_to_num__generated_plumbing(at::Tensor & self, ::std::optional nan, ::std::optional posinf, ::std::optional neginf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nan_to_num_::call(self, nan, posinf, neginf); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, nan, posinf, neginf); + return self; +} +template +at::Tensor linear_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::linear::call(input, weight, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linear_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::linear_backward::call(self, grad_output, weight, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor mkldnn_linear_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::mkldnn_linear::call(self, weight, bias); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_linear_backward_input_generated_plumbing(at::IntArrayRef input_size, const at::Tensor & grad_output, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mkldnn_linear_backward_input::call(input_size, grad_output, weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(input_size, grad_output_value, grad_output_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mkldnn_linear_backward_weights_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & input, const at::Tensor & weight, bool bias_defined) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mkldnn_linear_backward_weights::call(grad_output, input, weight, bias_defined); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_value, input_bdim, weight_value, weight_bdim, bias_defined); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple mkldnn_linear_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mkldnn_linear_backward::call(self, grad_output, weight, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _cslt_compress_generated_plumbing(const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_cslt_compress::call(input); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cslt_sparse_mm_generated_plumbing(const at::Tensor & compressed_A, const at::Tensor & dense_B, const ::std::optional & bias, const ::std::optional & alpha, ::std::optional out_dtype, bool transpose_result, int64_t alg_id, int64_t split_k, int64_t split_k_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_A, cur_level) && !isBatchedAtLevel(dense_B, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(alpha, cur_level)) { + return at::_ops::_cslt_sparse_mm::call(compressed_A, dense_B, bias, alpha, out_dtype, transpose_result, alg_id, split_k, split_k_mode); + } + auto [compressed_A_value, compressed_A_bdim] = unwrapTensorAtLevel(compressed_A, cur_level); + auto [dense_B_value, dense_B_bdim] = unwrapTensorAtLevel(dense_B, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional alpha_value; + std::optional alpha_bdim; + if (alpha) { + std::tie(alpha_value, alpha_bdim) = unwrapTensorAtLevel(alpha.value(), cur_level); + } + auto results = batch_rule(compressed_A_value, compressed_A_bdim, dense_B_value, dense_B_bdim, bias_value, bias_bdim, alpha_value, alpha_bdim, out_dtype, transpose_result, alg_id, split_k, split_k_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _sparse_semi_structured_tile_generated_plumbing(const at::Tensor & input, c10::string_view algorithm, bool use_cutlass) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_sparse_semi_structured_tile::call(input, algorithm, use_cutlass); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, algorithm, use_cutlass); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _sparse_semi_structured_apply_generated_plumbing(const at::Tensor & input, const at::Tensor & thread_masks) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(thread_masks, cur_level)) { + return at::_ops::_sparse_semi_structured_apply::call(input, thread_masks); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [thread_masks_value, thread_masks_bdim] = unwrapTensorAtLevel(thread_masks, cur_level); + auto results = batch_rule(input_value, input_bdim, thread_masks_value, thread_masks_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _sparse_semi_structured_apply_dense_generated_plumbing(const at::Tensor & input, const at::Tensor & thread_masks) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(thread_masks, cur_level)) { + return at::_ops::_sparse_semi_structured_apply_dense::call(input, thread_masks); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [thread_masks_value, thread_masks_bdim] = unwrapTensorAtLevel(thread_masks, cur_level); + auto results = batch_rule(input_value, input_bdim, thread_masks_value, thread_masks_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_semi_structured_linear_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & meta, const ::std::optional & bias, ::std::optional activation, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(meta, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_sparse_semi_structured_linear::call(input, weight, meta, bias, activation, out_dtype); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [meta_value, meta_bdim] = unwrapTensorAtLevel(meta, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, meta_value, meta_bdim, bias_value, bias_bdim, activation, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_semi_structured_mm_generated_plumbing(const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat1_meta, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_sparse_semi_structured_mm::call(mat1, mat1_meta, mat2, out_dtype); + } + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat1_meta_value, mat1_meta_bdim] = unwrapTensorAtLevel(mat1_meta, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(mat1_value, mat1_bdim, mat1_meta_value, mat1_meta_bdim, mat2_value, mat2_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_semi_structured_addmm_generated_plumbing(const at::Tensor & input, const at::Tensor & mat1, const at::Tensor & mat1_meta, const at::Tensor & mat2, const at::Scalar & alpha, const at::Scalar & beta, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat1_meta, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_sparse_semi_structured_addmm::call(input, mat1, mat1_meta, mat2, alpha, beta, out_dtype); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat1_meta_value, mat1_meta_bdim] = unwrapTensorAtLevel(mat1_meta, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(input_value, input_bdim, mat1_value, mat1_bdim, mat1_meta_value, mat1_meta_bdim, mat2_value, mat2_bdim, alpha, beta, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mixed_dtypes_linear_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & scale, const ::std::optional & bias, ::std::optional activation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_mixed_dtypes_linear::call(input, weight, scale, bias, activation); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, scale_value, scale_bdim, bias_value, bias_bdim, activation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_int8_weight_fp32_activation_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(packed, cur_level) && !isBatchedAtLevel(col_offsets, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_int8_weight_fp32_activation::call(input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [packed_value, packed_bdim] = unwrapTensorAtLevel(packed, cur_level); + auto [col_offsets_value, col_offsets_bdim] = unwrapTensorAtLevel(col_offsets, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, packed_value, packed_bdim, col_offsets_value, col_offsets_bdim, weight_scale, weight_zero_point, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_int8_weight_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const at::Tensor & packed, const at::Tensor & col_offsets, const at::Scalar & weight_scale, const at::Scalar & weight_zero_point, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(packed, cur_level) && !isBatchedAtLevel(col_offsets, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_int8_weight::call(input, weight, packed, col_offsets, weight_scale, weight_zero_point, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [packed_value, packed_bdim] = unwrapTensorAtLevel(packed, cur_level); + auto [col_offsets_value, col_offsets_bdim] = unwrapTensorAtLevel(col_offsets, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, packed_value, packed_bdim, col_offsets_value, col_offsets_bdim, weight_scale, weight_zero_point, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_pack_gemm_matrix_fp16_generated_plumbing(const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::fbgemm_pack_gemm_matrix_fp16::call(input); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _wrapped_linear_prepack_generated_plumbing(const at::Tensor & weight, const at::Tensor & weight_scale, const at::Tensor & weight_zero_point, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_scale, cur_level) && !isBatchedAtLevel(weight_zero_point, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_wrapped_linear_prepack::call(weight, weight_scale, weight_zero_point, bias); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [weight_scale_value, weight_scale_bdim] = unwrapTensorAtLevel(weight_scale, cur_level); + auto [weight_zero_point_value, weight_zero_point_bdim] = unwrapTensorAtLevel(weight_zero_point, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(weight_value, weight_bdim, weight_scale_value, weight_scale_bdim, weight_zero_point_value, weight_zero_point_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _wrapped_quantized_linear_prepacked_generated_plumbing(const at::Tensor & input, const at::Tensor & input_scale, const at::Tensor & input_zero_point, const at::Tensor & packed_weight, const at::Tensor & output_scale, const at::Tensor & output_zero_point, int64_t out_channel) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(input_scale, cur_level) && !isBatchedAtLevel(input_zero_point, cur_level) && !isBatchedAtLevel(packed_weight, cur_level) && !isBatchedAtLevel(output_scale, cur_level) && !isBatchedAtLevel(output_zero_point, cur_level)) { + return at::_ops::_wrapped_quantized_linear_prepacked::call(input, input_scale, input_zero_point, packed_weight, output_scale, output_zero_point, out_channel); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [input_scale_value, input_scale_bdim] = unwrapTensorAtLevel(input_scale, cur_level); + auto [input_zero_point_value, input_zero_point_bdim] = unwrapTensorAtLevel(input_zero_point, cur_level); + auto [packed_weight_value, packed_weight_bdim] = unwrapTensorAtLevel(packed_weight, cur_level); + auto [output_scale_value, output_scale_bdim] = unwrapTensorAtLevel(output_scale, cur_level); + auto [output_zero_point_value, output_zero_point_bdim] = unwrapTensorAtLevel(output_zero_point, cur_level); + auto results = batch_rule(input_value, input_bdim, input_scale_value, input_scale_bdim, input_zero_point_value, input_zero_point_bdim, packed_weight_value, packed_weight_bdim, output_scale_value, output_scale_bdim, output_zero_point_value, output_zero_point_bdim, out_channel); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_fp16_weight_fp32_activation_generated_plumbing(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(packed_weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_fp16_weight_fp32_activation::call(input, packed_weight, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [packed_weight_value, packed_weight_bdim] = unwrapTensorAtLevel(packed_weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, packed_weight_value, packed_weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_linear_fp16_weight_generated_plumbing(const at::Tensor & input, const at::Tensor & packed_weight, const at::Tensor & bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(packed_weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::fbgemm_linear_fp16_weight::call(input, packed_weight, bias); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [packed_weight_value, packed_weight_bdim] = unwrapTensorAtLevel(packed_weight, cur_level); + auto [bias_value, bias_bdim] = unwrapTensorAtLevel(bias, cur_level); + auto results = batch_rule(input_value, input_bdim, packed_weight_value, packed_weight_bdim, bias_value, bias_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_pack_quantized_matrix_generated_plumbing(const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::fbgemm_pack_quantized_matrix::call(input); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fbgemm_pack_quantized_matrix_KN_generated_plumbing(const at::Tensor & input, int64_t K, int64_t N) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::fbgemm_pack_quantized_matrix_KN::call(input, K, N); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, K, N); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ldexp_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ldexp_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ldexp__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ldexp_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor linspace_Tensor_Tensor_generated_plumbing(const at::Tensor & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::linspace_Tensor_Tensor::call(start, end, steps, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start_value, start_bdim, end_value, end_bdim, steps, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linspace_Tensor_Scalar_generated_plumbing(const at::Tensor & start, const at::Scalar & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level)) { + return at::_ops::linspace_Tensor_Scalar::call(start, end, steps, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto results = batch_rule(start_value, start_bdim, end, steps, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linspace_Scalar_Tensor_generated_plumbing(const at::Scalar & start, const at::Tensor & end, int64_t steps, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(end, cur_level)) { + return at::_ops::linspace_Scalar_Tensor::call(start, end, steps, dtype, layout, device, pin_memory); + } + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start, end_value, end_bdim, steps, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor log10_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log10::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log10__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log10_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor log1p_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log1p::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log1p__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log1p_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor log2_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log2::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & log2__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log2_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor logaddexp_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logaddexp::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logaddexp2_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::logaddexp2::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor xlogy_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::xlogy_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor xlogy_Scalar_Self_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::xlogy_Scalar_Self::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor xlogy_Scalar_Other_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::xlogy_Scalar_Other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & xlogy__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::xlogy__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor & xlogy__Scalar_Other_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::xlogy__Scalar_Other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor logspace_Tensor_Tensor_generated_plumbing(const at::Tensor & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::logspace_Tensor_Tensor::call(start, end, steps, base, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start_value, start_bdim, end_value, end_bdim, steps, base, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logspace_Tensor_Scalar_generated_plumbing(const at::Tensor & start, const at::Scalar & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(start, cur_level)) { + return at::_ops::logspace_Tensor_Scalar::call(start, end, steps, base, dtype, layout, device, pin_memory); + } + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto results = batch_rule(start_value, start_bdim, end, steps, base, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logspace_Scalar_Tensor_generated_plumbing(const at::Scalar & start, const at::Tensor & end, int64_t steps, double base, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(end, cur_level)) { + return at::_ops::logspace_Scalar_Tensor::call(start, end, steps, base, dtype, layout, device, pin_memory); + } + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(start, end_value, end_bdim, steps, base, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _log_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_log_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _log_softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_log_softmax_backward_data::call(grad_output, output, dim, input_dtype); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, input_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _logcumsumexp_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_logcumsumexp::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logcumsumexp_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logcumsumexp::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logcumsumexp_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logcumsumexp_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logsumexp_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logsumexp::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logsumexp_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logsumexp_names::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor margin_ranking_loss_generated_plumbing(const at::Tensor & input1, const at::Tensor & input2, const at::Tensor & target, double margin, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input1, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::margin_ranking_loss::call(input1, input2, target, margin, reduction); + } + auto [input1_value, input1_bdim] = unwrapTensorAtLevel(input1, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(input1_value, input1_bdim, input2_value, input2_bdim, target_value, target_bdim, margin, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::matmul::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple matmul_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & other, ::std::array mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::matmul_backward::call(grad, self, other, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, other_value, other_bdim, mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor matrix_power_generated_plumbing(const at::Tensor & self, int64_t n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::matrix_power::call(self, n); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matrix_exp_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::matrix_exp::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matrix_exp_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad, cur_level)) { + return at::_ops::matrix_exp_backward::call(self, grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_value, grad_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _aminmax_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_aminmax::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _aminmax_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_aminmax_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple aminmax_generated_plumbing(const at::Tensor & self, ::std::optional dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::aminmax::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _compute_linear_combination_generated_plumbing(const at::Tensor & input, const at::Tensor & coefficients) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(coefficients, cur_level)) { + return at::_ops::_compute_linear_combination::call(input, coefficients); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [coefficients_value, coefficients_bdim] = unwrapTensorAtLevel(coefficients, cur_level); + auto results = batch_rule(input_value, input_bdim, coefficients_value, coefficients_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple max_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor value_selecting_reduction_backward_generated_plumbing(const at::Tensor & grad, int64_t dim, const at::Tensor & indices, c10::SymIntArrayRef sizes, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::value_selecting_reduction_backward::call(grad, dim, indices, sizes, keepdim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_value, grad_bdim, dim, indices_value, indices_bdim, sizes, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor amax_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::amax::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_pool1d_with_indices_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool1d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor max_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool1d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool2d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool2d_backward::call(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_max_pool2d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::mkldnn_max_pool2d_backward::call(grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, input_value, input_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_max_pool3d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_max_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & input, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::mkldnn_max_pool3d_backward::call(grad_output, output, input, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, input_value, input_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_max_pool1d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantized_max_pool1d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantized_max_pool2d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantized_max_pool3d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool3d::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mean_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mean::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mean_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mean_dim::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mean_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mean_names_dim::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nanmean_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmean::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor median_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::median::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple median_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::median_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple median_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::median_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nanmedian_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmedian::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple nanmedian_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmedian_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple nanmedian_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanmedian_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple min_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::min_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple min_names_dim_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::min_names_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor amin_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::amin::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mps_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_mps_convolution::call(self, weight, bias, padding, stride, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mps_convolution_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_output, const at::Tensor & weight, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::mps_convolution_backward::call(self, grad_output, weight, padding, stride, dilation, groups, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, padding, stride, dilation, groups, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor mkldnn_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::mkldnn_convolution::call(self, weight, bias, padding, stride, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mkldnn_rnn_layer_generated_plumbing(const at::Tensor & input, const at::Tensor & weight0, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & hx_, const at::Tensor & cx_, bool reverse, at::IntArrayRef batch_sizes, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool bidirectional, bool batch_first, bool train) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight0, cur_level) && !isBatchedAtLevel(weight1, cur_level) && !isBatchedAtLevel(weight2, cur_level) && !isBatchedAtLevel(weight3, cur_level) && !isBatchedAtLevel(hx_, cur_level) && !isBatchedAtLevel(cx_, cur_level)) { + return at::_ops::mkldnn_rnn_layer::call(input, weight0, weight1, weight2, weight3, hx_, cx_, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight0_value, weight0_bdim] = unwrapTensorAtLevel(weight0, cur_level); + auto [weight1_value, weight1_bdim] = unwrapTensorAtLevel(weight1, cur_level); + auto [weight2_value, weight2_bdim] = unwrapTensorAtLevel(weight2, cur_level); + auto [weight3_value, weight3_bdim] = unwrapTensorAtLevel(weight3, cur_level); + auto [hx__value, hx__bdim] = unwrapTensorAtLevel(hx_, cur_level); + auto [cx__value, cx__bdim] = unwrapTensorAtLevel(cx_, cur_level); + auto results = batch_rule(input_value, input_bdim, weight0_value, weight0_bdim, weight1_value, weight1_bdim, weight2_value, weight2_bdim, weight3_value, weight3_bdim, hx__value, hx__bdim, cx__value, cx__bdim, reverse, batch_sizes, mode, hidden_size, num_layers, has_biases, bidirectional, batch_first, train); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple mkldnn_rnn_layer_backward_generated_plumbing(const at::Tensor & input, const at::Tensor & weight1, const at::Tensor & weight2, const at::Tensor & weight3, const at::Tensor & weight4, const at::Tensor & hx_, const at::Tensor & cx_tmp, const at::Tensor & output, const at::Tensor & hy_, const at::Tensor & cy_, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, bool reverse, int64_t mode, int64_t hidden_size, int64_t num_layers, bool has_biases, bool train, bool bidirectional, at::IntArrayRef batch_sizes, bool batch_first, const at::Tensor & workspace) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight1, cur_level) && !isBatchedAtLevel(weight2, cur_level) && !isBatchedAtLevel(weight3, cur_level) && !isBatchedAtLevel(weight4, cur_level) && !isBatchedAtLevel(hx_, cur_level) && !isBatchedAtLevel(cx_tmp, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(hy_, cur_level) && !isBatchedAtLevel(cy_, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::mkldnn_rnn_layer_backward::call(input, weight1, weight2, weight3, weight4, hx_, cx_tmp, output, hy_, cy_, grad_output, grad_hy, grad_cy, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight1_value, weight1_bdim] = unwrapTensorAtLevel(weight1, cur_level); + auto [weight2_value, weight2_bdim] = unwrapTensorAtLevel(weight2, cur_level); + auto [weight3_value, weight3_bdim] = unwrapTensorAtLevel(weight3, cur_level); + auto [weight4_value, weight4_bdim] = unwrapTensorAtLevel(weight4, cur_level); + auto [hx__value, hx__bdim] = unwrapTensorAtLevel(hx_, cur_level); + auto [cx_tmp_value, cx_tmp_bdim] = unwrapTensorAtLevel(cx_tmp, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [hy__value, hy__bdim] = unwrapTensorAtLevel(hy_, cur_level); + auto [cy__value, cy__bdim] = unwrapTensorAtLevel(cy_, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight1_value, weight1_bdim, weight2_value, weight2_bdim, weight3_value, weight3_bdim, weight4_value, weight4_bdim, hx__value, hx__bdim, cx_tmp_value, cx_tmp_bdim, output_value, output_bdim, hy__value, hy__bdim, cy__value, cy__bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, reverse, mode, hidden_size, num_layers, has_biases, train, bidirectional, batch_sizes, batch_first, workspace_value, workspace_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level), makeBatched(std::get<12>(results), std::get<13>(results), cur_level)); +} +template +::std::tuple miopen_batch_norm_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double exponential_average_factor, double epsilon) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::miopen_batch_norm::call(input, weight, bias, running_mean, running_var, training, exponential_average_factor, epsilon); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, exponential_average_factor, epsilon); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple miopen_batch_norm_backward_generated_plumbing(const at::Tensor & input, const at::Tensor & grad_output, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, double epsilon) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var, cur_level)) { + return at::_ops::miopen_batch_norm_backward::call(input, grad_output, weight, running_mean, running_var, save_mean, save_var, epsilon); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_value; + std::optional save_var_bdim; + if (save_var) { + std::tie(save_var_value, save_var_bdim) = unwrapTensorAtLevel(save_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, grad_output_value, grad_output_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_value, save_var_bdim, epsilon); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor miopen_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution::call(self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups, benchmark, deterministic); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_convolution_transpose_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution_transpose::call(self, weight, bias, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, output_padding, stride, dilation, groups, benchmark, deterministic); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_depthwise_convolution_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, bool benchmark, bool deterministic) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_depthwise_convolution::call(self, weight, bias, padding, stride, dilation, groups, benchmark, deterministic); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride, dilation, groups, benchmark, deterministic); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_convolution_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution_relu::call(self, weight, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor miopen_convolution_add_relu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, const at::Tensor & z, const ::std::optional & alpha, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(z, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::miopen_convolution_add_relu::call(self, weight, z, alpha, bias, stride, padding, dilation, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [z_value, z_bdim] = unwrapTensorAtLevel(z, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, z_value, z_bdim, alpha, bias_value, bias_bdim, stride, padding, dilation, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple miopen_rnn_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & hx, const ::std::optional & cx, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(dropout_state, cur_level)) { + return at::_ops::miopen_rnn::call(input, weight, weight_stride0, hx, cx, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, hx_value, hx_bdim, cx_value, cx_bdim, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple> miopen_rnn_backward_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level)) { + return at::_ops::miopen_rnn_backward::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::mm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, at::ScalarType out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::mm_dtype::call(self, mat2, out_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _int_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_int_mm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_weight_to_int4pack_generated_plumbing(const at::Tensor & self, int64_t innerKTiles) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convert_weight_to_int4pack::call(self, innerKTiles); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, innerKTiles); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int4pack_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(qScaleAndZeros, cur_level)) { + return at::_ops::_weight_int4pack_mm::call(self, mat2, qGroupSize, qScaleAndZeros); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [qScaleAndZeros_value, qScaleAndZeros_bdim] = unwrapTensorAtLevel(qScaleAndZeros, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, qGroupSize, qScaleAndZeros_value, qScaleAndZeros_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int4pack_mm_with_scales_and_zeros_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScale, const at::Tensor & qZeros) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(qScale, cur_level) && !isBatchedAtLevel(qZeros, cur_level)) { + return at::_ops::_weight_int4pack_mm_with_scales_and_zeros::call(self, mat2, qGroupSize, qScale, qZeros); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [qScale_value, qScale_bdim] = unwrapTensorAtLevel(qScale, cur_level); + auto [qZeros_value, qZeros_bdim] = unwrapTensorAtLevel(qZeros, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, qGroupSize, qScale_value, qScale_bdim, qZeros_value, qZeros_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_weight_to_int4pack_for_cpu_generated_plumbing(const at::Tensor & self, int64_t innerKTiles) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convert_weight_to_int4pack_for_cpu::call(self, innerKTiles); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, innerKTiles); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int4pack_mm_for_cpu_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, int64_t qGroupSize, const at::Tensor & qScaleAndZeros) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(qScaleAndZeros, cur_level)) { + return at::_ops::_weight_int4pack_mm_for_cpu::call(self, mat2, qGroupSize, qScaleAndZeros); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [qScaleAndZeros_value, qScaleAndZeros_bdim] = unwrapTensorAtLevel(qScaleAndZeros, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, qGroupSize, qScaleAndZeros_value, qScaleAndZeros_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dyn_quant_pack_4bit_weight_generated_plumbing(const at::Tensor & weights, const at::Tensor & scales_zeros, const ::std::optional & bias, int64_t block_size, int64_t in_features, int64_t out_features) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weights, cur_level) && !isBatchedAtLevel(scales_zeros, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_dyn_quant_pack_4bit_weight::call(weights, scales_zeros, bias, block_size, in_features, out_features); + } + auto [weights_value, weights_bdim] = unwrapTensorAtLevel(weights, cur_level); + auto [scales_zeros_value, scales_zeros_bdim] = unwrapTensorAtLevel(scales_zeros, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(weights_value, weights_bdim, scales_zeros_value, scales_zeros_bdim, bias_value, bias_bdim, block_size, in_features, out_features); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dyn_quant_matmul_4bit_generated_plumbing(const at::Tensor & inp, const at::Tensor & packed_weights, int64_t block_size, int64_t in_features, int64_t out_features) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(inp, cur_level) && !isBatchedAtLevel(packed_weights, cur_level)) { + return at::_ops::_dyn_quant_matmul_4bit::call(inp, packed_weights, block_size, in_features, out_features); + } + auto [inp_value, inp_bdim] = unwrapTensorAtLevel(inp, cur_level); + auto [packed_weights_value, packed_weights_bdim] = unwrapTensorAtLevel(packed_weights, cur_level); + auto results = batch_rule(inp_value, inp_bdim, packed_weights_value, packed_weights_bdim, block_size, in_features, out_features); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_int8pack_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scales, cur_level)) { + return at::_ops::_weight_int8pack_mm::call(self, mat2, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scales_value, scales_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_mm_generated_plumbing(const at::Tensor & sparse, const at::Tensor & dense) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sparse, cur_level) && !isBatchedAtLevel(dense, cur_level)) { + return at::_ops::_sparse_mm::call(sparse, dense); + } + auto [sparse_value, sparse_bdim] = unwrapTensorAtLevel(sparse, cur_level); + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(sparse_value, sparse_bdim, dense_value, dense_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_mm_reduce_generated_plumbing(const at::Tensor & sparse, const at::Tensor & dense, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sparse, cur_level) && !isBatchedAtLevel(dense, cur_level)) { + return at::_ops::_sparse_mm_reduce::call(sparse, dense, reduce); + } + auto [sparse_value, sparse_bdim] = unwrapTensorAtLevel(sparse, cur_level); + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(sparse_value, sparse_bdim, dense_value, dense_bdim, reduce); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sparse_matmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_sparse_sparse_matmul::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple mode_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mode::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple mode_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mode_dimname::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor mul_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::mul_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mul__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::mul__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor mul_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mul_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mul__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mul__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor multiply_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::multiply_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & multiply__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::multiply__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor multiply_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::multiply_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & multiply__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::multiply__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor mv_generated_plumbing(const at::Tensor & self, const at::Tensor & vec) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec, cur_level)) { + return at::_ops::mv::call(self, vec); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec_value, vec_bdim] = unwrapTensorAtLevel(vec, cur_level); + auto results = batch_rule(self_value, self_bdim, vec_value, vec_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mvlgamma_generated_plumbing(const at::Tensor & self, int64_t p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mvlgamma::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mvlgamma__generated_plumbing(at::Tensor & self, int64_t p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mvlgamma_::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p); + return self; +} +template +at::Tensor narrow_copy_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::narrow_copy::call(self, dim, start, length); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, length); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor narrow_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt start, c10::SymInt length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::narrow::call(self, dim, start, length); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, length); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor narrow_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & start, c10::SymInt length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(start, cur_level)) { + return at::_ops::narrow_Tensor::call(self, dim, start, length); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [start_value, start_bdim] = unwrapTensorAtLevel(start, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start_value, start_bdim, length); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple native_batch_norm_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, bool training, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::native_batch_norm::call(input, weight, bias, running_mean, running_var, training, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _native_batch_norm_legit_no_training_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_native_batch_norm_legit_no_training::call(input, weight, bias, running_mean, running_var, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [running_mean_value, running_mean_bdim] = unwrapTensorAtLevel(running_mean, cur_level); + auto [running_var_value, running_var_bdim] = unwrapTensorAtLevel(running_var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _native_batch_norm_legit_no_stats_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, bool training, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_native_batch_norm_legit_no_stats::call(input, weight, bias, training, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, training, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple batch_norm_stats_generated_plumbing(const at::Tensor & input, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::batch_norm_stats::call(input, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor batch_norm_elemt_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & mean, const at::Tensor & invstd, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level)) { + return at::_ops::batch_norm_elemt::call(input, weight, bias, mean, invstd, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple batch_norm_gather_stats_generated_plumbing(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, int64_t count) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::batch_norm_gather_stats::call(input, mean, invstd, running_mean, running_var, momentum, eps, count); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps, count); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple batch_norm_gather_stats_with_counts_generated_plumbing(const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps, const at::Tensor & counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(counts, cur_level)) { + return at::_ops::batch_norm_gather_stats_with_counts::call(input, mean, invstd, running_mean, running_var, momentum, eps, counts); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + auto [counts_value, counts_bdim] = unwrapTensorAtLevel(counts, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps, counts_value, counts_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple native_batch_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const ::std::optional & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_invstd, bool train, double eps, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_invstd, cur_level)) { + return at::_ops::native_batch_norm_backward::call(grad_out, input, weight, running_mean, running_var, save_mean, save_invstd, train, eps, output_mask); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_invstd_value; + std::optional save_invstd_bdim; + if (save_invstd) { + std::tie(save_invstd_value, save_invstd_bdim) = unwrapTensorAtLevel(save_invstd.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_invstd_value, save_invstd_bdim, train, eps, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple batch_norm_backward_reduce_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, bool input_g, bool weight_g, bool bias_g) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::batch_norm_backward_reduce::call(grad_out, input, mean, invstd, weight, input_g, weight_g, bias_g); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, weight_value, weight_bdim, input_g, weight_g, bias_g); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor batch_norm_backward_elemt_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & mean, const at::Tensor & invstd, const ::std::optional & weight, const at::Tensor & sum_dy, const at::Tensor & sum_dy_xmu, const at::Tensor & count) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(invstd, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(sum_dy, cur_level) && !isBatchedAtLevel(sum_dy_xmu, cur_level) && !isBatchedAtLevel(count, cur_level)) { + return at::_ops::batch_norm_backward_elemt::call(grad_out, input, mean, invstd, weight, sum_dy, sum_dy_xmu, count); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [invstd_value, invstd_bdim] = unwrapTensorAtLevel(invstd, cur_level); + auto [sum_dy_value, sum_dy_bdim] = unwrapTensorAtLevel(sum_dy, cur_level); + auto [sum_dy_xmu_value, sum_dy_xmu_bdim] = unwrapTensorAtLevel(sum_dy_xmu, cur_level); + auto [count_value, count_bdim] = unwrapTensorAtLevel(count, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, mean_value, mean_bdim, invstd_value, invstd_bdim, weight_value, weight_bdim, sum_dy_value, sum_dy_bdim, sum_dy_xmu_value, sum_dy_xmu_bdim, count_value, count_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple batch_norm_update_stats_generated_plumbing(const at::Tensor & input, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::batch_norm_update_stats::call(input, running_mean, running_var, momentum); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _nnpack_spatial_convolution_generated_plumbing(const at::Tensor & input, const at::Tensor & weight, const ::std::optional & bias, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_nnpack_spatial_convolution::call(input, weight, bias, padding, stride); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, padding, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ones_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ones_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pairwise_distance_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, double p, double eps, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::pairwise_distance::call(x1, x2, p, eps, keepdim); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, p, eps, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cdist_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::cdist::call(x1, x2, p, compute_mode); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, p, compute_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _euclidean_dist_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::_euclidean_dist::call(x1, x2); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cdist_forward_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, double p, ::std::optional compute_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::_cdist_forward::call(x1, x2, p, compute_mode); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, p, compute_mode); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cdist_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & x1, const at::Tensor & x2, double p, const at::Tensor & cdist) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level) && !isBatchedAtLevel(cdist, cur_level)) { + return at::_ops::_cdist_backward::call(grad, x1, x2, p, cdist); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto [cdist_value, cdist_bdim] = unwrapTensorAtLevel(cdist, cur_level); + auto results = batch_rule(grad_value, grad_bdim, x1_value, x1_bdim, x2_value, x2_bdim, p, cdist_value, cdist_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pdist_generated_plumbing(const at::Tensor & self, double p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pdist::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pdist_forward_generated_plumbing(const at::Tensor & self, double p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pdist_forward::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pdist_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, double p, const at::Tensor & pdist) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(pdist, cur_level)) { + return at::_ops::_pdist_backward::call(grad, self, p, pdist); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [pdist_value, pdist_bdim] = unwrapTensorAtLevel(pdist, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, p, pdist_value, pdist_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cosine_similarity_generated_plumbing(const at::Tensor & x1, const at::Tensor & x2, int64_t dim, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x1, cur_level) && !isBatchedAtLevel(x2, cur_level)) { + return at::_ops::cosine_similarity::call(x1, x2, dim, eps); + } + auto [x1_value, x1_bdim] = unwrapTensorAtLevel(x1, cur_level); + auto [x2_value, x2_bdim] = unwrapTensorAtLevel(x2, cur_level); + auto results = batch_rule(x1_value, x1_bdim, x2_value, x2_bdim, dim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor permute_generated_plumbing(const at::Tensor & self, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::permute::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor movedim_intlist_generated_plumbing(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::movedim_intlist::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor movedim_int_generated_plumbing(const at::Tensor & self, int64_t source, int64_t destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::movedim_int::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor moveaxis_intlist_generated_plumbing(const at::Tensor & self, at::IntArrayRef source, at::IntArrayRef destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::moveaxis_intlist::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor moveaxis_int_generated_plumbing(const at::Tensor & self, int64_t source, int64_t destination) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::moveaxis_int::call(self, source, destination); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, destination); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor numpy_T_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::numpy_T::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor matrix_H_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::matrix_H::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mT_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mT::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mH_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mH::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adjoint_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adjoint::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pixel_shuffle_generated_plumbing(const at::Tensor & self, int64_t upscale_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pixel_shuffle::call(self, upscale_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upscale_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pixel_unshuffle_generated_plumbing(const at::Tensor & self, int64_t downscale_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pixel_unshuffle::call(self, downscale_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, downscale_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor channel_shuffle_generated_plumbing(const at::Tensor & self, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::channel_shuffle::call(self, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor native_channel_shuffle_generated_plumbing(const at::Tensor & self, c10::SymInt groups) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::native_channel_shuffle::call(self, groups); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, groups); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pin_memory_generated_plumbing(const at::Tensor & self, ::std::optional device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pin_memory::call(self, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, device); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pin_memory_generated_plumbing(const at::Tensor & self, ::std::optional device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pin_memory::call(self, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, device); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pinverse_generated_plumbing(const at::Tensor & self, double rcond) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pinverse::call(self, rcond); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, rcond); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor poisson_nll_loss_generated_plumbing(const at::Tensor & input, const at::Tensor & target, bool log_input, bool full, double eps, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::poisson_nll_loss::call(input, target, log_input, full, eps, reduction); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(input_value, input_bdim, target_value, target_bdim, log_input, full, eps, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rad2deg_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rad2deg::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & rad2deg__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rad2deg_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor deg2rad_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::deg2rad::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & deg2rad__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::deg2rad_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor rand_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rand_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_generated_plumbing(const at::Tensor & self, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randint_like::call(self, high, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, high, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(high, cur_level)) { + return at::_ops::randint_like_Tensor::call(self, high, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [high_value, high_bdim] = unwrapTensorAtLevel(high, cur_level); + auto results = batch_rule(self_value, self_bdim, high_value, high_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randint_like_low_dtype_generated_plumbing(const at::Tensor & self, c10::SymInt low, c10::SymInt high, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randint_like_low_dtype::call(self, low, high, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, low, high, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor randn_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::randn_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ravel_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ravel::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reciprocal_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reciprocal::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & reciprocal__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reciprocal_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor neg_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::neg::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & neg__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::neg_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor negative_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::negative::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & negative__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::negative_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor repeat_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef repeats) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::repeat::call(self, repeats); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, repeats); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor repeat_interleave_Tensor_generated_plumbing(const at::Tensor & repeats, ::std::optional output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(repeats, cur_level)) { + return at::_ops::repeat_interleave_Tensor::call(repeats, output_size); + } + auto [repeats_value, repeats_bdim] = unwrapTensorAtLevel(repeats, cur_level); + auto results = batch_rule(repeats_value, repeats_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor repeat_interleave_self_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & repeats, ::std::optional dim, ::std::optional output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(repeats, cur_level)) { + return at::_ops::repeat_interleave_self_Tensor::call(self, repeats, dim, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [repeats_value, repeats_bdim] = unwrapTensorAtLevel(repeats, cur_level); + auto results = batch_rule(self_value, self_bdim, repeats_value, repeats_bdim, dim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor repeat_interleave_self_int_generated_plumbing(const at::Tensor & self, c10::SymInt repeats, ::std::optional dim, ::std::optional output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::repeat_interleave_self_int::call(self, repeats, dim, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, repeats, dim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reshape_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef shape) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reshape::call(self, shape); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, shape); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _reshape_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_reshape_copy::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _reshape_alias_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_reshape_alias::call(self, size, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mkldnn_reshape_generated_plumbing(const at::Tensor & self, at::IntArrayRef shape) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_mkldnn_reshape::call(self, shape); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, shape); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reshape_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::reshape_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor round_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & round__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor round_decimals_generated_plumbing(const at::Tensor & self, int64_t decimals) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round_decimals::call(self, decimals); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, decimals); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & round__decimals_generated_plumbing(at::Tensor & self, int64_t decimals) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::round__decimals::call(self, decimals); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, decimals); + return self; +} +template +at::Tensor rrelu_generated_plumbing(const at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rrelu::call(self, lower, upper, training, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lower, upper, training, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & rrelu__generated_plumbing(at::Tensor & self, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rrelu_::call(self, lower, upper, training, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, lower, upper, training, generator); + return self; +} +template +at::Tensor relu_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & relu__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor relu6_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu6::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & relu6__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::relu6_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor prelu_generated_plumbing(const at::Tensor & self, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::prelu::call(self, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _prelu_kernel_generated_plumbing(const at::Tensor & self, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_prelu_kernel::call(self, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _prelu_kernel_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_prelu_kernel_backward::call(grad_output, self, weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, weight_value, weight_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor & gelu__generated_plumbing(at::Tensor & self, c10::string_view approximate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gelu_::call(self, approximate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, approximate); + return self; +} +template +at::Tensor gelu_generated_plumbing(const at::Tensor & self, c10::string_view approximate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gelu::call(self, approximate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, approximate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gelu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::string_view approximate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::gelu_backward::call(grad_output, self, approximate); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, approximate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor infinitely_differentiable_gelu_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::infinitely_differentiable_gelu_backward::call(grad, self); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardshrink_generated_plumbing(const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardshrink::call(self, lambd); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardshrink_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardshrink_backward::call(grad_out, self, lambd); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rsqrt_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rsqrt::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & rsqrt__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rsqrt_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor select_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, int64_t index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::select_Dimname::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_int_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::select_int::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::select_backward::call(grad_output, input_sizes, dim, index); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_sizes, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_select_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_select_backward::call(grad_output, self, dim, index); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor selu_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::selu::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & selu__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::selu_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor celu_generated_plumbing(const at::Tensor & self, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::celu::call(self, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & celu__generated_plumbing(at::Tensor & self, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::celu_::call(self, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, alpha); + return self; +} +template +at::Tensor silu_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::silu::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & silu__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::silu_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor silu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::silu_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mish_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mish::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & mish__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mish_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor mish_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::mish_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sigmoid_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sigmoid::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sigmoid__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sigmoid_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor logit_generated_plumbing(const at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logit::call(self, eps); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & logit__generated_plumbing(at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logit_::call(self, eps); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, eps); + return self; +} +template +at::Tensor sin_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sin::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sin__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sin_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sinc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sinc__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinc_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sinh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sinh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sinh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor detach_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::detach::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::slice_Tensor::call(self, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef input_sizes, int64_t dim, c10::SymInt start, c10::SymInt end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::slice_backward::call(grad_output, input_sizes, dim, start, end, step); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, input_sizes, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_inverse_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::slice_inverse::call(self, src, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::slice_scatter::call(self, src, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::select_scatter::call(self, src, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::diagonal_scatter::call(self, src, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor as_strided_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & src, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::as_strided_scatter::call(self, src, size, stride, storage_offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, size, stride, storage_offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor smm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::smm::call(self, mat2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, at::ScalarType input_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_softmax_backward_data::call(grad_output, output, dim, input_dtype); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, input_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unsafe_split_Tensor_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsafe_split_Tensor::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_Tensor_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_Tensor::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_sizes_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_sizes::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unsafe_split_with_sizes_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsafe_split_with_sizes::call(self, split_sizes, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_sizes, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_with_sizes_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_with_sizes::call(self, split_sizes, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_sizes, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector hsplit_int_generated_plumbing(const at::Tensor & self, int64_t sections) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hsplit_int::call(self, sections); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector hsplit_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hsplit_array::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector vsplit_int_generated_plumbing(const at::Tensor & self, int64_t sections) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::vsplit_int::call(self, sections); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector vsplit_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::vsplit_array::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector dsplit_int_generated_plumbing(const at::Tensor & self, int64_t sections) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dsplit_int::call(self, sections); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sections); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector dsplit_array_generated_plumbing(const at::Tensor & self, at::IntArrayRef indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dsplit_array::call(self, indices); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, indices); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_dim_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_dim::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_dims_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_dims::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sspaddmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::sspaddmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _chunk_cat_generated_plumbing(at::TensorList tensors, int64_t dim, int64_t num_chunks) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::_chunk_cat::call(tensors, dim, num_chunks); + } + + auto results = batch_rule(tensors, dim, num_chunks); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor stack_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::stack::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _stack_generated_plumbing(at::TensorList tensors, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::_stack::call(tensors, dim); + } + + auto results = batch_rule(tensors, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hstack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::hstack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor vstack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::vstack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dstack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::dstack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor stft_generated_plumbing(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(window, cur_level)) { + return at::_ops::stft::call(self, n_fft, hop_length, win_length, window, normalized, onesided, return_complex, align_to_window); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional window_value; + std::optional window_bdim; + if (window) { + std::tie(window_value, window_bdim) = unwrapTensorAtLevel(window.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n_fft, hop_length, win_length, window_value, window_bdim, normalized, onesided, return_complex, align_to_window); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor stft_center_generated_plumbing(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, c10::string_view pad_mode, bool normalized, ::std::optional onesided, ::std::optional return_complex, ::std::optional align_to_window) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(window, cur_level)) { + return at::_ops::stft_center::call(self, n_fft, hop_length, win_length, window, center, pad_mode, normalized, onesided, return_complex, align_to_window); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional window_value; + std::optional window_bdim; + if (window) { + std::tie(window_value, window_bdim) = unwrapTensorAtLevel(window.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n_fft, hop_length, win_length, window_value, window_bdim, center, pad_mode, normalized, onesided, return_complex, align_to_window); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor istft_generated_plumbing(const at::Tensor & self, int64_t n_fft, ::std::optional hop_length, ::std::optional win_length, const ::std::optional & window, bool center, bool normalized, ::std::optional onesided, ::std::optional length, bool return_complex) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(window, cur_level)) { + return at::_ops::istft::call(self, n_fft, hop_length, win_length, window, center, normalized, onesided, length, return_complex); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional window_value; + std::optional window_bdim; + if (window) { + std::tie(window_value, window_bdim) = unwrapTensorAtLevel(window.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, n_fft, hop_length, win_length, window_value, window_bdim, center, normalized, onesided, length, return_complex); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_dim_IntList_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum_dim_IntList::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_dim_DimnameList_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum_dim_DimnameList::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_sum_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_sum_backward::call(grad, self, dim, keepdim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nansum_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nansum::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sum_to_size_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sum_to_size::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sqrt_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sqrt::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sqrt__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sqrt_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor square_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::square::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & square__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::square_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor std_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor std_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor std_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple std_mean_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple std_mean_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_mean_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor std_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor std_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::std_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor prod_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::prod::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor prod_dim_int_generated_plumbing(const at::Tensor & self, int64_t dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::prod_dim_int::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor prod_dim_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::prod_dim_Dimname::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor t_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::t::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tan_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tan::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & tan__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tan_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor tanh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tanh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & tanh__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tanh_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor tensordot_generated_plumbing(const at::Tensor & self, const at::Tensor & other, at::IntArrayRef dims_self, at::IntArrayRef dims_other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::tensordot::call(self, other, dims_self, dims_other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dims_self, dims_other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor threshold_generated_plumbing(const at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::threshold::call(self, threshold, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, threshold, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & threshold__generated_plumbing(at::Tensor & self, const at::Scalar & threshold, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::threshold_::call(self, threshold, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, threshold, value); + return self; +} +template +at::Tensor threshold_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & threshold) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::threshold_backward::call(grad_output, self, threshold); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, threshold); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tile_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tile::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor transpose_int_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::transpose_int::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor transpose_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim0, at::Dimname dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::transpose_Dimname::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _mkldnn_transpose_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_mkldnn_transpose::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _mkldnn_transpose__generated_plumbing(at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_mkldnn_transpose_::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim0, dim1); + return self; +} +template +at::Tensor one_hot_generated_plumbing(const at::Tensor & self, int64_t num_classes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::one_hot::call(self, num_classes); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, num_classes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flip_generated_plumbing(const at::Tensor & self, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flip::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fliplr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fliplr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flipud_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::flipud::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor roll_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef shifts, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::roll::call(self, shifts, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, shifts, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rot90_generated_plumbing(const at::Tensor & self, int64_t k, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rot90::call(self, k, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapezoid_x_generated_plumbing(const at::Tensor & y, const at::Tensor & x, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level) && !isBatchedAtLevel(x, cur_level)) { + return at::_ops::trapezoid_x::call(y, x, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(y_value, y_bdim, x_value, x_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapezoid_dx_generated_plumbing(const at::Tensor & y, const at::Scalar & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level)) { + return at::_ops::trapezoid_dx::call(y, dx, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(y_value, y_bdim, dx, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapz_x_generated_plumbing(const at::Tensor & y, const at::Tensor & x, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level) && !isBatchedAtLevel(x, cur_level)) { + return at::_ops::trapz_x::call(y, x, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(y_value, y_bdim, x_value, x_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trapz_dx_generated_plumbing(const at::Tensor & y, double dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(y, cur_level)) { + return at::_ops::trapz_dx::call(y, dx, dim); + } + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(y_value, y_bdim, dx, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _transform_bias_rescale_qkv_generated_plumbing(const at::Tensor & qkv, const at::Tensor & qkv_bias, int64_t num_heads) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(qkv, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level)) { + return at::_ops::_transform_bias_rescale_qkv::call(qkv, qkv_bias, num_heads); + } + auto [qkv_value, qkv_bdim] = unwrapTensorAtLevel(qkv, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto results = batch_rule(qkv_value, qkv_bdim, qkv_bias_value, qkv_bias_bdim, num_heads); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _nested_tensor_from_mask_generated_plumbing(const at::Tensor & t, const at::Tensor & mask, bool mask_check) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(t, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_nested_tensor_from_mask::call(t, mask, mask_check); + } + auto [t_value, t_bdim] = unwrapTensorAtLevel(t, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(t_value, t_bdim, mask_value, mask_bdim, mask_check); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_from_padded_generated_plumbing(const at::Tensor & padded, const at::Tensor & cpu_nested_shape_example, bool fuse_transform_0213) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(padded, cur_level) && !isBatchedAtLevel(cpu_nested_shape_example, cur_level)) { + return at::_ops::_nested_from_padded::call(padded, cpu_nested_shape_example, fuse_transform_0213); + } + auto [padded_value, padded_bdim] = unwrapTensorAtLevel(padded, cur_level); + auto [cpu_nested_shape_example_value, cpu_nested_shape_example_bdim] = unwrapTensorAtLevel(cpu_nested_shape_example, cur_level); + auto results = batch_rule(padded_value, padded_bdim, cpu_nested_shape_example_value, cpu_nested_shape_example_bdim, fuse_transform_0213); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_size_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_tensor_size::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_strides_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_tensor_strides::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_storage_offsets_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_tensor_storage_offsets::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_from_padded_and_nested_example_generated_plumbing(const at::Tensor & padded, const at::Tensor & nt_example) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(padded, cur_level) && !isBatchedAtLevel(nt_example, cur_level)) { + return at::_ops::_nested_from_padded_and_nested_example::call(padded, nt_example); + } + auto [padded_value, padded_bdim] = unwrapTensorAtLevel(padded, cur_level); + auto [nt_example_value, nt_example_bdim] = unwrapTensorAtLevel(nt_example, cur_level); + auto results = batch_rule(padded_value, padded_bdim, nt_example_value, nt_example_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_buffer_generated_plumbing(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(nested_size, cur_level) && !isBatchedAtLevel(nested_strides, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_nested_view_from_buffer::call(self, nested_size, nested_strides, offsets); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [nested_size_value, nested_size_bdim] = unwrapTensorAtLevel(nested_size, cur_level); + auto [nested_strides_value, nested_strides_bdim] = unwrapTensorAtLevel(nested_strides, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto results = batch_rule(self_value, self_bdim, nested_size_value, nested_size_bdim, nested_strides_value, nested_strides_bdim, offsets_value, offsets_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_buffer_copy_generated_plumbing(const at::Tensor & self, const at::Tensor & nested_size, const at::Tensor & nested_strides, const at::Tensor & offsets) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(nested_size, cur_level) && !isBatchedAtLevel(nested_strides, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_nested_view_from_buffer_copy::call(self, nested_size, nested_strides, offsets); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [nested_size_value, nested_size_bdim] = unwrapTensorAtLevel(nested_size, cur_level); + auto [nested_strides_value, nested_strides_bdim] = unwrapTensorAtLevel(nested_strides, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto results = batch_rule(self_value, self_bdim, nested_size_value, nested_size_bdim, nested_strides_value, nested_strides_bdim, offsets_value, offsets_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_jagged_generated_plumbing(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(dummy, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(min_seqlen, cur_level) && !isBatchedAtLevel(max_seqlen, cur_level)) { + return at::_ops::_nested_view_from_jagged::call(self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional min_seqlen_value; + std::optional min_seqlen_bdim; + if (min_seqlen) { + std::tie(min_seqlen_value, min_seqlen_bdim) = unwrapTensorAtLevel(min_seqlen.value(), cur_level); + } + std::optional max_seqlen_value; + std::optional max_seqlen_bdim; + if (max_seqlen) { + std::tie(max_seqlen_value, max_seqlen_bdim) = unwrapTensorAtLevel(max_seqlen.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, offsets_value, offsets_bdim, dummy_value, dummy_bdim, lengths_value, lengths_bdim, ragged_idx, min_seqlen_value, min_seqlen_bdim, max_seqlen_value, max_seqlen_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_view_from_jagged_copy_generated_plumbing(const at::Tensor & self, const at::Tensor & offsets, const at::Tensor & dummy, const ::std::optional & lengths, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(dummy, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(min_seqlen, cur_level) && !isBatchedAtLevel(max_seqlen, cur_level)) { + return at::_ops::_nested_view_from_jagged_copy::call(self, offsets, dummy, lengths, ragged_idx, min_seqlen, max_seqlen); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional min_seqlen_value; + std::optional min_seqlen_bdim; + if (min_seqlen) { + std::tie(min_seqlen_value, min_seqlen_bdim) = unwrapTensorAtLevel(min_seqlen.value(), cur_level); + } + std::optional max_seqlen_value; + std::optional max_seqlen_bdim; + if (max_seqlen) { + std::tie(max_seqlen_value, max_seqlen_bdim) = unwrapTensorAtLevel(max_seqlen.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, offsets_value, offsets_bdim, dummy_value, dummy_bdim, lengths_value, lengths_bdim, ragged_idx, min_seqlen_value, min_seqlen_bdim, max_seqlen_value, max_seqlen_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_values_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_values::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_values_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_values_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_offsets_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_offsets::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_lengths_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_lengths::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_min_seqlen_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_min_seqlen::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_max_seqlen_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_nested_get_max_seqlen::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_get_jagged_dummy_generated_plumbing(const at::Tensor & any) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(any, cur_level)) { + return at::_ops::_nested_get_jagged_dummy::call(any); + } + auto [any_value, any_bdim] = unwrapTensorAtLevel(any, cur_level); + auto results = batch_rule(any_value, any_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _nested_compute_contiguous_strides_offsets_generated_plumbing(const at::Tensor & nested_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(nested_size, cur_level)) { + return at::_ops::_nested_compute_contiguous_strides_offsets::call(nested_size); + } + auto [nested_size_value, nested_size_bdim] = unwrapTensorAtLevel(nested_size, cur_level); + auto results = batch_rule(nested_size_value, nested_size_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _trilinear_generated_plumbing(const at::Tensor & i1, const at::Tensor & i2, const at::Tensor & i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(i1, cur_level) && !isBatchedAtLevel(i2, cur_level) && !isBatchedAtLevel(i3, cur_level)) { + return at::_ops::_trilinear::call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + } + auto [i1_value, i1_bdim] = unwrapTensorAtLevel(i1, cur_level); + auto [i2_value, i2_bdim] = unwrapTensorAtLevel(i2, cur_level); + auto [i3_value, i3_bdim] = unwrapTensorAtLevel(i3, cur_level); + auto results = batch_rule(i1_value, i1_bdim, i2_value, i2_bdim, i3_value, i3_bdim, expand1, expand2, expand3, sumdim, unroll_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor triplet_margin_loss_generated_plumbing(const at::Tensor & anchor, const at::Tensor & positive, const at::Tensor & negative, double margin, double p, double eps, bool swap, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(anchor, cur_level) && !isBatchedAtLevel(positive, cur_level) && !isBatchedAtLevel(negative, cur_level)) { + return at::_ops::triplet_margin_loss::call(anchor, positive, negative, margin, p, eps, swap, reduction); + } + auto [anchor_value, anchor_bdim] = unwrapTensorAtLevel(anchor, cur_level); + auto [positive_value, positive_bdim] = unwrapTensorAtLevel(positive, cur_level); + auto [negative_value, negative_bdim] = unwrapTensorAtLevel(negative, cur_level); + auto results = batch_rule(anchor_value, anchor_bdim, positive_value, positive_bdim, negative_value, negative_bdim, margin, p, eps, swap, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trunc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::trunc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & trunc__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::trunc_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor fix_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fix::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fix__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fix_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor type_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::type_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _unique_generated_plumbing(const at::Tensor & self, bool sorted, bool return_inverse) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_unique::call(self, sorted, return_inverse); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sorted, return_inverse); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple unique_dim_generated_plumbing(const at::Tensor & self, int64_t dim, bool sorted, bool return_inverse, bool return_counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unique_dim::call(self, dim, sorted, return_inverse, return_counts); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, sorted, return_inverse, return_counts); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple unique_consecutive_generated_plumbing(const at::Tensor & self, bool return_inverse, bool return_counts, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unique_consecutive::call(self, return_inverse, return_counts, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, return_inverse, return_counts, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple unique_dim_consecutive_generated_plumbing(const at::Tensor & self, int64_t dim, bool return_inverse, bool return_counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unique_dim_consecutive::call(self, dim, return_inverse, return_counts); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, return_inverse, return_counts); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _unique2_generated_plumbing(const at::Tensor & self, bool sorted, bool return_inverse, bool return_counts) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_unique2::call(self, sorted, return_inverse, return_counts); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sorted, return_inverse, return_counts); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _unsafe_view_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_unsafe_view::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unsqueeze_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsqueeze::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor vander_generated_plumbing(const at::Tensor & x, ::std::optional N, bool increasing) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::vander::call(x, N, increasing); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, N, increasing); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor var_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple var_mean_generated_plumbing(const at::Tensor & self, bool unbiased) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean::call(self, unbiased); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, unbiased); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_dim_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_correction_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_correction::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_names_dim_generated_plumbing(const at::Tensor & self, at::DimnameList dim, bool unbiased, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_names_dim::call(self, dim, unbiased, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, unbiased, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple var_mean_correction_names_generated_plumbing(const at::Tensor & self, at::DimnameList dim, const ::std::optional & correction, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::var_mean_correction_names::call(self, dim, correction, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, correction, keepdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor view_as_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::view_as::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_self_generated_plumbing(const at::Tensor & condition, const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::where_self::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_ScalarSelf_generated_plumbing(const at::Tensor & condition, const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::where_ScalarSelf::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_ScalarOther_generated_plumbing(const at::Tensor & condition, const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::where_ScalarOther::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor where_Scalar_generated_plumbing(const at::Tensor & condition, const at::Scalar & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level)) { + return at::_ops::where_Scalar::call(condition, self, other); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto results = batch_rule(condition_value, condition_bdim, self, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector where_generated_plumbing(const at::Tensor & condition) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(condition, cur_level)) { + return at::_ops::where::call(condition); + } + auto [condition_value, condition_bdim] = unwrapTensorAtLevel(condition, cur_level); + auto results = batch_rule(condition_value, condition_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_except_dim_generated_plumbing(const at::Tensor & v, int64_t pow, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(v, cur_level)) { + return at::_ops::norm_except_dim::call(v, pow, dim); + } + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto results = batch_rule(v_value, v_bdim, pow, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _weight_norm_generated_plumbing(const at::Tensor & v, const at::Tensor & g, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(v, cur_level) && !isBatchedAtLevel(g, cur_level)) { + return at::_ops::_weight_norm::call(v, g, dim); + } + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto [g_value, g_bdim] = unwrapTensorAtLevel(g, cur_level); + auto results = batch_rule(v_value, v_bdim, g_value, g_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _weight_norm_interface_generated_plumbing(const at::Tensor & v, const at::Tensor & g, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(v, cur_level) && !isBatchedAtLevel(g, cur_level)) { + return at::_ops::_weight_norm_interface::call(v, g, dim); + } + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto [g_value, g_bdim] = unwrapTensorAtLevel(g, cur_level); + auto results = batch_rule(v_value, v_bdim, g_value, g_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _weight_norm_interface_backward_generated_plumbing(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_w, cur_level) && !isBatchedAtLevel(saved_v, cur_level) && !isBatchedAtLevel(saved_g, cur_level) && !isBatchedAtLevel(saved_norms, cur_level)) { + return at::_ops::_weight_norm_interface_backward::call(grad_w, saved_v, saved_g, saved_norms, dim); + } + auto [grad_w_value, grad_w_bdim] = unwrapTensorAtLevel(grad_w, cur_level); + auto [saved_v_value, saved_v_bdim] = unwrapTensorAtLevel(saved_v, cur_level); + auto [saved_g_value, saved_g_bdim] = unwrapTensorAtLevel(saved_g, cur_level); + auto [saved_norms_value, saved_norms_bdim] = unwrapTensorAtLevel(saved_norms, cur_level); + auto results = batch_rule(grad_w_value, grad_w_bdim, saved_v_value, saved_v_bdim, saved_g_value, saved_g_bdim, saved_norms_value, saved_norms_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _weight_norm_differentiable_backward_generated_plumbing(const at::Tensor & grad_w, const at::Tensor & saved_v, const at::Tensor & saved_g, const at::Tensor & saved_norms, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_w, cur_level) && !isBatchedAtLevel(saved_v, cur_level) && !isBatchedAtLevel(saved_g, cur_level) && !isBatchedAtLevel(saved_norms, cur_level)) { + return at::_ops::_weight_norm_differentiable_backward::call(grad_w, saved_v, saved_g, saved_norms, dim); + } + auto [grad_w_value, grad_w_bdim] = unwrapTensorAtLevel(grad_w, cur_level); + auto [saved_v_value, saved_v_bdim] = unwrapTensorAtLevel(saved_v, cur_level); + auto [saved_g_value, saved_g_bdim] = unwrapTensorAtLevel(saved_g, cur_level); + auto [saved_norms_value, saved_norms_bdim] = unwrapTensorAtLevel(saved_norms, cur_level); + auto results = batch_rule(grad_w_value, grad_w_bdim, saved_v_value, saved_v_bdim, saved_g_value, saved_g_bdim, saved_norms_value, saved_norms_bdim, dim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor zeros_like_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::zeros_like::call(self, dtype, layout, device, pin_memory, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _standard_gamma_grad_generated_plumbing(const at::Tensor & self, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_standard_gamma_grad::call(self, output); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(self_value, self_bdim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _standard_gamma_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_standard_gamma::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _dirichlet_grad_generated_plumbing(const at::Tensor & x, const at::Tensor & alpha, const at::Tensor & total) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(alpha, cur_level) && !isBatchedAtLevel(total, cur_level)) { + return at::_ops::_dirichlet_grad::call(x, alpha, total); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [alpha_value, alpha_bdim] = unwrapTensorAtLevel(alpha, cur_level); + auto [total_value, total_bdim] = unwrapTensorAtLevel(total, cur_level); + auto results = batch_rule(x_value, x_bdim, alpha_value, alpha_bdim, total_value, total_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sample_dirichlet_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sample_dirichlet::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor poisson_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::poisson::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor binomial_generated_plumbing(const at::Tensor & count, const at::Tensor & prob, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(count, cur_level) && !isBatchedAtLevel(prob, cur_level)) { + return at::_ops::binomial::call(count, prob, generator); + } + auto [count_value, count_bdim] = unwrapTensorAtLevel(count, cur_level); + auto [prob_value, prob_bdim] = unwrapTensorAtLevel(prob, cur_level); + auto results = batch_rule(count_value, count_bdim, prob_value, prob_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor native_norm_generated_plumbing(const at::Tensor & self, const at::Scalar & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::native_norm::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor native_norm_ScalarOpt_dim_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::native_norm_ScalarOpt_dim_dtype::call(self, p, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _batch_norm_no_update_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const ::std::optional & running_mean, const ::std::optional & running_var, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_batch_norm_no_update::call(input, weight, bias, running_mean, running_var, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple batch_norm_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & input, const at::Tensor & weight, const ::std::optional & running_mean, const ::std::optional & running_var, const ::std::optional & save_mean, const ::std::optional & save_var, bool update, double eps, ::std::array output_mask, const at::Tensor & reserve) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level) && !isBatchedAtLevel(save_mean, cur_level) && !isBatchedAtLevel(save_var, cur_level) && !isBatchedAtLevel(reserve, cur_level)) { + return at::_ops::batch_norm_backward::call(grad_out, input, weight, running_mean, running_var, save_mean, save_var, update, eps, output_mask, reserve); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + std::optional running_mean_value; + std::optional running_mean_bdim; + if (running_mean) { + std::tie(running_mean_value, running_mean_bdim) = unwrapTensorAtLevel(running_mean.value(), cur_level); + } + std::optional running_var_value; + std::optional running_var_bdim; + if (running_var) { + std::tie(running_var_value, running_var_bdim) = unwrapTensorAtLevel(running_var.value(), cur_level); + } + std::optional save_mean_value; + std::optional save_mean_bdim; + if (save_mean) { + std::tie(save_mean_value, save_mean_bdim) = unwrapTensorAtLevel(save_mean.value(), cur_level); + } + std::optional save_var_value; + std::optional save_var_bdim; + if (save_var) { + std::tie(save_var_value, save_var_bdim) = unwrapTensorAtLevel(save_var.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, input_value, input_bdim, weight_value, weight_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, save_mean_value, save_mean_bdim, save_var_value, save_var_bdim, update, eps, output_mask, reserve_value, reserve_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _sparse_sum_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_dtype::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_dim_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_dim::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_dim_dtype_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_dim_dtype::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_sum_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_sum_backward::call(grad, self, dim); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csr_sum_dim_dtype_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_csr_sum_dim_dtype::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csr_prod_dim_dtype_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_csr_prod_dim_dtype::call(self, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_softmax_backward_data::call(grad_output, output, dim, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_int_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax_int::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax_Dimname::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, bool half_to_float) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax::call(self, dim, half_to_float); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, half_to_float); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_log_softmax_backward_data_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, int64_t dim, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_log_softmax_backward_data::call(grad_output, output, dim, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, dim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _spdiags_generated_plumbing(const at::Tensor & diagonals, const at::Tensor & offsets, at::IntArrayRef shape, ::std::optional layout) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(diagonals, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_spdiags::call(diagonals, offsets, shape, layout); + } + auto [diagonals_value, diagonals_bdim] = unwrapTensorAtLevel(diagonals, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto results = batch_rule(diagonals_value, diagonals_bdim, offsets_value, offsets_bdim, shape, layout); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_ScalarOpt_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_ScalarOpt_dtype::call(self, p, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_Scalar::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_ScalarOpt_dim_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_ScalarOpt_dim_dtype::call(self, p, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_ScalarOpt_dim_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_ScalarOpt_dim::call(self, p, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_names_ScalarOpt_dim_dtype_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_names_ScalarOpt_dim_dtype::call(self, p, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor norm_names_ScalarOpt_dim_generated_plumbing(const at::Tensor & self, const ::std::optional & p, at::DimnameList dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::norm_names_ScalarOpt_dim::call(self, p, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple frexp_Tensor_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frexp_Tensor::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor frobenius_norm_dim_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::frobenius_norm_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nuclear_norm_generated_plumbing(const at::Tensor & self, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nuclear_norm::call(self, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nuclear_norm_dim_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nuclear_norm_dim::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor clone_generated_plumbing(const at::Tensor & self, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::clone::call(self, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor positive_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::positive::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +const at::Tensor & resize_as_sparse__generated_plumbing(const at::Tensor & self, const at::Tensor & the_template) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(the_template, cur_level)) { + return at::_ops::resize_as_sparse_::call(self, the_template); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [the_template_value, the_template_bdim] = unwrapTensorAtLevel(the_template, cur_level); + batch_rule(self_value, self_bdim, the_template_value, the_template_bdim); + return self; +} +template +at::Tensor & zero__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::zero_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sub_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::sub_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sub__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::sub__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor sub_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sub_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sub__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sub__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor subtract_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::subtract_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & subtract__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::subtract__Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return self; +} +template +at::Tensor subtract_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::subtract_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & subtract__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::subtract__Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other, alpha); + return self; +} +template +at::Tensor rsub_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::rsub_Tensor::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor heaviside_generated_plumbing(const at::Tensor & self, const at::Tensor & values) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::heaviside::call(self, values); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, values_value, values_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & heaviside__generated_plumbing(at::Tensor & self, const at::Tensor & values) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::heaviside_::call(self, values); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(self_value, self_bdim, values_value, values_bdim); + return self; +} +template +at::Tensor rsub_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::rsub_Scalar::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_addmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_sparse_addmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_sampled_addmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::sparse_sampled_addmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _sparse_mm_reduce_impl_generated_plumbing(const at::Tensor & self, const at::Tensor & other, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_sparse_mm_reduce_impl::call(self, other, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, reduce); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _sparse_mm_reduce_impl_backward_generated_plumbing(const at::Tensor & self, const at::Tensor & grad_out, const at::Tensor & weight, c10::string_view reduce, const at::Tensor & arg_out, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(arg_out, cur_level)) { + return at::_ops::_sparse_mm_reduce_impl_backward::call(self, grad_out, weight, reduce, arg_out, output_mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto [arg_out_value, arg_out_bdim] = unwrapTensorAtLevel(arg_out, cur_level); + auto results = batch_rule(self_value, self_bdim, grad_out_value, grad_out_bdim, weight_value, weight_bdim, reduce, arg_out_value, arg_out_bdim, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor addmm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::addmm::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor addmm_dtype_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, at::ScalarType out_dtype, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::addmm_dtype::call(self, mat1, mat2, out_dtype, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, out_dtype, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addmm__generated_plumbing(at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::addmm_::call(self, mat1, mat2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha); + return self; +} +template +at::Tensor _addmm_activation_generated_plumbing(const at::Tensor & self, const at::Tensor & mat1, const at::Tensor & mat2, const at::Scalar & beta, const at::Scalar & alpha, bool use_gelu) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::_addmm_activation::call(self, mat1, mat2, beta, alpha, use_gelu); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(self_value, self_bdim, mat1_value, mat1_bdim, mat2_value, mat2_bdim, beta, alpha, use_gelu); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _scaled_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scale_a, cur_level) && !isBatchedAtLevel(scale_b, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(scale_result, cur_level)) { + return at::_ops::_scaled_mm::call(self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [scale_a_value, scale_a_bdim] = unwrapTensorAtLevel(scale_a, cur_level); + auto [scale_b_value, scale_b_bdim] = unwrapTensorAtLevel(scale_b, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional scale_result_value; + std::optional scale_result_bdim; + if (scale_result) { + std::tie(scale_result_value, scale_result_bdim) = unwrapTensorAtLevel(scale_result.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scale_a_value, scale_a_bdim, scale_b_value, scale_b_bdim, bias_value, bias_bdim, scale_result_value, scale_result_bdim, out_dtype, use_fast_accum); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _scaled_grouped_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const at::Tensor & scale_a, const at::Tensor & scale_b, const ::std::optional & offs, const ::std::optional & bias, const ::std::optional & scale_result, ::std::optional out_dtype, bool use_fast_accum) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(scale_a, cur_level) && !isBatchedAtLevel(scale_b, cur_level) && !isBatchedAtLevel(offs, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(scale_result, cur_level)) { + return at::_ops::_scaled_grouped_mm::call(self, mat2, scale_a, scale_b, offs, bias, scale_result, out_dtype, use_fast_accum); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto [scale_a_value, scale_a_bdim] = unwrapTensorAtLevel(scale_a, cur_level); + auto [scale_b_value, scale_b_bdim] = unwrapTensorAtLevel(scale_b, cur_level); + std::optional offs_value; + std::optional offs_bdim; + if (offs) { + std::tie(offs_value, offs_bdim) = unwrapTensorAtLevel(offs.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional scale_result_value; + std::optional scale_result_bdim; + if (scale_result) { + std::tie(scale_result_value, scale_result_bdim) = unwrapTensorAtLevel(scale_result.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, scale_a_value, scale_a_bdim, scale_b_value, scale_b_bdim, offs_value, offs_bdim, bias_value, bias_bdim, scale_result_value, scale_result_bdim, out_dtype, use_fast_accum); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _grouped_mm_generated_plumbing(const at::Tensor & self, const at::Tensor & mat2, const ::std::optional & offs, const ::std::optional & bias, ::std::optional out_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mat2, cur_level) && !isBatchedAtLevel(offs, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_grouped_mm::call(self, mat2, offs, bias, out_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + std::optional offs_value; + std::optional offs_bdim; + if (offs) { + std::tie(offs_value, offs_bdim) = unwrapTensorAtLevel(offs.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, mat2_value, mat2_bdim, offs_value, offs_bdim, bias_value, bias_bdim, out_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_compressed_tensor_comp_plain_value_size_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_compressed_tensor_comp_plain_value_size::call(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csr_tensor_crow_col_value_size_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csr_tensor_crow_col_value_size::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csc_tensor_ccol_row_value_size_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csc_tensor_ccol_row_value_size::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsr_tensor_crow_col_value_size_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsr_tensor_crow_col_value_size::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsc_tensor_ccol_row_value_size_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsc_tensor_ccol_row_value_size::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_compressed_tensor_comp_plain_value_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_compressed_tensor_comp_plain_value::call(compressed_indices, plain_indices, values, dtype, layout, device, pin_memory); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csr_tensor_crow_col_value_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csr_tensor_crow_col_value::call(crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_csc_tensor_ccol_row_value_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_csc_tensor_ccol_row_value::call(ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsr_tensor_crow_col_value_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsr_tensor_crow_col_value::call(crow_indices, col_indices, values, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_bsc_tensor_ccol_row_value_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_bsc_tensor_ccol_row_value::call(ccol_indices, row_indices, values, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_compressed_tensor_unsafe_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_compressed_tensor_unsafe::call(compressed_indices, plain_indices, values, size, dtype, layout, device, pin_memory); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csr_tensor_unsafe_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_csr_tensor_unsafe::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_csc_tensor_unsafe_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_csc_tensor_unsafe::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_bsr_tensor_unsafe_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_bsr_tensor_unsafe::call(crow_indices, col_indices, values, size, dtype, layout, device, pin_memory); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_bsc_tensor_unsafe_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_bsc_tensor_unsafe::call(ccol_indices, row_indices, values, size, dtype, layout, device, pin_memory); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_coo_tensor_indices_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_coo_tensor_indices::call(indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(indices_value, indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_coo_tensor_indices_size_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::sparse_coo_tensor_indices_size::call(indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(indices_value, indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_coo_tensor_unsafe_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, c10::SymIntArrayRef size, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_coo_tensor_unsafe::call(indices, values, size, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(indices_value, indices_bdim, values_value, values_bdim, size, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _validate_sparse_coo_tensor_args_generated_plumbing(const at::Tensor & indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional is_coalesced, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_coo_tensor_args::call(indices, values, size, is_coalesced, check_pinning); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(indices_value, indices_bdim, values_value, values_bdim, size, is_coalesced, check_pinning); +} +template +void _validate_sparse_compressed_tensor_args_generated_plumbing(const at::Tensor & compressed_indices, const at::Tensor & plain_indices, const at::Tensor & values, at::IntArrayRef size, at::Layout layout, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(compressed_indices, cur_level) && !isBatchedAtLevel(plain_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_compressed_tensor_args::call(compressed_indices, plain_indices, values, size, layout, check_pinning); + } + auto [compressed_indices_value, compressed_indices_bdim] = unwrapTensorAtLevel(compressed_indices, cur_level); + auto [plain_indices_value, plain_indices_bdim] = unwrapTensorAtLevel(plain_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(compressed_indices_value, compressed_indices_bdim, plain_indices_value, plain_indices_bdim, values_value, values_bdim, size, layout, check_pinning); +} +template +void _validate_sparse_csr_tensor_args_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_csr_tensor_args::call(crow_indices, col_indices, values, size, check_pinning); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +void _validate_sparse_csc_tensor_args_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_csc_tensor_args::call(ccol_indices, row_indices, values, size, check_pinning); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +void _validate_sparse_bsr_tensor_args_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_bsr_tensor_args::call(crow_indices, col_indices, values, size, check_pinning); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +void _validate_sparse_bsc_tensor_args_generated_plumbing(const at::Tensor & ccol_indices, const at::Tensor & row_indices, const at::Tensor & values, at::IntArrayRef size, ::std::optional check_pinning) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(ccol_indices, cur_level) && !isBatchedAtLevel(row_indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_validate_sparse_bsc_tensor_args::call(ccol_indices, row_indices, values, size, check_pinning); + } + auto [ccol_indices_value, ccol_indices_bdim] = unwrapTensorAtLevel(ccol_indices, cur_level); + auto [row_indices_value, row_indices_bdim] = unwrapTensorAtLevel(row_indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + batch_rule(ccol_indices_value, ccol_indices_bdim, row_indices_value, row_indices_bdim, values_value, values_bdim, size, check_pinning); +} +template +at::Tensor _sparse_coo_tensor_with_dims_and_tensors_generated_plumbing(int64_t sparse_dim, int64_t dense_dim, c10::SymIntArrayRef size, const at::Tensor & indices, const at::Tensor & values, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, ::std::optional is_coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_sparse_coo_tensor_with_dims_and_tensors::call(sparse_dim, dense_dim, size, indices, values, dtype, layout, device, pin_memory, is_coalesced); + } + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(sparse_dim, dense_dim, size, indices_value, indices_bdim, values_value, values_bdim, dtype, layout, device, pin_memory, is_coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +const at::Tensor & sparse_resize__generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize_::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return self; +} +template +const at::Tensor & sparse_resize_and_clear__generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize_and_clear_::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return self; +} +template +at::Tensor sparse_mask_generated_plumbing(const at::Tensor & self, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::sparse_mask::call(self, mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_mask_projection_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, bool accumulate_matches) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_sparse_mask_projection::call(self, mask, accumulate_matches); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, accumulate_matches); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _to_cpu_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::_to_cpu::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dense_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_dense::call(self, dtype, masked_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, masked_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_dense_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional masked_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_dense::call(self, dtype, masked_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, masked_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dense_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, ::std::optional masked_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::to_dense_backward::call(grad, input, masked_grad); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, masked_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor coalesce_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::coalesce::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _coalesce_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_coalesce::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _values_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_values::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _coalesced__generated_plumbing(at::Tensor & self, bool coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_coalesced_::call(self, coalesced); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, coalesced); + return self; +} +template +at::Tensor indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor values_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::values::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor crow_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::crow_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor col_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::col_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ccol_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ccol_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor row_indices_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::row_indices::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hspmm_generated_plumbing(const at::Tensor & mat1, const at::Tensor & mat2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mat1, cur_level) && !isBatchedAtLevel(mat2, cur_level)) { + return at::_ops::hspmm::call(mat1, mat2); + } + auto [mat1_value, mat1_bdim] = unwrapTensorAtLevel(mat1, cur_level); + auto [mat2_value, mat2_bdim] = unwrapTensorAtLevel(mat2, cur_level); + auto results = batch_rule(mat1_value, mat1_bdim, mat2_value, mat2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & copy_sparse_to_sparse__generated_plumbing(at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy_sparse_to_sparse_::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return self; +} +template +::std::vector unbind_int_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unbind_int::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unbind_Dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unbind_Dimname::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_sparse_dim_generated_plumbing(const at::Tensor & self, int64_t sparse_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_sparse_dim::call(self, sparse_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sparse_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_sparse_dim_generated_plumbing(const at::Tensor & self, int64_t sparse_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_sparse_dim::call(self, sparse_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, sparse_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_generated_plumbing(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse::call(self, layout, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, layout, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_generated_plumbing(const at::Tensor & self, ::std::optional layout, at::OptionalIntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse::call(self, layout, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, layout, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_csr_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_csr::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_csr_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_csr::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_csc_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_csc::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_csc_generated_plumbing(const at::Tensor & self, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_csc::call(self, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_bsr_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_bsr::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_bsr_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_bsr::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_sparse_bsc_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_sparse_bsc::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_sparse_bsc_generated_plumbing(const at::Tensor & self, at::IntArrayRef blocksize, ::std::optional dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_sparse_bsc::call(self, blocksize, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, blocksize, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _to_sparse_semi_structured_generated_plumbing(const at::Tensor & dense) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dense, cur_level)) { + return at::_ops::_to_sparse_semi_structured::call(dense); + } + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(dense_value, dense_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor to_mkldnn_generated_plumbing(const at::Tensor & self, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_mkldnn::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_reorder_conv2d_weight_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_reorder_conv2d_weight::call(self, padding, stride, dilation, groups, input_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, stride, dilation, groups, input_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_reorder_conv3d_weight_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding, c10::SymIntArrayRef stride, c10::SymIntArrayRef dilation, c10::SymInt groups, at::OptionalSymIntArrayRef input_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_reorder_conv3d_weight::call(self, padding, stride, dilation, groups, input_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, stride, dilation, groups, input_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_mkldnn_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level)) { + return at::_ops::to_mkldnn_backward::call(grad, input); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_tensor_dynamic_generated_plumbing(const at::Tensor & self, at::ScalarType dtype, bool reduce_range) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantize_per_tensor_dynamic::call(self, dtype, reduce_range); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, reduce_range); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_tensor_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantize_per_tensor::call(self, scale, zero_point, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_tensor_tensor_qparams_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::quantize_per_tensor_tensor_qparams::call(self, scale, zero_point, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector quantize_per_tensor_tensors_generated_plumbing(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level) && !isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level)) { + return at::_ops::quantize_per_tensor_tensors::call(tensors, scales, zero_points, dtype); + } + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + auto results = batch_rule(tensors, scales_value, scales_bdim, zero_points_value, zero_points_bdim, dtype); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantize_per_channel_generated_plumbing(const at::Tensor & self, const at::Tensor & scales, const at::Tensor & zero_points, int64_t axis, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level)) { + return at::_ops::quantize_per_channel::call(self, scales, zero_points, axis, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + auto results = batch_rule(self_value, self_bdim, scales_value, scales_bdim, zero_points_value, zero_points_bdim, axis, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dequantize_self_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::dequantize_self::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector dequantize_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::dequantize_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor q_per_channel_scales_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::q_per_channel_scales::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor q_per_channel_zero_points_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::q_per_channel_zero_points::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor int_repr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::int_repr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_per_tensor_quantized_tensor_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_make_per_tensor_quantized_tensor::call(self, scale, zero_point); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_per_channel_quantized_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_make_per_channel_quantized_tensor::call(self, scale, zero_point, axis); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fake_quantize_per_tensor_affine_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine::call(self, scale, zero_point, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point, quant_min, quant_max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fake_quantize_per_tensor_affine_tensor_qparams_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine_tensor_qparams::call(self, scale, zero_point, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, quant_min, quant_max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fake_quantize_per_tensor_affine_cachemask_generated_plumbing(const at::Tensor & self, double scale, int64_t zero_point, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask::call(self, scale, zero_point, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, scale, zero_point, quant_min, quant_max); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _fake_quantize_per_tensor_affine_cachemask_tensor_qparams_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, const at::Tensor & fake_quant_enabled, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level) && !isBatchedAtLevel(fake_quant_enabled, cur_level)) { + return at::_ops::_fake_quantize_per_tensor_affine_cachemask_tensor_qparams::call(self, scale, zero_point, fake_quant_enabled, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto [fake_quant_enabled_value, fake_quant_enabled_bdim] = unwrapTensorAtLevel(fake_quant_enabled, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, fake_quant_enabled_value, fake_quant_enabled_bdim, quant_min, quant_max); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fake_quantize_per_tensor_affine_cachemask_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::fake_quantize_per_tensor_affine_cachemask_backward::call(grad, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_value, grad_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fake_quantize_learnable_per_tensor_affine_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine::call(self, scale, zero_point, quant_min, quant_max, grad_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, quant_min, quant_max, grad_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _fake_quantize_learnable_per_tensor_affine_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_tensor_affine_backward::call(grad, self, scale, zero_point, quant_min, quant_max, grad_factor); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, quant_min, quant_max, grad_factor); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor fake_quantize_per_channel_affine_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::fake_quantize_per_channel_affine::call(self, scale, zero_point, axis, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fake_quantize_per_channel_affine_cachemask_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::fake_quantize_per_channel_affine_cachemask::call(self, scale, zero_point, axis, quant_min, quant_max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fake_quantize_per_channel_affine_cachemask_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::fake_quantize_per_channel_affine_cachemask_backward::call(grad, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_value, grad_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fake_quantize_learnable_per_channel_affine_generated_plumbing(const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_channel_affine::call(self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max, grad_factor); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _fake_quantize_learnable_per_channel_affine_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, const at::Tensor & scale, const at::Tensor & zero_point, int64_t axis, int64_t quant_min, int64_t quant_max, double grad_factor) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fake_quantize_learnable_per_channel_affine_backward::call(grad, self, scale, zero_point, axis, quant_min, quant_max, grad_factor); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, axis, quant_min, quant_max, grad_factor); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _saturate_weight_to_fp16_generated_plumbing(const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_saturate_weight_to_fp16::call(weight); + } + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple choose_qparams_optimized_generated_plumbing(const at::Tensor & input, int64_t numel, int64_t n_bins, double ratio, int64_t bit_width) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::choose_qparams_optimized::call(input, numel, n_bins, ratio, bit_width); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, numel, n_bins, ratio, bit_width); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _autocast_to_reduced_precision_generated_plumbing(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled, at::ScalarType cuda_dtype, at::ScalarType cpu_dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_autocast_to_reduced_precision::call(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _autocast_to_full_precision_generated_plumbing(const at::Tensor & self, bool cuda_enabled, bool cpu_enabled) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_autocast_to_full_precision::call(self, cuda_enabled, cpu_enabled); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, cuda_enabled, cpu_enabled); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _to_copy_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_to_copy::call(self, dtype, layout, device, pin_memory, non_blocking, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, non_blocking, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dtype_layout_generated_plumbing(const at::Tensor & self, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_dtype_layout::call(self, dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, layout, device, pin_memory, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_device_generated_plumbing(const at::Tensor & self, at::Device device, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_device::call(self, device, dtype, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, device, dtype, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_dtype::call(self, dtype, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_other_generated_plumbing(const at::Tensor & self, const at::Tensor & other, bool non_blocking, bool copy, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::to_other::call(self, other, non_blocking, copy, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, non_blocking, copy, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector meshgrid_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::meshgrid::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector meshgrid_indexing_generated_plumbing(at::TensorList tensors, c10::string_view indexing) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::meshgrid_indexing::call(tensors, indexing); + } + + auto results = batch_rule(tensors, indexing); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cartesian_prod_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::cartesian_prod::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor combinations_generated_plumbing(const at::Tensor & self, int64_t r, bool with_replacement) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::combinations::call(self, r, with_replacement); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, r, with_replacement); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _lstm_mps_generated_plumbing(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::_lstm_mps::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level)); +} +template +::std::tuple,::std::vector> lstm_mps_backward_generated_plumbing(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_y, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(z_state, cur_level) && !isBatchedAtLevel(cell_state_fwd, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(layersOutputs, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::lstm_mps_backward::call(grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [z_state_value, z_state_bdim] = unwrapTensorAtLevel(z_state, cur_level); + auto [cell_state_fwd_value, cell_state_fwd_bdim] = unwrapTensorAtLevel(cell_state_fwd, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [layersOutputs_value, layersOutputs_bdim] = unwrapTensorAtLevel(layersOutputs, cur_level); + std::optional grad_y_value; + std::optional grad_y_bdim; + if (grad_y) { + std::tie(grad_y_value, grad_y_bdim) = unwrapTensorAtLevel(grad_y.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(grad_y_value, grad_y_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, z_state_value, z_state_bdim, cell_state_fwd_value, cell_state_fwd_bdim, input_value, input_bdim, layersOutputs_value, layersOutputs_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _thnn_fused_lstm_cell_generated_plumbing(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & cx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level)) { + return at::_ops::_thnn_fused_lstm_cell::call(input_gates, hidden_gates, cx, input_bias, hidden_bias); + } + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, cx_value, cx_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _thnn_fused_lstm_cell_backward_impl_generated_plumbing(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(cy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::_thnn_fused_lstm_cell_backward_impl::call(grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + auto [cy_value, cy_bdim] = unwrapTensorAtLevel(cy, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, cx_value, cx_bdim, cy_value, cy_bdim, workspace_value, workspace_bdim, has_bias); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _thnn_fused_lstm_cell_backward_generated_plumbing(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & cx, const at::Tensor & cy, const at::Tensor & workspace, bool has_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(cy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::_thnn_fused_lstm_cell_backward::call(grad_hy, grad_cy, cx, cy, workspace, has_bias); + } + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + auto [cy_value, cy_bdim] = unwrapTensorAtLevel(cy, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, cx_value, cx_bdim, cy_value, cy_bdim, workspace_value, workspace_bdim, has_bias); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _thnn_differentiable_lstm_cell_backward_generated_plumbing(const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const ::std::optional & input_bias, const ::std::optional & hidden_bias, const at::Tensor & cx, const at::Tensor & cy) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(cy, cur_level)) { + return at::_ops::_thnn_differentiable_lstm_cell_backward::call(grad_hy, grad_cy, input_gates, hidden_gates, input_bias, hidden_bias, cx, cy); + } + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [cx_value, cx_bdim] = unwrapTensorAtLevel(cx, cur_level); + auto [cy_value, cy_bdim] = unwrapTensorAtLevel(cy, cur_level); + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim, cx_value, cx_bdim, cy_value, cy_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _thnn_fused_gru_cell_generated_plumbing(const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level)) { + return at::_ops::_thnn_fused_gru_cell::call(input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, hx_value, hx_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _thnn_fused_gru_cell_backward_generated_plumbing(const at::Tensor & grad_hy, const at::Tensor & workspace, bool has_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(workspace, cur_level)) { + return at::_ops::_thnn_fused_gru_cell_backward::call(grad_hy, workspace, has_bias); + } + auto [grad_hy_value, grad_hy_bdim] = unwrapTensorAtLevel(grad_hy, cur_level); + auto [workspace_value, workspace_bdim] = unwrapTensorAtLevel(workspace, cur_level); + auto results = batch_rule(grad_hy_value, grad_hy_bdim, workspace_value, workspace_bdim, has_bias); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _thnn_differentiable_gru_cell_backward_generated_plumbing(const at::Tensor & grad_hy, const at::Tensor & input_gates, const at::Tensor & hidden_gates, const at::Tensor & hx, const ::std::optional & input_bias, const ::std::optional & hidden_bias) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(input_gates, cur_level) && !isBatchedAtLevel(hidden_gates, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(input_bias, cur_level) && !isBatchedAtLevel(hidden_bias, cur_level)) { + return at::_ops::_thnn_differentiable_gru_cell_backward::call(grad_hy, input_gates, hidden_gates, hx, input_bias, hidden_bias); + } + auto [grad_hy_value, grad_hy_bdim] = unwrapTensorAtLevel(grad_hy, cur_level); + auto [input_gates_value, input_gates_bdim] = unwrapTensorAtLevel(input_gates, cur_level); + auto [hidden_gates_value, hidden_gates_bdim] = unwrapTensorAtLevel(hidden_gates, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + std::optional input_bias_value; + std::optional input_bias_bdim; + if (input_bias) { + std::tie(input_bias_value, input_bias_bdim) = unwrapTensorAtLevel(input_bias.value(), cur_level); + } + std::optional hidden_bias_value; + std::optional hidden_bias_bdim; + if (hidden_bias) { + std::tie(hidden_bias_value, hidden_bias_bdim) = unwrapTensorAtLevel(hidden_bias.value(), cur_level); + } + auto results = batch_rule(grad_hy_value, grad_hy_bdim, input_gates_value, input_gates_bdim, hidden_gates_value, hidden_gates_bdim, hx_value, hx_bdim, input_bias_value, input_bias_bdim, hidden_bias_value, hidden_bias_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple lstm_input_generated_plumbing(const at::Tensor & input, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::lstm_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple lstm_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::lstm_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple gru_input_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::gru_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple gru_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::gru_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_tanh_input_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_tanh_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_tanh_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_tanh_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_relu_input_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_relu_input::call(input, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple rnn_relu_data_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, const at::Tensor & hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level)) { + return at::_ops::rnn_relu_data::call(data, batch_sizes, hx, params, has_biases, num_layers, dropout, train, bidirectional); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, hx_value, hx_bdim, params, has_biases, num_layers, dropout, train, bidirectional); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple lstm_cell_generated_plumbing(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::lstm_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor gru_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::gru_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rnn_tanh_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::rnn_tanh_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rnn_relu_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const ::std::optional & b_ih, const ::std::optional & b_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level)) { + return at::_ops::rnn_relu_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + std::optional b_ih_value; + std::optional b_ih_bdim; + if (b_ih) { + std::tie(b_ih_value, b_ih_bdim) = unwrapTensorAtLevel(b_ih.value(), cur_level); + } + std::optional b_hh_value; + std::optional b_hh_bdim; + if (b_hh) { + std::tie(b_hh_value, b_hh_bdim) = unwrapTensorAtLevel(b_hh.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple quantized_lstm_cell_generated_plumbing(const at::Tensor & input, at::TensorList hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_lstm_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor quantized_gru_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_gru_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_rnn_relu_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_rnn_relu_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantized_rnn_tanh_cell_generated_plumbing(const at::Tensor & input, const at::Tensor & hx, const at::Tensor & w_ih, const at::Tensor & w_hh, const at::Tensor & b_ih, const at::Tensor & b_hh, const at::Tensor & packed_ih, const at::Tensor & packed_hh, const at::Tensor & col_offsets_ih, const at::Tensor & col_offsets_hh, const at::Scalar & scale_ih, const at::Scalar & scale_hh, const at::Scalar & zero_point_ih, const at::Scalar & zero_point_hh) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(w_ih, cur_level) && !isBatchedAtLevel(w_hh, cur_level) && !isBatchedAtLevel(b_ih, cur_level) && !isBatchedAtLevel(b_hh, cur_level) && !isBatchedAtLevel(packed_ih, cur_level) && !isBatchedAtLevel(packed_hh, cur_level) && !isBatchedAtLevel(col_offsets_ih, cur_level) && !isBatchedAtLevel(col_offsets_hh, cur_level)) { + return at::_ops::quantized_rnn_tanh_cell::call(input, hx, w_ih, w_hh, b_ih, b_hh, packed_ih, packed_hh, col_offsets_ih, col_offsets_hh, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [w_ih_value, w_ih_bdim] = unwrapTensorAtLevel(w_ih, cur_level); + auto [w_hh_value, w_hh_bdim] = unwrapTensorAtLevel(w_hh, cur_level); + auto [b_ih_value, b_ih_bdim] = unwrapTensorAtLevel(b_ih, cur_level); + auto [b_hh_value, b_hh_bdim] = unwrapTensorAtLevel(b_hh, cur_level); + auto [packed_ih_value, packed_ih_bdim] = unwrapTensorAtLevel(packed_ih, cur_level); + auto [packed_hh_value, packed_hh_bdim] = unwrapTensorAtLevel(packed_hh, cur_level); + auto [col_offsets_ih_value, col_offsets_ih_bdim] = unwrapTensorAtLevel(col_offsets_ih, cur_level); + auto [col_offsets_hh_value, col_offsets_hh_bdim] = unwrapTensorAtLevel(col_offsets_hh, cur_level); + auto results = batch_rule(input_value, input_bdim, hx_value, hx_bdim, w_ih_value, w_ih_bdim, w_hh_value, w_hh_bdim, b_ih_value, b_ih_bdim, b_hh_value, b_hh_bdim, packed_ih_value, packed_ih_bdim, packed_hh_value, packed_hh_bdim, col_offsets_ih_value, col_offsets_ih_bdim, col_offsets_hh_value, col_offsets_hh_bdim, scale_ih, scale_hh, zero_point_ih, zero_point_hh); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _pack_padded_sequence_generated_plumbing(const at::Tensor & input, const at::Tensor & lengths, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(lengths, cur_level)) { + return at::_ops::_pack_padded_sequence::call(input, lengths, batch_first); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [lengths_value, lengths_bdim] = unwrapTensorAtLevel(lengths, cur_level); + auto results = batch_rule(input_value, input_bdim, lengths_value, lengths_bdim, batch_first); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _pack_padded_sequence_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef input_size, const at::Tensor & batch_sizes, bool batch_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level)) { + return at::_ops::_pack_padded_sequence_backward::call(grad, input_size, batch_sizes, batch_first); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_size, batch_sizes_value, batch_sizes_bdim, batch_first); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _pad_packed_sequence_generated_plumbing(const at::Tensor & data, const at::Tensor & batch_sizes, bool batch_first, const at::Scalar & padding_value, int64_t total_length) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(batch_sizes, cur_level)) { + return at::_ops::_pad_packed_sequence::call(data, batch_sizes, batch_first, padding_value, total_length); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + auto [batch_sizes_value, batch_sizes_bdim] = unwrapTensorAtLevel(batch_sizes, cur_level); + auto results = batch_rule(data_value, data_bdim, batch_sizes_value, batch_sizes_bdim, batch_first, padding_value, total_length); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor lift_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lift::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lift_fresh_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lift_fresh::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lift_fresh_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lift_fresh_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & masked_fill__Scalar_generated_plumbing(at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_fill__Scalar::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + batch_rule(self_value, self_bdim, mask_value, mask_bdim, value); + return self; +} +template +at::Tensor masked_fill_Scalar_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_fill_Scalar::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & masked_fill__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::masked_fill__Tensor::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, mask_value, mask_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor masked_fill_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::masked_fill_Tensor::call(self, mask, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & masked_scatter__generated_plumbing(at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::masked_scatter_::call(self, mask, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, mask_value, mask_bdim, source_value, source_bdim); + return self; +} +template +at::Tensor masked_scatter_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::masked_scatter::call(self, mask, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor masked_scatter_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & mask, c10::SymIntArrayRef sizes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_scatter_backward::call(grad_output, mask, sizes); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, mask_value, mask_bdim, sizes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _masked_softmax_generated_plumbing(const at::Tensor & self, const at::Tensor & mask, ::std::optional dim, ::std::optional mask_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_masked_softmax::call(self, mask, dim, mask_type); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim, dim, mask_type); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _masked_softmax_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output, const at::Tensor & mask, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_masked_softmax_backward::call(grad_output, output, mask, dim); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim, mask_value, mask_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_dtype::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & put__generated_plumbing(at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::put_::call(self, index, source, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, index_value, index_bdim, source_value, source_bdim, accumulate); + return self; +} +template +at::Tensor put_generated_plumbing(const at::Tensor & self, const at::Tensor & index, const at::Tensor & source, bool accumulate) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::put::call(self, index, source, accumulate); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, index_value, index_bdim, source_value, source_bdim, accumulate); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_add__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_add_::call(self, dim, index, source, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, alpha); + return self; +} +template +at::Tensor index_add_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_add::call(self, dim, index, source, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_add_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & source, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_add_dimname::call(self, dim, index, source, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_reduce__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_reduce_::call(self, dim, index, source, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, reduce, include_self); + return self; +} +template +at::Tensor index_reduce_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & source, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::index_reduce::call(self, dim, index, source, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, source_value, source_bdim, reduce, include_self); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_fill__int_Scalar_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill__int_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return self; +} +template +at::Tensor index_fill_int_Scalar_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill_int_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_fill__int_Tensor_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill__int_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor index_fill_int_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill_int_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & index_fill__Dimname_Scalar_generated_plumbing(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill__Dimname_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return self; +} +template +at::Tensor & index_fill__Dimname_Tensor_generated_plumbing(at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill__Dimname_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return self; +} +template +at::Tensor index_fill_Dimname_Scalar_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_fill_Dimname_Scalar::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_fill_Dimname_Tensor_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::index_fill_Dimname_Tensor::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value_value, value_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_src_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_src::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__src_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter__src::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return self; +} +template +at::Tensor scatter_value_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter_value::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__value_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter__value::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return self; +} +template +at::Tensor scatter_reduce_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_reduce::call(self, dim, index, src, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__reduce_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter__reduce::call(self, dim, index, src, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce); + return self; +} +template +at::Tensor scatter_value_reduce_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter_value_reduce::call(self, dim, index, value, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value, reduce); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter__value_reduce_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Scalar & value, c10::string_view reduce) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter__value_reduce::call(self, dim, index, value, reduce); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value, reduce); + return self; +} +template +at::Tensor scatter_dimname_src_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_dimname_src::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_dimname_value_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::scatter_dimname_value::call(self, dim, index, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_add_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_add::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter_add__generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_add_::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return self; +} +template +at::Tensor scatter_add_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, const at::Tensor & src) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_add_dimname::call(self, dim, index, src); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor scatter_reduce_two_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_reduce_two::call(self, dim, index, src, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce, include_self); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & scatter_reduce__two_generated_plumbing(at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & src, c10::string_view reduce, bool include_self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::scatter_reduce__two::call(self, dim, index, src, reduce, include_self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + batch_rule(self_value, self_bdim, dim, index_value, index_bdim, src_value, src_bdim, reduce, include_self); + return self; +} +template +at::Tensor & eq__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::eq__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & eq__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::eq__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_and_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_and_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_and_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_and_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_and_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_and_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_and__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_and__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & bitwise_and__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_and__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __and___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__and___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __and___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__and___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __iand___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__iand___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __iand___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__iand___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_or_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_or_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_or_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_or_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_or_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_or_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_or__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_or__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & bitwise_or__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_or__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __or___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__or___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __or___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__or___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __ior___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__ior___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __ior___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__ior___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_xor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_xor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_xor_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_xor_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bitwise_xor_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_xor_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_xor__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_xor__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & bitwise_xor__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_xor__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __xor___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__xor___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __xor___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__xor___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __ixor___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__ixor___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __ixor___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__ixor___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor __lshift___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__lshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __lshift___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__lshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __ilshift___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__ilshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __ilshift___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__ilshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_left_shift_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_left_shift_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_left_shift__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_left_shift__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_left_shift_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_left_shift_Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_left_shift__Tensor_Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_left_shift__Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor bitwise_left_shift_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_left_shift_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __rshift___Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__rshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor __rshift___Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__rshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & __irshift___Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::__irshift___Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & __irshift___Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::__irshift___Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_right_shift_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_right_shift_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_right_shift__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_right_shift__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor bitwise_right_shift_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_right_shift_Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & bitwise_right_shift__Tensor_Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::bitwise_right_shift__Tensor_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor bitwise_right_shift_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::bitwise_right_shift_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & tril__generated_plumbing(at::Tensor & self, int64_t diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tril_::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, diagonal); + return self; +} +template +at::Tensor & triu__generated_plumbing(at::Tensor & self, int64_t diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::triu_::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, diagonal); + return self; +} +template +at::Tensor & digamma__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::digamma_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor & lerp__Scalar_generated_plumbing(at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::lerp__Scalar::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + batch_rule(self_value, self_bdim, end_value, end_bdim, weight); + return self; +} +template +at::Tensor & lerp__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::lerp__Tensor::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + batch_rule(self_value, self_bdim, end_value, end_bdim, weight_value, weight_bdim); + return self; +} +template +at::Tensor & addbmm__generated_plumbing(at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::addbmm_::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return self; +} +template +at::Tensor addbmm_generated_plumbing(const at::Tensor & self, const at::Tensor & batch1, const at::Tensor & batch2, const at::Scalar & beta, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(batch1, cur_level) && !isBatchedAtLevel(batch2, cur_level)) { + return at::_ops::addbmm::call(self, batch1, batch2, beta, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [batch1_value, batch1_bdim] = unwrapTensorAtLevel(batch1, cur_level); + auto [batch2_value, batch2_bdim] = unwrapTensorAtLevel(batch2, cur_level); + auto results = batch_rule(self_value, self_bdim, batch1_value, batch1_bdim, batch2_value, batch2_bdim, beta, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & random__from_generated_plumbing(at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random__from::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, from, to, generator); + return self; +} +template +at::Tensor & random__to_generated_plumbing(at::Tensor & self, int64_t to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random__to::call(self, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, to, generator); + return self; +} +template +at::Tensor & random__generated_plumbing(at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random_::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, generator); + return self; +} +template +at::Tensor & uniform__generated_plumbing(at::Tensor & self, double from, double to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::uniform_::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, from, to, generator); + return self; +} +template +at::Tensor & cauchy__generated_plumbing(at::Tensor & self, double median, double sigma, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cauchy_::call(self, median, sigma, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, median, sigma, generator); + return self; +} +template +at::Tensor & log_normal__generated_plumbing(at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_normal_::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, mean, std, generator); + return self; +} +template +at::Tensor & exponential__generated_plumbing(at::Tensor & self, double lambd, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exponential_::call(self, lambd, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, lambd, generator); + return self; +} +template +at::Tensor & geometric__generated_plumbing(at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::geometric_::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, generator); + return self; +} +template +at::Tensor diag_generated_plumbing(const at::Tensor & self, int64_t diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diag::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, diagonal); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cross_generated_plumbing(const at::Tensor & self, const at::Tensor & other, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::cross::call(self, other, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor triu_generated_plumbing(const at::Tensor & self, int64_t diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::triu::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, diagonal); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tril_generated_plumbing(const at::Tensor & self, int64_t diagonal) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::tril::call(self, diagonal); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, diagonal); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trace_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::trace::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor trace_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef sizes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level)) { + return at::_ops::trace_backward::call(grad, sizes); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(grad_value, grad_bdim, sizes); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ne_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ne_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ne_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ne_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ne__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ne__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & ne__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ne__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor not_equal_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::not_equal_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor not_equal_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::not_equal_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & not_equal__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::not_equal__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & not_equal__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::not_equal__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor eq_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::eq_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor eq_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::eq_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ge_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ge_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ge_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ge_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & ge__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ge__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & ge__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::ge__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor greater_equal_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater_equal_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor greater_equal_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater_equal_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & greater_equal__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater_equal__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & greater_equal__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater_equal__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor le_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::le_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor le_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::le_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & le__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::le__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & le__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::le__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor less_equal_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less_equal_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor less_equal_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less_equal_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & less_equal__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less_equal__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & less_equal__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less_equal__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor gt_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gt_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gt_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gt_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & gt__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::gt__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & gt__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::gt__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor greater_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor greater_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & greater__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::greater__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & greater__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::greater__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor lt_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lt_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lt_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lt_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & lt__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lt__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & lt__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::lt__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor less_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor less_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & less__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::less__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor & less__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::less__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor take_generated_plumbing(const at::Tensor & self, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::take::call(self, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor take_along_dim_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, ::std::optional dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::take_along_dim::call(self, indices, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_select_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_select::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_select_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_select_dimname::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor index_select_backward_generated_plumbing(const at::Tensor & grad, c10::SymIntArrayRef self_sizes, int64_t dim, const at::Tensor & index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::index_select_backward::call(grad, self_sizes, dim, index); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_sizes, dim, index_value, index_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor masked_select_generated_plumbing(const at::Tensor & self, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_select::call(self, mask); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(self_value, self_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor masked_select_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & input, const at::Tensor & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::masked_select_backward::call(grad, input, mask); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [mask_value, mask_bdim] = unwrapTensorAtLevel(mask, cur_level); + auto results = batch_rule(grad_value, grad_bdim, input_value, input_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nonzero_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nonzero::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nonzero_static_generated_plumbing(const at::Tensor & self, c10::SymInt size, int64_t fill_value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nonzero_static::call(self, size, fill_value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, fill_value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector nonzero_numpy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nonzero_numpy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argwhere_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argwhere::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gather_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::gather::call(self, dim, index, sparse_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, sparse_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gather_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & self, int64_t dim, const at::Tensor & index, bool sparse_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::gather_backward::call(grad, self, dim, index, sparse_grad); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(grad_value, grad_bdim, self_value, self_bdim, dim, index_value, index_bdim, sparse_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor gather_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, const at::Tensor & index, bool sparse_grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level)) { + return at::_ops::gather_dimname::call(self, dim, index, sparse_grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, sparse_grad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _gather_sparse_backward_generated_plumbing(const at::Tensor & self, int64_t dim, const at::Tensor & index, const at::Tensor & grad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(index, cur_level) && !isBatchedAtLevel(grad, cur_level)) { + return at::_ops::_gather_sparse_backward::call(self, dim, index, grad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [index_value, index_bdim] = unwrapTensorAtLevel(index, cur_level); + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index_value, index_bdim, grad_value, grad_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor addcmul_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcmul::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addcmul__generated_plumbing(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcmul_::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return self; +} +template +at::Tensor addcdiv_generated_plumbing(const at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcdiv::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + auto results = batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & addcdiv__generated_plumbing(at::Tensor & self, const at::Tensor & tensor1, const at::Tensor & tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::addcdiv_::call(self, tensor1, tensor2, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [tensor1_value, tensor1_bdim] = unwrapTensorAtLevel(tensor1, cur_level); + auto [tensor2_value, tensor2_bdim] = unwrapTensorAtLevel(tensor2, cur_level); + batch_rule(self_value, self_bdim, tensor1_value, tensor1_bdim, tensor2_value, tensor2_bdim, value); + return self; +} +template +at::Tensor cross_entropy_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, double label_smoothing) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::cross_entropy_loss::call(self, target, weight, reduction, ignore_index, label_smoothing); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index, label_smoothing); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple triangular_solve_generated_plumbing(const at::Tensor & self, const at::Tensor & A, bool upper, bool transpose, bool unitriangular) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(A, cur_level)) { + return at::_ops::triangular_solve::call(self, A, upper, transpose, unitriangular); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(self_value, self_bdim, A_value, A_bdim, upper, transpose, unitriangular); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _linalg_check_errors_generated_plumbing(const at::Tensor & info, c10::string_view api_name, bool is_matrix) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(info, cur_level)) { + return at::_ops::_linalg_check_errors::call(info, api_name, is_matrix); + } + auto [info_value, info_bdim] = unwrapTensorAtLevel(info, cur_level); + batch_rule(info_value, info_bdim, api_name, is_matrix); +} +template +at::Tensor linalg_solve_triangular_generated_plumbing(const at::Tensor & self, const at::Tensor & B, bool upper, bool left, bool unitriangular) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_solve_triangular::call(self, B, upper, left, unitriangular); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(self_value, self_bdim, B_value, B_bdim, upper, left, unitriangular); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_vander_generated_plumbing(const at::Tensor & x, ::std::optional N) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::linalg_vander::call(x, N); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, N); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple svd_generated_plumbing(const at::Tensor & self, bool some, bool compute_uv) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::svd::call(self, some, compute_uv); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, some, compute_uv); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor swapaxes_generated_plumbing(const at::Tensor & self, int64_t axis0, int64_t axis1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::swapaxes::call(self, axis0, axis1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, axis0, axis1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor swapdims_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::swapdims::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cholesky_generated_plumbing(const at::Tensor & self, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cholesky::call(self, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cholesky_solve_generated_plumbing(const at::Tensor & self, const at::Tensor & input2, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input2, cur_level)) { + return at::_ops::cholesky_solve::call(self, input2, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto results = batch_rule(self_value, self_bdim, input2_value, input2_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _cholesky_solve_helper_generated_plumbing(const at::Tensor & self, const at::Tensor & A, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(A, cur_level)) { + return at::_ops::_cholesky_solve_helper::call(self, A, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(self_value, self_bdim, A_value, A_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cholesky_inverse_generated_plumbing(const at::Tensor & self, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cholesky_inverse::call(self, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple qr_generated_plumbing(const at::Tensor & self, bool some) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::qr::call(self, some); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, some); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple geqrf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::geqrf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor orgqr_generated_plumbing(const at::Tensor & self, const at::Tensor & input2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input2, cur_level)) { + return at::_ops::orgqr::call(self, input2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto results = batch_rule(self_value, self_bdim, input2_value, input2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ormqr_generated_plumbing(const at::Tensor & self, const at::Tensor & input2, const at::Tensor & input3, bool left, bool transpose) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(input2, cur_level) && !isBatchedAtLevel(input3, cur_level)) { + return at::_ops::ormqr::call(self, input2, input3, left, transpose); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [input2_value, input2_bdim] = unwrapTensorAtLevel(input2, cur_level); + auto [input3_value, input3_bdim] = unwrapTensorAtLevel(input3, cur_level); + auto results = batch_rule(self_value, self_bdim, input2_value, input2_bdim, input3_value, input3_bdim, left, transpose); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _lu_with_info_generated_plumbing(const at::Tensor & self, bool pivot, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_lu_with_info::call(self, pivot, check_errors); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pivot, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor lu_solve_generated_plumbing(const at::Tensor & self, const at::Tensor & LU_data, const at::Tensor & LU_pivots) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(LU_data, cur_level) && !isBatchedAtLevel(LU_pivots, cur_level)) { + return at::_ops::lu_solve::call(self, LU_data, LU_pivots); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [LU_data_value, LU_data_bdim] = unwrapTensorAtLevel(LU_data, cur_level); + auto [LU_pivots_value, LU_pivots_bdim] = unwrapTensorAtLevel(LU_pivots, cur_level); + auto results = batch_rule(self_value, self_bdim, LU_data_value, LU_data_bdim, LU_pivots_value, LU_pivots_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple lu_unpack_generated_plumbing(const at::Tensor & LU_data, const at::Tensor & LU_pivots, bool unpack_data, bool unpack_pivots) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(LU_data, cur_level) && !isBatchedAtLevel(LU_pivots, cur_level)) { + return at::_ops::lu_unpack::call(LU_data, LU_pivots, unpack_data, unpack_pivots); + } + auto [LU_data_value, LU_data_bdim] = unwrapTensorAtLevel(LU_data, cur_level); + auto [LU_pivots_value, LU_pivots_bdim] = unwrapTensorAtLevel(LU_pivots, cur_level); + auto results = batch_rule(LU_data_value, LU_data_bdim, LU_pivots_value, LU_pivots_bdim, unpack_data, unpack_pivots); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor multinomial_generated_plumbing(const at::Tensor & self, c10::SymInt num_samples, bool replacement, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::multinomial::call(self, num_samples, replacement, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, num_samples, replacement, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & lgamma__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lgamma_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor lgamma_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::lgamma::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor digamma_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::digamma::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor polygamma_generated_plumbing(int64_t n, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::polygamma::call(n, self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(n, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & polygamma__generated_plumbing(at::Tensor & self, int64_t n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::polygamma_::call(self, n); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, n); + return self; +} +template +at::Tensor erfinv_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfinv::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & erfinv__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::erfinv_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor i0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::i0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & i0__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::i0_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor sign_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sign::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & sign__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sign_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor signbit_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::signbit::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor dist_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::dist::call(self, other, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & atan2__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::atan2_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor atan2_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::atan2::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor arctan2_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::arctan2::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & arctan2__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::arctan2_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor lerp_Scalar_generated_plumbing(const at::Tensor & self, const at::Tensor & end, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level)) { + return at::_ops::lerp_Scalar::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto results = batch_rule(self_value, self_bdim, end_value, end_bdim, weight); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor lerp_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & end, const at::Tensor & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(end, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::lerp_Tensor::call(self, end, weight); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [end_value, end_bdim] = unwrapTensorAtLevel(end, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(self_value, self_bdim, end_value, end_bdim, weight_value, weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor histc_generated_plumbing(const at::Tensor & self, int64_t bins, const at::Scalar & min, const at::Scalar & max) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::histc::call(self, bins, min, max); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, bins, min, max); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple histogram_bins_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & bins, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(bins, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogram_bins_tensor::call(self, bins, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [bins_value, bins_bdim] = unwrapTensorAtLevel(bins, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins_value, bins_bdim, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple histogram_bin_ct_generated_plumbing(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogram_bin_ct::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::vector _histogramdd_bin_edges_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_histogramdd_bin_edges::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _histogramdd_from_bin_cts_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_histogramdd_from_bin_cts::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _histogramdd_from_bin_tensors_generated_plumbing(const at::Tensor & self, at::TensorList bins, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(bins, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_histogramdd_from_bin_tensors::call(self, bins, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, weight_value, weight_bdim, density); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple> histogramdd_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogramdd::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple> histogramdd_int_bins_generated_plumbing(const at::Tensor & self, int64_t bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogramdd_int_bins::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple> histogramdd_TensorList_bins_generated_plumbing(const at::Tensor & self, at::TensorList bins, ::std::optional> range, const ::std::optional & weight, bool density) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(bins, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::histogramdd_TensorList_bins::call(self, bins, range, weight, density); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fmod_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fmod_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fmod__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fmod__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor fmod_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmod_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & fmod__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmod__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor hypot_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::hypot::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hypot__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::hypot_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor igamma_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igamma::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & igamma__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igamma_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor igammac_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igammac::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & igammac__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::igammac_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor nextafter_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::nextafter::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & nextafter__generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::nextafter_::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor remainder_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::remainder_Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & remainder__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::remainder__Scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, other); + return self; +} +template +at::Tensor remainder_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::remainder_Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & remainder__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::remainder__Tensor::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self_value, self_bdim, other_value, other_bdim); + return self; +} +template +at::Tensor remainder_Scalar_Tensor_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::remainder_Scalar_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor min_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::min::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fmin_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmin::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fmax_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::fmax::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor maximum_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::maximum::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_other_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::max_other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor minimum_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::minimum::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor min_other_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::min_other::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantile_generated_plumbing(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(q, cur_level)) { + return at::_ops::quantile::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [q_value, q_bdim] = unwrapTensorAtLevel(q, cur_level); + auto results = batch_rule(self_value, self_bdim, q_value, q_bdim, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor quantile_scalar_generated_plumbing(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::quantile_scalar::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, q, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nanquantile_generated_plumbing(const at::Tensor & self, const at::Tensor & q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(q, cur_level)) { + return at::_ops::nanquantile::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [q_value, q_bdim] = unwrapTensorAtLevel(q, cur_level); + auto results = batch_rule(self_value, self_bdim, q_value, q_bdim, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nanquantile_scalar_generated_plumbing(const at::Tensor & self, double q, ::std::optional dim, bool keepdim, c10::string_view interpolation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nanquantile_scalar::call(self, q, dim, keepdim, interpolation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, q, dim, keepdim, interpolation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple sort_generated_plumbing(const at::Tensor & self, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple sort_stable_generated_plumbing(const at::Tensor & self, ::std::optional stable, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort_stable::call(self, stable, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, stable, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple sort_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort_dimname::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple sort_dimname_stable_generated_plumbing(const at::Tensor & self, ::std::optional stable, at::Dimname dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sort_dimname_stable::call(self, stable, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, stable, dim, descending); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor msort_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::msort::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argsort_generated_plumbing(const at::Tensor & self, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argsort::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argsort_stable_generated_plumbing(const at::Tensor & self, bool stable, int64_t dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argsort_stable::call(self, stable, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, stable, dim, descending); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor argsort_dimname_generated_plumbing(const at::Tensor & self, at::Dimname dim, bool descending) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::argsort_dimname::call(self, dim, descending); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, descending); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple topk_generated_plumbing(const at::Tensor & self, c10::SymInt k, int64_t dim, bool largest, bool sorted) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::topk::call(self, k, dim, largest, sorted); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, k, dim, largest, sorted); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor all_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::all::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor any_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::any::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor renorm_generated_plumbing(const at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::renorm::call(self, p, dim, maxnorm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, dim, maxnorm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & renorm__generated_plumbing(at::Tensor & self, const at::Scalar & p, int64_t dim, const at::Scalar & maxnorm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::renorm_::call(self, p, dim, maxnorm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, p, dim, maxnorm); + return self; +} +template +at::Tensor unfold_generated_plumbing(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unfold::call(self, dimension, size, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dimension, size, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unfold_backward_generated_plumbing(const at::Tensor & grad_in, c10::SymIntArrayRef input_sizes, int64_t dim, int64_t size, int64_t step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_in, cur_level)) { + return at::_ops::unfold_backward::call(grad_in, input_sizes, dim, size, step); + } + auto [grad_in_value, grad_in_bdim] = unwrapTensorAtLevel(grad_in, cur_level); + auto results = batch_rule(grad_in_value, grad_in_bdim, input_sizes, dim, size, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pow_Tensor_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::pow_Tensor_Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pow_Scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::pow_Scalar::call(self, exponent); + } + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pow_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pow_Tensor_Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & pow__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pow__Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, exponent); + return self; +} +template +at::Tensor & pow__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::pow__Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return self; +} +template +at::Tensor float_power_Tensor_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::float_power_Tensor_Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor float_power_Scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::float_power_Scalar::call(self, exponent); + } + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + auto results = batch_rule(self, exponent_value, exponent_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor float_power_Tensor_Scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::float_power_Tensor_Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, exponent); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & float_power__Scalar_generated_plumbing(at::Tensor & self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::float_power__Scalar::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, exponent); + return self; +} +template +at::Tensor & float_power__Tensor_generated_plumbing(at::Tensor & self, const at::Tensor & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::float_power__Tensor::call(self, exponent); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [exponent_value, exponent_bdim] = unwrapTensorAtLevel(exponent, cur_level); + batch_rule(self_value, self_bdim, exponent_value, exponent_bdim); + return self; +} +template +at::Tensor & normal__generated_plumbing(at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::normal_::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, mean, std, generator); + return self; +} +template +at::Tensor normal_functional_generated_plumbing(const at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::normal_functional::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, mean, std, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor normal_Tensor_float_generated_plumbing(const at::Tensor & mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mean, cur_level)) { + return at::_ops::normal_Tensor_float::call(mean, std, generator); + } + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto results = batch_rule(mean_value, mean_bdim, std, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor normal_float_Tensor_generated_plumbing(double mean, const at::Tensor & std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(std, cur_level)) { + return at::_ops::normal_float_Tensor::call(mean, std, generator); + } + auto [std_value, std_bdim] = unwrapTensorAtLevel(std, cur_level); + auto results = batch_rule(mean, std_value, std_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor normal_Tensor_Tensor_generated_plumbing(const at::Tensor & mean, const at::Tensor & std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(mean, cur_level) && !isBatchedAtLevel(std, cur_level)) { + return at::_ops::normal_Tensor_Tensor::call(mean, std, generator); + } + auto [mean_value, mean_bdim] = unwrapTensorAtLevel(mean, cur_level); + auto [std_value, std_bdim] = unwrapTensorAtLevel(std, cur_level); + auto results = batch_rule(mean_value, mean_bdim, std_value, std_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor alias_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::alias::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _amp_foreach_non_finite_check_and_unscale__generated_plumbing(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(inv_scale, cur_level)) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_::call(self, found_inf, inv_scale); + } + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto [inv_scale_value, inv_scale_bdim] = unwrapTensorAtLevel(inv_scale, cur_level); + batch_rule(self, found_inf_value, found_inf_bdim, inv_scale_value, inv_scale_bdim); +} +template +::std::vector _foreach_add_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_add_List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add_List::call(self, other, alpha); + } + + auto results = batch_rule(self, other, alpha); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add__List::call(self, other, alpha); + } + + batch_rule(self, other, alpha); +} +template +::std::vector _foreach_add_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_add__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_add_Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add_Tensor::call(self, other, alpha); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim, alpha); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_add__Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_add__Tensor::call(self, other, alpha); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, alpha); +} +template +::std::vector _foreach_sub_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sub__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_sub_List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_sub_List::call(self, other, alpha); + } + + auto results = batch_rule(self, other, alpha); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sub__List_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_sub__List::call(self, other, alpha); + } + + batch_rule(self, other, alpha); +} +template +::std::vector _foreach_sub_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sub__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sub__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_mul_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_mul_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_mul_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_mul__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_mul_Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_mul__Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_mul__Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim); +} +template +::std::vector _foreach_div_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_div_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_div_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_div__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_div_Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div_Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_div__Tensor_generated_plumbing(at::TensorList self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_div__Tensor::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim); +} +template +::std::vector _foreach_clamp_max_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_max__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_clamp_max_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_max_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_max__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_max__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_clamp_max_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_max__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_max__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_clamp_min_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_min__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_clamp_min_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_min_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_min__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_clamp_min__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_clamp_min_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_clamp_min__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_clamp_min__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_maximum_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_maximum__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_maximum_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_maximum_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_maximum__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_maximum__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_maximum_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_maximum__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_maximum__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_minimum_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum_Scalar::call(self, scalar); + } + + auto results = batch_rule(self, scalar); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_minimum__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & scalar) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum__Scalar::call(self, scalar); + } + + batch_rule(self, scalar); +} +template +::std::vector _foreach_minimum_List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_minimum_List::call(self, other); + } + + auto results = batch_rule(self, other); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_minimum__List_generated_plumbing(at::TensorList self, at::TensorList other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_foreach_minimum__List::call(self, other); + } + + batch_rule(self, other); +} +template +::std::vector _foreach_minimum_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum_ScalarList::call(self, scalars); + } + + auto results = batch_rule(self, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_minimum__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_minimum__ScalarList::call(self, scalars); + } + + batch_rule(self, scalars); +} +template +::std::vector _foreach_addcdiv_Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv_Scalar::call(self, tensor1, tensor2, value); + } + + auto results = batch_rule(self, tensor1, tensor2, value); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcdiv_ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv_ScalarList::call(self, tensor1, tensor2, scalars); + } + + auto results = batch_rule(self, tensor1, tensor2, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcdiv_Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcdiv_Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + auto results = batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_addcdiv__Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv__Scalar::call(self, tensor1, tensor2, value); + } + + batch_rule(self, tensor1, tensor2, value); +} +template +void _foreach_addcdiv__ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcdiv__ScalarList::call(self, tensor1, tensor2, scalars); + } + + batch_rule(self, tensor1, tensor2, scalars); +} +template +void _foreach_addcdiv__Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcdiv__Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); +} +template +::std::vector _foreach_addcmul_Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul_Scalar::call(self, tensor1, tensor2, value); + } + + auto results = batch_rule(self, tensor1, tensor2, value); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcmul_ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul_ScalarList::call(self, tensor1, tensor2, scalars); + } + + auto results = batch_rule(self, tensor1, tensor2, scalars); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_addcmul_Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcmul_Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + auto results = batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_addcmul__Scalar_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul__Scalar::call(self, tensor1, tensor2, value); + } + + batch_rule(self, tensor1, tensor2, value); +} +template +void _foreach_addcmul__ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level)) { + return at::_ops::_foreach_addcmul__ScalarList::call(self, tensor1, tensor2, scalars); + } + + batch_rule(self, tensor1, tensor2, scalars); +} +template +void _foreach_addcmul__Tensor_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level)) { + return at::_ops::_foreach_addcmul__Tensor::call(self, tensor1, tensor2, scalars); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim); +} +template +::std::vector _foreach_abs_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_abs::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_abs__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_abs_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_acos_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_acos::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_acos__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_acos_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_asin_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_asin::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_asin__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_asin_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_atan_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_atan::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_atan__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_atan_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_ceil_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_ceil::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_ceil__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_ceil_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_cos_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cos::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_cos__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cos_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_cosh_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cosh::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_cosh__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_cosh_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_erf_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erf::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_erf__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erf_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_erfc_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erfc::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_erfc__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_erfc_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_exp_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_exp::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_exp__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_exp_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_expm1_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_expm1::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_expm1__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_expm1_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_floor_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_floor::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_floor__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_floor_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_frac_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_frac::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_frac__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_frac_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_lerp_List_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(weights, cur_level)) { + return at::_ops::_foreach_lerp_List::call(self, tensors1, weights); + } + + auto results = batch_rule(self, tensors1, weights); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lerp__List_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::TensorList weights) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(weights, cur_level)) { + return at::_ops::_foreach_lerp__List::call(self, tensors1, weights); + } + + batch_rule(self, tensors1, weights); +} +template +::std::vector _foreach_lerp_Scalar_generated_plumbing(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp_Scalar::call(self, tensors1, weight); + } + + auto results = batch_rule(self, tensors1, weight); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lerp__Scalar_generated_plumbing(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp__Scalar::call(self, tensors1, weight); + } + + batch_rule(self, tensors1, weight); +} +template +::std::vector _foreach_lerp_ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp_ScalarList::call(self, tensors1, weight); + } + + auto results = batch_rule(self, tensors1, weight); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lerp__ScalarList_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level)) { + return at::_ops::_foreach_lerp__ScalarList::call(self, tensors1, weight); + } + + batch_rule(self, tensors1, weight); +} +template +::std::vector _foreach_lgamma_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_lgamma::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_lgamma__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_lgamma_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log10_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log10::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log10__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log10_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log1p_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log1p::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log1p__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log1p_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_log2_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log2::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_log2__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_log2_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_max_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_max::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_neg_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_neg::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_neg__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_neg_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_norm_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & ord, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_norm_Scalar::call(self, ord, dtype); + } + + auto results = batch_rule(self, ord, dtype); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_List_generated_plumbing(at::TensorList self, at::TensorList exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::_foreach_pow_List::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_Scalar_generated_plumbing(at::TensorList self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow_Scalar::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow_ScalarList::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector _foreach_pow_ScalarAndTensor_generated_plumbing(const at::Scalar & self, at::TensorList exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::_foreach_pow_ScalarAndTensor::call(self, exponent); + } + + auto results = batch_rule(self, exponent); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_pow__List_generated_plumbing(at::TensorList self, at::TensorList exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level)) { + return at::_ops::_foreach_pow__List::call(self, exponent); + } + + batch_rule(self, exponent); +} +template +void _foreach_pow__Scalar_generated_plumbing(at::TensorList self, const at::Scalar & exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow__Scalar::call(self, exponent); + } + + batch_rule(self, exponent); +} +template +void _foreach_pow__ScalarList_generated_plumbing(at::TensorList self, at::ArrayRef exponent) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_pow__ScalarList::call(self, exponent); + } + + batch_rule(self, exponent); +} +template +::std::vector _foreach_reciprocal_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_reciprocal::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_reciprocal__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_reciprocal_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_round_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_round::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_round__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_round_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_rsqrt_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_rsqrt::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_rsqrt__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_rsqrt_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sigmoid_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sigmoid::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sigmoid__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sigmoid_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sign_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sign::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sign__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sign_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sin_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sin::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sin__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sin_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sinh_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sinh::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sinh__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sinh_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_sqrt_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sqrt::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_sqrt__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_sqrt_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_tan_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tan::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_tan__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tan_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_tanh_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tanh::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_tanh__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_tanh_::call(self); + } + + batch_rule(self); +} +template +::std::vector _foreach_trunc_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_trunc::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_trunc__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_trunc_::call(self); + } + + batch_rule(self); +} +template +void _foreach_zero__generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_zero_::call(self); + } + + batch_rule(self); +} +template +void _foreach_copy__generated_plumbing(at::TensorList self, at::TensorList src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::_foreach_copy_::call(self, src, non_blocking); + } + + batch_rule(self, src, non_blocking); +} +template +::std::vector _foreach_copy_generated_plumbing(at::TensorList self, at::TensorList src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::_foreach_copy::call(self, src, non_blocking); + } + + auto results = batch_rule(self, src, non_blocking); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bucketize_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & boundaries, bool out_int32, bool right) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(boundaries, cur_level)) { + return at::_ops::bucketize_Tensor::call(self, boundaries, out_int32, right); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [boundaries_value, boundaries_bdim] = unwrapTensorAtLevel(boundaries, cur_level); + auto results = batch_rule(self_value, self_bdim, boundaries_value, boundaries_bdim, out_int32, right); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor bucketize_Scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & boundaries, bool out_int32, bool right) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(boundaries, cur_level)) { + return at::_ops::bucketize_Scalar::call(self, boundaries, out_int32, right); + } + auto [boundaries_value, boundaries_bdim] = unwrapTensorAtLevel(boundaries, cur_level); + auto results = batch_rule(self, boundaries_value, boundaries_bdim, out_int32, right); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor searchsorted_Tensor_generated_plumbing(const at::Tensor & sorted_sequence, const at::Tensor & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sorted_sequence, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(sorter, cur_level)) { + return at::_ops::searchsorted_Tensor::call(sorted_sequence, self, out_int32, right, side, sorter); + } + auto [sorted_sequence_value, sorted_sequence_bdim] = unwrapTensorAtLevel(sorted_sequence, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional sorter_value; + std::optional sorter_bdim; + if (sorter) { + std::tie(sorter_value, sorter_bdim) = unwrapTensorAtLevel(sorter.value(), cur_level); + } + auto results = batch_rule(sorted_sequence_value, sorted_sequence_bdim, self_value, self_bdim, out_int32, right, side, sorter_value, sorter_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor searchsorted_Scalar_generated_plumbing(const at::Tensor & sorted_sequence, const at::Scalar & self, bool out_int32, bool right, ::std::optional side, const ::std::optional & sorter) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sorted_sequence, cur_level) && !isBatchedAtLevel(sorter, cur_level)) { + return at::_ops::searchsorted_Scalar::call(sorted_sequence, self, out_int32, right, side, sorter); + } + auto [sorted_sequence_value, sorted_sequence_bdim] = unwrapTensorAtLevel(sorted_sequence, cur_level); + std::optional sorter_value; + std::optional sorter_bdim; + if (sorter) { + std::tie(sorter_value, sorter_bdim) = unwrapTensorAtLevel(sorter.value(), cur_level); + } + auto results = batch_rule(sorted_sequence_value, sorted_sequence_bdim, self, out_int32, right, side, sorter_value, sorter_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_indices_from_coo_to_csr_generated_plumbing(const at::Tensor & self, int64_t size, bool out_int32) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_convert_indices_from_coo_to_csr::call(self, size, out_int32); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, out_int32); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _convert_indices_from_csr_to_coo_generated_plumbing(const at::Tensor & crow_indices, const at::Tensor & col_indices, bool out_int32, bool transpose) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(crow_indices, cur_level) && !isBatchedAtLevel(col_indices, cur_level)) { + return at::_ops::_convert_indices_from_csr_to_coo::call(crow_indices, col_indices, out_int32, transpose); + } + auto [crow_indices_value, crow_indices_bdim] = unwrapTensorAtLevel(crow_indices, cur_level); + auto [col_indices_value, col_indices_bdim] = unwrapTensorAtLevel(col_indices, cur_level); + auto results = batch_rule(crow_indices_value, crow_indices_bdim, col_indices_value, col_indices_bdim, out_int32, transpose); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mse_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::mse_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mse_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::mse_loss_backward::call(grad_output, self, target, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor l1_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::l1_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor multi_margin_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::multi_margin_loss::call(self, target, p, margin, weight, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, p, margin, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor multi_margin_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const at::Scalar & p, const at::Scalar & margin, const ::std::optional & weight, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::multi_margin_loss_backward::call(grad_output, self, target, p, margin, weight, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, p, margin, weight_value, weight_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor multilabel_margin_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::multilabel_margin_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple multilabel_margin_loss_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::multilabel_margin_loss_forward::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor multilabel_margin_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, const at::Tensor & is_target) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(is_target, cur_level)) { + return at::_ops::multilabel_margin_loss_backward::call(grad_output, self, target, reduction, is_target); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto [is_target_value, is_target_bdim] = unwrapTensorAtLevel(is_target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction, is_target_value, is_target_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nll_loss_nd_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss_nd::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nll_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple nll_loss_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss_forward::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nll_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(total_weight, cur_level)) { + return at::_ops::nll_loss_backward::call(grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto [total_weight_value, total_weight_bdim] = unwrapTensorAtLevel(total_weight, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index, total_weight_value, total_weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nll_loss2d_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss2d::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple nll_loss2d_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::nll_loss2d_forward::call(self, target, weight, reduction, ignore_index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor nll_loss2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const ::std::optional & weight, int64_t reduction, c10::SymInt ignore_index, const at::Tensor & total_weight) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(total_weight, cur_level)) { + return at::_ops::nll_loss2d_backward::call(grad_output, self, target, weight, reduction, ignore_index, total_weight); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto [total_weight_value, total_weight_bdim] = unwrapTensorAtLevel(total_weight, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, weight_value, weight_bdim, reduction, ignore_index, total_weight_value, total_weight_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor smooth_l1_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::smooth_l1_loss::call(self, target, reduction, beta); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction, beta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor smooth_l1_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double beta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::smooth_l1_loss_backward::call(grad_output, self, target, reduction, beta); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction, beta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor huber_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::huber_loss::call(self, target, reduction, delta); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction, delta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor huber_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction, double delta) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::huber_loss_backward::call(grad_output, self, target, reduction, delta); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction, delta); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor soft_margin_loss_generated_plumbing(const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::soft_margin_loss::call(self, target, reduction); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor soft_margin_loss_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, int64_t reduction) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(target, cur_level)) { + return at::_ops::soft_margin_loss_backward::call(grad_output, self, target, reduction); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [target_value, target_bdim] = unwrapTensorAtLevel(target, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, target_value, target_bdim, reduction); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor elu_generated_plumbing(const at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::elu::call(self, alpha, scale, input_scale); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, alpha, scale, input_scale); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor elu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale, bool is_result, const at::Tensor & self_or_result) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self_or_result, cur_level)) { + return at::_ops::elu_backward::call(grad_output, alpha, scale, input_scale, is_result, self_or_result); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_or_result_value, self_or_result_bdim] = unwrapTensorAtLevel(self_or_result, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, alpha, scale, input_scale, is_result, self_or_result_value, self_or_result_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & elu__generated_plumbing(at::Tensor & self, const at::Scalar & alpha, const at::Scalar & scale, const at::Scalar & input_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::elu_::call(self, alpha, scale, input_scale); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, alpha, scale, input_scale); + return self; +} +template +at::Tensor glu_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::glu::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor glu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::glu_backward::call(grad_output, self, dim); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor glu_jvp_generated_plumbing(const at::Tensor & glu, const at::Tensor & x, const at::Tensor & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(glu, cur_level) && !isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(dx, cur_level)) { + return at::_ops::glu_jvp::call(glu, x, dx, dim); + } + auto [glu_value, glu_bdim] = unwrapTensorAtLevel(glu, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [dx_value, dx_bdim] = unwrapTensorAtLevel(dx, cur_level); + auto results = batch_rule(glu_value, glu_bdim, x_value, x_bdim, dx_value, dx_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor glu_backward_jvp_generated_plumbing(const at::Tensor & grad_x, const at::Tensor & grad_glu, const at::Tensor & x, const at::Tensor & dgrad_glu, const at::Tensor & dx, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_x, cur_level) && !isBatchedAtLevel(grad_glu, cur_level) && !isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(dgrad_glu, cur_level) && !isBatchedAtLevel(dx, cur_level)) { + return at::_ops::glu_backward_jvp::call(grad_x, grad_glu, x, dgrad_glu, dx, dim); + } + auto [grad_x_value, grad_x_bdim] = unwrapTensorAtLevel(grad_x, cur_level); + auto [grad_glu_value, grad_glu_bdim] = unwrapTensorAtLevel(grad_glu, cur_level); + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [dgrad_glu_value, dgrad_glu_bdim] = unwrapTensorAtLevel(dgrad_glu, cur_level); + auto [dx_value, dx_bdim] = unwrapTensorAtLevel(dx, cur_level); + auto results = batch_rule(grad_x_value, grad_x_bdim, grad_glu_value, grad_glu_bdim, x_value, x_bdim, dgrad_glu_value, dgrad_glu_bdim, dx_value, dx_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardsigmoid_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardsigmoid::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hardsigmoid__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardsigmoid_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor hardsigmoid_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardsigmoid_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardtanh_generated_plumbing(const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardtanh::call(self, min_val, max_val); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, min_val, max_val); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor hardtanh_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardtanh_backward::call(grad_output, self, min_val, max_val); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, min_val, max_val); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hardtanh__generated_plumbing(at::Tensor & self, const at::Scalar & min_val, const at::Scalar & max_val) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardtanh_::call(self, min_val, max_val); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, min_val, max_val); + return self; +} +template +at::Tensor hardswish_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardswish::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & hardswish__generated_plumbing(at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardswish_::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim); + return self; +} +template +at::Tensor hardswish_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::hardswish_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor leaky_relu_generated_plumbing(const at::Tensor & self, const at::Scalar & negative_slope) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::leaky_relu::call(self, negative_slope); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, negative_slope); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor leaky_relu_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & negative_slope, bool self_is_result) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::leaky_relu_backward::call(grad_output, self, negative_slope, self_is_result); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, negative_slope, self_is_result); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & leaky_relu__generated_plumbing(at::Tensor & self, const at::Scalar & negative_slope) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::leaky_relu_::call(self, negative_slope); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, negative_slope); + return self; +} +template +at::Tensor log_sigmoid_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_sigmoid::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple log_sigmoid_forward_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_sigmoid_forward::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor log_sigmoid_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & buffer) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(buffer, cur_level)) { + return at::_ops::log_sigmoid_backward::call(grad_output, self, buffer); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [buffer_value, buffer_bdim] = unwrapTensorAtLevel(buffer, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, buffer_value, buffer_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor rrelu_with_noise_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, bool self_is_result) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(noise, cur_level)) { + return at::_ops::rrelu_with_noise_backward::call(grad_output, self, noise, lower, upper, training, self_is_result); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, self_is_result); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softplus_generated_plumbing(const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softplus::call(self, beta, threshold); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, beta, threshold); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softplus_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & beta, const at::Scalar & threshold) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::softplus_backward::call(grad_output, self, beta, threshold); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, beta, threshold); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softshrink_generated_plumbing(const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::softshrink::call(self, lambd); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor softshrink_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Scalar & lambd) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::softshrink_backward::call(grad_output, self, lambd); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, lambd); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adaptive_avg_pool2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_avg_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_adaptive_avg_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_adaptive_avg_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor mkldnn_adaptive_avg_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::mkldnn_adaptive_avg_pool2d_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool2d_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor adaptive_avg_pool3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_avg_pool3d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool3d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _adaptive_avg_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::_adaptive_avg_pool3d_backward::call(grad_output, self); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple adaptive_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_max_pool2d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor adaptive_max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::adaptive_max_pool2d_backward::call(grad_output, self, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple adaptive_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::adaptive_max_pool3d::call(self, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor adaptive_max_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::adaptive_max_pool3d_backward::call(grad_output, self, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool2d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool2d_backward::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool3d::call(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor avg_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, bool ceil_mode, bool count_include_pad, ::std::optional divisor_override) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::avg_pool3d_backward::call(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fractional_max_pool2d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(random_samples, cur_level)) { + return at::_ops::fractional_max_pool2d::call(self, kernel_size, output_size, random_samples); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [random_samples_value, random_samples_bdim] = unwrapTensorAtLevel(random_samples, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, output_size, random_samples_value, random_samples_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fractional_max_pool2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::fractional_max_pool2d_backward::call(grad_output, self, kernel_size, output_size, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, output_size, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple fractional_max_pool3d_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & random_samples) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(random_samples, cur_level)) { + return at::_ops::fractional_max_pool3d::call(self, kernel_size, output_size, random_samples); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [random_samples_value, random_samples_bdim] = unwrapTensorAtLevel(random_samples, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, output_size, random_samples_value, random_samples_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor fractional_max_pool3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef output_size, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::fractional_max_pool3d_backward::call(grad_output, self, kernel_size, output_size, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, output_size, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_pool2d_with_indices_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool2d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor max_pool2d_with_indices_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_pool2d_with_indices_backward::call(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple max_pool3d_with_indices_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::max_pool3d_with_indices::call(self, kernel_size, stride, padding, dilation, ceil_mode); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor max_pool3d_with_indices_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, const at::Tensor & indices) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_pool3d_with_indices_backward::call(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, kernel_size, stride, padding, dilation, ceil_mode, indices_value, indices_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_unpool2d_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_unpool2d::call(self, indices, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor max_unpool3d_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, c10::SymIntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::max_unpool3d::call(self, indices, output_size, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, output_size, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad1d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad1d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad1d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad2d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad2d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad3d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor reflection_pad3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::reflection_pad3d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad1d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad1d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad1d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad2d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad2d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad2d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad3d::call(self, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor replication_pad3d_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::replication_pad3d_backward::call(grad_output, self, padding); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pad_circular_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pad_circular::call(self, pad); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _pad_enum_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad, int64_t mode, ::std::optional value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_pad_enum::call(self, pad, mode, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad, mode, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pad_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef pad, c10::string_view mode, ::std::optional value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::pad::call(self, pad, mode, value); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, pad, mode, value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_linear1d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_linear1d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bilinear2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_bilinear2d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bilinear2d_aa_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_bilinear2d_aa_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_trilinear3d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_trilinear3d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bicubic2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_bicubic2d_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bicubic2d_aa_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, bool align_corners, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_bicubic2d_aa_vec::call(input, output_size, align_corners, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, align_corners, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest1d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_nearest1d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact1d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_nearest_exact1d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_nearest2d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact2d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_nearest_exact2d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest3d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::upsample_nearest3d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact3d_vec_generated_plumbing(const at::Tensor & input, at::OptionalSymIntArrayRef output_size, ::std::optional> scale_factors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level)) { + return at::_ops::_upsample_nearest_exact3d_vec::call(input, output_size, scale_factors); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto results = batch_rule(input_value, input_bdim, output_size, scale_factors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_linear1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_linear1d::call(self, output_size, align_corners, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_linear1d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_linear1d_backward::call(grad_output, output_size, input_size, align_corners, scales); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bilinear2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_bilinear2d::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bilinear2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_bilinear2d_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bilinear2d_aa_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_bilinear2d_aa::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bilinear2d_aa_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_bilinear2d_aa_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bicubic2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_bicubic2d::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_bicubic2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_bicubic2d_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bicubic2d_aa_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_bicubic2d_aa::call(self, output_size, align_corners, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_bicubic2d_aa_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_bicubic2d_aa_backward::call(grad_output, output_size, input_size, align_corners, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_trilinear3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_trilinear3d::call(self, output_size, align_corners, scales_d, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, align_corners, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_trilinear3d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, bool align_corners, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_trilinear3d_backward::call(grad_output, output_size, input_size, align_corners, scales_d, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, align_corners, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_nearest1d::call(self, output_size, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact1d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_nearest_exact1d::call(self, output_size, scales); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest1d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_nearest1d_backward::call(grad_output, output_size, input_size, scales); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact1d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_nearest_exact1d_backward::call(grad_output, output_size, input_size, scales); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_nearest2d::call(self, output_size, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact2d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_nearest_exact2d::call(self, output_size, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_nearest2d_backward::call(grad_output, output_size, input_size, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact2d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_nearest_exact2d_backward::call(grad_output, output_size, input_size, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::upsample_nearest3d::call(self, output_size, scales_d, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact3d_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_upsample_nearest_exact3d::call(self, output_size, scales_d, scales_h, scales_w); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor upsample_nearest3d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::upsample_nearest3d_backward::call(grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _upsample_nearest_exact3d_backward_generated_plumbing(const at::Tensor & grad_output, c10::SymIntArrayRef output_size, c10::SymIntArrayRef input_size, ::std::optional scales_d, ::std::optional scales_h, ::std::optional scales_w) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level)) { + return at::_ops::_upsample_nearest_exact3d_backward::call(grad_output, output_size, input_size, scales_d, scales_h, scales_w); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_size, input_size, scales_d, scales_h, scales_w); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sigmoid_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::sigmoid_backward::call(grad_output, output); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor logit_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level)) { + return at::_ops::logit_backward::call(grad_output, self, eps); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor tanh_backward_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::tanh_backward::call(grad_output, output); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, output_value, output_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_transpose2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_transpose2d::call(self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, output_padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_transpose3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef output_padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_transpose3d::call(self, weight, kernel_size, bias, stride, padding, output_padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, output_padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor thnn_conv2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::thnn_conv2d::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _slow_conv2d_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_slow_conv2d_forward::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _slow_conv2d_backward_output_mask_generated_plumbing(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, ::std::array output_mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level)) { + return at::_ops::_slow_conv2d_backward_output_mask::call(grad_output, self, weight, kernel_size, stride, padding, output_mask); + } + auto [grad_output_value, grad_output_bdim] = unwrapTensorAtLevel(grad_output, cur_level); + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + auto results = batch_rule(grad_output_value, grad_output_bdim, self_value, self_bdim, weight_value, weight_bdim, kernel_size, stride, padding, output_mask); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor _conv_depthwise2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::_conv_depthwise2d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor conv_depthwise3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::conv_depthwise3d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv3d::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv3d_forward_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv3d_forward::call(self, weight, kernel_size, bias, stride, padding); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_dilated2d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_dilated2d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slow_conv_dilated3d_generated_plumbing(const at::Tensor & self, const at::Tensor & weight, c10::SymIntArrayRef kernel_size, const ::std::optional & bias, c10::SymIntArrayRef stride, c10::SymIntArrayRef padding, c10::SymIntArrayRef dilation) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level)) { + return at::_ops::slow_conv_dilated3d::call(self, weight, kernel_size, bias, stride, padding, dilation); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [weight_value, weight_bdim] = unwrapTensorAtLevel(weight, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, weight_value, weight_bdim, kernel_size, bias_value, bias_bdim, stride, padding, dilation); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor col2im_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef output_size, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::col2im::call(self, output_size, kernel_size, dilation, padding, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, output_size, kernel_size, dilation, padding, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor column_stack_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::column_stack::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor im2col_generated_plumbing(const at::Tensor & self, at::IntArrayRef kernel_size, at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::im2col::call(self, kernel_size, dilation, padding, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, kernel_size, dilation, padding, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isfinite_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isfinite::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isinf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isinf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void record_stream_generated_plumbing(at::Tensor & self, at::Stream s) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::record_stream::call(self, s); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, s); +} +template +at::Tensor isposinf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isposinf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor isneginf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::isneginf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _add_batch_dim_generated_plumbing(const at::Tensor & self, int64_t batch_dim, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_add_batch_dim::call(self, batch_dim, level); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, batch_dim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _remove_batch_dim_generated_plumbing(const at::Tensor & self, int64_t level, c10::SymInt batch_size, int64_t out_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_remove_batch_dim::call(self, level, batch_size, out_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, level, batch_size, out_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_entr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_entr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_ndtri_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_ndtri::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_log_ndtr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_log_ndtr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_expm1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_expm1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_exp2_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_exp2::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_psi_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_psi::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_digamma_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_digamma::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_gammaln_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_gammaln::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erf_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erf::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erfc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erfc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erfcx_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erfcx::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_erfinv_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_erfinv::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_ndtr_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_ndtr::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlog1py_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlog1py::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlog1py_self_scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlog1py_self_scalar::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlog1py_other_scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_xlog1py_other_scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlogy_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlogy::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlogy_self_scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_xlogy_self_scalar::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_xlogy_other_scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_xlogy_other_scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_zeta_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_zeta::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_zeta_self_scalar_generated_plumbing(const at::Scalar & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_zeta_self_scalar::call(self, other); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_zeta_other_scalar_generated_plumbing(const at::Tensor & self, const at::Scalar & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_zeta_other_scalar::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, other); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i0e_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i0e::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_i1e_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_i1e::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_logit_generated_plumbing(const at::Tensor & self, ::std::optional eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_logit::call(self, eps); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, eps); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_polygamma_generated_plumbing(int64_t n, const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_polygamma::call(n, self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(n, self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_logsumexp_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim, bool keepdim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_logsumexp::call(self, dim, keepdim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, keepdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_expit_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_expit::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_sinc_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_sinc::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_round_generated_plumbing(const at::Tensor & self, int64_t decimals) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_round::call(self, decimals); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, decimals); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_log1p_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_log1p::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_log_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_log_softmax::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_gammainc_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_gammainc::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_gammaincc_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::special_gammaincc::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_multigammaln_generated_plumbing(const at::Tensor & self, int64_t p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_multigammaln::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_softmax::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_rfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_rfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_irfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_irfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_hfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_hfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ihfft_generated_plumbing(const at::Tensor & self, ::std::optional n, int64_t dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ihfft::call(self, n, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_rfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_rfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_irfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_irfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_hfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_hfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ihfft2_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::IntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ihfft2::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_rfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_rfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_irfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_irfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_hfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_hfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ihfftn_generated_plumbing(const at::Tensor & self, at::OptionalSymIntArrayRef s, at::OptionalIntArrayRef dim, ::std::optional norm) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ihfftn::call(self, s, dim, norm); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, s, dim, norm); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_fftshift_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_fftshift::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor fft_ifftshift_generated_plumbing(const at::Tensor & self, at::OptionalIntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::fft_ifftshift::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_cholesky_ex_generated_plumbing(const at::Tensor & self, bool upper, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cholesky_ex::call(self, upper, check_errors); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_cholesky_generated_plumbing(const at::Tensor & self, bool upper) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cholesky::call(self, upper); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, upper); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_cross_generated_plumbing(const at::Tensor & self, const at::Tensor & other, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::linalg_cross::call(self, other, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_lu_factor_generated_plumbing(const at::Tensor & A, bool pivot) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_lu_factor::call(A, pivot); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, pivot); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple linalg_lu_factor_ex_generated_plumbing(const at::Tensor & A, bool pivot, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_lu_factor_ex::call(A, pivot, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, pivot, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple linalg_lu_generated_plumbing(const at::Tensor & A, bool pivot) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_lu::call(A, pivot); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, pivot); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor linalg_lu_solve_generated_plumbing(const at::Tensor & LU, const at::Tensor & pivots, const at::Tensor & B, bool left, bool adjoint) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(LU, cur_level) && !isBatchedAtLevel(pivots, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_lu_solve::call(LU, pivots, B, left, adjoint); + } + auto [LU_value, LU_bdim] = unwrapTensorAtLevel(LU, cur_level); + auto [pivots_value, pivots_bdim] = unwrapTensorAtLevel(pivots, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(LU_value, LU_bdim, pivots_value, pivots_bdim, B_value, B_bdim, left, adjoint); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_det_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_det::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor linalg_det_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_det::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor det_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::det::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_ldl_factor_ex_generated_plumbing(const at::Tensor & self, bool hermitian, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_ldl_factor_ex::call(self, hermitian, check_errors); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, hermitian, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple linalg_ldl_factor_generated_plumbing(const at::Tensor & self, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_ldl_factor::call(self, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, hermitian); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_ldl_solve_generated_plumbing(const at::Tensor & LD, const at::Tensor & pivots, const at::Tensor & B, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(LD, cur_level) && !isBatchedAtLevel(pivots, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_ldl_solve::call(LD, pivots, B, hermitian); + } + auto [LD_value, LD_bdim] = unwrapTensorAtLevel(LD, cur_level); + auto [pivots_value, pivots_bdim] = unwrapTensorAtLevel(pivots, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(LD_value, LD_bdim, pivots_value, pivots_bdim, B_value, B_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_lstsq_generated_plumbing(const at::Tensor & self, const at::Tensor & b, ::std::optional rcond, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(b, cur_level)) { + return at::_ops::linalg_lstsq::call(self, b, rcond, driver); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [b_value, b_bdim] = unwrapTensorAtLevel(b, cur_level); + auto results = batch_rule(self_value, self_bdim, b_value, b_bdim, rcond, driver); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor linalg_matmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::linalg_matmul::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_vecdot_generated_plumbing(const at::Tensor & x, const at::Tensor & y, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(y, cur_level)) { + return at::_ops::linalg_vecdot::call(x, y, dim); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [y_value, y_bdim] = unwrapTensorAtLevel(y, cur_level); + auto results = batch_rule(x_value, x_bdim, y_value, y_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_exp_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_exp::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_slogdet_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_slogdet::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple linalg_slogdet_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_slogdet::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple slogdet_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::slogdet::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor logdet_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::logdet::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_eig_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eig::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor _linalg_eigvals_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_linalg_eigvals::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_eigvals_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eigvals::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_eigh_generated_plumbing(const at::Tensor & A, c10::string_view UPLO, bool compute_v) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_eigh::call(A, UPLO, compute_v); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, UPLO, compute_v); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple linalg_eigh_generated_plumbing(const at::Tensor & self, c10::string_view UPLO) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eigh::call(self, UPLO); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, UPLO); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_eigvalsh_generated_plumbing(const at::Tensor & self, c10::string_view UPLO) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_eigvalsh::call(self, UPLO); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, UPLO); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_householder_product_generated_plumbing(const at::Tensor & input, const at::Tensor & tau) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(tau, cur_level)) { + return at::_ops::linalg_householder_product::call(input, tau); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [tau_value, tau_bdim] = unwrapTensorAtLevel(tau, cur_level); + auto results = batch_rule(input_value, input_bdim, tau_value, tau_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_inv_ex_generated_plumbing(const at::Tensor & A, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_inv_ex::call(A, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_inv_generated_plumbing(const at::Tensor & A) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_inv::call(A); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor inverse_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::inverse::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor inner_generated_plumbing(const at::Tensor & self, const at::Tensor & other) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::inner::call(self, other); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor outer_generated_plumbing(const at::Tensor & self, const at::Tensor & vec2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::outer::call(self, vec2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + auto results = batch_rule(self_value, self_bdim, vec2_value, vec2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ger_generated_plumbing(const at::Tensor & self, const at::Tensor & vec2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(vec2, cur_level)) { + return at::_ops::ger::call(self, vec2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [vec2_value, vec2_bdim] = unwrapTensorAtLevel(vec2, cur_level); + auto results = batch_rule(self_value, self_bdim, vec2_value, vec2_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_norm_generated_plumbing(const at::Tensor & self, const ::std::optional & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_norm::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_norm_ord_str_generated_plumbing(const at::Tensor & self, c10::string_view ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_norm_ord_str::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_vector_norm_generated_plumbing(const at::Tensor & self, const at::Scalar & ord, at::OptionalIntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_vector_norm::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_norm_generated_plumbing(const at::Tensor & self, const at::Scalar & ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_norm::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_norm_str_ord_generated_plumbing(const at::Tensor & self, c10::string_view ord, at::IntArrayRef dim, bool keepdim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_norm_str_ord::call(self, ord, dim, keepdim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ord, dim, keepdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_svd_generated_plumbing(const at::Tensor & A, bool full_matrices, bool compute_uv, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::_linalg_svd::call(A, full_matrices, compute_uv, driver); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, full_matrices, compute_uv, driver); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple linalg_svd_generated_plumbing(const at::Tensor & A, bool full_matrices, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_svd::call(A, full_matrices, driver); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, full_matrices, driver); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +at::Tensor linalg_svdvals_generated_plumbing(const at::Tensor & A, ::std::optional driver) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_svdvals::call(A, driver); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, driver); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_cond_generated_plumbing(const at::Tensor & self, const ::std::optional & p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cond::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_cond_p_str_generated_plumbing(const at::Tensor & self, c10::string_view p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_cond_p_str::call(self, p); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_atol_rtol_tensor_generated_plumbing(const at::Tensor & self, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(atol, cur_level) && !isBatchedAtLevel(rtol, cur_level)) { + return at::_ops::linalg_pinv_atol_rtol_tensor::call(self, atol, rtol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional atol_value; + std::optional atol_bdim; + if (atol) { + std::tie(atol_value, atol_bdim) = unwrapTensorAtLevel(atol.value(), cur_level); + } + std::optional rtol_value; + std::optional rtol_bdim; + if (rtol) { + std::tie(rtol_value, rtol_bdim) = unwrapTensorAtLevel(rtol.value(), cur_level); + } + auto results = batch_rule(self_value, self_bdim, atol_value, atol_bdim, rtol_value, rtol_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_atol_rtol_float_generated_plumbing(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_pinv_atol_rtol_float::call(self, atol, rtol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, atol, rtol, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_generated_plumbing(const at::Tensor & self, double rcond, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_pinv::call(self, rcond, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, rcond, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_pinv_rcond_tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & rcond, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(rcond, cur_level)) { + return at::_ops::linalg_pinv_rcond_tensor::call(self, rcond, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [rcond_value, rcond_bdim] = unwrapTensorAtLevel(rcond, cur_level); + auto results = batch_rule(self_value, self_bdim, rcond_value, rcond_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _linalg_solve_ex_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::_linalg_solve_ex::call(A, B, left, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple linalg_solve_ex_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left, bool check_errors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_solve_ex::call(A, B, left, check_errors); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left, check_errors); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_solve_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::linalg_solve::call(A, B, left); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _spsolve_generated_plumbing(const at::Tensor & A, const at::Tensor & B, bool left) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level) && !isBatchedAtLevel(B, cur_level)) { + return at::_ops::_spsolve::call(A, B, left); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto [B_value, B_bdim] = unwrapTensorAtLevel(B, cur_level); + auto results = batch_rule(A_value, A_bdim, B_value, B_bdim, left); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_tensorinv_generated_plumbing(const at::Tensor & self, int64_t ind) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_tensorinv::call(self, ind); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, ind); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_tensorsolve_generated_plumbing(const at::Tensor & self, const at::Tensor & other, at::OptionalIntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::linalg_tensorsolve::call(self, other, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple linalg_qr_generated_plumbing(const at::Tensor & A, c10::string_view mode) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(A, cur_level)) { + return at::_ops::linalg_qr::call(A, mode); + } + auto [A_value, A_bdim] = unwrapTensorAtLevel(A, cur_level); + auto results = batch_rule(A_value, A_bdim, mode); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor linalg_matrix_power_generated_plumbing(const at::Tensor & self, int64_t n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_power::call(self, n); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_atol_rtol_tensor_generated_plumbing(const at::Tensor & input, const ::std::optional & atol, const ::std::optional & rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(atol, cur_level) && !isBatchedAtLevel(rtol, cur_level)) { + return at::_ops::linalg_matrix_rank_atol_rtol_tensor::call(input, atol, rtol, hermitian); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + std::optional atol_value; + std::optional atol_bdim; + if (atol) { + std::tie(atol_value, atol_bdim) = unwrapTensorAtLevel(atol.value(), cur_level); + } + std::optional rtol_value; + std::optional rtol_bdim; + if (rtol) { + std::tie(rtol_value, rtol_bdim) = unwrapTensorAtLevel(rtol.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, atol_value, atol_bdim, rtol_value, rtol_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_atol_rtol_float_generated_plumbing(const at::Tensor & self, ::std::optional atol, ::std::optional rtol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_rank_atol_rtol_float::call(self, atol, rtol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, atol, rtol, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_generated_plumbing(const at::Tensor & self, double tol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::linalg_matrix_rank::call(self, tol, hermitian); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, tol, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_matrix_rank_tol_tensor_generated_plumbing(const at::Tensor & input, const at::Tensor & tol, bool hermitian) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(tol, cur_level)) { + return at::_ops::linalg_matrix_rank_tol_tensor::call(input, tol, hermitian); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [tol_value, tol_bdim] = unwrapTensorAtLevel(tol, cur_level); + auto results = batch_rule(input_value, input_bdim, tol_value, tol_bdim, hermitian); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor linalg_multi_dot_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::linalg_multi_dot::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor nested_to_padded_tensor_generated_plumbing(const at::Tensor & self, double padding, at::OptionalIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::nested_to_padded_tensor::call(self, padding, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_serialization_subcmul_generated_plumbing(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level)) { + return at::_ops::_test_serialization_subcmul::call(self, other, alpha); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + auto results = batch_rule(self_value, self_bdim, other_value, other_bdim, alpha); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_parallel_materialize_generated_plumbing(const at::Tensor & self, int64_t num_parallel, bool skip_first) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_parallel_materialize::call(self, num_parallel, skip_first); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, num_parallel, skip_first); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_optional_intlist_generated_plumbing(const at::Tensor & values, at::OptionalIntArrayRef addends) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level)) { + return at::_ops::_test_optional_intlist::call(values, addends); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, addends); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_optional_filled_intlist_generated_plumbing(const at::Tensor & values, at::OptionalIntArrayRef addends) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level)) { + return at::_ops::_test_optional_filled_intlist::call(values, addends); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, addends); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_optional_floatlist_generated_plumbing(const at::Tensor & values, ::std::optional> addends) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level)) { + return at::_ops::_test_optional_floatlist::call(values, addends); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, addends); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_string_default_generated_plumbing(const at::Tensor & dummy, c10::string_view a, c10::string_view b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dummy, cur_level)) { + return at::_ops::_test_string_default::call(dummy, a, b); + } + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + auto results = batch_rule(dummy_value, dummy_bdim, a, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_ambiguous_defaults_a_generated_plumbing(const at::Tensor & dummy, int64_t a, int64_t b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dummy, cur_level)) { + return at::_ops::_test_ambiguous_defaults_a::call(dummy, a, b); + } + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + auto results = batch_rule(dummy_value, dummy_bdim, a, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_ambiguous_defaults_b_generated_plumbing(const at::Tensor & dummy, int64_t a, c10::string_view b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dummy, cur_level)) { + return at::_ops::_test_ambiguous_defaults_b::call(dummy, a, b); + } + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + auto results = batch_rule(dummy_value, dummy_bdim, a, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_warn_in_autograd_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_warn_in_autograd::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_fullcoverage_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_fullcoverage::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_ntonly_generated_plumbing(const at::Tensor & self, bool b) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_ntonly::call(self, b); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, b); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_view_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_view::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _test_autograd_multiple_dispatch_view_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_test_autograd_multiple_dispatch_view_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor segment_reduce_generated_plumbing(const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & indices, const ::std::optional & offsets, int64_t axis, bool unsafe, const ::std::optional & initial) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::segment_reduce::call(data, reduce, lengths, indices, offsets, axis, unsafe, initial); + } + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional indices_value; + std::optional indices_bdim; + if (indices) { + std::tie(indices_value, indices_bdim) = unwrapTensorAtLevel(indices.value(), cur_level); + } + std::optional offsets_value; + std::optional offsets_bdim; + if (offsets) { + std::tie(offsets_value, offsets_bdim) = unwrapTensorAtLevel(offsets.value(), cur_level); + } + auto results = batch_rule(data_value, data_bdim, reduce, lengths_value, lengths_bdim, indices_value, indices_bdim, offsets_value, offsets_bdim, axis, unsafe, initial); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _segment_reduce_backward_generated_plumbing(const at::Tensor & grad, const at::Tensor & output, const at::Tensor & data, c10::string_view reduce, const ::std::optional & lengths, const ::std::optional & offsets, int64_t axis, const ::std::optional & initial) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(data, cur_level) && !isBatchedAtLevel(lengths, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_segment_reduce_backward::call(grad, output, data, reduce, lengths, offsets, axis, initial); + } + auto [grad_value, grad_bdim] = unwrapTensorAtLevel(grad, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [data_value, data_bdim] = unwrapTensorAtLevel(data, cur_level); + std::optional lengths_value; + std::optional lengths_bdim; + if (lengths) { + std::tie(lengths_value, lengths_bdim) = unwrapTensorAtLevel(lengths.value(), cur_level); + } + std::optional offsets_value; + std::optional offsets_bdim; + if (offsets) { + std::tie(offsets_value, offsets_bdim) = unwrapTensorAtLevel(offsets.value(), cur_level); + } + auto results = batch_rule(grad_value, grad_bdim, output_value, output_bdim, data_value, data_bdim, reduce, lengths_value, lengths_bdim, offsets_value, offsets_bdim, axis, initial); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor pad_sequence_generated_plumbing(at::TensorList sequences, bool batch_first, double padding_value, c10::string_view padding_side) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(sequences, cur_level)) { + return at::_ops::pad_sequence::call(sequences, batch_first, padding_value, padding_side); + } + + auto results = batch_rule(sequences, batch_first, padding_value, padding_side); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor flatten_dense_tensors_generated_plumbing(at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::flatten_dense_tensors::call(tensors); + } + + auto results = batch_rule(tensors); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unflatten_dense_tensors_generated_plumbing(const at::Tensor & flat, at::TensorList tensors) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(flat, cur_level) && !isBatchedAtLevel(tensors, cur_level)) { + return at::_ops::unflatten_dense_tensors::call(flat, tensors); + } + auto [flat_value, flat_bdim] = unwrapTensorAtLevel(flat, cur_level); + auto results = batch_rule(flat_value, flat_bdim, tensors); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_from_tensor_list_generated_plumbing(at::TensorList list, ::std::optional dtype, ::std::optional layout, ::std::optional device, ::std::optional pin_memory) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(list, cur_level)) { + return at::_ops::_nested_tensor_from_tensor_list::call(list, dtype, layout, device, pin_memory); + } + + auto results = batch_rule(list, dtype, layout, device, pin_memory); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _fw_primal_copy_generated_plumbing(const at::Tensor & self, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fw_primal_copy::call(self, level); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _make_dual_copy_generated_plumbing(const at::Tensor & primal, const at::Tensor & tangent, int64_t level) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(primal, cur_level) && !isBatchedAtLevel(tangent, cur_level)) { + return at::_ops::_make_dual_copy::call(primal, tangent, level); + } + auto [primal_value, primal_bdim] = unwrapTensorAtLevel(primal, cur_level); + auto [tangent_value, tangent_bdim] = unwrapTensorAtLevel(tangent, cur_level); + auto results = batch_rule(primal_value, primal_bdim, tangent_value, tangent_bdim, level); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_real_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_real_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_as_complex_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_as_complex_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _conj_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_conj_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _neg_view_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_neg_view_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor as_strided_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride, ::std::optional storage_offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::as_strided_copy::call(self, size, stride, storage_offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride, storage_offset); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _sparse_broadcast_to_copy_generated_plumbing(const at::Tensor & self, at::IntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_sparse_broadcast_to_copy::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor diagonal_copy_generated_plumbing(const at::Tensor & self, int64_t offset, int64_t dim1, int64_t dim2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::diagonal_copy::call(self, offset, dim1, dim2); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, offset, dim1, dim2); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor expand_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, bool implicit) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::expand_copy::call(self, size, implicit); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, implicit); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor permute_copy_generated_plumbing(const at::Tensor & self, at::IntArrayRef dims) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::permute_copy::call(self, dims); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dims); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _reshape_alias_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_reshape_alias_copy::call(self, size, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor select_copy_int_generated_plumbing(const at::Tensor & self, int64_t dim, c10::SymInt index) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::select_copy_int::call(self, dim, index); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, index); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor detach_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::detach_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor slice_copy_Tensor_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional start, ::std::optional end, c10::SymInt step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::slice_copy_Tensor::call(self, dim, start, end, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, start, end, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_copy_Tensor_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_copy_Tensor::call(self, split_size, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_size, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector split_with_sizes_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::split_with_sizes_copy::call(self, split_sizes, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, split_sizes, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_copy_dim_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_copy_dim::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor squeeze_copy_dims_generated_plumbing(const at::Tensor & self, at::IntArrayRef dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::squeeze_copy_dims::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor t_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::t_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor transpose_copy_int_generated_plumbing(const at::Tensor & self, int64_t dim0, int64_t dim1) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::transpose_copy_int::call(self, dim0, dim1); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim0, dim1); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unsqueeze_copy_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unsqueeze_copy::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _values_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_values_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor values_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::values_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor crow_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::crow_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor col_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::col_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor ccol_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::ccol_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor row_indices_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::row_indices_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::vector unbind_copy_int_generated_plumbing(const at::Tensor & self, int64_t dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unbind_copy_int::call(self, dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void unbind_copy_int_out_generated_plumbing(const at::Tensor & self, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::unbind_copy_int_out::call(self, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dim, out); +} +template +void split_copy_Tensor_out_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::split_copy_Tensor_out::call(self, split_size, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_size, dim, out); +} +template +void split_with_sizes_copy_out_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::split_with_sizes_copy_out::call(self, split_sizes, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_sizes, dim, out); +} +template +at::Tensor view_copy_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_copy::call(self, size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor view_copy_dtype_generated_plumbing(const at::Tensor & self, at::ScalarType dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::view_copy_dtype::call(self, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor unfold_copy_generated_plumbing(const at::Tensor & self, int64_t dimension, int64_t size, int64_t step) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::unfold_copy::call(self, dimension, size, step); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dimension, size, step); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor alias_copy_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::alias_copy::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor to_padded_tensor_generated_plumbing(const at::Tensor & self, double padding, at::OptionalSymIntArrayRef output_size) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::to_padded_tensor::call(self, padding, output_size); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, padding, output_size); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _jagged_to_padded_dense_forward_generated_plumbing(const at::Tensor & values, at::TensorList offsets, c10::SymIntArrayRef max_lengths, double padding_value) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(values, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_jagged_to_padded_dense_forward::call(values, offsets, max_lengths, padding_value); + } + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(values_value, values_bdim, offsets, max_lengths, padding_value); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _padded_dense_to_jagged_forward_generated_plumbing(const at::Tensor & dense, at::TensorList offsets, ::std::optional total_L) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(dense, cur_level) && !isBatchedAtLevel(offsets, cur_level)) { + return at::_ops::_padded_dense_to_jagged_forward::call(dense, offsets, total_L); + } + auto [dense_value, dense_bdim] = unwrapTensorAtLevel(dense, cur_level); + auto results = batch_rule(dense_value, dense_bdim, offsets, total_L); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_from_padded_tensor_generated_plumbing(const at::Tensor & padded, const at::Tensor & offsets, const at::Tensor & dummy, int64_t ragged_idx, const ::std::optional & min_seqlen, const ::std::optional & max_seqlen, ::std::optional sum_S) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(padded, cur_level) && !isBatchedAtLevel(offsets, cur_level) && !isBatchedAtLevel(dummy, cur_level) && !isBatchedAtLevel(min_seqlen, cur_level) && !isBatchedAtLevel(max_seqlen, cur_level)) { + return at::_ops::_nested_from_padded_tensor::call(padded, offsets, dummy, ragged_idx, min_seqlen, max_seqlen, sum_S); + } + auto [padded_value, padded_bdim] = unwrapTensorAtLevel(padded, cur_level); + auto [offsets_value, offsets_bdim] = unwrapTensorAtLevel(offsets, cur_level); + auto [dummy_value, dummy_bdim] = unwrapTensorAtLevel(dummy, cur_level); + std::optional min_seqlen_value; + std::optional min_seqlen_bdim; + if (min_seqlen) { + std::tie(min_seqlen_value, min_seqlen_bdim) = unwrapTensorAtLevel(min_seqlen.value(), cur_level); + } + std::optional max_seqlen_value; + std::optional max_seqlen_bdim; + if (max_seqlen) { + std::tie(max_seqlen_value, max_seqlen_bdim) = unwrapTensorAtLevel(max_seqlen.value(), cur_level); + } + auto results = batch_rule(padded_value, padded_bdim, offsets_value, offsets_bdim, dummy_value, dummy_bdim, ragged_idx, min_seqlen_value, min_seqlen_bdim, max_seqlen_value, max_seqlen_bdim, sum_S); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _nested_tensor_softmax_with_shape_generated_plumbing(const at::Tensor & self, const at::Tensor & query) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(query, cur_level)) { + return at::_ops::_nested_tensor_softmax_with_shape::call(self, query); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto results = batch_rule(self_value, self_bdim, query_value, query_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _safe_softmax_generated_plumbing(const at::Tensor & self, int64_t dim, ::std::optional dtype) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_safe_softmax::call(self, dim, dtype); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, dim, dtype); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _transformer_encoder_layer_fwd_generated_plumbing(const at::Tensor & src, int64_t embed_dim, int64_t num_heads, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, bool use_gelu, bool norm_first, double eps, const at::Tensor & norm_weight_1, const at::Tensor & norm_bias_1, const at::Tensor & norm_weight_2, const at::Tensor & norm_bias_2, const at::Tensor & ffn_weight_1, const at::Tensor & ffn_bias_1, const at::Tensor & ffn_weight_2, const at::Tensor & ffn_bias_2, const ::std::optional & mask, ::std::optional mask_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(src, cur_level) && !isBatchedAtLevel(qkv_weight, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level) && !isBatchedAtLevel(proj_weight, cur_level) && !isBatchedAtLevel(proj_bias, cur_level) && !isBatchedAtLevel(norm_weight_1, cur_level) && !isBatchedAtLevel(norm_bias_1, cur_level) && !isBatchedAtLevel(norm_weight_2, cur_level) && !isBatchedAtLevel(norm_bias_2, cur_level) && !isBatchedAtLevel(ffn_weight_1, cur_level) && !isBatchedAtLevel(ffn_bias_1, cur_level) && !isBatchedAtLevel(ffn_weight_2, cur_level) && !isBatchedAtLevel(ffn_bias_2, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_transformer_encoder_layer_fwd::call(src, embed_dim, num_heads, qkv_weight, qkv_bias, proj_weight, proj_bias, use_gelu, norm_first, eps, norm_weight_1, norm_bias_1, norm_weight_2, norm_bias_2, ffn_weight_1, ffn_bias_1, ffn_weight_2, ffn_bias_2, mask, mask_type); + } + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto [qkv_weight_value, qkv_weight_bdim] = unwrapTensorAtLevel(qkv_weight, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto [proj_weight_value, proj_weight_bdim] = unwrapTensorAtLevel(proj_weight, cur_level); + auto [proj_bias_value, proj_bias_bdim] = unwrapTensorAtLevel(proj_bias, cur_level); + auto [norm_weight_1_value, norm_weight_1_bdim] = unwrapTensorAtLevel(norm_weight_1, cur_level); + auto [norm_bias_1_value, norm_bias_1_bdim] = unwrapTensorAtLevel(norm_bias_1, cur_level); + auto [norm_weight_2_value, norm_weight_2_bdim] = unwrapTensorAtLevel(norm_weight_2, cur_level); + auto [norm_bias_2_value, norm_bias_2_bdim] = unwrapTensorAtLevel(norm_bias_2, cur_level); + auto [ffn_weight_1_value, ffn_weight_1_bdim] = unwrapTensorAtLevel(ffn_weight_1, cur_level); + auto [ffn_bias_1_value, ffn_bias_1_bdim] = unwrapTensorAtLevel(ffn_bias_1, cur_level); + auto [ffn_weight_2_value, ffn_weight_2_bdim] = unwrapTensorAtLevel(ffn_weight_2, cur_level); + auto [ffn_bias_2_value, ffn_bias_2_bdim] = unwrapTensorAtLevel(ffn_bias_2, cur_level); + std::optional mask_value; + std::optional mask_bdim; + if (mask) { + std::tie(mask_value, mask_bdim) = unwrapTensorAtLevel(mask.value(), cur_level); + } + auto results = batch_rule(src_value, src_bdim, embed_dim, num_heads, qkv_weight_value, qkv_weight_bdim, qkv_bias_value, qkv_bias_bdim, proj_weight_value, proj_weight_bdim, proj_bias_value, proj_bias_bdim, use_gelu, norm_first, eps, norm_weight_1_value, norm_weight_1_bdim, norm_bias_1_value, norm_bias_1_bdim, norm_weight_2_value, norm_weight_2_bdim, norm_bias_2_value, norm_bias_2_bdim, ffn_weight_1_value, ffn_weight_1_bdim, ffn_bias_1_value, ffn_bias_1_bdim, ffn_weight_2_value, ffn_weight_2_bdim, ffn_bias_2_value, ffn_bias_2_bdim, mask_value, mask_bdim, mask_type); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _native_multi_head_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask, bool need_weights, bool average_attn_weights, ::std::optional mask_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(qkv_weight, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level) && !isBatchedAtLevel(proj_weight, cur_level) && !isBatchedAtLevel(proj_bias, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_native_multi_head_attention::call(query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask, need_weights, average_attn_weights, mask_type); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [qkv_weight_value, qkv_weight_bdim] = unwrapTensorAtLevel(qkv_weight, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto [proj_weight_value, proj_weight_bdim] = unwrapTensorAtLevel(proj_weight, cur_level); + auto [proj_bias_value, proj_bias_bdim] = unwrapTensorAtLevel(proj_bias, cur_level); + std::optional mask_value; + std::optional mask_bdim; + if (mask) { + std::tie(mask_value, mask_bdim) = unwrapTensorAtLevel(mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, embed_dim, num_head, qkv_weight_value, qkv_weight_bdim, qkv_bias_value, qkv_bias_bdim, proj_weight_value, proj_weight_bdim, proj_bias_value, proj_bias_bdim, mask_value, mask_bdim, need_weights, average_attn_weights, mask_type); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +at::Tensor scaled_dot_product_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, ::std::optional scale, bool enable_gqa) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level)) { + return at::_ops::scaled_dot_product_attention::call(query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_mask_value, attn_mask_bdim, dropout_p, is_causal, scale, enable_gqa); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +::std::tuple _scaled_dot_product_attention_math_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale, bool enable_gqa) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level) && !isBatchedAtLevel(dropout_mask, cur_level)) { + return at::_ops::_scaled_dot_product_attention_math::call(query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale, enable_gqa); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + std::optional dropout_mask_value; + std::optional dropout_mask_bdim; + if (dropout_mask) { + std::tie(dropout_mask_value, dropout_mask_bdim) = unwrapTensorAtLevel(dropout_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_mask_value, attn_mask_bdim, dropout_p, is_causal, dropout_mask_value, dropout_mask_bdim, scale, enable_gqa); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_attention_math_for_mps_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_mask, double dropout_p, bool is_causal, const ::std::optional & dropout_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level) && !isBatchedAtLevel(dropout_mask, cur_level)) { + return at::_ops::_scaled_dot_product_attention_math_for_mps::call(query, key, value, attn_mask, dropout_p, is_causal, dropout_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + std::optional dropout_mask_value; + std::optional dropout_mask_bdim; + if (dropout_mask) { + std::tie(dropout_mask_value, dropout_mask_bdim) = unwrapTensorAtLevel(dropout_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_mask_value, attn_mask_bdim, dropout_p, is_causal, dropout_mask_value, dropout_mask_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention::call(query, key, value, dropout_p, is_causal, return_debug_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, dropout_p, is_causal, return_debug_mask, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), std::get<8>(results), std::get<9>(results), makeBatched(std::get<10>(results), std::get<11>(results), cur_level), makeBatched(std::get<12>(results), std::get<13>(results), cur_level), makeBatched(std::get<14>(results), std::get<15>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_for_cpu_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_mask, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu::call(query, key, value, dropout_p, is_causal, attn_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, dropout_p, is_causal, attn_mask_value, attn_mask_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_flash_attention_for_cpu_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, double dropout_p, bool is_causal, const ::std::optional & attn_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(attn_mask, cur_level)) { + return at::_ops::_scaled_dot_product_flash_attention_for_cpu_backward::call(grad_out, query, key, value, out, logsumexp, dropout_p, is_causal, attn_mask, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + std::optional attn_mask_value; + std::optional attn_mask_bdim; + if (attn_mask) { + std::tie(attn_mask_value, attn_mask_bdim) = unwrapTensorAtLevel(attn_mask.value(), cur_level); + } + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, dropout_p, is_causal, attn_mask_value, attn_mask_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_fused_attention_overrideable_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, ::std::array grad_input_mask, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & philox_seed, const at::Tensor & philox_offset, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_scaled_dot_product_fused_attention_overrideable_backward::call(grad_out, query, key, value, attn_bias, grad_input_mask, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, philox_seed, philox_offset, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, grad_input_mask, out_value, out_bdim, logsumexp_value, logsumexp_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_efficient_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level)) { + return at::_ops::_scaled_dot_product_efficient_attention::call(query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_bias_value; + std::optional attn_bias_bdim; + if (attn_bias) { + std::tie(attn_bias_value, attn_bias_bdim) = unwrapTensorAtLevel(attn_bias.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, compute_log_sumexp, dropout_p, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_efficient_attention_backward_generated_plumbing(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & attn_bias, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, double dropout_p, ::std::array grad_input_mask, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out_, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_scaled_dot_product_efficient_attention_backward::call(grad_out_, query, key, value, attn_bias, out, logsumexp, philox_seed, philox_offset, dropout_p, grad_input_mask, is_causal, scale); + } + auto [grad_out__value, grad_out__bdim] = unwrapTensorAtLevel(grad_out_, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto results = batch_rule(grad_out__value, grad_out__bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, dropout_p, grad_input_mask, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_cudnn_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & attn_bias, bool compute_log_sumexp, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(attn_bias, cur_level)) { + return at::_ops::_scaled_dot_product_cudnn_attention::call(query, key, value, attn_bias, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional attn_bias_value; + std::optional attn_bias_bdim; + if (attn_bias) { + std::tie(attn_bias_value, attn_bias_bdim) = unwrapTensorAtLevel(attn_bias.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, attn_bias_value, attn_bias_bdim, compute_log_sumexp, dropout_p, is_causal, return_debug_mask, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), std::get<8>(results), std::get<9>(results), makeBatched(std::get<10>(results), std::get<11>(results), cur_level), makeBatched(std::get<12>(results), std::get<13>(results), cur_level), makeBatched(std::get<14>(results), std::get<15>(results), cur_level)); +} +template +::std::tuple _scaled_dot_product_cudnn_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & philox_seed, const at::Tensor & philox_offset, const at::Tensor & attn_bias, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, ::std::optional scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level) && !isBatchedAtLevel(attn_bias, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level)) { + return at::_ops::_scaled_dot_product_cudnn_attention_backward::call(grad_out, query, key, value, out, logsumexp, philox_seed, philox_offset, attn_bias, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, scale); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + auto [attn_bias_value, attn_bias_bdim] = unwrapTensorAtLevel(attn_bias, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, attn_bias_value, attn_bias_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, scale); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _flash_attention_forward_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & cum_seq_q, const ::std::optional & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, bool return_debug_mask, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right, const ::std::optional & seqused_k, const ::std::optional & alibi_slopes) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(seqused_k, cur_level) && !isBatchedAtLevel(alibi_slopes, cur_level)) { + return at::_ops::_flash_attention_forward::call(query, key, value, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left, window_size_right, seqused_k, alibi_slopes); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + std::optional cum_seq_q_value; + std::optional cum_seq_q_bdim; + if (cum_seq_q) { + std::tie(cum_seq_q_value, cum_seq_q_bdim) = unwrapTensorAtLevel(cum_seq_q.value(), cur_level); + } + std::optional cum_seq_k_value; + std::optional cum_seq_k_bdim; + if (cum_seq_k) { + std::tie(cum_seq_k_value, cum_seq_k_bdim) = unwrapTensorAtLevel(cum_seq_k.value(), cur_level); + } + std::optional seqused_k_value; + std::optional seqused_k_bdim; + if (seqused_k) { + std::tie(seqused_k_value, seqused_k_bdim) = unwrapTensorAtLevel(seqused_k.value(), cur_level); + } + std::optional alibi_slopes_value; + std::optional alibi_slopes_bdim; + if (alibi_slopes) { + std::tie(alibi_slopes_value, alibi_slopes_bdim) = unwrapTensorAtLevel(alibi_slopes.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, return_debug_mask, scale, window_size_left, window_size_right, seqused_k_value, seqused_k_bdim, alibi_slopes_value, alibi_slopes_bdim); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +::std::tuple _flash_attention_backward_generated_plumbing(const at::Tensor & grad_out, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const at::Tensor & out, const at::Tensor & logsumexp, const at::Tensor & cum_seq_q, const at::Tensor & cum_seq_k, c10::SymInt max_q, c10::SymInt max_k, double dropout_p, bool is_causal, const at::Tensor & rng_state, const at::Tensor & unused, ::std::optional scale, ::std::optional window_size_left, ::std::optional window_size_right) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(cum_seq_q, cur_level) && !isBatchedAtLevel(cum_seq_k, cur_level) && !isBatchedAtLevel(rng_state, cur_level) && !isBatchedAtLevel(unused, cur_level)) { + return at::_ops::_flash_attention_backward::call(grad_out, query, key, value, out, logsumexp, cum_seq_q, cum_seq_k, max_q, max_k, dropout_p, is_causal, rng_state, unused, scale, window_size_left, window_size_right); + } + auto [grad_out_value, grad_out_bdim] = unwrapTensorAtLevel(grad_out, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [cum_seq_q_value, cum_seq_q_bdim] = unwrapTensorAtLevel(cum_seq_q, cur_level); + auto [cum_seq_k_value, cum_seq_k_bdim] = unwrapTensorAtLevel(cum_seq_k, cur_level); + auto [rng_state_value, rng_state_bdim] = unwrapTensorAtLevel(rng_state, cur_level); + auto [unused_value, unused_bdim] = unwrapTensorAtLevel(unused, cur_level); + auto results = batch_rule(grad_out_value, grad_out_bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, out_value, out_bdim, logsumexp_value, logsumexp_bdim, cum_seq_q_value, cum_seq_q_bdim, cum_seq_k_value, cum_seq_k_bdim, max_q, max_k, dropout_p, is_causal, rng_state_value, rng_state_bdim, unused_value, unused_bdim, scale, window_size_left, window_size_right); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +::std::tuple _efficient_attention_backward_generated_plumbing(const at::Tensor & grad_out_, const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, const ::std::optional & bias, const at::Tensor & out, const ::std::optional & cu_seqlens_q, const ::std::optional & cu_seqlens_k, c10::SymInt max_seqlen_q, c10::SymInt max_seqlen_k, const at::Tensor & logsumexp, double dropout_p, const at::Tensor & philox_seed, const at::Tensor & philox_offset, int64_t custom_mask_type, bool bias_requires_grad, ::std::optional scale, ::std::optional num_splits_key, ::std::optional window_size, bool shared_storage_dqdkdv) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_out_, cur_level) && !isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(out, cur_level) && !isBatchedAtLevel(cu_seqlens_q, cur_level) && !isBatchedAtLevel(cu_seqlens_k, cur_level) && !isBatchedAtLevel(logsumexp, cur_level) && !isBatchedAtLevel(philox_seed, cur_level) && !isBatchedAtLevel(philox_offset, cur_level)) { + return at::_ops::_efficient_attention_backward::call(grad_out_, query, key, value, bias, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, logsumexp, dropout_p, philox_seed, philox_offset, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + } + auto [grad_out__value, grad_out__bdim] = unwrapTensorAtLevel(grad_out_, cur_level); + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [out_value, out_bdim] = unwrapTensorAtLevel(out, cur_level); + auto [logsumexp_value, logsumexp_bdim] = unwrapTensorAtLevel(logsumexp, cur_level); + auto [philox_seed_value, philox_seed_bdim] = unwrapTensorAtLevel(philox_seed, cur_level); + auto [philox_offset_value, philox_offset_bdim] = unwrapTensorAtLevel(philox_offset, cur_level); + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + std::optional cu_seqlens_q_value; + std::optional cu_seqlens_q_bdim; + if (cu_seqlens_q) { + std::tie(cu_seqlens_q_value, cu_seqlens_q_bdim) = unwrapTensorAtLevel(cu_seqlens_q.value(), cur_level); + } + std::optional cu_seqlens_k_value; + std::optional cu_seqlens_k_bdim; + if (cu_seqlens_k) { + std::tie(cu_seqlens_k_value, cu_seqlens_k_bdim) = unwrapTensorAtLevel(cu_seqlens_k.value(), cur_level); + } + auto results = batch_rule(grad_out__value, grad_out__bdim, query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, bias_value, bias_bdim, out_value, out_bdim, cu_seqlens_q_value, cu_seqlens_q_bdim, cu_seqlens_k_value, cu_seqlens_k_bdim, max_seqlen_q, max_seqlen_k, logsumexp_value, logsumexp_bdim, dropout_p, philox_seed_value, philox_seed_bdim, philox_offset_value, philox_offset_bdim, custom_mask_type, bias_requires_grad, scale, num_splits_key, window_size, shared_storage_dqdkdv); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +at::Tensor _triton_scaled_dot_attention_generated_plumbing(const at::Tensor & q, const at::Tensor & k, const at::Tensor & v, double dropout_p) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(q, cur_level) && !isBatchedAtLevel(k, cur_level) && !isBatchedAtLevel(v, cur_level)) { + return at::_ops::_triton_scaled_dot_attention::call(q, k, v, dropout_p); + } + auto [q_value, q_bdim] = unwrapTensorAtLevel(q, cur_level); + auto [k_value, k_bdim] = unwrapTensorAtLevel(k, cur_level); + auto [v_value, v_bdim] = unwrapTensorAtLevel(v, cur_level); + auto results = batch_rule(q_value, q_bdim, k_value, k_bdim, v_value, v_bdim, dropout_p); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor & _fill_mem_eff_dropout_mask__generated_plumbing(at::Tensor & self, double dropout_p, int64_t seed, int64_t offset) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_inplace_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_fill_mem_eff_dropout_mask_::call(self, dropout_p, seed, offset); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, dropout_p, seed, offset); + return self; +} +template +at::Tensor _triton_multi_head_attention_generated_plumbing(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value, int64_t embed_dim, int64_t num_head, const at::Tensor & qkv_weight, const at::Tensor & qkv_bias, const at::Tensor & proj_weight, const at::Tensor & proj_bias, const ::std::optional & mask) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(query, cur_level) && !isBatchedAtLevel(key, cur_level) && !isBatchedAtLevel(value, cur_level) && !isBatchedAtLevel(qkv_weight, cur_level) && !isBatchedAtLevel(qkv_bias, cur_level) && !isBatchedAtLevel(proj_weight, cur_level) && !isBatchedAtLevel(proj_bias, cur_level) && !isBatchedAtLevel(mask, cur_level)) { + return at::_ops::_triton_multi_head_attention::call(query, key, value, embed_dim, num_head, qkv_weight, qkv_bias, proj_weight, proj_bias, mask); + } + auto [query_value, query_bdim] = unwrapTensorAtLevel(query, cur_level); + auto [key_value, key_bdim] = unwrapTensorAtLevel(key, cur_level); + auto [value_value, value_bdim] = unwrapTensorAtLevel(value, cur_level); + auto [qkv_weight_value, qkv_weight_bdim] = unwrapTensorAtLevel(qkv_weight, cur_level); + auto [qkv_bias_value, qkv_bias_bdim] = unwrapTensorAtLevel(qkv_bias, cur_level); + auto [proj_weight_value, proj_weight_bdim] = unwrapTensorAtLevel(proj_weight, cur_level); + auto [proj_bias_value, proj_bias_bdim] = unwrapTensorAtLevel(proj_bias, cur_level); + std::optional mask_value; + std::optional mask_bdim; + if (mask) { + std::tie(mask_value, mask_bdim) = unwrapTensorAtLevel(mask.value(), cur_level); + } + auto results = batch_rule(query_value, query_bdim, key_value, key_bdim, value_value, value_bdim, embed_dim, num_head, qkv_weight_value, qkv_weight_bdim, qkv_bias_value, qkv_bias_bdim, proj_weight_value, proj_weight_bdim, proj_bias_value, proj_bias_bdim, mask_value, mask_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_airy_ai_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_airy_ai::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_j0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_j0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_j1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_j1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_y0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_y0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_bessel_y1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_bessel_y1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_t_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_t::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_t_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_t_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_t_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_t_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_u_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_u::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_u_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_u_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_u_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_u_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_v_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_v::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_v_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_v_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_v_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_v_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_w_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_w::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_w_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_chebyshev_polynomial_w_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_chebyshev_polynomial_w_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_chebyshev_polynomial_w_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_h_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_h::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_h_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_h_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_h_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_hermite_polynomial_h_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_he_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_he::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_he_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_hermite_polynomial_he_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_hermite_polynomial_he_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_hermite_polynomial_he_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_laguerre_polynomial_l_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_laguerre_polynomial_l::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_laguerre_polynomial_l_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_laguerre_polynomial_l_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_laguerre_polynomial_l_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_laguerre_polynomial_l_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_legendre_polynomial_p_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_legendre_polynomial_p::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_legendre_polynomial_p_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_legendre_polynomial_p_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_legendre_polynomial_p_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_legendre_polynomial_p_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_i0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_i0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_i1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_i1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_k0_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_k0::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_modified_bessel_k1_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::special_modified_bessel_k1::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_scaled_modified_bessel_k0_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_scaled_modified_bessel_k0::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_scaled_modified_bessel_k1_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_scaled_modified_bessel_k1::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_t_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_t::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_t_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_t_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_t_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_t_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_u_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_u::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_u_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_u_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_u_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_u_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_v_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_v::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_v_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_v_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_v_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_v_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_w_generated_plumbing(const at::Tensor & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level) && !isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_w::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x_value, x_bdim, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_w_x_scalar_generated_plumbing(const at::Scalar & x, const at::Tensor & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(n, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_w_x_scalar::call(x, n); + } + auto [n_value, n_bdim] = unwrapTensorAtLevel(n, cur_level); + auto results = batch_rule(x, n_value, n_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_shifted_chebyshev_polynomial_w_n_scalar_generated_plumbing(const at::Tensor & x, const at::Scalar & n) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_shifted_chebyshev_polynomial_w_n_scalar::call(x, n); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim, n); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor special_spherical_bessel_j0_generated_plumbing(const at::Tensor & x) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(x, cur_level)) { + return at::_ops::special_spherical_bessel_j0::call(x); + } + auto [x_value, x_bdim] = unwrapTensorAtLevel(x, cur_level); + auto results = batch_rule(x_value, x_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _foobar_generated_plumbing(const at::Tensor & self, bool arg1, bool arg2, bool arg3) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foobar::call(self, arg1, arg2, arg3); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, arg1, arg2, arg3); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _fused_adam__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam_::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adam__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam__tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adamw__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw_::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adamw__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw__tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_sgd__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd_::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_sgd__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd__tensor_lr::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr_value, lr_bdim, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adagrad__generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad_::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _fused_adagrad__tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad__tensor_lr::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr_value, lr_bdim, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); +} +template +void _propagate_xla_data_generated_plumbing(const at::Tensor & input, const at::Tensor & output) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(output, cur_level)) { + return at::_ops::_propagate_xla_data::call(input, output); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + batch_rule(input_value, input_bdim, output_value, output_bdim); +} +template +void _cudnn_rnn_backward_out_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, c10::SymInt hidden_size, c10::SymInt proj_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, c10::SymIntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level) && !isBatchedAtLevel(out0, cur_level) && !isBatchedAtLevel(out1, cur_level) && !isBatchedAtLevel(out2, cur_level) && !isBatchedAtLevel(out3, cur_level)) { + return at::_ops::_cudnn_rnn_backward_out::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + auto [out0_value, out0_bdim] = unwrapTensorAtLevel(out0, cur_level); + auto [out1_value, out1_bdim] = unwrapTensorAtLevel(out1, cur_level); + auto [out2_value, out2_bdim] = unwrapTensorAtLevel(out2, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, proj_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask, out0_value, out0_bdim, out1_value, out1_bdim, out2_value, out2_bdim, out3); +} +template +at::Tensor bernoulli_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(p, cur_level)) { + return at::_ops::bernoulli_Tensor::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [p_value, p_bdim] = unwrapTensorAtLevel(p, cur_level); + auto results = batch_rule(self_value, self_bdim, p_value, p_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor embedding_renorm_generated_plumbing(const at::Tensor & self, const at::Tensor & indices, double max_norm, double norm_type) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level)) { + return at::_ops::embedding_renorm::call(self, indices, max_norm, norm_type); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [indices_value, indices_bdim] = unwrapTensorAtLevel(indices, cur_level); + auto results = batch_rule(self_value, self_bdim, indices_value, indices_bdim, max_norm, norm_type); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor resize_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::resize::call(self, size, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _resize_output_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef size, at::Device device) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_resize_output::call(self, size, device); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, device); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _index_put_impl_generated_plumbing(const at::Tensor & self, const c10::List<::std::optional> & indices, const at::Tensor & values, bool accumulate, bool unsafe) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(indices, cur_level) && !isBatchedAtLevel(values, cur_level)) { + return at::_ops::_index_put_impl::call(self, indices, values, accumulate, unsafe); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [values_value, values_bdim] = unwrapTensorAtLevel(values, cur_level); + auto results = batch_rule(self_value, self_bdim, indices, values_value, values_bdim, accumulate, unsafe); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void miopen_rnn_backward_out_generated_plumbing(const at::Tensor & input, at::TensorList weight, int64_t weight_stride0, const at::Tensor & weight_buf, const at::Tensor & hx, const ::std::optional & cx, const at::Tensor & output, const ::std::optional & grad_output, const ::std::optional & grad_hy, const ::std::optional & grad_cy, int64_t mode, int64_t hidden_size, int64_t num_layers, bool batch_first, double dropout, bool train, bool bidirectional, at::IntArrayRef batch_sizes, const ::std::optional & dropout_state, const at::Tensor & reserve, ::std::array output_mask, at::Tensor & out0, at::Tensor & out1, at::Tensor & out2, at::TensorList out3) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(weight_buf, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(cx, cur_level) && !isBatchedAtLevel(output, cur_level) && !isBatchedAtLevel(grad_output, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(dropout_state, cur_level) && !isBatchedAtLevel(reserve, cur_level) && !isBatchedAtLevel(out0, cur_level) && !isBatchedAtLevel(out1, cur_level) && !isBatchedAtLevel(out2, cur_level) && !isBatchedAtLevel(out3, cur_level)) { + return at::_ops::miopen_rnn_backward_out::call(input, weight, weight_stride0, weight_buf, hx, cx, output, grad_output, grad_hy, grad_cy, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state, reserve, output_mask, out0, out1, out2, out3); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [weight_buf_value, weight_buf_bdim] = unwrapTensorAtLevel(weight_buf, cur_level); + auto [hx_value, hx_bdim] = unwrapTensorAtLevel(hx, cur_level); + auto [output_value, output_bdim] = unwrapTensorAtLevel(output, cur_level); + auto [reserve_value, reserve_bdim] = unwrapTensorAtLevel(reserve, cur_level); + auto [out0_value, out0_bdim] = unwrapTensorAtLevel(out0, cur_level); + auto [out1_value, out1_bdim] = unwrapTensorAtLevel(out1, cur_level); + auto [out2_value, out2_bdim] = unwrapTensorAtLevel(out2, cur_level); + std::optional cx_value; + std::optional cx_bdim; + if (cx) { + std::tie(cx_value, cx_bdim) = unwrapTensorAtLevel(cx.value(), cur_level); + } + std::optional grad_output_value; + std::optional grad_output_bdim; + if (grad_output) { + std::tie(grad_output_value, grad_output_bdim) = unwrapTensorAtLevel(grad_output.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + std::optional dropout_state_value; + std::optional dropout_state_bdim; + if (dropout_state) { + std::tie(dropout_state_value, dropout_state_bdim) = unwrapTensorAtLevel(dropout_state.value(), cur_level); + } + batch_rule(input_value, input_bdim, weight, weight_stride0, weight_buf_value, weight_buf_bdim, hx_value, hx_bdim, cx_value, cx_bdim, output_value, output_bdim, grad_output_value, grad_output_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, mode, hidden_size, num_layers, batch_first, dropout, train, bidirectional, batch_sizes, dropout_state_value, dropout_state_bdim, reserve_value, reserve_bdim, output_mask, out0_value, out0_bdim, out1_value, out1_bdim, out2_value, out2_bdim, out3); +} +template +::std::tuple _native_batch_norm_legit_functional_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, bool training, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_native_batch_norm_legit_functional::call(input, weight, bias, running_mean, running_var, training, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [running_mean_value, running_mean_bdim] = unwrapTensorAtLevel(running_mean, cur_level); + auto [running_var_value, running_var_bdim] = unwrapTensorAtLevel(running_var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, training, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void unsafe_split_Tensor_out_generated_plumbing(const at::Tensor & self, c10::SymInt split_size, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::unsafe_split_Tensor_out::call(self, split_size, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_size, dim, out); +} +template +void unsafe_split_with_sizes_out_generated_plumbing(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::unsafe_split_with_sizes_out::call(self, split_sizes, dim, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + batch_rule(self_value, self_bdim, split_sizes, dim, out); +} +template +::std::tuple _batch_norm_with_update_functional_generated_plumbing(const at::Tensor & input, const ::std::optional & weight, const ::std::optional & bias, const at::Tensor & running_mean, const at::Tensor & running_var, double momentum, double eps) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(bias, cur_level) && !isBatchedAtLevel(running_mean, cur_level) && !isBatchedAtLevel(running_var, cur_level)) { + return at::_ops::_batch_norm_with_update_functional::call(input, weight, bias, running_mean, running_var, momentum, eps); + } + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [running_mean_value, running_mean_bdim] = unwrapTensorAtLevel(running_mean, cur_level); + auto [running_var_value, running_var_bdim] = unwrapTensorAtLevel(running_var, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + std::optional bias_value; + std::optional bias_bdim; + if (bias) { + std::tie(bias_value, bias_bdim) = unwrapTensorAtLevel(bias.value(), cur_level); + } + auto results = batch_rule(input_value, input_bdim, weight_value, weight_bdim, bias_value, bias_bdim, running_mean_value, running_mean_bdim, running_var_value, running_var_bdim, momentum, eps); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level)); +} +template +at::Tensor resize_as_generated_plumbing(const at::Tensor & self, const at::Tensor & the_template, ::std::optional memory_format) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(the_template, cur_level)) { + return at::_ops::resize_as::call(self, the_template, memory_format); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [the_template_value, the_template_bdim] = unwrapTensorAtLevel(the_template, cur_level); + auto results = batch_rule(self_value, self_bdim, the_template_value, the_template_bdim, memory_format); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor resize_as_sparse_generated_plumbing(const at::Tensor & self, const at::Tensor & the_template) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(the_template, cur_level)) { + return at::_ops::resize_as_sparse::call(self, the_template); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [the_template_value, the_template_bdim] = unwrapTensorAtLevel(the_template, cur_level); + auto results = batch_rule(self_value, self_bdim, the_template_value, the_template_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor zero_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::zero::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_resize_generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor sparse_resize_and_clear_generated_plumbing(const at::Tensor & self, at::IntArrayRef size, int64_t sparse_dim, int64_t dense_dim) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::sparse_resize_and_clear::call(self, size, sparse_dim, dense_dim); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, size, sparse_dim, dense_dim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor _coalesced_generated_plumbing(const at::Tensor & self, bool coalesced) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_coalesced::call(self, coalesced); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, coalesced); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor copy_sparse_to_sparse_generated_plumbing(const at::Tensor & self, const at::Tensor & src, bool non_blocking) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level)) { + return at::_ops::copy_sparse_to_sparse::call(self, src, non_blocking); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [src_value, src_bdim] = unwrapTensorAtLevel(src, cur_level); + auto results = batch_rule(self_value, self_bdim, src_value, src_bdim, non_blocking); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void quantize_per_tensor_tensors_out_generated_plumbing(at::TensorList tensors, const at::Tensor & scales, const at::Tensor & zero_points, at::ScalarType dtype, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level) && !isBatchedAtLevel(scales, cur_level) && !isBatchedAtLevel(zero_points, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::quantize_per_tensor_tensors_out::call(tensors, scales, zero_points, dtype, out); + } + auto [scales_value, scales_bdim] = unwrapTensorAtLevel(scales, cur_level); + auto [zero_points_value, zero_points_bdim] = unwrapTensorAtLevel(zero_points, cur_level); + batch_rule(tensors, scales_value, scales_bdim, zero_points_value, zero_points_bdim, dtype, out); +} +template +void dequantize_tensors_out_generated_plumbing(at::TensorList tensors, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(tensors, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::dequantize_tensors_out::call(tensors, out); + } + + batch_rule(tensors, out); +} +template +::std::tuple _fused_moving_avg_obs_fq_helper_functional_generated_plumbing(const at::Tensor & self, const at::Tensor & observer_on, const at::Tensor & fake_quant_on, const at::Tensor & running_min, const at::Tensor & running_max, const at::Tensor & scale, const at::Tensor & zero_point, double averaging_const, int64_t quant_min, int64_t quant_max, int64_t ch_axis, bool per_row_fake_quant, bool symmetric_quant) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(observer_on, cur_level) && !isBatchedAtLevel(fake_quant_on, cur_level) && !isBatchedAtLevel(running_min, cur_level) && !isBatchedAtLevel(running_max, cur_level) && !isBatchedAtLevel(scale, cur_level) && !isBatchedAtLevel(zero_point, cur_level)) { + return at::_ops::_fused_moving_avg_obs_fq_helper_functional::call(self, observer_on, fake_quant_on, running_min, running_max, scale, zero_point, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [observer_on_value, observer_on_bdim] = unwrapTensorAtLevel(observer_on, cur_level); + auto [fake_quant_on_value, fake_quant_on_bdim] = unwrapTensorAtLevel(fake_quant_on, cur_level); + auto [running_min_value, running_min_bdim] = unwrapTensorAtLevel(running_min, cur_level); + auto [running_max_value, running_max_bdim] = unwrapTensorAtLevel(running_max, cur_level); + auto [scale_value, scale_bdim] = unwrapTensorAtLevel(scale, cur_level); + auto [zero_point_value, zero_point_bdim] = unwrapTensorAtLevel(zero_point, cur_level); + auto results = batch_rule(self_value, self_bdim, observer_on_value, observer_on_bdim, fake_quant_on_value, fake_quant_on_bdim, running_min_value, running_min_bdim, running_max_value, running_max_bdim, scale_value, scale_bdim, zero_point_value, zero_point_bdim, averaging_const, quant_min, quant_max, ch_axis, per_row_fake_quant, symmetric_quant); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level), makeBatched(std::get<4>(results), std::get<5>(results), cur_level), makeBatched(std::get<6>(results), std::get<7>(results), cur_level), makeBatched(std::get<8>(results), std::get<9>(results), cur_level), makeBatched(std::get<10>(results), std::get<11>(results), cur_level)); +} +template +void lstm_mps_backward_out_generated_plumbing(const ::std::optional & grad_y, const ::std::optional & grad_hy, const ::std::optional & grad_cy, const at::Tensor & z_state, const at::Tensor & cell_state_fwd, const at::Tensor & input, const at::Tensor & layersOutputs, at::TensorList hx, at::TensorList params, bool has_biases, int64_t num_layers, double dropout, bool train, bool bidirectional, bool batch_first, at::Tensor & out0, at::TensorList out1, at::TensorList out2) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(grad_y, cur_level) && !isBatchedAtLevel(grad_hy, cur_level) && !isBatchedAtLevel(grad_cy, cur_level) && !isBatchedAtLevel(z_state, cur_level) && !isBatchedAtLevel(cell_state_fwd, cur_level) && !isBatchedAtLevel(input, cur_level) && !isBatchedAtLevel(layersOutputs, cur_level) && !isBatchedAtLevel(hx, cur_level) && !isBatchedAtLevel(params, cur_level) && !isBatchedAtLevel(out0, cur_level) && !isBatchedAtLevel(out1, cur_level) && !isBatchedAtLevel(out2, cur_level)) { + return at::_ops::lstm_mps_backward_out::call(grad_y, grad_hy, grad_cy, z_state, cell_state_fwd, input, layersOutputs, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0, out1, out2); + } + auto [z_state_value, z_state_bdim] = unwrapTensorAtLevel(z_state, cur_level); + auto [cell_state_fwd_value, cell_state_fwd_bdim] = unwrapTensorAtLevel(cell_state_fwd, cur_level); + auto [input_value, input_bdim] = unwrapTensorAtLevel(input, cur_level); + auto [layersOutputs_value, layersOutputs_bdim] = unwrapTensorAtLevel(layersOutputs, cur_level); + auto [out0_value, out0_bdim] = unwrapTensorAtLevel(out0, cur_level); + std::optional grad_y_value; + std::optional grad_y_bdim; + if (grad_y) { + std::tie(grad_y_value, grad_y_bdim) = unwrapTensorAtLevel(grad_y.value(), cur_level); + } + std::optional grad_hy_value; + std::optional grad_hy_bdim; + if (grad_hy) { + std::tie(grad_hy_value, grad_hy_bdim) = unwrapTensorAtLevel(grad_hy.value(), cur_level); + } + std::optional grad_cy_value; + std::optional grad_cy_bdim; + if (grad_cy) { + std::tie(grad_cy_value, grad_cy_bdim) = unwrapTensorAtLevel(grad_cy.value(), cur_level); + } + batch_rule(grad_y_value, grad_y_bdim, grad_hy_value, grad_hy_bdim, grad_cy_value, grad_cy_bdim, z_state_value, z_state_bdim, cell_state_fwd_value, cell_state_fwd_bdim, input_value, input_bdim, layersOutputs_value, layersOutputs_bdim, hx, params, has_biases, num_layers, dropout, train, bidirectional, batch_first, out0_value, out0_bdim, out1, out2); +} +template +at::Tensor set_source_Storage_generated_plumbing(const at::Tensor & self, at::Storage source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::set_source_Storage::call(self, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor set_source_Storage_storage_offset_generated_plumbing(const at::Tensor & self, at::Storage source, c10::SymInt storage_offset, c10::SymIntArrayRef size, c10::SymIntArrayRef stride) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::set_source_Storage_storage_offset::call(self, source, storage_offset, size, stride); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, source, storage_offset, size, stride); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor set_source_Tensor_generated_plumbing(const at::Tensor & self, const at::Tensor & source) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(source, cur_level)) { + return at::_ops::set_source_Tensor::call(self, source); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [source_value, source_bdim] = unwrapTensorAtLevel(source, cur_level); + auto results = batch_rule(self_value, self_bdim, source_value, source_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor set_generated_plumbing(const at::Tensor & self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::set::call(self); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor random_from_generated_plumbing(const at::Tensor & self, int64_t from, ::std::optional to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random_from::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, from, to, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor random_to_generated_plumbing(const at::Tensor & self, int64_t to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random_to::call(self, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, to, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor random_generated_plumbing(const at::Tensor & self, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::random::call(self, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor uniform_generated_plumbing(const at::Tensor & self, double from, double to, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::uniform::call(self, from, to, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, from, to, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor cauchy_generated_plumbing(const at::Tensor & self, double median, double sigma, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::cauchy::call(self, median, sigma, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, median, sigma, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor log_normal_generated_plumbing(const at::Tensor & self, double mean, double std, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::log_normal::call(self, mean, std, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, mean, std, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor exponential_generated_plumbing(const at::Tensor & self, double lambd, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::exponential::call(self, lambd, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, lambd, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +at::Tensor geometric_generated_plumbing(const at::Tensor & self, double p, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::geometric::call(self, p, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto results = batch_rule(self_value, self_bdim, p, generator); + return makeBatched(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _histogramdd_bin_edges_out_generated_plumbing(const at::Tensor & self, at::IntArrayRef bins, ::std::optional> range, const ::std::optional & weight, bool density, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(weight, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_histogramdd_bin_edges_out::call(self, bins, range, weight, density, out); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + std::optional weight_value; + std::optional weight_bdim; + if (weight) { + std::tie(weight_value, weight_bdim) = unwrapTensorAtLevel(weight.value(), cur_level); + } + batch_rule(self_value, self_bdim, bins, range, weight_value, weight_bdim, density, out); +} +template +void _amp_foreach_non_finite_check_and_unscale_out_generated_plumbing(at::TensorList self, at::Tensor & found_inf, const at::Tensor & inv_scale, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(inv_scale, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale_out::call(self, found_inf, inv_scale, out); + } + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto [inv_scale_value, inv_scale_bdim] = unwrapTensorAtLevel(inv_scale, cur_level); + batch_rule(self, found_inf_value, found_inf_bdim, inv_scale_value, inv_scale_bdim, out); +} +template +::std::tuple<::std::vector,at::Tensor> _amp_foreach_non_finite_check_and_unscale_generated_plumbing(at::TensorList self, const at::Tensor & found_inf, const at::Tensor & inv_scale) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(inv_scale, cur_level)) { + return at::_ops::_amp_foreach_non_finite_check_and_unscale::call(self, found_inf, inv_scale); + } + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto [inv_scale_value, inv_scale_bdim] = unwrapTensorAtLevel(inv_scale, cur_level); + auto results = batch_rule(self, found_inf_value, found_inf_bdim, inv_scale_value, inv_scale_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +::std::tuple _amp_update_scale_generated_plumbing(const at::Tensor & self, const at::Tensor & growth_tracker, const at::Tensor & found_inf, double scale_growth_factor, double scale_backoff_factor, int64_t growth_interval) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(growth_tracker, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_amp_update_scale::call(self, growth_tracker, found_inf, scale_growth_factor, scale_backoff_factor, growth_interval); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [growth_tracker_value, growth_tracker_bdim] = unwrapTensorAtLevel(growth_tracker, cur_level); + auto [found_inf_value, found_inf_bdim] = unwrapTensorAtLevel(found_inf, cur_level); + auto results = batch_rule(self_value, self_bdim, growth_tracker_value, growth_tracker_bdim, found_inf_value, found_inf_bdim, scale_growth_factor, scale_backoff_factor, growth_interval); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _foreach_add_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_add_List_out_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_List_out::call(self, other, alpha, out); + } + + batch_rule(self, other, alpha, out); +} +template +void _foreach_add_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_add_Tensor_out_generated_plumbing(at::TensorList self, const at::Tensor & other, const at::Scalar & alpha, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_add_Tensor_out::call(self, other, alpha, out); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, alpha, out); +} +template +void _foreach_sub_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sub_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_sub_List_out_generated_plumbing(at::TensorList self, at::TensorList other, const at::Scalar & alpha, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sub_List_out::call(self, other, alpha, out); + } + + batch_rule(self, other, alpha, out); +} +template +void _foreach_sub_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sub_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_mul_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_mul_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_mul_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_mul_Tensor_out_generated_plumbing(at::TensorList self, const at::Tensor & other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_mul_Tensor_out::call(self, other, out); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, out); +} +template +void _foreach_div_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_div_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_div_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_div_Tensor_out_generated_plumbing(at::TensorList self, const at::Tensor & other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_div_Tensor_out::call(self, other, out); + } + auto [other_value, other_bdim] = unwrapTensorAtLevel(other, cur_level); + batch_rule(self, other_value, other_bdim, out); +} +template +void _foreach_clamp_max_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_max_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_clamp_max_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_max_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_clamp_max_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_max_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_clamp_min_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_min_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_clamp_min_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_min_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_clamp_min_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_clamp_min_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_maximum_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_maximum_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_maximum_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_maximum_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_maximum_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_maximum_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_minimum_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & scalar, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_minimum_Scalar_out::call(self, scalar, out); + } + + batch_rule(self, scalar, out); +} +template +void _foreach_minimum_List_out_generated_plumbing(at::TensorList self, at::TensorList other, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(other, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_minimum_List_out::call(self, other, out); + } + + batch_rule(self, other, out); +} +template +void _foreach_minimum_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_minimum_ScalarList_out::call(self, scalars, out); + } + + batch_rule(self, scalars, out); +} +template +void _foreach_addcdiv_Scalar_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcdiv_Scalar_out::call(self, tensor1, tensor2, value, out); + } + + batch_rule(self, tensor1, tensor2, value, out); +} +template +void _foreach_addcdiv_ScalarList_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcdiv_ScalarList_out::call(self, tensor1, tensor2, scalars, out); + } + + batch_rule(self, tensor1, tensor2, scalars, out); +} +template +void _foreach_addcdiv_Tensor_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcdiv_Tensor_out::call(self, tensor1, tensor2, scalars, out); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim, out); +} +template +void _foreach_addcmul_Scalar_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Scalar & value, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcmul_Scalar_out::call(self, tensor1, tensor2, value, out); + } + + batch_rule(self, tensor1, tensor2, value, out); +} +template +void _foreach_addcmul_ScalarList_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, at::ArrayRef scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcmul_ScalarList_out::call(self, tensor1, tensor2, scalars, out); + } + + batch_rule(self, tensor1, tensor2, scalars, out); +} +template +void _foreach_addcmul_Tensor_out_generated_plumbing(at::TensorList self, at::TensorList tensor1, at::TensorList tensor2, const at::Tensor & scalars, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensor1, cur_level) && !isBatchedAtLevel(tensor2, cur_level) && !isBatchedAtLevel(scalars, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_addcmul_Tensor_out::call(self, tensor1, tensor2, scalars, out); + } + auto [scalars_value, scalars_bdim] = unwrapTensorAtLevel(scalars, cur_level); + batch_rule(self, tensor1, tensor2, scalars_value, scalars_bdim, out); +} +template +void _foreach_abs_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_abs_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_acos_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_acos_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_asin_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_asin_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_atan_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_atan_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_ceil_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_ceil_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_cos_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_cos_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_cosh_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_cosh_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_erf_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_erf_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_erfc_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_erfc_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_exp_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_exp_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_expm1_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_expm1_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_floor_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_floor_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_frac_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_frac_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_lerp_List_out_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::TensorList weights, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(weights, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lerp_List_out::call(self, tensors1, weights, out); + } + + batch_rule(self, tensors1, weights, out); +} +template +void _foreach_lerp_Scalar_out_generated_plumbing(at::TensorList self, at::TensorList tensors1, const at::Scalar & weight, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lerp_Scalar_out::call(self, tensors1, weight, out); + } + + batch_rule(self, tensors1, weight, out); +} +template +void _foreach_lerp_ScalarList_out_generated_plumbing(at::TensorList self, at::TensorList tensors1, at::ArrayRef weight, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(tensors1, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lerp_ScalarList_out::call(self, tensors1, weight, out); + } + + batch_rule(self, tensors1, weight, out); +} +template +void _foreach_lgamma_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_lgamma_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log10_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log10_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log1p_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log1p_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_log2_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_log2_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_max_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_max_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_neg_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_neg_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_norm_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & ord, ::std::optional dtype, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_norm_Scalar_out::call(self, ord, dtype, out); + } + + batch_rule(self, ord, dtype, out); +} +template +void _foreach_pow_List_out_generated_plumbing(at::TensorList self, at::TensorList exponent, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(exponent, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_pow_List_out::call(self, exponent, out); + } + + batch_rule(self, exponent, out); +} +template +void _foreach_pow_Scalar_out_generated_plumbing(at::TensorList self, const at::Scalar & exponent, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_pow_Scalar_out::call(self, exponent, out); + } + + batch_rule(self, exponent, out); +} +template +void _foreach_pow_ScalarList_out_generated_plumbing(at::TensorList self, at::ArrayRef exponent, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_pow_ScalarList_out::call(self, exponent, out); + } + + batch_rule(self, exponent, out); +} +template +void _foreach_reciprocal_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_reciprocal_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_round_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_round_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_rsqrt_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_rsqrt_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sigmoid_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sigmoid_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sign_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sign_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sin_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sin_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sinh_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sinh_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_sqrt_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_sqrt_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_tan_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_tan_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_tanh_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_tanh_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_trunc_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_trunc_out::call(self, out); + } + + batch_rule(self, out); +} +template +void _foreach_zero_out_generated_plumbing(at::TensorList self, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_zero_out::call(self, out); + } + + batch_rule(self, out); +} +template +::std::vector _foreach_zero_generated_plumbing(at::TensorList self) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level)) { + return at::_ops::_foreach_zero::call(self); + } + + auto results = batch_rule(self); + return makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level); +} +template +void _foreach_copy_out_generated_plumbing(at::TensorList self, at::TensorList src, bool non_blocking, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(src, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_foreach_copy_out::call(self, src, non_blocking, out); + } + + batch_rule(self, src, non_blocking, out); +} +template +::std::tuple rrelu_with_noise_functional_generated_plumbing(const at::Tensor & self, const at::Tensor & noise, const at::Scalar & lower, const at::Scalar & upper, bool training, ::std::optional generator) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(noise, cur_level)) { + return at::_ops::rrelu_with_noise_functional::call(self, noise, lower, upper, training, generator); + } + auto [self_value, self_bdim] = unwrapTensorAtLevel(self, cur_level); + auto [noise_value, noise_bdim] = unwrapTensorAtLevel(noise, cur_level); + auto results = batch_rule(self_value, self_bdim, noise_value, noise_bdim, lower, upper, training, generator); + return std::make_tuple(makeBatched(std::get<0>(results), std::get<1>(results), cur_level), makeBatched(std::get<2>(results), std::get<3>(results), cur_level)); +} +template +void _fused_adam_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adam_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_adam_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adam_tensor_lr_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adam_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adam_tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_adamw_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adamw_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, double lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_adamw_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adamw_tensor_lr_out::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector,::std::vector> _fused_adamw_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList exp_avgs, at::TensorList exp_avg_sqs, at::TensorList max_exp_avg_sqs, at::TensorList state_steps, const at::Tensor & lr, double beta1, double beta2, double weight_decay, double eps, bool amsgrad, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(exp_avgs, cur_level) && !isBatchedAtLevel(exp_avg_sqs, cur_level) && !isBatchedAtLevel(max_exp_avg_sqs, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adamw_tensor_lr::call(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, lr_value, lr_bdim, beta1, beta2, weight_decay, eps, amsgrad, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level), makeBatchedVector(std::get<8>(results), std::get<9>(results), cur_level)); +} +template +void _fused_sgd_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_sgd_out::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, double lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +void _fused_sgd_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_sgd_tensor_lr_out::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr_value, lr_bdim, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_sgd_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList momentum_buffer_list, double weight_decay, double momentum, const at::Tensor & lr, double dampening, bool nesterov, bool maximize, bool is_first_step, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(momentum_buffer_list, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_sgd_tensor_lr::call(self, grads, momentum_buffer_list, weight_decay, momentum, lr, dampening, nesterov, maximize, is_first_step, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, momentum_buffer_list, weight_decay, momentum, lr_value, lr_bdim, dampening, nesterov, maximize, is_first_step, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} +template +void _fused_adagrad_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adagrad_out::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector,::std::vector> _fused_adagrad_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, double lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level), makeBatchedVector(std::get<6>(results), std::get<7>(results), cur_level)); +} +template +void _fused_adagrad_tensor_lr_out_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf, at::TensorList out) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing_no_returns"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level) && !isBatchedAtLevel(out, cur_level)) { + return at::_ops::_fused_adagrad_tensor_lr_out::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf, out); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + batch_rule(self, grads, state_sums, state_steps, lr_value, lr_bdim, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim, out); +} +template +::std::tuple<::std::vector,::std::vector,::std::vector> _fused_adagrad_tensor_lr_generated_plumbing(at::TensorList self, at::TensorList grads, at::TensorList state_sums, at::TensorList state_steps, const at::Tensor & lr, double lr_decay, double weight_decay, double eps, bool maximize, const ::std::optional & grad_scale, const ::std::optional & found_inf) { + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::FuncTorchBatched); + auto maybe_layer = maybeCurrentDynamicLayer(); + vmap_check_escaped(maybe_layer, "gen_vmap_plumbing"); + int64_t cur_level = maybe_layer->layerId(); + if (!isBatchedAtLevel(self, cur_level) && !isBatchedAtLevel(grads, cur_level) && !isBatchedAtLevel(state_sums, cur_level) && !isBatchedAtLevel(state_steps, cur_level) && !isBatchedAtLevel(lr, cur_level) && !isBatchedAtLevel(grad_scale, cur_level) && !isBatchedAtLevel(found_inf, cur_level)) { + return at::_ops::_fused_adagrad_tensor_lr::call(self, grads, state_sums, state_steps, lr, lr_decay, weight_decay, eps, maximize, grad_scale, found_inf); + } + auto [lr_value, lr_bdim] = unwrapTensorAtLevel(lr, cur_level); + std::optional grad_scale_value; + std::optional grad_scale_bdim; + if (grad_scale) { + std::tie(grad_scale_value, grad_scale_bdim) = unwrapTensorAtLevel(grad_scale.value(), cur_level); + } + std::optional found_inf_value; + std::optional found_inf_bdim; + if (found_inf) { + std::tie(found_inf_value, found_inf_bdim) = unwrapTensorAtLevel(found_inf.value(), cur_level); + } + auto results = batch_rule(self, grads, state_sums, state_steps, lr_value, lr_bdim, lr_decay, weight_decay, eps, maximize, grad_scale_value, grad_scale_bdim, found_inf_value, found_inf_bdim); + return std::make_tuple(makeBatchedVector(std::get<0>(results), std::get<1>(results), cur_level), makeBatchedVector(std::get<2>(results), std::get<3>(results), cur_level), makeBatchedVector(std::get<4>(results), std::get<5>(results), cur_level)); +} + +}} // namespace at::functorch diff --git a/phivenv/Lib/site-packages/torch/include/ATen/WrapDimUtils.h b/phivenv/Lib/site-packages/torch/include/ATen/WrapDimUtils.h new file mode 100644 index 0000000000000000000000000000000000000000..9d3009506fb89e0e293ab29432c6bc617f1e7336 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/WrapDimUtils.h @@ -0,0 +1,156 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +// if dim_post_expr is 0 and wrap_scalar is true, then dim must be in the +// range [-1, 0]. This is a special case for scalar tensors and manifests in +// e.g. torch.sum(scalar_tensor, 0) Otherwise, dim should be in the range +// [-dim_post_expr, dim_post_expr-1]. +using c10::maybe_wrap_dim; + +inline int64_t maybe_wrap_dim(int64_t dim, TensorImpl* tensor) { + return maybe_wrap_dim(dim, tensor->dim()); +} + +inline int64_t maybe_wrap_dim(int64_t dim, TensorList tensors) { + if (tensors.empty()) { + // can't wrap empty TensorList; rely on underlying implementation to throw + // error if necessary. + return dim; + } + return maybe_wrap_dim(dim, tensors[0].dim()); +} + +inline int64_t maybe_wrap_dim( + int64_t dim, + const std::vector>& tensor_sizes) { + if (tensor_sizes.empty()) { + // can't wrap empty list; rely on underlying implementation to throw error + // if necessary + return dim; + } + return maybe_wrap_dim(dim, static_cast(tensor_sizes[0].size())); +} + +// Given an array of dimensions `dims` of length `ndims`, this function "Wraps" +// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be +// specified using negative indices. +// +// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will +// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for +// dimensions not in the range [-dim_post_expr, dim_post_expr). +inline void maybe_wrap_dims_n( + int64_t* dims, + int64_t ndims, + int64_t dim_post_expr, + bool wrap_scalars = true) { + if (dim_post_expr <= 0) { + if (wrap_scalars) { + dim_post_expr = 1; // this will make range [-1, 0] + } else { + TORCH_CHECK_INDEX( + ndims == 0, + "Dimension specified as ", + dims[0], + " but tensor has no dimensions"); + return; + } + } + int64_t min = -dim_post_expr; + int64_t max = dim_post_expr - 1; + for (const auto i : c10::irange(ndims)) { + auto& dim = dims[i]; + if (dim < min || dim > max) { + TORCH_CHECK_INDEX( + false, + "Dimension out of range (expected to be in range of [", + min, + ", ", + max, + "], but got ", + dim, + ")"); + } + if (dim < 0) + dim += dim_post_expr; + } +} + +// Given a contiguous container of dimensions `dims`, this function "Wraps" +// each dim in-place for a tensor of rank `dim_post_expr`, allowing dims to be +// specified using negative indices. +// +// Additionally, if `wrap_scalar` is true then scalar tensors with rank 0, will +// allow dimensions in the range [-1, 0]. Otherwise, an IndexError is raised for +// dimensions not in the range [-dim_post_expr, dim_post_expr). +template +inline void maybe_wrap_dims( + Container& dims, + int64_t dim_post_expr, + bool wrap_scalars = true) { + return maybe_wrap_dims_n( + dims.data(), dims.size(), dim_post_expr, wrap_scalars); +} + +// previously, size [0] tensors were the only possible empty tensors; thus, it +// wasn't possible to cat empty tensors unless all the other tensors were +// 1-dimensional, so we allowed these tensors to be "skipped" (both for wrap +// dimension behavior and dimension size checking). We maintain this behavior +// for backwards compatibility, but only for this specific size (i.e. other +// empty sizes are not skipped). +inline int64_t legacy_cat_wrap_dim( + int64_t dim, + const std::vector>& tensor_sizes) { + for (auto& sizes : tensor_sizes) { + if (sizes.size() == 1 && sizes[0] == 0) { + continue; + } + return maybe_wrap_dim(dim, static_cast(sizes.size())); + } + return dim; +} + +inline int64_t legacy_cat_wrap_dim_symint( + int64_t dim, + const std::vector>& tensor_sizes) { + for (auto& sizes : tensor_sizes) { + if (sizes.size() == 1) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(sizes[0].sym_eq(0))) { + continue; + } + } + return maybe_wrap_dim(dim, static_cast(sizes.size())); + } + return dim; +} + +inline int64_t legacy_cat_wrap_dim( + int64_t dim, + const MaterializedITensorListRef& tensors) { + for (const Tensor& tensor : tensors) { + if (tensor.dim() == 1) { + if (TORCH_GUARD_SIZE_OBLIVIOUS(tensor.sym_sizes()[0].sym_eq(0))) { + continue; + } + } + return maybe_wrap_dim(dim, tensor.dim()); + } + return dim; +} + +// wrap negative dims in a vector +inline void wrap_all_dims( + std::vector& dims_to_wrap, + int64_t tensor_total_dims) { + for (const auto i : c10::irange(dims_to_wrap.size())) { + dims_to_wrap[i] = maybe_wrap_dim(dims_to_wrap[i], tensor_total_dims); + } +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/WrapDimUtilsMulti.h b/phivenv/Lib/site-packages/torch/include/ATen/WrapDimUtilsMulti.h new file mode 100644 index 0000000000000000000000000000000000000000..d4eeba8e879721afd43977ea8ae16aec747b2e26 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/WrapDimUtilsMulti.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include +#include +#include + +namespace at { + +// This is in an extra file to work around strange interaction of +// bitset on Windows with operator overloading + +constexpr size_t dim_bitset_size = 64; + +inline std::bitset dim_list_to_bitset( + OptionalIntArrayRef opt_dims, + size_t ndims) { + TORCH_CHECK( + ndims <= dim_bitset_size, + "only tensors with up to ", + dim_bitset_size, + " dims are supported"); + std::bitset seen; + if (opt_dims.has_value()) { + auto dims = opt_dims.value(); + for (const auto i : c10::irange(dims.size())) { + size_t dim = maybe_wrap_dim(dims[i], static_cast(ndims)); + TORCH_CHECK( + !seen[dim], + "dim ", + dim, + " appears multiple times in the list of dims"); + seen[dim] = true; + } + } else { + for (size_t dim = 0; dim < ndims; dim++) { + seen[dim] = true; + } + } + return seen; +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/autocast_mode.h b/phivenv/Lib/site-packages/torch/include/ATen/autocast_mode.h new file mode 100644 index 0000000000000000000000000000000000000000..1ab7305bd09b5856c477b5b1f217e436e8ab0283 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/autocast_mode.h @@ -0,0 +1,971 @@ +#pragma once + +#include +#include +#include +#include + +#include +#include + +namespace at::autocast { + +TORCH_API bool is_autocast_enabled(at::DeviceType device_type); +TORCH_API void set_autocast_enabled(at::DeviceType device_type, bool enabled); +TORCH_API at::ScalarType get_autocast_dtype(at::DeviceType device_type); +TORCH_API void set_autocast_dtype( + at::DeviceType device_type, + at::ScalarType dtype); +TORCH_API void clear_cache(); +TORCH_API int increment_nesting(); +TORCH_API int decrement_nesting(); +TORCH_API bool is_autocast_cache_enabled(); +TORCH_API void set_autocast_cache_enabled(bool enabled); + +// deprecated CUDA-specific autocast APIs +C10_DEPRECATED_MESSAGE( + "at::autocast::is_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.") +TORCH_API inline bool is_enabled() { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "() is deprecated. Please use at::autocast::is_autocast_enabled(at::kCUDA) instead.") + return is_autocast_enabled(at::kCUDA); +} +C10_DEPRECATED_MESSAGE( + "at::autocast::set_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.") +TORCH_API inline void set_enabled(bool enabled) { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(at::kCUDA, enabled) instead.") + set_autocast_enabled(at::kCUDA, enabled); +} +C10_DEPRECATED_MESSAGE( + "at::autocast::get_autocast_gpu_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.") +TORCH_API inline at::ScalarType get_autocast_gpu_dtype() { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "() is deprecated. Please use at::autocast::get_autocast_dtype(at::kCUDA) instead.") + return get_autocast_dtype(at::kCUDA); +} +C10_DEPRECATED_MESSAGE( + "at::autocast::set_autocast_gpu_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.") +TORCH_API inline void set_autocast_gpu_dtype(at::ScalarType dtype) { + TORCH_WARN_DEPRECATION( + "at::autocast::", + __func__, + "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(at::kCUDA, dtype) instead.") + set_autocast_dtype(at::kCUDA, dtype); +} + +#define DECLARE_DEPRECATED_AUTOCAST_APIS(name, device_type) \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::is_" #name \ + "_enabled() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \ + ") instead.") \ + TORCH_API inline bool is_##name##_enabled() { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "() is deprecated. Please use at::autocast::is_autocast_enabled(" #device_type \ + ") instead.") \ + return is_autocast_enabled(device_type); \ + } \ + \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::set_" #name \ + "_enabled(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \ + ", enabled) instead.") \ + TORCH_API inline void set_##name##_enabled(bool enabled) { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "(enabled) is deprecated. Please use at::autocast::set_autocast_enabled(" #device_type \ + ", enabled) instead.") \ + set_autocast_enabled(device_type, enabled); \ + } \ + \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::get_autocast_" #name \ + "_dtype() is deprecated. Please use at::autocast::get_autocast_dtype(" #device_type \ + ") instead.") \ + TORCH_API inline at::ScalarType get_autocast_##name##_dtype() { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "() is deprecated. Please at::autocast::get_autocast_dtype(" #device_type \ + ") instead.") \ + return get_autocast_dtype(device_type); \ + } \ + \ + C10_DEPRECATED_MESSAGE( \ + "at::autocast::set_autocast_" #name \ + "_dtype(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \ + ", dtype) instead.") \ + TORCH_API inline void set_autocast_##name##_dtype(at::ScalarType dtype) { \ + TORCH_WARN_DEPRECATION( \ + "at::autocast::", \ + __func__, \ + "(dtype) is deprecated. Please use at::autocast::set_autocast_dtype(" #device_type \ + ", dtype) instead.") \ + set_autocast_dtype(device_type, dtype); \ + } + +#define AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(_) \ + _(cpu, at::kCPU) \ + _(mtia, at::kMTIA) \ + _(xpu, at::kXPU) \ + _(xla, at::kXLA) \ + _(hpu, at::kHPU) \ + _(ipu, at::kIPU) \ + _(privateuseone, at::kPrivateUse1) + +// deprecated other backend specific autocast APIs +// NOLINTNEXTLINE(misc-use-internal-linkage) +AT_FORALL_DEPRECATED_AUTOCAST_BACKENDS(DECLARE_DEPRECATED_AUTOCAST_APIS) + +const std::array _AUTOCAST_SUPPORTED_DEVICES{ + at::kCPU, + at::kCUDA, + at::kMTIA, + at::kMAIA, + at::kXPU, + at::kIPU, + at::kHPU, + at::kXLA, + at::kPrivateUse1, + at::kMPS}; + +namespace { +inline bool is_autocast_eligible( + const Tensor& tensor, + c10::DeviceType device_type) { + switch (device_type) { + case c10::DeviceType::CUDA: + return (tensor.is_cuda() || tensor.is_xla()) && + tensor.is_floating_point(); + case c10::DeviceType::CPU: + return (tensor.is_cpu() || tensor.is_mkldnn()) && + tensor.is_floating_point(); + case c10::DeviceType::MTIA: + return tensor.is_mtia() && tensor.is_floating_point(); + case c10::DeviceType::MAIA: + return tensor.is_maia() && tensor.is_floating_point(); + case c10::DeviceType::XPU: + return tensor.is_xpu() && tensor.is_floating_point(); + case c10::DeviceType::IPU: + return tensor.is_ipu() && tensor.is_floating_point(); + case c10::DeviceType::HPU: + return tensor.is_hpu() && tensor.is_floating_point(); + case c10::DeviceType::XLA: + return tensor.is_xla() && tensor.is_floating_point(); + case c10::DeviceType::PrivateUse1: + return tensor.is_privateuseone() && tensor.is_floating_point(); + case c10::DeviceType::MPS: + return tensor.is_mps() && tensor.is_floating_point(); + default: + return false; + } +} +} // namespace + +inline DispatchKey get_autocast_dispatch_key_from_device_type( + c10::DeviceType device_type) { + switch (device_type) { + case c10::DeviceType::CUDA: + return DispatchKey::Autocast; + case c10::DeviceType::CPU: + return DispatchKey::AutocastCPU; + case c10::DeviceType::MTIA: + return DispatchKey::AutocastMTIA; + case c10::DeviceType::MAIA: + return DispatchKey::AutocastMAIA; + case c10::DeviceType::XPU: + return DispatchKey::AutocastXPU; + case c10::DeviceType::IPU: + return DispatchKey::AutocastIPU; + case c10::DeviceType::HPU: + return DispatchKey::AutocastHPU; + case c10::DeviceType::XLA: + return DispatchKey::AutocastXLA; + case c10::DeviceType::PrivateUse1: + return DispatchKey::AutocastPrivateUse1; + case c10::DeviceType::MPS: + return DispatchKey::AutocastMPS; + default: + TORCH_CHECK( + false, + "unknown device type for autocast in get_autocast_dispatch_key_from_device_type"); + } +} + +inline bool is_autocast_available(c10::DeviceType device_type) { + if (std::find( + _AUTOCAST_SUPPORTED_DEVICES.begin(), + _AUTOCAST_SUPPORTED_DEVICES.end(), + device_type) != _AUTOCAST_SUPPORTED_DEVICES.end()) { + return true; + } else { + return false; + } +} + +inline at::ScalarType get_lower_precision_fp_from_device_type( + c10::DeviceType device_type) { + if (is_autocast_available(device_type)) { + return get_autocast_dtype(device_type); + } else { + TORCH_CHECK( + false, + "unknown device type for autocast in get_lower_precision_fp_from_device_type"); + } +} + +/******************************************************************** +Logic to extract the promote type from any Tensor or TensorList args. +********************************************************************/ + +// Overload to catch Tensor args. +// If nextArg is floating-point, compare its scalar_type with our +// current best guess for the promote type, and update if necessary. +inline at::ScalarType prioritize( + at::ScalarType current, + const Tensor& nextArg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + if (current == at::kDouble) { + TORCH_CHECK(false, "promote type is double in at::autocast::prioritize"); + return current; + } + at::ScalarType lower_precision_fp = + get_lower_precision_fp_from_device_type(device_type); + if (is_autocast_eligible(nextArg, device_type)) { + auto next = nextArg.scalar_type(); + if (next == at::kDouble) { + return current; // ignores double tensors + } else if (current == at::kFloat || next == at::kFloat) { + return at::kFloat; // prioritizes float over lower_precision_fp + } else if (current == lower_precision_fp && next == lower_precision_fp) { + return lower_precision_fp; + } else { + TORCH_CHECK( + false, "Unexpected floating ScalarType in at::autocast::prioritize"); + return current; + } + } else { + return current; + } +} + +// Overload to catch TensorList args (for e.g. cat, stack). +// Reuses the overload above to process each Tensor in the list. +inline at::ScalarType prioritize( + at::ScalarType current, + const TensorList& list, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + for (const auto& tensor : list) { + current = prioritize(current, tensor, device_type); + } + return current; +} + +inline at::ScalarType prioritize( + at::ScalarType current, + const ITensorListRef& list, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + for (const auto& tensor : list) { + current = prioritize(current, tensor, device_type); + } + return current; +} + +// Template to catch non-Tensor args (no-op that returns current best guess) +template +inline at::ScalarType prioritize( + at::ScalarType current, + T nextArg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + return current; +} + +// Overload for the tail case. +inline at::ScalarType promote_type( + at::ScalarType current, + c10::DeviceType device_type) { + return current; +} + +// Unpack args and determine if incoming lower_precision_fp tensors need to be +// promoted to float32. Non-Tensor arguments are ignored. +template +inline at::ScalarType promote_type( + at::ScalarType current, + c10::DeviceType device_type, + Arg0 arg0, + Args... args) { + auto new_current = prioritize(current, arg0, device_type); + return promote_type(new_current, device_type, args...); +} + +/**************************************************** +Logic to apply cached casting to any Tensor argument. +****************************************************/ +inline bool is_eligible( + const Tensor& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + return ( + arg.defined() && is_autocast_eligible(arg, device_type) && + (arg.scalar_type() != at::kDouble)); +} + +// Overload to catch Tensor args +TORCH_API Tensor cached_cast( + at::ScalarType to_type, + const Tensor& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA); + +// Overload to process std::optional +inline std::optional cached_cast( + at::ScalarType to_type, + const std::optional& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + if (arg.has_value()) { + return cached_cast(to_type, *arg, device_type); + } else { + return std::nullopt; + } +} + +// Overload to process TensorLists +inline std::vector cached_cast( + at::ScalarType to_type, + const TensorList& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + std::vector vec; + vec.reserve(arg.size()); + for (const auto& t : arg) { + vec.emplace_back(cached_cast(to_type, t, device_type)); + } + return vec; +} + +inline std::vector cached_cast( + at::ScalarType to_type, + const ITensorListRef& arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + std::vector vec; + vec.reserve(arg.size()); + for (const auto& t : arg) { + vec.emplace_back(cached_cast(to_type, t, device_type)); + } + return vec; +} + +// Template to catch non-Tensor args. +template +inline T cached_cast( + at::ScalarType to_type, + T arg, + c10::DeviceType device_type = c10::DeviceType::CUDA) { + return arg; +} + +/******************************************************* +Logic to flip an output dtype flag. +Keep it simple for now by assuming only one such flag is +present in the argument list. If I ever need a function +with more than flag I'll figure out something else. +The policy is: +If the user has explicity specified a dtype, respect it. +Otherwise, set it to the autocast type. +********************************************************/ + +// Overload to catch dtype flags +std::optional inline set_opt_dtype( + at::ScalarType to_type, + const std::optional& dtype) { + return dtype.has_value() ? dtype : to_type; +} + +// Template to catch other args +template +inline T set_opt_dtype(at::ScalarType to_type, T arg) { + return arg; +} + +template +inline bool firstarg_is_eligible( + c10::DeviceType device_type, + const Tensor& arg, + Args... args) { + return is_eligible(arg, device_type); +} + +template +inline at::ScalarType type_from_firstarg( + c10::DeviceType device_type, + at::ScalarType to_type, + const Tensor& arg, + Args... args) { + return (is_eligible(arg, device_type) ? to_type : arg.scalar_type()); +} + +// Policies correspond to op categories that need code-divergent handling. +// Wrapper templates below are specialized based on a policy template parameter. +enum class CastPolicy : uint8_t { + lower_precision_fp = 0, // Cast all inputs to lower_precision_fp before + // running the op. Currently, lower_precision_fp is + // fp16 for AutocastCUDA, and is defined by user + // (default bf16) for AutocastCPU or other device. + fp32, // Cast all inputs to at::kFloat before running the op. + fp32_set_opt_dtype, // Treats functions (like softmax) that + // 1. we'd like to run in fp32 and + // 2. have a std::optional arg that controls + // the output type. + // fp32_set_opt_dtype wrappers' policy is: if the output + // type is already set, don't touch it, otherwise, set + // it to at::kFloat. + fp32_append_dtype, // Treats functions (like norm) that + // 1. we'd like to run in fp32 and + // 2. have some overloads that accept an output type and + // other overloads that don't. + // fp32_append_dtype wrappers wrap the overloads that don't + // have an output dtype. + // The wrapper policy is: append at::kFloat to the args, + // and redispatch to the type-aware overload. + promote, // Run in the widest dtype among several args. +}; + +/******************************************************************************************************** +Templates to provide wrapper functions + +I'm copying the pattern used in core/boxing/impl/WrapFunctionIntoFunctor.h to +extract args and return type. (see also +https://stackoverflow.com/questions/46533698/how-to-deduce-argument-list-from-function-pointer) + +This strategy uses an exterior "WrapFunction" that extracts arguments on behalf +of (in my case several specializations of) an interior "WrapFunction_". +Interior WrapFunction_ specializations are defined for each CastPolicy. +********************************************************************************************************/ + +// Base template for WrapFunction_, which is specialized to contain a "call" +// method each CastPolicy +template < + CastPolicy policy, + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class ArgList> +struct WrapFunction_ {}; + +// CastPolicy::lower_precision_fp General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::lower_precision_fp, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast( + get_lower_precision_fp_from_device_type(device_type), + args, + device_type)...); + } +}; + +// CastPolicy::fp32 General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::fp32, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + return (*F)(cached_cast(at::kFloat, args, device_type)...); + } +}; + +// CastPolicy::fp32_set_opt_dtype General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::fp32_set_opt_dtype, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + if (firstarg_is_eligible(device_type, args...)) { + return (*F)(set_opt_dtype(at::kFloat, args)...); + } else { + // If ineligible, calls F with unaltered args. Does not set opt dtype, + // because setting opt dtype explicitly may interfere with internal + // implicit promotion decisions. + return (*F)(args...); + } + } +}; + +// CastPolicy::fp32_append_dtype General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::fp32_append_dtype, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + at::ScalarType out_type = + type_from_firstarg(device_type, at::kFloat, args...); + return (*F)(args..., out_type); + } +}; + +// CastPolicy::promote General_DeviceType +template < + c10::DeviceType device_type, + class Redispatch, + Redispatch* F, + class Ret, + class... Args> +struct WrapFunction_< + CastPolicy::promote, + device_type, + Redispatch, + F, + Ret, + guts::typelist::typelist> { + static Ret call(Args... args) { + c10::impl::ExcludeDispatchKeyGuard no_autocast( + get_autocast_dispatch_key_from_device_type(device_type)); + auto to_type = promote_type( + get_lower_precision_fp_from_device_type(device_type), + device_type, + args...); + return (*F)(cached_cast(to_type, args, device_type)...); + } +}; + +// Wrapper to infer return_type and parameter_types for WrapFunction_ (imitating +// core/boxing/impl/WrapFunctionIntoFunctor.h) +template < + CastPolicy policy, + c10::DeviceType device_type, + class Registered, // The signature for which we're registering. The + // dispatcher's calling code invokes our registered + // functions with arguments matching Registered, so we + // register WrapFunction_::call methods with a matching + // signature to properly field those arguments. + // guts::function_traits below extracts return_type and + // parameter_types from Registered, which WrapFunction_ + // templates above use to declare their call methods. + class Redispatch, // The signature for the function we're redispatching to. + // In most cases this is the same as Registered, but for + // some ops (for example, ops where we append a dtype) + // it's useful to redispatch to a function with a + // different signature. + Redispatch* F> // The actual function we're redispatching to. +struct WrapFunction final { + using type = WrapFunction_< + policy, + device_type, + Redispatch, + F, + typename guts::function_traits::return_type, + typename guts::function_traits::parameter_types>; +}; + +/***************************************************************************************************************** +This section performs load-time registration for autocast wrappers. + +It's debatable at what level operations should be patched. We'd like casts to +be autograd-exposed and precede autograd history recording, so that for +lower_precision_fp ops, input tensors are saved for backward in +lower_precision_fp rather than fp32. Saving inputs in lower_precision_fp +can significantly reduce a model's memory footprint. + +Option 1 (strawman): Patch only at the level of explicit calls into +cudnn/cublas (cudnn_convolution, etc), because those are the code paths that are +guaranteed to use Tensor Cores, therefore they're the ones that will benefit +most from lower_precision_fp. Potential pitfall: convolutions (and other ops) +are wrapped in several layers of at::* calls. If one of those happens to record +autograd history, then we've lost the opportunity to save inputs in +lower_precision_fp. + +Option 2: Patch the Python-exposed surface of calls, to make 100% sure autograd +history recording can't sneak in ahead of autocast. This mirrors Apex most +closely. + +I think Option 2 is the right answer for all ops, not just convolutions. Option +2 is what I implement here. +*****************************************************************************************************************/ + +/******************************************************************************************************************** +Explicit registration for out-of-place ops + +The stuff below could be codegenned. Ed said +> you are going to have to write the function definition at some point, I +wouldn't try to get clever about it Therefore, for the moment, this is all +copy pasted in from VariableTypeEverything.cpp with appropriate substitutions. +********************************************************************************************************************/ + +} // namespace at::autocast + +#define ADD_NS(RAW_OP) at::RAW_OP + +#define _KERNEL_OVERLOAD_NARG_IMPL(_0, _1, _2, N, ...) N +#define _KERNEL_OVERLOAD_NARG(...) \ + C10_EXPAND_MSVC_WORKAROUND(_KERNEL_OVERLOAD_NARG_IMPL(__VA_ARGS__, 2, 1)) + +// Common cases where registration signature matches redispatch signature +// (that's why SIGNATURE is repeated in the WrapFunction instantiation) +#define KERNEL1(DISPATCHKEY, OP, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP), \ + &::at::autocast::WrapFunction< \ + ::at::autocast::CastPolicy::POLICY, \ + DISPATCHKEY, \ + decltype(ATEN_FN(OP)), \ + decltype(ATEN_FN(OP)), \ + &ATEN_FN(OP)>::type::call); + +#define KERNEL2(DISPATCHKEY, OP, OVERLOAD, POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" #OP "." #OVERLOAD), \ + &::at::autocast::WrapFunction< \ + ::at::autocast::CastPolicy::POLICY, \ + DISPATCHKEY, \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + decltype(ATEN_FN2(OP, OVERLOAD)), \ + &ATEN_FN2(OP, OVERLOAD)>::type::call); + +#define _KERNEL_DISPATCH(DISPATCHKEY, NARG, ...) \ + C10_CONCATENATE(KERNEL, NARG)(DISPATCHKEY, __VA_ARGS__) + +#define _KERNEL_IMPL(DISPATCHKEY, ...) \ + _KERNEL_DISPATCH(DISPATCHKEY, _KERNEL_OVERLOAD_NARG(__VA_ARGS__), __VA_ARGS__) + +// It will dispatch to KERNEL1 or KERNEL2 based on its inputs. +#define KERNEL(DISPATCHKEY, ...) _KERNEL_IMPL(DISPATCHKEY, __VA_ARGS__) + +// Less-common but still useful case: redispatching to a function +// with a new signature (e.g. appending a dtype) +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + DISPATCHKEY, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + m.impl( \ + TORCH_SELECTIVE_NAME("aten::" REGISTER_NAME), \ + &::at::autocast::WrapFunction< \ + ::at::autocast::CastPolicy::POLICY, \ + DISPATCHKEY, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + &REDISPATCH_FUNC>::type::call); + +// KERNEL_CPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCPU +#define KERNEL_CPU(...) KERNEL(c10::DeviceType::CPU, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CPU( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::CPU, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_CUDA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastCUDA +#define KERNEL_CUDA(...) KERNEL(c10::DeviceType::CUDA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_CUDA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::CUDA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_MTIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMTIA +#define KERNEL_MTIA(...) KERNEL(c10::DeviceType::MTIA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MTIA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::MTIA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_MAIA/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMAIA +#define KERNEL_MAIA(...) KERNEL(c10::DeviceType::MAIA, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_MAIA( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::MAIA, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_XPU/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastXPU +#define KERNEL_XPU(...) KERNEL(c10::DeviceType::XPU, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_XPU( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::XPU, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_PRIVATEUSEONE/KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastPrivateUse1 +#define KERNEL_PRIVATEUSEONE(...) \ + KERNEL(c10::DeviceType::PrivateUse1, __VA_ARGS__) + +#define KERNEL_DIFFERENT_REDISPATCH_SIGNATURE_PRIVATEUSEONE( \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) \ + KERNEL_DIFFERENT_REDISPATCH_SIGNATURE( \ + c10::DeviceType::PrivateUse1, \ + REDISPATCH_FUNC, \ + REGISTER_NAME, \ + REGISTER_SIGNATURE, \ + REDISPATCH_SIGNATURE, \ + POLICY) + +// KERNEL_MPS +// registration (OP, POLICY) or (OP, OVERLOAD, POLICY) for AutocastMPS +#define KERNEL_MPS(...) KERNEL(c10::DeviceType::MPS, __VA_ARGS__) + +// Op lists for different policies. +// To make sure other backends can reuse the policy op list. +#define AT_FORALL_LOWER_PRECISION_FP(_) \ + _(_convolution, deprecated) \ + _(_convolution) \ + _(conv1d) \ + _(conv2d) \ + _(conv3d) \ + _(conv_tbc) \ + _(conv_transpose1d) \ + _(conv_transpose2d, input) \ + _(conv_transpose3d, input) \ + _(convolution) \ + _(prelu) \ + _(addmm) \ + _(addmv) \ + _(addr) \ + _(matmul) \ + _(einsum) \ + _(mm) \ + _(mv) \ + _(linalg_vecdot) \ + _(linear) \ + _(addbmm) \ + _(baddbmm) \ + _(bmm) \ + _(chain_matmul) \ + _(linalg_multi_dot) \ + _(_thnn_fused_lstm_cell) \ + _(_thnn_fused_gru_cell) \ + _(lstm_cell) \ + _(gru_cell) \ + _(rnn_tanh_cell) \ + _(rnn_relu_cell) \ + _(_scaled_dot_product_flash_attention) \ + _(scaled_dot_product_attention) + +#define AT_FORALL_FP32(_) \ + _(acos) \ + _(asin) \ + _(cosh) \ + _(erfinv) \ + _(exp) \ + _(expm1) \ + _(log) \ + _(log10) \ + _(log2) \ + _(log1p) \ + _(reciprocal) \ + _(rsqrt) \ + _(sinh) \ + _(tan) \ + _(pow, Tensor_Scalar) \ + _(pow, Tensor_Tensor) \ + _(pow, Scalar) \ + _(softplus) \ + _(layer_norm) \ + _(native_layer_norm) \ + _(group_norm) \ + _(frobenius_norm, dim) \ + _(nuclear_norm) \ + _(nuclear_norm, dim) \ + _(cosine_similarity) \ + _(poisson_nll_loss) \ + _(cosine_embedding_loss) \ + _(nll_loss) \ + _(nll_loss2d) \ + _(hinge_embedding_loss) \ + _(kl_div) \ + _(l1_loss) \ + _(smooth_l1_loss) \ + _(huber_loss) \ + _(mse_loss) \ + _(margin_ranking_loss) \ + _(multilabel_margin_loss) \ + _(soft_margin_loss) \ + _(triplet_margin_loss) \ + _(multi_margin_loss) \ + _(binary_cross_entropy_with_logits) \ + _(dist) \ + _(pdist) \ + _(cdist) \ + _(renorm) \ + _(logsumexp) \ + _(upsample_nearest1d) \ + _(_upsample_nearest_exact1d) \ + _(upsample_nearest2d) \ + _(_upsample_nearest_exact2d) \ + _(upsample_nearest3d) \ + _(_upsample_nearest_exact3d) \ + _(upsample_linear1d) \ + _(upsample_bilinear2d) \ + _(_upsample_bilinear2d_aa) \ + _(upsample_trilinear3d) \ + _(upsample_bicubic2d) \ + _(_upsample_bicubic2d_aa) + +#define AT_FORALL_FP32_SET_OPT_DTYPE(_) \ + _(prod) \ + _(prod, dim_int) \ + _(prod, dim_Dimname) \ + _(softmax, int) \ + _(softmax, Dimname) \ + _(log_softmax, int) \ + _(log_softmax, Dimname) \ + _(cumprod) \ + _(cumprod, dimname) \ + _(cumsum) \ + _(cumsum, dimname) \ + _(linalg_vector_norm) \ + _(linalg_matrix_norm) \ + _(linalg_matrix_norm, str_ord) \ + _(sum) \ + _(sum, dim_IntList) \ + _(sum, dim_DimnameList) + +#define AT_FORALL_DIFFERENT_REDISPATCH_SIGNATURE(_) \ + _(ADD_NS(norm), \ + "norm.Scalar", \ + Tensor(const Tensor&, const Scalar&), \ + Tensor(const Tensor&, const std::optional&, ScalarType), \ + fp32_append_dtype) \ + _(ADD_NS(norm), \ + "norm.ScalarOpt_dim", \ + Tensor(const Tensor&, const std::optional&, IntArrayRef, bool), \ + Tensor( \ + const Tensor&, \ + const std::optional&, \ + IntArrayRef, \ + bool, \ + ScalarType), \ + fp32_append_dtype) \ + _(ADD_NS(norm), \ + "norm.names_ScalarOpt_dim", \ + Tensor(const Tensor&, const std::optional&, DimnameList, bool), \ + Tensor( \ + const Tensor&, \ + const std::optional&, \ + DimnameList, \ + bool, \ + ScalarType), \ + fp32_append_dtype) + +#define AT_FORALL_PROMOTE(_) \ + _(addcdiv) \ + _(addcmul) \ + _(atan2) \ + _(bilinear) \ + _(cross) \ + _(dot) \ + _(vdot) \ + _(grid_sampler) \ + _(index_put) \ + _(tensordot) \ + _(scatter_add) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/ceil_div.h b/phivenv/Lib/site-packages/torch/include/ATen/ceil_div.h new file mode 100644 index 0000000000000000000000000000000000000000..7eb9940e57d8bd97cef964acb8650466d663da17 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/ceil_div.h @@ -0,0 +1,24 @@ +#pragma once +#include +#include + +namespace at { + +/** + Computes ceil(a / b) +*/ +template >> +C10_ALWAYS_INLINE C10_HOST_DEVICE T ceil_div(T a, T b) { + return (a + b - 1) / b; +} + +/** + Computes ceil(a / b) * b; i.e., rounds up `a` to the next highest + multiple of b +*/ +template +C10_ALWAYS_INLINE C10_HOST_DEVICE T round_up(T a, T b) { + return ceil_div(a, b) * b; +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/code_template.h b/phivenv/Lib/site-packages/torch/include/ATen/code_template.h new file mode 100644 index 0000000000000000000000000000000000000000..d1b04345e6799ebd05a54907a63c4b73b0f58792 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/code_template.h @@ -0,0 +1,245 @@ +#pragma once + +#include + +#include +#include +#include +#include + +namespace at::jit { + +// A template environment is a mapping from template variable names, e.g., +// identifier (corresponding to $identifier) to their expansions. +// +// This template environment supports storing strings, numbers and lists +// of strings, and can be chained together (so that lookup proceeds in +// in the top level environment, and then recurses into a parent +// environment if the key is not found.) +struct TemplateEnv { + TemplateEnv() = default; + TemplateEnv(TemplateEnv& parent) : parent(&parent) {} + TemplateEnv(TemplateEnv&&) = delete; + TemplateEnv& operator=(const TemplateEnv& parent) = delete; + TemplateEnv& operator=(TemplateEnv&& parent) = delete; + ~TemplateEnv() = default; + + using string_list = std::vector; + + // Add a string 'v' to the map at key 'k'. + void s(const std::string& k, const std::string& v) { + strings_[k] = v; + lists_.erase(k); + } + + // Add a number 'v' to the map at key 'k' + template + void d(const std::string& k, const T& v) { + strings_[k] = std::to_string(v); + lists_.erase(k); + } + + // Retrieve the string representation of the value stored at 'k' from the map. + // Raises an exception if the key is not found. + const std::string& s(const std::string& k) const { + if (strings_.count(k) == 0) { + if (parent) { + return parent->s(k); + } + notFound(k); + } + return strings_.at(k); + } + + // Store a list of strings 'v' in the map at 'k'. + void v(const std::string& k, const string_list& v) { + lists_[k] = v; + strings_.erase(k); + } + + // Retrieve a list of strings stored at 'k' from the map. + // Raises an exception if the key is not found. + const string_list& v(const std::string& k) const { + if (lists_.count(k) == 0) { + if (parent) { + return parent->v(k); + } + notFound(k); + } + return lists_.at(k); + } + + // Test if a string 'k' is a string (as opposed to a list.) + bool keyIsString(const std::string& k) const { + if (strings_.count(k) > 0) + return true; + if (lists_.count(k) > 0) + return false; + if (parent) + return parent->keyIsString(k); + notFound(k); + } + + private: + [[noreturn]] void notFound(const std::string& k) const { + std::stringstream ss; + ss << "key not found: " << k; + throw std::logic_error(ss.str()); + } + + std::unordered_map strings_; + std::unordered_map lists_; + TemplateEnv* parent{nullptr}; +}; + +/* +# Match $identifier or ${identifier} and replace with the value in env. +# If this identifier is at the beginning of whitespace on a line +# and its value is a list then it is treated as +# block substitution by indenting all lines of all elements. +# If the identifier is on a line starting with non-whitespace and a list +# then it is comma separated. ${,foo} will insert a comma before the list +# if this list is not empty and ${foo,} will insert one after. +*/ +struct CodeTemplate { + /* implicit */ CodeTemplate(std::string t) : template_text(std::move(t)) {} + + std::string format(const TemplateEnv& env) const { + std::stringstream out; + size_t pos = 0; + size_t indent = 0; + bool all_whitespace = true; + while (pos < template_text.size()) { + char c = template_text[pos]; + if (c == '$') { + std::stringstream kss; + bool comma_before = false; + bool comma_after = false; + size_t new_pos = parseKey(pos, kss, comma_before, comma_after); + std::string k = kss.str(); + bool is_string = env.keyIsString(k); + if (all_whitespace) { + if (is_string) + emitStringWithIndents(out, indent, env.s(k)); + else + emitLinesIndented(out, indent, env.v(k)); + } else { + if (is_string) + out << env.s(k); + else + emitCommaSeparatedList(out, env.v(k), comma_before, comma_after); + } + all_whitespace = false; + pos = new_pos; + } else { + out << c; + if (!isspace(c)) + all_whitespace = false; + indent++; + if (c == '\n') { + indent = 0; + all_whitespace = true; + } + pos++; + } + } + return out.str(); + } + + private: + using string_list = std::vector; + char charAt(size_t p) const { + if (p >= template_text.size()) + throw std::logic_error("EOS found in key"); + return template_text[p]; + } + size_t parseKey( + size_t pos, + std::ostream& k, + bool& comma_before, + bool& comma_after) const { + comma_before = false; + comma_after = false; + pos++; + if (charAt(pos) == '{') { + pos++; + if (charAt(pos) == ',') { + comma_before = true; + pos++; + } + pos = parseIdent(pos, k); + if (charAt(pos) == ',') { + comma_after = true; + pos++; + } + if (charAt(pos) != '}') + throw std::logic_error("missing terminating '}'"); + pos++; + return pos; + } else { + return parseIdent(pos, k); + } + } + size_t parseIdent(size_t pos, std::ostream& k) const { + while (pos < template_text.size() && + (isalnum(template_text[pos]) || template_text[pos] == '_')) { + k << template_text[pos]; + pos++; + } + return pos; + } + void emitCommaSeparatedList( + std::ostream& out, + const string_list& strings, + bool comma_before, + bool comma_after) const { + if (comma_before && !strings.empty()) + out << ", "; + for (const auto i : c10::irange(strings.size())) { + if (i > 0) + out << ", "; + out << strings[i]; + } + if (comma_after && !strings.empty()) + out << ", "; + } + // These indentation functions follow the convention that they never emit + // leading or trailing newlines when the input string does not have leading + // or trailing newlines. It's the responsibility of the calling function + // to indent correctly in the context. + void emitIndent(std::ostream& out, size_t indent) const { + for ([[maybe_unused]] const auto i : c10::irange(indent)) { + out << " "; + } + } + void emitStringWithIndents( + std::ostream& out, + size_t indent, + const std::string& str) const { + for (auto c : str) { + out << c; + if (c == '\n') { + emitIndent(out, indent); + } + } + } + void emitLinesIndented( + std::stringstream& out, + size_t indent, + const string_list& strings) const { + for (const auto i : c10::irange(strings.size())) { + if (i > 0) + emitIndent(out, indent); + emitStringWithIndents(out, indent, strings[i]); + if (i + 1 != strings.size()) + out << "\n"; + } + } + std::string template_text; +}; + +static inline std::string format(const std::string& fmt, TemplateEnv& env) { + return CodeTemplate(fmt).format(env); +} + +} // namespace at::jit diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ATenGeneral.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ATenGeneral.h new file mode 100644 index 0000000000000000000000000000000000000000..8f411e535837a17c272762ccbd2714e15a1466cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ATenGeneral.h @@ -0,0 +1,3 @@ +#pragma once + +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ATenOpList.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ATenOpList.h new file mode 100644 index 0000000000000000000000000000000000000000..6dfed2b9398544bb43938cdcc8243cfb10d9be32 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ATenOpList.h @@ -0,0 +1,13 @@ +#pragma once + +#include + +namespace c10 { +struct OperatorName; +} + +namespace at { + +// check if an op is a custom op (i.e. did not come from native_functions.yaml) +TORCH_API bool is_custom_op(const c10::OperatorName& opName); +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ATen_fwd.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ATen_fwd.h new file mode 100644 index 0000000000000000000000000000000000000000..263e339c5bd6c7d4362771bc078ca8d980e042ec --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ATen_fwd.h @@ -0,0 +1,46 @@ +#pragma once +#include + +// Forward declarations of core ATen types used in dispatch functions +namespace c10 { + +template +class List; +template +class IListRef; +class Stream; +class Scalar; +class SymInt; +class SymIntList; +struct Storage; +struct TensorOptions; +template +class ArrayRef; +template +class OptionalArrayRef; + +} // namespace c10 + +namespace at { + +class Tensor; +class OptionalTensorRef; +struct Dimname; +struct Generator; +using TensorList = c10::ArrayRef; +using ITensorListRef = c10::IListRef; +using IOptTensorListRef = c10::IListRef; +using DimnameList = c10::ArrayRef; +using IntArrayRef = c10::ArrayRef; +using OptionalIntArrayRef = c10::OptionalArrayRef; +using OptionalSymIntArrayRef = c10::OptionalArrayRef; + +using c10::Stream; +using c10::Storage; +using c10::QScheme; +using c10::Scalar; +using c10::SymInt; +using c10::SymIntList; +using c10::TensorOptions; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/ATen_pch.h b/phivenv/Lib/site-packages/torch/include/ATen/core/ATen_pch.h new file mode 100644 index 0000000000000000000000000000000000000000..b7a59b565da29953478f26256e064f957f224abc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/ATen_pch.h @@ -0,0 +1,161 @@ +// This global header must not depend on native_functions.yaml or +// incremental builds will be next to useless +#pragma push_macro("TORCH_ASSERT_NO_OPERATORS") +#define TORCH_ASSERT_NO_OPERATORS + +#include + +// This list of headers was generated using a script that finds +// high-impact headers and then manually tweaked to remove OS specific +// or duplicate headers (e.g. and ) and to remove +// "impl" headers (e.g BFloat16-inl.h or complex_math.h in c10). + +// To generate the initial list: +// 1. Build pytorch from scratch with all build caching disabled +// 2. Generate a build trace with ninjatracing (https://github.com/nico/ninjatracing) +// $ ninjatracing /path/to/pytorch/build/.ninja_log > trace_all.json +// 3. Run pch_gen.py from https://github.com/peterbell10/build_analysis/ +// $ python pch_gen.py --threshold .80 --target torch_cpu --build_dir /path/to/pytorch/build --trace trace_all.json +// Where the threshold can be tweaked until c10 and some of ATen +// core are included but TORCH_ASSERT_NO_OPERATORS still passes. + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#pragma pop_macro("TORCH_ASSERT_NO_OPERATORS") diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Array.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Array.h new file mode 100644 index 0000000000000000000000000000000000000000..4dd9415208da48958d0f3cbe11e073f6027b9b98 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Array.h @@ -0,0 +1,48 @@ +#pragma once + +// A fixed-size array type usable from both host and +// device code. + +#include +#include + +namespace at::detail { + +template +struct Array { + // NOLINTNEXTLINE(*c-array*) + T data[size_]; + + C10_HOST_DEVICE T operator[](int i) const { + return data[i]; + } + C10_HOST_DEVICE T& operator[](int i) { + return data[i]; + } +#if defined(USE_ROCM) + C10_HOST_DEVICE Array() = default; + C10_HOST_DEVICE Array(const Array&) = default; + C10_HOST_DEVICE Array& operator=(const Array&) = default; + C10_HOST_DEVICE Array(Array&&) = default; + C10_HOST_DEVICE Array& operator=(Array&&) = default; + C10_HOST_DEVICE ~Array() = default; +#else + Array() = default; + Array(const Array&) = default; + Array& operator=(const Array&) = default; + Array(Array&&) noexcept = default; + Array& operator=(Array&&) noexcept = default; + ~Array() = default; +#endif + static constexpr int size() { + return size_; + } + // Fill the array with x. + C10_HOST_DEVICE Array(T x) { + for (int i = 0; i < size_; i++) { + data[i] = x; + } + } +}; + +} // namespace at::detail diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Backtrace.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Backtrace.h new file mode 100644 index 0000000000000000000000000000000000000000..684825dc2ba32d0dd84284f08591ec0ec314980f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Backtrace.h @@ -0,0 +1,2 @@ +#include +#include diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/CachingHostAllocator.h b/phivenv/Lib/site-packages/torch/include/ATen/core/CachingHostAllocator.h new file mode 100644 index 0000000000000000000000000000000000000000..298f5497acf788e3e5271578bcece05132efbd35 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/CachingHostAllocator.h @@ -0,0 +1,737 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +C10_DIAGNOSTIC_PUSH_AND_IGNORED_IF_DEFINED("-Wunused-parameter") +namespace at { + +using c10::CachingAllocator::Stat; +using c10::CachingAllocator::DurationStat; + +/** + * HostBlock is typically a fundamental memory block used in pinned memory. It + * is likely related to Event and Stream of device runtime. It is probably a + * base struct or interface that can be inherited and extended by each backend. + */ +template +struct HostBlock { + // constructor for search key + HostBlock(size_t size) : size_(size) {} + + HostBlock(size_t size, void* ptr) : size_(size), ptr_(ptr) {} + + std::mutex mutex_; + size_t size_{0}; // block size in bytes + void* ptr_{nullptr}; // memory address + bool allocated_{false}; // in-use flag + size_t event_count_{0}; // number of related events + ska::flat_hash_set streams_; // streams on which the block was used +}; + +template +struct alignas(64) FreeBlockList { + std::mutex mutex_; + std::deque list_; +}; + +namespace { + // Max cached block sizes: (1 << MAX_SIZE_INDEX) bytes + // NOLINTNEXTLINE(misc-definitions-in-headers) + constexpr size_t MAX_SIZE_INDEX = 64; +} + +// Struct containing memory allocator summary statistics for host. +struct TORCH_API HostStats { + // COUNT: allocations requested by client code. Note that active + // count can be extracted by looking at current allocations + Stat allocation; + // COUNT: number of allocated segments from host memory allocation. + Stat segment; + + // SUM: bytes allocated by this memory alocator. Note that active bytes + // can be extracted by looking at current bytes allocated + Stat allocated_bytes; + // SUM: bytes reserved by this memory allocator (both free and used) + Stat reserved_bytes; + + // SUM: time spent in cudaHostAlloc/cudaHostRegister in microseconds + DurationStat host_alloc_time; + + // SUM: time spent in cudaHostFree/cudaHostUnregister in microseconds + DurationStat host_free_time; + + // COUNT: number of times cudaHostAlloc/cudaHostRegister was called because + // the request could not be satisfied from existing free blocks. + int64_t num_host_alloc = 0; // This is derived from segment or timing + + // COUNT: number of times cudaHostFree/cudaHostUnregister was called. + int64_t num_host_free = 0; // This is derived from segment or timing +}; + +// Struct containing memory allocator summary statistics for host, as they +// are staged for reporting. This is a temporary struct that is used to +// avoid locking the allocator while collecting stats. +struct alignas(64) HostStatsStaged { + std::mutex timing_mutex_; + // COUNT: allocations requested by client code resulting in a new segment/block allocation + // LOCK: access to this stat is protected by the allocator's blocks_mutex_ + Stat allocation; + // SUM: bytes within active memory blocks, including blocks that are + // currently in the free list. + // LOCK: access to this stat is protected by the allocator's blocks_mutex_ + Stat allocated_bytes; + // COUNT: number of allocations per bucket + // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_ + std::vector allocation_bucket_stats = std::vector(MAX_SIZE_INDEX); + // SUM: bytes of allocation per bucket + // LOCK: access to this stat is protected by the per bucket free_list_[index].mutex_ + std::vector allocated_bytes_bucket_stats = std::vector(MAX_SIZE_INDEX); + // SUM: time spent in cudaHostAlloc/cudaHostRegister + // LOCK: access to this stat is protected by the timing_mutex_ + DurationStat host_alloc_time; + // SUM: time spent in cudaHostFree/cudaHostUnregister + // LOCK: access to this stat is protected by the timing_mutex_ + DurationStat host_free_time; +}; + +/** + * Note [HostAllocator design] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + * We have three key data structures - the free list which stores blocks that + * are not currently used, the block list which stores all blocks that have been + * allocated, and the event queue which stores runtime events and their + * corresponding blocks. + * + * Each of these are protected by a separate mutex. The key design principles + * are to 1) only hold each mutex for the minimal amount of time possible, 2) + * never do any possible expensive operations (such as CUDA runtime API calls) + * while holding the lock. + * + * There are four public methods: allocate, free, record_event and empty_cache. + * 1) In the allocate path, we first check to see if we can service our + * request from this free list, and otherwise we create a new block with + * allocate_host_memory. + * 2) In the free path, we insert events (if required) into the event queue, + * and if possible insert our block back into the free list. In allocate, we + * first eagerly query events until we find one that is not ready, and insert + * the corresponding block onto the free list if all the events recorded for a + * block are ready. + * 3) In the record_event path, we simply insert the given stream into the set + * of streams tracked by the specified block. This set of streams is then + * consumed in the free path. + * 4) In the empty_cache path, we flush any available blocks into the free + * list. Remove all element of free list, then remove them from block list and + * release the associated pinned memory allocation via free_block. + * + * We generalize the caching host allocator into two parts: interface and + * implementation. For any new backend looking to integrate with host allocator + * and reuse caching mechanism, these two parts are necessary to be specialized. + * + * For the implementation, we provide a CachingHostAllocatorImpl struct + * to abstract the caching mechanism. Any backend needs to provide a customized + * implementation by specializing its own public functions and the related + * runtime functions. Its template parameter S represents runtime Stream, E + * denotes runtime Event, B indicates the fundamental memory block. + * + * For the interface, we provide a CachingHostAllocatorInterface struct as an + * interface. Any backend needs to derive its own host allocator from this + * interface. Its template parameter T refers to an implementation that + * inherited from CachingHostAllocatorImpl. + * + * So this design can share the caching mechanism across each backend, and + * provide flexibility to each backend. A backend can choose to follow this + * implementation or reuse them by extending and overriding them as necessary. + * Taking CUDA as an example, it specializes runtime related functions to reuse + * the caching mechanism. Additionally, it extends the allocator's functionality + * by adding the allocWithCudaHostRegister function to support page-locking the + * memory range used by CUDA. Of course, you can also refer to + * XPUCachingHostAllocator, which is a host caching allocator supported on XPU + * backend, to implement a basic host caching allocator. + * + * Some of the invariants here are less strict than they could be - for example, + * we do not enforce that free(Block* block) => block->event_count == 0. This is + * for compatibility reasons, and we can explore enforcing these in subsequent + * versions. + * + * Note that this caching host allocator does not split larger allocations into + * smaller blocks, unlike the caching device allocator. + * + * In order to gather statistics about caching host allocator while minimally + * impacting performance, we use a HostStatsStaged struct to stage the stats + * before reporting them. This is done to avoid adding new locks to the allocator. + * Collecting stats is carefully done under existing locks, and then the staged + * stats are converted to the final stats when getStats is called. At that time + * we hold the same locks as empty_cache, to ensure the fidelity of the stats. + */ + +template < + typename S, + typename E, + typename B = HostBlock> +struct CachingHostAllocatorImpl { + virtual ~CachingHostAllocatorImpl() { + active_ = false; + if (pinned_use_background_threads()) { + getBackgroundThreadPool()->waitWorkComplete(); + } + } + + public: + // return data_ptr and block pair. + virtual std::pair allocate(size_t size) { + if (size == 0) { + return {nullptr, nullptr}; + } + + // If we are using background threads, we can process events in the + // background. + if (!pinned_use_background_threads()) { + process_events(); + } + + // Round up the allocation to the nearest power of two to improve reuse. + // These power of two sizes are also used to index into the free list. + size_t roundSize = c10::llvm::PowerOf2Ceil(size); + + // First, try to allocate from the free list + auto* block = get_free_block(roundSize); + if (block) { + return {block->ptr_, reinterpret_cast(block)}; + } + + // Check in the recently freed blocks with pending events to see if we + // can reuse them. Call get_free_block again after processing events + if (pinned_use_background_threads()) { + process_events_for_specific_size(roundSize); + block = get_free_block(roundSize); + if (block) { + return {block->ptr_, reinterpret_cast(block)}; + } + + // Launch the background thread and process events in a loop. + static bool background_thread_flag [[maybe_unused]] = [this] { + getBackgroundThreadPool()->run([&]() { + while (active_) { + process_events(); + std::this_thread::sleep_for(std::chrono::microseconds(100)); + } + }); + return true; + }(); + } + + // Slow path: if we can't allocate from the cached free list, we need + // to create a new block. + void* ptr = nullptr; + allocate_host_memory(roundSize, &ptr); + + // Then, create a new block. + block = new B(roundSize, ptr); + block->allocated_ = true; + + add_allocated_block(block); + return {block->ptr_, reinterpret_cast(block)}; + } + + virtual void free(void* ctx) { + if (!ctx) { + return; + } + + // Note: we can assume that free is correctly paired with alloc, and thus we + // do not need to look up the ctx in blocks_. + auto* block = reinterpret_cast(ctx); + + std::optional> events; + { + std::lock_guard g(block->mutex_); + block->allocated_ = false; + if (block->streams_.empty()) { + TORCH_INTERNAL_ASSERT(block->event_count_ == 0); + } else { + events = std::vector(); + events->reserve(block->streams_.size()); + for (auto stream : block->streams_) { + record_stream(events, stream); + } + block->event_count_ += events->size(); + block->streams_.clear(); + } + } + + if (!events) { + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); + stats_.allocation_bucket_stats[index].decrease(1); + stats_.allocated_bytes_bucket_stats[index].decrease(block->size_); + } else { + // restore these events that record by used streams. + std::lock_guard g(events_mutex_); + for (auto&& event : *events) { + events_.emplace_front(std::move(event), block); + } + } + } + + virtual bool record_event(void* ptr, void* ctx, c10::Stream s) { + S stream = S(s); + auto* block = reinterpret_cast(ctx); + + // Note: we need to check if the passed-in `ctx` is valid. This is because + // `record_event` (via `CachingHostAllocator_recordEvent`) can be invoked on + // an arbitrary tensor, and is not guaranteed to correspond to a pinned + // memory allocation. Therefore, we need to check that `ctx` is valid before + // proceeding. + { + std::lock_guard g(blocks_mutex_); + if (blocks_.find(block) != blocks_.end()) { + // Now we know this object is safe to access. + std::lock_guard gb(block->mutex_); + TORCH_INTERNAL_ASSERT(block->allocated_); + block->streams_.insert(stream); + return true; + } + auto it = ptr_to_block_.find(ptr); + if (it != ptr_to_block_.end()) { + block = it->second; + std::lock_guard g(block->mutex_); + TORCH_INTERNAL_ASSERT(block->allocated_); + block->streams_.insert(stream); + return true; + } + } + + return false; + } + + virtual void empty_cache() { + // Flush any available blocks into the free_list. + process_events(); + + // Remove all elements from the free list, remove them from the blocks + // list, and free the associated pinned memory allocation. This requires + // concurrently holding both the free list mutexes and the blocks mutex, and + // is the only function that concurrently holds multiple mutexes. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + std::vector blocks_to_remove(free_list_[i].list_.begin(), free_list_[i].list_.end()); + free_list_[i].list_.clear(); + + for (auto* block : blocks_to_remove) { + blocks_.erase(block); + ptr_to_block_.erase(block->ptr_); + stats_.allocation.decrease(1); + stats_.allocated_bytes.decrease(block->size_); + free_block(block); + delete block; + } + } + } + + inline size_t size_index(size_t size) { + return c10::llvm::Log2_64_Ceil(size); + } + + virtual bool pinned_use_background_threads() { + return false; + } + + virtual void copy_data(void* dest [[maybe_unused]], const void* src [[maybe_unused]], std::size_t count [[maybe_unused]]) const { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for copy_data"); + } + + HostStats getStats() { + HostStats stats; + + // To keep getStats lightweight we do *not* flush any available blocks + // into the free_list. This may skew the stats a bit. + + auto add_bucket_stats = [](Stat& accumulator, const Stat& other) { + accumulator.allocated += other.allocated; + accumulator.current += other.current; + accumulator.freed += other.freed; + // Since peaks are measured per bucket independently, we add them up + // to estimate the total peak. This is not strictly correct, but it is + // the best approximation we can get after the fact. + accumulator.peak += other.peak; + }; + + // Accurate reading of memory stats requires concurrently holding both the + // free list mutexes and the blocks mutex. Previously, this was only done in + // empty_cache function. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + // We collect the slow-path stats only once, since they are not collected + // per bucket (we pick index 0 arbitrarily). These are also all the host + // allocations, not taking into account caching and free lists. + if (i == 0) { + stats.segment = stats_.allocation; + stats.reserved_bytes = stats_.allocated_bytes; + stats.num_host_alloc = stats.segment.allocated; + stats.num_host_free = stats.segment.freed; + } + + // Bucket stats need to be merged with the slow-path stats. We do this in + // a best effort manner, since we can't really replay the cached events per bucket. + add_bucket_stats(stats.allocation, stats_.allocation_bucket_stats[i]); + add_bucket_stats(stats.allocated_bytes, stats_.allocated_bytes_bucket_stats[i]); + } + + // Get the timing stats + { + std::lock_guard g(stats_.timing_mutex_); + + stats.host_alloc_time = stats_.host_alloc_time; + stats.host_free_time = stats_.host_free_time; + } + + return stats; + } + + void resetAccumulatedStats() { + // Reseting accumulated memory stats requires concurrently holding both the + // free list mutexes and the blocks mutex. Previously, this was only done in + // empty_cache function. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + if (i == 0) { + stats_.allocation.reset_accumulated(); + stats_.allocated_bytes.reset_accumulated(); + } + stats_.allocation_bucket_stats[i].reset_accumulated(); + stats_.allocated_bytes_bucket_stats[i].reset_accumulated(); + } + + // Also reset timing stats + { + std::lock_guard g(stats_.timing_mutex_); + stats_.host_alloc_time.reset_accumulated(); + stats_.host_free_time.reset_accumulated(); + } + } + + void resetPeakStats() { + // Reseting peak memory stats requires concurrently holding both the + // free list mutexes and the blocks mutex. Previously, this was only done in + // empty_cache function. + for (size_t i = 0; i < free_list_.size(); ++i) { + std::lock(free_list_[i].mutex_, blocks_mutex_); + std::lock_guard gf(free_list_[i].mutex_, std::adopt_lock); + std::lock_guard gb(blocks_mutex_, std::adopt_lock); + + if (i == 0) { + stats_.allocation.reset_peak(); + stats_.allocated_bytes.reset_peak(); + } + stats_.allocation_bucket_stats[i].reset_peak(); + stats_.allocated_bytes_bucket_stats[i].reset_peak(); + } + + // Also reset timing stats + { + std::lock_guard g(stats_.timing_mutex_); + stats_.host_alloc_time.reset_peak(); + stats_.host_free_time.reset_peak(); + } + } + + private: + virtual void add_allocated_block(B* block) { + std::lock_guard g(blocks_mutex_); + blocks_.insert(block); + stats_.allocation.increase(1); + stats_.allocated_bytes.increase(block->size_); + ptr_to_block_.insert({block->ptr_, block}); + + // Unfortunately, we have to, on the slow path, quickly + // lock the bucket to record the allocation. This should + // be a rare event once the cache is warmed up. + auto size = block->size_; + auto index = size_index(size); + { + std::lock_guard g(free_list_[index].mutex_); + stats_.allocation_bucket_stats[index].increase(1); + stats_.allocated_bytes_bucket_stats[index].increase(size); + } + } + + virtual B* get_free_block(size_t size) { + auto index = size_index(size); + std::lock_guard g(free_list_[index].mutex_); + if (!free_list_[index].list_.empty()) { + B* block = free_list_[index].list_.back(); + free_list_[index].list_.pop_back(); + block->allocated_ = true; + stats_.allocation_bucket_stats[index].increase(1); + stats_.allocated_bytes_bucket_stats[index].increase(size); + return block; + } + return nullptr; + } + + virtual void process_events() { + // process all events until the last unready event, not for specific size. + process_events_for_specific_size(-1); + } + + // If size is -1, process all events from backwards until the last unready + // event. Otherwise, process events for a specific size and on first ready block + // is found, add it to the free list and return. + virtual void process_events_for_specific_size(int64_t size) { + size_t event_count = 0; + size_t max_events = 0; + { + std::lock_guard g(events_mutex_); + max_events = events_.size(); + } + + while (true) { + // Avoid calling cudaEventDestroy while holding a mutex, so move + // intermediate events out of the lock into this object. + // process the last event + std::optional> processed; + { + std::lock_guard g(events_mutex_); + if (!events_.empty()) { + processed = std::move(events_.back()); + events_.pop_back(); + } + } + + if (!processed) { + return; + } + + if (size != -1) { + if (event_count++ > max_events) { + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + return; + } + if (size != (int64_t)processed->second->size_) { + // if we are processing a specific size, and the size of the block + // doesn't match, we can't use it. + { + std::lock_guard g(events_mutex_); + events_.push_front(std::move(*processed)); + } + continue; + } + } + + // otherwise, query the event + { + // now, see if we can handle this element + auto& event = processed->first; + if (!query_event(event)) { + // push the event onto the back if it's not ready. + { + std::lock_guard g(events_mutex_); + if (size == -1) { + events_.push_back(std::move(*processed)); + return; + } else { + events_.push_front(std::move(*processed)); + continue; + } + } + } + } + + // Process the events. + TORCH_INTERNAL_ASSERT(processed); + auto* block = processed->second; + bool available = false; + { + std::lock_guard g(block->mutex_); + TORCH_INTERNAL_ASSERT(!block->allocated_) + block->event_count_--; + if (block->event_count_ == 0) { + available = true; + } + } + + if (available) { + auto index = size_index(block->size_); + std::lock_guard g(free_list_[index].mutex_); + free_list_[index].list_.push_back(block); + stats_.allocation_bucket_stats[index].decrease(1); + stats_.allocated_bytes_bucket_stats[index].decrease(size); + if (size != -1) { + return; + } + } + } + } + + TaskThreadPool* getBackgroundThreadPool() { + static TaskThreadPool* pool = new TaskThreadPool(1); + return pool; + } + + /* These following functions are runtime-related. */ + + // Allocate page-locked memory on the host. + virtual void allocate_host_memory(size_t size, void** ptr) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "Not implemented for allocate_host_memory"); + } + + // Free block and release the pointer contained in block. + virtual void free_block(B* block) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for free_block"); + } + + // Record an event on stream and store event into events. + virtual void record_stream(std::optional>& events, S stream) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for record_stream"); + } + + // Query event if it is completed. + virtual bool query_event(E& event) { + TORCH_CHECK_NOT_IMPLEMENTED(false, "Not implemented for query_event"); + } + + alignas(64) std::mutex blocks_mutex_; + ska::flat_hash_set blocks_; // block list + ska::flat_hash_map ptr_to_block_; + + // We keep free list as a vector of free lists, one for each power of two + // size. This allows us to quickly find a free block of the right size. + // We use deque to store per size free list and guard the list with its own + // mutex. + alignas(64) std::vector> free_list_ = + std::vector>(MAX_SIZE_INDEX); + + alignas(64) std::mutex events_mutex_; + std::deque> events_; // event queue paired with block + + // Indicates whether the object is active. + // Set to false in the destructor to signal background threads to stop. + std::atomic active_{true}; +protected: + alignas(64) HostStatsStaged stats_; +}; + +struct TORCH_API HostAllocator : public at::Allocator { + // Associates the pinned memory allocation with a stream to track + // dependencies. This ensures the memory won't be reused until the stream's + // operations complete + virtual bool record_event(void* ptr, void* ctx, c10::Stream stream) = 0; + + // Frees all cached pinned memory and returns it to the system, clearing the + // allocator's internal cache + virtual void empty_cache() = 0; + + // Returns comprehensive statistics about the allocator's memory usage, + // allocation patterns, and timing metrics + virtual HostStats get_stats() = 0; + + // Resets the cumulative allocation statistics + virtual void reset_accumulated_stats() = 0; + + // Resets the peak memory usage metrics + virtual void reset_peak_stats() = 0; +}; + +template +struct CachingHostAllocatorInterface : public HostAllocator { + CachingHostAllocatorInterface() : impl_(std::make_unique()) {} + + at::DataPtr allocate(size_t size) override { + auto ptr_and_ctx = impl_->allocate(size); + return { + ptr_and_ctx.first, + ptr_and_ctx.second, + deleteFunc, // Use the template parameter deleter function + at::DeviceType::CPU}; + } + + void free(void* ctx) { + impl_->free(ctx); + } + + bool record_event(void* ptr, void* ctx, c10::Stream stream) override { + return impl_->record_event(ptr, ctx, stream); + } + + void empty_cache() override { + impl_->empty_cache(); + } + + void copy_data(void* dest, const void* src, std::size_t count) + const override { + impl_->copy_data(dest, src, count); + } + + HostStats get_stats() override { + return impl_->getStats(); + } + + void reset_accumulated_stats() override { + impl_->resetAccumulatedStats(); + } + + void reset_peak_stats() override { + impl_->resetPeakStats(); + } + + std::unique_ptr impl_; +}; + +#define DECLARE_HOST_ALLOCATOR(name, impl, deleter, instance) \ + void deleter(void* ptr); \ + struct name final \ + : public at::CachingHostAllocatorInterface {}; \ + static name instance; \ + void deleter(void* ptr) { \ + instance.free(ptr); \ + } + +/** + * Set the host allocator for DeviceType `device_type`. This allocator manages + * pinned memory on the host that can be accessed efficiently by the specified + * device type. Note that this function is not thread-safe. + */ +TORCH_API void setHostAllocator( + at::DeviceType device_type, + at::HostAllocator* allocator, + uint8_t priority = 0); + +TORCH_API at::HostAllocator* getHostAllocator(at::DeviceType device_type); + +template +struct HostAllocatorRegistry { + explicit HostAllocatorRegistry(HostAllocator* allocator) { + at::setHostAllocator(device_type, allocator); + } +}; + +#define REGISTER_HOST_ALLOCATOR(device_type, allocator) \ + namespace { \ + static at::HostAllocatorRegistry \ + g_host_allocator_registry_instance(allocator); \ + } + +} // namespace at +C10_DIAGNOSTIC_POP() diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/CheckMemoryFormat.h b/phivenv/Lib/site-packages/torch/include/ATen/core/CheckMemoryFormat.h new file mode 100644 index 0000000000000000000000000000000000000000..582480aa960ecde4b9e72a8489501f3e3f7e1047 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/CheckMemoryFormat.h @@ -0,0 +1,24 @@ +#include + +namespace c10::impl { + +inline std::optional +check_tensor_options_and_extract_memory_format( + const TensorOptions& options, + std::optional memory_format) { + TORCH_CHECK( + options.requires_grad_opt() != true, + "Operators taking TensorOptions cannot take a TensorOptions with " + "options.requires_grad set as true. This isn't implemented yet."); + TORCH_CHECK( + !(options.has_memory_format() && memory_format.has_value()), + "Cannot set memory_format both in TensorOptions and explicit argument; please delete " + "the redundant setter."); + if (memory_format.has_value()) { + return memory_format; + } else { + return options.memory_format_opt(); + } +} + +} // namespace impl namespace c10 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h b/phivenv/Lib/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h new file mode 100644 index 0000000000000000000000000000000000000000..4bff7ad3ecbbd4f28906651ac21dc25c44e1881d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/DeprecatedTypeProperties.h @@ -0,0 +1,139 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + + +namespace at { + +class Tensor; + +// This class specifies a Backend and a ScalarType. Currently, it primarily +// serves as a replacement return value for Tensor::type(). Previously, +// Tensor::type() returned Type&, but we are changing Type to not be +// dtype-specific. +class TORCH_API DeprecatedTypeProperties { + public: + DeprecatedTypeProperties(Backend backend, ScalarType scalar_type) + : backend_(backend), scalar_type_(scalar_type) {} + + Backend backend() const { + return backend_; + } + + Layout layout() const { + return layout_from_backend(backend_); + } + + bool is_sparse() const { + return layout_from_backend(backend()) == kSparse; + } + + bool is_sparse_csr() const { + return layout_from_backend(backend()) == kSparseCsr; + } + + c10::DeviceType device_type() const { + return backendToDeviceType(backend_); + } + + bool is_cuda() const { + return backendToDeviceType(backend_) == kCUDA; + } + + ScalarType scalarType() const { + return scalar_type_; + } + + caffe2::TypeMeta typeMeta() const { + return scalarTypeToTypeMeta(scalar_type_); + } + + bool operator==(const DeprecatedTypeProperties& other) const { + return backend_ == other.backend() && scalar_type_ == other.scalarType(); + } + + bool operator!=(const DeprecatedTypeProperties& other) const { + return !(*this == other); + } + + std::string toString() const { + std::string base_str; + if (backend_ == Backend::Undefined || scalar_type_ == ScalarType::Undefined) { + base_str = "UndefinedType"; + } else { + base_str = std::string(at::toString(backend_)) + at::toString(scalar_type_) + "Type"; + } + return base_str; + } + + DeprecatedTypeProperties & toBackend(Backend b) const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + b, scalar_type_); + } + + DeprecatedTypeProperties & toScalarType(ScalarType s) const { + return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( + backend_, s); + } + + DeprecatedTypeProperties & cpu() const { + return toBackend(Backend::CPU); + } + + DeprecatedTypeProperties & cuda() const { + return toBackend(Backend::CUDA); + } + + DeprecatedTypeProperties & hip() const { + return toBackend(Backend::HIP); + } + + DeprecatedTypeProperties & privateUser1() const { + return toBackend(Backend::PrivateUse1); + } + + /// Constructs the `TensorOptions` from a type and a `device_index`. + TensorOptions options(int16_t device_index = -1) const { + return TensorOptions().dtype(typeMeta()) + .device(device_type(), static_cast(device_index)) + .layout(layout()); + } + + /// Constructs the `TensorOptions` from a type and a Device. Asserts that + /// the device type matches the device type of the type. + TensorOptions options(std::optional device_opt) const { + if (!device_opt.has_value()) { + return options(-1); + } else { + Device device = device_opt.value(); + AT_ASSERT(device.type() == device_type()); + return options(device.index()); + } + } + + operator TensorOptions() const { + return options(); + } + + int64_t id() const { + return static_cast(backend()) * + static_cast(ScalarType::NumOptions) + + static_cast(scalarType()); + } + + Tensor unsafeTensorFromTH(void * th_pointer, bool retain) const; + Storage unsafeStorageFromTH(void * th_pointer, bool retain) const; + Tensor copy(const Tensor & src, bool non_blocking=false, std::optional to_device={}) const; + + private: + Backend backend_; + ScalarType scalar_type_; +}; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h b/phivenv/Lib/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h new file mode 100644 index 0000000000000000000000000000000000000000..984a9892fee0322219da8baf314d23c3870b91fd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/DeprecatedTypePropertiesRegistry.h @@ -0,0 +1,33 @@ +#pragma once + +// In order to preserve bc, we make DeprecatedTypeProperties instances unique +// just like they are for Type. + +#include +#include +#include + +namespace at { + +class DeprecatedTypeProperties; + +struct TORCH_API DeprecatedTypePropertiesDeleter { + void operator()(DeprecatedTypeProperties * ptr); +}; + +class TORCH_API DeprecatedTypePropertiesRegistry { + public: + DeprecatedTypePropertiesRegistry(); + + DeprecatedTypeProperties& getDeprecatedTypeProperties(Backend p, ScalarType s) const; + +private: + // NOLINTNEXTLINE(*c-array*) + std::unique_ptr registry + [static_cast(Backend::NumOptions)] + [static_cast(ScalarType::NumOptions)]; +}; + +TORCH_API DeprecatedTypePropertiesRegistry& globalDeprecatedTypePropertiesRegistry(); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Dict.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Dict.h new file mode 100644 index 0000000000000000000000000000000000000000..6dde3ad2dd9385d869f9353db7b215bc76247d14 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Dict.h @@ -0,0 +1,396 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +struct IValue; +template class Dict; +struct Type; + +namespace impl { + +using valid_dict_key_types = guts::typelist::typelist< + int64_t, + std::string, + double, + c10::complex, + bool, + at::Tensor +>; +} + +namespace detail { + +struct DictKeyHash { + size_t operator()(const IValue& ivalue) const; +}; + +struct DictKeyEqualTo { + bool operator()(const IValue& lhs, const IValue& rhs) const; +}; + +struct DictImpl final : public c10::intrusive_ptr_target { + using dict_map_type = ska_ordered::order_preserving_flat_hash_map; + struct DictElementTypes final { + TypePtr keyType; + TypePtr valueType; + }; + + explicit DictImpl(dict_map_type dict_, DictElementTypes elementTypes_) + : dict(std::move(dict_)) + , elementTypes(std::move(elementTypes_)) {} + dict_map_type dict; + + DictElementTypes elementTypes; + + intrusive_ptr copy() const; + friend TORCH_API bool operator==(const DictImpl& lhs, const DictImpl& rhs); +}; + +} + +namespace impl { +template class DictIterator; + +/** + * A reference to an entry in the Dict. + * Use the `key()` and `value()` methods to read the element. + */ +template +class DictEntryRef final { +public: + explicit DictEntryRef(Iterator iterator) + : iterator_(std::move(iterator)) {} + + decltype(auto) key() const { + return iterator_->first.template to(); + } + + decltype(auto) value() const { + return iterator_->second.template to(); + } + + template + void setValue(Value_&& value) const { + static_assert(std::is_constructible_v, "Wrong type for the value argument of setValue()"); + iterator_->second = Value(std::forward(value)); + } + ~DictEntryRef() = default; + +private: + // allow copying and moving, but only our friends (i.e. the Dict class) can do + // it. Copying/moving this reference wrapper would be too ambiguous to allow it + // in the public API. + DictEntryRef(const DictEntryRef&) = default; + DictEntryRef& operator=(const DictEntryRef&) = default; + DictEntryRef(DictEntryRef&&) noexcept = default; + DictEntryRef& operator=(DictEntryRef&& rhs) & noexcept = default; + + Iterator iterator_; + friend class DictIterator; + friend class Dict; +}; + +// this wraps map_type::iterator to make sure user code can't rely +// on it being the type of the underlying map. +template +class DictIterator final { +public: + // C++17 friendly std::iterator implementation + using iterator_category = std::forward_iterator_tag; + using value_type = DictEntryRef; + using difference_type = std::ptrdiff_t; + using pointer = value_type*; + using reference = value_type&; + + explicit DictIterator() = default; + ~DictIterator() = default; + + DictIterator(const DictIterator& rhs): entryRef_(rhs.entryRef_) {} + DictIterator(DictIterator&& rhs) noexcept: entryRef_(std::move(rhs.entryRef_)) {} + DictIterator& operator=(const DictIterator& rhs) = default; + DictIterator& operator=(DictIterator&& rhs) noexcept { + entryRef_ = std::move(rhs.entryRef_); + return *this; + } + + DictIterator& operator++() { + ++entryRef_.iterator_; + return *this; + } + + DictIterator operator++(int) { + DictIterator copy(*this); + ++*this; + return copy; + } + + const DictEntryRef& operator*() const { + return entryRef_; + } + + const DictEntryRef* operator->() const { + return &entryRef_; + } + + friend difference_type operator-(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.entryRef_.iterator_ - rhs.entryRef_.iterator_; + } + +private: + explicit DictIterator(Iterator iterator): entryRef_(std::move(iterator)) {} + + const Iterator& get_iterator_() const { + return entryRef_.iterator_; + } + + friend bool operator==(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() == rhs.get_iterator_(); + } + + friend bool operator!=(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() != rhs.get_iterator_(); + } + + friend bool operator<(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() < rhs.get_iterator_(); + } + + friend bool operator<=(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() <= rhs.get_iterator_(); + } + + friend bool operator>(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() > rhs.get_iterator_(); + } + + friend bool operator>=(const DictIterator& lhs, const DictIterator& rhs) { + return lhs.get_iterator_() >= rhs.get_iterator_(); + } + + DictEntryRef entryRef_; + + friend class DictIterator; + friend class Dict; +}; + +template Dict toTypedDict(Dict dict); +template Dict toGenericDict(Dict dict); +} + +/** + * An object of this class stores a map from Key to Value. + * + * This is a pointer type. After a copy, both Dicts + * will share the same storage: + * + * > Dict a; + * > Dict b = a; + * > b.insert(3, "three"); + * > ASSERT("three" == a.at(3)); + * + * We use this class in the PyTorch kernel API because that + * allows us to do optimizations and switch out the underlying + * map implementation without breaking backwards compatibility + * for the kernel API. + */ +template +// NOLINTNEXTLINE(cppcoreguidelines-special-member-functions) +class Dict final { +private: + static_assert((std::is_same_v && std::is_same_v) || guts::typelist::contains::value, "Invalid Key type for Dict. We only support int64_t, double, bool, and string."); + + // impl_ stores the underlying map as a ska_ordered::order_preserving_flat_hash_map. + // We intentionally don't offer conversion from/to + // order_preserving_flat_hash_map, return references to it or something like that, + // because such operations would get expensive if we switch out + // the actual map implementation. + // This is an intrusive_ptr because Dict is a pointer type. + // Invariant: This will never be a nullptr, there will always be a valid + // DictImpl. + c10::intrusive_ptr impl_; + + explicit Dict(c10::intrusive_ptr&& impl); + friend struct IValue; + template friend Dict impl::toTypedDict(Dict); + template friend Dict impl::toGenericDict(Dict); + +public: + using key_type = Key; + using mapped_type = Value; + using size_type = typename detail::DictImpl::dict_map_type::size_type; + using iterator = impl::DictIterator; + + /** + * Creates an empty dict. + */ + explicit Dict(); + + /** + * Create a generic dict with runtime type information. + * This only works for c10::impl::GenericDict and is not part of the public API + * but only supposed to be used internally by PyTorch. + */ + explicit Dict(TypePtr keyType, TypePtr valueType); + + ~Dict() = default; + + Dict(const Dict&) = default; + Dict& operator=(const Dict&) = default; + + /** + * Create a new Dict pointing to a deep copy of the same data. + * The Dict returned is a new dict with separate storage. + * Changes in it are not reflected in the original dict or vice versa. + */ + Dict copy() const; + + /** + * Returns an iterator to the first element of the container. + * If the container is empty, the returned iterator will be equal to end(). + */ + iterator begin() const; + + /** + * Returns an iterator to the element following the last element of the container. + * This element acts as a placeholder; attempting to access it results in undefined behavior. + */ + iterator end() const; + + /** + * Checks if the container has no elements. + */ + bool empty() const; + + /** + * Returns the number of elements in the container. + */ + size_type size() const; + + /** + * Erases all elements from the container. After this call, size() returns zero. + * Invalidates any references, pointers, or iterators referring to contained elements. May also invalidate past-the-end iterators. + */ + void clear() const; + + /** + * Inserts element(s) into the container, if the container doesn't already contain an element with an equivalent key. + * May invalidate any references, pointers, or iterators referring to contained elements. + * + * @return A pair consisting of an iterator to the inserted element (or to the element that prevented the insertion) and a bool denoting whether the insertion took place. + */ + template + std::pair insert(Key_&& key, Value_&& value) const; + + /** + * If an element with the given key already exists, it is overwritten with the given value. + * Otherwise, a new element with the given key and value are inserted. + * May invalidate any references, pointers, or iterators referring to contained elements. + * + * @return The bool component is true if the insertion took place and false if the assignment took place. The iterator component is pointing at the element that was inserted or updated. + */ + template + std::pair insert_or_assign(Key_&& key, Value_&& value) const; + + /** + * Removes the element pointed to by iter. + * May invalidate any references, pointers, or iterators referring to contained elements. + * The iterator iter must be valid and dereferenceable. Thus the end() iterator (which is valid, but is not dereferenceable) cannot be used as a value for iter. + */ + void erase(iterator iter) const; + + /** + * Removes the element with the given key, if it exists. + * May invalidate any references, pointers, or iterators referring to contained elements. + * + * @return The number of elements removed. This is either '1' if an element with the key existed, or '0' if it didn't. + */ + [[nodiscard]] size_t erase(const Key& key) const; + + /** + * Returns the mapped value of the element with key equivalent to key. + * If no such element exists, an exception of type std::out_of_range is thrown. + */ + Value at(const Key& key) const; + + /** + * Finds an element with key equivalent to key. + * + * @return Iterator to an element with key equivalent to key. + * If no such element is found, past-the-end (see end()) iterator is returned. + */ + iterator find(const Key& key) const; + + /** + * Checks if there is an element with key equivalent to key in the container. + * + * @return true if there is such an element, otherwise false. + */ + bool contains(const Key& key) const; + + /** + * Increase the capacity so that at least count elements can be stored without + * having to reallocate or rehash. + */ + void reserve(size_type count) const; + + /** + * Value equality comparison. This function implements Python-like semantics for + * equality: two dicts with the same identity (e.g. same pointer) trivially + * compare equal, otherwise each element is compared for equality. + */ + template + friend bool operator==( + const Dict& lhs, + const Dict& rhs); + template + friend bool operator!=( + const Dict& lhs, + const Dict& rhs); + + /** + * Identity comparison. Returns true if and only if `rhs` represents the same + * Dict object as `this`. + */ + bool is(const Dict& rhs) const; + + // private API for now because the return type will change to TypePtr + // instead of std::optional once types are mandatory. + TypePtr keyType() const; + TypePtr valueType() const; + + // [unsafe set type] + // These functions mutate the tagged type of this dictionary in place. + // There is no checking that the members of the dictionary are instances + // of the new types, nor is there a check that other IValues which + // hold references to this dictionary have the right static type. + // This functionality is used only in the unpickler, where at + // creation type the real type of the dictionary is unknown, but + // then later recovered from the static type information of the + // unpickled object. + void unsafeSetKeyType(TypePtr t); + void unsafeSetValueType(TypePtr t); +}; + +namespace impl { +// GenericDict is how IValue stores dicts. It is, however, not part of the +// public API. Kernels should use Dicts with concrete Key, Value types instead +// (maybe except for some internal prim ops). +using GenericDict = Dict; + +} +} + +namespace torch { + template using Dict = c10::Dict; +} + +#include // IWYU pragma: keep diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Dict_inl.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Dict_inl.h new file mode 100644 index 0000000000000000000000000000000000000000..9ad24da9caef7b6fc88f0f64e442fb5e9ee52575 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Dict_inl.h @@ -0,0 +1,208 @@ +#pragma once + +#include +#include + +namespace c10 { +namespace detail { +inline bool DictKeyEqualTo::operator()(const IValue& lhs, const IValue& rhs) const { + if (lhs.isTensor() && rhs.isTensor()) { + // for tensors, we compare only by identity (following how it's done in Python). + return lhs.is(rhs); + } + // Otherwise, we first compare by identity for efficiency, then by value (see: + // [container equality]) + return _fastEqualsForContainer(lhs, rhs); +} +} + +template decltype(auto) getTypePtr(); +std::string toString(const Type& type); + +namespace impl { + +template +Dict toTypedDict(GenericDict dict) { + TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.keyType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr()), ", ", toString(*getTypePtr()), ">. Key types mismatch."); + TORCH_INTERNAL_ASSERT(*getTypePtr() == *dict.impl_->elementTypes.valueType, "Tried to cast a Dict<", toString(*dict.impl_->elementTypes.keyType), ", ", toString(*dict.impl_->elementTypes.valueType) ,"> to a Dict<", toString(*getTypePtr()), ", ", toString(*getTypePtr()), ">. Value types mismatch."); + + return Dict(std::move(dict.impl_)); +} + +template +GenericDict toGenericDict(Dict dict) { + return GenericDict(std::move(dict.impl_)); +} +} + +namespace detail { + +inline size_t DictKeyHash::operator()(const IValue& ivalue) const { + if (ivalue.isInt()) { + return std::hash()(ivalue.toInt()); + } else if (ivalue.isString()) { + return std::hash()(ivalue.toStringView()); + } else if (ivalue.isDouble()) { + return std::hash()(ivalue.toDouble()); + } else if (ivalue.isComplexDouble()) { + return c10::hash>()(ivalue.toComplexDouble()); + } else if (ivalue.isBool()) { + return std::hash()(ivalue.toBool()); + } else if (ivalue.isTensor()) { + return std::hash()(ivalue.toTensor().unsafeGetTensorImpl()); + } else if (ivalue.isDevice()) { + return std::hash()(ivalue.toDevice()); + } else { + TORCH_CHECK(false, "Can't hash IValues with tag '", ivalue.tagKind(), "'"); + } +} + +inline intrusive_ptr DictImpl::copy() const { + return make_intrusive(dict, elementTypes); +} + +} + +template +Dict::Dict() + :Dict(make_intrusive( + detail::DictImpl::dict_map_type(), + detail::DictImpl::DictElementTypes{getTypePtr(), getTypePtr()})) { + static_assert(!std::is_same_v, "This constructor is not valid for Dict. Please use c10::impl::GenericDict(keyType, valueType) instead."); + static_assert(!std::is_same_v, "This constructor is not valid for Dict<_, IValue>. Please use c10::impl::GenericDict(keyType, valueType) instead."); +} + +template +Dict::Dict(TypePtr keyType, TypePtr valueType) +: Dict(make_intrusive( + detail::DictImpl::dict_map_type(), + detail::DictImpl::DictElementTypes {std::move(keyType), std::move(valueType)})) { + static_assert(std::is_same_v, "This constructor is only valid for c10::impl::GenericDict."); + static_assert(std::is_same_v, "This constructor is only valid for c10::impl::GenericDict."); +} + +template +Dict::Dict(c10::intrusive_ptr&& impl): impl_(std::move(impl)) {} + +template +Dict Dict::copy() const { + return Dict(impl_->copy()); +} + +template +typename Dict::iterator Dict::begin() const { + return iterator{impl_->dict.begin()}; +} + +template +typename Dict::iterator Dict::end() const { + return iterator{impl_->dict.end()}; +} + +template +bool Dict::empty() const { + return impl_->dict.empty(); +} + +template +typename Dict::size_type Dict::size() const { + return impl_->dict.size(); +} + +template +void Dict::clear() const { + impl_->dict.clear(); +} + +template +template +std::pair::iterator, bool> Dict::insert(Key_&& key, Value_&& value) const { + static_assert(std::is_constructible_v, "Wrong type for the key argument of Dict::insert"); + static_assert(std::is_constructible_v, "Wrong type for the value argument of Dict::insert"); + auto inserted = impl_->dict.emplace( + Key(std::forward(key)), + Value(std::forward(value))); + return {iterator{inserted.first}, inserted.second}; +} + +template +template +std::pair::iterator, bool> Dict::insert_or_assign(Key_&& key, Value_&& value) const { + static_assert(std::is_constructible_v, "Wrong type for the key argument of Dict::insert_or_assign"); + static_assert(std::is_constructible_v, "Wrong type for the value argument of Dict::insert_or_assign"); + auto inserted = impl_->dict.insert_or_assign( + Key(std::forward(key)), + Value(std::forward(value))); + return {iterator{inserted.first}, inserted.second}; +} + +template +void Dict::erase(iterator iter) const { + impl_->dict.erase(iter.entryRef_.iterator_); +} + +template +[[nodiscard]] size_t Dict::erase(const Key& key) const { + return impl_->dict.erase(key); +} + +template +Value Dict::at(const Key& key) const { + return impl_->dict.at(key).template to(); +} + +template +typename Dict::iterator Dict::find(const Key& key) const { + return iterator{impl_->dict.find(key)}; +} + +template +bool Dict::contains(const Key& key) const { + return end() != find(key); +} + +template +void Dict::reserve(size_type count) const { + impl_->dict.reserve(count); +} + +template +TypePtr Dict::keyType() const { + return impl_->elementTypes.keyType; +} + +template +TypePtr Dict::valueType() const { + return impl_->elementTypes.valueType; +} +template +void Dict::unsafeSetKeyType(TypePtr t) { + impl_->elementTypes.keyType = std::move(t); +} + +template +void Dict::unsafeSetValueType(TypePtr t) { + impl_->elementTypes.valueType = std::move(t); +} + +template +bool operator==(const Dict& lhs, const Dict& rhs) { + // Dicts with the same identity trivially compare equal. + if (lhs.impl_ == rhs.impl_) { + return true; + } + + // Otherwise compare the values + return *lhs.impl_ == *rhs.impl_; +} + +template +bool operator!=(const Dict& lhs, const Dict& rhs) { + return !(lhs == rhs); +} + +template +bool Dict::is(const Dict& rhs) const { + return this->impl_ == rhs.impl_; +} +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/DimVector.h b/phivenv/Lib/site-packages/torch/include/ATen/core/DimVector.h new file mode 100644 index 0000000000000000000000000000000000000000..9d0318b7e3bd6b6207c9b2e333b6fdf99eaf0585 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/DimVector.h @@ -0,0 +1,13 @@ +#pragma once +#include + +namespace at { + +// Re-declaring 'DimVector' type and size inside 'at' namespace. +// This is done to avoid modifying every use into their 'c10' +// equivalent. + +using c10::kDimVectorStaticSize; +using c10::DimVector; + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/Dimname.h b/phivenv/Lib/site-packages/torch/include/ATen/core/Dimname.h new file mode 100644 index 0000000000000000000000000000000000000000..054816cd35ab3fd854323132a3ba5a79f3bcdd25 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/Dimname.h @@ -0,0 +1,48 @@ +#pragma once + +#include +#include +#include +#include + +namespace at { + +enum class NameType: uint8_t { BASIC, WILDCARD }; + +struct TORCH_API Dimname { + static Dimname fromSymbol(Symbol name); + static Dimname wildcard(); + static bool isValidName(const std::string& name); + + NameType type() const { return type_; } + Symbol symbol() const { return name_; } + + bool isBasic() const { return type_ == NameType::BASIC; } + bool isWildcard() const { return type_ == NameType::WILDCARD; } + + bool matches(Dimname other) const; + std::optional unify(Dimname other) const; + + private: + Dimname(Symbol name) + : name_(name), type_(NameType::BASIC) {} + Dimname(Symbol name, NameType type) + : name_(name), type_(type) {} + + Symbol name_; + NameType type_; +}; + +using DimnameList = c10::ArrayRef; + +TORCH_API std::ostream& operator<<(std::ostream& out, const Dimname& dimname); + +inline bool operator==(const Dimname& lhs, const Dimname& rhs) { + return lhs.symbol() == rhs.symbol(); +} + +inline bool operator!=(const Dimname& lhs, const Dimname& rhs) { + return !(lhs == rhs); +} + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/DistributionsHelper.h b/phivenv/Lib/site-packages/torch/include/ATen/core/DistributionsHelper.h new file mode 100644 index 0000000000000000000000000000000000000000..2b2e71c10e791a047013cefb649677aaa967a4d0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/DistributionsHelper.h @@ -0,0 +1,332 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +/** + * Distributions kernel adapted from THRandom.cpp + * The kernels try to follow std::random distributions signature + * For instance: in ATen + * auto gen = at::detail::createCPUGenerator(); + * at::uniform_real_distribution uniform(0, 1); + * auto sample = uniform(gen.get()); + * + * vs std::random + * + * std::mt19937 gen; + * std::uniform_real_distribution uniform(0, 1); + * auto sample = uniform(gen); + */ + + +namespace at { +namespace { + +/** + * Samples a discrete uniform distribution in the range [base, base+range) of type T + */ +template +struct uniform_int_from_to_distribution { + + C10_HOST_DEVICE inline uniform_int_from_to_distribution(uint64_t range, int64_t base) : range_(range), base_(base) {} + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { +#ifdef FBCODE_CAFFE2 + if (( + std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v) && range_ >= 1ULL << 32) +#else + if (range_ >= 1ULL << 28) // allow approx 5% skew in uniform int generation using % +#endif + { + return transformation::uniform_int_from_to(generator->random64(), range_, base_); + } else { + return transformation::uniform_int_from_to(generator->random(), range_, base_); + } + } + + private: + uint64_t range_; + int64_t base_; +}; + +/** + * Samples a discrete uniform distribution in the range [min_value(int64_t), max_value(int64_t)] + */ +template +struct uniform_int_full_range_distribution { + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + return transformation::uniform_int_full_range(generator->random64()); + } + +}; + +/** + * Samples a discrete uniform distribution in the range [0, max_value(T)] for integral types + * and [0, 2^mantissa] for floating-point types. + */ +template +struct uniform_int_distribution { + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + if constexpr (std::is_same_v || std::is_same_v) { + return transformation::uniform_int(generator->random64()); + } else { + return transformation::uniform_int(generator->random()); + } + } + +}; + +/** + * Samples a uniform distribution in the range [from, to) of type T + */ +template +struct uniform_real_distribution { + + C10_HOST_DEVICE inline uniform_real_distribution(T from, T to) : from_(from), to_(to) { + TORCH_CHECK_IF_NOT_ON_CUDA(from <= to); + TORCH_CHECK_IF_NOT_ON_CUDA(to - from <= std::numeric_limits::max()); + } + + template + C10_HOST_DEVICE inline dist_acctype operator()(RNG generator){ + if constexpr (std::is_same_v) { + return transformation::uniform_real(generator->random64(), from_, to_); + } else { + return transformation::uniform_real(generator->random(), from_, to_); + } + } + + private: + T from_; + T to_; +}; + +// The SFINAE checks introduced in #39816 looks overcomplicated and must revisited +// https://github.com/pytorch/pytorch/issues/40052 +#define DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(member) \ +template \ +struct has_member_##member \ +{ \ + typedef char yes; \ + typedef long no; \ + template static yes test(decltype(&U::member)); \ + template static no test(...); \ + static constexpr bool value = sizeof(test(0)) == sizeof(yes); \ +} + +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_double_normal_sample); +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_double_normal_sample); +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(next_float_normal_sample); +DISTRIBUTION_HELPER_GENERATE_HAS_MEMBER(set_next_float_normal_sample); + +#define DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(TYPE) \ + \ +template ::value && \ + has_member_set_next_##TYPE##_normal_sample::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* generator, ret_type* ret) { \ + if (generator->next_##TYPE##_normal_sample()) { \ + *ret = *(generator->next_##TYPE##_normal_sample()); \ + generator->set_next_##TYPE##_normal_sample(std::optional()); \ + return true; \ + } \ + return false; \ +} \ + \ +template ::value || \ + !has_member_set_next_##TYPE##_normal_sample::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline bool maybe_get_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type* /*ret*/) { \ + return false; \ +} \ + \ +template ::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* generator, ret_type cache) { \ + generator->set_next_##TYPE##_normal_sample(cache); \ +} \ + \ +template ::value \ + ), int> = 0> \ +C10_HOST_DEVICE inline void maybe_set_next_##TYPE##_normal_sample(RNG* /*generator*/, ret_type /*cache*/) { \ +} + +DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(double) +DISTRIBUTION_HELPER_GENERATE_NEXT_NORMAL_METHODS(float) + +/** + * Samples a normal distribution using the Box-Muller method + * Takes mean and standard deviation as inputs + * Note that Box-muller method returns two samples at a time. + * Hence, we cache the "next" sample in the CPUGeneratorImpl class. + */ +template +struct normal_distribution { + + C10_HOST_DEVICE inline normal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in >= 0, "stdv_in must be positive: ", stdv_in); + } + + template + C10_HOST_DEVICE inline dist_acctype operator()(RNG generator){ + dist_acctype ret; + // return cached values if available + if constexpr (std::is_same_v) { + if (maybe_get_next_double_normal_sample(generator, &ret)) { + return transformation::normal(ret, mean, stdv); + } + } else { + if (maybe_get_next_float_normal_sample(generator, &ret)) { + return transformation::normal(ret, mean, stdv); + } + } + // otherwise generate new normal values + uniform_real_distribution uniform(0.0, 1.0); + const dist_acctype u1 = uniform(generator); + const dist_acctype u2 = uniform(generator); + const dist_acctype r = ::sqrt(static_cast(-2.0) * ::log1p(-u2)); + const dist_acctype theta = static_cast(2.0) * c10::pi * u1; + if constexpr (std::is_same_v) { + maybe_set_next_double_normal_sample(generator, r * ::sin(theta)); + } else { + maybe_set_next_float_normal_sample(generator, r * ::sin(theta)); + } + ret = r * ::cos(theta); + return transformation::normal(ret, mean, stdv); + } + + private: + T mean; + T stdv; +}; + +template +struct DiscreteDistributionType { using type = float; }; + +template <> struct DiscreteDistributionType { using type = double; }; + +/** + * Samples a bernoulli distribution given a probability input + */ +template +struct bernoulli_distribution { + + C10_HOST_DEVICE inline bernoulli_distribution(T p_in) : p(p_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(p_in >= 0 && p_in <= 1); + } + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::bernoulli(uniform(generator), p); + } + + private: + T p; +}; + +/** + * Samples a geometric distribution given a probability input + */ +template +struct geometric_distribution { + + C10_HOST_DEVICE inline geometric_distribution(T p_in) : p(p_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(p_in > 0 && p_in < 1); + } + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::geometric(uniform(generator), p); + } + + private: + T p; +}; + +/** + * Samples an exponential distribution given a lambda input + */ +template +struct exponential_distribution { + + C10_HOST_DEVICE inline exponential_distribution(T lambda_in) : lambda(lambda_in) {} + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::exponential(uniform(generator), lambda); + } + + private: + T lambda; +}; + +/** + * Samples a cauchy distribution given median and sigma as inputs + */ +template +struct cauchy_distribution { + + C10_HOST_DEVICE inline cauchy_distribution(T median_in, T sigma_in) : median(median_in), sigma(sigma_in) {} + + template + C10_HOST_DEVICE inline T operator()(RNG generator) { + uniform_real_distribution uniform(0.0, 1.0); + return transformation::cauchy(uniform(generator), median, sigma); + } + + private: + T median; + T sigma; +}; + +/** + * Samples a lognormal distribution + * Takes mean and standard deviation as inputs + * Outputs two samples at a time + */ +template +struct lognormal_distribution { + + C10_HOST_DEVICE inline lognormal_distribution(T mean_in, T stdv_in) : mean(mean_in), stdv(stdv_in) { + TORCH_CHECK_IF_NOT_ON_CUDA(stdv_in > 0); + } + + template + C10_HOST_DEVICE inline T operator()(RNG generator){ + normal_distribution normal(mean, stdv); + return transformation::log_normal(normal(generator)); + } + + private: + T mean; + T stdv; +}; +} +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/alias_info.h b/phivenv/Lib/site-packages/torch/include/ATen/core/alias_info.h new file mode 100644 index 0000000000000000000000000000000000000000..1ed9f501bff2860336183de9b1b809431ef4e4ce --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/alias_info.h @@ -0,0 +1,162 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include + +namespace c10 { +/** + * class AliasInfo + * + * Data structure to hold aliasing information for an `Argument`. They can be + * nested to represent aliasing information on contained types. + * + * There is a `beforeSet` which describes the aliasing information before the + * operator executes, and an `afterSet` that describes aliasing info + * after execution. + */ +class AliasInfo { + public: + AliasInfo() = default; + AliasInfo(bool is_write, const std::set& before_qual_strings, const std::set& after_qual_strings) : isWrite_(is_write) { + for (const auto& s: before_qual_strings) { + beforeSets_.insert(Symbol::fromQualString(s)); + } + for (const auto& s : after_qual_strings) { + afterSets_.insert(Symbol::fromQualString(s)); + } + } + // Symbol for the set that can alias anything + static Symbol wildcardSet() { + static const Symbol wc = Symbol::fromQualString("alias::*"); + return wc; + } + + void setIsWrite(bool isWrite) { + isWrite_ = isWrite; + } + + bool isWrite() const { + return isWrite_; + } + + void addBeforeSet(Symbol aliasSet) { + beforeSets_.insert(aliasSet); + } + + void addAfterSet(Symbol aliasSet) { + afterSets_.insert(aliasSet); + } + + const std::unordered_set& beforeSets() const { + return beforeSets_; + } + + const std::unordered_set& afterSets() const { + return afterSets_; + } + + Symbol beforeSet() const { + AT_ASSERT(beforeSets_.size() == 1); + return *beforeSets_.begin(); + } + + bool isWildcardBefore() const { + return beforeSets_.count(wildcardSet()) != 0; + } + + bool isWildcardAfter() const { + return afterSets_.count(wildcardSet()) != 0; + } + + // the alias info for the contained types of the type + // e.g. if this is an annotation on List[T], `sets` refers to + // the alias sets that the list may be in + // while containedTypes()[0] refers to the sets that members of the list + // may be in + void addContainedType(AliasInfo aliasInfo) { + containedTypes_.push_back(std::move(aliasInfo)); + } + const std::vector& containedTypes() const { + return containedTypes_; + } + + private: + std::unordered_set beforeSets_; + std::unordered_set afterSets_; + std::vector containedTypes_; + bool isWrite_ = false; +}; + +inline bool operator==(const AliasInfo& lhs, const AliasInfo& rhs) { + return lhs.isWrite() == rhs.isWrite() + && lhs.beforeSets() == rhs.beforeSets() + && lhs.afterSets() == rhs.afterSets() + && lhs.containedTypes() == rhs.containedTypes(); +} + +// this does match the way things are represented in the schema +inline std::ostream& operator<<(std::ostream& out, const AliasInfo& aliasInfo) { + out << "("; + bool first = true; + for (const auto& set : aliasInfo.beforeSets()) { + if (first) { + first = false; + } else { + out << "|"; + } + out << set.toUnqualString(); + } + if (aliasInfo.isWrite()) { + out << "!"; + } + if (aliasInfo.beforeSets() != aliasInfo.afterSets()) { + out << " -> "; + first = true; + for (const auto& set : aliasInfo.afterSets()) { + if (first) { + first = false; + } else { + out << "|"; + } + out << set.toUnqualString(); + } + } + out << ")"; + return out; +} +} // namespace c10 + +namespace std { +template <> + struct hash { + size_t operator()(const c10::AliasInfo& aliasInfo) const { + auto hash = std::hash()(aliasInfo.isWrite()); + + // NOTE: for unordered_set hashes, we couldn't use hash_combine + // because hash_combine is order dependent. Instead, we choose to + // use XOR as the combining function as XOR is commutative. + size_t before_set_hash_seed = 0; + for (auto &e: aliasInfo.beforeSets()) { + auto symbol_hash = std::hash()(e); + before_set_hash_seed = before_set_hash_seed ^ symbol_hash; + } + size_t after_set_hash_seed = 0; + for (auto &e: aliasInfo.afterSets()) { + auto symbol_hash = std::hash()(e); + after_set_hash_seed = after_set_hash_seed ^ symbol_hash; + } + + hash = c10::hash_combine(hash, before_set_hash_seed); + hash = c10::hash_combine(hash, after_set_hash_seed); + for (auto &e: aliasInfo.containedTypes()) { + auto contained_type_hash = std::hash()(e); + hash = c10::hash_combine(hash, contained_type_hash); + } + return hash; + } + }; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/aten_interned_strings.h b/phivenv/Lib/site-packages/torch/include/ATen/core/aten_interned_strings.h new file mode 100644 index 0000000000000000000000000000000000000000..76839a71af32a2fc40b2f68cc573d1e6784e50b9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/aten_interned_strings.h @@ -0,0 +1,2294 @@ +#pragma once + +// @generated by torchgen/gen.py from aten_interned_strings.h + +#if defined(TORCH_ASSERT_NO_OPERATORS) || defined(TORCH_ASSERT_ONLY_METHOD_OPERATORS) +#error This change adds a dependency on native_functions.yaml, \ + meaning the file will need to be re-compiled every time an operator \ + is changed or added. Consider if including for \ + the c10::Symbol class would be sufficient, or if your change would be \ + better placed in another file. +#endif + +// ATen symbols correspond exactly to operators defined in ATen. Every +// symbol here corresponds exactly to an ATen operation defined in +// native_functions.yaml; attributes are in one-to-one correspondence +// with their ATen name. + +#define FORALL_ATEN_BASE_SYMBOLS(_) \ +_(aten, __and__) \ +_(aten, __iand__) \ +_(aten, __ilshift__) \ +_(aten, __ior__) \ +_(aten, __irshift__) \ +_(aten, __ixor__) \ +_(aten, __lshift__) \ +_(aten, __or__) \ +_(aten, __rshift__) \ +_(aten, __xor__) \ +_(aten, _adaptive_avg_pool2d) \ +_(aten, _adaptive_avg_pool2d_backward) \ +_(aten, _adaptive_avg_pool3d) \ +_(aten, _adaptive_avg_pool3d_backward) \ +_(aten, _add_batch_dim) \ +_(aten, _add_relu) \ +_(aten, _add_relu_) \ +_(aten, _addmm_activation) \ +_(aten, _aminmax) \ +_(aten, _amp_foreach_non_finite_check_and_unscale) \ +_(aten, _amp_foreach_non_finite_check_and_unscale_) \ +_(aten, _amp_update_scale) \ +_(aten, _amp_update_scale_) \ +_(aten, _assert_async) \ +_(aten, _assert_scalar) \ +_(aten, _assert_tensor_metadata) \ +_(aten, _autocast_to_full_precision) \ +_(aten, _autocast_to_reduced_precision) \ +_(aten, _backward) \ +_(aten, _batch_norm_impl_index) \ +_(aten, _batch_norm_impl_index_backward) \ +_(aten, _batch_norm_no_update) \ +_(aten, _batch_norm_with_update) \ +_(aten, _batch_norm_with_update_functional) \ +_(aten, _cast_Byte) \ +_(aten, _cast_Char) \ +_(aten, _cast_Double) \ +_(aten, _cast_Float) \ +_(aten, _cast_Half) \ +_(aten, _cast_Int) \ +_(aten, _cast_Long) \ +_(aten, _cast_Short) \ +_(aten, _cdist_backward) \ +_(aten, _cdist_forward) \ +_(aten, _cholesky_solve_helper) \ +_(aten, _choose_qparams_per_tensor) \ +_(aten, _chunk_cat) \ +_(aten, _coalesce) \ +_(aten, _coalesced) \ +_(aten, _coalesced_) \ +_(aten, _compute_linear_combination) \ +_(aten, _conj) \ +_(aten, _conj_copy) \ +_(aten, _conj_physical) \ +_(aten, _conv_depthwise2d) \ +_(aten, _convert_indices_from_coo_to_csr) \ +_(aten, _convert_indices_from_csr_to_coo) \ +_(aten, _convert_weight_to_int4pack) \ +_(aten, _convert_weight_to_int4pack_for_cpu) \ +_(aten, _convolution) \ +_(aten, _convolution_double_backward) \ +_(aten, _convolution_mode) \ +_(aten, _copy_from) \ +_(aten, _copy_from_and_resize) \ +_(aten, _cslt_compress) \ +_(aten, _cslt_sparse_mm) \ +_(aten, _cslt_sparse_mm_search) \ +_(aten, _ctc_loss) \ +_(aten, _ctc_loss_backward) \ +_(aten, _cudnn_attention_forward) \ +_(aten, _cudnn_ctc_loss) \ +_(aten, _cudnn_init_dropout_state) \ +_(aten, _cudnn_rnn) \ +_(aten, _cudnn_rnn_backward) \ +_(aten, _cudnn_rnn_flatten_weight) \ +_(aten, _cufft_clear_plan_cache) \ +_(aten, _cufft_get_plan_cache_max_size) \ +_(aten, _cufft_get_plan_cache_size) \ +_(aten, _cufft_set_plan_cache_max_size) \ +_(aten, _cummax_helper) \ +_(aten, _cummin_helper) \ +_(aten, _debug_has_internal_overlap) \ +_(aten, _dimI) \ +_(aten, _dimV) \ +_(aten, _dim_arange) \ +_(aten, _dirichlet_grad) \ +_(aten, _dyn_quant_matmul_4bit) \ +_(aten, _dyn_quant_pack_4bit_weight) \ +_(aten, _efficient_attention_backward) \ +_(aten, _efficient_attention_forward) \ +_(aten, _efficientzerotensor) \ +_(aten, _embedding_bag) \ +_(aten, _embedding_bag_backward) \ +_(aten, _embedding_bag_dense_backward) \ +_(aten, _embedding_bag_forward_only) \ +_(aten, _embedding_bag_per_sample_weights_backward) \ +_(aten, _embedding_bag_sparse_backward) \ +_(aten, _empty_affine_quantized) \ +_(aten, _empty_per_channel_affine_quantized) \ +_(aten, _euclidean_dist) \ +_(aten, _fake_quantize_learnable_per_channel_affine) \ +_(aten, _fake_quantize_learnable_per_channel_affine_backward) \ +_(aten, _fake_quantize_learnable_per_tensor_affine) \ +_(aten, _fake_quantize_learnable_per_tensor_affine_backward) \ +_(aten, _fake_quantize_per_tensor_affine_cachemask_tensor_qparams) \ +_(aten, _fft_c2c) \ +_(aten, _fft_c2r) \ +_(aten, _fft_r2c) \ +_(aten, _fill_mem_eff_dropout_mask) \ +_(aten, _fill_mem_eff_dropout_mask_) \ +_(aten, _flash_attention_backward) \ +_(aten, _flash_attention_forward) \ +_(aten, _foobar) \ +_(aten, _foreach_abs) \ +_(aten, _foreach_abs_) \ +_(aten, _foreach_acos) \ +_(aten, _foreach_acos_) \ +_(aten, _foreach_add) \ +_(aten, _foreach_add_) \ +_(aten, _foreach_addcdiv) \ +_(aten, _foreach_addcdiv_) \ +_(aten, _foreach_addcmul) \ +_(aten, _foreach_addcmul_) \ +_(aten, _foreach_asin) \ +_(aten, _foreach_asin_) \ +_(aten, _foreach_atan) \ +_(aten, _foreach_atan_) \ +_(aten, _foreach_ceil) \ +_(aten, _foreach_ceil_) \ +_(aten, _foreach_clamp_max) \ +_(aten, _foreach_clamp_max_) \ +_(aten, _foreach_clamp_min) \ +_(aten, _foreach_clamp_min_) \ +_(aten, _foreach_copy) \ +_(aten, _foreach_copy_) \ +_(aten, _foreach_cos) \ +_(aten, _foreach_cos_) \ +_(aten, _foreach_cosh) \ +_(aten, _foreach_cosh_) \ +_(aten, _foreach_div) \ +_(aten, _foreach_div_) \ +_(aten, _foreach_erf) \ +_(aten, _foreach_erf_) \ +_(aten, _foreach_erfc) \ +_(aten, _foreach_erfc_) \ +_(aten, _foreach_exp) \ +_(aten, _foreach_exp_) \ +_(aten, _foreach_expm1) \ +_(aten, _foreach_expm1_) \ +_(aten, _foreach_floor) \ +_(aten, _foreach_floor_) \ +_(aten, _foreach_frac) \ +_(aten, _foreach_frac_) \ +_(aten, _foreach_lerp) \ +_(aten, _foreach_lerp_) \ +_(aten, _foreach_lgamma) \ +_(aten, _foreach_lgamma_) \ +_(aten, _foreach_log) \ +_(aten, _foreach_log10) \ +_(aten, _foreach_log10_) \ +_(aten, _foreach_log1p) \ +_(aten, _foreach_log1p_) \ +_(aten, _foreach_log2) \ +_(aten, _foreach_log2_) \ +_(aten, _foreach_log_) \ +_(aten, _foreach_max) \ +_(aten, _foreach_maximum) \ +_(aten, _foreach_maximum_) \ +_(aten, _foreach_minimum) \ +_(aten, _foreach_minimum_) \ +_(aten, _foreach_mul) \ +_(aten, _foreach_mul_) \ +_(aten, _foreach_neg) \ +_(aten, _foreach_neg_) \ +_(aten, _foreach_norm) \ +_(aten, _foreach_pow) \ +_(aten, _foreach_pow_) \ +_(aten, _foreach_reciprocal) \ +_(aten, _foreach_reciprocal_) \ +_(aten, _foreach_round) \ +_(aten, _foreach_round_) \ +_(aten, _foreach_rsqrt) \ +_(aten, _foreach_rsqrt_) \ +_(aten, _foreach_sigmoid) \ +_(aten, _foreach_sigmoid_) \ +_(aten, _foreach_sign) \ +_(aten, _foreach_sign_) \ +_(aten, _foreach_sin) \ +_(aten, _foreach_sin_) \ +_(aten, _foreach_sinh) \ +_(aten, _foreach_sinh_) \ +_(aten, _foreach_sqrt) \ +_(aten, _foreach_sqrt_) \ +_(aten, _foreach_sub) \ +_(aten, _foreach_sub_) \ +_(aten, _foreach_tan) \ +_(aten, _foreach_tan_) \ +_(aten, _foreach_tanh) \ +_(aten, _foreach_tanh_) \ +_(aten, _foreach_trunc) \ +_(aten, _foreach_trunc_) \ +_(aten, _foreach_zero) \ +_(aten, _foreach_zero_) \ +_(aten, _functional_assert_async) \ +_(aten, _functional_assert_scalar) \ +_(aten, _functional_sym_constrain_range) \ +_(aten, _functional_sym_constrain_range_for_size) \ +_(aten, _fused_adagrad) \ +_(aten, _fused_adagrad_) \ +_(aten, _fused_adam) \ +_(aten, _fused_adam_) \ +_(aten, _fused_adamw) \ +_(aten, _fused_adamw_) \ +_(aten, _fused_dropout) \ +_(aten, _fused_moving_avg_obs_fq_helper) \ +_(aten, _fused_moving_avg_obs_fq_helper_functional) \ +_(aten, _fused_rms_norm) \ +_(aten, _fused_sdp_choice) \ +_(aten, _fused_sgd) \ +_(aten, _fused_sgd_) \ +_(aten, _fw_primal) \ +_(aten, _fw_primal_copy) \ +_(aten, _gather_sparse_backward) \ +_(aten, _grid_sampler_2d_cpu_fallback) \ +_(aten, _grid_sampler_2d_cpu_fallback_backward) \ +_(aten, _grouped_mm) \ +_(aten, _has_compatible_shallow_copy_type) \ +_(aten, _has_same_storage_numel) \ +_(aten, _histogramdd_bin_edges) \ +_(aten, _histogramdd_from_bin_cts) \ +_(aten, _histogramdd_from_bin_tensors) \ +_(aten, _index_put_impl) \ +_(aten, _index_put_impl_) \ +_(aten, _indices) \ +_(aten, _indices_copy) \ +_(aten, _int_mm) \ +_(aten, _is_all_true) \ +_(aten, _is_any_true) \ +_(aten, _is_zerotensor) \ +_(aten, _jagged_to_padded_dense_forward) \ +_(aten, _lazy_clone) \ +_(aten, _linalg_check_errors) \ +_(aten, _linalg_det) \ +_(aten, _linalg_eigh) \ +_(aten, _linalg_eigvals) \ +_(aten, _linalg_slogdet) \ +_(aten, _linalg_solve_ex) \ +_(aten, _linalg_svd) \ +_(aten, _local_scalar_dense) \ +_(aten, _log_softmax) \ +_(aten, _log_softmax_backward_data) \ +_(aten, _logcumsumexp) \ +_(aten, _lstm_mps) \ +_(aten, _lu_with_info) \ +_(aten, _make_dep_token) \ +_(aten, _make_dual) \ +_(aten, _make_dual_copy) \ +_(aten, _make_per_channel_quantized_tensor) \ +_(aten, _make_per_tensor_quantized_tensor) \ +_(aten, _masked_scale) \ +_(aten, _masked_softmax) \ +_(aten, _masked_softmax_backward) \ +_(aten, _mixed_dtypes_linear) \ +_(aten, _mkldnn_reshape) \ +_(aten, _mkldnn_transpose) \ +_(aten, _mkldnn_transpose_) \ +_(aten, _mps_convolution) \ +_(aten, _mps_convolution_transpose) \ +_(aten, _native_batch_norm_legit) \ +_(aten, _native_batch_norm_legit_functional) \ +_(aten, _native_batch_norm_legit_no_training) \ +_(aten, _native_multi_head_attention) \ +_(aten, _neg_view) \ +_(aten, _neg_view_copy) \ +_(aten, _nested_compute_contiguous_strides_offsets) \ +_(aten, _nested_from_padded) \ +_(aten, _nested_from_padded_and_nested_example) \ +_(aten, _nested_from_padded_tensor) \ +_(aten, _nested_get_jagged_dummy) \ +_(aten, _nested_get_lengths) \ +_(aten, _nested_get_max_seqlen) \ +_(aten, _nested_get_min_seqlen) \ +_(aten, _nested_get_offsets) \ +_(aten, _nested_get_ragged_idx) \ +_(aten, _nested_get_values) \ +_(aten, _nested_get_values_copy) \ +_(aten, _nested_select_backward) \ +_(aten, _nested_sum_backward) \ +_(aten, _nested_tensor_from_mask) \ +_(aten, _nested_tensor_from_mask_left_aligned) \ +_(aten, _nested_tensor_from_tensor_list) \ +_(aten, _nested_tensor_size) \ +_(aten, _nested_tensor_softmax_with_shape) \ +_(aten, _nested_tensor_storage_offsets) \ +_(aten, _nested_tensor_strides) \ +_(aten, _nested_view_from_buffer) \ +_(aten, _nested_view_from_buffer_copy) \ +_(aten, _nested_view_from_jagged) \ +_(aten, _nested_view_from_jagged_copy) \ +_(aten, _new_zeros_with_same_feature_meta) \ +_(aten, _nnpack_available) \ +_(aten, _nnpack_spatial_convolution) \ +_(aten, _nnz) \ +_(aten, _pack_padded_sequence) \ +_(aten, _pack_padded_sequence_backward) \ +_(aten, _pad_circular) \ +_(aten, _pad_enum) \ +_(aten, _pad_packed_sequence) \ +_(aten, _padded_dense_to_jagged_forward) \ +_(aten, _pdist_backward) \ +_(aten, _pdist_forward) \ +_(aten, _pin_memory) \ +_(aten, _prelu_kernel) \ +_(aten, _prelu_kernel_backward) \ +_(aten, _print) \ +_(aten, _propagate_xla_data) \ +_(aten, _remove_batch_dim) \ +_(aten, _reshape_alias) \ +_(aten, _reshape_alias_copy) \ +_(aten, _reshape_copy) \ +_(aten, _reshape_from_tensor) \ +_(aten, _resize_output) \ +_(aten, _resize_output_) \ +_(aten, _rowwise_prune) \ +_(aten, _safe_softmax) \ +_(aten, _sample_dirichlet) \ +_(aten, _saturate_weight_to_fp16) \ +_(aten, _scaled_dot_product_attention_math) \ +_(aten, _scaled_dot_product_attention_math_for_mps) \ +_(aten, _scaled_dot_product_cudnn_attention) \ +_(aten, _scaled_dot_product_cudnn_attention_backward) \ +_(aten, _scaled_dot_product_efficient_attention) \ +_(aten, _scaled_dot_product_efficient_attention_backward) \ +_(aten, _scaled_dot_product_flash_attention) \ +_(aten, _scaled_dot_product_flash_attention_backward) \ +_(aten, _scaled_dot_product_flash_attention_for_cpu) \ +_(aten, _scaled_dot_product_flash_attention_for_cpu_backward) \ +_(aten, _scaled_dot_product_fused_attention_overrideable) \ +_(aten, _scaled_dot_product_fused_attention_overrideable_backward) \ +_(aten, _scaled_grouped_mm) \ +_(aten, _scaled_mm) \ +_(aten, _segment_reduce_backward) \ +_(aten, _shape_as_tensor) \ +_(aten, _slow_conv2d_backward) \ +_(aten, _slow_conv2d_forward) \ +_(aten, _sobol_engine_draw) \ +_(aten, _sobol_engine_ff) \ +_(aten, _sobol_engine_ff_) \ +_(aten, _sobol_engine_initialize_state) \ +_(aten, _sobol_engine_initialize_state_) \ +_(aten, _sobol_engine_scramble) \ +_(aten, _sobol_engine_scramble_) \ +_(aten, _softmax) \ +_(aten, _softmax_backward_data) \ +_(aten, _sparse_addmm) \ +_(aten, _sparse_broadcast_to) \ +_(aten, _sparse_broadcast_to_copy) \ +_(aten, _sparse_bsc_tensor_unsafe) \ +_(aten, _sparse_bsr_tensor_unsafe) \ +_(aten, _sparse_compressed_tensor_unsafe) \ +_(aten, _sparse_compressed_tensor_with_dims) \ +_(aten, _sparse_coo_tensor_unsafe) \ +_(aten, _sparse_coo_tensor_with_dims) \ +_(aten, _sparse_coo_tensor_with_dims_and_tensors) \ +_(aten, _sparse_csc_tensor_unsafe) \ +_(aten, _sparse_csr_prod) \ +_(aten, _sparse_csr_sum) \ +_(aten, _sparse_csr_tensor_unsafe) \ +_(aten, _sparse_log_softmax) \ +_(aten, _sparse_log_softmax_backward_data) \ +_(aten, _sparse_mask_projection) \ +_(aten, _sparse_mm) \ +_(aten, _sparse_mm_reduce_impl) \ +_(aten, _sparse_mm_reduce_impl_backward) \ +_(aten, _sparse_semi_structured_addmm) \ +_(aten, _sparse_semi_structured_apply) \ +_(aten, _sparse_semi_structured_apply_dense) \ +_(aten, _sparse_semi_structured_linear) \ +_(aten, _sparse_semi_structured_mm) \ +_(aten, _sparse_semi_structured_tile) \ +_(aten, _sparse_softmax) \ +_(aten, _sparse_softmax_backward_data) \ +_(aten, _sparse_sparse_matmul) \ +_(aten, _sparse_sum) \ +_(aten, _sparse_sum_backward) \ +_(aten, _spdiags) \ +_(aten, _spsolve) \ +_(aten, _stack) \ +_(aten, _standard_gamma) \ +_(aten, _standard_gamma_grad) \ +_(aten, _test_ambiguous_defaults) \ +_(aten, _test_autograd_multiple_dispatch) \ +_(aten, _test_autograd_multiple_dispatch_view) \ +_(aten, _test_autograd_multiple_dispatch_view_copy) \ +_(aten, _test_check_tensor) \ +_(aten, _test_functorch_fallback) \ +_(aten, _test_optional_filled_intlist) \ +_(aten, _test_optional_floatlist) \ +_(aten, _test_optional_intlist) \ +_(aten, _test_parallel_materialize) \ +_(aten, _test_serialization_subcmul) \ +_(aten, _test_string_default) \ +_(aten, _test_warn_in_autograd) \ +_(aten, _thnn_differentiable_gru_cell_backward) \ +_(aten, _thnn_differentiable_lstm_cell_backward) \ +_(aten, _thnn_fused_gru_cell) \ +_(aten, _thnn_fused_gru_cell_backward) \ +_(aten, _thnn_fused_lstm_cell) \ +_(aten, _thnn_fused_lstm_cell_backward) \ +_(aten, _thnn_fused_lstm_cell_backward_impl) \ +_(aten, _to_copy) \ +_(aten, _to_cpu) \ +_(aten, _to_dense) \ +_(aten, _to_sparse) \ +_(aten, _to_sparse_bsc) \ +_(aten, _to_sparse_bsr) \ +_(aten, _to_sparse_csc) \ +_(aten, _to_sparse_csr) \ +_(aten, _to_sparse_semi_structured) \ +_(aten, _transform_bias_rescale_qkv) \ +_(aten, _transformer_encoder_layer_fwd) \ +_(aten, _trilinear) \ +_(aten, _triton_multi_head_attention) \ +_(aten, _triton_scaled_dot_attention) \ +_(aten, _unique) \ +_(aten, _unique2) \ +_(aten, _unpack_dual) \ +_(aten, _unsafe_index) \ +_(aten, _unsafe_index_put) \ +_(aten, _unsafe_masked_index) \ +_(aten, _unsafe_masked_index_put_accumulate) \ +_(aten, _unsafe_view) \ +_(aten, _upsample_bicubic2d_aa) \ +_(aten, _upsample_bicubic2d_aa_backward) \ +_(aten, _upsample_bilinear2d_aa) \ +_(aten, _upsample_bilinear2d_aa_backward) \ +_(aten, _upsample_nearest_exact1d) \ +_(aten, _upsample_nearest_exact1d_backward) \ +_(aten, _upsample_nearest_exact2d) \ +_(aten, _upsample_nearest_exact2d_backward) \ +_(aten, _upsample_nearest_exact3d) \ +_(aten, _upsample_nearest_exact3d_backward) \ +_(aten, _use_cudnn_ctc_loss) \ +_(aten, _use_cudnn_rnn_flatten_weight) \ +_(aten, _validate_compressed_sparse_indices) \ +_(aten, _validate_sparse_bsc_tensor_args) \ +_(aten, _validate_sparse_bsr_tensor_args) \ +_(aten, _validate_sparse_compressed_tensor_args) \ +_(aten, _validate_sparse_coo_tensor_args) \ +_(aten, _validate_sparse_csc_tensor_args) \ +_(aten, _validate_sparse_csr_tensor_args) \ +_(aten, _values) \ +_(aten, _values_copy) \ +_(aten, _version) \ +_(aten, _weight_int4pack_mm) \ +_(aten, _weight_int4pack_mm_for_cpu) \ +_(aten, _weight_int4pack_mm_with_scales_and_zeros) \ +_(aten, _weight_int8pack_mm) \ +_(aten, _weight_norm) \ +_(aten, _weight_norm_differentiable_backward) \ +_(aten, _weight_norm_interface) \ +_(aten, _weight_norm_interface_backward) \ +_(aten, _wrapped_linear_prepack) \ +_(aten, _wrapped_quantized_linear_prepacked) \ +_(aten, abs) \ +_(aten, abs_) \ +_(aten, absolute) \ +_(aten, absolute_) \ +_(aten, acos) \ +_(aten, acos_) \ +_(aten, acosh) \ +_(aten, acosh_) \ +_(aten, adaptive_avg_pool1d) \ +_(aten, adaptive_avg_pool2d) \ +_(aten, adaptive_avg_pool3d) \ +_(aten, adaptive_avg_pool3d_backward) \ +_(aten, adaptive_max_pool1d) \ +_(aten, adaptive_max_pool2d) \ +_(aten, adaptive_max_pool2d_backward) \ +_(aten, adaptive_max_pool3d) \ +_(aten, adaptive_max_pool3d_backward) \ +_(aten, add) \ +_(aten, add_) \ +_(aten, addbmm) \ +_(aten, addbmm_) \ +_(aten, addcdiv) \ +_(aten, addcdiv_) \ +_(aten, addcmul) \ +_(aten, addcmul_) \ +_(aten, addmm) \ +_(aten, addmm_) \ +_(aten, addmv) \ +_(aten, addmv_) \ +_(aten, addr) \ +_(aten, addr_) \ +_(aten, adjoint) \ +_(aten, affine_grid_generator) \ +_(aten, affine_grid_generator_backward) \ +_(aten, alias) \ +_(aten, alias_copy) \ +_(aten, align_as) \ +_(aten, align_tensors) \ +_(aten, align_to) \ +_(aten, all) \ +_(aten, allclose) \ +_(aten, alpha_dropout) \ +_(aten, alpha_dropout_) \ +_(aten, amax) \ +_(aten, amin) \ +_(aten, aminmax) \ +_(aten, angle) \ +_(aten, any) \ +_(aten, arange) \ +_(aten, arccos) \ +_(aten, arccos_) \ +_(aten, arccosh) \ +_(aten, arccosh_) \ +_(aten, arcsin) \ +_(aten, arcsin_) \ +_(aten, arcsinh) \ +_(aten, arcsinh_) \ +_(aten, arctan) \ +_(aten, arctan2) \ +_(aten, arctan2_) \ +_(aten, arctan_) \ +_(aten, arctanh) \ +_(aten, arctanh_) \ +_(aten, argmax) \ +_(aten, argmin) \ +_(aten, argsort) \ +_(aten, argwhere) \ +_(aten, as_strided) \ +_(aten, as_strided_) \ +_(aten, as_strided_copy) \ +_(aten, as_strided_scatter) \ +_(aten, asin) \ +_(aten, asin_) \ +_(aten, asinh) \ +_(aten, asinh_) \ +_(aten, atan) \ +_(aten, atan2) \ +_(aten, atan2_) \ +_(aten, atan_) \ +_(aten, atanh) \ +_(aten, atanh_) \ +_(aten, atleast_1d) \ +_(aten, atleast_2d) \ +_(aten, atleast_3d) \ +_(aten, avg_pool1d) \ +_(aten, avg_pool2d) \ +_(aten, avg_pool2d_backward) \ +_(aten, avg_pool3d) \ +_(aten, avg_pool3d_backward) \ +_(aten, baddbmm) \ +_(aten, baddbmm_) \ +_(aten, bartlett_window) \ +_(aten, batch_norm) \ +_(aten, batch_norm_backward) \ +_(aten, batch_norm_backward_elemt) \ +_(aten, batch_norm_backward_reduce) \ +_(aten, batch_norm_elemt) \ +_(aten, batch_norm_gather_stats) \ +_(aten, batch_norm_gather_stats_with_counts) \ +_(aten, batch_norm_stats) \ +_(aten, batch_norm_update_stats) \ +_(aten, bernoulli) \ +_(aten, bernoulli_) \ +_(aten, bilinear) \ +_(aten, binary_cross_entropy) \ +_(aten, binary_cross_entropy_backward) \ +_(aten, binary_cross_entropy_with_logits) \ +_(aten, bincount) \ +_(aten, binomial) \ +_(aten, bitwise_and) \ +_(aten, bitwise_and_) \ +_(aten, bitwise_left_shift) \ +_(aten, bitwise_left_shift_) \ +_(aten, bitwise_not) \ +_(aten, bitwise_not_) \ +_(aten, bitwise_or) \ +_(aten, bitwise_or_) \ +_(aten, bitwise_right_shift) \ +_(aten, bitwise_right_shift_) \ +_(aten, bitwise_xor) \ +_(aten, bitwise_xor_) \ +_(aten, blackman_window) \ +_(aten, block_diag) \ +_(aten, bmm) \ +_(aten, broadcast_tensors) \ +_(aten, broadcast_to) \ +_(aten, bucketize) \ +_(aten, can_cast) \ +_(aten, cartesian_prod) \ +_(aten, cat) \ +_(aten, cauchy) \ +_(aten, cauchy_) \ +_(aten, ccol_indices) \ +_(aten, ccol_indices_copy) \ +_(aten, cdist) \ +_(aten, ceil) \ +_(aten, ceil_) \ +_(aten, celu) \ +_(aten, celu_) \ +_(aten, chain_matmul) \ +_(aten, chalf) \ +_(aten, channel_shuffle) \ +_(aten, cholesky) \ +_(aten, cholesky_inverse) \ +_(aten, cholesky_solve) \ +_(aten, choose_qparams_optimized) \ +_(aten, chunk) \ +_(aten, clamp) \ +_(aten, clamp_) \ +_(aten, clamp_max) \ +_(aten, clamp_max_) \ +_(aten, clamp_min) \ +_(aten, clamp_min_) \ +_(aten, clip) \ +_(aten, clip_) \ +_(aten, clone) \ +_(aten, coalesce) \ +_(aten, col2im) \ +_(aten, col_indices) \ +_(aten, col_indices_copy) \ +_(aten, column_stack) \ +_(aten, combinations) \ +_(aten, complex) \ +_(aten, concat) \ +_(aten, concatenate) \ +_(aten, conj) \ +_(aten, conj_physical) \ +_(aten, conj_physical_) \ +_(aten, constant_pad_nd) \ +_(aten, contiguous) \ +_(aten, conv1d) \ +_(aten, conv2d) \ +_(aten, conv3d) \ +_(aten, conv_depthwise3d) \ +_(aten, conv_tbc) \ +_(aten, conv_tbc_backward) \ +_(aten, conv_transpose1d) \ +_(aten, conv_transpose2d) \ +_(aten, conv_transpose3d) \ +_(aten, convolution) \ +_(aten, convolution_backward) \ +_(aten, convolution_backward_overrideable) \ +_(aten, convolution_overrideable) \ +_(aten, copy) \ +_(aten, copy_) \ +_(aten, copy_sparse_to_sparse) \ +_(aten, copy_sparse_to_sparse_) \ +_(aten, copysign) \ +_(aten, copysign_) \ +_(aten, corrcoef) \ +_(aten, cos) \ +_(aten, cos_) \ +_(aten, cosh) \ +_(aten, cosh_) \ +_(aten, cosine_embedding_loss) \ +_(aten, cosine_similarity) \ +_(aten, count_nonzero) \ +_(aten, cov) \ +_(aten, cross) \ +_(aten, cross_entropy_loss) \ +_(aten, crow_indices) \ +_(aten, crow_indices_copy) \ +_(aten, ctc_loss) \ +_(aten, cudnn_affine_grid_generator) \ +_(aten, cudnn_affine_grid_generator_backward) \ +_(aten, cudnn_batch_norm) \ +_(aten, cudnn_batch_norm_backward) \ +_(aten, cudnn_convolution) \ +_(aten, cudnn_convolution_add_relu) \ +_(aten, cudnn_convolution_relu) \ +_(aten, cudnn_convolution_transpose) \ +_(aten, cudnn_grid_sampler) \ +_(aten, cudnn_grid_sampler_backward) \ +_(aten, cudnn_is_acceptable) \ +_(aten, cummax) \ +_(aten, cummaxmin_backward) \ +_(aten, cummin) \ +_(aten, cumprod) \ +_(aten, cumprod_) \ +_(aten, cumprod_backward) \ +_(aten, cumsum) \ +_(aten, cumsum_) \ +_(aten, cumulative_trapezoid) \ +_(aten, data) \ +_(aten, deg2rad) \ +_(aten, deg2rad_) \ +_(aten, dense_dim) \ +_(aten, dequantize) \ +_(aten, det) \ +_(aten, detach) \ +_(aten, detach_) \ +_(aten, detach_copy) \ +_(aten, diag) \ +_(aten, diag_embed) \ +_(aten, diagflat) \ +_(aten, diagonal) \ +_(aten, diagonal_backward) \ +_(aten, diagonal_copy) \ +_(aten, diagonal_scatter) \ +_(aten, diff) \ +_(aten, digamma) \ +_(aten, digamma_) \ +_(aten, dist) \ +_(aten, div) \ +_(aten, div_) \ +_(aten, divide) \ +_(aten, divide_) \ +_(aten, dot) \ +_(aten, dropout) \ +_(aten, dropout_) \ +_(aten, dsplit) \ +_(aten, dstack) \ +_(aten, einsum) \ +_(aten, elu) \ +_(aten, elu_) \ +_(aten, elu_backward) \ +_(aten, embedding) \ +_(aten, embedding_backward) \ +_(aten, embedding_bag) \ +_(aten, embedding_dense_backward) \ +_(aten, embedding_renorm) \ +_(aten, embedding_renorm_) \ +_(aten, embedding_sparse_backward) \ +_(aten, empty) \ +_(aten, empty_like) \ +_(aten, empty_permuted) \ +_(aten, empty_quantized) \ +_(aten, empty_strided) \ +_(aten, eq) \ +_(aten, eq_) \ +_(aten, equal) \ +_(aten, erf) \ +_(aten, erf_) \ +_(aten, erfc) \ +_(aten, erfc_) \ +_(aten, erfinv) \ +_(aten, erfinv_) \ +_(aten, exp) \ +_(aten, exp2) \ +_(aten, exp2_) \ +_(aten, exp_) \ +_(aten, expand) \ +_(aten, expand_as) \ +_(aten, expand_copy) \ +_(aten, expm1) \ +_(aten, expm1_) \ +_(aten, exponential) \ +_(aten, exponential_) \ +_(aten, eye) \ +_(aten, fake_quantize_per_channel_affine) \ +_(aten, fake_quantize_per_channel_affine_cachemask) \ +_(aten, fake_quantize_per_channel_affine_cachemask_backward) \ +_(aten, fake_quantize_per_tensor_affine) \ +_(aten, fake_quantize_per_tensor_affine_cachemask) \ +_(aten, fake_quantize_per_tensor_affine_cachemask_backward) \ +_(aten, fbgemm_linear_fp16_weight) \ +_(aten, fbgemm_linear_fp16_weight_fp32_activation) \ +_(aten, fbgemm_linear_int8_weight) \ +_(aten, fbgemm_linear_int8_weight_fp32_activation) \ +_(aten, fbgemm_linear_quantize_weight) \ +_(aten, fbgemm_pack_gemm_matrix_fp16) \ +_(aten, fbgemm_pack_quantized_matrix) \ +_(aten, feature_alpha_dropout) \ +_(aten, feature_alpha_dropout_) \ +_(aten, feature_dropout) \ +_(aten, feature_dropout_) \ +_(aten, fft_fft) \ +_(aten, fft_fft2) \ +_(aten, fft_fftfreq) \ +_(aten, fft_fftn) \ +_(aten, fft_fftshift) \ +_(aten, fft_hfft) \ +_(aten, fft_hfft2) \ +_(aten, fft_hfftn) \ +_(aten, fft_ifft) \ +_(aten, fft_ifft2) \ +_(aten, fft_ifftn) \ +_(aten, fft_ifftshift) \ +_(aten, fft_ihfft) \ +_(aten, fft_ihfft2) \ +_(aten, fft_ihfftn) \ +_(aten, fft_irfft) \ +_(aten, fft_irfft2) \ +_(aten, fft_irfftn) \ +_(aten, fft_rfft) \ +_(aten, fft_rfft2) \ +_(aten, fft_rfftfreq) \ +_(aten, fft_rfftn) \ +_(aten, fill) \ +_(aten, fill_) \ +_(aten, fill_diagonal) \ +_(aten, fill_diagonal_) \ +_(aten, fix) \ +_(aten, fix_) \ +_(aten, flatten) \ +_(aten, flatten_dense_tensors) \ +_(aten, flip) \ +_(aten, fliplr) \ +_(aten, flipud) \ +_(aten, float_power) \ +_(aten, float_power_) \ +_(aten, floor) \ +_(aten, floor_) \ +_(aten, floor_divide) \ +_(aten, floor_divide_) \ +_(aten, fmax) \ +_(aten, fmin) \ +_(aten, fmod) \ +_(aten, fmod_) \ +_(aten, frac) \ +_(aten, frac_) \ +_(aten, fractional_max_pool2d) \ +_(aten, fractional_max_pool2d_backward) \ +_(aten, fractional_max_pool3d) \ +_(aten, fractional_max_pool3d_backward) \ +_(aten, frexp) \ +_(aten, frobenius_norm) \ +_(aten, from_file) \ +_(aten, full) \ +_(aten, full_like) \ +_(aten, fused_moving_avg_obs_fake_quant) \ +_(aten, gather) \ +_(aten, gather_backward) \ +_(aten, gcd) \ +_(aten, gcd_) \ +_(aten, ge) \ +_(aten, ge_) \ +_(aten, gelu) \ +_(aten, gelu_) \ +_(aten, gelu_backward) \ +_(aten, geometric) \ +_(aten, geometric_) \ +_(aten, geqrf) \ +_(aten, ger) \ +_(aten, glu) \ +_(aten, glu_backward) \ +_(aten, glu_backward_jvp) \ +_(aten, glu_jvp) \ +_(aten, gradient) \ +_(aten, greater) \ +_(aten, greater_) \ +_(aten, greater_equal) \ +_(aten, greater_equal_) \ +_(aten, grid_sampler) \ +_(aten, grid_sampler_2d) \ +_(aten, grid_sampler_2d_backward) \ +_(aten, grid_sampler_3d) \ +_(aten, grid_sampler_3d_backward) \ +_(aten, group_norm) \ +_(aten, gru) \ +_(aten, gru_cell) \ +_(aten, gt) \ +_(aten, gt_) \ +_(aten, hamming_window) \ +_(aten, hann_window) \ +_(aten, hardshrink) \ +_(aten, hardshrink_backward) \ +_(aten, hardsigmoid) \ +_(aten, hardsigmoid_) \ +_(aten, hardsigmoid_backward) \ +_(aten, hardswish) \ +_(aten, hardswish_) \ +_(aten, hardswish_backward) \ +_(aten, hardtanh) \ +_(aten, hardtanh_) \ +_(aten, hardtanh_backward) \ +_(aten, heaviside) \ +_(aten, heaviside_) \ +_(aten, hinge_embedding_loss) \ +_(aten, histc) \ +_(aten, histogram) \ +_(aten, histogramdd) \ +_(aten, hsplit) \ +_(aten, hspmm) \ +_(aten, hstack) \ +_(aten, huber_loss) \ +_(aten, huber_loss_backward) \ +_(aten, hypot) \ +_(aten, hypot_) \ +_(aten, i0) \ +_(aten, i0_) \ +_(aten, igamma) \ +_(aten, igamma_) \ +_(aten, igammac) \ +_(aten, igammac_) \ +_(aten, im2col) \ +_(aten, imag) \ +_(aten, index) \ +_(aten, index_add) \ +_(aten, index_add_) \ +_(aten, index_copy) \ +_(aten, index_copy_) \ +_(aten, index_fill) \ +_(aten, index_fill_) \ +_(aten, index_put) \ +_(aten, index_put_) \ +_(aten, index_reduce) \ +_(aten, index_reduce_) \ +_(aten, index_select) \ +_(aten, index_select_backward) \ +_(aten, indices) \ +_(aten, indices_copy) \ +_(aten, infinitely_differentiable_gelu_backward) \ +_(aten, inner) \ +_(aten, instance_norm) \ +_(aten, int_repr) \ +_(aten, inverse) \ +_(aten, is_coalesced) \ +_(aten, is_complex) \ +_(aten, is_conj) \ +_(aten, is_distributed) \ +_(aten, is_floating_point) \ +_(aten, is_inference) \ +_(aten, is_leaf) \ +_(aten, is_neg) \ +_(aten, is_nonzero) \ +_(aten, is_pinned) \ +_(aten, is_same_size) \ +_(aten, is_set_to) \ +_(aten, is_signed) \ +_(aten, is_vulkan_available) \ +_(aten, isclose) \ +_(aten, isfinite) \ +_(aten, isin) \ +_(aten, isinf) \ +_(aten, isnan) \ +_(aten, isneginf) \ +_(aten, isposinf) \ +_(aten, isreal) \ +_(aten, istft) \ +_(aten, item) \ +_(aten, kaiser_window) \ +_(aten, kl_div) \ +_(aten, kron) \ +_(aten, kthvalue) \ +_(aten, l1_loss) \ +_(aten, layer_norm) \ +_(aten, lcm) \ +_(aten, lcm_) \ +_(aten, ldexp) \ +_(aten, ldexp_) \ +_(aten, le) \ +_(aten, le_) \ +_(aten, leaky_relu) \ +_(aten, leaky_relu_) \ +_(aten, leaky_relu_backward) \ +_(aten, lerp) \ +_(aten, lerp_) \ +_(aten, less) \ +_(aten, less_) \ +_(aten, less_equal) \ +_(aten, less_equal_) \ +_(aten, lgamma) \ +_(aten, lgamma_) \ +_(aten, lift) \ +_(aten, lift_fresh) \ +_(aten, lift_fresh_copy) \ +_(aten, linalg_cholesky) \ +_(aten, linalg_cholesky_ex) \ +_(aten, linalg_cond) \ +_(aten, linalg_cross) \ +_(aten, linalg_det) \ +_(aten, linalg_diagonal) \ +_(aten, linalg_eig) \ +_(aten, linalg_eigh) \ +_(aten, linalg_eigvals) \ +_(aten, linalg_eigvalsh) \ +_(aten, linalg_householder_product) \ +_(aten, linalg_inv) \ +_(aten, linalg_inv_ex) \ +_(aten, linalg_ldl_factor) \ +_(aten, linalg_ldl_factor_ex) \ +_(aten, linalg_ldl_solve) \ +_(aten, linalg_lstsq) \ +_(aten, linalg_lu) \ +_(aten, linalg_lu_factor) \ +_(aten, linalg_lu_factor_ex) \ +_(aten, linalg_lu_solve) \ +_(aten, linalg_matmul) \ +_(aten, linalg_matrix_exp) \ +_(aten, linalg_matrix_norm) \ +_(aten, linalg_matrix_power) \ +_(aten, linalg_matrix_rank) \ +_(aten, linalg_multi_dot) \ +_(aten, linalg_norm) \ +_(aten, linalg_pinv) \ +_(aten, linalg_qr) \ +_(aten, linalg_slogdet) \ +_(aten, linalg_solve) \ +_(aten, linalg_solve_ex) \ +_(aten, linalg_solve_triangular) \ +_(aten, linalg_svd) \ +_(aten, linalg_svdvals) \ +_(aten, linalg_tensorinv) \ +_(aten, linalg_tensorsolve) \ +_(aten, linalg_vander) \ +_(aten, linalg_vecdot) \ +_(aten, linalg_vector_norm) \ +_(aten, linear) \ +_(aten, linear_backward) \ +_(aten, linspace) \ +_(aten, log) \ +_(aten, log10) \ +_(aten, log10_) \ +_(aten, log1p) \ +_(aten, log1p_) \ +_(aten, log2) \ +_(aten, log2_) \ +_(aten, log_) \ +_(aten, log_normal) \ +_(aten, log_normal_) \ +_(aten, log_sigmoid) \ +_(aten, log_sigmoid_backward) \ +_(aten, log_sigmoid_forward) \ +_(aten, log_softmax) \ +_(aten, logaddexp) \ +_(aten, logaddexp2) \ +_(aten, logcumsumexp) \ +_(aten, logdet) \ +_(aten, logical_and) \ +_(aten, logical_and_) \ +_(aten, logical_not) \ +_(aten, logical_not_) \ +_(aten, logical_or) \ +_(aten, logical_or_) \ +_(aten, logical_xor) \ +_(aten, logical_xor_) \ +_(aten, logit) \ +_(aten, logit_) \ +_(aten, logit_backward) \ +_(aten, logspace) \ +_(aten, logsumexp) \ +_(aten, lshift) \ +_(aten, lstm) \ +_(aten, lstm_cell) \ +_(aten, lstm_mps_backward) \ +_(aten, lt) \ +_(aten, lt_) \ +_(aten, lu_solve) \ +_(aten, lu_unpack) \ +_(aten, mH) \ +_(aten, mT) \ +_(aten, margin_ranking_loss) \ +_(aten, masked_fill) \ +_(aten, masked_fill_) \ +_(aten, masked_scatter) \ +_(aten, masked_scatter_) \ +_(aten, masked_scatter_backward) \ +_(aten, masked_select) \ +_(aten, masked_select_backward) \ +_(aten, matmul) \ +_(aten, matmul_backward) \ +_(aten, matrix_H) \ +_(aten, matrix_exp) \ +_(aten, matrix_exp_backward) \ +_(aten, matrix_power) \ +_(aten, max) \ +_(aten, max_pool1d) \ +_(aten, max_pool1d_with_indices) \ +_(aten, max_pool2d) \ +_(aten, max_pool2d_backward) \ +_(aten, max_pool2d_with_indices) \ +_(aten, max_pool2d_with_indices_backward) \ +_(aten, max_pool3d) \ +_(aten, max_pool3d_with_indices) \ +_(aten, max_pool3d_with_indices_backward) \ +_(aten, max_unpool2d) \ +_(aten, max_unpool3d) \ +_(aten, maximum) \ +_(aten, mean) \ +_(aten, median) \ +_(aten, meshgrid) \ +_(aten, min) \ +_(aten, minimum) \ +_(aten, miopen_batch_norm) \ +_(aten, miopen_batch_norm_backward) \ +_(aten, miopen_convolution) \ +_(aten, miopen_convolution_add_relu) \ +_(aten, miopen_convolution_relu) \ +_(aten, miopen_convolution_transpose) \ +_(aten, miopen_depthwise_convolution) \ +_(aten, miopen_rnn) \ +_(aten, miopen_rnn_backward) \ +_(aten, mish) \ +_(aten, mish_) \ +_(aten, mish_backward) \ +_(aten, mkldnn_adaptive_avg_pool2d) \ +_(aten, mkldnn_adaptive_avg_pool2d_backward) \ +_(aten, mkldnn_convolution) \ +_(aten, mkldnn_linear) \ +_(aten, mkldnn_linear_backward) \ +_(aten, mkldnn_linear_backward_input) \ +_(aten, mkldnn_linear_backward_weights) \ +_(aten, mkldnn_max_pool2d) \ +_(aten, mkldnn_max_pool2d_backward) \ +_(aten, mkldnn_max_pool3d) \ +_(aten, mkldnn_max_pool3d_backward) \ +_(aten, mkldnn_reorder_conv2d_weight) \ +_(aten, mkldnn_reorder_conv3d_weight) \ +_(aten, mkldnn_rnn_layer) \ +_(aten, mkldnn_rnn_layer_backward) \ +_(aten, mm) \ +_(aten, mode) \ +_(aten, moveaxis) \ +_(aten, movedim) \ +_(aten, mps_convolution_backward) \ +_(aten, mps_convolution_transpose_backward) \ +_(aten, mse_loss) \ +_(aten, mse_loss_backward) \ +_(aten, msort) \ +_(aten, mul) \ +_(aten, mul_) \ +_(aten, multi_margin_loss) \ +_(aten, multi_margin_loss_backward) \ +_(aten, multilabel_margin_loss) \ +_(aten, multilabel_margin_loss_backward) \ +_(aten, multilabel_margin_loss_forward) \ +_(aten, multinomial) \ +_(aten, multiply) \ +_(aten, multiply_) \ +_(aten, mv) \ +_(aten, mvlgamma) \ +_(aten, mvlgamma_) \ +_(aten, nan_to_num) \ +_(aten, nan_to_num_) \ +_(aten, nanmean) \ +_(aten, nanmedian) \ +_(aten, nanquantile) \ +_(aten, nansum) \ +_(aten, narrow) \ +_(aten, narrow_copy) \ +_(aten, native_batch_norm) \ +_(aten, native_batch_norm_backward) \ +_(aten, native_channel_shuffle) \ +_(aten, native_dropout) \ +_(aten, native_dropout_backward) \ +_(aten, native_group_norm) \ +_(aten, native_group_norm_backward) \ +_(aten, native_layer_norm) \ +_(aten, native_layer_norm_backward) \ +_(aten, native_norm) \ +_(aten, ne) \ +_(aten, ne_) \ +_(aten, neg) \ +_(aten, neg_) \ +_(aten, negative) \ +_(aten, negative_) \ +_(aten, nested_to_padded_tensor) \ +_(aten, new_empty) \ +_(aten, new_empty_strided) \ +_(aten, new_full) \ +_(aten, new_ones) \ +_(aten, new_zeros) \ +_(aten, nextafter) \ +_(aten, nextafter_) \ +_(aten, nll_loss) \ +_(aten, nll_loss2d) \ +_(aten, nll_loss2d_backward) \ +_(aten, nll_loss2d_forward) \ +_(aten, nll_loss_backward) \ +_(aten, nll_loss_forward) \ +_(aten, nll_loss_nd) \ +_(aten, nonzero) \ +_(aten, nonzero_numpy) \ +_(aten, nonzero_static) \ +_(aten, norm) \ +_(aten, norm_except_dim) \ +_(aten, normal) \ +_(aten, normal_) \ +_(aten, normal_functional) \ +_(aten, not_equal) \ +_(aten, not_equal_) \ +_(aten, nuclear_norm) \ +_(aten, numpy_T) \ +_(aten, one_hot) \ +_(aten, ones) \ +_(aten, ones_like) \ +_(aten, orgqr) \ +_(aten, ormqr) \ +_(aten, outer) \ +_(aten, output_nr) \ +_(aten, pad) \ +_(aten, pad_sequence) \ +_(aten, pairwise_distance) \ +_(aten, pdist) \ +_(aten, permute) \ +_(aten, permute_copy) \ +_(aten, pin_memory) \ +_(aten, pinverse) \ +_(aten, pixel_shuffle) \ +_(aten, pixel_unshuffle) \ +_(aten, poisson) \ +_(aten, poisson_nll_loss) \ +_(aten, polar) \ +_(aten, polygamma) \ +_(aten, polygamma_) \ +_(aten, positive) \ +_(aten, pow) \ +_(aten, pow_) \ +_(aten, prelu) \ +_(aten, prod) \ +_(aten, promote_types) \ +_(aten, put) \ +_(aten, put_) \ +_(aten, q_per_channel_axis) \ +_(aten, q_per_channel_scales) \ +_(aten, q_per_channel_zero_points) \ +_(aten, q_scale) \ +_(aten, q_zero_point) \ +_(aten, qr) \ +_(aten, qscheme) \ +_(aten, quantile) \ +_(aten, quantize_per_channel) \ +_(aten, quantize_per_tensor) \ +_(aten, quantize_per_tensor_dynamic) \ +_(aten, quantized_batch_norm) \ +_(aten, quantized_gru_cell) \ +_(aten, quantized_lstm_cell) \ +_(aten, quantized_max_pool1d) \ +_(aten, quantized_max_pool2d) \ +_(aten, quantized_max_pool3d) \ +_(aten, quantized_rnn_relu_cell) \ +_(aten, quantized_rnn_tanh_cell) \ +_(aten, rad2deg) \ +_(aten, rad2deg_) \ +_(aten, rand) \ +_(aten, rand_like) \ +_(aten, randint) \ +_(aten, randint_like) \ +_(aten, randn) \ +_(aten, randn_like) \ +_(aten, random) \ +_(aten, random_) \ +_(aten, randperm) \ +_(aten, range) \ +_(aten, ravel) \ +_(aten, real) \ +_(aten, reciprocal) \ +_(aten, reciprocal_) \ +_(aten, record_stream) \ +_(aten, refine_names) \ +_(aten, reflection_pad1d) \ +_(aten, reflection_pad1d_backward) \ +_(aten, reflection_pad2d) \ +_(aten, reflection_pad2d_backward) \ +_(aten, reflection_pad3d) \ +_(aten, reflection_pad3d_backward) \ +_(aten, relu) \ +_(aten, relu6) \ +_(aten, relu6_) \ +_(aten, relu_) \ +_(aten, remainder) \ +_(aten, remainder_) \ +_(aten, rename) \ +_(aten, rename_) \ +_(aten, renorm) \ +_(aten, renorm_) \ +_(aten, repeat) \ +_(aten, repeat_interleave) \ +_(aten, replication_pad1d) \ +_(aten, replication_pad1d_backward) \ +_(aten, replication_pad2d) \ +_(aten, replication_pad2d_backward) \ +_(aten, replication_pad3d) \ +_(aten, replication_pad3d_backward) \ +_(aten, requires_grad) \ +_(aten, requires_grad_) \ +_(aten, reshape) \ +_(aten, reshape_as) \ +_(aten, resize) \ +_(aten, resize_) \ +_(aten, resize_as) \ +_(aten, resize_as_) \ +_(aten, resize_as_sparse) \ +_(aten, resize_as_sparse_) \ +_(aten, resolve_conj) \ +_(aten, resolve_neg) \ +_(aten, result_type) \ +_(aten, retain_grad) \ +_(aten, retains_grad) \ +_(aten, rms_norm) \ +_(aten, rnn_relu) \ +_(aten, rnn_relu_cell) \ +_(aten, rnn_tanh) \ +_(aten, rnn_tanh_cell) \ +_(aten, roll) \ +_(aten, rot90) \ +_(aten, round) \ +_(aten, round_) \ +_(aten, row_indices) \ +_(aten, row_indices_copy) \ +_(aten, row_stack) \ +_(aten, rrelu) \ +_(aten, rrelu_) \ +_(aten, rrelu_with_noise) \ +_(aten, rrelu_with_noise_) \ +_(aten, rrelu_with_noise_backward) \ +_(aten, rrelu_with_noise_functional) \ +_(aten, rshift) \ +_(aten, rsqrt) \ +_(aten, rsqrt_) \ +_(aten, rsub) \ +_(aten, scalar_tensor) \ +_(aten, scaled_dot_product_attention) \ +_(aten, scatter) \ +_(aten, scatter_) \ +_(aten, scatter_add) \ +_(aten, scatter_add_) \ +_(aten, scatter_reduce) \ +_(aten, scatter_reduce_) \ +_(aten, searchsorted) \ +_(aten, segment_reduce) \ +_(aten, select) \ +_(aten, select_backward) \ +_(aten, select_copy) \ +_(aten, select_scatter) \ +_(aten, selu) \ +_(aten, selu_) \ +_(aten, set) \ +_(aten, set_) \ +_(aten, set_data) \ +_(aten, sgn) \ +_(aten, sgn_) \ +_(aten, sigmoid) \ +_(aten, sigmoid_) \ +_(aten, sigmoid_backward) \ +_(aten, sign) \ +_(aten, sign_) \ +_(aten, signbit) \ +_(aten, silu) \ +_(aten, silu_) \ +_(aten, silu_backward) \ +_(aten, sin) \ +_(aten, sin_) \ +_(aten, sinc) \ +_(aten, sinc_) \ +_(aten, sinh) \ +_(aten, sinh_) \ +_(aten, size) \ +_(aten, slice) \ +_(aten, slice_backward) \ +_(aten, slice_copy) \ +_(aten, slice_inverse) \ +_(aten, slice_scatter) \ +_(aten, slogdet) \ +_(aten, slow_conv3d) \ +_(aten, slow_conv3d_forward) \ +_(aten, slow_conv_dilated2d) \ +_(aten, slow_conv_dilated3d) \ +_(aten, slow_conv_transpose2d) \ +_(aten, slow_conv_transpose3d) \ +_(aten, smm) \ +_(aten, smooth_l1_loss) \ +_(aten, smooth_l1_loss_backward) \ +_(aten, soft_margin_loss) \ +_(aten, soft_margin_loss_backward) \ +_(aten, softmax) \ +_(aten, softplus) \ +_(aten, softplus_backward) \ +_(aten, softshrink) \ +_(aten, softshrink_backward) \ +_(aten, sort) \ +_(aten, sparse_bsc_tensor) \ +_(aten, sparse_bsr_tensor) \ +_(aten, sparse_compressed_tensor) \ +_(aten, sparse_coo_tensor) \ +_(aten, sparse_csc_tensor) \ +_(aten, sparse_csr_tensor) \ +_(aten, sparse_dim) \ +_(aten, sparse_mask) \ +_(aten, sparse_resize) \ +_(aten, sparse_resize_) \ +_(aten, sparse_resize_and_clear) \ +_(aten, sparse_resize_and_clear_) \ +_(aten, sparse_sampled_addmm) \ +_(aten, special_airy_ai) \ +_(aten, special_bessel_j0) \ +_(aten, special_bessel_j1) \ +_(aten, special_bessel_y0) \ +_(aten, special_bessel_y1) \ +_(aten, special_chebyshev_polynomial_t) \ +_(aten, special_chebyshev_polynomial_u) \ +_(aten, special_chebyshev_polynomial_v) \ +_(aten, special_chebyshev_polynomial_w) \ +_(aten, special_digamma) \ +_(aten, special_entr) \ +_(aten, special_erf) \ +_(aten, special_erfc) \ +_(aten, special_erfcx) \ +_(aten, special_erfinv) \ +_(aten, special_exp2) \ +_(aten, special_expit) \ +_(aten, special_expm1) \ +_(aten, special_gammainc) \ +_(aten, special_gammaincc) \ +_(aten, special_gammaln) \ +_(aten, special_hermite_polynomial_h) \ +_(aten, special_hermite_polynomial_he) \ +_(aten, special_i0) \ +_(aten, special_i0e) \ +_(aten, special_i1) \ +_(aten, special_i1e) \ +_(aten, special_laguerre_polynomial_l) \ +_(aten, special_legendre_polynomial_p) \ +_(aten, special_log1p) \ +_(aten, special_log_ndtr) \ +_(aten, special_log_softmax) \ +_(aten, special_logit) \ +_(aten, special_logsumexp) \ +_(aten, special_modified_bessel_i0) \ +_(aten, special_modified_bessel_i1) \ +_(aten, special_modified_bessel_k0) \ +_(aten, special_modified_bessel_k1) \ +_(aten, special_multigammaln) \ +_(aten, special_ndtr) \ +_(aten, special_ndtri) \ +_(aten, special_polygamma) \ +_(aten, special_psi) \ +_(aten, special_round) \ +_(aten, special_scaled_modified_bessel_k0) \ +_(aten, special_scaled_modified_bessel_k1) \ +_(aten, special_shifted_chebyshev_polynomial_t) \ +_(aten, special_shifted_chebyshev_polynomial_u) \ +_(aten, special_shifted_chebyshev_polynomial_v) \ +_(aten, special_shifted_chebyshev_polynomial_w) \ +_(aten, special_sinc) \ +_(aten, special_softmax) \ +_(aten, special_spherical_bessel_j0) \ +_(aten, special_xlog1py) \ +_(aten, special_xlogy) \ +_(aten, special_zeta) \ +_(aten, split) \ +_(aten, split_copy) \ +_(aten, split_with_sizes) \ +_(aten, split_with_sizes_copy) \ +_(aten, sqrt) \ +_(aten, sqrt_) \ +_(aten, square) \ +_(aten, square_) \ +_(aten, squeeze) \ +_(aten, squeeze_) \ +_(aten, squeeze_copy) \ +_(aten, sspaddmm) \ +_(aten, stack) \ +_(aten, std) \ +_(aten, std_mean) \ +_(aten, stft) \ +_(aten, stride) \ +_(aten, sub) \ +_(aten, sub_) \ +_(aten, subtract) \ +_(aten, subtract_) \ +_(aten, sum) \ +_(aten, sum_to_size) \ +_(aten, svd) \ +_(aten, swapaxes) \ +_(aten, swapaxes_) \ +_(aten, swapdims) \ +_(aten, swapdims_) \ +_(aten, sym_constrain_range) \ +_(aten, sym_constrain_range_for_size) \ +_(aten, sym_numel) \ +_(aten, sym_size) \ +_(aten, sym_storage_offset) \ +_(aten, sym_stride) \ +_(aten, t) \ +_(aten, t_) \ +_(aten, t_copy) \ +_(aten, take) \ +_(aten, take_along_dim) \ +_(aten, tan) \ +_(aten, tan_) \ +_(aten, tanh) \ +_(aten, tanh_) \ +_(aten, tanh_backward) \ +_(aten, tensor_split) \ +_(aten, tensordot) \ +_(aten, thnn_conv2d) \ +_(aten, threshold) \ +_(aten, threshold_) \ +_(aten, threshold_backward) \ +_(aten, tile) \ +_(aten, to) \ +_(aten, to_dense) \ +_(aten, to_dense_backward) \ +_(aten, to_mkldnn) \ +_(aten, to_mkldnn_backward) \ +_(aten, to_padded_tensor) \ +_(aten, to_sparse) \ +_(aten, to_sparse_bsc) \ +_(aten, to_sparse_bsr) \ +_(aten, to_sparse_csc) \ +_(aten, to_sparse_csr) \ +_(aten, topk) \ +_(aten, trace) \ +_(aten, trace_backward) \ +_(aten, transpose) \ +_(aten, transpose_) \ +_(aten, transpose_copy) \ +_(aten, trapezoid) \ +_(aten, trapz) \ +_(aten, triangular_solve) \ +_(aten, tril) \ +_(aten, tril_) \ +_(aten, tril_indices) \ +_(aten, triplet_margin_loss) \ +_(aten, triu) \ +_(aten, triu_) \ +_(aten, triu_indices) \ +_(aten, true_divide) \ +_(aten, true_divide_) \ +_(aten, trunc) \ +_(aten, trunc_) \ +_(aten, type_as) \ +_(aten, unbind) \ +_(aten, unbind_copy) \ +_(aten, unflatten) \ +_(aten, unflatten_dense_tensors) \ +_(aten, unfold) \ +_(aten, unfold_backward) \ +_(aten, unfold_copy) \ +_(aten, uniform) \ +_(aten, uniform_) \ +_(aten, unique_consecutive) \ +_(aten, unique_dim) \ +_(aten, unique_dim_consecutive) \ +_(aten, unsafe_chunk) \ +_(aten, unsafe_split) \ +_(aten, unsafe_split_with_sizes) \ +_(aten, unsqueeze) \ +_(aten, unsqueeze_) \ +_(aten, unsqueeze_copy) \ +_(aten, upsample_bicubic2d) \ +_(aten, upsample_bicubic2d_backward) \ +_(aten, upsample_bilinear2d) \ +_(aten, upsample_bilinear2d_backward) \ +_(aten, upsample_linear1d) \ +_(aten, upsample_linear1d_backward) \ +_(aten, upsample_nearest1d) \ +_(aten, upsample_nearest1d_backward) \ +_(aten, upsample_nearest2d) \ +_(aten, upsample_nearest2d_backward) \ +_(aten, upsample_nearest3d) \ +_(aten, upsample_nearest3d_backward) \ +_(aten, upsample_trilinear3d) \ +_(aten, upsample_trilinear3d_backward) \ +_(aten, value_selecting_reduction_backward) \ +_(aten, values) \ +_(aten, values_copy) \ +_(aten, vander) \ +_(aten, var) \ +_(aten, var_mean) \ +_(aten, vdot) \ +_(aten, view) \ +_(aten, view_as) \ +_(aten, view_as_complex) \ +_(aten, view_as_complex_copy) \ +_(aten, view_as_real) \ +_(aten, view_as_real_copy) \ +_(aten, view_copy) \ +_(aten, vsplit) \ +_(aten, vstack) \ +_(aten, where) \ +_(aten, xlogy) \ +_(aten, xlogy_) \ +_(aten, zero) \ +_(aten, zero_) \ +_(aten, zeros) \ +_(aten, zeros_like) + +#define FORALL_ATTR_BASE_SYMBOLS(_) \ +_(attr, A) \ +_(attr, B) \ +_(attr, C) \ +_(attr, H) \ +_(attr, HxW) \ +_(attr, K) \ +_(attr, L) \ +_(attr, LD) \ +_(attr, LU) \ +_(attr, LU_data) \ +_(attr, LU_pivots) \ +_(attr, M) \ +_(attr, N) \ +_(attr, P) \ +_(attr, Q) \ +_(attr, R) \ +_(attr, S) \ +_(attr, U) \ +_(attr, UPLO) \ +_(attr, V) \ +_(attr, Vh) \ +_(attr, W) \ +_(attr, X) \ +_(attr, a) \ +_(attr, abs) \ +_(attr, accumulate) \ +_(attr, accumulate_matches) \ +_(attr, activation) \ +_(attr, addends) \ +_(attr, adjoint) \ +_(attr, alg_id) \ +_(attr, algorithm) \ +_(attr, alibi_slopes) \ +_(attr, align_corners) \ +_(attr, align_to_window) \ +_(attr, allow_tf32) \ +_(attr, alpha) \ +_(attr, amsgrad) \ +_(attr, anchor) \ +_(attr, angle) \ +_(attr, any) \ +_(attr, api_name) \ +_(attr, append) \ +_(attr, approximate) \ +_(attr, arg1) \ +_(attr, arg2) \ +_(attr, arg3) \ +_(attr, arg_out) \ +_(attr, assert_msg) \ +_(attr, assume_unique) \ +_(attr, atol) \ +_(attr, attn_bias) \ +_(attr, attn_mask) \ +_(attr, average_attn_weights) \ +_(attr, averaging_const) \ +_(attr, aweights) \ +_(attr, axis) \ +_(attr, axis0) \ +_(attr, axis1) \ +_(attr, b) \ +_(attr, b_hh) \ +_(attr, b_ih) \ +_(attr, bag_size) \ +_(attr, base) \ +_(attr, batch1) \ +_(attr, batch2) \ +_(attr, batch_dim) \ +_(attr, batch_first) \ +_(attr, batch_size) \ +_(attr, batch_sizes) \ +_(attr, benchmark) \ +_(attr, beta) \ +_(attr, beta1) \ +_(attr, beta2) \ +_(attr, bias) \ +_(attr, bias_defined) \ +_(attr, bias_g) \ +_(attr, bias_requires_grad) \ +_(attr, bias_sizes) \ +_(attr, bidirectional) \ +_(attr, bin_edges) \ +_(attr, bins) \ +_(attr, bit_width) \ +_(attr, blank) \ +_(attr, block_size) \ +_(attr, blocksize) \ +_(attr, boundaries) \ +_(attr, buffer) \ +_(attr, ccol_indices) \ +_(attr, cdim) \ +_(attr, cdist) \ +_(attr, ceil_mode) \ +_(attr, cell_state_fwd) \ +_(attr, center) \ +_(attr, ch_axis) \ +_(attr, check_errors) \ +_(attr, check_pinning) \ +_(attr, chunks) \ +_(attr, coalesced) \ +_(attr, coefficients) \ +_(attr, col) \ +_(attr, col_indices) \ +_(attr, col_offsets) \ +_(attr, col_offsets_hh) \ +_(attr, col_offsets_ih) \ +_(attr, compressed_A) \ +_(attr, compressed_idx) \ +_(attr, compressed_indices) \ +_(attr, compressed_indices_dtype) \ +_(attr, compute_log_sumexp) \ +_(attr, compute_mode) \ +_(attr, compute_uv) \ +_(attr, compute_v) \ +_(attr, condition) \ +_(attr, copy) \ +_(attr, correction) \ +_(attr, count) \ +_(attr, count_include_pad) \ +_(attr, counts) \ +_(attr, cpu_dtype) \ +_(attr, cpu_enabled) \ +_(attr, cpu_nested_shape_example) \ +_(attr, create_graph) \ +_(attr, crow_indices) \ +_(attr, cu_seqlens_k) \ +_(attr, cu_seqlens_q) \ +_(attr, cuda_dtype) \ +_(attr, cuda_enabled) \ +_(attr, cudnn_enable) \ +_(attr, cudnn_enabled) \ +_(attr, cum_seq_k) \ +_(attr, cum_seq_q) \ +_(attr, custom_mask_type) \ +_(attr, cx) \ +_(attr, cx_) \ +_(attr, cx_tmp) \ +_(attr, cy) \ +_(attr, cy_) \ +_(attr, d) \ +_(attr, dampening) \ +_(attr, data) \ +_(attr, decimals) \ +_(attr, delta) \ +_(attr, dense) \ +_(attr, dense_B) \ +_(attr, dense_dim) \ +_(attr, density) \ +_(attr, dep_token) \ +_(attr, descending) \ +_(attr, destination) \ +_(attr, deterministic) \ +_(attr, device) \ +_(attr, device_index) \ +_(attr, dgrad_glu) \ +_(attr, diagonal) \ +_(attr, diagonals) \ +_(attr, dilation) \ +_(attr, dim) \ +_(attr, dim0) \ +_(attr, dim1) \ +_(attr, dim2) \ +_(attr, dimension) \ +_(attr, dims) \ +_(attr, dims_other) \ +_(attr, dims_self) \ +_(attr, divisor_override) \ +_(attr, downscale_factor) \ +_(attr, driver) \ +_(attr, dropout) \ +_(attr, dropout_mask) \ +_(attr, dropout_p) \ +_(attr, dropout_seed) \ +_(attr, dropout_state) \ +_(attr, dst) \ +_(attr, dtype) \ +_(attr, dual) \ +_(attr, dummy) \ +_(attr, dx) \ +_(attr, edge_order) \ +_(attr, eigenvalues) \ +_(attr, eigenvectors) \ +_(attr, eigvals) \ +_(attr, eigvecs) \ +_(attr, element) \ +_(attr, elements) \ +_(attr, ellipsis_idx) \ +_(attr, embed_dim) \ +_(attr, enable_gqa) \ +_(attr, end) \ +_(attr, end_dim) \ +_(attr, eps) \ +_(attr, epsilon) \ +_(attr, equal_nan) \ +_(attr, equation) \ +_(attr, exp_avg_sqs) \ +_(attr, exp_avgs) \ +_(attr, expand1) \ +_(attr, expand2) \ +_(attr, expand3) \ +_(attr, exponent) \ +_(attr, exponential_average_factor) \ +_(attr, fake_quant_enabled) \ +_(attr, fake_quant_on) \ +_(attr, ffn_bias_1) \ +_(attr, ffn_bias_2) \ +_(attr, ffn_weight_1) \ +_(attr, ffn_weight_2) \ +_(attr, filename) \ +_(attr, fill) \ +_(attr, fill_value) \ +_(attr, flat) \ +_(attr, forward) \ +_(attr, found_inf) \ +_(attr, from) \ +_(attr, from_) \ +_(attr, full) \ +_(attr, full_matrices) \ +_(attr, fuse_transform_0213) \ +_(attr, fweights) \ +_(attr, g) \ +_(attr, gO) \ +_(attr, generator) \ +_(attr, ggI) \ +_(attr, ggW) \ +_(attr, ggb) \ +_(attr, glu) \ +_(attr, grad) \ +_(attr, grad_bias) \ +_(attr, grad_cy) \ +_(attr, grad_factor) \ +_(attr, grad_glu) \ +_(attr, grad_hy) \ +_(attr, grad_in) \ +_(attr, grad_input) \ +_(attr, grad_input_mask) \ +_(attr, grad_out) \ +_(attr, grad_out_) \ +_(attr, grad_output) \ +_(attr, grad_scale) \ +_(attr, grad_w) \ +_(attr, grad_weight) \ +_(attr, grad_x) \ +_(attr, grad_y) \ +_(attr, gradient) \ +_(attr, grads) \ +_(attr, grid) \ +_(attr, group) \ +_(attr, groups) \ +_(attr, growth_interval) \ +_(attr, growth_tracker) \ +_(attr, half_to_float) \ +_(attr, has_bias) \ +_(attr, has_biases) \ +_(attr, hermitian) \ +_(attr, hidden_bias) \ +_(attr, hidden_gates) \ +_(attr, hidden_size) \ +_(attr, high) \ +_(attr, hist) \ +_(attr, hop_length) \ +_(attr, hx) \ +_(attr, hx_) \ +_(attr, hy_) \ +_(attr, i1) \ +_(attr, i2) \ +_(attr, i3) \ +_(attr, ignore_index) \ +_(attr, imag) \ +_(attr, impl_index) \ +_(attr, implicit) \ +_(attr, in_features) \ +_(attr, include_last_offset) \ +_(attr, include_self) \ +_(attr, increasing) \ +_(attr, ind) \ +_(attr, index) \ +_(attr, index_dtype) \ +_(attr, indexing) \ +_(attr, indices) \ +_(attr, info) \ +_(attr, initial) \ +_(attr, innerKTiles) \ +_(attr, inp) \ +_(attr, input) \ +_(attr, input1) \ +_(attr, input2) \ +_(attr, input3) \ +_(attr, input_bias) \ +_(attr, input_dtype) \ +_(attr, input_g) \ +_(attr, input_gates) \ +_(attr, input_lengths) \ +_(attr, input_scale) \ +_(attr, input_size) \ +_(attr, input_sizes) \ +_(attr, input_zero_point) \ +_(attr, inputs) \ +_(attr, interpolation) \ +_(attr, interpolation_mode) \ +_(attr, inv_scale) \ +_(attr, inverse) \ +_(attr, invert) \ +_(attr, invstd) \ +_(attr, is_causal) \ +_(attr, is_coalesced) \ +_(attr, is_crow) \ +_(attr, is_first_step) \ +_(attr, is_matrix) \ +_(attr, is_result) \ +_(attr, is_target) \ +_(attr, k) \ +_(attr, keepdim) \ +_(attr, kernel_size) \ +_(attr, key) \ +_(attr, label_smoothing) \ +_(attr, lambd) \ +_(attr, largest) \ +_(attr, last_dim_size) \ +_(attr, layersOutputs) \ +_(attr, layout) \ +_(attr, left) \ +_(attr, length) \ +_(attr, lengths) \ +_(attr, level) \ +_(attr, like) \ +_(attr, list) \ +_(attr, log_alpha) \ +_(attr, log_input) \ +_(attr, log_probs) \ +_(attr, log_target) \ +_(attr, logabsdet) \ +_(attr, logsumexp) \ +_(attr, low) \ +_(attr, lower) \ +_(attr, lr) \ +_(attr, lr_decay) \ +_(attr, ltm) \ +_(attr, m) \ +_(attr, mantissa) \ +_(attr, margin) \ +_(attr, mask) \ +_(attr, mask_check) \ +_(attr, mask_type) \ +_(attr, masked_grad) \ +_(attr, mat) \ +_(attr, mat1) \ +_(attr, mat1_meta) \ +_(attr, mat2) \ +_(attr, matrices) \ +_(attr, max) \ +_(attr, max_exp_avg_sqs) \ +_(attr, max_k) \ +_(attr, max_lengths) \ +_(attr, max_norm) \ +_(attr, max_q) \ +_(attr, max_seqlen) \ +_(attr, max_seqlen_k) \ +_(attr, max_seqlen_q) \ +_(attr, max_size) \ +_(attr, max_val) \ +_(attr, max_values) \ +_(attr, maximize) \ +_(attr, maximum_indices) \ +_(attr, maxnorm) \ +_(attr, mean) \ +_(attr, median) \ +_(attr, memory_format) \ +_(attr, meta) \ +_(attr, min) \ +_(attr, min_indices) \ +_(attr, min_seqlen) \ +_(attr, min_val) \ +_(attr, minlength) \ +_(attr, mode) \ +_(attr, momentum) \ +_(attr, momentum_buffer_list) \ +_(attr, n) \ +_(attr, n_bins) \ +_(attr, n_fft) \ +_(attr, names) \ +_(attr, nan) \ +_(attr, need_weights) \ +_(attr, neg_log_likelihood) \ +_(attr, negative) \ +_(attr, negative_slope) \ +_(attr, neginf) \ +_(attr, nested_size) \ +_(attr, nested_strides) \ +_(attr, nesterov) \ +_(attr, new_data) \ +_(attr, nnz) \ +_(attr, noise) \ +_(attr, non_blocking) \ +_(attr, norm) \ +_(attr, norm_bias_1) \ +_(attr, norm_bias_2) \ +_(attr, norm_first) \ +_(attr, norm_type) \ +_(attr, norm_weight_1) \ +_(attr, norm_weight_2) \ +_(attr, normalization) \ +_(attr, normalized) \ +_(attr, normalized_shape) \ +_(attr, normalized_shape_ndim) \ +_(attr, nt_example) \ +_(attr, num_chunks) \ +_(attr, num_classes) \ +_(attr, num_generated) \ +_(attr, num_groups) \ +_(attr, num_head) \ +_(attr, num_heads) \ +_(attr, num_layers) \ +_(attr, num_parallel) \ +_(attr, num_samples) \ +_(attr, num_splits_key) \ +_(attr, num_weights) \ +_(attr, numel) \ +_(attr, observer_on) \ +_(attr, offs) \ +_(attr, offset) \ +_(attr, offset2bag) \ +_(attr, offsets) \ +_(attr, onesided) \ +_(attr, ord) \ +_(attr, order) \ +_(attr, other) \ +_(attr, out) \ +_(attr, out0) \ +_(attr, out1) \ +_(attr, out2) \ +_(attr, out3) \ +_(attr, out4) \ +_(attr, out5) \ +_(attr, out6) \ +_(attr, out_channel) \ +_(attr, out_dim) \ +_(attr, out_dtype) \ +_(attr, out_features) \ +_(attr, out_int32) \ +_(attr, outdim) \ +_(attr, output) \ +_(attr, output_mask) \ +_(attr, output_padding) \ +_(attr, output_scale) \ +_(attr, output_size) \ +_(attr, output_zero_point) \ +_(attr, p) \ +_(attr, packed) \ +_(attr, packed_hh) \ +_(attr, packed_ih) \ +_(attr, packed_weight) \ +_(attr, packed_weights) \ +_(attr, pad) \ +_(attr, pad_mode) \ +_(attr, padded) \ +_(attr, padding) \ +_(attr, padding_idx) \ +_(attr, padding_mode) \ +_(attr, padding_side) \ +_(attr, padding_value) \ +_(attr, params) \ +_(attr, path) \ +_(attr, pdist) \ +_(attr, per_row_fake_quant) \ +_(attr, per_sample_weights) \ +_(attr, periodic) \ +_(attr, philox_offset) \ +_(attr, philox_seed) \ +_(attr, physical_layout) \ +_(attr, pin_memory) \ +_(attr, pivot) \ +_(attr, pivots) \ +_(attr, plain_idx) \ +_(attr, plain_indices) \ +_(attr, pos_weight) \ +_(attr, posinf) \ +_(attr, positive) \ +_(attr, pow) \ +_(attr, prepend) \ +_(attr, primal) \ +_(attr, prob) \ +_(attr, proj_bias) \ +_(attr, proj_size) \ +_(attr, proj_weight) \ +_(attr, q) \ +_(attr, qGroupSize) \ +_(attr, qScale) \ +_(attr, qScaleAndZeros) \ +_(attr, qZeros) \ +_(attr, qkv) \ +_(attr, qkv_bias) \ +_(attr, qkv_weight) \ +_(attr, qtensor) \ +_(attr, quant_max) \ +_(attr, quant_min) \ +_(attr, quasi) \ +_(attr, query) \ +_(attr, r) \ +_(attr, ragged_idx) \ +_(attr, random_samples) \ +_(attr, range) \ +_(attr, rank) \ +_(attr, ratio) \ +_(attr, rcond) \ +_(attr, real) \ +_(attr, reduce) \ +_(attr, reduce_range) \ +_(attr, reduction) \ +_(attr, repeats) \ +_(attr, replacement) \ +_(attr, requires_grad) \ +_(attr, reserve) \ +_(attr, reserveSpace) \ +_(attr, reservedSpace) \ +_(attr, residuals) \ +_(attr, result) \ +_(attr, retain_graph) \ +_(attr, return_complex) \ +_(attr, return_counts) \ +_(attr, return_debug_mask) \ +_(attr, return_inverse) \ +_(attr, reverse) \ +_(attr, right) \ +_(attr, rng_state) \ +_(attr, rounding_mode) \ +_(attr, row) \ +_(attr, row_indices) \ +_(attr, rstd) \ +_(attr, rtol) \ +_(attr, running_max) \ +_(attr, running_mean) \ +_(attr, running_min) \ +_(attr, running_var) \ +_(attr, s) \ +_(attr, save_invstd) \ +_(attr, save_mean) \ +_(attr, save_var) \ +_(attr, save_var_transform) \ +_(attr, saved_g) \ +_(attr, saved_norms) \ +_(attr, saved_v) \ +_(attr, scalar) \ +_(attr, scalar1) \ +_(attr, scalar2) \ +_(attr, scalars) \ +_(attr, scale) \ +_(attr, scale_a) \ +_(attr, scale_b) \ +_(attr, scale_backoff_factor) \ +_(attr, scale_factors) \ +_(attr, scale_grad_by_freq) \ +_(attr, scale_growth_factor) \ +_(attr, scale_hh) \ +_(attr, scale_ih) \ +_(attr, scale_result) \ +_(attr, scales) \ +_(attr, scales_d) \ +_(attr, scales_h) \ +_(attr, scales_w) \ +_(attr, scales_zeros) \ +_(attr, sections) \ +_(attr, seed) \ +_(attr, self) \ +_(attr, self_is_result) \ +_(attr, self_num_batch_dims) \ +_(attr, self_or_result) \ +_(attr, self_sizes) \ +_(attr, seqlen_k) \ +_(attr, sequences) \ +_(attr, seqused_k) \ +_(attr, shape) \ +_(attr, shared) \ +_(attr, shared_storage_dqdkdv) \ +_(attr, shifts) \ +_(attr, side) \ +_(attr, sigma) \ +_(attr, sign) \ +_(attr, singular_values) \ +_(attr, size) \ +_(attr, sizes) \ +_(attr, skip_first) \ +_(attr, sobolstate) \ +_(attr, solution) \ +_(attr, some) \ +_(attr, sorted) \ +_(attr, sorted_sequence) \ +_(attr, sorter) \ +_(attr, source) \ +_(attr, spacing) \ +_(attr, sparse) \ +_(attr, sparse_dim) \ +_(attr, sparse_grad) \ +_(attr, split_k) \ +_(attr, split_k_mode) \ +_(attr, split_size) \ +_(attr, split_sizes) \ +_(attr, src) \ +_(attr, stable) \ +_(attr, start) \ +_(attr, start_dim) \ +_(attr, state_steps) \ +_(attr, state_sums) \ +_(attr, std) \ +_(attr, step) \ +_(attr, steps) \ +_(attr, storage_offset) \ +_(attr, stride) \ +_(attr, sum_S) \ +_(attr, sum_dy) \ +_(attr, sum_dy_xmu) \ +_(attr, sumdim) \ +_(attr, swap) \ +_(attr, symmetric_quant) \ +_(attr, t) \ +_(attr, tangent) \ +_(attr, target) \ +_(attr, target_lengths) \ +_(attr, targets) \ +_(attr, tau) \ +_(attr, tensor) \ +_(attr, tensor1) \ +_(attr, tensor2) \ +_(attr, tensor_indices_or_sections) \ +_(attr, tensors) \ +_(attr, tensors1) \ +_(attr, test_element) \ +_(attr, test_elements) \ +_(attr, the_template) \ +_(attr, theta) \ +_(attr, thread_masks) \ +_(attr, threshold) \ +_(attr, to) \ +_(attr, tol) \ +_(attr, total) \ +_(attr, total_L) \ +_(attr, total_length) \ +_(attr, total_weight) \ +_(attr, train) \ +_(attr, training) \ +_(attr, transpose) \ +_(attr, transpose_result) \ +_(attr, transposed) \ +_(attr, type1) \ +_(attr, type2) \ +_(attr, unbiased) \ +_(attr, unitriangular) \ +_(attr, unpack_data) \ +_(attr, unpack_pivots) \ +_(attr, unroll_dim) \ +_(attr, unsafe) \ +_(attr, unused) \ +_(attr, update) \ +_(attr, upper) \ +_(attr, upscale_factor) \ +_(attr, use_cutlass) \ +_(attr, use_fast_accum) \ +_(attr, use_gelu) \ +_(attr, use_input_stats) \ +_(attr, v) \ +_(attr, value) \ +_(attr, values) \ +_(attr, var) \ +_(attr, vec) \ +_(attr, vec1) \ +_(attr, vec2) \ +_(attr, w_hh) \ +_(attr, w_ih) \ +_(attr, weight) \ +_(attr, weight0) \ +_(attr, weight1) \ +_(attr, weight2) \ +_(attr, weight3) \ +_(attr, weight4) \ +_(attr, weight_arr) \ +_(attr, weight_buf) \ +_(attr, weight_decay) \ +_(attr, weight_g) \ +_(attr, weight_scale) \ +_(attr, weight_stride0) \ +_(attr, weight_zero_point) \ +_(attr, weights) \ +_(attr, win_length) \ +_(attr, window) \ +_(attr, window_length) \ +_(attr, window_size) \ +_(attr, window_size_left) \ +_(attr, window_size_right) \ +_(attr, with_replacement) \ +_(attr, workspace) \ +_(attr, wrap) \ +_(attr, x) \ +_(attr, x1) \ +_(attr, x2) \ +_(attr, y) \ +_(attr, z) \ +_(attr, z_state) \ +_(attr, zero_infinity) \ +_(attr, zero_point) \ +_(attr, zero_point_hh) \ +_(attr, zero_point_ih) \ +_(attr, zero_points) diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/blob.h b/phivenv/Lib/site-packages/torch/include/ATen/core/blob.h new file mode 100644 index 0000000000000000000000000000000000000000..75487bd1350f100a33a7731fe8c1467e148aadcb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/blob.h @@ -0,0 +1,204 @@ +#pragma once + +#include + +#include +#include +#include + +namespace caffe2 { + +class Tensor; + +/** + * @brief Blob is a general container that hosts a typed pointer. + * + * A Blob hosts a pointer as well as its type, and takes charge of deleting it + * properly when the blob is deallocated or re-allocated with a new type. A blob + * could contain anything, although the most common case is to contain a Tensor. + */ +class TORCH_API Blob final : public c10::intrusive_ptr_target { + public: + /** + * Initializes an empty Blob. + */ + Blob() noexcept = default; + ~Blob() override { + Reset(); + } + + Blob(Blob&& other) noexcept : Blob() { + swap(other); + } + + Blob& operator=(Blob&& other) noexcept { + Blob(std::move(other)).swap(*this); + return *this; + } + + /** + * Checks if the content stored in the blob is of type T. + */ + template + bool IsType() const noexcept { + return meta_.Match(); + } + + /** + * Returns the meta info of the blob. + */ + const TypeMeta meta() const noexcept { + return meta_; + } + + /** + * Returns a printable typename of the blob. + */ + std::string_view TypeName() const noexcept { + return meta_.name(); + } + + /** + * @brief Gets the const reference of the stored object. The code checks if + * the stored object is of the desired type. + */ + // TODO(jerryzh): add a Get(c10::DeviceType) function? + template + const T& Get() const { + TORCH_INTERNAL_ASSERT( + IsType(), + "wrong type for the Blob instance. Blob contains ", + meta_.name(), + " while caller expects ", + TypeMeta::TypeName()); + // TODO: after we add Get(c10::DeviceType) + // and changed all the callsites, we can add + // a static assert here to enforce T != Tensor + return *static_cast(pointer_); + } + + const void* GetRaw() const noexcept { + return pointer_; + } + void* GetRaw() noexcept { + return pointer_; + } + + /** + * @brief Gets a mutable pointer to the stored object. + * + * If the current object is not of the right type, a new object is created + * and the old object is freed. Note that type T should have a default + * constructor. Otherwise, create the object yourself first, and use + * Reset(). + */ + template + T* GetMutable() { + static_assert( + std::is_default_constructible_v, + "GetMutable can't be called with non-default-constructible types. " + "Try using specialized methods"); + if (IsType()) { + return static_cast(pointer_); + } else { + // TODO Re-enable logging + // VLOG(1) << "Create new mutable object " << TypeMeta::TypeName(); + return Reset(new T()); + } + } + + template + T* GetMutableOrNull() { + if (IsType()) { + return static_cast(pointer_); + } else { + return nullptr; + } + } + + /** + * Sets the underlying object to the allocated one. The Blob then takes over + * the ownership of the passed in pointer. If there is already an object in + * the Blob, the old object is freed. + * + * This is used when the underlying class T does not have a default ctor, or + * complex initializations needs to be done outside the blob. + */ + template + T* Reset(T* allocated) { + free_(); + meta_ = TypeMeta::Make(); + pointer_ = static_cast(allocated); + has_ownership_ = true; + return allocated; + } + + /** + * Sets the underlying object to the allocated one, but does not take over + * the ownership of the passed in pointer. If there is already an object in + * the Blob, the old object is freed. + * + * Unlike Reset, this does not take over the ownership of the pointer and the + * caller is responsible for making sure that the lifetime of the allocated + * blob outlasts the lifetime of any access to this blob, until another Reset + * call is made or the blob is destructed. + */ + template + std::remove_const_t* ShareExternal( + std::remove_const_t* allocated) { + return static_cast(ShareExternal( + static_cast(allocated), + TypeMeta::Make>())); + } + + void* ShareExternal(void* allocated, const TypeMeta meta) { + free_(); + meta_ = meta; + pointer_ = allocated; + has_ownership_ = false; + return allocated; + } + + /** + * Resets the Blob to an empty one. + */ + void Reset() { + free_(); + pointer_ = nullptr; + meta_ = TypeMeta(); + has_ownership_ = false; + } + + /** + * @brief Swaps the underlying storage of two blobs. + */ + void swap(Blob& rhs) noexcept { + using std::swap; + swap(meta_, rhs.meta_); + swap(pointer_, rhs.pointer_); + swap(has_ownership_, rhs.has_ownership_); + } + + private: + void free_() { + if (has_ownership_ && pointer_ != nullptr) { + (*meta_.deleteFn())(pointer_); + } + } + + TypeMeta meta_; + void* pointer_{nullptr}; + bool has_ownership_{false}; + + C10_DISABLE_COPY_AND_ASSIGN(Blob); +}; + +inline void swap(Blob& lhs, Blob& rhs) noexcept { + lhs.swap(rhs); +} + +inline std::ostream& operator<<(std::ostream& out, const Blob& v) { + return out << "Blob[" << v.TypeName() << "]"; +} + +} // namespace caffe2 diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/builtin_function.h b/phivenv/Lib/site-packages/torch/include/ATen/core/builtin_function.h new file mode 100644 index 0000000000000000000000000000000000000000..1ecd19d3895e30010d60a721d31a23cbbb42b7ee --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/builtin_function.h @@ -0,0 +1,90 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace torch::jit { + +struct BuiltinOpFunction : public Function { + BuiltinOpFunction( + c10::QualifiedName qualname, + c10::FunctionSchema schema, + std::function callable, + std::string doc_string = "") + : name_(std::move(qualname)), + callable_(std::move(callable)), + schema_(std::move(schema)), + doc_string_(std::move(doc_string)) { + TORCH_INTERNAL_ASSERT(schema_.returns().size() == 1); + } + + std::string_view doc_string() const override { + return doc_string_; + } + + void run(Stack& stack) override { + callable_(stack); + } + + c10::intrusive_ptr runAsync( + Stack& stack, + TaskLauncher /* not used */) override { + run(stack); + auto res = c10::make_intrusive(stack.front().type()); + res->markCompleted(std::move(stack.front())); + return res; + } + + const c10::QualifiedName& qualname() const override { + return name_; + } + + // if this isn't yet defined, run its method_creator function + void ensure_defined() override { + // nop + } + + const c10::FunctionSchema& getSchema() const override { + return schema_; + } + + size_t num_inputs() const override { + return schema_.arguments().size(); + } + + Function& setSchema(c10::FunctionSchema schema) override { + schema_ = std::move(schema); + return *this; + } + + bool call( + Stack& stack, + std::optional, + c10::function_ref) override { + run(stack); + return false; + } + + bool call(Stack& stack, c10::function_ref) + override { + run(stack); + return false; + } + + ~BuiltinOpFunction() override = default; + + private: + c10::QualifiedName name_; + + std::function callable_; + + c10::FunctionSchema schema_; + + std::string doc_string_; +}; + +} // namespace torch::jit diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/class_type.h b/phivenv/Lib/site-packages/torch/include/ATen/core/class_type.h new file mode 100644 index 0000000000000000000000000000000000000000..34d94d4015ed41b185cd060eb0ef19f3f942e038 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/class_type.h @@ -0,0 +1,441 @@ +#pragma once + +#include + +#include +#include +#include + + +namespace torch::jit { +struct CompilationUnit; +struct Function; +} // namespace torch::jit + + +namespace c10 { + +struct FunctionSchema; + +// This enumerator represents the 'kind' of an attribute - a buffer, a parameter, or neither. +// This state is mutually exclusive. Buffers and Parameters can only appear on modules. +enum class AttributeKind { + BUFFER, + PARAMETER, + REGULAR_ATTRIBUTE +}; + +// This structure represents all notional booking entities in a class attribute: name, kind (see: AttributeKind), and type (see: TypePtr). +// Note: This structure does not represent the value of the attribute. +struct TORCH_API ClassAttribute { + public: + ClassAttribute(AttributeKind kind, + TypePtr attributeType, + std::string attributeName) : + kind_(kind), + attributeType_(std::move(attributeType)), + attributeName_(std::move(attributeName)) {} + + AttributeKind getKind() const { + return kind_; + } + + const TypePtr& getType() const { + return attributeType_; + } + + const std::string& getName() const { + return attributeName_; + } + + private: + AttributeKind kind_; + TypePtr attributeType_; + std::string attributeName_; +}; + +/** + * User Defined Types + */ + +struct ClassType; +using ClassTypePtr = std::shared_ptr; +using ::torch::jit::CompilationUnit; + +// This represents a class in TorchScript. +struct TORCH_API ClassType : public NamedType { + // This represents an attribute of a class; a name associated with an attribute, and a + // getter and (optional) setter for that attribute. + struct Property { + std::string name; + torch::jit::Function* getter; + torch::jit::Function* setter; + }; + + // Create a class type with name `name` and its methods stored in `cu`. + static ClassTypePtr create( + std::optional qualifiedName, + std::weak_ptr cu, + bool is_module = false, + std::string doc_string = "", + std::vector unresolved_class_attributes = {}); + + bool equals(const Type& rhs) const override { + if (this == &rhs) { + return true; + } + if (auto user_rhs = rhs.castRaw()) { + const auto& lhs_name = name(); + const auto& rhs_name = user_rhs->name(); + return lhs_name.has_value() && lhs_name == rhs_name && + this->compilation_unit() == user_rhs->compilation_unit(); + } + return false; + } + + std::string str() const override { + return annotation_str(); + } + + std::string repr_str() const override { + std::stringstream ss; + ss << str() + << " (of Python compilation unit at: " << compilation_unit().get() << ")"; + return ss.str(); + } + + const std::vector& methods() const; + + TypePtr findAttribute(const std::string& name) const { + size_t pos = 0; + for (const auto& attr : attributes_) { + if (name == attr.getName()) { + break; + } + ++pos; + } + + if (pos >= attributes_.size()) { + return nullptr; + } + return attributes_[pos].getType(); + } + + const TypePtr& getAttribute(const std::string& name) const { + auto slot = findAttributeSlot(name); + TORCH_CHECK( + slot, + repr_str(), + " does not have an attribute with name '", + name, + "'"); + return attributes_[*slot].getType(); + } + + size_t numAttributes() const { + return attributes_.size(); + } + + const TypePtr& getAttribute(size_t slot) const { + AT_ASSERT(slot < attributes_.size()); + return attributes_.at(slot).getType(); + } + + const std::string getAttributeName(size_t slot) const { + AT_ASSERT(slot < attributes_.size()); + return attributes_[slot].getName(); + } + + void checkNotExist(const std::string& name, const std::string& what) const; + + // Attributes are stored in a specific slot at runtime for effiency. + // When emitting instructions we specify the slot so that attribute access is + // a constant lookup + std::optional findAttributeSlot(const std::string& name) const { + size_t slot = 0; + for (const auto& attr : attributes_) { + if (name == attr.getName()) { + return slot; + } + slot++; + } + return std::nullopt; + } + size_t getAttributeSlot(const std::string& name) const { + if (auto r = findAttributeSlot(name)) { + return *r; + } + TORCH_CHECK( + false, + repr_str(), + " does not have an attribute with name '", + name, + "'"); + } + + bool hasAttribute(const std::string& name) const { + return std::find_if( + attributes_.cbegin(), + attributes_.cend(), + [&](const ClassAttribute& attr) { return attr.getName() == name; }) != + attributes_.cend(); + } + + bool isUnresolvedClassAttribute(const std::string& name) const; + + at::ArrayRef containedTypes() const override { + return attributeTypes_; + } + + size_t addAttribute( + const std::string& name, + TypePtr type, + bool is_parameter = false, + bool is_buffer = false); + + // [Internal Only] Remove attribute from the ClassType, + // caller is responsible to make sure the modification is safe: + // it is unsafe to having existing allocations + // of this object around anymore, and any code that works on + // the attribute is now invalid. Only newly created code is + // valid again. + void unsafeRemoveAttribute(const std::string& name); + + // [Internal Only] Change the type of an attribute of the ClassType, + // The caller is responsible to make sure the modification is safe: + // it is unsafe to maintain uses of the old type of the attribute, + // and any code that works on the attribute is now invalid. + // Only newly created code is valid again. + void unsafeChangeAttributeType(const std::string& name, const TypePtr& new_ty); + + // Add attribute \p NAME if it doesn't exist or verify that it has a + // compatible type otherwise. + size_t addOrCheckAttribute( + const std::string& name, + TypePtr ty, + bool is_parameter = false, + bool is_buffer = false) { + auto slot_idx = findAttributeSlot(name); + if (!slot_idx) { + return addAttribute(name, std::move(ty), is_parameter, is_buffer); + } + + TORCH_CHECK( + is_parameter == this->is_parameter(*slot_idx), + "Parameter field mismatch for the field '", + name, + "'"); + const TypePtr& atype = getAttribute(*slot_idx); + TORCH_CHECK( + ty->isSubtypeOf(*atype), + ty->repr_str(), + " is not compatible with the type ", + atype->repr_str(), + " for the field '", + name, + "'"); + return *slot_idx; + } + + // Get the property with the given \p name, if it exists on the class. + std::optional getProperty(const std::string& name); + // Add a property named \p name with \p getter and \p setter as its getter and setter. + void addProperty(const std::string& name, torch::jit::Function* getter, torch::jit::Function* setter); + // Get a list of all properties. + const std::vector& properties() const { + return properties_; + } + + bool hasConstant(const std::string& name) const { + return std::find_if( + constantNames_.cbegin(), + constantNames_.cend(), + [&](const std::string& constant) { return constant == name; }) != + constantNames_.cend(); + } + + size_t addConstant(const std::string& name, const IValue& value); + + std::optional findConstantSlot(const std::string& name) const; + + size_t getConstantSlot(const std::string& name) const { + if (auto r = findConstantSlot(name)) { + return *r; + } + TORCH_CHECK( + false, + repr_str(), + " does not have constant field with the name '", + name, + "'"); + } + + const std::string& getConstantName(size_t slot) const; + + const std::string& doc_string() const { + return doc_string_; + } + + IValue getConstant(const std::string& name) const; + + IValue getConstant(size_t slot) const; + + std::optional findConstant(const std::string& name) const; + + size_t numConstants() const; + + at::ArrayRef constantNames() const { + return constantNames_; + } + + at::ArrayRef constantValues() const; + + // [Internal Only] Remove constant from the ClassType + // caller is responsible to make sure the modification is safe: + // it is unsafe to having existing allocations + // of this object around anymore, and any code that works on + // the attribute is now invalid. Only newly created code is + // valid again. + void unsafeRemoveConstant(const std::string& name); + + TypePtr createWithContained(std::vector contained_types) const override { + auto ptr = ClassType::create(name(), compilation_unit_, is_module()); + AT_ASSERT(numAttributes() == contained_types.size()); + for(size_t i = 0; i < attributes_.size(); ++i) { + AT_ASSERT(attributes_[i].getType()->isSubtypeOf(*contained_types[i])); + ptr->addAttribute(attributes_[i].getName(), std::move(contained_types[i])); + } + // Copy methods over + for (const auto& method : methods()) { + ptr->addMethod(method); + } + return ptr; + } + + bool is_module() const override { + return isModule_; + } + + const std::vector& getAttributes() const { + return attributes_; + } + + bool is_parameter(size_t slot) const { + TORCH_INTERNAL_ASSERT( + is_module(), "asking for parameterSlots of non-Module"); + return attributes_.at(slot).getKind() == AttributeKind::PARAMETER; + } + + bool is_buffer(size_t slot) const { + TORCH_INTERNAL_ASSERT( + is_module(), "asking for bufferWrittenSlots of non-Module"); + return attributes_.at(slot).getKind() == AttributeKind::BUFFER; + } + + void addForwardPreHook(torch::jit::Function* pre_hook_ptr); + void addForwardHook(torch::jit::Function* hook_ptr); + torch::jit::Function* findForwardPreHook(const std::string& name) const; + torch::jit::Function* findForwardHook(const std::string& name) const; + const std::vector& getForwardHooks() const; + const std::vector& getForwardPreHooks() const; + + void checkForwardPreHookSchema( + size_t pre_hook_idx, + const FunctionSchema& pre_hook_schema) const; + void checkForwardHookSchema( + size_t hook_idx, + const FunctionSchema& hook_schema) const; + + void addMethod(torch::jit::Function* method); + torch::jit::Function* findMethod(const std::string& name) const; + torch::jit::Function& getMethod(const std::string& name) const; + torch::jit::Function* findHook(const std::string& name) const; + torch::jit::Function& getHook(const std::string& name) const; + bool hasMethod(const std::string& name) const; + + torch::jit::Function* findStaticMethod(const std::string& name) const; + void addStaticMethod(torch::jit::Function* method); + + // [Internal Only] Remove method from the ClassType + // caller is responsible to make sure the modification is safe: + // it is unsafe to having existing allocations + // of this object around anymore, and any code that works on + // the attribute is now invalid. Only newly created code is + // valid again. + // Note this method is intended for freezing only. + void unsafeRemoveMethod(const std::string& name); + + std::shared_ptr compilation_unit(); + + std::shared_ptr compilation_unit() const; + + // generate a refined version of this class. + // It has the same name but the slot Types are subtypes of + // the original slots. It is only valid to refine a class type in a context + // where it is know that there are not assignments to the objects slots + // that would invalidate the refinement. + // These variants are not registered in the global class table. + ClassTypePtr refine(at::ArrayRef refined_slots) const; + + bool isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const override; + + static const TypeKind Kind = TypeKind::ClassType; + + private: + ClassType( + std::optional name, + std::weak_ptr cu, + bool is_module = false, + std::string doc_string = "", + std::vector unresolved_class_attributes = {}); + + std::string annotation_str_impl( + [[maybe_unused]] const TypePrinter& printer = nullptr) const override { + // NOLINTNEXTLINE(bugprone-unchecked-optional-access) + return name()->qualifiedName(); + } + + void addAttribute(ClassAttribute classAttribute); + std::string getForwardPreHookErrorMessage(size_t pre_hook_idx) const; + std::string getForwardHookErrorMessage(size_t hook_idx) const; + + // Mapping of attribute names -> their type. + // NOTE: this does not contain methods, which are stored in the module + // TODO: once modules support arbitrary ivalue attributes, we don't need this + // anymore. + // TODO: This is better represented as an OrderedDict, but alas it is not yet + // available from c10 + + // Mapping of constant names -> their value. + std::vector constantNames_; + std::vector constantValues_; + // Holds method attributes + std::weak_ptr compilation_unit_; + + // Holds all atrributes, attribute details are found on ClassAttribute + std::vector attributes_; + // Construct mirroring attributes_, only around due to the fact that `containedTypes()` method returns an ArrayRef. + // Never fill this without using the appropriate provideNewClassAttribute method + std::vector attributeTypes_; + + // List of methods associated with this class. + std::vector methods_; + std::vector staticmethods_; + + // List of hooks to be run before/after forward. + std::vector forward_hooks_; + std::vector forward_pre_hooks_; + + // List of properties exposed by this class. + std::vector properties_; + + bool isModule_ = false; + + // Doc string of class. + std::string doc_string_; + + // For error reporting accesses to class level attributes. + std::vector unresolved_class_attributes_; +}; + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/core/custom_class.h b/phivenv/Lib/site-packages/torch/include/ATen/core/custom_class.h new file mode 100644 index 0000000000000000000000000000000000000000..601af3eb48c1222edfb07a5f98b3d02f4e4c5a57 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/core/custom_class.h @@ -0,0 +1,28 @@ +#pragma once + +#include +#include + +#include +#include +#include + +namespace c10 { + +struct ClassType; +using ClassTypePtr = std::shared_ptr; + +TORCH_API c10::ClassTypePtr getCustomClassTypeImpl(const std::type_index &tindex); + +template +const c10::ClassTypePtr& getCustomClassType() { + // Classes are never unregistered from getCustomClassTypeMap and the + // hash lookup can be a hot path, so just cache. + // For the same reason, it's fine If this ends up getting duplicated across + // DSO boundaries for whatever reason. + static c10::ClassTypePtr cache = getCustomClassTypeImpl( + std::type_index(typeid(T))); + return cache; +} + +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/cpp_custom_type_hack.h b/phivenv/Lib/site-packages/torch/include/ATen/cpp_custom_type_hack.h new file mode 100644 index 0000000000000000000000000000000000000000..e9e4e3e677d16b3001188f678ef2b985319b8405 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/cpp_custom_type_hack.h @@ -0,0 +1,110 @@ +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP + +// YOU ARE IN THE WRONG PLACE! TURN BACK NOW! + +// This code was a temporary hack to enable embedding arbitrary C++ structures +// into Tensors. THIS IS UNSAFE AND IS NOT SUPPORTED. IF YOU USE THIS CODE, +// IT __WILL__ BREAK. + +// This code has been superseded by custom classes: +// https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html + +// Please use custom classes and **DO NOT ADD MORE CALLSITES TO THINGS DEFINED +// IN THIS FILE**. + +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP +// STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP STOP + +#include +#include + +#ifndef AT_PER_OPERATOR_HEADERS +#include +#else +#include +#endif + +namespace at::cpp_custom_type_hack { + +template +[[deprecated( + "Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] bool +isa(const Tensor& packed) { + return (packed.scalar_type() == kByte) && + (packed.storage().data_ptr().get_deleter() == + caffe2::TypeMeta::Make().deleteFn()); +} + +template +[[deprecated( + "Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] T& +cast(const Tensor& packed) { + TORCH_CHECK( + packed.scalar_type() == kByte, "Expected temporary cpp type wrapper"); + TORCH_CHECK( + packed.storage().data_ptr().get_deleter() == + caffe2::TypeMeta::Make().deleteFn(), + "Expected temporary cpp type wrapper of type ", + caffe2::TypeMeta::TypeName()); + return *reinterpret_cast(packed.storage().data_ptr().get()); +} + +template +[[deprecated( + "Use custom classes instead: " + "https://pytorch.org/tutorials/advanced/torch_script_custom_classes.html")]] Tensor +create(std::unique_ptr ptr, TensorOptions options) { + // None of this should trace, so turn off Tracer dispatching + at::AutoDispatchBelowADInplaceOrView guard; // TODO: remove + at::tracer::impl::NoTracerDispatchMode tracer_guard; + + // We store this instance away in a Tensor and register a deleter function + // so that we do not leak memory. On the other side, we pull out the storage's + // data_ptr and get the right typed pointer. + void* raw_ptr = ptr.release(); + at::DataPtr at_ptr( + raw_ptr, raw_ptr, caffe2::TypeMeta::Make().deleteFn(), at::kCPU); + + // size doesn't really matter, but we can align it to the actual size + // returning variables because one likely want to use this hack from python + auto retval = at::empty({sizeof(T)}, options.device(kCPU).dtype(at::kByte)); + retval.storage().set_data_ptr_noswap(std::move(at_ptr)); + return retval; +} + +} // namespace at::cpp_custom_type_hack diff --git a/phivenv/Lib/site-packages/torch/include/ATen/div_rtn.h b/phivenv/Lib/site-packages/torch/include/ATen/div_rtn.h new file mode 100644 index 0000000000000000000000000000000000000000..4a6d088b798c2ac96e58107db224a35ba5c9e8c8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/div_rtn.h @@ -0,0 +1,11 @@ +#pragma once + +// Integer division rounding to -Infinity +template +static inline T div_rtn(T x, T y) { + int q = x / y; + int r = x % y; + if ((r != 0) && ((r < 0) != (y < 0))) + --q; + return q; +} diff --git a/phivenv/Lib/site-packages/torch/include/ATen/dlpack.h b/phivenv/Lib/site-packages/torch/include/ATen/dlpack.h new file mode 100644 index 0000000000000000000000000000000000000000..eb4f39db2b2fb4d8df1d9104b084e1fd4c4540f2 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/dlpack.h @@ -0,0 +1,236 @@ +/*! + * Copyright (c) 2017 by Contributors + * \file dlpack.h + * \brief The common header of DLPack. + */ +#ifndef DLPACK_DLPACK_H_ +#define DLPACK_DLPACK_H_ + +/** + * \brief Compatibility with C++ + */ +#ifdef __cplusplus +#define DLPACK_EXTERN_C extern "C" +#else +#define DLPACK_EXTERN_C +#endif + +/*! \brief The current version of dlpack */ +#define DLPACK_VERSION 80 + +/*! \brief The current ABI version of dlpack */ +#define DLPACK_ABI_VERSION 1 + +/*! \brief DLPACK_DLL prefix for windows */ +#ifdef _WIN32 +#ifdef DLPACK_EXPORTS +#define DLPACK_DLL __declspec(dllexport) +#else +#define DLPACK_DLL __declspec(dllimport) +#endif +#else +#define DLPACK_DLL +#endif + +// NOLINTNEXTLINE(modernize-deprecated-headers) +#include +// NOLINTNEXTLINE(modernize-deprecated-headers) +#include + +#ifdef __cplusplus +extern "C" { +#endif +/*! + * \brief The device type in DLDevice. + */ +#ifdef __cplusplus +typedef enum : int32_t { +#else +typedef enum { +#endif + /*! \brief CPU device */ + kDLCPU = 1, + /*! \brief CUDA GPU device */ + kDLCUDA = 2, + /*! + * \brief Pinned CUDA CPU memory by cudaMallocHost + */ + kDLCUDAHost = 3, + /*! \brief OpenCL devices. */ + kDLOpenCL = 4, + /*! \brief Vulkan buffer for next generation graphics. */ + kDLVulkan = 7, + /*! \brief Metal for Apple GPU. */ + kDLMetal = 8, + /*! \brief Verilog simulator buffer */ + kDLVPI = 9, + /*! \brief ROCm GPUs for AMD GPUs */ + kDLROCM = 10, + /*! + * \brief Pinned ROCm CPU memory allocated by hipMallocHost + */ + kDLROCMHost = 11, + /*! + * \brief Reserved extension device type, + * used for quickly test extension device + * The semantics can differ depending on the implementation. + */ + kDLExtDev = 12, + /*! + * \brief CUDA managed/unified memory allocated by cudaMallocManaged + */ + kDLCUDAManaged = 13, + /*! + * \brief Unified shared memory allocated on a oneAPI non-partititioned + * device. Call to oneAPI runtime is required to determine the device + * type, the USM allocation type and the sycl context it is bound to. + * + */ + kDLOneAPI = 14, + /*! \brief GPU support for next generation WebGPU standard. */ + kDLWebGPU = 15, + /*! \brief Qualcomm Hexagon DSP */ + kDLHexagon = 16, + /*! \brief Microsoft AI Accelerator */ + kDLMAIA = 17, +} DLDeviceType; + +/*! + * \brief A Device for Tensor and operator. + */ +typedef struct { + /*! \brief The device type used in the device. */ + DLDeviceType device_type; + /*! + * \brief The device index. + * For vanilla CPU memory, pinned memory, or managed memory, this is set to 0. + */ + int32_t device_id; +} DLDevice; + +/*! + * \brief The type code options DLDataType. + */ +typedef enum { + /*! \brief signed integer */ + kDLInt = 0U, + /*! \brief unsigned integer */ + kDLUInt = 1U, + /*! \brief IEEE floating point */ + kDLFloat = 2U, + /*! + * \brief Opaque handle type, reserved for testing purposes. + * Frameworks need to agree on the handle data type for the exchange to be well-defined. + */ + kDLOpaqueHandle = 3U, + /*! \brief bfloat16 */ + kDLBfloat = 4U, + /*! + * \brief complex number + * (C/C++/Python layout: compact struct per complex number) + */ + kDLComplex = 5U, + /*! \brief boolean */ + kDLBool = 6U, +} DLDataTypeCode; + +/*! + * \brief The data type the tensor can hold. The data type is assumed to follow the + * native endian-ness. An explicit error message should be raised when attempting to + * export an array with non-native endianness + * + * Examples + * - float: type_code = 2, bits = 32, lanes = 1 + * - float4(vectorized 4 float): type_code = 2, bits = 32, lanes = 4 + * - int8: type_code = 0, bits = 8, lanes = 1 + * - std::complex: type_code = 5, bits = 64, lanes = 1 + * - bool: type_code = 6, bits = 8, lanes = 1 (as per common array library convention, the underlying storage size of bool is 8 bits) + */ +typedef struct { + /*! + * \brief Type code of base types. + * We keep it uint8_t instead of DLDataTypeCode for minimal memory + * footprint, but the value should be one of DLDataTypeCode enum values. + * */ + uint8_t code; + /*! + * \brief Number of bits, common choices are 8, 16, 32. + */ + uint8_t bits; + /*! \brief Number of lanes in the type, used for vector types. */ + uint16_t lanes; +} DLDataType; + +/*! + * \brief Plain C Tensor object, does not manage memory. + */ +typedef struct { + /*! + * \brief The data pointer points to the allocated data. This will be CUDA + * device pointer or cl_mem handle in OpenCL. It may be opaque on some device + * types. This pointer is always aligned to 256 bytes as in CUDA. The + * `byte_offset` field should be used to point to the beginning of the data. + * + * Note that as of Nov 2021, multiply libraries (CuPy, PyTorch, TensorFlow, + * TVM, perhaps others) do not adhere to this 256 byte aligment requirement + * on CPU/CUDA/ROCm, and always use `byte_offset=0`. This must be fixed + * (after which this note will be updated); at the moment it is recommended + * to not rely on the data pointer being correctly aligned. + * + * For given DLTensor, the size of memory required to store the contents of + * data is calculated as follows: + * + * \code{.c} + * static inline size_t GetDataSize(const DLTensor* t) { + * size_t size = 1; + * for (tvm_index_t i = 0; i < t->ndim; ++i) { + * size *= t->shape[i]; + * } + * size *= (t->dtype.bits * t->dtype.lanes + 7) / 8; + * return size; + * } + * \endcode + */ + void* data; + /*! \brief The device of the tensor */ + DLDevice device; + /*! \brief Number of dimensions */ + int32_t ndim; + /*! \brief The data type of the pointer*/ + DLDataType dtype; + /*! \brief The shape of the tensor */ + const int64_t* shape; + /*! + * \brief strides of the tensor (in number of elements, not bytes) + * can be NULL, indicating tensor is compact and row-majored. + */ + const int64_t* strides; + /*! \brief The offset in bytes to the beginning pointer to data */ + uint64_t byte_offset; +} DLTensor; + +/*! + * \brief C Tensor object, manage memory of DLTensor. This data structure is + * intended to facilitate the borrowing of DLTensor by another framework. It is + * not meant to transfer the tensor. When the borrowing framework doesn't need + * the tensor, it should call the deleter to notify the host that the resource + * is no longer needed. + */ +typedef struct DLManagedTensor { + /*! \brief DLTensor which is being memory managed */ + DLTensor dl_tensor; + /*! \brief the context of the original host framework of DLManagedTensor in + * which DLManagedTensor is used in the framework. It can also be NULL. + */ + void * manager_ctx; + /*! \brief Destructor signature void (*)(void*) - this should be called + * to destruct manager_ctx which holds the DLManagedTensor. It can be NULL + * if there is no way for the caller to provide a reasonable destructor. + * The destructors deletes the argument self as well. + */ + void (*deleter)(struct DLManagedTensor * self); +} DLManagedTensor; +#ifdef __cplusplus +} // DLPACK_EXTERN_C +#endif +#endif // DLPACK_DLPACK_H_ diff --git a/phivenv/Lib/site-packages/torch/include/ATen/jit_macros.h b/phivenv/Lib/site-packages/torch/include/ATen/jit_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..ac6d0432425f11f761dcf26de7b0402a8daae5ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/jit_macros.h @@ -0,0 +1,7 @@ +#pragma once +#include +#include + +// AT_USE_JITERATOR(), controls whether we jit some elementwise kernels +#define AT_USE_JITERATOR() true +#define jiterator_stringify(...) std::string(#__VA_ARGS__); diff --git a/phivenv/Lib/site-packages/torch/include/ATen/jiterator_macros.h b/phivenv/Lib/site-packages/torch/include/ATen/jiterator_macros.h new file mode 100644 index 0000000000000000000000000000000000000000..ccde91c67237707108eb61cc0eea38d0768aa2b5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/jiterator_macros.h @@ -0,0 +1,38 @@ +#pragma once +#include +#include + +#define JITERATOR_HOST_DEVICE C10_HOST_DEVICE +#if defined(_MSC_VER) && defined(__CUDACC__) +// NVRTC on Windows errors if __host__ __device__ attribute is +// present on kernel. +// error: attribute "__host__" does not apply here +// error: attribute "__device__" does not apply here +#define JITERATOR_HOST_DEVICE +#endif + +// jiterator_also_stringify_as macro is used to define code (for CPU/ROCm) +// and generate code string for `jiterator` (only when compiling for CUDA). +// Usage : +// jiterator_also_stringify_as( +// jiterator_code(template T identity(T x) { return x; }), +// identity_string); +// This will define the template `identity` as present in code and +// also define `std::string identity_string` with the code as the string +// if this is being compiled for CUDA. + +// `jiterator_code` macro is to deal with `,` in the kernel code. +// These `,`s confuse the preprocessor into thinking we are passing +// multiple arguments to the macro. +#define jiterator_code(...) __VA_ARGS__ +#if defined(__CUDACC__) || defined(__HIPCC__) +// CPU and CUDA and ROCm case +#define stringify_code(...) #__VA_ARGS__ +#define jiterator_also_stringify_as(code, str_name) \ + code /* define the function */ \ + const std::string str_name = std::string(stringify_code(code)); +#else +// CPU only or CPU and ROCm case +// Only needs the function +#define jiterator_also_stringify_as(code, str_name) code +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ATen/record_function.h b/phivenv/Lib/site-packages/torch/include/ATen/record_function.h new file mode 100644 index 0000000000000000000000000000000000000000..0403b98c32b540cc1b75b476155a042ce58535bc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ATen/record_function.h @@ -0,0 +1,802 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace c10 { +class TORCH_API OperatorHandle; +} + +namespace at { + +// Function name to record NCCL metadata +extern TORCH_API const std::string kParamCommsCallName; + +// Kind of record function scope; +enum class C10_API_ENUM RecordScope : uint8_t { + // c10/ATen ops, autograd nodes + FUNCTION = 0, + // Functions/nodes called from the autograd + BACKWARD_FUNCTION, + // TorchScript functions, methods + TORCHSCRIPT_FUNCTION, + // Kernel Function dtype Tag + KERNEL_FUNCTION_DTYPE, + // Torchbind custom class, + CUSTOM_CLASS, + // Generic Build Feature + BUILD_FEATURE, + // Kernel Function dtype Tag + LITE_INTERPRETER, + // User defined scope (e.g. with record_function()) + USER_SCOPE, + // Scopes for static runtime, a specialized TorchScript interpreter + STATIC_RUNTIME_OP, + STATIC_RUNTIME_MODEL, + NUM_SCOPES, // must be the last in the list +}; + +} // namespace at + +namespace std { +template <> +struct hash { + size_t operator()(const at::RecordScope& sc) const { + return static_cast(sc); + } +}; +} // namespace std + +namespace at { + +struct TORCH_API StringView { + StringView() : StringView(nullptr) {} + explicit StringView(const char* str_ptr) + : owned_str_ptr_(nullptr), str_ptr_(str_ptr) {} + explicit StringView(std::string str) + : owned_str_ptr_(std::make_shared(std::move(str))), + str_ptr_(owned_str_ptr_->c_str()) {} + + const char* str() const { + return str_ptr_; + } + + friend std::ostream& operator<<(std::ostream& os, const StringView& dt) { + os << dt.str(); + return os; + } + + friend bool operator==(const StringView& lhs, const StringView& rhs) { + return strcmp(lhs.str(), rhs.str()) == 0; + } + + friend bool operator!=(const StringView& lhs, const StringView& rhs) { + return !(lhs == rhs); + } + + private: + std::shared_ptr owned_str_ptr_; + const char* str_ptr_; +}; + +// Soft limit on the number of callbacks to use; +constexpr std::size_t kSoftLimitCallbacks = 4; + +// An abstract base class for various observer contexts that can be attached to +// the RecordFunction. +struct ObserverContext { + virtual ~ObserverContext() = default; + + protected: + ObserverContext() = default; +}; + +typedef c10::SmallVector CallbackHandles; +typedef c10::SmallVector, kSoftLimitCallbacks> + ObserverContextList; +typedef uint64_t RecordFunctionHandle; +struct RecordFunction; + +// +// PyTorch callbacks/observers API: +// + +/** + * RecordFunctionCallback represents a pair of callbacks to be used with + * RecordFunction, members: + * start, end - the callbacks to run when entering and exiting the scope; + * optionally, the start callback may return an ObserverContext which will + * be passed to the end callback, use appropriate constructor accordingly. + * needs_inputs - whether the callbacks need the inputs passed from the + * observed function/range; NOTE: passing the inputs incurs an additional + * overhead; sampling_probability - if not 1.0, then the callback is + * probabilistically sampled to run; NOTE: start and end callbacks always run as + * a pair and are sampled together; scopes - types of scopes to execute the + * callbacks on (see RecordScope); passing empty set means the callbacks will be + * executed for all possible scope types should_run - optional function that + * returns whether this callback should run; overwrites the effect of setting + * sampling_probability + */ +class TORCH_API RecordFunctionCallback { + public: + using StartCallback = + std::unique_ptr (*)(const RecordFunction&); + using EndCallback = void (*)(const RecordFunction&, ObserverContext*); + + // This interface supports observers that require passing an ObserverContext + // between start and end callbacks. + explicit RecordFunctionCallback( + StartCallback start, + EndCallback end = nullptr) + : start_(start), end_(end) { + scopes_.fill(true); + } + + RecordFunctionCallback& needsInputs(bool needs_inputs) { + needs_inputs_ = needs_inputs; + return *this; + } + + RecordFunctionCallback& needsOutputs(bool needs_outputs) { + needs_outputs_ = needs_outputs; + return *this; + } + + RecordFunctionCallback& needsIds(bool needs_ids) { + needs_ids_ = needs_ids; + return *this; + } + + RecordFunctionCallback& samplingProb(double sampling_prob) { + TORCH_CHECK( + sampling_prob >= 0.0 && sampling_prob <= 1.0, + "Invalid sampling probability"); + sampling_prob_ = sampling_prob; + return *this; + } + + RecordFunctionCallback& scopes( + const std::unordered_set>& scopes) { + if (!scopes.empty()) { + scopes_.fill(false); + for (auto sc : scopes) { + scopes_[static_cast(sc)] = true; + } + } else { + scopes_.fill(true); + } + return *this; + } + + bool needsInputs() const { + return needs_inputs_; + } + + bool needsOutputs() const { + return needs_outputs_; + } + + bool needsIds() const { + return needs_ids_; + } + + double samplingProb() const { + return sampling_prob_; + } + + bool checkScope(RecordScope sc) const { + return scopes_[(size_t)sc]; + } + + StartCallback start() const { + return start_; + } + + EndCallback end() const { + return end_; + } + + private: + StartCallback start_; + EndCallback end_; + double sampling_prob_ = 1.0; + std::array(RecordScope::NUM_SCOPES)> scopes_ = {}; + bool needs_inputs_ = false; + bool needs_outputs_ = false; + bool needs_ids_ = false; +}; + +// Notes: +// - two types of callbacks are provided: thread local and global +// - thread local callbacks are added/removed only for the given thread +// and are stored locally for each thread and separately from the list +// of the global callbacks +// - global callbacks are stored in a single per process list and are +// invoked by every RecordFunction, in addition to the thread local +// callbacks specific to the given thread +// - we allow the added callbacks to be sampled, by specifying a sampling +// probability for each callback pair, if the start callback is +// not picked to run, the corresponding end callback won't be called +// - a typical use case for the global callbacks is passive monitoring +// in the background (e.g. fleet-wide monitoring), without focusing on +// the specific piece of code +// - in contrast, thread local callbacks are enabled locally, on demand, +// for the specific piece of code (range) and are not sampled +// - a typical use case for thread local callbacks is profiler and code +// execution tracer +// - note, thread local callbacks are automatically propagated with +// ThreadLocalState across JIT continuations and async tasks (at::launch) + +typedef uint64_t CallbackHandle; + +constexpr CallbackHandle INVALID_CALLBACK_HANDLE{0}; + +// It is unnecessary to use atomic operations for enabling +// thread-local function callbacks. Moreover, it prevents saving to +// ThreadLocalState because std::atomic is non-copyable. +struct RecordFunctionCallbacksEntry { + RecordFunctionCallbacksEntry(RecordFunctionCallback cb, CallbackHandle h) + : callback_(cb), handle_(h) {} + + RecordFunctionCallback callback_; + bool enabled_{true}; + CallbackHandle handle_; +}; + +// Holds pairs (callbacks, unique_id) +using RecordFunctionCallbacks = std::vector; + +// Generated by the callback managers to determine which functions to run. +struct StepCallbacks { + StepCallbacks() = default; + StepCallbacks(uint64_t thread_id, RecordScope scope) + : thread_id_{thread_id}, scope_{scope} {} + + bool empty() const { + return callbacks_.empty(); + } + + struct StartEndPair { + RecordFunctionCallback::StartCallback start_; + RecordFunctionCallback::EndCallback end_; + }; + + using StartEndPairs = c10::SmallVector; + + StartEndPairs callbacks_; + uint64_t thread_id_{0}; + RecordScope scope_{RecordScope::FUNCTION}; + bool needs_inputs_{false}; + bool needs_outputs_{false}; + bool needs_ids_{false}; +}; + +struct TORCH_API RecordFunction { + // Default constructor is used with before function called afterwards: + // scope - record scope that this function tracks + // pre_sampled - whether this RecordFunction was already pre-sampled with + // kLowProb probability + explicit RecordFunction(RecordScope scope = RecordScope::FUNCTION); + explicit RecordFunction(StepCallbacks&& step_callbacks); + + template + void before( + F fn, + c10::ArrayRef args, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + inputs_ = args; + before(fn, current_sequence_nr); + } + + template + void before( + F fn, + c10::ArrayRef args, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(std::move(fn), args, current_sequence_nr); + } + + template + void before( + F fn, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(fn, current_sequence_nr); + } + + template + void before( + F fn, + const std::vector* args, + int64_t current_sequence_nr = -1) { + before( + std::move(fn), + c10::ArrayRef(args->data(), args->size()), + current_sequence_nr); + } + + template + void before( + F fn, + const std::vector* args, + const std::unordered_map* kwargs, + int64_t current_sequence_nr = -1) { + if (!isActive()) { + return; + } + kwinputs_ = *kwargs; + before(std::move(fn), args, current_sequence_nr); + } + + // Destructor calls end callbacks + virtual ~RecordFunction(); + + RecordFunction(const RecordFunction&) = delete; + RecordFunction& operator=(const RecordFunction&) = delete; + RecordFunction(RecordFunction&&) = delete; + RecordFunction& operator=(RecordFunction&&) = delete; + + const char* name() const; + const char* overload_name() const; + + int64_t seqNr() const { + return sequence_nr_; + } + + c10::ArrayRef inputs() const { +#ifndef NDEBUG + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + inputs_valid_, "Called inputs() outside RecordFunction start callback"); +#endif + return inputs_; + } + + std::unordered_map kwinputs() const { +#ifndef NDEBUG + TORCH_INTERNAL_ASSERT_DEBUG_ONLY( + inputs_valid_, + "Called kwinputs() outside RecordFunction start callback"); +#endif + return kwinputs_; + } + + const std::vector& outputs() const { + return outputs_; + } + + void setOutputs(std::vector&& outputs) { + outputs_ = std::move(outputs); + } + + void setOutputs(c10::ArrayRef outputs) { + outputs_ = outputs.vec(); + } + + size_t num_inputs() const; + size_t num_outputs() const; + + // Retrieves the thread_id that this RecordFunction ran start callbacks with. + // Useful for writing thread safe end callbacks that may be potentially + // executed in a different thread (async ops) + uint64_t threadId() const { + return step_callbacks_.thread_id_; + } + + // For backward functions - thread id of the corresponding forward function, + // or zero otherwise; + // used alongside with sequence number to correlate backward functions with + // the forward ones + uint64_t forwardThreadId() const { + return fwd_thread_id_; + } + + void setForwardThreadId(uint64_t thread_id) { + fwd_thread_id_ = thread_id; + } + + RecordScope scope() const { + return step_callbacks_.scope_; + } + + // Returns logical thread_id for the current thread + static uint64_t currentThreadId(); + + // Internal functions, do not use directly; + // used in python's context manager + + // before functions initialize RecordFunction members and call + // start callbacks + using schema_ref_t = std::reference_wrapper; + void before(const char* name, int64_t sequence_nr = -1); + void before(std::string name, int64_t sequence_nr = -1); + void before(schema_ref_t schema, int64_t sequence_nr = -1); + + // Sets node ID for distributed profiling + static void setDefaultNodeId(int64_t defaultNodeId); + // Gets node ID for distributed profiling + static int64_t getDefaultNodeId(); + + // Calls end callbacks. After end(), accessors will no longer provide useful + // results. + void end(); + + // Internal-only, used only force async event for distributed events + // profiling. + void _setAsync(); + + // Returns whether this RecordFunction corresponds to an async event or not. + bool isAsync() const; + + // Returns whether this RecordFunction corresponds to NCCL metadata collection + // or not. + bool isNcclMeta() const { + return is_nccl_meta_; + } + + // Internal-only, used to denote out variant used for Static Runtime execution + void _setStaticRuntimeOutVariant(); + bool isStaticRuntimeOutVariant() const; + + RecordFunctionHandle handle() const { + return handle_; + } + + std::optional operator_name() const; + + // This method returns a copy of the FunctionSchema and can be expensive. + std::optional operator_schema() const; + + void setHandle(RecordFunctionHandle handle) { + handle_ = handle; + } + + // Whether this RecordFunction runs any callbacks. + bool isActive() const { + return !step_callbacks_.empty(); + } + + bool needsInputs() const { + return step_callbacks_.needs_inputs_; + } + + bool needsOutputs() const { + return step_callbacks_.needs_outputs_; + } + + int64_t debugHandle() const { + return debug_handle_; + } + + void setDebugHandle(int64_t debug_handle) { + debug_handle_ = debug_handle; + } + + void invalidateInputs() { +#ifndef NDEBUG + inputs_valid_ = false; +#endif + } + + private: + void runStartCallbacks(); + + StepCallbacks step_callbacks_; + + // In cases when RecordFunction might be active but we chose not to + // use the observers (e.g. operator is not observed), this boolean + // flag is used to check whether the start callbacks were called + bool called_start_callbacks_ = false; + +#ifndef NDEBUG + bool inputs_valid_ = false; +#endif + + // Stores various ObserverContext objects with event metadata for callbacks. + ObserverContextList ctx_; + + std::variant fn_; + + int64_t sequence_nr_ = -1; + c10::ArrayRef inputs_; + std::unordered_map kwinputs_; + std::vector outputs_; + + // For backward functions - thread id of the forward function + uint64_t fwd_thread_id_ = 0; + + // Unique id for this RecordFunction, used in callbacks to track start + // and end of ranges + RecordFunctionHandle handle_{0}; + + // Whether this record_function corresponds to an async event or not. Async + // events can complete in different threads or follow a future-like pattern + // of use. + bool is_async_{false}; + + // Debug handles are used for lazy annotation of module hierarchy + // and callstack. + // This is specifically is useful for mobile runtime, where generated + // debug handles can be lazily symbolicated using debug information + int64_t debug_handle_{-1}; + + // Whether this RecordFunction is used for an out variant run with + // Static Runtime + bool is_static_runtime_out_variant_{false}; + + // Whether this RecordFunction is used for NCCL metadata collection + bool is_nccl_meta_{false}; +}; + +TORCH_API StepCallbacks getStepCallbacks(RecordScope scope); + +TORCH_API std::optional getStepCallbacksUnlessEmpty( + RecordScope scope); + +namespace detail { +template +void record_function_with_scope( + RecordFunction& guard, + F fn, + const Inputs& inputs, + Args&&... args) { + if (guard.needsInputs()) { + guard.before( + fn, + c10::ArrayRef(inputs.data(), inputs.size()), + std::forward(args)...); + } else { + guard.before(fn, std::forward(args)...); + } +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + F fn, + int64_t debug_handle, + const Inputs& inputs, + Args&&... args) { + guard.setDebugHandle(debug_handle); + if (guard.needsInputs()) { + guard.before( + fn, + c10::ArrayRef(inputs.data(), inputs.size()), + std::forward(args)...); + } else { + guard.before(fn, std::forward(args)...); + } +} + +template +void record_function_with_scope( + RecordFunction& guard, + F fn, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope< + c10::ArrayRef, + F, + Args...>(guard, std::move(fn), inputs, std::forward(args)...); +} + +template +void record_function_with_scope_and_debug_handle( + RecordFunction& guard, + F fn, + int64_t debug_handle, + c10::ArrayRef inputs, + Args&&... args) { + return record_function_with_scope_and_debug_handle< + c10::ArrayRef, + F, + Args...>( + guard, std::move(fn), debug_handle, inputs, std::forward(args)...); +} + +} // namespace detail + +// optional argument - function's seq_no +#define RECORD_FUNCTION_WITH_SCOPE(scope, fn, inputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + ::at::detail::record_function_with_scope( \ + guard, fn, inputs, ##__VA_ARGS__); \ + } + +#define RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ + scope, fn, inputs, outputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + if (guard.needsInputs()) { \ + guard.before(fn, inputs, ##__VA_ARGS__); \ + } else { \ + guard.before(fn, ##__VA_ARGS__); \ + } \ + if (guard.needsOutputs()) { \ + guard.setOutputs(outputs); \ + } \ + } + +#define RECORD_FUNCTION(fn, inputs, ...) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::FUNCTION, fn, inputs, ##__VA_ARGS__) + +#define RECORD_TORCHSCRIPT_FUNCTION(mn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::TORCHSCRIPT_FUNCTION, mn, inputs) + +#define RECORD_FUNCTION_WITH_INPUTS_OUTPUTS(fn, inputs, outputs, ...) \ + RECORD_FUNCTION_WITH_SCOPE_INPUTS_OUTPUTS( \ + at::RecordScope::FUNCTION, fn, inputs, outputs, ##__VA_ARGS__) + +// Custom user scopes in C++; similar to Python's 'with record_function("..."):' +#define RECORD_USER_SCOPE(fn) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, fn, c10::ArrayRef{}) + +// RECORD_USER_SCOPE with inputs +#define RECORD_USER_SCOPE_WITH_INPUTS(fn, inputs) \ + RECORD_FUNCTION_WITH_SCOPE(at::RecordScope::USER_SCOPE, fn, inputs) + +#define RECORD_USER_SCOPE_WITH_KWARGS_ONLY(fn, kwargs) \ + RECORD_FUNCTION_WITH_SCOPE( \ + at::RecordScope::USER_SCOPE, \ + fn, \ + c10::ArrayRef{}, \ + kwargs) + +// Helper macro to pass in debug handle that is used to +// post process events +#define RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ + scope, fn, debug_handle, inputs, ...) \ + at::RecordFunction guard(scope); \ + if (guard.isActive()) { \ + ::at::detail::record_function_with_scope_and_debug_handle( \ + guard, fn, debug_handle, inputs, ##__VA_ARGS__); \ + } + +// Helper macros to record LITE INTERPETER scope events with debug handles +#define RECORD_EDGE_SCOPE_WITH_DEBUG_HANDLE_AND_INPUTS( \ + fn, debug_handle, inputs) \ + RECORD_WITH_SCOPE_DEBUG_HANDLE_AND_INPUTS( \ + at::RecordScope::LITE_INTERPRETER, fn, debug_handle, inputs) + +// Bookend to the RECORD_FUNCTION macros. Use this after the kernel +// launch to let the profiler bind the outputs to the op that produced +// them. Note that guard is declared by RECORD_FUNCTION so this macro +// needs to be called from the same scope as RECORD_FUNCTION +#define RECORD_OUTPUTS(outputs) \ + if (guard.needsOutputs()) { \ + guard.setOutputs( \ + std::vector(outputs.begin(), outputs.end())); \ + } + +/** + * addThreadLocalCallback adds a thread local callback to run with + * RecordFunction, returns handle to use with removeThreadLocalCallback + */ +TORCH_API CallbackHandle addThreadLocalCallback(RecordFunctionCallback cb); + +/** + * hasThreadLocalCallbacks returns whether there're callbacks registered + * with addThreadLocalCallback + */ +TORCH_API bool hasThreadLocalCallbacks(); + +/** + * clearThreadLocalCallbacks removes all thread local callbacks + */ +TORCH_API void clearThreadLocalCallbacks(); + +/** + * addGlobalCallback adds a global callback to run with RecordFunction: + * + * only during the program initialization + */ +TORCH_API CallbackHandle addGlobalCallback(RecordFunctionCallback cb); + +/** + * removeCallback removes a callback given the handle returned by + * addThreadLocalCallback or addGlobalCallback; + * + * no other code can run simultaneously + */ +TORCH_API void removeCallback(CallbackHandle handle); + +/** + * Prevent the given callback from executing. If handle is invalid, + * does nothing. + */ +TORCH_API void disableCallback(CallbackHandle handle); + +/** + * Allow the given callback, previously disabled with disableCallback, to + * execute again. If handle is invalid, does nothing. + */ +TORCH_API void reenableCallback(CallbackHandle handle); + +/** + * hasGlobalCallbacks returns whether there're global callbacks + * registered with pushGlobalCallback + */ +TORCH_API bool hasGlobalCallbacks(); + +/** + * clearGlobalCallbacks removes all global callbacks + */ +TORCH_API void clearGlobalCallbacks(); + +// for both thread local and global callbacks +TORCH_API bool hasCallbacks(); +TORCH_API void clearCallbacks(); + +/** + * enableRecordFunction enables RecordFunction thread locally + */ +TORCH_API void enableRecordFunction(bool enable = true); + +/** + * isRecordFunctionEnabled returns whether RecordFunction + * is enabled thread locally + */ +TORCH_API bool isRecordFunctionEnabled(); + +class TORCH_API RecordFunctionGuard { + public: + explicit RecordFunctionGuard(bool is_enabled = true) + : prev_value_(isRecordFunctionEnabled()) { + enableRecordFunction(is_enabled); + } + + RecordFunctionGuard(RecordFunctionGuard&& other) = delete; + RecordFunctionGuard(const RecordFunctionGuard&) = delete; + RecordFunctionGuard& operator=(const RecordFunctionGuard&) = delete; + RecordFunctionGuard& operator=(RecordFunctionGuard&&) = delete; + virtual ~RecordFunctionGuard() { + enableRecordFunction(prev_value_); + } + + private: + bool prev_value_ = false; +}; + +class TORCH_API DisableRecordFunctionGuard : public RecordFunctionGuard { + public: + DisableRecordFunctionGuard() : RecordFunctionGuard(false) {} + ~DisableRecordFunctionGuard() override = default; +}; + +struct TORCH_API RecordFunctionTLS { + // Thread local vector of callbacks, holds pairs (callbacks, unique_id); + // must be sorted in increasing handles order + RecordFunctionCallbacks sorted_tls_callbacks_; + + bool tls_record_function_enabled_ = true; +}; + +TORCH_API const RecordFunctionTLS& get_record_function_tls_(); + +TORCH_API void set_record_function_tls_(const RecordFunctionTLS& tls); + +TORCH_API void set_record_function_seed_for_testing(uint32_t seed); + +} // namespace at diff --git a/phivenv/Lib/site-packages/torch/include/advisor-annotate.h b/phivenv/Lib/site-packages/torch/include/advisor-annotate.h new file mode 100644 index 0000000000000000000000000000000000000000..aebceee48e2121843d7d5db0e309e6c07b7f5ad7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/advisor-annotate.h @@ -0,0 +1,520 @@ +/* + * Copyright (C) 2005-2019 Intel Corporation + * SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause + */ + +/* This file defines macros and inline functions used by + * the Intel(R) Advisor XE "Dependencies Modeling" and + * "Suitability Modeling" analysis, and are in described + * in the "Annotations" section of the help. + * + * Expansion Options + * + * There are several options you can used to control how advisor-annotate.h + * is included. To use these, define the option prior to including + * advisor-annotate.h. e.g. + * #define ANNOTATE_DECLARE + * #include "advisor-annotate.h" + * + * Controlling inclusion of windows.h + * + * windows.h is included for declarations for LoadLibrary, GetProcSymbol, + * but this can have interactions with user code, such as conflicting + * definitions of types. There are two general approaches to work around + * this if this triggers problems building your application: + * + * 1. Reduce the amount declared by windows.h by using the following: + * #define NOMINMAX + * #define WIN32_LEAN_AND_MEAN + * prior to including advisor-annotate.h in your code. + * The first avoids problems with STL min/max in particular + * This is sufficient in some cases, and may be the easiest. + * + * 2. Use a declaration/definition approach, where all uses of advisor-annotate.h + * other than one, generate a set of declarations, and windows.h is only + * needed in a single implementation module. In this model, all includes + * of advisor-annotate.h except one specify ANNOTATE_DECLARE, which causes + * advisor-annotate.h to declare an external routine, and not include + * windows.h. A final include of advisor-annotate.h than specifies + * ANNOTATE_DEFINE, to actually define the global routine to resolve + * the external reference. This one include is the only one that winds up + * using windows.h. If necessary, this can be placed in a file by itself. + * + * An example using this mechanism: + * + * ... + * // Some header(s) used in places in your system where you want + * // to be able to use annotations + * #define ANNOTATE_DECLARE + * #include "advisor-annotate.h" + * ... + * // annotation uses + * ANNOTATE_SITE_BEGIN(MySite1) + * ... + * ANNOTATE_SITE_END(MySite1) + * ... + * + * ... + * // A single implementation file (.cpp/.cxx) causes windows.h + * // to be included, and the support routine to be defined as a + * // global routine called from the various annotation uses. + * #define ANNOTATE_DEFINE + * #include "advisor-annotate.h" + * ... + * + * Null expansion of annotations + * + * Some people may find it useful to have no expansion for annotations, + * if you have a project that you want to build without any annotation + * effects at all. (e.g. if you have a project where you want to have + * some annotations in a shared source pool, but only particular + * developers are actually building with the annotations enabled.) + * Defining ANNOTATE_EXPAND_NULL avoids declaring comdat routines, + * and avoids any textual expansion for annotation macros. + */ + +#ifndef _ADVISOR_ANNOTATE_H_ +#define _ADVISOR_ANNOTATE_H_ + +/* Version of the annotations. + * The presence of this macro serves to idetify the annotation definition + * file and the form of annotations. + */ +#define INTEL_ADVISOR_ANNOTATION_VERSION 1.0 + +#ifdef ANNOTATE_EXPAND_NULL + +#define ANNOTATE_SITE_BEGIN(_SITE) +#define ANNOTATE_SITE_END(...) +#define ANNOTATE_TASK_BEGIN(_TASK) +#define ANNOTATE_TASK_END(...) +#define ANNOTATE_ITERATION_TASK(_TASK) +#define ANNOTATE_LOCK_ACQUIRE(_ADDR) +#define ANNOTATE_LOCK_RELEASE(_ADDR) +#define ANNOTATE_RECORD_ALLOCATION(_ADDR, _SIZE) +#define ANNOTATE_RECORD_DEALLOCATION(_ADDR) +#define ANNOTATE_INDUCTION_USES(_ADDR, _SIZE) +#define ANNOTATE_REDUCTION_USES(_ADDR, _SIZE) +#define ANNOTATE_OBSERVE_USES(_ADDR, _SIZE) +#define ANNOTATE_CLEAR_USES(_ADDR) +#define ANNOTATE_DISABLE_OBSERVATION_PUSH +#define ANNOTATE_DISABLE_OBSERVATION_POP +#define ANNOTATE_DISABLE_COLLECTION_PUSH +#define ANNOTATE_DISABLE_COLLECTION_POP +#define ANNOTATE_AGGREGATE_TASK(_COUNT) + +#else /* ANNOTATE_EXPAND_NULL */ + +#if defined(WIN32) || defined(_WIN32) + +#define ANNOTATEAPI __cdecl + +#ifndef ANNOTATE_DECLARE +#include + +typedef HMODULE lib_t; + +#define __itt_get_proc(lib, name) GetProcAddress(lib, name) +#define __itt_load_lib(name) LoadLibraryA(name) +#define __itt_unload_lib(handle) FreeLibrary(handle) +#define __itt_system_error() (int)GetLastError() +#endif /* ANNOTATE_DECLARE */ + +#else /* defined(WIN32) || defined(_WIN32) */ + +#if defined _M_IX86 || __i386__ +# define ANNOTATEAPI __attribute__ ((cdecl)) +#else +# define ANNOTATEAPI /* actual only on x86 platform */ +#endif + + +#ifndef ANNOTATE_DECLARE +#include +#include +#include + +typedef void* lib_t; + +#define __itt_get_proc(lib, name) dlsym(lib, name) +#define __itt_load_lib(name) dlopen(name, RTLD_LAZY) +#define __itt_unload_lib(handle) dlclose(handle) +#define __itt_system_error() errno +#endif /* ANNOTATE_DECLARE */ + +#endif /* defined(WIN32) || defined(_WIN32) */ + +#include + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifndef _ITTNOTIFY_H_ + +/* Handles for sites and tasks. + */ +typedef void* __itt_model_site; /* handle for lexical site */ +typedef void* __itt_model_site_instance; /* handle for dynamic instance */ +typedef void* __itt_model_task; /* handle for lexical site */ +typedef void* __itt_model_task_instance; /* handle for dynamic instance */ + +typedef enum { + __itt_model_disable_observation, + __itt_model_disable_collection +} __itt_model_disable; + +#endif /* _ITTNOTIFY_H_ */ + +/*** Use the routines in libittnotify.dll. ***/ + +/* Basic approach: + * For the case of calling the dll, there is an __annotate_routine function + * declared as a comdat in each compilation unit with annotations present. + * That routine in turn has an internal static structure that is initialized + * once to contain the address of functions occuring in libittnotify.dll. + * Each time an annotation macro is invoked, that causes a call to the + * __annotate_routine function to get addresses of the routines, followed + * by calling the specific routine, provided the address is non-null. + */ + +/* This set of macros generates calls that are part of application images, + * which call the __itt_model_xxx routines in the dynamically loaded + * libittnotify.dll. + */ +#ifndef _ITTNOTIFY_H_ +#define ITT_NOTIFY_DECL(_text) _text +#else +#define ITT_NOTIFY_DECL(_text) +#endif + +/* For C++, a static initialization is used */ +#if defined(__cplusplus) && defined(WIN32) +#define _ANNOTATE_ROUTINES_ADDR __annotate_routines_s +#else +#define _ANNOTATE_ROUTINES_ADDR __annotate_routines_init( __annotate_routines() ) +#endif /* __cplusplus */ + + +#define _ANNOTATE_DECLARE_0(_BASENAME) \ +typedef void (ANNOTATEAPI * __annotate_##_BASENAME##_t)(); \ +static __inline void ANNOTATEAPI __annotate_##_BASENAME##_t_nop() { }; \ +ITT_NOTIFY_DECL( extern void ANNOTATEAPI __itt_model_##_BASENAME(); ) + +#define _ANNOTATE_DECLARE_0_INT(_BASENAME) \ +typedef int (ANNOTATEAPI * __annotate_##_BASENAME##_t)(); \ +static __inline int ANNOTATEAPI __annotate_##_BASENAME##_t_nop() { return 0; }; \ +ITT_NOTIFY_DECL( extern void ANNOTATEAPI __itt_model_##_BASENAME(); ) + +#define _ANNOTATE_CALL_0(_BASENAME) { _ANNOTATE_ROUTINES_ADDR->_BASENAME(); } + +#define _ANNOTATE_DECLARE_1(_BASENAME, _P1TYPE) \ +typedef void (ANNOTATEAPI * __annotate_##_BASENAME##_t)(_P1TYPE p1); \ +static __inline void ANNOTATEAPI __annotate_##_BASENAME##_t_nop(_P1TYPE p1) { (void)p1; }; \ +ITT_NOTIFY_DECL( extern void ANNOTATEAPI __itt_model_##_BASENAME(_P1TYPE p1); ) + +#define _ANNOTATE_CALL_1(_BASENAME, _P1) { _ANNOTATE_ROUTINES_ADDR->_BASENAME(_P1); } + +#define _ANNOTATE_DECLARE_2(_BASENAME, _P1TYPE, _P2TYPE) \ +typedef void (ANNOTATEAPI * __annotate_##_BASENAME##_t)(_P1TYPE p1, _P2TYPE p2); \ +static __inline void ANNOTATEAPI __annotate_##_BASENAME##_t_nop(_P1TYPE p1, _P2TYPE p2) { (void)p1; (void)p2; }; \ +ITT_NOTIFY_DECL( extern void ANNOTATEAPI __itt_model_##_BASENAME(_P1TYPE p1, _P2TYPE p2); ) + +#define _ANNOTATE_CALL_2(_BASENAME, _P1, _P2) { _ANNOTATE_ROUTINES_ADDR->_BASENAME((_P1), (_P2)); } + +/*** Declare routines appropriately based on usage style ***/ + +/* Depending on above, this will either expand to comdats that are + * used directly, or comdats that call routines in libittnotify.dll + */ +_ANNOTATE_DECLARE_1(site_beginA, const char *) +_ANNOTATE_DECLARE_0(site_end_2) +_ANNOTATE_DECLARE_1(task_beginA, const char *) +_ANNOTATE_DECLARE_0(task_end_2) +_ANNOTATE_DECLARE_1(iteration_taskA, const char *) +_ANNOTATE_DECLARE_1(lock_acquire_2, void *) +_ANNOTATE_DECLARE_1(lock_release_2, void *) +_ANNOTATE_DECLARE_2(record_allocation, void *, size_t) +_ANNOTATE_DECLARE_1(record_deallocation, void *) +_ANNOTATE_DECLARE_2(induction_uses, void *, size_t) +_ANNOTATE_DECLARE_2(reduction_uses, void *, size_t) +_ANNOTATE_DECLARE_2(observe_uses, void *, size_t) +_ANNOTATE_DECLARE_1(clear_uses, void *) +_ANNOTATE_DECLARE_1(disable_push, __itt_model_disable) +_ANNOTATE_DECLARE_0(disable_pop) +_ANNOTATE_DECLARE_1(aggregate_task, size_t) +_ANNOTATE_DECLARE_0_INT(is_collection_disabled) + +/* All of the symbols potentially in the library + */ +struct __annotate_routines { + volatile int initialized; + __annotate_site_beginA_t site_beginA; + __annotate_site_end_2_t site_end_2; + __annotate_task_beginA_t task_beginA; + __annotate_task_end_2_t task_end_2; + __annotate_iteration_taskA_t iteration_taskA; + __annotate_lock_acquire_2_t lock_acquire_2; + __annotate_lock_release_2_t lock_release_2; + __annotate_record_allocation_t record_allocation; + __annotate_record_deallocation_t record_deallocation; + __annotate_induction_uses_t induction_uses; + __annotate_reduction_uses_t reduction_uses; + __annotate_observe_uses_t observe_uses; + __annotate_clear_uses_t clear_uses; + __annotate_disable_push_t disable_push; + __annotate_disable_pop_t disable_pop; + __annotate_aggregate_task_t aggregate_task; + __annotate_is_collection_disabled_t is_collection_disabled; +}; + +/* This comdat-ed routine means there is a single instance of the function pointer + * structure per image + */ +static __inline struct __annotate_routines* __annotate_routines() +{ + static struct __annotate_routines __annotate_routines; + return &__annotate_routines; +} + +/* This routine is called to get the address of an initialized + * set of function pointers for the annotation routines. + */ + +#ifdef ANNOTATE_DECLARE +extern struct __annotate_routines* ANNOTATEAPI __annotate_routines_init(struct __annotate_routines* itt); +#else +#ifdef ANNOTATE_DEFINE + /* */ +#else + static __inline +#endif +struct __annotate_routines* +ANNOTATEAPI +__annotate_routines_init(struct __annotate_routines* itt) { + + if (itt->initialized) { + return itt; + } else { + + /* Initialized by first invocation + * This assumes that the code here can be executed successfully + * by multiple threads, should that ever happen. + */ + int do_disable_pop = 0; + char* lib_name = NULL; + lib_t itt_notify = 0; + + if (sizeof(void*) > 4) { + lib_name = getenv("INTEL_LIBITTNOTIFY64"); + } else { + lib_name = getenv("INTEL_LIBITTNOTIFY32"); + } + + if (lib_name) { + itt_notify = __itt_load_lib(lib_name); + } else { +#if defined(WIN32) || defined(_WIN32) + itt_notify = __itt_load_lib("libittnotify.dll"); +#elif defined(__APPLE__) + itt_notify = __itt_load_lib("libittnotify.dylib"); +#else + itt_notify = __itt_load_lib("libittnotify.so"); +#endif + } + + if (itt_notify != NULL) { + /* The static variables initialized and itt are reported as race conditions + * or inconsistent lock usage by Dependencies Modeling in some obscure cases + * involving multiple dlls. Ignoring this initialization phase gets rid of + * this problem. + */ + __annotate_disable_push_t disable_push; + __annotate_is_collection_disabled_t is_collection_disabled; + disable_push = (__annotate_disable_push_t) __itt_get_proc(itt_notify, "__itt_model_disable_push"); + is_collection_disabled = (__annotate_is_collection_disabled_t) __itt_get_proc(itt_notify, "__itt_model_is_collection_disabled"); + if (disable_push) { + if ( ! (is_collection_disabled && is_collection_disabled()) ) + { + // disable collection only if it is not disabled already (for example, started paused) + disable_push(__itt_model_disable_observation); + do_disable_pop = 1; + } + } + itt->site_beginA = (__annotate_site_beginA_t) __itt_get_proc(itt_notify, "__itt_model_site_beginA"); + itt->site_end_2 = (__annotate_site_end_2_t) __itt_get_proc(itt_notify, "__itt_model_site_end_2"); + itt->task_beginA = (__annotate_task_beginA_t) __itt_get_proc(itt_notify, "__itt_model_task_beginA"); + itt->task_end_2 = (__annotate_task_end_2_t) __itt_get_proc(itt_notify, "__itt_model_task_end_2"); + itt->iteration_taskA = (__annotate_iteration_taskA_t) __itt_get_proc(itt_notify, "__itt_model_iteration_taskA"); + itt->lock_acquire_2 = (__annotate_lock_acquire_2_t) __itt_get_proc(itt_notify, "__itt_model_lock_acquire_2"); + itt->lock_release_2 = (__annotate_lock_release_2_t) __itt_get_proc(itt_notify, "__itt_model_lock_release_2"); + itt->record_allocation = (__annotate_record_allocation_t) __itt_get_proc(itt_notify, "__itt_model_record_allocation"); + itt->record_deallocation = (__annotate_record_deallocation_t)__itt_get_proc(itt_notify, "__itt_model_record_deallocation"); + itt->induction_uses = (__annotate_induction_uses_t) __itt_get_proc(itt_notify, "__itt_model_induction_uses"); + itt->reduction_uses = (__annotate_reduction_uses_t) __itt_get_proc(itt_notify, "__itt_model_reduction_uses"); + itt->observe_uses = (__annotate_observe_uses_t) __itt_get_proc(itt_notify, "__itt_model_observe_uses"); + itt->clear_uses = (__annotate_clear_uses_t) __itt_get_proc(itt_notify, "__itt_model_clear_uses"); + itt->disable_push = disable_push; + itt->disable_pop = (__annotate_disable_pop_t) __itt_get_proc(itt_notify, "__itt_model_disable_pop"); + itt->aggregate_task = (__annotate_aggregate_task_t) __itt_get_proc(itt_notify, "__itt_model_aggregate_task"); + itt->is_collection_disabled = is_collection_disabled; + } + /* No-op routine for any that didn't get resolved */ + if (!itt->site_beginA) itt->site_beginA = __annotate_site_beginA_t_nop; + if (!itt->site_end_2) itt->site_end_2 = __annotate_site_end_2_t_nop; + if (!itt->task_beginA) itt->task_beginA = __annotate_task_beginA_t_nop; + if (!itt->task_end_2) itt->task_end_2 = __annotate_task_end_2_t_nop; + if (!itt->iteration_taskA) itt->iteration_taskA = __annotate_iteration_taskA_t_nop; + if (!itt->lock_acquire_2) itt->lock_acquire_2 = __annotate_lock_acquire_2_t_nop; + if (!itt->lock_release_2) itt->lock_release_2 = __annotate_lock_release_2_t_nop; + if (!itt->record_allocation) itt->record_allocation = __annotate_record_allocation_t_nop; + if (!itt->record_deallocation) itt->record_deallocation=__annotate_record_deallocation_t_nop; + if (!itt->induction_uses) itt->induction_uses = __annotate_induction_uses_t_nop; + if (!itt->reduction_uses) itt->reduction_uses = __annotate_reduction_uses_t_nop; + if (!itt->observe_uses) itt->observe_uses = __annotate_observe_uses_t_nop; + if (!itt->clear_uses) itt->clear_uses = __annotate_clear_uses_t_nop; + if (!itt->disable_push) itt->disable_push = __annotate_disable_push_t_nop; + if (!itt->disable_pop) itt->disable_pop = __annotate_disable_pop_t_nop; + if (!itt->aggregate_task) itt->aggregate_task = __annotate_aggregate_task_t_nop; + if (!itt->is_collection_disabled) itt->is_collection_disabled = __annotate_is_collection_disabled_t_nop; + + itt->initialized = 1; + + if (do_disable_pop) { + itt->disable_pop(); + } + } + return itt; +} +#endif /* ANNOTATE_DECLARE */ + +/* For C++ only, use a class to force initialization */ + +#if defined(__cplusplus) && defined(WIN32) +/* Force one-shot initialization so individual calls don't need it */ +static struct __annotate_routines* __annotate_routines_s = __annotate_routines_init( __annotate_routines() ); +#endif + +/* For C++, allow the Annotate::SiteBegin(x) form. For Windows CLR, this is the default + * expansion for the macros (with no-inline) to get the best call stacks in the tools. */ +#if defined(__cplusplus) +/* Ensure this code is managed and non-inlinable */ +#if defined(WIN32) && defined(__CLR_VER) +#pragma managed(push, on) +#define ANNOTATE_CLR_NOINLINE __declspec(noinline) +#else +#define ANNOTATE_CLR_NOINLINE +#endif +class Annotate { +public: + static ANNOTATE_CLR_NOINLINE void SiteBegin(const char* site) { _ANNOTATE_ROUTINES_ADDR->site_beginA(site); } + static ANNOTATE_CLR_NOINLINE void SiteEnd() { _ANNOTATE_ROUTINES_ADDR->site_end_2(); } + static ANNOTATE_CLR_NOINLINE void TaskBegin(const char* task) { _ANNOTATE_ROUTINES_ADDR->task_beginA(task); } + static ANNOTATE_CLR_NOINLINE void TaskEnd() { _ANNOTATE_ROUTINES_ADDR->task_end_2(); } + static ANNOTATE_CLR_NOINLINE void IterationTask(const char* task) { _ANNOTATE_ROUTINES_ADDR->iteration_taskA(task); } + static ANNOTATE_CLR_NOINLINE void LockAcquire(void* lockId) { _ANNOTATE_ROUTINES_ADDR->lock_acquire_2(lockId); } + static ANNOTATE_CLR_NOINLINE void LockRelease(void* lockId) { _ANNOTATE_ROUTINES_ADDR->lock_release_2(lockId); } + static ANNOTATE_CLR_NOINLINE void RecordAllocation(void *p, size_t s) { _ANNOTATE_ROUTINES_ADDR->record_allocation(p, s); } + static ANNOTATE_CLR_NOINLINE void RecordDeallocation(void *p) { _ANNOTATE_ROUTINES_ADDR->record_deallocation(p); } + static ANNOTATE_CLR_NOINLINE void InductionUses(void *p, size_t s) { _ANNOTATE_ROUTINES_ADDR->induction_uses(p, s); } + static ANNOTATE_CLR_NOINLINE void ReductionUses(void *p, size_t s) { _ANNOTATE_ROUTINES_ADDR->reduction_uses(p, s); } + static ANNOTATE_CLR_NOINLINE void ObserveUses(void *p, size_t s) { _ANNOTATE_ROUTINES_ADDR->observe_uses(p, s); } + static ANNOTATE_CLR_NOINLINE void ClearUses(void *p) { _ANNOTATE_ROUTINES_ADDR->clear_uses(p); } + static ANNOTATE_CLR_NOINLINE void DisablePush(__itt_model_disable d) { _ANNOTATE_ROUTINES_ADDR->disable_push(d); } + static ANNOTATE_CLR_NOINLINE void DisablePop() { _ANNOTATE_ROUTINES_ADDR->disable_pop(); } + static ANNOTATE_CLR_NOINLINE void AggregateTask(size_t c) { _ANNOTATE_ROUTINES_ADDR->aggregate_task(c); } +}; +#if defined(WIN32) && defined(__CLR_VER) +#pragma managed(pop) +#endif +#undef ANNOTATE_CLR_NOINLINE +#endif + +#if defined(__cplusplus) && defined(WIN32) && defined(__CLR_VER) + +#define ANNOTATE_SITE_BEGIN(_SITE) Annotate::SiteBegin(#_SITE) +#define ANNOTATE_SITE_END(...) Annotate::SiteEnd() +#define ANNOTATE_TASK_BEGIN(_TASK) Annotate::TaskBegin(#_TASK) +#define ANNOTATE_TASK_END(...) Annotate::TaskEnd() +#define ANNOTATE_ITERATION_TASK(_TASK) Annotate::IterationTask(#_TASK) +#define ANNOTATE_LOCK_ACQUIRE(_ADDR) Annotate::LockAcquire(_ADDR) +#define ANNOTATE_LOCK_RELEASE(_ADDR) Annotate::LockRelease(_ADDR) +#define ANNOTATE_RECORD_ALLOCATION(_ADDR, _SIZE) Annotate::RecordAllocation((_ADDR), (_SIZE)) +#define ANNOTATE_RECORD_DEALLOCATION(_ADDR) Annotate::RecordDeallocation(_ADDR) +#define ANNOTATE_INDUCTION_USES(_ADDR, _SIZE) Annotate::InductionUses((_ADDR), (_SIZE)) +#define ANNOTATE_REDUCTION_USES(_ADDR, _SIZE) Annotate::ReductionUses((_ADDR), (_SIZE)) +#define ANNOTATE_OBSERVE_USES(_ADDR, _SIZE) Annotate::ObserveUses((_ADDR), (_SIZE)) +#define ANNOTATE_CLEAR_USES(_ADDR) Annotate::ClearUses(_ADDR) +#define ANNOTATE_DISABLE_OBSERVATION_PUSH Annotate::DisablePush(itt_model_disable_observation) +#define ANNOTATE_DISABLE_OBSERVATION_POP Annotate::DisablePop() +#define ANNOTATE_DISABLE_COLLECTION_PUSH Annotate::DisablePush(__itt_model_disable_collection) +#define ANNOTATE_DISABLE_COLLECTION_POP Annotate::DisablePop() +#define ANNOTATE_AGGREGATE_TASK(_COUNT) Annotate::AggregateTask(_COUNT) + +#else + +/* Mark the start of a site (region) to be analyzed by the tool */ +#define ANNOTATE_SITE_BEGIN(_SITE) _ANNOTATE_CALL_1(site_beginA, #_SITE) + +/* Mark the end of a site (region) to be analyzed by the tool and + * indicate a WaitForAll task synchronization */ +#define ANNOTATE_SITE_END(...) _ANNOTATE_CALL_0(site_end_2) + +/* Mark the beginning of a region of code that constitutes a task */ +#define ANNOTATE_TASK_BEGIN(_TASK) _ANNOTATE_CALL_1(task_beginA, #_TASK) + +/* Mark the end of a region of code that constitutes a task */ +#define ANNOTATE_TASK_END(...) _ANNOTATE_CALL_0(task_end_2) + +/* Mark the break between one task and the next task (a "split" description model + * rather than a "begin/end" description model. */ +#define ANNOTATE_ITERATION_TASK(_TASK) _ANNOTATE_CALL_1(iteration_taskA, #_TASK) + +/* Acquire a lock identified by lockId */ +#define ANNOTATE_LOCK_ACQUIRE(_ADDR) _ANNOTATE_CALL_1(lock_acquire_2, (_ADDR)) + +/* Release a lock identified by lockId */ +#define ANNOTATE_LOCK_RELEASE(_ADDR) _ANNOTATE_CALL_1(lock_release_2, (_ADDR)) + +/* Record user allocation of memory */ +#define ANNOTATE_RECORD_ALLOCATION(_ADDR, _SIZE) _ANNOTATE_CALL_2(record_allocation, (_ADDR), (_SIZE)) + +/* Record user deallocation of memory */ +#define ANNOTATE_RECORD_DEALLOCATION(_ADDR) _ANNOTATE_CALL_1(record_deallocation, (_ADDR)) + +/* Denote storage as an inductive value */ +#define ANNOTATE_INDUCTION_USES(_ADDR, _SIZE) _ANNOTATE_CALL_2(induction_uses, (_ADDR), (_SIZE)) + +/* Denote storage as a reduction */ +#define ANNOTATE_REDUCTION_USES(_ADDR, _SIZE) _ANNOTATE_CALL_2(reduction_uses, (_ADDR), (_SIZE)) + +/* Record all observations of uses */ +#define ANNOTATE_OBSERVE_USES(_ADDR, _SIZE) _ANNOTATE_CALL_2(observe_uses, (_ADDR), (_SIZE)) + +/* Clear handling of values */ +#define ANNOTATE_CLEAR_USES(_ADDR) _ANNOTATE_CALL_1(clear_uses, (_ADDR)) + +/* Push disable of observations */ +#define ANNOTATE_DISABLE_OBSERVATION_PUSH _ANNOTATE_CALL_1(disable_push, __itt_model_disable_observation) + +/* Pop disable of observations */ +#define ANNOTATE_DISABLE_OBSERVATION_POP _ANNOTATE_CALL_0(disable_pop) + +/* Push disable of collection */ +#define ANNOTATE_DISABLE_COLLECTION_PUSH _ANNOTATE_CALL_1(disable_push, __itt_model_disable_collection) + +/* Pop disable of collection */ +#define ANNOTATE_DISABLE_COLLECTION_POP _ANNOTATE_CALL_0(disable_pop) + +/* Task aggregation */ +#define ANNOTATE_AGGREGATE_TASK(_COUNT) _ANNOTATE_CALL_1(aggregate_task, (_COUNT)) + +#endif + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* ANNOTATE_EXPAND_NULL */ + +#endif /* _ADVISOR_ANNOTATE_H_ */ diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/a64.h b/phivenv/Lib/site-packages/torch/include/asmjit/a64.h new file mode 100644 index 0000000000000000000000000000000000000000..1c19e2f72cc9922425e709e2d157152da7d089a1 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/a64.h @@ -0,0 +1,60 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_A64_H_INCLUDED +#define ASMJIT_A64_H_INCLUDED + +//! \addtogroup asmjit_a64 +//! +//! ### Emitters +//! +//! - \ref a64::Assembler - AArch64 assembler (must read, provides examples). +//! - \ref a64::Builder - AArch64 builder. +//! - \ref a64::Compiler - AArch64 compiler. +//! - \ref a64::Emitter - AArch64 emitter (abstract). +//! +//! ### Supported Instructions +//! +//! - Emitters: +//! - \ref a64::EmitterExplicitT - Provides all instructions that use explicit operands, provides also utility +//! functions. The member functions provided are part of all AArch64 emitters. +//! +//! - Instruction representation: +//! - \ref a64::Inst::Id - instruction identifiers. +//! +//! ### Register Operands +//! +//! - \ref arm::Reg - Base class of all AArch32/AArch64 registers. +//! - \ref a64::Gp - General purpose register (AArch64): +//! - \ref a64::GpW - 32-bit general purpose register (AArch64). +//! - \ref a64::GpX - 64-bit general purpose register (AArch64). +//! - \ref a64::Vec - Vector (SIMD) register: +//! - \ref a64::VecB - 8-bit SIMD register. +//! - \ref a64::VecH - 16-bit SIMD register. +//! - \ref a64::VecS - 32-bit SIMD register. +//! - \ref a64::VecD - 64-bit SIMD register. +//! - \ref a64::VecV - 128-bit SIMD register. +//! +//! ### Memory Operands +//! +//! - \ref arm::Mem - AArch32/AArch64 memory operand that provides support for all ARM addressing features +//! including base, index, pre/post increment, and ARM-specific shift addressing and index extending. +//! +//! ### Other +//! +//! - \ref arm::Shift - Shift operation and value. +//! - \ref arm::Utils - Utilities that can help during code generation for AArch32 and AArch64. + +#include "./arm.h" +#include "./arm/a64assembler.h" +#include "./arm/a64builder.h" +#include "./arm/a64compiler.h" +#include "./arm/a64emitter.h" +#include "./arm/a64globals.h" +#include "./arm/a64instdb.h" +#include "./arm/a64operand.h" + +#endif // ASMJIT_A64_H_INCLUDED + diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm.h new file mode 100644 index 0000000000000000000000000000000000000000..65a17212c40b8a72e87204f5c661311bbdbaff68 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm.h @@ -0,0 +1,84 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_H_INCLUDED +#define ASMJIT_ARM_H_INCLUDED + +//! \addtogroup asmjit_arm +//! +//! ### Namespaces +//! +//! - \ref arm - arm namespace provides common functionality for both AArch32 and AArch64 backends. +//! - \ref a32 - a32 namespace provides support for AArch32 architecture. In addition it includes +//! \ref arm namespace, so you can only use a single namespace when targeting AArch32 architecture. +//! - \ref a64 - a64 namespace provides support for AArch64 architecture. In addition it includes +//! \ref arm namespace, so you can only use a single namespace when targeting AArch64 architecture. +//! +//! ### Emitters +//! +//! - AArch32 +//! - \ref a32::Assembler - AArch32 assembler (must read, provides examples). +//! - \ref a32::Builder - AArch32 builder. +//! - \ref a32::Compiler - AArch32 compiler. +//! - \ref a32::Emitter - AArch32 emitter (abstract). +//! +//! - AArch64 +//! - \ref a64::Assembler - AArch64 assembler (must read, provides examples). +//! - \ref a64::Builder - AArch64 builder. +//! - \ref a64::Compiler - AArch64 compiler. +//! - \ref a64::Emitter - AArch64 emitter (abstract). +//! +//! ### Supported Instructions +//! +//! - AArch32: +//! - Emitters: +//! - \ref a32::EmitterExplicitT - Provides all instructions that use explicit operands, provides also +//! utility functions. The member functions provided are part of all AArch32 emitters. +//! - Instruction representation: +//! - \ref a32::Inst::Id - instruction identifiers. +//! +//! - AArch64: +//! - Emitters: +//! - \ref a64::EmitterExplicitT - Provides all instructions that use explicit operands, provides also +//! utility functions. The member functions provided are part of all AArch64 emitters. +//! - Instruction representation: +//! - \ref a64::Inst::Id - instruction identifiers. +//! +//! ### Register Operands +//! +//! - \ref arm::Reg - Base class of all AArch32/AArch64 registers. +//! - \ref a32::Gp - 32-bit general purpose register used by AArch32: +//! - \ref a64::Gp - 32-bit or 64-bit general purpose register used by AArch64: +//! - \ref a64::GpW - 32-bit register (AArch64). +//! - \ref a64::GpX - 64-bit register (AArch64). +//! - \ref arm::BaseVec - Base vector (SIMD) register. +//! - \ref a32::Vec - Vector (SIMD) register (AArch32): +//! - \ref a32::VecS - 32-bit SIMD register (AArch32). +//! - \ref a32::VecD - 64-bit SIMD register (AArch32). +//! - \ref a32::VecV - 128-bit SIMD register (AArch32). +//! - \ref a64::Vec - Vector (SIMD) register (AArch64): +//! - \ref a64::VecB - 8-bit SIMD register (AArch64). +//! - \ref a64::VecH - 16-bit SIMD register (AArch64). +//! - \ref a64::VecS - 32-bit SIMD register (AArch64). +//! - \ref a64::VecD - 64-bit SIMD register (AArch64). +//! - \ref a64::VecV - 128-bit SIMD register (AArch64). +//! +//! ### Memory Operands +//! +//! - \ref arm::Mem - AArch32/AArch64 memory operand that provides support for all ARM addressing features +//! including base, index, pre/post increment, and ARM-specific shift addressing and index extending. +//! +//! ### Other +//! +//! - \ref arm::Shift - Shift operation and value (both AArch32 and AArch64). +//! - \ref arm::DataType - Data type that is part of an instruction in AArch32 mode. +//! - \ref arm::Utils - Utilities that can help during code generation for AArch32 and AArch64. + +#include "./core.h" +#include "./arm/armglobals.h" +#include "./arm/armoperand.h" +#include "./arm/armutils.h" + +#endif // ASMJIT_ARM_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64assembler.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64assembler.h new file mode 100644 index 0000000000000000000000000000000000000000..5bb756c37df13f1d39fe2a9e81989cfdaa657df8 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64assembler.h @@ -0,0 +1,61 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64ASSEMBLER_H_INCLUDED +#define ASMJIT_ARM_A64ASSEMBLER_H_INCLUDED + +#include "../core/assembler.h" +#include "../arm/a64emitter.h" +#include "../arm/a64operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 assembler implementation. +class ASMJIT_VIRTAPI Assembler + : public BaseAssembler, + public EmitterExplicitT { + +public: + typedef BaseAssembler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API Assembler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Assembler() noexcept override; + + //! \} + + //! \name Emit + //! \{ + + ASMJIT_API Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) override; + + //! \} + + //! \name Align + //! \{ + + ASMJIT_API Error align(AlignMode alignMode, uint32_t alignment) override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64ASSEMBLER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64builder.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64builder.h new file mode 100644 index 0000000000000000000000000000000000000000..acb6b1bb3aeac22336df0fce5a0480895f40f569 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64builder.h @@ -0,0 +1,57 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64BUILDER_H_INCLUDED +#define ASMJIT_ARM_A64BUILDER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_BUILDER + +#include "../core/builder.h" +#include "../arm/a64emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 builder implementation. +class ASMJIT_VIRTAPI Builder + : public BaseBuilder, + public EmitterExplicitT { +public: + ASMJIT_NONCOPYABLE(Builder) + typedef BaseBuilder Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Builder(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Builder() noexcept override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_BUILDER +#endif // ASMJIT_ARM_A64BUILDER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64compiler.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..2d45b166e2c730fb07213d105372fa3050da7872 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64compiler.h @@ -0,0 +1,254 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64COMPILER_H_INCLUDED +#define ASMJIT_ARM_A64COMPILER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_COMPILER + +#include "../core/compiler.h" +#include "../core/type.h" +#include "../arm/a64emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 compiler implementation. +class ASMJIT_VIRTAPI Compiler + : public BaseCompiler, + public EmitterExplicitT { +public: + ASMJIT_NONCOPYABLE(Compiler) + typedef BaseCompiler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Compiler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Compiler() noexcept override; + + //! \} + + //! \name Virtual Registers + //! \{ + + //! \cond INTERNAL + template + ASMJIT_INLINE_NODEBUG RegT _newRegInternal(const Type& type) { + RegT reg(Globals::NoInit); + _newReg(®, type, nullptr); + return reg; + } + + template + ASMJIT_INLINE_NODEBUG RegT _newRegInternal(const Type& type, const char* s) { +#ifndef ASMJIT_NO_LOGGING + RegT reg(Globals::NoInit); + _newReg(®, type, s); + return reg; +#else + DebugUtils::unused(s); + return _newRegInternal(type); +#endif + } + + template + ASMJIT_INLINE_NODEBUG RegT _newRegInternal(const Type& type, const char* s, Args&&... args) { +#ifndef ASMJIT_NO_LOGGING + RegT reg(Globals::NoInit); + _newRegFmt(®, type, s, std::forward(args)...); + return reg; +#else + DebugUtils::unused(s, std::forward(args)...); + return _newRegInternal(type); +#endif + } + //! \endcond + + template + ASMJIT_INLINE_NODEBUG RegT newSimilarReg(const RegT& ref, Args&&... args) { + return _newRegInternal(ref, std::forward(args)...); + } + + template + ASMJIT_INLINE_NODEBUG Reg newReg(TypeId typeId, Args&&... args) { return _newRegInternal(typeId, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newGp(TypeId typeId, Args&&... args) { return _newRegInternal(typeId, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVec(TypeId typeId, Args&&... args) { return _newRegInternal(typeId, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newInt32(Args&&... args) { return _newRegInternal(TypeId::kInt32, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newUInt32(Args&&... args) { return _newRegInternal(TypeId::kUInt32, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newInt64(Args&&... args) { return _newRegInternal(TypeId::kInt64, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newUInt64(Args&&... args) { return _newRegInternal(TypeId::kUInt64, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newIntPtr(Args&&... args) { return _newRegInternal(TypeId::kIntPtr, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newUIntPtr(Args&&... args) { return _newRegInternal(TypeId::kUIntPtr, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Gp newGpw(Args&&... args) { return _newRegInternal(TypeId::kUInt32, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newGpx(Args&&... args) { return _newRegInternal(TypeId::kUInt64, std::forward(args)...); } + template + ASMJIT_INLINE_NODEBUG Gp newGpz(Args&&... args) { return _newRegInternal(TypeId::kUIntPtr, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVecS(Args&&... args) { return _newRegInternal(TypeId::kFloat32, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVecD(Args&&... args) { return _newRegInternal(TypeId::kFloat64, std::forward(args)...); } + + template + ASMJIT_INLINE_NODEBUG Vec newVecQ(Args&&... args) { return _newRegInternal(TypeId::kUInt8x16, std::forward(args)...); } + + //! \} + + //! \name Stack + //! \{ + + //! Creates a new memory chunk allocated on the current function's stack. + ASMJIT_INLINE_NODEBUG Mem newStack(uint32_t size, uint32_t alignment, const char* name = nullptr) { + Mem m(Globals::NoInit); + _newStack(&m, size, alignment, name); + return m; + } + + //! \} + + //! \name Constants + //! \{ + + //! Put data to a constant-pool and get a memory reference to it. + ASMJIT_INLINE_NODEBUG Mem newConst(ConstPoolScope scope, const void* data, size_t size) { + Mem m(Globals::NoInit); + _newConst(&m, scope, data, size); + return m; + } + + //! Put a BYTE `val` to a constant-pool (8 bits). + ASMJIT_INLINE_NODEBUG Mem newByteConst(ConstPoolScope scope, uint8_t val) noexcept { return newConst(scope, &val, 1); } + //! Put a HWORD `val` to a constant-pool (16 bits). + ASMJIT_INLINE_NODEBUG Mem newHWordConst(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a WORD `val` to a constant-pool (32 bits). + ASMJIT_INLINE_NODEBUG Mem newWordConst(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a DWORD `val` to a constant-pool (64 bits). + ASMJIT_INLINE_NODEBUG Mem newDWordConst(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt16Const(ConstPoolScope scope, int16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt16Const(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt32Const(ConstPoolScope scope, int32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt32Const(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt64Const(ConstPoolScope scope, int64_t val) noexcept { return newConst(scope, &val, 8); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt64Const(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a SP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newFloatConst(ConstPoolScope scope, float val) noexcept { return newConst(scope, &val, 4); } + //! Put a DP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newDoubleConst(ConstPoolScope scope, double val) noexcept { return newConst(scope, &val, 8); } + + //! \} + + //! \name Instruction Options + //! \{ + + //! Force the compiler to not follow the conditional or unconditional jump. + ASMJIT_INLINE_NODEBUG Compiler& unfollow() noexcept { _instOptions |= InstOptions::kUnfollow; return *this; } + + //! \} + + //! \name Compiler specific + //! \{ + + //! Special pseudo-instruction that can be used to load a memory address into `o0` GP register. + //! + //! \note At the moment this instruction is only useful to load a stack allocated address into a GP register + //! for further use. It makes very little sense to use it for anything else. The semantics of this instruction + //! is the same as X86 `LEA` (load effective address) instruction. + ASMJIT_INLINE_NODEBUG Error loadAddressOf(const Gp& o0, const Mem& o1) { return _emitter()->_emitI(Inst::kIdAdr, o0, o1); } + + //! \} + + //! \name Function Call & Ret Intrinsics + //! \{ + + //! Invoke a function call without `target` type enforcement. + ASMJIT_INLINE_NODEBUG Error invoke_(InvokeNode** out, const Operand_& target, const FuncSignature& signature) { + return addInvokeNode(out, Inst::kIdBlr, target, signature); + } + + //! Invoke a function call of the given `target` and `signature` and store the added node to `out`. + //! + //! Creates a new \ref InvokeNode, initializes all the necessary members to match the given function `signature`, + //! adds the node to the compiler, and stores its pointer to `out`. The operation is atomic, if anything fails + //! nullptr is stored in `out` and error code is returned. + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Gp& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Mem& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Label& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Imm& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, uint64_t target, const FuncSignature& signature) { return invoke_(out, Imm(int64_t(target)), signature); } + + //! Return. + ASMJIT_INLINE_NODEBUG Error ret() { return addRet(Operand(), Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0) { return addRet(o0, Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0, const BaseReg& o1) { return addRet(o0, o1); } + + //! \} + + //! \name Jump Tables Support + //! \{ + + using EmitterExplicitT::br; + + //! Adds a jump to the given `target` with the provided jump `annotation`. + ASMJIT_INLINE_NODEBUG Error br(const BaseReg& target, JumpAnnotation* annotation) { return emitAnnotatedJump(Inst::kIdBr, target, annotation); } + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_COMPILER +#endif // ASMJIT_ARM_A64COMPILER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64emitter.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..64a1f827fa98b60f830dc2e12a017f7e76a636b3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64emitter.h @@ -0,0 +1,1232 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64EMITTER_H_INCLUDED +#define ASMJIT_ARM_A64EMITTER_H_INCLUDED + +#include "../core/emitter.h" +#include "../core/support.h" +#include "../arm/a64instdb.h" +#include "../arm/a64operand.h" + +// MSVC targeting AArch64 defines a lot of macros without underscores clashing +// with AArch64 instruction names. We have to workaround until it's fixed in SDK. +#if defined(_MSC_VER) && defined(mvn) + #define ASMJIT_RESTORE_MSVC_AARCH64_MACROS + #pragma push_macro("mvn") + #undef mvn +#endif + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +#define ASMJIT_INST_0x(NAME, ID) \ + inline Error NAME() { return _emitter()->_emitI(Inst::kId##ID); } + +#define ASMJIT_INST_1x(NAME, ID, T0) \ + inline Error NAME(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID, o0); } + +#define ASMJIT_INST_2x(NAME, ID, T0, T1) \ + inline Error NAME(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID, o0, o1); } + +#define ASMJIT_INST_3x(NAME, ID, T0, T1, T2) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2); } + +#define ASMJIT_INST_4x(NAME, ID, T0, T1, T2, T3) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3); } + +#define ASMJIT_INST_5x(NAME, ID, T0, T1, T2, T3, T4) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4); } + +#define ASMJIT_INST_6x(NAME, ID, T0, T1, T2, T3, T4, T5) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4, const T5& o5) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4, o5); } + +#define ASMJIT_INST_1cc(NAME, ID, T0) \ + inline Error NAME(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID, o0); } \ + \ + inline Error NAME(CondCode cc, const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, cc), o0); } \ + \ + inline Error NAME##_eq(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kEQ), o0); } \ + inline Error NAME##_ne(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kNE), o0); } \ + inline Error NAME##_cs(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kCS), o0); } \ + inline Error NAME##_hs(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kHS), o0); } \ + inline Error NAME##_cc(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kCC), o0); } \ + inline Error NAME##_lo(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLO), o0); } \ + inline Error NAME##_mi(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kMI), o0); } \ + inline Error NAME##_pl(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kPL), o0); } \ + inline Error NAME##_vs(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kVS), o0); } \ + inline Error NAME##_vc(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kVC), o0); } \ + inline Error NAME##_hi(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kHI), o0); } \ + inline Error NAME##_ls(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLS), o0); } \ + inline Error NAME##_ge(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kGE), o0); } \ + inline Error NAME##_lt(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLT), o0); } \ + inline Error NAME##_gt(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kGT), o0); } \ + inline Error NAME##_le(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kLE), o0); } \ + inline Error NAME##_al(const T0& o0) { return _emitter()->_emitI(BaseInst::composeARMInstId(Inst::kId##ID, CondCode::kAL), o0); } + +//! \addtogroup asmjit_a64 +//! \{ + +//! ARM emitter. +//! +//! NOTE: This class cannot be instantiated, you can only cast to it and use it as emitter that emits to either +//! \ref Assembler, \ref Builder, or \ref Compiler (use with caution with \ref Compiler as it expects virtual +//! registers to be used). +template +struct EmitterExplicitT { + //! \cond + + // These two are unfortunately reported by the sanitizer. We know what we do, however, the sanitizer doesn't. + // I have tried to use reinterpret_cast instead, but that would generate bad code when compiled by MSC. + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG This* _emitter() noexcept { return static_cast(this); } + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG const This* _emitter() const noexcept { return static_cast(this); } + + //! \endcond + + //! \name General Purpose Instructions + //! \{ + + ASMJIT_INST_3x(adc, Adc, Gp, Gp, Gp) + ASMJIT_INST_3x(adcs, Adcs, Gp, Gp, Gp) + + ASMJIT_INST_3x(add, Add, Gp, Gp, Gp) + ASMJIT_INST_4x(add, Add, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(add, Add, Gp, Gp, Imm) + ASMJIT_INST_4x(add, Add, Gp, Gp, Imm, Imm) + ASMJIT_INST_3x(adds, Adds, Gp, Gp, Gp) + ASMJIT_INST_3x(adds, Adds, Gp, Gp, Imm) + ASMJIT_INST_4x(adds, Adds, Gp, Gp, Gp, Imm) + ASMJIT_INST_4x(adds, Adds, Gp, Gp, Imm, Imm) + + ASMJIT_INST_2x(adr, Adr, Gp, Imm) + ASMJIT_INST_2x(adr, Adr, Gp, Label) + ASMJIT_INST_2x(adrp, Adrp, Gp, Imm) + ASMJIT_INST_2x(adrp, Adrp, Gp, Label) + + ASMJIT_INST_3x(and_, And, Gp, Gp, Imm) + ASMJIT_INST_3x(and_, And, Gp, Gp, Gp) + ASMJIT_INST_4x(and_, And, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(ands, Ands, Gp, Gp, Imm) + ASMJIT_INST_3x(ands, Ands, Gp, Gp, Gp) + ASMJIT_INST_4x(ands, Ands, Gp, Gp, Gp, Imm) + + ASMJIT_INST_3x(asr, Asr, Gp, Gp, Imm) + ASMJIT_INST_3x(asr, Asr, Gp, Gp, Gp) + ASMJIT_INST_3x(asrv, Asrv, Gp, Gp, Gp) + + ASMJIT_INST_2x(at, At, Imm, Gp) + + ASMJIT_INST_3x(bfc, Bfc, Gp, Imm, Imm) + ASMJIT_INST_4x(bfi, Bfi, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(bfm, Bfm, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(bfxil, Bfxil, Gp, Gp, Imm, Imm) + + ASMJIT_INST_3x(bic, Bic, Gp, Gp, Imm); + ASMJIT_INST_3x(bic, Bic, Gp, Gp, Gp); + ASMJIT_INST_4x(bic, Bic, Gp, Gp, Gp, Imm); + ASMJIT_INST_3x(bics, Bics, Gp, Gp, Imm); + ASMJIT_INST_3x(bics, Bics, Gp, Gp, Gp); + ASMJIT_INST_4x(bics, Bics, Gp, Gp, Gp, Imm); + + ASMJIT_INST_1x(brk, Brk, Imm) + + ASMJIT_INST_4x(ccmn, Ccmn, Gp, Gp, Imm, Imm); + ASMJIT_INST_4x(ccmn, Ccmn, Gp, Imm, Imm, Imm); + ASMJIT_INST_4x(ccmp, Ccmp, Gp, Gp, Imm, Imm); + ASMJIT_INST_4x(ccmp, Ccmp, Gp, Imm, Imm, Imm); + + ASMJIT_INST_3x(cinc, Cinc, Gp, Gp, Imm); + ASMJIT_INST_3x(cinv, Cinv, Gp, Gp, Imm); + + ASMJIT_INST_1x(clrex, Clrex, Imm) + + ASMJIT_INST_2x(cls, Cls, Gp, Gp) + ASMJIT_INST_2x(clz, Clz, Gp, Gp) + + ASMJIT_INST_2x(cmn, Cmn, Gp, Gp) + ASMJIT_INST_3x(cmn, Cmn, Gp, Gp, Imm) + ASMJIT_INST_2x(cmn, Cmn, Gp, Imm) + ASMJIT_INST_3x(cmn, Cmn, Gp, Imm, Imm) + ASMJIT_INST_2x(cmp, Cmp, Gp, Gp) + ASMJIT_INST_3x(cmp, Cmp, Gp, Gp, Imm) + ASMJIT_INST_2x(cmp, Cmp, Gp, Imm) + ASMJIT_INST_3x(cmp, Cmp, Gp, Imm, Imm) + + ASMJIT_INST_3x(cneg, Cneg, Gp, Gp, Imm); + + ASMJIT_INST_4x(csel, Csel, Gp, Gp, Gp, Imm); + ASMJIT_INST_2x(cset, Cset, Gp, Imm); + ASMJIT_INST_2x(csetm, Csetm, Gp, Imm); + + ASMJIT_INST_4x(csinc, Csinc, Gp, Gp, Gp, Imm); + ASMJIT_INST_4x(csinv, Csinv, Gp, Gp, Gp, Imm); + ASMJIT_INST_4x(csneg, Csneg, Gp, Gp, Gp, Imm); + + ASMJIT_INST_2x(dc, Dc, Imm, Gp) + ASMJIT_INST_1x(dmb, Dmb, Imm) + ASMJIT_INST_1x(dsb, Dsb, Imm) + ASMJIT_INST_0x(drps, Drps) + + ASMJIT_INST_3x(eon, Eon, Gp, Gp, Gp) + ASMJIT_INST_4x(eon, Eon, Gp, Gp, Gp, Imm) + + ASMJIT_INST_3x(eor, Eor, Gp, Gp, Imm) + ASMJIT_INST_3x(eor, Eor, Gp, Gp, Gp) + ASMJIT_INST_4x(eor, Eor, Gp, Gp, Gp, Imm) + + ASMJIT_INST_0x(eret, Eret) + ASMJIT_INST_0x(esb, Esb) + + ASMJIT_INST_4x(extr, Extr, Gp, Gp, Gp, Imm) + + ASMJIT_INST_1x(hlt, Hlt, Imm) + ASMJIT_INST_1x(hvc, Hvc, Imm) + ASMJIT_INST_2x(ic, Ic, Imm, Gp) + ASMJIT_INST_1x(isb, Isb, Imm) + + ASMJIT_INST_3x(lsl, Lsl, Gp, Gp, Imm) + ASMJIT_INST_3x(lsl, Lsl, Gp, Gp, Gp) + ASMJIT_INST_3x(lslv, Lslv, Gp, Gp, Gp) + + ASMJIT_INST_3x(lsr, Lsr, Gp, Gp, Imm) + ASMJIT_INST_3x(lsr, Lsr, Gp, Gp, Gp) + ASMJIT_INST_3x(lsrv, Lsrv, Gp, Gp, Gp) + + ASMJIT_INST_4x(madd, Madd, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(mneg, Mneg, Gp, Gp, Gp) + + ASMJIT_INST_2x(mov, Mov, Gp, Gp) + ASMJIT_INST_2x(mov, Mov, Gp, Imm) + ASMJIT_INST_2x(movk, Movk, Gp, Imm) + ASMJIT_INST_3x(movk, Movk, Gp, Imm, Imm) + ASMJIT_INST_2x(movn, Movn, Gp, Imm) + ASMJIT_INST_3x(movn, Movn, Gp, Imm, Imm) + ASMJIT_INST_2x(movz, Movz, Gp, Imm) + ASMJIT_INST_3x(movz, Movz, Gp, Imm, Imm) + + ASMJIT_INST_2x(mrs, Mrs, Gp, Imm) + ASMJIT_INST_2x(msr, Msr, Imm, Gp) + ASMJIT_INST_2x(msr, Msr, Imm, Imm) + + ASMJIT_INST_4x(msub, Msub, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(mul, Mul, Gp, Gp, Gp) + + ASMJIT_INST_2x(mvn, Mvn, Gp, Gp) + ASMJIT_INST_3x(mvn, Mvn, Gp, Gp, Imm) + + ASMJIT_INST_2x(neg, Neg, Gp, Gp) + ASMJIT_INST_3x(neg, Neg, Gp, Gp, Imm) + ASMJIT_INST_2x(negs, Negs, Gp, Gp) + ASMJIT_INST_3x(negs, Negs, Gp, Gp, Imm) + + ASMJIT_INST_2x(ngc, Ngc, Gp, Gp) + ASMJIT_INST_2x(ngcs, Ngcs, Gp, Gp) + + ASMJIT_INST_3x(orn, Orn, Gp, Gp, Gp) + ASMJIT_INST_4x(orn, Orn, Gp, Gp, Gp, Imm) + + ASMJIT_INST_3x(orr, Orr, Gp, Gp, Imm) + ASMJIT_INST_3x(orr, Orr, Gp, Gp, Gp) + ASMJIT_INST_4x(orr, Orr, Gp, Gp, Gp, Imm) + + ASMJIT_INST_2x(rbit, Rbit, Gp, Gp) + ASMJIT_INST_1x(ret, Ret, Gp) + + ASMJIT_INST_2x(rev, Rev, Gp, Gp) + ASMJIT_INST_2x(rev16, Rev16, Gp, Gp) + ASMJIT_INST_2x(rev32, Rev32, Gp, Gp) + ASMJIT_INST_2x(rev64, Rev64, Gp, Gp) + + ASMJIT_INST_3x(ror, Ror, Gp, Gp, Imm) + ASMJIT_INST_3x(ror, Ror, Gp, Gp, Gp) + ASMJIT_INST_3x(rorv, Rorv, Gp, Gp, Gp) + + ASMJIT_INST_3x(sbc, Sbc, Gp, Gp, Gp) + ASMJIT_INST_3x(sbcs, Sbcs, Gp, Gp, Gp) + + ASMJIT_INST_4x(sbfiz, Sbfiz, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(sbfm, Sbfm, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(sbfx, Sbfx, Gp, Gp, Imm, Imm) + + ASMJIT_INST_3x(sdiv, Sdiv, Gp, Gp, Gp) + + ASMJIT_INST_4x(smaddl, Smaddl, Gp, Gp, Gp, Gp) + ASMJIT_INST_1x(smc, Smc, Imm) + ASMJIT_INST_3x(smnegl, Smnegl, Gp, Gp, Gp) + ASMJIT_INST_4x(smsubl, Smsubl, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(smulh, Smulh, Gp, Gp, Gp) + ASMJIT_INST_3x(smull, Smull, Gp, Gp, Gp) + + ASMJIT_INST_3x(sub, Sub, Gp, Gp, Gp) + ASMJIT_INST_4x(sub, Sub, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(sub, Sub, Gp, Gp, Imm) + ASMJIT_INST_4x(sub, Sub, Gp, Gp, Imm, Imm) + ASMJIT_INST_3x(subs, Subs, Gp, Gp, Gp) + ASMJIT_INST_4x(subs, Subs, Gp, Gp, Gp, Imm) + ASMJIT_INST_3x(subs, Subs, Gp, Gp, Imm) + ASMJIT_INST_4x(subs, Subs, Gp, Gp, Imm, Imm) + + ASMJIT_INST_1x(svc, Svc, Imm) + + ASMJIT_INST_2x(sxtb, Sxtb, Gp, Gp) + ASMJIT_INST_2x(sxth, Sxth, Gp, Gp) + ASMJIT_INST_2x(sxtw, Sxtw, Gp, Gp) + + ASMJIT_INST_4x(sys, Sys, Imm, Imm, Imm, Imm) + ASMJIT_INST_5x(sys, Sys, Imm, Imm, Imm, Imm, Gp) + + ASMJIT_INST_2x(tlbi, Tlbi, Imm, Gp) + ASMJIT_INST_2x(tst, Tst, Gp, Imm) + ASMJIT_INST_2x(tst, Tst, Gp, Gp) + ASMJIT_INST_3x(tst, Tst, Gp, Gp, Imm) + + ASMJIT_INST_3x(udiv, Udiv, Gp, Gp, Gp) + + ASMJIT_INST_4x(ubfiz, Ubfiz, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(ubfm, Ubfm, Gp, Gp, Imm, Imm) + ASMJIT_INST_4x(ubfx, Ubfx, Gp, Gp, Imm, Imm) + + ASMJIT_INST_4x(umaddl, Umaddl, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(umnegl, Umnegl, Gp, Gp, Gp) + ASMJIT_INST_4x(umsubl, Umsubl, Gp, Gp, Gp, Gp) + ASMJIT_INST_3x(umull, Umull, Gp, Gp, Gp) + ASMJIT_INST_3x(umulh, Umulh, Gp, Gp, Gp) + + ASMJIT_INST_2x(uxtb, Uxtb, Gp, Gp) + ASMJIT_INST_2x(uxth, Uxth, Gp, Gp) + + ASMJIT_INST_0x(csdb, Csdb) + ASMJIT_INST_1x(dcps1, Dcps1, Imm) + ASMJIT_INST_1x(dcps2, Dcps2, Imm) + ASMJIT_INST_1x(dcps3, Dcps3, Imm) + ASMJIT_INST_0x(dgh, Dgh) + ASMJIT_INST_0x(pssbb, Pssbb) + ASMJIT_INST_0x(ssbb, Ssbb) + ASMJIT_INST_1x(udf, Udf, Imm) + ASMJIT_INST_1x(setf8, Setf8, Gp) + ASMJIT_INST_1x(setf16, Setf16, Gp) + + //! \} + + //! \name ARMv8.4 Instructions + //! \{ + + ASMJIT_INST_0x(cfinv, Cfinv) + + //! \} + + //! \name ARMv8.5 Instructions + //! \{ + + ASMJIT_INST_0x(axflag, Axflag) + ASMJIT_INST_0x(xaflag, Xaflag) + + //! \} + + //! \name Branch Instructions + //! \{ + + ASMJIT_INST_1cc(b, B, Imm) + ASMJIT_INST_1cc(b, B, Label) + ASMJIT_INST_1x(bl, Bl, Imm) + ASMJIT_INST_1x(bl, Bl, Label) + ASMJIT_INST_1x(blr, Blr, Gp) + ASMJIT_INST_1x(br, Br, Gp) + ASMJIT_INST_2x(cbz, Cbz, Gp, Imm) + ASMJIT_INST_2x(cbz, Cbz, Gp, Label) + ASMJIT_INST_2x(cbnz, Cbnz, Gp, Imm) + ASMJIT_INST_2x(cbnz, Cbnz, Gp, Label) + ASMJIT_INST_3x(tbnz, Tbnz, Gp, Imm, Imm) + ASMJIT_INST_3x(tbnz, Tbnz, Gp, Imm, Label) + ASMJIT_INST_3x(tbz, Tbz, Gp, Imm, Imm) + ASMJIT_INST_3x(tbz, Tbz, Gp, Imm, Label) + + //! \} + + //! \name Load & Store Instructions + //! \{ + + ASMJIT_INST_3x(cas, Cas, Gp, Gp, Mem) + ASMJIT_INST_3x(casa, Casa, Gp, Gp, Mem) + ASMJIT_INST_3x(casab, Casab, Gp, Gp, Mem) + ASMJIT_INST_3x(casah, Casah, Gp, Gp, Mem) + ASMJIT_INST_3x(casal, Casal, Gp, Gp, Mem) + ASMJIT_INST_3x(casalb, Casalb, Gp, Gp, Mem) + ASMJIT_INST_3x(casalh, Casalh, Gp, Gp, Mem) + ASMJIT_INST_3x(casb, Casb, Gp, Gp, Mem) + ASMJIT_INST_3x(cash, Cash, Gp, Gp, Mem) + ASMJIT_INST_3x(casl, Casl, Gp, Gp, Mem) + ASMJIT_INST_3x(caslb, Caslb, Gp, Gp, Mem) + ASMJIT_INST_3x(caslh, Caslh, Gp, Gp, Mem) + + ASMJIT_INST_5x(casp, Casp, Gp, Gp, Gp, Gp, Mem) + ASMJIT_INST_5x(caspa, Caspa, Gp, Gp, Gp, Gp, Mem) + ASMJIT_INST_5x(caspal, Caspal, Gp, Gp, Gp, Gp, Mem) + ASMJIT_INST_5x(caspl, Caspl, Gp, Gp, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldadd, Ldadd, Gp, Gp, Mem) + ASMJIT_INST_3x(ldadda, Ldadda, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddab, Ldaddab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddah, Ldaddah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddal, Ldaddal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddalb, Ldaddalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddalh, Ldaddalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddb, Ldaddb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddh, Ldaddh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddl, Ldaddl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddlb, Ldaddlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaddlh, Ldaddlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldar, Ldar, Gp, Mem) + ASMJIT_INST_2x(ldarb, Ldarb, Gp, Mem) + ASMJIT_INST_2x(ldarh, Ldarh, Gp, Mem) + + ASMJIT_INST_2x(ldaxr, Ldaxr, Gp, Mem) + ASMJIT_INST_2x(ldaxrb, Ldaxrb, Gp, Mem) + ASMJIT_INST_2x(ldaxrh, Ldaxrh, Gp, Mem) + + ASMJIT_INST_3x(ldclr, Ldclr, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclra, Ldclra, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrab, Ldclrab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrah, Ldclrah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclral, Ldclral, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclralb, Ldclralb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclralh, Ldclralh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrb, Ldclrb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrh, Ldclrh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrl, Ldclrl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrlb, Ldclrlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldclrlh, Ldclrlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldeor, Ldeor, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeora, Ldeora, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorab, Ldeorab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorah, Ldeorah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeoral, Ldeoral, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeoralb, Ldeoralb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeoralh, Ldeoralh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorb, Ldeorb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorh, Ldeorh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorl, Ldeorl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorlb, Ldeorlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldeorlh, Ldeorlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldlar, Ldlar, Gp, Mem) + ASMJIT_INST_2x(ldlarb, Ldlarb, Gp, Mem) + ASMJIT_INST_2x(ldlarh, Ldlarh, Gp, Mem) + + ASMJIT_INST_3x(ldnp, Ldnp, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldp, Ldp, Gp, Gp, Mem) + ASMJIT_INST_3x(ldpsw, Ldpsw, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldr, Ldr, Gp, Mem) + ASMJIT_INST_2x(ldrb, Ldrb, Gp, Mem) + ASMJIT_INST_2x(ldrh, Ldrh, Gp, Mem) + ASMJIT_INST_2x(ldrsb, Ldrsb, Gp, Mem) + ASMJIT_INST_2x(ldrsh, Ldrsh, Gp, Mem) + ASMJIT_INST_2x(ldrsw, Ldrsw, Gp, Mem) + + ASMJIT_INST_3x(ldset, Ldset, Gp, Gp, Mem) + ASMJIT_INST_3x(ldseta, Ldseta, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetab, Ldsetab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetah, Ldsetah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetal, Ldsetal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetalb, Ldsetalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetalh, Ldsetalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetb, Ldsetb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldseth, Ldseth, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetl, Ldsetl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetlb, Ldsetlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsetlh, Ldsetlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldsmax, Ldsmax, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxa, Ldsmaxa, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxab, Ldsmaxab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxah, Ldsmaxah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxal, Ldsmaxal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxalb, Ldsmaxalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxalh, Ldsmaxalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxb, Ldsmaxb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxh, Ldsmaxh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxl, Ldsmaxl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxlb, Ldsmaxlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmaxlh, Ldsmaxlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldsmin, Ldsmin, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsmina, Ldsmina, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminab, Ldsminab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminah, Ldsminah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminal, Ldsminal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminalb, Ldsminalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminalh, Ldsminalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminb, Ldsminb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminh, Ldsminh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminl, Ldsminl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminlb, Ldsminlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldsminlh, Ldsminlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldtr, Ldtr, Gp, Mem) + ASMJIT_INST_2x(ldtrb, Ldtrb, Gp, Mem) + ASMJIT_INST_2x(ldtrh, Ldtrh, Gp, Mem) + ASMJIT_INST_2x(ldtrsb, Ldtrsb, Gp, Mem) + ASMJIT_INST_2x(ldtrsh, Ldtrsh, Gp, Mem) + ASMJIT_INST_2x(ldtrsw, Ldtrsw, Gp, Mem) + + ASMJIT_INST_3x(ldumax, Ldumax, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxa, Ldumaxa, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxab, Ldumaxab, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxah, Ldumaxah, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxal, Ldumaxal, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxalb, Ldumaxalb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxalh, Ldumaxalh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxb, Ldumaxb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxh, Ldumaxh, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxl, Ldumaxl, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxlb, Ldumaxlb, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumaxlh, Ldumaxlh, Gp, Gp, Mem) + + ASMJIT_INST_3x(ldumin, Ldumin, Gp, Gp, Mem) + ASMJIT_INST_3x(ldumina, Ldumina, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminab, Lduminab, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminah, Lduminah, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminal, Lduminal, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminalb, Lduminalb, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminalh, Lduminalh, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminb, Lduminb, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminh, Lduminh, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminl, Lduminl, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminlb, Lduminlb, Gp, Gp, Mem) + ASMJIT_INST_3x(lduminlh, Lduminlh, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldur, Ldur, Gp, Mem) + ASMJIT_INST_2x(ldurb, Ldurb, Gp, Mem) + ASMJIT_INST_2x(ldurh, Ldurh, Gp, Mem) + ASMJIT_INST_2x(ldursb, Ldursb, Gp, Mem) + ASMJIT_INST_2x(ldursh, Ldursh, Gp, Mem) + ASMJIT_INST_2x(ldursw, Ldursw, Gp, Mem) + + ASMJIT_INST_3x(ldxp, Ldxp, Gp, Gp, Mem) + ASMJIT_INST_3x(ldaxp, Ldaxp, Gp, Gp, Mem) + + ASMJIT_INST_2x(ldxr, Ldxr, Gp, Mem) + ASMJIT_INST_2x(ldxrb, Ldxrb, Gp, Mem) + ASMJIT_INST_2x(ldxrh, Ldxrh, Gp, Mem) + + ASMJIT_INST_2x(prfm, Prfm, Imm, Mem) + + ASMJIT_INST_2x(stadd, Stadd, Gp, Mem) + ASMJIT_INST_2x(staddb, Staddb, Gp, Mem) + ASMJIT_INST_2x(staddh, Staddh, Gp, Mem) + ASMJIT_INST_2x(staddl, Staddl, Gp, Mem) + ASMJIT_INST_2x(staddlb, Staddlb, Gp, Mem) + ASMJIT_INST_2x(staddlh, Staddlh, Gp, Mem) + + ASMJIT_INST_2x(stclr, Stclr, Gp, Mem) + ASMJIT_INST_2x(stclrb, Stclrb, Gp, Mem) + ASMJIT_INST_2x(stclrh, Stclrh, Gp, Mem) + ASMJIT_INST_2x(stclrl, Stclrl, Gp, Mem) + ASMJIT_INST_2x(stclrlb, Stclrlb, Gp, Mem) + ASMJIT_INST_2x(stclrlh, Stclrlh, Gp, Mem) + + ASMJIT_INST_2x(steor, Steor, Gp, Mem) + ASMJIT_INST_2x(steorb, Steorb, Gp, Mem) + ASMJIT_INST_2x(steorh, Steorh, Gp, Mem) + ASMJIT_INST_2x(steorl, Steorl, Gp, Mem) + ASMJIT_INST_2x(steorlb, Steorlb, Gp, Mem) + ASMJIT_INST_2x(steorlh, Steorlh, Gp, Mem) + + ASMJIT_INST_2x(stllr, Stllr, Gp, Mem) + ASMJIT_INST_2x(stllrb, Stllrb, Gp, Mem) + ASMJIT_INST_2x(stllrh, Stllrh, Gp, Mem) + + ASMJIT_INST_2x(stlr, Stllr, Gp, Mem) + ASMJIT_INST_2x(stlrb, Stllrb, Gp, Mem) + ASMJIT_INST_2x(stlrh, Stllrh, Gp, Mem) + + ASMJIT_INST_3x(stlxr, Stlxr, Gp, Gp, Mem) + ASMJIT_INST_3x(stlxrb, Stlxrb, Gp, Gp, Mem) + ASMJIT_INST_3x(stlxrh, Stlxrh, Gp, Gp, Mem) + + ASMJIT_INST_3x(stnp, Stnp, Gp, Gp, Mem) + ASMJIT_INST_3x(stp, Stp, Gp, Gp, Mem) + + ASMJIT_INST_2x(str, Str, Gp, Mem) + ASMJIT_INST_2x(strb, Strb, Gp, Mem) + ASMJIT_INST_2x(strh, Strh, Gp, Mem) + + ASMJIT_INST_2x(stset, Stset, Gp, Mem) + ASMJIT_INST_2x(stsetb, Stsetb, Gp, Mem) + ASMJIT_INST_2x(stseth, Stseth, Gp, Mem) + ASMJIT_INST_2x(stsetl, Stsetl, Gp, Mem) + ASMJIT_INST_2x(stsetlb, Stsetlb, Gp, Mem) + ASMJIT_INST_2x(stsetlh, Stsetlh, Gp, Mem) + + ASMJIT_INST_2x(stsmax, Stsmax, Gp, Mem) + ASMJIT_INST_2x(stsmaxb, Stsmaxb, Gp, Mem) + ASMJIT_INST_2x(stsmaxh, Stsmaxh, Gp, Mem) + ASMJIT_INST_2x(stsmaxl, Stsmaxl, Gp, Mem) + ASMJIT_INST_2x(stsmaxlb, Stsmaxlb, Gp, Mem) + ASMJIT_INST_2x(stsmaxlh, Stsmaxlh, Gp, Mem) + + ASMJIT_INST_2x(stsmin, Stsmin, Gp, Mem) + ASMJIT_INST_2x(stsminb, Stsminb, Gp, Mem) + ASMJIT_INST_2x(stsminh, Stsminh, Gp, Mem) + ASMJIT_INST_2x(stsminl, Stsminl, Gp, Mem) + ASMJIT_INST_2x(stsminlb, Stsminlb, Gp, Mem) + ASMJIT_INST_2x(stsminlh, Stsminlh, Gp, Mem) + + ASMJIT_INST_2x(sttr, Sttr, Gp, Mem) + ASMJIT_INST_2x(sttrb, Sttrb, Gp, Mem) + ASMJIT_INST_2x(sttrh, Sttrh, Gp, Mem) + + ASMJIT_INST_2x(stumax, Stumax, Gp, Mem) + ASMJIT_INST_2x(stumaxb, Stumaxb, Gp, Mem) + ASMJIT_INST_2x(stumaxh, Stumaxh, Gp, Mem) + ASMJIT_INST_2x(stumaxl, Stumaxl, Gp, Mem) + ASMJIT_INST_2x(stumaxlb, Stumaxlb, Gp, Mem) + ASMJIT_INST_2x(stumaxlh, Stumaxlh, Gp, Mem) + + ASMJIT_INST_2x(stumin, Stumin, Gp, Mem) + ASMJIT_INST_2x(stuminb, Stuminb, Gp, Mem) + ASMJIT_INST_2x(stuminh, Stuminh, Gp, Mem) + ASMJIT_INST_2x(stuminl, Stuminl, Gp, Mem) + ASMJIT_INST_2x(stuminlb, Stuminlb, Gp, Mem) + ASMJIT_INST_2x(stuminlh, Stuminlh, Gp, Mem) + + ASMJIT_INST_2x(stur, Stur, Gp, Mem) + ASMJIT_INST_2x(sturb, Sturb, Gp, Mem) + ASMJIT_INST_2x(sturh, Sturh, Gp, Mem) + + ASMJIT_INST_4x(stxp, Stxp, Gp, Gp, Gp, Mem) + ASMJIT_INST_4x(stlxp, Stlxp, Gp, Gp, Gp, Mem) + + ASMJIT_INST_3x(stxr, Stxr, Gp, Gp, Mem) + ASMJIT_INST_3x(stxrb, Stxrb, Gp, Gp, Mem) + ASMJIT_INST_3x(stxrh, Stxrh, Gp, Gp, Mem) + + ASMJIT_INST_3x(swp, Swp, Gp, Gp, Mem) + ASMJIT_INST_3x(swpa, Swpa, Gp, Gp, Mem) + ASMJIT_INST_3x(swpab, Swpab, Gp, Gp, Mem) + ASMJIT_INST_3x(swpah, Swpah, Gp, Gp, Mem) + ASMJIT_INST_3x(swpal, Swpal, Gp, Gp, Mem) + ASMJIT_INST_3x(swpalb, Swpalb, Gp, Gp, Mem) + ASMJIT_INST_3x(swpalh, Swpalh, Gp, Gp, Mem) + ASMJIT_INST_3x(swpb, Swpb, Gp, Gp, Mem) + ASMJIT_INST_3x(swph, Swph, Gp, Gp, Mem) + ASMJIT_INST_3x(swpl, Swpl, Gp, Gp, Mem) + ASMJIT_INST_3x(swplb, Swplb, Gp, Gp, Mem) + ASMJIT_INST_3x(swplh, Swplh, Gp, Gp, Mem) + //! \} + + //! \name CRC Instructions (ARMv8.1-A, optional in ARMv8.0-A) + //! \{ + + ASMJIT_INST_3x(crc32b, Crc32b, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32h, Crc32h, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32w, Crc32w, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32x, Crc32x, Gp, Gp, Gp); + + ASMJIT_INST_3x(crc32cb, Crc32cb, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32ch, Crc32ch, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32cw, Crc32cw, Gp, Gp, Gp); + ASMJIT_INST_3x(crc32cx, Crc32cx, Gp, Gp, Gp); + + //! \} + + //! \name MTE Instructions + //! \{ + + ASMJIT_INST_2x(autda, Autda, Gp, Gp); + ASMJIT_INST_2x(autdb, Autdb, Gp, Gp); + ASMJIT_INST_1x(autdza, Autdza, Gp); + ASMJIT_INST_1x(autdzb, Autdzb, Gp); + ASMJIT_INST_2x(autia, Autia, Gp, Gp); + ASMJIT_INST_0x(autia1716, Autia1716); + ASMJIT_INST_0x(autiasp, Autiasp); + ASMJIT_INST_0x(autiaz, Autiaz); + ASMJIT_INST_2x(autib, Autib, Gp, Gp); + ASMJIT_INST_0x(autib1716, Autib1716); + ASMJIT_INST_0x(autibsp, Autibsp); + ASMJIT_INST_0x(autibz, Autibz); + ASMJIT_INST_1x(autiza, Autiza, Gp); + ASMJIT_INST_1x(autizb, Autizb, Gp); + + ASMJIT_INST_3x(gmi, Gmi, Gp, Gp, Gp); + + ASMJIT_INST_2x(cmpp, Cmpp, Gp, Gp); + ASMJIT_INST_4x(addg, Addg, Gp, Gp, Imm, Imm); + + ASMJIT_INST_2x(ldg, Ldg, Gp, Mem) + ASMJIT_INST_2x(ldgm, Ldgm, Gp, Mem) + ASMJIT_INST_2x(ldraa, Ldraa, Gp, Mem) + ASMJIT_INST_2x(ldrab, Ldrab, Gp, Mem) + + ASMJIT_INST_2x(pacda, Pacda, Gp, Gp); + ASMJIT_INST_2x(pacdb, Pacdb, Gp, Gp); + ASMJIT_INST_1x(pacdza, Pacdza, Gp); + ASMJIT_INST_1x(pacdzb, Pacdzb, Gp); + ASMJIT_INST_3x(pacga, Pacga, Gp, Gp, Gp); + + ASMJIT_INST_3x(subp, Subp, Gp, Gp, Gp); + ASMJIT_INST_3x(subps, Subps, Gp, Gp, Gp); + ASMJIT_INST_4x(subg, Subg, Gp, Gp, Imm, Imm); + + ASMJIT_INST_2x(st2g, St2g, Gp, Mem) + ASMJIT_INST_2x(stg, Stg, Gp, Mem) + ASMJIT_INST_3x(stgp, Stgp, Gp, Gp, Mem) + ASMJIT_INST_2x(stgm, Stgm, Gp, Mem) + ASMJIT_INST_2x(stzg, Stzg, Gp, Mem) + ASMJIT_INST_2x(stz2g, Stz2g, Gp, Mem) + ASMJIT_INST_2x(stzgm, Stzgm, Gp, Mem) + + ASMJIT_INST_1x(xpacd, Xpacd, Gp); + ASMJIT_INST_1x(xpaci, Xpaci, Gp); + ASMJIT_INST_0x(xpaclri, Xpaclri); + + //! \} + + //! \name Hint Instructions + //! \{ + + ASMJIT_INST_1x(hint, Hint, Imm) + ASMJIT_INST_0x(nop, Nop) + ASMJIT_INST_0x(sev, Sev) + ASMJIT_INST_0x(sevl, Sevl) + ASMJIT_INST_0x(wfe, Wfe) + ASMJIT_INST_0x(wfi, Wfi) + ASMJIT_INST_0x(yield, Yield) + + //! \} + + //! \name SIMD & FP Instructions + //! \{ + + ASMJIT_INST_2x(abs, Abs_v, Vec, Vec); + ASMJIT_INST_3x(add, Add_v, Vec, Vec, Vec); + ASMJIT_INST_3x(addhn, Addhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(addhn2, Addhn2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(addp, Addp_v, Vec, Vec); + ASMJIT_INST_3x(addp, Addp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(addv, Addv_v, Vec, Vec); + ASMJIT_INST_3x(and_, And_v, Vec, Vec, Vec); + ASMJIT_INST_2x(bic, Bic_v, Vec, Imm); + ASMJIT_INST_3x(bic, Bic_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bic, Bic_v, Vec, Imm, Imm); + ASMJIT_INST_3x(bif, Bif_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bit, Bit_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bsl, Bsl_v, Vec, Vec, Vec); + ASMJIT_INST_2x(cls, Cls_v, Vec, Vec); + ASMJIT_INST_2x(clz, Clz_v, Vec, Vec); + ASMJIT_INST_3x(cmeq, Cmeq_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmeq, Cmeq_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmge, Cmge_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmge, Cmge_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmgt, Cmgt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmgt, Cmgt_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmhi, Cmhi_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmhs, Cmhs_v, Vec, Vec, Vec); + ASMJIT_INST_3x(cmle, Cmle_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmlt, Cmlt_v, Vec, Vec, Imm); + ASMJIT_INST_3x(cmtst, Cmtst_v, Vec, Vec, Vec); + ASMJIT_INST_2x(cnt, Cnt_v, Vec, Vec); + ASMJIT_INST_2x(dup, Dup_v, Vec, Gp); + ASMJIT_INST_2x(dup, Dup_v, Vec, Vec); + ASMJIT_INST_3x(eor, Eor_v, Vec, Vec, Vec); + ASMJIT_INST_4x(ext, Ext_v, Vec, Vec, Vec, Imm); + ASMJIT_INST_3x(fabd, Fabd_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fabs, Fabs_v, Vec, Vec); + ASMJIT_INST_3x(facge, Facge_v, Vec, Vec, Vec); + ASMJIT_INST_3x(facgt, Facgt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fadd, Fadd_v, Vec, Vec, Vec); + ASMJIT_INST_2x(faddp, Faddp_v, Vec, Vec); + ASMJIT_INST_3x(faddp, Faddp_v, Vec, Vec, Vec); + ASMJIT_INST_4x(fccmp, Fccmp_v, Vec, Vec, Imm, Imm); + ASMJIT_INST_4x(fccmpe, Fccmpe_v, Vec, Vec, Imm, Imm); + ASMJIT_INST_3x(fcmeq, Fcmeq_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fcmeq, Fcmeq_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmge, Fcmge_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fcmge, Fcmge_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmgt, Fcmgt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fcmgt, Fcmgt_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmle, Fcmle_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fcmlt, Fcmlt_v, Vec, Vec, Imm); + ASMJIT_INST_2x(fcmp, Fcmp_v, Vec, Vec); + ASMJIT_INST_2x(fcmp, Fcmp_v, Vec, Imm); + ASMJIT_INST_2x(fcmpe, Fcmpe_v, Vec, Vec); + ASMJIT_INST_2x(fcmpe, Fcmpe_v, Vec, Imm); + ASMJIT_INST_4x(fcsel, Fcsel_v, Vec, Vec, Vec, Imm); + ASMJIT_INST_2x(fcvt, Fcvt_v, Vec, Vec); + ASMJIT_INST_2x(fcvtas, Fcvtas_v, Gp, Vec); + ASMJIT_INST_2x(fcvtas, Fcvtas_v, Vec, Vec); + ASMJIT_INST_2x(fcvtau, Fcvtau_v, Gp, Vec); + ASMJIT_INST_2x(fcvtau, Fcvtau_v, Vec, Vec); + ASMJIT_INST_2x(fcvtl, Fcvtl_v, Vec, Vec); + ASMJIT_INST_2x(fcvtl2, Fcvtl2_v, Vec, Vec); + ASMJIT_INST_2x(fcvtms, Fcvtms_v, Gp, Vec); + ASMJIT_INST_2x(fcvtms, Fcvtms_v, Vec, Vec); + ASMJIT_INST_2x(fcvtmu, Fcvtmu_v, Gp, Vec); + ASMJIT_INST_2x(fcvtmu, Fcvtmu_v, Vec, Vec); + ASMJIT_INST_2x(fcvtn, Fcvtn_v, Vec, Vec); + ASMJIT_INST_2x(fcvtn2, Fcvtn2_v, Vec, Vec); + ASMJIT_INST_2x(fcvtns, Fcvtns_v, Gp, Vec); + ASMJIT_INST_2x(fcvtns, Fcvtns_v, Vec, Vec); + ASMJIT_INST_2x(fcvtnu, Fcvtnu_v, Gp, Vec); + ASMJIT_INST_2x(fcvtnu, Fcvtnu_v, Vec, Vec); + ASMJIT_INST_2x(fcvtps, Fcvtps_v, Gp, Vec); + ASMJIT_INST_2x(fcvtps, Fcvtps_v, Vec, Vec); + ASMJIT_INST_2x(fcvtpu, Fcvtpu_v, Gp, Vec); + ASMJIT_INST_2x(fcvtpu, Fcvtpu_v, Vec, Vec); + ASMJIT_INST_2x(fcvtxn, Fcvtxn_v, Vec, Vec); + ASMJIT_INST_2x(fcvtxn2, Fcvtxn2_v, Vec, Vec); + ASMJIT_INST_2x(fcvtzs, Fcvtzs_v, Gp, Vec); + ASMJIT_INST_3x(fcvtzs, Fcvtzs_v, Gp, Vec, Imm); + ASMJIT_INST_2x(fcvtzs, Fcvtzs_v, Vec, Vec); + ASMJIT_INST_3x(fcvtzs, Fcvtzs_v, Vec, Vec, Imm); + ASMJIT_INST_2x(fcvtzu, Fcvtzu_v, Gp, Vec); + ASMJIT_INST_3x(fcvtzu, Fcvtzu_v, Gp, Vec, Imm); + ASMJIT_INST_2x(fcvtzu, Fcvtzu_v, Vec, Vec); + ASMJIT_INST_3x(fcvtzu, Fcvtzu_v, Vec, Vec, Imm); + ASMJIT_INST_3x(fdiv, Fdiv_v, Vec, Vec, Vec); + ASMJIT_INST_4x(fmadd, Fmadd_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(fmax, Fmax_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmaxnm, Fmaxnm_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmaxnmp, Fmaxnmp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fmaxnmp, Fmaxnmp_v, Vec, Vec); + ASMJIT_INST_2x(fmaxnmv, Fmaxnmv_v, Vec, Vec); + ASMJIT_INST_3x(fmaxp, Fmaxp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fmaxp, Fmaxp_v, Vec, Vec); + ASMJIT_INST_2x(fmaxv, Fmaxv_v, Vec, Vec); + ASMJIT_INST_3x(fmin, Fmin_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fminnm, Fminnm_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fminnmv, Fminnmv_v, Vec, Vec); + ASMJIT_INST_3x(fminnmp, Fminnmp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fminnmp, Fminnmp_v, Vec, Vec); + ASMJIT_INST_2x(fminp, Fminp_v, Vec, Vec); + ASMJIT_INST_3x(fminp, Fminp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fminv, Fminv_v, Vec, Vec); + ASMJIT_INST_3x(fmla, Fmla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmls, Fmls_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fmov, Fmov_v, Gp, Vec); + ASMJIT_INST_2x(fmov, Fmov_v, Vec, Gp); + ASMJIT_INST_2x(fmov, Fmov_v, Vec, Vec); + ASMJIT_INST_2x(fmov, Fmov_v, Vec, Imm); + ASMJIT_INST_4x(fmsub, Fmsub_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(fmul, Fmul_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmulx, Fmulx_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fneg, Fneg_v, Vec, Vec); + ASMJIT_INST_4x(fnmadd, Fnmadd_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_4x(fnmsub, Fnmsub_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(fnmul, Fnmul_v, Vec, Vec, Vec); + ASMJIT_INST_2x(frecpe, Frecpe_v, Vec, Vec); + ASMJIT_INST_3x(frecps, Frecps_v, Vec, Vec, Vec); + ASMJIT_INST_2x(frecpx, Frecpx_v, Vec, Vec); + ASMJIT_INST_2x(frint32x, Frint32x_v, Vec, Vec); + ASMJIT_INST_2x(frint32z, Frint32z_v, Vec, Vec); + ASMJIT_INST_2x(frint64x, Frint64x_v, Vec, Vec); + ASMJIT_INST_2x(frint64z, Frint64z_v, Vec, Vec); + ASMJIT_INST_2x(frinta, Frinta_v, Vec, Vec); + ASMJIT_INST_2x(frinti, Frinti_v, Vec, Vec); + ASMJIT_INST_2x(frintm, Frintm_v, Vec, Vec); + ASMJIT_INST_2x(frintn, Frintn_v, Vec, Vec); + ASMJIT_INST_2x(frintp, Frintp_v, Vec, Vec); + ASMJIT_INST_2x(frintx, Frintx_v, Vec, Vec); + ASMJIT_INST_2x(frintz, Frintz_v, Vec, Vec); + ASMJIT_INST_2x(frsqrte, Frsqrte_v, Vec, Vec); + ASMJIT_INST_3x(frsqrts, Frsqrts_v, Vec, Vec, Vec); + ASMJIT_INST_2x(fsqrt, Fsqrt_v, Vec, Vec); + ASMJIT_INST_3x(fsub, Fsub_v, Vec, Vec, Vec); + ASMJIT_INST_2x(ins, Ins_v, Vec, Gp); + ASMJIT_INST_2x(ins, Ins_v, Vec, Vec); + ASMJIT_INST_2x(ld1, Ld1_v, Vec, Mem); + ASMJIT_INST_3x(ld1, Ld1_v, Vec, Vec, Mem); + ASMJIT_INST_4x(ld1, Ld1_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(ld1, Ld1_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_2x(ld1r, Ld1r_v, Vec, Mem); + ASMJIT_INST_3x(ld2, Ld2_v, Vec, Vec, Mem); + ASMJIT_INST_3x(ld2r, Ld2r_v, Vec, Vec, Mem); + ASMJIT_INST_4x(ld3, Ld3_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_4x(ld3r, Ld3r_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(ld4, Ld4_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(ld4r, Ld4r_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_3x(ldnp, Ldnp_v, Vec, Vec, Mem); + ASMJIT_INST_3x(ldp, Ldp_v, Vec, Vec, Mem); + ASMJIT_INST_2x(ldr, Ldr_v, Vec, Mem); + ASMJIT_INST_2x(ldur, Ldur_v, Vec, Mem); + ASMJIT_INST_3x(mla, Mla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(mls, Mls_v, Vec, Vec, Vec); + ASMJIT_INST_2x(mov, Mov_v, Vec, Vec); + ASMJIT_INST_2x(mov, Mov_v, Gp, Vec); + ASMJIT_INST_2x(mov, Mov_v, Vec, Gp); + ASMJIT_INST_2x(movi, Movi_v, Vec, Imm); + ASMJIT_INST_3x(movi, Movi_v, Vec, Imm, Imm); + ASMJIT_INST_3x(mul, Mul_v, Vec, Vec, Vec); + ASMJIT_INST_2x(mvn, Mvn_v, Vec, Vec); + ASMJIT_INST_2x(mvni, Mvni_v, Vec, Imm); + ASMJIT_INST_3x(mvni, Mvni_v, Vec, Imm, Imm); + ASMJIT_INST_2x(neg, Neg_v, Vec, Vec); + ASMJIT_INST_2x(not_, Not_v, Vec, Vec); + ASMJIT_INST_3x(orn, Orn_v, Vec, Vec, Vec); + ASMJIT_INST_2x(orr, Orr_v, Vec, Imm); + ASMJIT_INST_3x(orr, Orr_v, Vec, Vec, Vec); + ASMJIT_INST_3x(orr, Orr_v, Vec, Imm, Imm); + ASMJIT_INST_3x(pmul, Pmul_v, Vec, Vec, Vec); + ASMJIT_INST_3x(pmull, Pmull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(pmull2, Pmull2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(raddhn, Raddhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(raddhn2, Raddhn2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(rbit, Rbit_v, Vec, Vec); + ASMJIT_INST_2x(rev16, Rev16_v, Vec, Vec); + ASMJIT_INST_2x(rev32, Rev32_v, Vec, Vec); + ASMJIT_INST_2x(rev64, Rev64_v, Vec, Vec); + ASMJIT_INST_3x(rshrn, Rshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(rshrn2, Rshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(rsubhn, Rsubhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(rsubhn2, Rsubhn2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(saba, Saba_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabal, Sabal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabal2, Sabal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabd, Sabd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabdl, Sabdl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sabdl2, Sabdl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sadalp, Sadalp_v, Vec, Vec); + ASMJIT_INST_3x(saddl, Saddl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(saddl2, Saddl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(saddlp, Saddlp_v, Vec, Vec); + ASMJIT_INST_2x(saddlv, Saddlv_v, Vec, Vec); + ASMJIT_INST_3x(saddw, Saddw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(saddw2, Saddw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(scvtf, Scvtf_v, Vec, Gp); + ASMJIT_INST_3x(scvtf, Scvtf_v, Vec, Gp, Imm); + ASMJIT_INST_2x(scvtf, Scvtf_v, Vec, Vec); + ASMJIT_INST_3x(scvtf, Scvtf_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shadd, Shadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(shl, Shl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shll, Shll_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shll2, Shll2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shrn, Shrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shrn2, Shrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(shsub, Shsub_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sli, Sli_v, Vec, Vec, Imm); + ASMJIT_INST_3x(smax, Smax_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smaxp, Smaxp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(smaxv, Smaxv_v, Vec, Vec); + ASMJIT_INST_3x(smin, Smin_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sminp, Sminp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sminv, Sminv_v, Vec, Vec); + ASMJIT_INST_3x(smlal, Smlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smlal2, Smlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smlsl, Smlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smlsl2, Smlsl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(smov, Smov_v, Gp, Vec); + ASMJIT_INST_3x(smull, Smull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(smull2, Smull2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sqabs, Sqabs_v, Vec, Vec); + ASMJIT_INST_3x(sqadd, Sqadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlal, Sqdmlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlal2, Sqdmlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlsl, Sqdmlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmlsl2, Sqdmlsl2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmulh, Sqdmulh_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmull, Sqdmull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqdmull2, Sqdmull2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sqneg, Sqneg_v, Vec, Vec); + ASMJIT_INST_3x(sqrdmulh, Sqrdmulh_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqrshl, Sqrshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqrshrn, Sqrshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqrshrn2, Sqrshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqrshrun, Sqrshrun_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqrshrun2, Sqrshrun2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshl, Sqshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqshl, Sqshl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshlu, Sqshlu_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrn, Sqshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrn2, Sqshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrun, Sqshrun_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqshrun2, Sqshrun2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sqsub, Sqsub_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sqxtn, Sqxtn_v, Vec, Vec); + ASMJIT_INST_2x(sqxtn2, Sqxtn2_v, Vec, Vec); + ASMJIT_INST_2x(sqxtun, Sqxtun_v, Vec, Vec); + ASMJIT_INST_2x(sqxtun2, Sqxtun2_v, Vec, Vec); + ASMJIT_INST_3x(srhadd, Srhadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sri, Sri_v, Vec, Vec, Imm); + ASMJIT_INST_3x(srshl, Srshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(srshr, Srshr_v, Vec, Vec, Imm); + ASMJIT_INST_3x(srsra, Srsra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sshl, Sshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sshll, Sshll_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sshll2, Sshll2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(sshr, Sshr_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ssra, Ssra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ssubl, Ssubl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ssubl2, Ssubl2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ssubw, Ssubw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ssubw2, Ssubw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(st1, St1_v, Vec, Mem); + ASMJIT_INST_3x(st1, St1_v, Vec, Vec, Mem); + ASMJIT_INST_4x(st1, St1_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(st1, St1_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_3x(st2, St2_v, Vec, Vec, Mem); + ASMJIT_INST_4x(st3, St3_v, Vec, Vec, Vec, Mem); + ASMJIT_INST_5x(st4, St4_v, Vec, Vec, Vec, Vec, Mem); + ASMJIT_INST_3x(stnp, Stnp_v, Vec, Vec, Mem); + ASMJIT_INST_3x(stp, Stp_v, Vec, Vec, Mem); + ASMJIT_INST_2x(str, Str_v, Vec, Mem); + ASMJIT_INST_2x(stur, Stur_v, Vec, Mem); + ASMJIT_INST_3x(sub, Sub_v, Vec, Vec, Vec); + ASMJIT_INST_3x(subhn, Subhn_v, Vec, Vec, Vec); + ASMJIT_INST_3x(subhn2, Subhn2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(suqadd, Suqadd_v, Vec, Vec); + ASMJIT_INST_2x(sxtl, Sxtl_v, Vec, Vec); + ASMJIT_INST_2x(sxtl2, Sxtl2_v, Vec, Vec); + ASMJIT_INST_3x(tbl, Tbl_v, Vec, Vec, Vec); + ASMJIT_INST_4x(tbl, Tbl_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_5x(tbl, Tbl_v, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_6x(tbl, Tbl_v, Vec, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(tbx, Tbx_v, Vec, Vec, Vec); + ASMJIT_INST_4x(tbx, Tbx_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_5x(tbx, Tbx_v, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_6x(tbx, Tbx_v, Vec, Vec, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(trn1, Trn1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(trn2, Trn2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uaba, Uaba_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabal, Uabal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabal2, Uabal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabd, Uabd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabdl, Uabdl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uabdl2, Uabdl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uadalp, Uadalp_v, Vec, Vec); + ASMJIT_INST_3x(uaddl, Uaddl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uaddl2, Uaddl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uaddlp, Uaddlp_v, Vec, Vec); + ASMJIT_INST_2x(uaddlv, Uaddlv_v, Vec, Vec); + ASMJIT_INST_3x(uaddw, Uaddw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uaddw2, Uaddw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(ucvtf, Ucvtf_v, Vec, Gp); + ASMJIT_INST_3x(ucvtf, Ucvtf_v, Vec, Gp, Imm); + ASMJIT_INST_2x(ucvtf, Ucvtf_v, Vec, Vec); + ASMJIT_INST_3x(ucvtf, Ucvtf_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uhadd, Uhadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uhsub, Uhsub_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umax, Umax_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umaxp, Umaxp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(umaxv, Umaxv_v, Vec, Vec); + ASMJIT_INST_3x(umin, Umin_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uminp, Uminp_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uminv, Uminv_v, Vec, Vec); + ASMJIT_INST_3x(umlal, Umlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umlal2, Umlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umlsl, Umlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umlsl2, Umlsl2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(umov, Umov_v, Gp, Vec); + ASMJIT_INST_3x(umull, Umull_v, Vec, Vec, Vec); + ASMJIT_INST_3x(umull2, Umull2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqadd, Uqadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqrshl, Uqrshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqrshl, Uqrshl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqrshrn, Uqrshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqrshrn2, Uqrshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqshl, Uqshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uqshl, Uqshl_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqshrn, Uqshrn_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqshrn2, Uqshrn2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(uqsub, Uqsub_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uqxtn, Uqxtn_v, Vec, Vec); + ASMJIT_INST_2x(uqxtn2, Uqxtn2_v, Vec, Vec); + ASMJIT_INST_2x(urecpe, Urecpe_v, Vec, Vec); + ASMJIT_INST_3x(urhadd, Urhadd_v, Vec, Vec, Vec); + ASMJIT_INST_3x(urshl, Urshl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(urshr, Urshr_v, Vec, Vec, Imm); + ASMJIT_INST_2x(ursqrte, Ursqrte_v, Vec, Vec); + ASMJIT_INST_3x(ursra, Ursra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ushl, Ushl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ushll, Ushll_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ushll2, Ushll2_v, Vec, Vec, Imm); + ASMJIT_INST_3x(ushr, Ushr_v, Vec, Vec, Imm); + ASMJIT_INST_2x(usqadd, Usqadd_v, Vec, Vec); + ASMJIT_INST_3x(usra, Usra_v, Vec, Vec, Imm); + ASMJIT_INST_3x(usubl, Usubl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usubl2, Usubl2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usubw, Usubw_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usubw2, Usubw2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(uxtl, Uxtl_v, Vec, Vec); + ASMJIT_INST_2x(uxtl2, Uxtl2_v, Vec, Vec); + ASMJIT_INST_3x(uzp1, Uzp1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(uzp2, Uzp2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(xtn, Xtn_v, Vec, Vec); + ASMJIT_INST_2x(xtn2, Xtn2_v, Vec, Vec); + ASMJIT_INST_3x(zip1, Zip1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(zip2, Zip2_v, Vec, Vec, Vec); + + //! \} + + //! \name AES Instructions + //! \{ + + ASMJIT_INST_2x(aesd, Aesd_v, Vec, Vec); + ASMJIT_INST_2x(aese, Aese_v, Vec, Vec); + ASMJIT_INST_2x(aesimc, Aesimc_v, Vec, Vec); + ASMJIT_INST_2x(aesmc, Aesmc_v, Vec, Vec); + + //! \} + + //! \name SHA1 Instructions + //! \{ + + ASMJIT_INST_3x(sha1c, Sha1c_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha1h, Sha1h_v, Vec, Vec); + ASMJIT_INST_3x(sha1m, Sha1m_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha1p, Sha1p_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha1su0, Sha1su0_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha1su1, Sha1su1_v, Vec, Vec); + + //! \} + + //! \name SHA2 Instructions + //! \{ + + ASMJIT_INST_3x(sha256h, Sha256h_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha256h2, Sha256h2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha256su0, Sha256su0_v, Vec, Vec); + ASMJIT_INST_3x(sha256su1, Sha256su1_v, Vec, Vec, Vec); + + //! \} + + //! \name RDMA Instructions (ARMv8.1-A) + //! \{ + + ASMJIT_INST_3x(sqrdmlah, Sqrdmlah_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sqrdmlsh, Sqrdmlsh_v, Vec, Vec, Vec); + + //! \} + + //! \name FCMA Instruction (ARMv8.3-A) + //! \{ + + ASMJIT_INST_4x(fcadd, Fcadd_v, Vec, Vec, Vec, Imm); + ASMJIT_INST_4x(fcmla, Fcmla_v, Vec, Vec, Vec, Imm); + + //! \} + + //! \name JSCVT Instruction (ARMv8.3-A) + //! \{ + + ASMJIT_INST_2x(fjcvtzs, Fjcvtzs_v, Gp, Vec); + + //! \} + + //! \name FHM Instructions + //! \{ + + ASMJIT_INST_3x(fmlal, Fmlal_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmlal2, Fmlal2_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmlsl, Fmlsl_v, Vec, Vec, Vec); + ASMJIT_INST_3x(fmlsl2, Fmlsl2_v, Vec, Vec, Vec); + + + //! \} + + //! \name SHA3 Instructions (ARMv8.4-A, optional in ARMv8.2-A) + //! \{ + + ASMJIT_INST_4x(bcax, Bcax_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_4x(eor3, Eor3_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(rax1, Rax1_v, Vec, Vec, Vec); + ASMJIT_INST_4x(xar, Xar_v, Vec, Vec, Vec, Imm); + + //! \} + + //! \name SHA512 Instructions (ARMv8.4-A) + //! \{ + + ASMJIT_INST_3x(sha512h, Sha512h_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sha512h2, Sha512h2_v, Vec, Vec, Vec); + ASMJIT_INST_2x(sha512su0, Sha512su0_v, Vec, Vec); + ASMJIT_INST_3x(sha512su1, Sha512su1_v, Vec, Vec, Vec); + + //! \} + + //! \name SM3 Instructions (ARMv8.4-A) + //! \{ + + ASMJIT_INST_3x(sm3partw1, Sm3partw1_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3partw2, Sm3partw2_v, Vec, Vec, Vec); + ASMJIT_INST_4x(sm3ss1, Sm3ss1_v, Vec, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt1a, Sm3tt1a_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt1b, Sm3tt1b_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt2a, Sm3tt2a_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sm3tt2b, Sm3tt2b_v, Vec, Vec, Vec); + + //! \} + + //! \name SM4 Instructions (ARMv8.4-A) + //! \{ + + ASMJIT_INST_2x(sm4e, Sm4e_v, Vec, Vec); + ASMJIT_INST_3x(sm4ekey, Sm4ekey_v, Vec, Vec, Vec); + + //! \} + + //! \name DOTPROD Instructions (ARMv8.4-A, optional in ARMv8.2-A) + //! \{ + + ASMJIT_INST_3x(sdot, Sdot_v, Vec, Vec, Vec); + ASMJIT_INST_3x(udot, Udot_v, Vec, Vec, Vec); + + //! \} + + //! \name BF16 Instructions (ARMv8.6-A) + //! \{ + + ASMJIT_INST_2x(bfcvt, Bfcvt_v, Vec, Vec); + ASMJIT_INST_2x(bfcvtn, Bfcvtn_v, Vec, Vec); + ASMJIT_INST_2x(bfcvtn2, Bfcvtn2_v, Vec, Vec); + ASMJIT_INST_3x(bfmlalb, Bfmlalb_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bfmlalt, Bfmlalt_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bfmmla, Bfmmla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(bfdot, Bfdot_v, Vec, Vec, Vec); + + //! \} + + //! \name I8MM Instructions (ARMv8.6-A) + //! \{ + + ASMJIT_INST_3x(smmla, Smmla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(sudot, Sudot_v, Vec, Vec, Vec); + ASMJIT_INST_3x(ummla, Ummla_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usdot, Usdot_v, Vec, Vec, Vec); + ASMJIT_INST_3x(usmmla, Usmmla_v, Vec, Vec, Vec); + + //! \} +}; + +//! Emitter (ARM). +//! +//! \note This class cannot be instantiated, you can only cast to it and use it as emitter that emits to either +//! `a64::Assembler`, `a64::Builder`, or `a64::Compiler` (use with caution with `a64::Compiler` as it requires +//! virtual registers). +class Emitter : public BaseEmitter, public EmitterExplicitT { + ASMJIT_NONCONSTRUCTIBLE(Emitter) +}; + +//! \} + +#undef ASMJIT_INST_0x +#undef ASMJIT_INST_1x +#undef ASMJIT_INST_2x +#undef ASMJIT_INST_3x +#undef ASMJIT_INST_4x +#undef ASMJIT_INST_5x +#undef ASMJIT_INST_6x +#undef ASMJIT_INST_1cc + +ASMJIT_END_SUB_NAMESPACE + +// Restore undefined MSVC AArch64 macros. +#if defined(ASMJIT_RESTORE_MSVC_AARCH64_MACROS) + #pragma pop_macro("mvn") +#endif + +#endif // ASMJIT_ARM_A64EMITTER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64globals.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64globals.h new file mode 100644 index 0000000000000000000000000000000000000000..0d0d4fda5950f127208268ff4a0bee22902630fb --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64globals.h @@ -0,0 +1,1895 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64GLOBALS_H_INCLUDED +#define ASMJIT_ARM_A64GLOBALS_H_INCLUDED + +#include "../arm/armglobals.h" + +//! \namespace asmjit::a64 +//! \ingroup asmjit_a64 +//! +//! AArch64 backend. + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! AArch64 instruction. +//! +//! \note Only used to hold ARM-specific enumerations and static functions. +struct Inst { + //! Instruction id. + enum Id : uint32_t { + // ${InstId:Begin} + kIdNone = 0, //!< Instruction ''. + kIdAdc, //!< Instruction 'adc'. + kIdAdcs, //!< Instruction 'adcs'. + kIdAdd, //!< Instruction 'add'. + kIdAddg, //!< Instruction 'addg'. + kIdAdds, //!< Instruction 'adds'. + kIdAdr, //!< Instruction 'adr'. + kIdAdrp, //!< Instruction 'adrp'. + kIdAnd, //!< Instruction 'and'. + kIdAnds, //!< Instruction 'ands'. + kIdAsr, //!< Instruction 'asr'. + kIdAsrv, //!< Instruction 'asrv'. + kIdAt, //!< Instruction 'at'. + kIdAutda, //!< Instruction 'autda'. + kIdAutdza, //!< Instruction 'autdza'. + kIdAutdb, //!< Instruction 'autdb'. + kIdAutdzb, //!< Instruction 'autdzb'. + kIdAutia, //!< Instruction 'autia'. + kIdAutia1716, //!< Instruction 'autia1716'. + kIdAutiasp, //!< Instruction 'autiasp'. + kIdAutiaz, //!< Instruction 'autiaz'. + kIdAutib, //!< Instruction 'autib'. + kIdAutib1716, //!< Instruction 'autib1716'. + kIdAutibsp, //!< Instruction 'autibsp'. + kIdAutibz, //!< Instruction 'autibz'. + kIdAutiza, //!< Instruction 'autiza'. + kIdAutizb, //!< Instruction 'autizb'. + kIdAxflag, //!< Instruction 'axflag'. + kIdB, //!< Instruction 'b'. + kIdBfc, //!< Instruction 'bfc'. + kIdBfi, //!< Instruction 'bfi'. + kIdBfm, //!< Instruction 'bfm'. + kIdBfxil, //!< Instruction 'bfxil'. + kIdBic, //!< Instruction 'bic'. + kIdBics, //!< Instruction 'bics'. + kIdBl, //!< Instruction 'bl'. + kIdBlr, //!< Instruction 'blr'. + kIdBr, //!< Instruction 'br'. + kIdBrk, //!< Instruction 'brk'. + kIdCas, //!< Instruction 'cas'. + kIdCasa, //!< Instruction 'casa'. + kIdCasab, //!< Instruction 'casab'. + kIdCasah, //!< Instruction 'casah'. + kIdCasal, //!< Instruction 'casal'. + kIdCasalb, //!< Instruction 'casalb'. + kIdCasalh, //!< Instruction 'casalh'. + kIdCasb, //!< Instruction 'casb'. + kIdCash, //!< Instruction 'cash'. + kIdCasl, //!< Instruction 'casl'. + kIdCaslb, //!< Instruction 'caslb'. + kIdCaslh, //!< Instruction 'caslh'. + kIdCasp, //!< Instruction 'casp'. + kIdCaspa, //!< Instruction 'caspa'. + kIdCaspal, //!< Instruction 'caspal'. + kIdCaspl, //!< Instruction 'caspl'. + kIdCbnz, //!< Instruction 'cbnz'. + kIdCbz, //!< Instruction 'cbz'. + kIdCcmn, //!< Instruction 'ccmn'. + kIdCcmp, //!< Instruction 'ccmp'. + kIdCfinv, //!< Instruction 'cfinv'. + kIdCinc, //!< Instruction 'cinc'. + kIdCinv, //!< Instruction 'cinv'. + kIdClrex, //!< Instruction 'clrex'. + kIdCls, //!< Instruction 'cls'. + kIdClz, //!< Instruction 'clz'. + kIdCmn, //!< Instruction 'cmn'. + kIdCmp, //!< Instruction 'cmp'. + kIdCmpp, //!< Instruction 'cmpp'. + kIdCneg, //!< Instruction 'cneg'. + kIdCrc32b, //!< Instruction 'crc32b'. + kIdCrc32cb, //!< Instruction 'crc32cb'. + kIdCrc32ch, //!< Instruction 'crc32ch'. + kIdCrc32cw, //!< Instruction 'crc32cw'. + kIdCrc32cx, //!< Instruction 'crc32cx'. + kIdCrc32h, //!< Instruction 'crc32h'. + kIdCrc32w, //!< Instruction 'crc32w'. + kIdCrc32x, //!< Instruction 'crc32x'. + kIdCsdb, //!< Instruction 'csdb'. + kIdCsel, //!< Instruction 'csel'. + kIdCset, //!< Instruction 'cset'. + kIdCsetm, //!< Instruction 'csetm'. + kIdCsinc, //!< Instruction 'csinc'. + kIdCsinv, //!< Instruction 'csinv'. + kIdCsneg, //!< Instruction 'csneg'. + kIdDc, //!< Instruction 'dc'. + kIdDcps1, //!< Instruction 'dcps1'. + kIdDcps2, //!< Instruction 'dcps2'. + kIdDcps3, //!< Instruction 'dcps3'. + kIdDgh, //!< Instruction 'dgh'. + kIdDmb, //!< Instruction 'dmb'. + kIdDrps, //!< Instruction 'drps'. + kIdDsb, //!< Instruction 'dsb'. + kIdEon, //!< Instruction 'eon'. + kIdEor, //!< Instruction 'eor'. + kIdEsb, //!< Instruction 'esb'. + kIdExtr, //!< Instruction 'extr'. + kIdEret, //!< Instruction 'eret'. + kIdGmi, //!< Instruction 'gmi'. + kIdHint, //!< Instruction 'hint'. + kIdHlt, //!< Instruction 'hlt'. + kIdHvc, //!< Instruction 'hvc'. + kIdIc, //!< Instruction 'ic'. + kIdIsb, //!< Instruction 'isb'. + kIdLdadd, //!< Instruction 'ldadd'. + kIdLdadda, //!< Instruction 'ldadda'. + kIdLdaddab, //!< Instruction 'ldaddab'. + kIdLdaddah, //!< Instruction 'ldaddah'. + kIdLdaddal, //!< Instruction 'ldaddal'. + kIdLdaddalb, //!< Instruction 'ldaddalb'. + kIdLdaddalh, //!< Instruction 'ldaddalh'. + kIdLdaddb, //!< Instruction 'ldaddb'. + kIdLdaddh, //!< Instruction 'ldaddh'. + kIdLdaddl, //!< Instruction 'ldaddl'. + kIdLdaddlb, //!< Instruction 'ldaddlb'. + kIdLdaddlh, //!< Instruction 'ldaddlh'. + kIdLdar, //!< Instruction 'ldar'. + kIdLdarb, //!< Instruction 'ldarb'. + kIdLdarh, //!< Instruction 'ldarh'. + kIdLdaxp, //!< Instruction 'ldaxp'. + kIdLdaxr, //!< Instruction 'ldaxr'. + kIdLdaxrb, //!< Instruction 'ldaxrb'. + kIdLdaxrh, //!< Instruction 'ldaxrh'. + kIdLdclr, //!< Instruction 'ldclr'. + kIdLdclra, //!< Instruction 'ldclra'. + kIdLdclrab, //!< Instruction 'ldclrab'. + kIdLdclrah, //!< Instruction 'ldclrah'. + kIdLdclral, //!< Instruction 'ldclral'. + kIdLdclralb, //!< Instruction 'ldclralb'. + kIdLdclralh, //!< Instruction 'ldclralh'. + kIdLdclrb, //!< Instruction 'ldclrb'. + kIdLdclrh, //!< Instruction 'ldclrh'. + kIdLdclrl, //!< Instruction 'ldclrl'. + kIdLdclrlb, //!< Instruction 'ldclrlb'. + kIdLdclrlh, //!< Instruction 'ldclrlh'. + kIdLdeor, //!< Instruction 'ldeor'. + kIdLdeora, //!< Instruction 'ldeora'. + kIdLdeorab, //!< Instruction 'ldeorab'. + kIdLdeorah, //!< Instruction 'ldeorah'. + kIdLdeoral, //!< Instruction 'ldeoral'. + kIdLdeoralb, //!< Instruction 'ldeoralb'. + kIdLdeoralh, //!< Instruction 'ldeoralh'. + kIdLdeorb, //!< Instruction 'ldeorb'. + kIdLdeorh, //!< Instruction 'ldeorh'. + kIdLdeorl, //!< Instruction 'ldeorl'. + kIdLdeorlb, //!< Instruction 'ldeorlb'. + kIdLdeorlh, //!< Instruction 'ldeorlh'. + kIdLdg, //!< Instruction 'ldg'. + kIdLdgm, //!< Instruction 'ldgm'. + kIdLdlar, //!< Instruction 'ldlar'. + kIdLdlarb, //!< Instruction 'ldlarb'. + kIdLdlarh, //!< Instruction 'ldlarh'. + kIdLdnp, //!< Instruction 'ldnp'. + kIdLdp, //!< Instruction 'ldp'. + kIdLdpsw, //!< Instruction 'ldpsw'. + kIdLdr, //!< Instruction 'ldr'. + kIdLdraa, //!< Instruction 'ldraa'. + kIdLdrab, //!< Instruction 'ldrab'. + kIdLdrb, //!< Instruction 'ldrb'. + kIdLdrh, //!< Instruction 'ldrh'. + kIdLdrsb, //!< Instruction 'ldrsb'. + kIdLdrsh, //!< Instruction 'ldrsh'. + kIdLdrsw, //!< Instruction 'ldrsw'. + kIdLdset, //!< Instruction 'ldset'. + kIdLdseta, //!< Instruction 'ldseta'. + kIdLdsetab, //!< Instruction 'ldsetab'. + kIdLdsetah, //!< Instruction 'ldsetah'. + kIdLdsetal, //!< Instruction 'ldsetal'. + kIdLdsetalb, //!< Instruction 'ldsetalb'. + kIdLdsetalh, //!< Instruction 'ldsetalh'. + kIdLdsetb, //!< Instruction 'ldsetb'. + kIdLdseth, //!< Instruction 'ldseth'. + kIdLdsetl, //!< Instruction 'ldsetl'. + kIdLdsetlb, //!< Instruction 'ldsetlb'. + kIdLdsetlh, //!< Instruction 'ldsetlh'. + kIdLdsmax, //!< Instruction 'ldsmax'. + kIdLdsmaxa, //!< Instruction 'ldsmaxa'. + kIdLdsmaxab, //!< Instruction 'ldsmaxab'. + kIdLdsmaxah, //!< Instruction 'ldsmaxah'. + kIdLdsmaxal, //!< Instruction 'ldsmaxal'. + kIdLdsmaxalb, //!< Instruction 'ldsmaxalb'. + kIdLdsmaxalh, //!< Instruction 'ldsmaxalh'. + kIdLdsmaxb, //!< Instruction 'ldsmaxb'. + kIdLdsmaxh, //!< Instruction 'ldsmaxh'. + kIdLdsmaxl, //!< Instruction 'ldsmaxl'. + kIdLdsmaxlb, //!< Instruction 'ldsmaxlb'. + kIdLdsmaxlh, //!< Instruction 'ldsmaxlh'. + kIdLdsmin, //!< Instruction 'ldsmin'. + kIdLdsmina, //!< Instruction 'ldsmina'. + kIdLdsminab, //!< Instruction 'ldsminab'. + kIdLdsminah, //!< Instruction 'ldsminah'. + kIdLdsminal, //!< Instruction 'ldsminal'. + kIdLdsminalb, //!< Instruction 'ldsminalb'. + kIdLdsminalh, //!< Instruction 'ldsminalh'. + kIdLdsminb, //!< Instruction 'ldsminb'. + kIdLdsminh, //!< Instruction 'ldsminh'. + kIdLdsminl, //!< Instruction 'ldsminl'. + kIdLdsminlb, //!< Instruction 'ldsminlb'. + kIdLdsminlh, //!< Instruction 'ldsminlh'. + kIdLdtr, //!< Instruction 'ldtr'. + kIdLdtrb, //!< Instruction 'ldtrb'. + kIdLdtrh, //!< Instruction 'ldtrh'. + kIdLdtrsb, //!< Instruction 'ldtrsb'. + kIdLdtrsh, //!< Instruction 'ldtrsh'. + kIdLdtrsw, //!< Instruction 'ldtrsw'. + kIdLdumax, //!< Instruction 'ldumax'. + kIdLdumaxa, //!< Instruction 'ldumaxa'. + kIdLdumaxab, //!< Instruction 'ldumaxab'. + kIdLdumaxah, //!< Instruction 'ldumaxah'. + kIdLdumaxal, //!< Instruction 'ldumaxal'. + kIdLdumaxalb, //!< Instruction 'ldumaxalb'. + kIdLdumaxalh, //!< Instruction 'ldumaxalh'. + kIdLdumaxb, //!< Instruction 'ldumaxb'. + kIdLdumaxh, //!< Instruction 'ldumaxh'. + kIdLdumaxl, //!< Instruction 'ldumaxl'. + kIdLdumaxlb, //!< Instruction 'ldumaxlb'. + kIdLdumaxlh, //!< Instruction 'ldumaxlh'. + kIdLdumin, //!< Instruction 'ldumin'. + kIdLdumina, //!< Instruction 'ldumina'. + kIdLduminab, //!< Instruction 'lduminab'. + kIdLduminah, //!< Instruction 'lduminah'. + kIdLduminal, //!< Instruction 'lduminal'. + kIdLduminalb, //!< Instruction 'lduminalb'. + kIdLduminalh, //!< Instruction 'lduminalh'. + kIdLduminb, //!< Instruction 'lduminb'. + kIdLduminh, //!< Instruction 'lduminh'. + kIdLduminl, //!< Instruction 'lduminl'. + kIdLduminlb, //!< Instruction 'lduminlb'. + kIdLduminlh, //!< Instruction 'lduminlh'. + kIdLdur, //!< Instruction 'ldur'. + kIdLdurb, //!< Instruction 'ldurb'. + kIdLdurh, //!< Instruction 'ldurh'. + kIdLdursb, //!< Instruction 'ldursb'. + kIdLdursh, //!< Instruction 'ldursh'. + kIdLdursw, //!< Instruction 'ldursw'. + kIdLdxp, //!< Instruction 'ldxp'. + kIdLdxr, //!< Instruction 'ldxr'. + kIdLdxrb, //!< Instruction 'ldxrb'. + kIdLdxrh, //!< Instruction 'ldxrh'. + kIdLsl, //!< Instruction 'lsl'. + kIdLslv, //!< Instruction 'lslv'. + kIdLsr, //!< Instruction 'lsr'. + kIdLsrv, //!< Instruction 'lsrv'. + kIdMadd, //!< Instruction 'madd'. + kIdMneg, //!< Instruction 'mneg'. + kIdMov, //!< Instruction 'mov'. + kIdMovk, //!< Instruction 'movk'. + kIdMovn, //!< Instruction 'movn'. + kIdMovz, //!< Instruction 'movz'. + kIdMrs, //!< Instruction 'mrs'. + kIdMsr, //!< Instruction 'msr'. + kIdMsub, //!< Instruction 'msub'. + kIdMul, //!< Instruction 'mul'. + kIdMvn, //!< Instruction 'mvn'. + kIdNeg, //!< Instruction 'neg'. + kIdNegs, //!< Instruction 'negs'. + kIdNgc, //!< Instruction 'ngc'. + kIdNgcs, //!< Instruction 'ngcs'. + kIdNop, //!< Instruction 'nop'. + kIdOrn, //!< Instruction 'orn'. + kIdOrr, //!< Instruction 'orr'. + kIdPacda, //!< Instruction 'pacda'. + kIdPacdb, //!< Instruction 'pacdb'. + kIdPacdza, //!< Instruction 'pacdza'. + kIdPacdzb, //!< Instruction 'pacdzb'. + kIdPacga, //!< Instruction 'pacga'. + kIdPrfm, //!< Instruction 'prfm'. + kIdPssbb, //!< Instruction 'pssbb'. + kIdRbit, //!< Instruction 'rbit'. + kIdRet, //!< Instruction 'ret'. + kIdRev, //!< Instruction 'rev'. + kIdRev16, //!< Instruction 'rev16'. + kIdRev32, //!< Instruction 'rev32'. + kIdRev64, //!< Instruction 'rev64'. + kIdRor, //!< Instruction 'ror'. + kIdRorv, //!< Instruction 'rorv'. + kIdSbc, //!< Instruction 'sbc'. + kIdSbcs, //!< Instruction 'sbcs'. + kIdSbfiz, //!< Instruction 'sbfiz'. + kIdSbfm, //!< Instruction 'sbfm'. + kIdSbfx, //!< Instruction 'sbfx'. + kIdSdiv, //!< Instruction 'sdiv'. + kIdSetf8, //!< Instruction 'setf8'. + kIdSetf16, //!< Instruction 'setf16'. + kIdSev, //!< Instruction 'sev'. + kIdSevl, //!< Instruction 'sevl'. + kIdSmaddl, //!< Instruction 'smaddl'. + kIdSmc, //!< Instruction 'smc'. + kIdSmnegl, //!< Instruction 'smnegl'. + kIdSmsubl, //!< Instruction 'smsubl'. + kIdSmulh, //!< Instruction 'smulh'. + kIdSmull, //!< Instruction 'smull'. + kIdSsbb, //!< Instruction 'ssbb'. + kIdSt2g, //!< Instruction 'st2g'. + kIdStadd, //!< Instruction 'stadd'. + kIdStaddl, //!< Instruction 'staddl'. + kIdStaddb, //!< Instruction 'staddb'. + kIdStaddlb, //!< Instruction 'staddlb'. + kIdStaddh, //!< Instruction 'staddh'. + kIdStaddlh, //!< Instruction 'staddlh'. + kIdStclr, //!< Instruction 'stclr'. + kIdStclrl, //!< Instruction 'stclrl'. + kIdStclrb, //!< Instruction 'stclrb'. + kIdStclrlb, //!< Instruction 'stclrlb'. + kIdStclrh, //!< Instruction 'stclrh'. + kIdStclrlh, //!< Instruction 'stclrlh'. + kIdSteor, //!< Instruction 'steor'. + kIdSteorl, //!< Instruction 'steorl'. + kIdSteorb, //!< Instruction 'steorb'. + kIdSteorlb, //!< Instruction 'steorlb'. + kIdSteorh, //!< Instruction 'steorh'. + kIdSteorlh, //!< Instruction 'steorlh'. + kIdStg, //!< Instruction 'stg'. + kIdStgm, //!< Instruction 'stgm'. + kIdStgp, //!< Instruction 'stgp'. + kIdStllr, //!< Instruction 'stllr'. + kIdStllrb, //!< Instruction 'stllrb'. + kIdStllrh, //!< Instruction 'stllrh'. + kIdStlr, //!< Instruction 'stlr'. + kIdStlrb, //!< Instruction 'stlrb'. + kIdStlrh, //!< Instruction 'stlrh'. + kIdStlxp, //!< Instruction 'stlxp'. + kIdStlxr, //!< Instruction 'stlxr'. + kIdStlxrb, //!< Instruction 'stlxrb'. + kIdStlxrh, //!< Instruction 'stlxrh'. + kIdStnp, //!< Instruction 'stnp'. + kIdStp, //!< Instruction 'stp'. + kIdStr, //!< Instruction 'str'. + kIdStrb, //!< Instruction 'strb'. + kIdStrh, //!< Instruction 'strh'. + kIdStset, //!< Instruction 'stset'. + kIdStsetl, //!< Instruction 'stsetl'. + kIdStsetb, //!< Instruction 'stsetb'. + kIdStsetlb, //!< Instruction 'stsetlb'. + kIdStseth, //!< Instruction 'stseth'. + kIdStsetlh, //!< Instruction 'stsetlh'. + kIdStsmax, //!< Instruction 'stsmax'. + kIdStsmaxl, //!< Instruction 'stsmaxl'. + kIdStsmaxb, //!< Instruction 'stsmaxb'. + kIdStsmaxlb, //!< Instruction 'stsmaxlb'. + kIdStsmaxh, //!< Instruction 'stsmaxh'. + kIdStsmaxlh, //!< Instruction 'stsmaxlh'. + kIdStsmin, //!< Instruction 'stsmin'. + kIdStsminl, //!< Instruction 'stsminl'. + kIdStsminb, //!< Instruction 'stsminb'. + kIdStsminlb, //!< Instruction 'stsminlb'. + kIdStsminh, //!< Instruction 'stsminh'. + kIdStsminlh, //!< Instruction 'stsminlh'. + kIdSttr, //!< Instruction 'sttr'. + kIdSttrb, //!< Instruction 'sttrb'. + kIdSttrh, //!< Instruction 'sttrh'. + kIdStumax, //!< Instruction 'stumax'. + kIdStumaxl, //!< Instruction 'stumaxl'. + kIdStumaxb, //!< Instruction 'stumaxb'. + kIdStumaxlb, //!< Instruction 'stumaxlb'. + kIdStumaxh, //!< Instruction 'stumaxh'. + kIdStumaxlh, //!< Instruction 'stumaxlh'. + kIdStumin, //!< Instruction 'stumin'. + kIdStuminl, //!< Instruction 'stuminl'. + kIdStuminb, //!< Instruction 'stuminb'. + kIdStuminlb, //!< Instruction 'stuminlb'. + kIdStuminh, //!< Instruction 'stuminh'. + kIdStuminlh, //!< Instruction 'stuminlh'. + kIdStur, //!< Instruction 'stur'. + kIdSturb, //!< Instruction 'sturb'. + kIdSturh, //!< Instruction 'sturh'. + kIdStxp, //!< Instruction 'stxp'. + kIdStxr, //!< Instruction 'stxr'. + kIdStxrb, //!< Instruction 'stxrb'. + kIdStxrh, //!< Instruction 'stxrh'. + kIdStz2g, //!< Instruction 'stz2g'. + kIdStzg, //!< Instruction 'stzg'. + kIdStzgm, //!< Instruction 'stzgm'. + kIdSub, //!< Instruction 'sub'. + kIdSubg, //!< Instruction 'subg'. + kIdSubp, //!< Instruction 'subp'. + kIdSubps, //!< Instruction 'subps'. + kIdSubs, //!< Instruction 'subs'. + kIdSvc, //!< Instruction 'svc'. + kIdSwp, //!< Instruction 'swp'. + kIdSwpa, //!< Instruction 'swpa'. + kIdSwpab, //!< Instruction 'swpab'. + kIdSwpah, //!< Instruction 'swpah'. + kIdSwpal, //!< Instruction 'swpal'. + kIdSwpalb, //!< Instruction 'swpalb'. + kIdSwpalh, //!< Instruction 'swpalh'. + kIdSwpb, //!< Instruction 'swpb'. + kIdSwph, //!< Instruction 'swph'. + kIdSwpl, //!< Instruction 'swpl'. + kIdSwplb, //!< Instruction 'swplb'. + kIdSwplh, //!< Instruction 'swplh'. + kIdSxtb, //!< Instruction 'sxtb'. + kIdSxth, //!< Instruction 'sxth'. + kIdSxtw, //!< Instruction 'sxtw'. + kIdSys, //!< Instruction 'sys'. + kIdTlbi, //!< Instruction 'tlbi'. + kIdTst, //!< Instruction 'tst'. + kIdTbnz, //!< Instruction 'tbnz'. + kIdTbz, //!< Instruction 'tbz'. + kIdUbfiz, //!< Instruction 'ubfiz'. + kIdUbfm, //!< Instruction 'ubfm'. + kIdUbfx, //!< Instruction 'ubfx'. + kIdUdf, //!< Instruction 'udf'. + kIdUdiv, //!< Instruction 'udiv'. + kIdUmaddl, //!< Instruction 'umaddl'. + kIdUmnegl, //!< Instruction 'umnegl'. + kIdUmull, //!< Instruction 'umull'. + kIdUmulh, //!< Instruction 'umulh'. + kIdUmsubl, //!< Instruction 'umsubl'. + kIdUxtb, //!< Instruction 'uxtb'. + kIdUxth, //!< Instruction 'uxth'. + kIdWfe, //!< Instruction 'wfe'. + kIdWfi, //!< Instruction 'wfi'. + kIdXaflag, //!< Instruction 'xaflag'. + kIdXpacd, //!< Instruction 'xpacd'. + kIdXpaci, //!< Instruction 'xpaci'. + kIdXpaclri, //!< Instruction 'xpaclri'. + kIdYield, //!< Instruction 'yield'. + kIdAbs_v, //!< Instruction 'abs' {ASIMD}. + kIdAdd_v, //!< Instruction 'add' {ASIMD}. + kIdAddhn_v, //!< Instruction 'addhn' {ASIMD}. + kIdAddhn2_v, //!< Instruction 'addhn2' {ASIMD}. + kIdAddp_v, //!< Instruction 'addp' {ASIMD}. + kIdAddv_v, //!< Instruction 'addv' {ASIMD}. + kIdAesd_v, //!< Instruction 'aesd' {ASIMD}. + kIdAese_v, //!< Instruction 'aese' {ASIMD}. + kIdAesimc_v, //!< Instruction 'aesimc' {ASIMD}. + kIdAesmc_v, //!< Instruction 'aesmc' {ASIMD}. + kIdAnd_v, //!< Instruction 'and' {ASIMD}. + kIdBcax_v, //!< Instruction 'bcax' {ASIMD}. + kIdBfcvt_v, //!< Instruction 'bfcvt' {ASIMD}. + kIdBfcvtn_v, //!< Instruction 'bfcvtn' {ASIMD}. + kIdBfcvtn2_v, //!< Instruction 'bfcvtn2' {ASIMD}. + kIdBfdot_v, //!< Instruction 'bfdot' {ASIMD}. + kIdBfmlalb_v, //!< Instruction 'bfmlalb' {ASIMD}. + kIdBfmlalt_v, //!< Instruction 'bfmlalt' {ASIMD}. + kIdBfmmla_v, //!< Instruction 'bfmmla' {ASIMD}. + kIdBic_v, //!< Instruction 'bic' {ASIMD}. + kIdBif_v, //!< Instruction 'bif' {ASIMD}. + kIdBit_v, //!< Instruction 'bit' {ASIMD}. + kIdBsl_v, //!< Instruction 'bsl' {ASIMD}. + kIdCls_v, //!< Instruction 'cls' {ASIMD}. + kIdClz_v, //!< Instruction 'clz' {ASIMD}. + kIdCmeq_v, //!< Instruction 'cmeq' {ASIMD}. + kIdCmge_v, //!< Instruction 'cmge' {ASIMD}. + kIdCmgt_v, //!< Instruction 'cmgt' {ASIMD}. + kIdCmhi_v, //!< Instruction 'cmhi' {ASIMD}. + kIdCmhs_v, //!< Instruction 'cmhs' {ASIMD}. + kIdCmle_v, //!< Instruction 'cmle' {ASIMD}. + kIdCmlt_v, //!< Instruction 'cmlt' {ASIMD}. + kIdCmtst_v, //!< Instruction 'cmtst' {ASIMD}. + kIdCnt_v, //!< Instruction 'cnt' {ASIMD}. + kIdDup_v, //!< Instruction 'dup' {ASIMD}. + kIdEor_v, //!< Instruction 'eor' {ASIMD}. + kIdEor3_v, //!< Instruction 'eor3' {ASIMD}. + kIdExt_v, //!< Instruction 'ext' {ASIMD}. + kIdFabd_v, //!< Instruction 'fabd' {ASIMD}. + kIdFabs_v, //!< Instruction 'fabs' {ASIMD}. + kIdFacge_v, //!< Instruction 'facge' {ASIMD}. + kIdFacgt_v, //!< Instruction 'facgt' {ASIMD}. + kIdFadd_v, //!< Instruction 'fadd' {ASIMD}. + kIdFaddp_v, //!< Instruction 'faddp' {ASIMD}. + kIdFcadd_v, //!< Instruction 'fcadd' {ASIMD}. + kIdFccmp_v, //!< Instruction 'fccmp' {ASIMD}. + kIdFccmpe_v, //!< Instruction 'fccmpe' {ASIMD}. + kIdFcmeq_v, //!< Instruction 'fcmeq' {ASIMD}. + kIdFcmge_v, //!< Instruction 'fcmge' {ASIMD}. + kIdFcmgt_v, //!< Instruction 'fcmgt' {ASIMD}. + kIdFcmla_v, //!< Instruction 'fcmla' {ASIMD}. + kIdFcmle_v, //!< Instruction 'fcmle' {ASIMD}. + kIdFcmlt_v, //!< Instruction 'fcmlt' {ASIMD}. + kIdFcmp_v, //!< Instruction 'fcmp' {ASIMD}. + kIdFcmpe_v, //!< Instruction 'fcmpe' {ASIMD}. + kIdFcsel_v, //!< Instruction 'fcsel' {ASIMD}. + kIdFcvt_v, //!< Instruction 'fcvt' {ASIMD}. + kIdFcvtas_v, //!< Instruction 'fcvtas' {ASIMD}. + kIdFcvtau_v, //!< Instruction 'fcvtau' {ASIMD}. + kIdFcvtl_v, //!< Instruction 'fcvtl' {ASIMD}. + kIdFcvtl2_v, //!< Instruction 'fcvtl2' {ASIMD}. + kIdFcvtms_v, //!< Instruction 'fcvtms' {ASIMD}. + kIdFcvtmu_v, //!< Instruction 'fcvtmu' {ASIMD}. + kIdFcvtn_v, //!< Instruction 'fcvtn' {ASIMD}. + kIdFcvtn2_v, //!< Instruction 'fcvtn2' {ASIMD}. + kIdFcvtns_v, //!< Instruction 'fcvtns' {ASIMD}. + kIdFcvtnu_v, //!< Instruction 'fcvtnu' {ASIMD}. + kIdFcvtps_v, //!< Instruction 'fcvtps' {ASIMD}. + kIdFcvtpu_v, //!< Instruction 'fcvtpu' {ASIMD}. + kIdFcvtxn_v, //!< Instruction 'fcvtxn' {ASIMD}. + kIdFcvtxn2_v, //!< Instruction 'fcvtxn2' {ASIMD}. + kIdFcvtzs_v, //!< Instruction 'fcvtzs' {ASIMD}. + kIdFcvtzu_v, //!< Instruction 'fcvtzu' {ASIMD}. + kIdFdiv_v, //!< Instruction 'fdiv' {ASIMD}. + kIdFjcvtzs_v, //!< Instruction 'fjcvtzs' {ASIMD}. + kIdFmadd_v, //!< Instruction 'fmadd' {ASIMD}. + kIdFmax_v, //!< Instruction 'fmax' {ASIMD}. + kIdFmaxnm_v, //!< Instruction 'fmaxnm' {ASIMD}. + kIdFmaxnmp_v, //!< Instruction 'fmaxnmp' {ASIMD}. + kIdFmaxnmv_v, //!< Instruction 'fmaxnmv' {ASIMD}. + kIdFmaxp_v, //!< Instruction 'fmaxp' {ASIMD}. + kIdFmaxv_v, //!< Instruction 'fmaxv' {ASIMD}. + kIdFmin_v, //!< Instruction 'fmin' {ASIMD}. + kIdFminnm_v, //!< Instruction 'fminnm' {ASIMD}. + kIdFminnmp_v, //!< Instruction 'fminnmp' {ASIMD}. + kIdFminnmv_v, //!< Instruction 'fminnmv' {ASIMD}. + kIdFminp_v, //!< Instruction 'fminp' {ASIMD}. + kIdFminv_v, //!< Instruction 'fminv' {ASIMD}. + kIdFmla_v, //!< Instruction 'fmla' {ASIMD}. + kIdFmlal_v, //!< Instruction 'fmlal' {ASIMD}. + kIdFmlal2_v, //!< Instruction 'fmlal2' {ASIMD}. + kIdFmls_v, //!< Instruction 'fmls' {ASIMD}. + kIdFmlsl_v, //!< Instruction 'fmlsl' {ASIMD}. + kIdFmlsl2_v, //!< Instruction 'fmlsl2' {ASIMD}. + kIdFmov_v, //!< Instruction 'fmov' {ASIMD}. + kIdFmsub_v, //!< Instruction 'fmsub' {ASIMD}. + kIdFmul_v, //!< Instruction 'fmul' {ASIMD}. + kIdFmulx_v, //!< Instruction 'fmulx' {ASIMD}. + kIdFneg_v, //!< Instruction 'fneg' {ASIMD}. + kIdFnmadd_v, //!< Instruction 'fnmadd' {ASIMD}. + kIdFnmsub_v, //!< Instruction 'fnmsub' {ASIMD}. + kIdFnmul_v, //!< Instruction 'fnmul' {ASIMD}. + kIdFrecpe_v, //!< Instruction 'frecpe' {ASIMD}. + kIdFrecps_v, //!< Instruction 'frecps' {ASIMD}. + kIdFrecpx_v, //!< Instruction 'frecpx' {ASIMD}. + kIdFrint32x_v, //!< Instruction 'frint32x' {ASIMD}. + kIdFrint32z_v, //!< Instruction 'frint32z' {ASIMD}. + kIdFrint64x_v, //!< Instruction 'frint64x' {ASIMD}. + kIdFrint64z_v, //!< Instruction 'frint64z' {ASIMD}. + kIdFrinta_v, //!< Instruction 'frinta' {ASIMD}. + kIdFrinti_v, //!< Instruction 'frinti' {ASIMD}. + kIdFrintm_v, //!< Instruction 'frintm' {ASIMD}. + kIdFrintn_v, //!< Instruction 'frintn' {ASIMD}. + kIdFrintp_v, //!< Instruction 'frintp' {ASIMD}. + kIdFrintx_v, //!< Instruction 'frintx' {ASIMD}. + kIdFrintz_v, //!< Instruction 'frintz' {ASIMD}. + kIdFrsqrte_v, //!< Instruction 'frsqrte' {ASIMD}. + kIdFrsqrts_v, //!< Instruction 'frsqrts' {ASIMD}. + kIdFsqrt_v, //!< Instruction 'fsqrt' {ASIMD}. + kIdFsub_v, //!< Instruction 'fsub' {ASIMD}. + kIdIns_v, //!< Instruction 'ins' {ASIMD}. + kIdLd1_v, //!< Instruction 'ld1' {ASIMD}. + kIdLd1r_v, //!< Instruction 'ld1r' {ASIMD}. + kIdLd2_v, //!< Instruction 'ld2' {ASIMD}. + kIdLd2r_v, //!< Instruction 'ld2r' {ASIMD}. + kIdLd3_v, //!< Instruction 'ld3' {ASIMD}. + kIdLd3r_v, //!< Instruction 'ld3r' {ASIMD}. + kIdLd4_v, //!< Instruction 'ld4' {ASIMD}. + kIdLd4r_v, //!< Instruction 'ld4r' {ASIMD}. + kIdLdnp_v, //!< Instruction 'ldnp' {ASIMD}. + kIdLdp_v, //!< Instruction 'ldp' {ASIMD}. + kIdLdr_v, //!< Instruction 'ldr' {ASIMD}. + kIdLdur_v, //!< Instruction 'ldur' {ASIMD}. + kIdMla_v, //!< Instruction 'mla' {ASIMD}. + kIdMls_v, //!< Instruction 'mls' {ASIMD}. + kIdMov_v, //!< Instruction 'mov' {ASIMD}. + kIdMovi_v, //!< Instruction 'movi' {ASIMD}. + kIdMul_v, //!< Instruction 'mul' {ASIMD}. + kIdMvn_v, //!< Instruction 'mvn' {ASIMD}. + kIdMvni_v, //!< Instruction 'mvni' {ASIMD}. + kIdNeg_v, //!< Instruction 'neg' {ASIMD}. + kIdNot_v, //!< Instruction 'not' {ASIMD}. + kIdOrn_v, //!< Instruction 'orn' {ASIMD}. + kIdOrr_v, //!< Instruction 'orr' {ASIMD}. + kIdPmul_v, //!< Instruction 'pmul' {ASIMD}. + kIdPmull_v, //!< Instruction 'pmull' {ASIMD}. + kIdPmull2_v, //!< Instruction 'pmull2' {ASIMD}. + kIdRaddhn_v, //!< Instruction 'raddhn' {ASIMD}. + kIdRaddhn2_v, //!< Instruction 'raddhn2' {ASIMD}. + kIdRax1_v, //!< Instruction 'rax1' {ASIMD}. + kIdRbit_v, //!< Instruction 'rbit' {ASIMD}. + kIdRev16_v, //!< Instruction 'rev16' {ASIMD}. + kIdRev32_v, //!< Instruction 'rev32' {ASIMD}. + kIdRev64_v, //!< Instruction 'rev64' {ASIMD}. + kIdRshrn_v, //!< Instruction 'rshrn' {ASIMD}. + kIdRshrn2_v, //!< Instruction 'rshrn2' {ASIMD}. + kIdRsubhn_v, //!< Instruction 'rsubhn' {ASIMD}. + kIdRsubhn2_v, //!< Instruction 'rsubhn2' {ASIMD}. + kIdSaba_v, //!< Instruction 'saba' {ASIMD}. + kIdSabal_v, //!< Instruction 'sabal' {ASIMD}. + kIdSabal2_v, //!< Instruction 'sabal2' {ASIMD}. + kIdSabd_v, //!< Instruction 'sabd' {ASIMD}. + kIdSabdl_v, //!< Instruction 'sabdl' {ASIMD}. + kIdSabdl2_v, //!< Instruction 'sabdl2' {ASIMD}. + kIdSadalp_v, //!< Instruction 'sadalp' {ASIMD}. + kIdSaddl_v, //!< Instruction 'saddl' {ASIMD}. + kIdSaddl2_v, //!< Instruction 'saddl2' {ASIMD}. + kIdSaddlp_v, //!< Instruction 'saddlp' {ASIMD}. + kIdSaddlv_v, //!< Instruction 'saddlv' {ASIMD}. + kIdSaddw_v, //!< Instruction 'saddw' {ASIMD}. + kIdSaddw2_v, //!< Instruction 'saddw2' {ASIMD}. + kIdScvtf_v, //!< Instruction 'scvtf' {ASIMD}. + kIdSdot_v, //!< Instruction 'sdot' {ASIMD}. + kIdSha1c_v, //!< Instruction 'sha1c' {ASIMD}. + kIdSha1h_v, //!< Instruction 'sha1h' {ASIMD}. + kIdSha1m_v, //!< Instruction 'sha1m' {ASIMD}. + kIdSha1p_v, //!< Instruction 'sha1p' {ASIMD}. + kIdSha1su0_v, //!< Instruction 'sha1su0' {ASIMD}. + kIdSha1su1_v, //!< Instruction 'sha1su1' {ASIMD}. + kIdSha256h_v, //!< Instruction 'sha256h' {ASIMD}. + kIdSha256h2_v, //!< Instruction 'sha256h2' {ASIMD}. + kIdSha256su0_v, //!< Instruction 'sha256su0' {ASIMD}. + kIdSha256su1_v, //!< Instruction 'sha256su1' {ASIMD}. + kIdSha512h_v, //!< Instruction 'sha512h' {ASIMD}. + kIdSha512h2_v, //!< Instruction 'sha512h2' {ASIMD}. + kIdSha512su0_v, //!< Instruction 'sha512su0' {ASIMD}. + kIdSha512su1_v, //!< Instruction 'sha512su1' {ASIMD}. + kIdShadd_v, //!< Instruction 'shadd' {ASIMD}. + kIdShl_v, //!< Instruction 'shl' {ASIMD}. + kIdShll_v, //!< Instruction 'shll' {ASIMD}. + kIdShll2_v, //!< Instruction 'shll2' {ASIMD}. + kIdShrn_v, //!< Instruction 'shrn' {ASIMD}. + kIdShrn2_v, //!< Instruction 'shrn2' {ASIMD}. + kIdShsub_v, //!< Instruction 'shsub' {ASIMD}. + kIdSli_v, //!< Instruction 'sli' {ASIMD}. + kIdSm3partw1_v, //!< Instruction 'sm3partw1' {ASIMD}. + kIdSm3partw2_v, //!< Instruction 'sm3partw2' {ASIMD}. + kIdSm3ss1_v, //!< Instruction 'sm3ss1' {ASIMD}. + kIdSm3tt1a_v, //!< Instruction 'sm3tt1a' {ASIMD}. + kIdSm3tt1b_v, //!< Instruction 'sm3tt1b' {ASIMD}. + kIdSm3tt2a_v, //!< Instruction 'sm3tt2a' {ASIMD}. + kIdSm3tt2b_v, //!< Instruction 'sm3tt2b' {ASIMD}. + kIdSm4e_v, //!< Instruction 'sm4e' {ASIMD}. + kIdSm4ekey_v, //!< Instruction 'sm4ekey' {ASIMD}. + kIdSmax_v, //!< Instruction 'smax' {ASIMD}. + kIdSmaxp_v, //!< Instruction 'smaxp' {ASIMD}. + kIdSmaxv_v, //!< Instruction 'smaxv' {ASIMD}. + kIdSmin_v, //!< Instruction 'smin' {ASIMD}. + kIdSminp_v, //!< Instruction 'sminp' {ASIMD}. + kIdSminv_v, //!< Instruction 'sminv' {ASIMD}. + kIdSmlal_v, //!< Instruction 'smlal' {ASIMD}. + kIdSmlal2_v, //!< Instruction 'smlal2' {ASIMD}. + kIdSmlsl_v, //!< Instruction 'smlsl' {ASIMD}. + kIdSmlsl2_v, //!< Instruction 'smlsl2' {ASIMD}. + kIdSmmla_v, //!< Instruction 'smmla' {ASIMD}. + kIdSmov_v, //!< Instruction 'smov' {ASIMD}. + kIdSmull_v, //!< Instruction 'smull' {ASIMD}. + kIdSmull2_v, //!< Instruction 'smull2' {ASIMD}. + kIdSqabs_v, //!< Instruction 'sqabs' {ASIMD}. + kIdSqadd_v, //!< Instruction 'sqadd' {ASIMD}. + kIdSqdmlal_v, //!< Instruction 'sqdmlal' {ASIMD}. + kIdSqdmlal2_v, //!< Instruction 'sqdmlal2' {ASIMD}. + kIdSqdmlsl_v, //!< Instruction 'sqdmlsl' {ASIMD}. + kIdSqdmlsl2_v, //!< Instruction 'sqdmlsl2' {ASIMD}. + kIdSqdmulh_v, //!< Instruction 'sqdmulh' {ASIMD}. + kIdSqdmull_v, //!< Instruction 'sqdmull' {ASIMD}. + kIdSqdmull2_v, //!< Instruction 'sqdmull2' {ASIMD}. + kIdSqneg_v, //!< Instruction 'sqneg' {ASIMD}. + kIdSqrdmlah_v, //!< Instruction 'sqrdmlah' {ASIMD}. + kIdSqrdmlsh_v, //!< Instruction 'sqrdmlsh' {ASIMD}. + kIdSqrdmulh_v, //!< Instruction 'sqrdmulh' {ASIMD}. + kIdSqrshl_v, //!< Instruction 'sqrshl' {ASIMD}. + kIdSqrshrn_v, //!< Instruction 'sqrshrn' {ASIMD}. + kIdSqrshrn2_v, //!< Instruction 'sqrshrn2' {ASIMD}. + kIdSqrshrun_v, //!< Instruction 'sqrshrun' {ASIMD}. + kIdSqrshrun2_v, //!< Instruction 'sqrshrun2' {ASIMD}. + kIdSqshl_v, //!< Instruction 'sqshl' {ASIMD}. + kIdSqshlu_v, //!< Instruction 'sqshlu' {ASIMD}. + kIdSqshrn_v, //!< Instruction 'sqshrn' {ASIMD}. + kIdSqshrn2_v, //!< Instruction 'sqshrn2' {ASIMD}. + kIdSqshrun_v, //!< Instruction 'sqshrun' {ASIMD}. + kIdSqshrun2_v, //!< Instruction 'sqshrun2' {ASIMD}. + kIdSqsub_v, //!< Instruction 'sqsub' {ASIMD}. + kIdSqxtn_v, //!< Instruction 'sqxtn' {ASIMD}. + kIdSqxtn2_v, //!< Instruction 'sqxtn2' {ASIMD}. + kIdSqxtun_v, //!< Instruction 'sqxtun' {ASIMD}. + kIdSqxtun2_v, //!< Instruction 'sqxtun2' {ASIMD}. + kIdSrhadd_v, //!< Instruction 'srhadd' {ASIMD}. + kIdSri_v, //!< Instruction 'sri' {ASIMD}. + kIdSrshl_v, //!< Instruction 'srshl' {ASIMD}. + kIdSrshr_v, //!< Instruction 'srshr' {ASIMD}. + kIdSrsra_v, //!< Instruction 'srsra' {ASIMD}. + kIdSshl_v, //!< Instruction 'sshl' {ASIMD}. + kIdSshll_v, //!< Instruction 'sshll' {ASIMD}. + kIdSshll2_v, //!< Instruction 'sshll2' {ASIMD}. + kIdSshr_v, //!< Instruction 'sshr' {ASIMD}. + kIdSsra_v, //!< Instruction 'ssra' {ASIMD}. + kIdSsubl_v, //!< Instruction 'ssubl' {ASIMD}. + kIdSsubl2_v, //!< Instruction 'ssubl2' {ASIMD}. + kIdSsubw_v, //!< Instruction 'ssubw' {ASIMD}. + kIdSsubw2_v, //!< Instruction 'ssubw2' {ASIMD}. + kIdSt1_v, //!< Instruction 'st1' {ASIMD}. + kIdSt2_v, //!< Instruction 'st2' {ASIMD}. + kIdSt3_v, //!< Instruction 'st3' {ASIMD}. + kIdSt4_v, //!< Instruction 'st4' {ASIMD}. + kIdStnp_v, //!< Instruction 'stnp' {ASIMD}. + kIdStp_v, //!< Instruction 'stp' {ASIMD}. + kIdStr_v, //!< Instruction 'str' {ASIMD}. + kIdStur_v, //!< Instruction 'stur' {ASIMD}. + kIdSub_v, //!< Instruction 'sub' {ASIMD}. + kIdSubhn_v, //!< Instruction 'subhn' {ASIMD}. + kIdSubhn2_v, //!< Instruction 'subhn2' {ASIMD}. + kIdSudot_v, //!< Instruction 'sudot' {ASIMD}. + kIdSuqadd_v, //!< Instruction 'suqadd' {ASIMD}. + kIdSxtl_v, //!< Instruction 'sxtl' {ASIMD}. + kIdSxtl2_v, //!< Instruction 'sxtl2' {ASIMD}. + kIdTbl_v, //!< Instruction 'tbl' {ASIMD}. + kIdTbx_v, //!< Instruction 'tbx' {ASIMD}. + kIdTrn1_v, //!< Instruction 'trn1' {ASIMD}. + kIdTrn2_v, //!< Instruction 'trn2' {ASIMD}. + kIdUaba_v, //!< Instruction 'uaba' {ASIMD}. + kIdUabal_v, //!< Instruction 'uabal' {ASIMD}. + kIdUabal2_v, //!< Instruction 'uabal2' {ASIMD}. + kIdUabd_v, //!< Instruction 'uabd' {ASIMD}. + kIdUabdl_v, //!< Instruction 'uabdl' {ASIMD}. + kIdUabdl2_v, //!< Instruction 'uabdl2' {ASIMD}. + kIdUadalp_v, //!< Instruction 'uadalp' {ASIMD}. + kIdUaddl_v, //!< Instruction 'uaddl' {ASIMD}. + kIdUaddl2_v, //!< Instruction 'uaddl2' {ASIMD}. + kIdUaddlp_v, //!< Instruction 'uaddlp' {ASIMD}. + kIdUaddlv_v, //!< Instruction 'uaddlv' {ASIMD}. + kIdUaddw_v, //!< Instruction 'uaddw' {ASIMD}. + kIdUaddw2_v, //!< Instruction 'uaddw2' {ASIMD}. + kIdUcvtf_v, //!< Instruction 'ucvtf' {ASIMD}. + kIdUdot_v, //!< Instruction 'udot' {ASIMD}. + kIdUhadd_v, //!< Instruction 'uhadd' {ASIMD}. + kIdUhsub_v, //!< Instruction 'uhsub' {ASIMD}. + kIdUmax_v, //!< Instruction 'umax' {ASIMD}. + kIdUmaxp_v, //!< Instruction 'umaxp' {ASIMD}. + kIdUmaxv_v, //!< Instruction 'umaxv' {ASIMD}. + kIdUmin_v, //!< Instruction 'umin' {ASIMD}. + kIdUminp_v, //!< Instruction 'uminp' {ASIMD}. + kIdUminv_v, //!< Instruction 'uminv' {ASIMD}. + kIdUmlal_v, //!< Instruction 'umlal' {ASIMD}. + kIdUmlal2_v, //!< Instruction 'umlal2' {ASIMD}. + kIdUmlsl_v, //!< Instruction 'umlsl' {ASIMD}. + kIdUmlsl2_v, //!< Instruction 'umlsl2' {ASIMD}. + kIdUmmla_v, //!< Instruction 'ummla' {ASIMD}. + kIdUmov_v, //!< Instruction 'umov' {ASIMD}. + kIdUmull_v, //!< Instruction 'umull' {ASIMD}. + kIdUmull2_v, //!< Instruction 'umull2' {ASIMD}. + kIdUqadd_v, //!< Instruction 'uqadd' {ASIMD}. + kIdUqrshl_v, //!< Instruction 'uqrshl' {ASIMD}. + kIdUqrshrn_v, //!< Instruction 'uqrshrn' {ASIMD}. + kIdUqrshrn2_v, //!< Instruction 'uqrshrn2' {ASIMD}. + kIdUqshl_v, //!< Instruction 'uqshl' {ASIMD}. + kIdUqshrn_v, //!< Instruction 'uqshrn' {ASIMD}. + kIdUqshrn2_v, //!< Instruction 'uqshrn2' {ASIMD}. + kIdUqsub_v, //!< Instruction 'uqsub' {ASIMD}. + kIdUqxtn_v, //!< Instruction 'uqxtn' {ASIMD}. + kIdUqxtn2_v, //!< Instruction 'uqxtn2' {ASIMD}. + kIdUrecpe_v, //!< Instruction 'urecpe' {ASIMD}. + kIdUrhadd_v, //!< Instruction 'urhadd' {ASIMD}. + kIdUrshl_v, //!< Instruction 'urshl' {ASIMD}. + kIdUrshr_v, //!< Instruction 'urshr' {ASIMD}. + kIdUrsqrte_v, //!< Instruction 'ursqrte' {ASIMD}. + kIdUrsra_v, //!< Instruction 'ursra' {ASIMD}. + kIdUsdot_v, //!< Instruction 'usdot' {ASIMD}. + kIdUshl_v, //!< Instruction 'ushl' {ASIMD}. + kIdUshll_v, //!< Instruction 'ushll' {ASIMD}. + kIdUshll2_v, //!< Instruction 'ushll2' {ASIMD}. + kIdUshr_v, //!< Instruction 'ushr' {ASIMD}. + kIdUsmmla_v, //!< Instruction 'usmmla' {ASIMD}. + kIdUsqadd_v, //!< Instruction 'usqadd' {ASIMD}. + kIdUsra_v, //!< Instruction 'usra' {ASIMD}. + kIdUsubl_v, //!< Instruction 'usubl' {ASIMD}. + kIdUsubl2_v, //!< Instruction 'usubl2' {ASIMD}. + kIdUsubw_v, //!< Instruction 'usubw' {ASIMD}. + kIdUsubw2_v, //!< Instruction 'usubw2' {ASIMD}. + kIdUxtl_v, //!< Instruction 'uxtl' {ASIMD}. + kIdUxtl2_v, //!< Instruction 'uxtl2' {ASIMD}. + kIdUzp1_v, //!< Instruction 'uzp1' {ASIMD}. + kIdUzp2_v, //!< Instruction 'uzp2' {ASIMD}. + kIdXar_v, //!< Instruction 'xar' {ASIMD}. + kIdXtn_v, //!< Instruction 'xtn' {ASIMD}. + kIdXtn2_v, //!< Instruction 'xtn2' {ASIMD}. + kIdZip1_v, //!< Instruction 'zip1' {ASIMD}. + kIdZip2_v, //!< Instruction 'zip2' {ASIMD}. + _kIdCount + // ${InstId:End} + }; + + //! Tests whether the `instId` is defined (counts also Inst::kIdNone, which must be zero). + static ASMJIT_INLINE_NODEBUG bool isDefinedId(InstId instId) noexcept { return (instId & uint32_t(InstIdParts::kRealId)) < _kIdCount; } +}; + +namespace Predicate { + +//! Address translate options (AT). +namespace AT { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + enum Value : uint32_t { + kS1E1R = encode(0b000, 0b0111, 0b1000, 0b000), + kS1E2R = encode(0b100, 0b0111, 0b1000, 0b000), + kS1E3R = encode(0b110, 0b0111, 0b1000, 0b000), + kS1E1W = encode(0b000, 0b0111, 0b1000, 0b001), + kS1E2W = encode(0b100, 0b0111, 0b1000, 0b001), + kS1E3W = encode(0b110, 0b0111, 0b1000, 0b001), + kS1E0R = encode(0b000, 0b0111, 0b1000, 0b010), + kS1E0W = encode(0b000, 0b0111, 0b1000, 0b011), + kS12E1R = encode(0b100, 0b0111, 0b1000, 0b100), + kS12E1W = encode(0b100, 0b0111, 0b1000, 0b101), + kS12E0R = encode(0b100, 0b0111, 0b1000, 0b110), + kS12E0W = encode(0b100, 0b0111, 0b1000, 0b111), + kS1E1RP = encode(0b000, 0b0111, 0b1001, 0b000), + kS1E1WP = encode(0b000, 0b0111, 0b1001, 0b001) + }; +} + +//! Data barrier options (DMB/DSB). +namespace DB { + //! Data barrier immediate values. + enum Value : uint32_t { + //! Waits only for loads to complete, and only applies to the outer shareable domain. + kOSHLD = 0x01u, + //! Waits only for stores to complete, and only applies to the outer shareable domain. + kOSHST = 0x02u, + //! Only applies to the outer shareable domain. + kOSH = 0x03u, + + //! Waits only for loads to complete and only applies out to the point of unification. + kNSHLD = 0x05u, + //! Waits only for stores to complete and only applies out to the point of unification. + kNSHST = 0x06u, + //! Only applies out to the point of unification. + kNSH = 0x07u, + + //! Waits only for loads to complete, and only applies to the inner shareable domain. + kISHLD = 0x09u, + //! Waits only for stores to complete, and only applies to the inner shareable domain. + kISHST = 0x0Au, + //! Only applies to the inner shareable domain. + kISH = 0x0Bu, + + //! Waits only for loads to complete. + kLD = 0x0Du, + //! Waits only for stores to complete. + kST = 0x0Eu, + //! Full system memory barrier operation. + kSY = 0x0Fu + }; +} + +//! Data cache maintenance options. +namespace DC { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + //! Data cache maintenance immediate values. + enum Value : uint32_t { + kZVA = encode(0b011, 0b0111, 0b0100, 0b001), + kIVAC = encode(0b000, 0b0111, 0b0110, 0b001), + kISW = encode(0b000, 0b0111, 0b0110, 0b010), + kCVAC = encode(0b011, 0b0111, 0b1010, 0b001), + kCSW = encode(0b000, 0b0111, 0b1010, 0b010), + kCVAU = encode(0b011, 0b0111, 0b1011, 0b001), + kCIVAC = encode(0b011, 0b0111, 0b1110, 0b001), + kCISW = encode(0b000, 0b0111, 0b1110, 0b010), + kCVAP = encode(0b011, 0b0111, 0b1100, 0b001), + kCVADP = encode(0b011, 0b0111, 0b1101, 0b001), + kIGVAC = encode(0b000, 0b0111, 0b0110, 0b011), + kIGSW = encode(0b000, 0b0111, 0b0110, 0b100), + kCGSW = encode(0b000, 0b0111, 0b1010, 0b100), + kCIGSW = encode(0b000, 0b0111, 0b1110, 0b100), + kCGVAC = encode(0b011, 0b0111, 0b1010, 0b011), + kCGVAP = encode(0b011, 0b0111, 0b1100, 0b011), + kCGVADP = encode(0b011, 0b0111, 0b1101, 0b011), + kCIGVAC = encode(0b011, 0b0111, 0b1110, 0b011), + kGVA = encode(0b011, 0b0111, 0b0100, 0b011), + kIGDVAC = encode(0b000, 0b0111, 0b0110, 0b101), + kIGDSW = encode(0b000, 0b0111, 0b0110, 0b110), + kCGDSW = encode(0b000, 0b0111, 0b1010, 0b110), + kCIGDSW = encode(0b000, 0b0111, 0b1110, 0b110), + kCGDVAC = encode(0b011, 0b0111, 0b1010, 0b101), + kCGDVAP = encode(0b011, 0b0111, 0b1100, 0b101), + kCGDVADP = encode(0b011, 0b0111, 0b1101, 0b101), + kCIGDVAC = encode(0b011, 0b0111, 0b1110, 0b101), + kGZVA = encode(0b011, 0b0111, 0b0100, 0b100) + }; +} + +//! Instruction cache maintenance options. +namespace IC { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + //! Instruction cache maintenance immediate values. + enum Value : uint32_t { + kIALLUIS = encode(0b000, 0b0111, 0b0001, 0b000), + kIALLU = encode(0b000, 0b0111, 0b0101, 0b000), + kIVAU = encode(0b011, 0b0111, 0b0101, 0b001) + }; +} + +//! Instruction-fetch barrier options. +namespace ISB { + //! Instruction-fetch barrier immediate values. + enum Value : uint32_t { + kSY = 0xF + }; +} + +//! Prefetch options. +namespace PRFOp { + //! Prefetch immediate values. + enum Value : uint32_t { + kPLDL1KEEP = 0x00, + kPLDL1STRM = 0x01, + kPLDL2KEEP = 0x02, + kPLDL2STRM = 0x03, + kPLDL3KEEP = 0x04, + kPLDL3STRM = 0x05, + kPLIL1KEEP = 0x08, + kPLIL1STRM = 0x09, + kPLIL2KEEP = 0x0A, + kPLIL2STRM = 0x0B, + kPLIL3KEEP = 0x0C, + kPLIL3STRM = 0x0D, + kPSTL1KEEP = 0x10, + kPSTL1STRM = 0x11, + kPSTL2KEEP = 0x12, + kPSTL2STRM = 0x13, + kPSTL3KEEP = 0x14, + kPSTL3STRM = 0x15 + }; +} + +//! PSB instruction options. +namespace PSB { + //! PSB immediate values. + enum Value : uint32_t { + kCSYNC = 0x11u + }; +} + +namespace TLBI { + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + enum Value : uint32_t { + kIPAS2E1IS = encode(0b100, 0b1000, 0b0000, 0b001), + kIPAS2LE1IS = encode(0b100, 0b1000, 0b0000, 0b101), + kVMALLE1IS = encode(0b000, 0b1000, 0b0011, 0b000), + kALLE2IS = encode(0b100, 0b1000, 0b0011, 0b000), + kALLE3IS = encode(0b110, 0b1000, 0b0011, 0b000), + kVAE1IS = encode(0b000, 0b1000, 0b0011, 0b001), + kVAE2IS = encode(0b100, 0b1000, 0b0011, 0b001), + kVAE3IS = encode(0b110, 0b1000, 0b0011, 0b001), + kASIDE1IS = encode(0b000, 0b1000, 0b0011, 0b010), + kVAAE1IS = encode(0b000, 0b1000, 0b0011, 0b011), + kALLE1IS = encode(0b100, 0b1000, 0b0011, 0b100), + kVALE1IS = encode(0b000, 0b1000, 0b0011, 0b101), + kVALE2IS = encode(0b100, 0b1000, 0b0011, 0b101), + kVALE3IS = encode(0b110, 0b1000, 0b0011, 0b101), + kVMALLS12E1IS = encode(0b100, 0b1000, 0b0011, 0b110), + kVAALE1IS = encode(0b000, 0b1000, 0b0011, 0b111), + kIPAS2E1 = encode(0b100, 0b1000, 0b0100, 0b001), + kIPAS2LE1 = encode(0b100, 0b1000, 0b0100, 0b101), + kVMALLE1 = encode(0b000, 0b1000, 0b0111, 0b000), + kALLE2 = encode(0b100, 0b1000, 0b0111, 0b000), + kALLE3 = encode(0b110, 0b1000, 0b0111, 0b000), + kVAE1 = encode(0b000, 0b1000, 0b0111, 0b001), + kVAE2 = encode(0b100, 0b1000, 0b0111, 0b001), + kVAE3 = encode(0b110, 0b1000, 0b0111, 0b001), + kASIDE1 = encode(0b000, 0b1000, 0b0111, 0b010), + kVAAE1 = encode(0b000, 0b1000, 0b0111, 0b011), + kALLE1 = encode(0b100, 0b1000, 0b0111, 0b100), + kVALE1 = encode(0b000, 0b1000, 0b0111, 0b101), + kVALE2 = encode(0b100, 0b1000, 0b0111, 0b101), + kVALE3 = encode(0b110, 0b1000, 0b0111, 0b101), + kVMALLS12E1 = encode(0b100, 0b1000, 0b0111, 0b110), + kVAALE1 = encode(0b000, 0b1000, 0b0111, 0b111), + + kVMALLE1OS = encode(0b000, 0b1000, 0b0001, 0b000), + kVAE1OS = encode(0b000, 0b1000, 0b0001, 0b001), + kASIDE1OS = encode(0b000, 0b1000, 0b0001, 0b010), + kVAAE1OS = encode(0b000, 0b1000, 0b0001, 0b011), + kVALE1OS = encode(0b000, 0b1000, 0b0001, 0b101), + kVAALE1OS = encode(0b000, 0b1000, 0b0001, 0b111), + kIPAS2E1OS = encode(0b100, 0b1000, 0b0100, 0b000), + kIPAS2LE1OS = encode(0b100, 0b1000, 0b0100, 0b100), + kVAE2OS = encode(0b100, 0b1000, 0b0001, 0b001), + kVALE2OS = encode(0b100, 0b1000, 0b0001, 0b101), + kVMALLS12E1OS = encode(0b100, 0b1000, 0b0001, 0b110), + kVAE3OS = encode(0b110, 0b1000, 0b0001, 0b001), + kVALE3OS = encode(0b110, 0b1000, 0b0001, 0b101), + kALLE2OS = encode(0b100, 0b1000, 0b0001, 0b000), + kALLE1OS = encode(0b100, 0b1000, 0b0001, 0b100), + kALLE3OS = encode(0b110, 0b1000, 0b0001, 0b000), + + kRVAE1 = encode(0b000, 0b1000, 0b0110, 0b001), + kRVAAE1 = encode(0b000, 0b1000, 0b0110, 0b011), + kRVALE1 = encode(0b000, 0b1000, 0b0110, 0b101), + kRVAALE1 = encode(0b000, 0b1000, 0b0110, 0b111), + kRVAE1IS = encode(0b000, 0b1000, 0b0010, 0b001), + kRVAAE1IS = encode(0b000, 0b1000, 0b0010, 0b011), + kRVALE1IS = encode(0b000, 0b1000, 0b0010, 0b101), + kRVAALE1IS = encode(0b000, 0b1000, 0b0010, 0b111), + kRVAE1OS = encode(0b000, 0b1000, 0b0101, 0b001), + kRVAAE1OS = encode(0b000, 0b1000, 0b0101, 0b011), + kRVALE1OS = encode(0b000, 0b1000, 0b0101, 0b101), + kRVAALE1OS = encode(0b000, 0b1000, 0b0101, 0b111), + kRIPAS2E1IS = encode(0b100, 0b1000, 0b0000, 0b010), + kRIPAS2LE1IS = encode(0b100, 0b1000, 0b0000, 0b110), + kRIPAS2E1 = encode(0b100, 0b1000, 0b0100, 0b010), + kRIPAS2LE1 = encode(0b100, 0b1000, 0b0100, 0b110), + kRIPAS2E1OS = encode(0b100, 0b1000, 0b0100, 0b011), + kRIPAS2LE1OS = encode(0b100, 0b1000, 0b0100, 0b111), + kRVAE2 = encode(0b100, 0b1000, 0b0110, 0b001), + kRVALE2 = encode(0b100, 0b1000, 0b0110, 0b101), + kRVAE2IS = encode(0b100, 0b1000, 0b0010, 0b001), + kRVALE2IS = encode(0b100, 0b1000, 0b0010, 0b101), + kRVAE2OS = encode(0b100, 0b1000, 0b0101, 0b001), + kRVALE2OS = encode(0b100, 0b1000, 0b0101, 0b101), + kRVAE3 = encode(0b110, 0b1000, 0b0110, 0b001), + kRVALE3 = encode(0b110, 0b1000, 0b0110, 0b101), + kRVAE3IS = encode(0b110, 0b1000, 0b0010, 0b001), + kRVALE3IS = encode(0b110, 0b1000, 0b0010, 0b101), + kRVAE3OS = encode(0b110, 0b1000, 0b0101, 0b001), + kRVALE3OS = encode(0b110, 0b1000, 0b0101, 0b101), + }; +} + +//! Trace synchronization barrier options. +namespace TSB { + //! Trace synchronization immediate values. + enum Value : uint32_t { + kCSYNC = 0 + }; +} + +//! Processor state access through MSR. +namespace PState { + //! Encodes a pstate from `op0` and `op1`. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op0, uint32_t op1) noexcept { + return (op0 << 3) | (op1 << 0); + } + + //! Processor state access immediates. + enum Value : uint32_t { + kSPSel = encode(0b000, 0b101), + kDAIFSet = encode(0b011, 0b110), + kDAIFClr = encode(0b011, 0b111), + kPAN = encode(0b000, 0b100), + kUAO = encode(0b000, 0b011), + kDIT = encode(0b011, 0b010), + kSSBS = encode(0b011, 0b001), + kTCO = encode(0b011, 0b100) + }; +}; + +//! System register identifiers and utilities (MSR/MRS). +namespace SysReg { + //! System register fields. + struct Fields { + uint8_t op0; + uint8_t op1; + uint8_t cRn; + uint8_t cRm; + uint8_t op2; + }; + + //! Encodes a system register from `op0`, `op1`, `cRn`, `cRm`, and `op2` fields. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(uint32_t op0, uint32_t op1, uint32_t cRn, uint32_t cRm, uint32_t op2) noexcept { + return (op0 << 14) | (op1 << 11) | (cRn << 7) | (cRm << 3) | (op2 << 0); + } + + //! Encodes a system register from `fields`. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t encode(const Fields& fields) noexcept { + return encode(fields.op0, fields.op1, fields.cRn, fields.cRm, fields.op2); + } + + //! Decodes a system register to \ref Fields. + static ASMJIT_INLINE_NODEBUG constexpr Fields decode(uint32_t id) noexcept { + return Fields { + uint8_t((id >> 14) & 0x3u), + uint8_t((id >> 11) & 0x7u), + uint8_t((id >> 7) & 0xFu), + uint8_t((id >> 3) & 0xFu), + uint8_t((id >> 0) & 0x7u) + }; + } + + //! System register identifiers. + enum Id : uint32_t { + kACTLR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b001), // RW + kACTLR_EL2 = encode(0b11, 0b100, 0b0001, 0b0000, 0b001), // RW + kACTLR_EL3 = encode(0b11, 0b110, 0b0001, 0b0000, 0b001), // RW + kAFSR0_EL1 = encode(0b11, 0b000, 0b0101, 0b0001, 0b000), // RW + kAFSR0_EL12 = encode(0b11, 0b101, 0b0101, 0b0001, 0b000), // RW + kAFSR0_EL2 = encode(0b11, 0b100, 0b0101, 0b0001, 0b000), // RW + kAFSR0_EL3 = encode(0b11, 0b110, 0b0101, 0b0001, 0b000), // RW + kAFSR1_EL1 = encode(0b11, 0b000, 0b0101, 0b0001, 0b001), // RW + kAFSR1_EL12 = encode(0b11, 0b101, 0b0101, 0b0001, 0b001), // RW + kAFSR1_EL2 = encode(0b11, 0b100, 0b0101, 0b0001, 0b001), // RW + kAFSR1_EL3 = encode(0b11, 0b110, 0b0101, 0b0001, 0b001), // RW + kAIDR_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b111), // RO + kAMAIR_EL1 = encode(0b11, 0b000, 0b1010, 0b0011, 0b000), // RW + kAMAIR_EL12 = encode(0b11, 0b101, 0b1010, 0b0011, 0b000), // RW + kAMAIR_EL2 = encode(0b11, 0b100, 0b1010, 0b0011, 0b000), // RW + kAMAIR_EL3 = encode(0b11, 0b110, 0b1010, 0b0011, 0b000), // RW + kAMCFGR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b001), // RO + kAMCGCR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b010), // RO + kAMCNTENCLR0_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b100), // RW + kAMCNTENCLR1_EL0 = encode(0b11, 0b011, 0b1101, 0b0011, 0b000), // RW + kAMCNTENSET0_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b101), // RW + kAMCNTENSET1_EL0 = encode(0b11, 0b011, 0b1101, 0b0011, 0b001), // RW + kAMCR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b000), // RW + kAMEVCNTR00_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b000), // RW + kAMEVCNTR01_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b001), // RW + kAMEVCNTR02_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b010), // RW + kAMEVCNTR03_EL0 = encode(0b11, 0b011, 0b1101, 0b0100, 0b011), // RW + kAMEVCNTR10_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b000), // RW + kAMEVCNTR110_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b010), // RW + kAMEVCNTR111_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b011), // RW + kAMEVCNTR112_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b100), // RW + kAMEVCNTR113_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b101), // RW + kAMEVCNTR114_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b110), // RW + kAMEVCNTR115_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b111), // RW + kAMEVCNTR11_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b001), // RW + kAMEVCNTR12_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b010), // RW + kAMEVCNTR13_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b011), // RW + kAMEVCNTR14_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b100), // RW + kAMEVCNTR15_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b101), // RW + kAMEVCNTR16_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b110), // RW + kAMEVCNTR17_EL0 = encode(0b11, 0b011, 0b1101, 0b1100, 0b111), // RW + kAMEVCNTR18_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b000), // RW + kAMEVCNTR19_EL0 = encode(0b11, 0b011, 0b1101, 0b1101, 0b001), // RW + kAMEVTYPER00_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b000), // RO + kAMEVTYPER01_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b001), // RO + kAMEVTYPER02_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b010), // RO + kAMEVTYPER03_EL0 = encode(0b11, 0b011, 0b1101, 0b0110, 0b011), // RO + kAMEVTYPER10_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b000), // RW + kAMEVTYPER110_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b010), // RW + kAMEVTYPER111_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b011), // RW + kAMEVTYPER112_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b100), // RW + kAMEVTYPER113_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b101), // RW + kAMEVTYPER114_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b110), // RW + kAMEVTYPER115_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b111), // RW + kAMEVTYPER11_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b001), // RW + kAMEVTYPER12_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b010), // RW + kAMEVTYPER13_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b011), // RW + kAMEVTYPER14_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b100), // RW + kAMEVTYPER15_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b101), // RW + kAMEVTYPER16_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b110), // RW + kAMEVTYPER17_EL0 = encode(0b11, 0b011, 0b1101, 0b1110, 0b111), // RW + kAMEVTYPER18_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b000), // RW + kAMEVTYPER19_EL0 = encode(0b11, 0b011, 0b1101, 0b1111, 0b001), // RW + kAMUSERENR_EL0 = encode(0b11, 0b011, 0b1101, 0b0010, 0b011), // RW + kAPDAKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b001), // RW + kAPDAKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b000), // RW + kAPDBKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b011), // RW + kAPDBKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0010, 0b010), // RW + kAPGAKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0011, 0b001), // RW + kAPGAKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0011, 0b000), // RW + kAPIAKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b001), // RW + kAPIAKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b000), // RW + kAPIBKeyHi_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b011), // RW + kAPIBKeyLo_EL1 = encode(0b11, 0b000, 0b0010, 0b0001, 0b010), // RW + kCCSIDR2_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b010), // RO + kCCSIDR_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b000), // RO + kCLIDR_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b001), // RO + kCNTFRQ_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b000), // RW + kCNTHCTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0001, 0b000), // RW + kCNTHPS_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0101, 0b001), // RW + kCNTHPS_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0101, 0b010), // RW + kCNTHPS_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0101, 0b000), // RW + kCNTHP_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0010, 0b001), // RW + kCNTHP_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0010, 0b010), // RW + kCNTHP_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0010, 0b000), // RW + kCNTHVS_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0100, 0b001), // RW + kCNTHVS_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0100, 0b010), // RW + kCNTHVS_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0100, 0b000), // RW + kCNTHV_CTL_EL2 = encode(0b11, 0b100, 0b1110, 0b0011, 0b001), // RW + kCNTHV_CVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0011, 0b010), // RW + kCNTHV_TVAL_EL2 = encode(0b11, 0b100, 0b1110, 0b0011, 0b000), // RW + kCNTISCALE_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b101), // RW + kCNTKCTL_EL1 = encode(0b11, 0b000, 0b1110, 0b0001, 0b000), // RW + kCNTKCTL_EL12 = encode(0b11, 0b101, 0b1110, 0b0001, 0b000), // RW + kCNTPCTSS_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b101), // RW + kCNTPCT_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b001), // RO + kCNTPOFF_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b110), // RW + kCNTPS_CTL_EL1 = encode(0b11, 0b111, 0b1110, 0b0010, 0b001), // RW + kCNTPS_CVAL_EL1 = encode(0b11, 0b111, 0b1110, 0b0010, 0b010), // RW + kCNTPS_TVAL_EL1 = encode(0b11, 0b111, 0b1110, 0b0010, 0b000), // RW + kCNTP_CTL_EL0 = encode(0b11, 0b011, 0b1110, 0b0010, 0b001), // RW + kCNTP_CTL_EL02 = encode(0b11, 0b101, 0b1110, 0b0010, 0b001), // RW + kCNTP_CVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0010, 0b010), // RW + kCNTP_CVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0010, 0b010), // RW + kCNTP_TVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0010, 0b000), // RW + kCNTP_TVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0010, 0b000), // RW + kCNTSCALE_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b100), // RW + kCNTVCTSS_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b110), // RW + kCNTVCT_EL0 = encode(0b11, 0b011, 0b1110, 0b0000, 0b010), // RO + kCNTVFRQ_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b111), // RW + kCNTVOFF_EL2 = encode(0b11, 0b100, 0b1110, 0b0000, 0b011), // RW + kCNTV_CTL_EL0 = encode(0b11, 0b011, 0b1110, 0b0011, 0b001), // RW + kCNTV_CTL_EL02 = encode(0b11, 0b101, 0b1110, 0b0011, 0b001), // RW + kCNTV_CVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0011, 0b010), // RW + kCNTV_CVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0011, 0b010), // RW + kCNTV_TVAL_EL0 = encode(0b11, 0b011, 0b1110, 0b0011, 0b000), // RW + kCNTV_TVAL_EL02 = encode(0b11, 0b101, 0b1110, 0b0011, 0b000), // RW + kCONTEXTIDR_EL1 = encode(0b11, 0b000, 0b1101, 0b0000, 0b001), // RW + kCONTEXTIDR_EL12 = encode(0b11, 0b101, 0b1101, 0b0000, 0b001), // RW + kCONTEXTIDR_EL2 = encode(0b11, 0b100, 0b1101, 0b0000, 0b001), // RW + kCPACR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b010), // RW + kCPACR_EL12 = encode(0b11, 0b101, 0b0001, 0b0000, 0b010), // RW + kCPM_IOACC_CTL_EL3 = encode(0b11, 0b111, 0b1111, 0b0010, 0b000), // RW + kCPTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b010), // RW + kCPTR_EL3 = encode(0b11, 0b110, 0b0001, 0b0001, 0b010), // RW + kCSSELR_EL1 = encode(0b11, 0b010, 0b0000, 0b0000, 0b000), // RW + kCTR_EL0 = encode(0b11, 0b011, 0b0000, 0b0000, 0b001), // RO + kCurrentEL = encode(0b11, 0b000, 0b0100, 0b0010, 0b010), // RO + kDACR32_EL2 = encode(0b11, 0b100, 0b0011, 0b0000, 0b000), // RW + kDAIF = encode(0b11, 0b011, 0b0100, 0b0010, 0b001), // RW + kDBGAUTHSTATUS_EL1 = encode(0b10, 0b000, 0b0111, 0b1110, 0b110), // RO + kDBGBCR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b101), // RW + kDBGBCR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b101), // RW + kDBGBCR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b101), // RW + kDBGBCR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b101), // RW + kDBGBCR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b101), // RW + kDBGBCR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b101), // RW + kDBGBCR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b101), // RW + kDBGBCR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b101), // RW + kDBGBCR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b101), // RW + kDBGBCR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b101), // RW + kDBGBCR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b101), // RW + kDBGBCR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b101), // RW + kDBGBCR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b101), // RW + kDBGBCR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b101), // RW + kDBGBCR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b101), // RW + kDBGBCR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b101), // RW + kDBGBVR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b100), // RW + kDBGBVR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b100), // RW + kDBGBVR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b100), // RW + kDBGBVR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b100), // RW + kDBGBVR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b100), // RW + kDBGBVR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b100), // RW + kDBGBVR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b100), // RW + kDBGBVR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b100), // RW + kDBGBVR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b100), // RW + kDBGBVR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b100), // RW + kDBGBVR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b100), // RW + kDBGBVR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b100), // RW + kDBGBVR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b100), // RW + kDBGBVR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b100), // RW + kDBGBVR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b100), // RW + kDBGBVR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b100), // RW + kDBGCLAIMCLR_EL1 = encode(0b10, 0b000, 0b0111, 0b1001, 0b110), // RW + kDBGCLAIMSET_EL1 = encode(0b10, 0b000, 0b0111, 0b1000, 0b110), // RW + kDBGDTRRX_EL0 = encode(0b10, 0b011, 0b0000, 0b0101, 0b000), // RO + kDBGDTRTX_EL0 = encode(0b10, 0b011, 0b0000, 0b0101, 0b000), // WO + kDBGDTR_EL0 = encode(0b10, 0b011, 0b0000, 0b0100, 0b000), // RW + kDBGPRCR_EL1 = encode(0b10, 0b000, 0b0001, 0b0100, 0b100), // RW + kDBGVCR32_EL2 = encode(0b10, 0b100, 0b0000, 0b0111, 0b000), // RW + kDBGWCR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b111), // RW + kDBGWCR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b111), // RW + kDBGWCR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b111), // RW + kDBGWCR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b111), // RW + kDBGWCR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b111), // RW + kDBGWCR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b111), // RW + kDBGWCR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b111), // RW + kDBGWCR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b111), // RW + kDBGWCR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b111), // RW + kDBGWCR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b111), // RW + kDBGWCR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b111), // RW + kDBGWCR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b111), // RW + kDBGWCR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b111), // RW + kDBGWCR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b111), // RW + kDBGWCR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b111), // RW + kDBGWCR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b111), // RW + kDBGWVR0_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b110), // RW + kDBGWVR10_EL1 = encode(0b10, 0b000, 0b0000, 0b1010, 0b110), // RW + kDBGWVR11_EL1 = encode(0b10, 0b000, 0b0000, 0b1011, 0b110), // RW + kDBGWVR12_EL1 = encode(0b10, 0b000, 0b0000, 0b1100, 0b110), // RW + kDBGWVR13_EL1 = encode(0b10, 0b000, 0b0000, 0b1101, 0b110), // RW + kDBGWVR14_EL1 = encode(0b10, 0b000, 0b0000, 0b1110, 0b110), // RW + kDBGWVR15_EL1 = encode(0b10, 0b000, 0b0000, 0b1111, 0b110), // RW + kDBGWVR1_EL1 = encode(0b10, 0b000, 0b0000, 0b0001, 0b110), // RW + kDBGWVR2_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b110), // RW + kDBGWVR3_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b110), // RW + kDBGWVR4_EL1 = encode(0b10, 0b000, 0b0000, 0b0100, 0b110), // RW + kDBGWVR5_EL1 = encode(0b10, 0b000, 0b0000, 0b0101, 0b110), // RW + kDBGWVR6_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b110), // RW + kDBGWVR7_EL1 = encode(0b10, 0b000, 0b0000, 0b0111, 0b110), // RW + kDBGWVR8_EL1 = encode(0b10, 0b000, 0b0000, 0b1000, 0b110), // RW + kDBGWVR9_EL1 = encode(0b10, 0b000, 0b0000, 0b1001, 0b110), // RW + kDCZID_EL0 = encode(0b11, 0b011, 0b0000, 0b0000, 0b111), // RO + kDISR_EL1 = encode(0b11, 0b000, 0b1100, 0b0001, 0b001), // RW + kDIT = encode(0b11, 0b011, 0b0100, 0b0010, 0b101), // RW + kDLR_EL0 = encode(0b11, 0b011, 0b0100, 0b0101, 0b001), // RW + kDSPSR_EL0 = encode(0b11, 0b011, 0b0100, 0b0101, 0b000), // RW + kELR_EL1 = encode(0b11, 0b000, 0b0100, 0b0000, 0b001), // RW + kELR_EL12 = encode(0b11, 0b101, 0b0100, 0b0000, 0b001), // RW + kELR_EL2 = encode(0b11, 0b100, 0b0100, 0b0000, 0b001), // RW + kELR_EL3 = encode(0b11, 0b110, 0b0100, 0b0000, 0b001), // RW + kERRIDR_EL1 = encode(0b11, 0b000, 0b0101, 0b0011, 0b000), // RO + kERRSELR_EL1 = encode(0b11, 0b000, 0b0101, 0b0011, 0b001), // RW + kERXADDR_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b011), // RW + kERXCTLR_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b001), // RW + kERXFR_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b000), // RO + kERXMISC0_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b000), // RW + kERXMISC1_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b001), // RW + kERXMISC2_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b010), // RW + kERXMISC3_EL1 = encode(0b11, 0b000, 0b0101, 0b0101, 0b011), // RW + kERXPFGCDN_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b110), // RW + kERXPFGCTL_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b101), // RW + kERXPFGF_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b100), // RO + kERXSTATUS_EL1 = encode(0b11, 0b000, 0b0101, 0b0100, 0b010), // RW + kESR_EL1 = encode(0b11, 0b000, 0b0101, 0b0010, 0b000), // RW + kESR_EL12 = encode(0b11, 0b101, 0b0101, 0b0010, 0b000), // RW + kESR_EL2 = encode(0b11, 0b100, 0b0101, 0b0010, 0b000), // RW + kESR_EL3 = encode(0b11, 0b110, 0b0101, 0b0010, 0b000), // RW + kFAR_EL1 = encode(0b11, 0b000, 0b0110, 0b0000, 0b000), // RW + kFAR_EL12 = encode(0b11, 0b101, 0b0110, 0b0000, 0b000), // RW + kFAR_EL2 = encode(0b11, 0b100, 0b0110, 0b0000, 0b000), // RW + kFAR_EL3 = encode(0b11, 0b110, 0b0110, 0b0000, 0b000), // RW + kFPCR = encode(0b11, 0b011, 0b0100, 0b0100, 0b000), // RW + kFPEXC32_EL2 = encode(0b11, 0b100, 0b0101, 0b0011, 0b000), // RW + kFPSR = encode(0b11, 0b011, 0b0100, 0b0100, 0b001), // RW + kGCR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b110), // RW + kGMID_EL1 = encode(0b11, 0b001, 0b0000, 0b0000, 0b100), // RO + kHACR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b111), // RW + kHCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b000), // RW + kHDFGRTR_EL2 = encode(0b11, 0b100, 0b0011, 0b0001, 0b100), // RW + kHDFGWTR_EL2 = encode(0b11, 0b100, 0b0011, 0b0001, 0b101), // RW + kHFGITR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b110), // RW + kHFGRTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b100), // RW + kHFGWTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b101), // RW + kHPFAR_EL2 = encode(0b11, 0b100, 0b0110, 0b0000, 0b100), // RW + kHSTR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b011), // RW + kICC_AP0R0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b100), // RW + kICC_AP0R1_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b101), // RW + kICC_AP0R2_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b110), // RW + kICC_AP0R3_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b111), // RW + kICC_AP1R0_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b000), // RW + kICC_AP1R1_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b001), // RW + kICC_AP1R2_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b010), // RW + kICC_AP1R3_EL1 = encode(0b11, 0b000, 0b1100, 0b1001, 0b011), // RW + kICC_ASGI1R_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b110), // WO + kICC_BPR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b011), // RW + kICC_BPR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b011), // RW + kICC_CTLR_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b100), // RW + kICC_CTLR_EL3 = encode(0b11, 0b110, 0b1100, 0b1100, 0b100), // RW + kICC_DIR_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b001), // WO + kICC_EOIR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b001), // WO + kICC_EOIR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b001), // WO + kICC_HPPIR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b010), // RO + kICC_HPPIR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b010), // RO + kICC_IAR0_EL1 = encode(0b11, 0b000, 0b1100, 0b1000, 0b000), // RO + kICC_IAR1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b000), // RO + kICC_IGRPEN0_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b110), // RW + kICC_IGRPEN1_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b111), // RW + kICC_IGRPEN1_EL3 = encode(0b11, 0b110, 0b1100, 0b1100, 0b111), // RW + kICC_PMR_EL1 = encode(0b11, 0b000, 0b0100, 0b0110, 0b000), // RW + kICC_RPR_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b011), // RO + kICC_SGI0R_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b111), // WO + kICC_SGI1R_EL1 = encode(0b11, 0b000, 0b1100, 0b1011, 0b101), // WO + kICC_SRE_EL1 = encode(0b11, 0b000, 0b1100, 0b1100, 0b101), // RW + kICC_SRE_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b101), // RW + kICC_SRE_EL3 = encode(0b11, 0b110, 0b1100, 0b1100, 0b101), // RW + kICH_AP0R0_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b000), // RW + kICH_AP0R1_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b001), // RW + kICH_AP0R2_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b010), // RW + kICH_AP0R3_EL2 = encode(0b11, 0b100, 0b1100, 0b1000, 0b011), // RW + kICH_AP1R0_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b000), // RW + kICH_AP1R1_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b001), // RW + kICH_AP1R2_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b010), // RW + kICH_AP1R3_EL2 = encode(0b11, 0b100, 0b1100, 0b1001, 0b011), // RW + kICH_EISR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b011), // RO + kICH_ELRSR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b101), // RO + kICH_HCR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b000), // RW + kICH_LR0_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b000), // RW + kICH_LR10_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b010), // RW + kICH_LR11_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b011), // RW + kICH_LR12_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b100), // RW + kICH_LR13_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b101), // RW + kICH_LR14_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b110), // RW + kICH_LR15_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b111), // RW + kICH_LR1_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b001), // RW + kICH_LR2_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b010), // RW + kICH_LR3_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b011), // RW + kICH_LR4_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b100), // RW + kICH_LR5_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b101), // RW + kICH_LR6_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b110), // RW + kICH_LR7_EL2 = encode(0b11, 0b100, 0b1100, 0b1100, 0b111), // RW + kICH_LR8_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b000), // RW + kICH_LR9_EL2 = encode(0b11, 0b100, 0b1100, 0b1101, 0b001), // RW + kICH_MISR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b010), // RO + kICH_VMCR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b111), // RW + kICH_VTR_EL2 = encode(0b11, 0b100, 0b1100, 0b1011, 0b001), // RO + kID_AA64AFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b100), // RO + kID_AA64AFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b101), // RO + kID_AA64DFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b000), // RO + kID_AA64DFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0101, 0b001), // RO + kID_AA64ISAR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0110, 0b000), // RO + kID_AA64ISAR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0110, 0b001), // RO + kID_AA64ISAR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0110, 0b010), // RO + kID_AA64MMFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b000), // RO + kID_AA64MMFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b001), // RO + kID_AA64MMFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b010), // RO + kID_AA64MMFR3_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b011), // RO + kID_AA64MMFR4_EL1 = encode(0b11, 0b000, 0b0000, 0b0111, 0b100), // RO + kID_AA64PFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0100, 0b000), // RO + kID_AA64PFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0100, 0b001), // RO + kID_AA64ZFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0100, 0b100), // RO + kID_AFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b011), // RO + kID_DFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b010), // RO + kID_ISAR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b000), // RO + kID_ISAR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b001), // RO + kID_ISAR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b010), // RO + kID_ISAR3_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b011), // RO + kID_ISAR4_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b100), // RO + kID_ISAR5_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b101), // RO + kID_ISAR6_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b111), // RO + kID_MMFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b100), // RO + kID_MMFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b101), // RO + kID_MMFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b110), // RO + kID_MMFR3_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b111), // RO + kID_MMFR4_EL1 = encode(0b11, 0b000, 0b0000, 0b0010, 0b110), // RO + kID_MMFR5_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b110), // RO + kID_PFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b000), // RO + kID_PFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0001, 0b001), // RO + kID_PFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b100), // RO + kIFSR32_EL2 = encode(0b11, 0b100, 0b0101, 0b0000, 0b001), // RW + kISR_EL1 = encode(0b11, 0b000, 0b1100, 0b0001, 0b000), // RO + kLORC_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b011), // RW + kLOREA_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b001), // RW + kLORID_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b111), // RO + kLORN_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b010), // RW + kLORSA_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b000), // RW + kMAIR_EL1 = encode(0b11, 0b000, 0b1010, 0b0010, 0b000), // RW + kMAIR_EL12 = encode(0b11, 0b101, 0b1010, 0b0010, 0b000), // RW + kMAIR_EL2 = encode(0b11, 0b100, 0b1010, 0b0010, 0b000), // RW + kMAIR_EL3 = encode(0b11, 0b110, 0b1010, 0b0010, 0b000), // RW + kMDCCINT_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b000), // RW + kMDCCSR_EL0 = encode(0b10, 0b011, 0b0000, 0b0001, 0b000), // RO + kMDCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0001, 0b001), // RW + kMDCR_EL3 = encode(0b11, 0b110, 0b0001, 0b0011, 0b001), // RW + kMDRAR_EL1 = encode(0b10, 0b000, 0b0001, 0b0000, 0b000), // RO + kMDSCR_EL1 = encode(0b10, 0b000, 0b0000, 0b0010, 0b010), // RW + kMIDR_EL1 = encode(0b11, 0b000, 0b0000, 0b0000, 0b000), // RO + kMPAM0_EL1 = encode(0b11, 0b000, 0b1010, 0b0101, 0b001), // RW + kMPAM1_EL1 = encode(0b11, 0b000, 0b1010, 0b0101, 0b000), // RW + kMPAM1_EL12 = encode(0b11, 0b101, 0b1010, 0b0101, 0b000), // RW + kMPAM2_EL2 = encode(0b11, 0b100, 0b1010, 0b0101, 0b000), // RW + kMPAM3_EL3 = encode(0b11, 0b110, 0b1010, 0b0101, 0b000), // RW + kMPAMHCR_EL2 = encode(0b11, 0b100, 0b1010, 0b0100, 0b000), // RW + kMPAMIDR_EL1 = encode(0b11, 0b000, 0b1010, 0b0100, 0b100), // RO + kMPAMVPM0_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b000), // RW + kMPAMVPM1_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b001), // RW + kMPAMVPM2_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b010), // RW + kMPAMVPM3_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b011), // RW + kMPAMVPM4_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b100), // RW + kMPAMVPM5_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b101), // RW + kMPAMVPM6_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b110), // RW + kMPAMVPM7_EL2 = encode(0b11, 0b100, 0b1010, 0b0110, 0b111), // RW + kMPAMVPMV_EL2 = encode(0b11, 0b100, 0b1010, 0b0100, 0b001), // RW + kMPIDR_EL1 = encode(0b11, 0b000, 0b0000, 0b0000, 0b101), // RO + kMVFR0_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b000), // RO + kMVFR1_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b001), // RO + kMVFR2_EL1 = encode(0b11, 0b000, 0b0000, 0b0011, 0b010), // RO + kNZCV = encode(0b11, 0b011, 0b0100, 0b0010, 0b000), // RW + kOSDLR_EL1 = encode(0b10, 0b000, 0b0001, 0b0011, 0b100), // RW + kOSDTRRX_EL1 = encode(0b10, 0b000, 0b0000, 0b0000, 0b010), // RW + kOSDTRTX_EL1 = encode(0b10, 0b000, 0b0000, 0b0011, 0b010), // RW + kOSECCR_EL1 = encode(0b10, 0b000, 0b0000, 0b0110, 0b010), // RW + kOSLAR_EL1 = encode(0b10, 0b000, 0b0001, 0b0000, 0b100), // WO + kOSLSR_EL1 = encode(0b10, 0b000, 0b0001, 0b0001, 0b100), // RO + kPAN = encode(0b11, 0b000, 0b0100, 0b0010, 0b011), // RW + kPAR_EL1 = encode(0b11, 0b000, 0b0111, 0b0100, 0b000), // RW + kPMBIDR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b111), // RO + kPMBLIMITR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b000), // RW + kPMBPTR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b001), // RW + kPMBSR_EL1 = encode(0b11, 0b000, 0b1001, 0b1010, 0b011), // RW + kPMCCFILTR_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b111), // RW + kPMCCNTR_EL0 = encode(0b11, 0b011, 0b1001, 0b1101, 0b000), // RW + kPMCEID0_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b110), // RO + kPMCEID1_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b111), // RO + kPMCNTENCLR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b010), // RW + kPMCNTENSET_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b001), // RW + kPMCR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b000), // RW + kPMEVCNTR0_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b000), // RW + kPMEVCNTR10_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b010), // RW + kPMEVCNTR11_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b011), // RW + kPMEVCNTR12_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b100), // RW + kPMEVCNTR13_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b101), // RW + kPMEVCNTR14_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b110), // RW + kPMEVCNTR15_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b111), // RW + kPMEVCNTR16_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b000), // RW + kPMEVCNTR17_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b001), // RW + kPMEVCNTR18_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b010), // RW + kPMEVCNTR19_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b011), // RW + kPMEVCNTR1_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b001), // RW + kPMEVCNTR20_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b100), // RW + kPMEVCNTR21_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b101), // RW + kPMEVCNTR22_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b110), // RW + kPMEVCNTR23_EL0 = encode(0b11, 0b011, 0b1110, 0b1010, 0b111), // RW + kPMEVCNTR24_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b000), // RW + kPMEVCNTR25_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b001), // RW + kPMEVCNTR26_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b010), // RW + kPMEVCNTR27_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b011), // RW + kPMEVCNTR28_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b100), // RW + kPMEVCNTR29_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b101), // RW + kPMEVCNTR2_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b010), // RW + kPMEVCNTR30_EL0 = encode(0b11, 0b011, 0b1110, 0b1011, 0b110), // RW + kPMEVCNTR3_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b011), // RW + kPMEVCNTR4_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b100), // RW + kPMEVCNTR5_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b101), // RW + kPMEVCNTR6_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b110), // RW + kPMEVCNTR7_EL0 = encode(0b11, 0b011, 0b1110, 0b1000, 0b111), // RW + kPMEVCNTR8_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b000), // RW + kPMEVCNTR9_EL0 = encode(0b11, 0b011, 0b1110, 0b1001, 0b001), // RW + kPMEVTYPER0_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b000), // RW + kPMEVTYPER10_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b010), // RW + kPMEVTYPER11_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b011), // RW + kPMEVTYPER12_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b100), // RW + kPMEVTYPER13_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b101), // RW + kPMEVTYPER14_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b110), // RW + kPMEVTYPER15_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b111), // RW + kPMEVTYPER16_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b000), // RW + kPMEVTYPER17_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b001), // RW + kPMEVTYPER18_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b010), // RW + kPMEVTYPER19_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b011), // RW + kPMEVTYPER1_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b001), // RW + kPMEVTYPER20_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b100), // RW + kPMEVTYPER21_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b101), // RW + kPMEVTYPER22_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b110), // RW + kPMEVTYPER23_EL0 = encode(0b11, 0b011, 0b1110, 0b1110, 0b111), // RW + kPMEVTYPER24_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b000), // RW + kPMEVTYPER25_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b001), // RW + kPMEVTYPER26_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b010), // RW + kPMEVTYPER27_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b011), // RW + kPMEVTYPER28_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b100), // RW + kPMEVTYPER29_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b101), // RW + kPMEVTYPER2_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b010), // RW + kPMEVTYPER30_EL0 = encode(0b11, 0b011, 0b1110, 0b1111, 0b110), // RW + kPMEVTYPER3_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b011), // RW + kPMEVTYPER4_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b100), // RW + kPMEVTYPER5_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b101), // RW + kPMEVTYPER6_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b110), // RW + kPMEVTYPER7_EL0 = encode(0b11, 0b011, 0b1110, 0b1100, 0b111), // RW + kPMEVTYPER8_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b000), // RW + kPMEVTYPER9_EL0 = encode(0b11, 0b011, 0b1110, 0b1101, 0b001), // RW + kPMINTENCLR_EL1 = encode(0b11, 0b000, 0b1001, 0b1110, 0b010), // RW + kPMINTENSET_EL1 = encode(0b11, 0b000, 0b1001, 0b1110, 0b001), // RW + kPMMIR_EL1 = encode(0b11, 0b000, 0b1001, 0b1110, 0b110), // RW + kPMOVSCLR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b011), // RW + kPMOVSSET_EL0 = encode(0b11, 0b011, 0b1001, 0b1110, 0b011), // RW + kPMSCR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b000), // RW + kPMSCR_EL12 = encode(0b11, 0b101, 0b1001, 0b1001, 0b000), // RW + kPMSCR_EL2 = encode(0b11, 0b100, 0b1001, 0b1001, 0b000), // RW + kPMSELR_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b101), // RW + kPMSEVFR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b101), // RW + kPMSFCR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b100), // RW + kPMSICR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b010), // RW + kPMSIDR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b111), // RO + kPMSIRR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b011), // RW + kPMSLATFR_EL1 = encode(0b11, 0b000, 0b1001, 0b1001, 0b110), // RW + kPMSWINC_EL0 = encode(0b11, 0b011, 0b1001, 0b1100, 0b100), // WO + kPMUSERENR_EL0 = encode(0b11, 0b011, 0b1001, 0b1110, 0b000), // RW + kPMXEVCNTR_EL0 = encode(0b11, 0b011, 0b1001, 0b1101, 0b010), // RW + kPMXEVTYPER_EL0 = encode(0b11, 0b011, 0b1001, 0b1101, 0b001), // RW + kREVIDR_EL1 = encode(0b11, 0b000, 0b0000, 0b0000, 0b110), // RO + kRGSR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b101), // RW + kRMR_EL1 = encode(0b11, 0b000, 0b1100, 0b0000, 0b010), // RW + kRMR_EL2 = encode(0b11, 0b100, 0b1100, 0b0000, 0b010), // RW + kRMR_EL3 = encode(0b11, 0b110, 0b1100, 0b0000, 0b010), // RW + kRNDR = encode(0b11, 0b011, 0b0010, 0b0100, 0b000), // RO + kRNDRRS = encode(0b11, 0b011, 0b0010, 0b0100, 0b001), // RO + kRVBAR_EL1 = encode(0b11, 0b000, 0b1100, 0b0000, 0b001), // RO + kRVBAR_EL2 = encode(0b11, 0b100, 0b1100, 0b0000, 0b001), // RO + kRVBAR_EL3 = encode(0b11, 0b110, 0b1100, 0b0000, 0b001), // RO + kSCR_EL3 = encode(0b11, 0b110, 0b0001, 0b0001, 0b000), // RW + kSCTLR_EL1 = encode(0b11, 0b000, 0b0001, 0b0000, 0b000), // RW + kSCTLR_EL12 = encode(0b11, 0b101, 0b0001, 0b0000, 0b000), // RW + kSCTLR_EL2 = encode(0b11, 0b100, 0b0001, 0b0000, 0b000), // RW + kSCTLR_EL3 = encode(0b11, 0b110, 0b0001, 0b0000, 0b000), // RW + kSCXTNUM_EL0 = encode(0b11, 0b011, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL1 = encode(0b11, 0b000, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL12 = encode(0b11, 0b101, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL2 = encode(0b11, 0b100, 0b1101, 0b0000, 0b111), // RW + kSCXTNUM_EL3 = encode(0b11, 0b110, 0b1101, 0b0000, 0b111), // RW + kSDER32_EL2 = encode(0b11, 0b100, 0b0001, 0b0011, 0b001), // RW + kSDER32_EL3 = encode(0b11, 0b110, 0b0001, 0b0001, 0b001), // RW + kSPSR_EL1 = encode(0b11, 0b000, 0b0100, 0b0000, 0b000), // RW + kSPSR_EL12 = encode(0b11, 0b101, 0b0100, 0b0000, 0b000), // RW + kSPSR_EL2 = encode(0b11, 0b100, 0b0100, 0b0000, 0b000), // RW + kSPSR_EL3 = encode(0b11, 0b110, 0b0100, 0b0000, 0b000), // RW + kSPSR_abt = encode(0b11, 0b100, 0b0100, 0b0011, 0b001), // RW + kSPSR_fiq = encode(0b11, 0b100, 0b0100, 0b0011, 0b011), // RW + kSPSR_irq = encode(0b11, 0b100, 0b0100, 0b0011, 0b000), // RW + kSPSR_und = encode(0b11, 0b100, 0b0100, 0b0011, 0b010), // RW + kSPSel = encode(0b11, 0b000, 0b0100, 0b0010, 0b000), // RW + kSP_EL0 = encode(0b11, 0b000, 0b0100, 0b0001, 0b000), // RW + kSP_EL1 = encode(0b11, 0b100, 0b0100, 0b0001, 0b000), // RW + kSP_EL2 = encode(0b11, 0b110, 0b0100, 0b0001, 0b000), // RW + kSSBS = encode(0b11, 0b011, 0b0100, 0b0010, 0b110), // RW + kTCO = encode(0b11, 0b011, 0b0100, 0b0010, 0b111), // RW + kTCR_EL1 = encode(0b11, 0b000, 0b0010, 0b0000, 0b010), // RW + kTCR_EL12 = encode(0b11, 0b101, 0b0010, 0b0000, 0b010), // RW + kTCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0000, 0b010), // RW + kTCR_EL3 = encode(0b11, 0b110, 0b0010, 0b0000, 0b010), // RW + kTEECR32_EL1 = encode(0b10, 0b010, 0b0000, 0b0000, 0b000), // RW + kTEEHBR32_EL1 = encode(0b10, 0b010, 0b0001, 0b0000, 0b000), // RW + kTFSRE0_EL1 = encode(0b11, 0b000, 0b0101, 0b0110, 0b001), // RW + kTFSR_EL1 = encode(0b11, 0b000, 0b0101, 0b0110, 0b000), // RW + kTFSR_EL12 = encode(0b11, 0b101, 0b0101, 0b0110, 0b000), // RW + kTFSR_EL2 = encode(0b11, 0b100, 0b0101, 0b0110, 0b000), // RW + kTFSR_EL3 = encode(0b11, 0b110, 0b0101, 0b0110, 0b000), // RW + kTPIDRRO_EL0 = encode(0b11, 0b011, 0b1101, 0b0000, 0b011), // RW + kTPIDR_EL0 = encode(0b11, 0b011, 0b1101, 0b0000, 0b010), // RW + kTPIDR_EL1 = encode(0b11, 0b000, 0b1101, 0b0000, 0b100), // RW + kTPIDR_EL2 = encode(0b11, 0b100, 0b1101, 0b0000, 0b010), // RW + kTPIDR_EL3 = encode(0b11, 0b110, 0b1101, 0b0000, 0b010), // RW + kTRBBASER_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b010), // RW + kTRBIDR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b111), // RO + kTRBLIMITR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b000), // RW + kTRBMAR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b100), // RW + kTRBPTR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b001), // RW + kTRBSR_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b011), // RW + kTRBTRG_EL1 = encode(0b11, 0b000, 0b1001, 0b1011, 0b110), // RW + kTRCACATR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b010), // RW + kTRCACATR1 = encode(0b10, 0b001, 0b0010, 0b0010, 0b010), // RW + kTRCACATR10 = encode(0b10, 0b001, 0b0010, 0b0100, 0b011), // RW + kTRCACATR11 = encode(0b10, 0b001, 0b0010, 0b0110, 0b011), // RW + kTRCACATR12 = encode(0b10, 0b001, 0b0010, 0b1000, 0b011), // RW + kTRCACATR13 = encode(0b10, 0b001, 0b0010, 0b1010, 0b011), // RW + kTRCACATR14 = encode(0b10, 0b001, 0b0010, 0b1100, 0b011), // RW + kTRCACATR15 = encode(0b10, 0b001, 0b0010, 0b1110, 0b011), // RW + kTRCACATR2 = encode(0b10, 0b001, 0b0010, 0b0100, 0b010), // RW + kTRCACATR3 = encode(0b10, 0b001, 0b0010, 0b0110, 0b010), // RW + kTRCACATR4 = encode(0b10, 0b001, 0b0010, 0b1000, 0b010), // RW + kTRCACATR5 = encode(0b10, 0b001, 0b0010, 0b1010, 0b010), // RW + kTRCACATR6 = encode(0b10, 0b001, 0b0010, 0b1100, 0b010), // RW + kTRCACATR7 = encode(0b10, 0b001, 0b0010, 0b1110, 0b010), // RW + kTRCACATR8 = encode(0b10, 0b001, 0b0010, 0b0000, 0b011), // RW + kTRCACATR9 = encode(0b10, 0b001, 0b0010, 0b0010, 0b011), // RW + kTRCACVR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b000), // RW + kTRCACVR1 = encode(0b10, 0b001, 0b0010, 0b0010, 0b000), // RW + kTRCACVR10 = encode(0b10, 0b001, 0b0010, 0b0100, 0b001), // RW + kTRCACVR11 = encode(0b10, 0b001, 0b0010, 0b0110, 0b001), // RW + kTRCACVR12 = encode(0b10, 0b001, 0b0010, 0b1000, 0b001), // RW + kTRCACVR13 = encode(0b10, 0b001, 0b0010, 0b1010, 0b001), // RW + kTRCACVR14 = encode(0b10, 0b001, 0b0010, 0b1100, 0b001), // RW + kTRCACVR15 = encode(0b10, 0b001, 0b0010, 0b1110, 0b001), // RW + kTRCACVR2 = encode(0b10, 0b001, 0b0010, 0b0100, 0b000), // RW + kTRCACVR3 = encode(0b10, 0b001, 0b0010, 0b0110, 0b000), // RW + kTRCACVR4 = encode(0b10, 0b001, 0b0010, 0b1000, 0b000), // RW + kTRCACVR5 = encode(0b10, 0b001, 0b0010, 0b1010, 0b000), // RW + kTRCACVR6 = encode(0b10, 0b001, 0b0010, 0b1100, 0b000), // RW + kTRCACVR7 = encode(0b10, 0b001, 0b0010, 0b1110, 0b000), // RW + kTRCACVR8 = encode(0b10, 0b001, 0b0010, 0b0000, 0b001), // RW + kTRCACVR9 = encode(0b10, 0b001, 0b0010, 0b0010, 0b001), // RW + kTRCAUTHSTATUS = encode(0b10, 0b001, 0b0111, 0b1110, 0b110), // RO + kTRCAUXCTLR = encode(0b10, 0b001, 0b0000, 0b0110, 0b000), // RW + kTRCBBCTLR = encode(0b10, 0b001, 0b0000, 0b1111, 0b000), // RW + kTRCCCCTLR = encode(0b10, 0b001, 0b0000, 0b1110, 0b000), // RW + kTRCCIDCCTLR0 = encode(0b10, 0b001, 0b0011, 0b0000, 0b010), // RW + kTRCCIDCCTLR1 = encode(0b10, 0b001, 0b0011, 0b0001, 0b010), // RW + kTRCCIDCVR0 = encode(0b10, 0b001, 0b0011, 0b0000, 0b000), // RW + kTRCCIDCVR1 = encode(0b10, 0b001, 0b0011, 0b0010, 0b000), // RW + kTRCCIDCVR2 = encode(0b10, 0b001, 0b0011, 0b0100, 0b000), // RW + kTRCCIDCVR3 = encode(0b10, 0b001, 0b0011, 0b0110, 0b000), // RW + kTRCCIDCVR4 = encode(0b10, 0b001, 0b0011, 0b1000, 0b000), // RW + kTRCCIDCVR5 = encode(0b10, 0b001, 0b0011, 0b1010, 0b000), // RW + kTRCCIDCVR6 = encode(0b10, 0b001, 0b0011, 0b1100, 0b000), // RW + kTRCCIDCVR7 = encode(0b10, 0b001, 0b0011, 0b1110, 0b000), // RW + kTRCCIDR0 = encode(0b10, 0b001, 0b0111, 0b1100, 0b111), // RO + kTRCCIDR1 = encode(0b10, 0b001, 0b0111, 0b1101, 0b111), // RO + kTRCCIDR2 = encode(0b10, 0b001, 0b0111, 0b1110, 0b111), // RO + kTRCCIDR3 = encode(0b10, 0b001, 0b0111, 0b1111, 0b111), // RO + kTRCCLAIMCLR = encode(0b10, 0b001, 0b0111, 0b1001, 0b110), // RW + kTRCCLAIMSET = encode(0b10, 0b001, 0b0111, 0b1000, 0b110), // RW + kTRCCNTCTLR0 = encode(0b10, 0b001, 0b0000, 0b0100, 0b101), // RW + kTRCCNTCTLR1 = encode(0b10, 0b001, 0b0000, 0b0101, 0b101), // RW + kTRCCNTCTLR2 = encode(0b10, 0b001, 0b0000, 0b0110, 0b101), // RW + kTRCCNTCTLR3 = encode(0b10, 0b001, 0b0000, 0b0111, 0b101), // RW + kTRCCNTRLDVR0 = encode(0b10, 0b001, 0b0000, 0b0000, 0b101), // RW + kTRCCNTRLDVR1 = encode(0b10, 0b001, 0b0000, 0b0001, 0b101), // RW + kTRCCNTRLDVR2 = encode(0b10, 0b001, 0b0000, 0b0010, 0b101), // RW + kTRCCNTRLDVR3 = encode(0b10, 0b001, 0b0000, 0b0011, 0b101), // RW + kTRCCNTVR0 = encode(0b10, 0b001, 0b0000, 0b1000, 0b101), // RW + kTRCCNTVR1 = encode(0b10, 0b001, 0b0000, 0b1001, 0b101), // RW + kTRCCNTVR2 = encode(0b10, 0b001, 0b0000, 0b1010, 0b101), // RW + kTRCCNTVR3 = encode(0b10, 0b001, 0b0000, 0b1011, 0b101), // RW + kTRCCONFIGR = encode(0b10, 0b001, 0b0000, 0b0100, 0b000), // RW + kTRCDEVAFF0 = encode(0b10, 0b001, 0b0111, 0b1010, 0b110), // RO + kTRCDEVAFF1 = encode(0b10, 0b001, 0b0111, 0b1011, 0b110), // RO + kTRCDEVARCH = encode(0b10, 0b001, 0b0111, 0b1111, 0b110), // RO + kTRCDEVID = encode(0b10, 0b001, 0b0111, 0b0010, 0b111), // RO + kTRCDEVTYPE = encode(0b10, 0b001, 0b0111, 0b0011, 0b111), // RO + kTRCDVCMR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b110), // RW + kTRCDVCMR1 = encode(0b10, 0b001, 0b0010, 0b0100, 0b110), // RW + kTRCDVCMR2 = encode(0b10, 0b001, 0b0010, 0b1000, 0b110), // RW + kTRCDVCMR3 = encode(0b10, 0b001, 0b0010, 0b1100, 0b110), // RW + kTRCDVCMR4 = encode(0b10, 0b001, 0b0010, 0b0000, 0b111), // RW + kTRCDVCMR5 = encode(0b10, 0b001, 0b0010, 0b0100, 0b111), // RW + kTRCDVCMR6 = encode(0b10, 0b001, 0b0010, 0b1000, 0b111), // RW + kTRCDVCMR7 = encode(0b10, 0b001, 0b0010, 0b1100, 0b111), // RW + kTRCDVCVR0 = encode(0b10, 0b001, 0b0010, 0b0000, 0b100), // RW + kTRCDVCVR1 = encode(0b10, 0b001, 0b0010, 0b0100, 0b100), // RW + kTRCDVCVR2 = encode(0b10, 0b001, 0b0010, 0b1000, 0b100), // RW + kTRCDVCVR3 = encode(0b10, 0b001, 0b0010, 0b1100, 0b100), // RW + kTRCDVCVR4 = encode(0b10, 0b001, 0b0010, 0b0000, 0b101), // RW + kTRCDVCVR5 = encode(0b10, 0b001, 0b0010, 0b0100, 0b101), // RW + kTRCDVCVR6 = encode(0b10, 0b001, 0b0010, 0b1000, 0b101), // RW + kTRCDVCVR7 = encode(0b10, 0b001, 0b0010, 0b1100, 0b101), // RW + kTRCEVENTCTL0R = encode(0b10, 0b001, 0b0000, 0b1000, 0b000), // RW + kTRCEVENTCTL1R = encode(0b10, 0b001, 0b0000, 0b1001, 0b000), // RW + kTRCEXTINSELR = encode(0b10, 0b001, 0b0000, 0b1000, 0b100), // RW + kTRCEXTINSELR0 = encode(0b10, 0b001, 0b0000, 0b1000, 0b100), // RW + kTRCEXTINSELR1 = encode(0b10, 0b001, 0b0000, 0b1001, 0b100), // RW + kTRCEXTINSELR2 = encode(0b10, 0b001, 0b0000, 0b1010, 0b100), // RW + kTRCEXTINSELR3 = encode(0b10, 0b001, 0b0000, 0b1011, 0b100), // RW + kTRCIDR0 = encode(0b10, 0b001, 0b0000, 0b1000, 0b111), // RO + kTRCIDR1 = encode(0b10, 0b001, 0b0000, 0b1001, 0b111), // RO + kTRCIDR10 = encode(0b10, 0b001, 0b0000, 0b0010, 0b110), // RO + kTRCIDR11 = encode(0b10, 0b001, 0b0000, 0b0011, 0b110), // RO + kTRCIDR12 = encode(0b10, 0b001, 0b0000, 0b0100, 0b110), // RO + kTRCIDR13 = encode(0b10, 0b001, 0b0000, 0b0101, 0b110), // RO + kTRCIDR2 = encode(0b10, 0b001, 0b0000, 0b1010, 0b111), // RO + kTRCIDR3 = encode(0b10, 0b001, 0b0000, 0b1011, 0b111), // RO + kTRCIDR4 = encode(0b10, 0b001, 0b0000, 0b1100, 0b111), // RO + kTRCIDR5 = encode(0b10, 0b001, 0b0000, 0b1101, 0b111), // RO + kTRCIDR6 = encode(0b10, 0b001, 0b0000, 0b1110, 0b111), // RO + kTRCIDR7 = encode(0b10, 0b001, 0b0000, 0b1111, 0b111), // RO + kTRCIDR8 = encode(0b10, 0b001, 0b0000, 0b0000, 0b110), // RO + kTRCIDR9 = encode(0b10, 0b001, 0b0000, 0b0001, 0b110), // RO + kTRCIMSPEC0 = encode(0b10, 0b001, 0b0000, 0b0000, 0b111), // RW + kTRCIMSPEC1 = encode(0b10, 0b001, 0b0000, 0b0001, 0b111), // RW + kTRCIMSPEC2 = encode(0b10, 0b001, 0b0000, 0b0010, 0b111), // RW + kTRCIMSPEC3 = encode(0b10, 0b001, 0b0000, 0b0011, 0b111), // RW + kTRCIMSPEC4 = encode(0b10, 0b001, 0b0000, 0b0100, 0b111), // RW + kTRCIMSPEC5 = encode(0b10, 0b001, 0b0000, 0b0101, 0b111), // RW + kTRCIMSPEC6 = encode(0b10, 0b001, 0b0000, 0b0110, 0b111), // RW + kTRCIMSPEC7 = encode(0b10, 0b001, 0b0000, 0b0111, 0b111), // RW + kTRCITCTRL = encode(0b10, 0b001, 0b0111, 0b0000, 0b100), // RW + kTRCLAR = encode(0b10, 0b001, 0b0111, 0b1100, 0b110), // WO + kTRCLSR = encode(0b10, 0b001, 0b0111, 0b1101, 0b110), // RO + kTRCOSLAR = encode(0b10, 0b001, 0b0001, 0b0000, 0b100), // WO + kTRCOSLSR = encode(0b10, 0b001, 0b0001, 0b0001, 0b100), // RO + kTRCPDCR = encode(0b10, 0b001, 0b0001, 0b0100, 0b100), // RW + kTRCPDSR = encode(0b10, 0b001, 0b0001, 0b0101, 0b100), // RO + kTRCPIDR0 = encode(0b10, 0b001, 0b0111, 0b1000, 0b111), // RO + kTRCPIDR1 = encode(0b10, 0b001, 0b0111, 0b1001, 0b111), // RO + kTRCPIDR2 = encode(0b10, 0b001, 0b0111, 0b1010, 0b111), // RO + kTRCPIDR3 = encode(0b10, 0b001, 0b0111, 0b1011, 0b111), // RO + kTRCPIDR4 = encode(0b10, 0b001, 0b0111, 0b0100, 0b111), // RO + kTRCPIDR5 = encode(0b10, 0b001, 0b0111, 0b0101, 0b111), // RO + kTRCPIDR6 = encode(0b10, 0b001, 0b0111, 0b0110, 0b111), // RO + kTRCPIDR7 = encode(0b10, 0b001, 0b0111, 0b0111, 0b111), // RO + kTRCPRGCTLR = encode(0b10, 0b001, 0b0000, 0b0001, 0b000), // RW + kTRCPROCSELR = encode(0b10, 0b001, 0b0000, 0b0010, 0b000), // RW + kTRCQCTLR = encode(0b10, 0b001, 0b0000, 0b0001, 0b001), // RW + kTRCRSCTLR10 = encode(0b10, 0b001, 0b0001, 0b1010, 0b000), // RW + kTRCRSCTLR11 = encode(0b10, 0b001, 0b0001, 0b1011, 0b000), // RW + kTRCRSCTLR12 = encode(0b10, 0b001, 0b0001, 0b1100, 0b000), // RW + kTRCRSCTLR13 = encode(0b10, 0b001, 0b0001, 0b1101, 0b000), // RW + kTRCRSCTLR14 = encode(0b10, 0b001, 0b0001, 0b1110, 0b000), // RW + kTRCRSCTLR15 = encode(0b10, 0b001, 0b0001, 0b1111, 0b000), // RW + kTRCRSCTLR16 = encode(0b10, 0b001, 0b0001, 0b0000, 0b001), // RW + kTRCRSCTLR17 = encode(0b10, 0b001, 0b0001, 0b0001, 0b001), // RW + kTRCRSCTLR18 = encode(0b10, 0b001, 0b0001, 0b0010, 0b001), // RW + kTRCRSCTLR19 = encode(0b10, 0b001, 0b0001, 0b0011, 0b001), // RW + kTRCRSCTLR2 = encode(0b10, 0b001, 0b0001, 0b0010, 0b000), // RW + kTRCRSCTLR20 = encode(0b10, 0b001, 0b0001, 0b0100, 0b001), // RW + kTRCRSCTLR21 = encode(0b10, 0b001, 0b0001, 0b0101, 0b001), // RW + kTRCRSCTLR22 = encode(0b10, 0b001, 0b0001, 0b0110, 0b001), // RW + kTRCRSCTLR23 = encode(0b10, 0b001, 0b0001, 0b0111, 0b001), // RW + kTRCRSCTLR24 = encode(0b10, 0b001, 0b0001, 0b1000, 0b001), // RW + kTRCRSCTLR25 = encode(0b10, 0b001, 0b0001, 0b1001, 0b001), // RW + kTRCRSCTLR26 = encode(0b10, 0b001, 0b0001, 0b1010, 0b001), // RW + kTRCRSCTLR27 = encode(0b10, 0b001, 0b0001, 0b1011, 0b001), // RW + kTRCRSCTLR28 = encode(0b10, 0b001, 0b0001, 0b1100, 0b001), // RW + kTRCRSCTLR29 = encode(0b10, 0b001, 0b0001, 0b1101, 0b001), // RW + kTRCRSCTLR3 = encode(0b10, 0b001, 0b0001, 0b0011, 0b000), // RW + kTRCRSCTLR30 = encode(0b10, 0b001, 0b0001, 0b1110, 0b001), // RW + kTRCRSCTLR31 = encode(0b10, 0b001, 0b0001, 0b1111, 0b001), // RW + kTRCRSCTLR4 = encode(0b10, 0b001, 0b0001, 0b0100, 0b000), // RW + kTRCRSCTLR5 = encode(0b10, 0b001, 0b0001, 0b0101, 0b000), // RW + kTRCRSCTLR6 = encode(0b10, 0b001, 0b0001, 0b0110, 0b000), // RW + kTRCRSCTLR7 = encode(0b10, 0b001, 0b0001, 0b0111, 0b000), // RW + kTRCRSCTLR8 = encode(0b10, 0b001, 0b0001, 0b1000, 0b000), // RW + kTRCRSCTLR9 = encode(0b10, 0b001, 0b0001, 0b1001, 0b000), // RW + kTRCRSR = encode(0b10, 0b001, 0b0000, 0b1010, 0b000), // RW + kTRCSEQEVR0 = encode(0b10, 0b001, 0b0000, 0b0000, 0b100), // RW + kTRCSEQEVR1 = encode(0b10, 0b001, 0b0000, 0b0001, 0b100), // RW + kTRCSEQEVR2 = encode(0b10, 0b001, 0b0000, 0b0010, 0b100), // RW + kTRCSEQRSTEVR = encode(0b10, 0b001, 0b0000, 0b0110, 0b100), // RW + kTRCSEQSTR = encode(0b10, 0b001, 0b0000, 0b0111, 0b100), // RW + kTRCSSCCR0 = encode(0b10, 0b001, 0b0001, 0b0000, 0b010), // RW + kTRCSSCCR1 = encode(0b10, 0b001, 0b0001, 0b0001, 0b010), // RW + kTRCSSCCR2 = encode(0b10, 0b001, 0b0001, 0b0010, 0b010), // RW + kTRCSSCCR3 = encode(0b10, 0b001, 0b0001, 0b0011, 0b010), // RW + kTRCSSCCR4 = encode(0b10, 0b001, 0b0001, 0b0100, 0b010), // RW + kTRCSSCCR5 = encode(0b10, 0b001, 0b0001, 0b0101, 0b010), // RW + kTRCSSCCR6 = encode(0b10, 0b001, 0b0001, 0b0110, 0b010), // RW + kTRCSSCCR7 = encode(0b10, 0b001, 0b0001, 0b0111, 0b010), // RW + kTRCSSCSR0 = encode(0b10, 0b001, 0b0001, 0b1000, 0b010), // RW + kTRCSSCSR1 = encode(0b10, 0b001, 0b0001, 0b1001, 0b010), // RW + kTRCSSCSR2 = encode(0b10, 0b001, 0b0001, 0b1010, 0b010), // RW + kTRCSSCSR3 = encode(0b10, 0b001, 0b0001, 0b1011, 0b010), // RW + kTRCSSCSR4 = encode(0b10, 0b001, 0b0001, 0b1100, 0b010), // RW + kTRCSSCSR5 = encode(0b10, 0b001, 0b0001, 0b1101, 0b010), // RW + kTRCSSCSR6 = encode(0b10, 0b001, 0b0001, 0b1110, 0b010), // RW + kTRCSSCSR7 = encode(0b10, 0b001, 0b0001, 0b1111, 0b010), // RW + kTRCSSPCICR0 = encode(0b10, 0b001, 0b0001, 0b0000, 0b011), // RW + kTRCSSPCICR1 = encode(0b10, 0b001, 0b0001, 0b0001, 0b011), // RW + kTRCSSPCICR2 = encode(0b10, 0b001, 0b0001, 0b0010, 0b011), // RW + kTRCSSPCICR3 = encode(0b10, 0b001, 0b0001, 0b0011, 0b011), // RW + kTRCSSPCICR4 = encode(0b10, 0b001, 0b0001, 0b0100, 0b011), // RW + kTRCSSPCICR5 = encode(0b10, 0b001, 0b0001, 0b0101, 0b011), // RW + kTRCSSPCICR6 = encode(0b10, 0b001, 0b0001, 0b0110, 0b011), // RW + kTRCSSPCICR7 = encode(0b10, 0b001, 0b0001, 0b0111, 0b011), // RW + kTRCSTALLCTLR = encode(0b10, 0b001, 0b0000, 0b1011, 0b000), // RW + kTRCSTATR = encode(0b10, 0b001, 0b0000, 0b0011, 0b000), // RO + kTRCSYNCPR = encode(0b10, 0b001, 0b0000, 0b1101, 0b000), // RW + kTRCTRACEIDR = encode(0b10, 0b001, 0b0000, 0b0000, 0b001), // RW + kTRCTSCTLR = encode(0b10, 0b001, 0b0000, 0b1100, 0b000), // RW + kTRCVDARCCTLR = encode(0b10, 0b001, 0b0000, 0b1010, 0b010), // RW + kTRCVDCTLR = encode(0b10, 0b001, 0b0000, 0b1000, 0b010), // RW + kTRCVDSACCTLR = encode(0b10, 0b001, 0b0000, 0b1001, 0b010), // RW + kTRCVICTLR = encode(0b10, 0b001, 0b0000, 0b0000, 0b010), // RW + kTRCVIIECTLR = encode(0b10, 0b001, 0b0000, 0b0001, 0b010), // RW + kTRCVIPCSSCTLR = encode(0b10, 0b001, 0b0000, 0b0011, 0b010), // RW + kTRCVISSCTLR = encode(0b10, 0b001, 0b0000, 0b0010, 0b010), // RW + kTRCVMIDCCTLR0 = encode(0b10, 0b001, 0b0011, 0b0010, 0b010), // RW + kTRCVMIDCCTLR1 = encode(0b10, 0b001, 0b0011, 0b0011, 0b010), // RW + kTRCVMIDCVR0 = encode(0b10, 0b001, 0b0011, 0b0000, 0b001), // RW + kTRCVMIDCVR1 = encode(0b10, 0b001, 0b0011, 0b0010, 0b001), // RW + kTRCVMIDCVR2 = encode(0b10, 0b001, 0b0011, 0b0100, 0b001), // RW + kTRCVMIDCVR3 = encode(0b10, 0b001, 0b0011, 0b0110, 0b001), // RW + kTRCVMIDCVR4 = encode(0b10, 0b001, 0b0011, 0b1000, 0b001), // RW + kTRCVMIDCVR5 = encode(0b10, 0b001, 0b0011, 0b1010, 0b001), // RW + kTRCVMIDCVR6 = encode(0b10, 0b001, 0b0011, 0b1100, 0b001), // RW + kTRCVMIDCVR7 = encode(0b10, 0b001, 0b0011, 0b1110, 0b001), // RW + kTRFCR_EL1 = encode(0b11, 0b000, 0b0001, 0b0010, 0b001), // RW + kTRFCR_EL12 = encode(0b11, 0b101, 0b0001, 0b0010, 0b001), // RW + kTRFCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0010, 0b001), // RW + kTTBR0_EL1 = encode(0b11, 0b000, 0b0010, 0b0000, 0b000), // RW + kTTBR0_EL12 = encode(0b11, 0b101, 0b0010, 0b0000, 0b000), // RW + kTTBR0_EL2 = encode(0b11, 0b100, 0b0010, 0b0000, 0b000), // RW + kTTBR0_EL3 = encode(0b11, 0b110, 0b0010, 0b0000, 0b000), // RW + kTTBR1_EL1 = encode(0b11, 0b000, 0b0010, 0b0000, 0b001), // RW + kTTBR1_EL12 = encode(0b11, 0b101, 0b0010, 0b0000, 0b001), // RW + kTTBR1_EL2 = encode(0b11, 0b100, 0b0010, 0b0000, 0b001), // RW + kUAO = encode(0b11, 0b000, 0b0100, 0b0010, 0b100), // RW + kVBAR_EL1 = encode(0b11, 0b000, 0b1100, 0b0000, 0b000), // RW + kVBAR_EL12 = encode(0b11, 0b101, 0b1100, 0b0000, 0b000), // RW + kVBAR_EL2 = encode(0b11, 0b100, 0b1100, 0b0000, 0b000), // RW + kVBAR_EL3 = encode(0b11, 0b110, 0b1100, 0b0000, 0b000), // RW + kVDISR_EL2 = encode(0b11, 0b100, 0b1100, 0b0001, 0b001), // RW + kVMPIDR_EL2 = encode(0b11, 0b100, 0b0000, 0b0000, 0b101), // RW + kVNCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0010, 0b000), // RW + kVPIDR_EL2 = encode(0b11, 0b100, 0b0000, 0b0000, 0b000), // RW + kVSESR_EL2 = encode(0b11, 0b100, 0b0101, 0b0010, 0b011), // RW + kVSTCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0110, 0b010), // RW + kVSTTBR_EL2 = encode(0b11, 0b100, 0b0010, 0b0110, 0b000), // RW + kVTCR_EL2 = encode(0b11, 0b100, 0b0010, 0b0001, 0b010), // RW + kVTTBR_EL2 = encode(0b11, 0b100, 0b0010, 0b0001, 0b000), // RW + kZCR_EL1 = encode(0b11, 0b000, 0b0001, 0b0010, 0b000), // RW + kZCR_EL12 = encode(0b11, 0b101, 0b0001, 0b0010, 0b000), // RW + kZCR_EL2 = encode(0b11, 0b100, 0b0001, 0b0010, 0b000), // RW + kZCR_EL3 = encode(0b11, 0b110, 0b0001, 0b0010, 0b000) // RW + }; +}; + +} // {Predicate} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64GLOBALS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64instdb.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64instdb.h new file mode 100644 index 0000000000000000000000000000000000000000..3adeee88cfdc0d780f963acc85c55049f9f01aa3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64instdb.h @@ -0,0 +1,72 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64INSTDB_H_INCLUDED +#define ASMJIT_ARM_A64INSTDB_H_INCLUDED + +#include "../arm/a64globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +//! Instruction database (AArch64). +namespace InstDB { + +//! Instruction flags. +enum InstFlags : uint32_t { + //! The instruction provides conditional execution. + kInstFlagCond = 0x00000001u, + //! SIMD instruction that processes elements in pairs. + kInstFlagPair = 0x00000002u, + //! SIMD instruction that does widening (Long). + kInstFlagLong = 0x00000004u, + //! SIMD instruction that does narrowing (Narrow). + kInstFlagNarrow = 0x00000008u, + //! SIMD element access of half-words can only be used with v0..15. + kInstFlagVH0_15 = 0x00000010u, + + //! Instruction uses consecutive registers if the number of operands is greater than 2. + kInstFlagConsecutive = 0x00000080u +}; + +//! Instruction information (AArch64). +struct InstInfo { + //! Instruction encoding type. + uint32_t _encoding : 8; + //! Index to data specific to each encoding type. + uint32_t _encodingDataIndex : 8; + uint32_t _reserved : 16; + + uint16_t _rwInfoIndex; + uint16_t _flags; + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG uint32_t rwInfoIndex() const noexcept { return _rwInfoIndex; } + ASMJIT_INLINE_NODEBUG uint32_t flags() const noexcept { return _flags; } + + ASMJIT_INLINE_NODEBUG bool hasFlag(uint32_t flag) const { return (_flags & flag) != 0; } + + //! \} +}; + +ASMJIT_VARAPI const InstInfo _instInfoTable[]; + +static inline const InstInfo& infoById(InstId instId) noexcept { + instId &= uint32_t(InstIdParts::kRealId); + ASMJIT_ASSERT(Inst::isDefinedId(instId)); + return _instInfoTable[instId]; +} + +} // {InstDB} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_A64INSTDB_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64operand.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64operand.h new file mode 100644 index 0000000000000000000000000000000000000000..65947d73b1d6c51aff03a90a9e47ba65def35aae --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/a64operand.h @@ -0,0 +1,650 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_A64OPERAND_H_INCLUDED +#define ASMJIT_ARM_A64OPERAND_H_INCLUDED + +#include "../arm/armoperand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) + +//! \addtogroup asmjit_a64 +//! \{ + +class GpW; +class GpX; + +class VecB; +class VecH; +class VecS; +class VecD; +class VecV; + +//! General purpose register (AArch64). +class Gp : public Reg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Gp, Reg) + + //! Special register id. + enum Id : uint32_t { + //! Register that depends on OS, could be used as TLS offset. + kIdOs = 18, + //! Frame pointer register id. + kIdFp = 29, + //! Link register id. + kIdLr = 30, + //! Stack register id. + kIdSp = 31, + //! Zero register id. + //! + //! Although zero register has the same id as stack register it has a special treatment, because we need to be + //! able to distinguish between these two at API level. Some instructions were designed to be used with SP and + //! some other with ZR - so we need a way to distinguish these two to make sure we emit the right thing. + //! + //! The number 63 is not random, when you perform `id & 31` you would always get 31 for both SP and ZR inputs, + //! which is the identifier used by AArch64 ISA to encode either SP or ZR depending on the instruction. + kIdZr = 63 + }; + + //! Test whether this register is ZR register. + ASMJIT_INLINE_NODEBUG constexpr bool isZR() const noexcept { return id() == kIdZr; } + //! Test whether this register is SP register. + ASMJIT_INLINE_NODEBUG constexpr bool isSP() const noexcept { return id() == kIdSp; } + + //! Cast this register to a 32-bit W register (returns a new operand). + ASMJIT_INLINE_NODEBUG GpW w() const noexcept; + //! \overload + ASMJIT_INLINE_NODEBUG GpW r32() const noexcept; + //! Cast this register to a 64-bit X register (returns a new operand). + ASMJIT_INLINE_NODEBUG GpX x() const noexcept; + //! \overload + ASMJIT_INLINE_NODEBUG GpX r64() const noexcept; +}; + +//! 32-bit general purpose W register (AArch64). +class GpW : public Gp { ASMJIT_DEFINE_FINAL_REG(GpW, Gp, RegTraits); }; +//! 64-bit general purpose X register (AArch64). +class GpX : public Gp { ASMJIT_DEFINE_FINAL_REG(GpX, Gp, RegTraits); }; + +#ifndef _DOXYGEN +ASMJIT_INLINE_NODEBUG GpW Gp::w() const noexcept { return GpW(id()); } +ASMJIT_INLINE_NODEBUG GpX Gp::x() const noexcept { return GpX(id()); } +ASMJIT_INLINE_NODEBUG GpW Gp::r32() const noexcept { return GpW(id()); } +ASMJIT_INLINE_NODEBUG GpX Gp::r64() const noexcept { return GpX(id()); } +#endif + +//! Vector element type (AArch64). +enum class VecElementType : uint32_t { + //! No element type specified. + kNone = 0, + //! Byte elements (B8 or B16). + kB, + //! Halfword elements (H4 or H8). + kH, + //! Singleword elements (S2 or S4). + kS, + //! Doubleword elements (D2). + kD, + //! Byte elements grouped by 4 bytes (B4). + //! + //! \note This element-type is only used by few instructions. + kB4, + //! Halfword elements grouped by 2 halfwords (H2). + //! + //! \note This element-type is only used by few instructions. + kH2, + + //! Maximum value of \ref VecElementType + kMaxValue = kH2 +}; + +//! Vector register (AArch64). +class Vec : public BaseVec { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Vec, BaseVec) + + //! \cond + //! Shortcuts. + enum SignatureReg : uint32_t { + kSignatureElementB = uint32_t(VecElementType::kB) << kSignatureRegElementTypeShift, + kSignatureElementH = uint32_t(VecElementType::kH) << kSignatureRegElementTypeShift, + kSignatureElementS = uint32_t(VecElementType::kS) << kSignatureRegElementTypeShift, + kSignatureElementD = uint32_t(VecElementType::kD) << kSignatureRegElementTypeShift, + kSignatureElementB4 = uint32_t(VecElementType::kB4) << kSignatureRegElementTypeShift, + kSignatureElementH2 = uint32_t(VecElementType::kH2) << kSignatureRegElementTypeShift + }; + //! \endcond + + //! Returns whether the register has element type or element index (or both). + ASMJIT_INLINE_NODEBUG constexpr bool hasElementTypeOrIndex() const noexcept { return _signature.hasField(); } + + //! Returns whether the vector register has associated a vector element type. + ASMJIT_INLINE_NODEBUG constexpr bool hasElementType() const noexcept { return _signature.hasField(); } + //! Returns vector element type of the register. + ASMJIT_INLINE_NODEBUG constexpr VecElementType elementType() const noexcept { return VecElementType(_signature.getField()); } + //! Sets vector element type of the register to `elementType`. + ASMJIT_INLINE_NODEBUG void setElementType(VecElementType elementType) noexcept { _signature.setField(uint32_t(elementType)); } + //! Resets vector element type to none. + ASMJIT_INLINE_NODEBUG void resetElementType() noexcept { _signature.setField(0); } + + ASMJIT_INLINE_NODEBUG constexpr bool isVecB8() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementB); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecH4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementH); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecS2() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementS); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecD1() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature); } + + ASMJIT_INLINE_NODEBUG constexpr bool isVecB16() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementB); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecH8() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementH); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecS4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementS); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecD2() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementD); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecB4x4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementB4); } + ASMJIT_INLINE_NODEBUG constexpr bool isVecH2x4() const noexcept { return _signature.subset(kBaseSignatureMask | kSignatureRegElementTypeMask) == (RegTraits::kSignature | kSignatureElementH2); } + + //! Creates a cloned register with element access. + ASMJIT_INLINE_NODEBUG Vec at(uint32_t elementIndex) const noexcept { + return Vec((signature() & ~kSignatureRegElementIndexMask) | (elementIndex << kSignatureRegElementIndexShift) | kSignatureRegElementFlagMask, id()); + } + + //! Cast this register to an 8-bit B register (AArch64 only). + ASMJIT_INLINE_NODEBUG VecB b() const noexcept; + //! Cast this register to a 16-bit H register (AArch64 only). + ASMJIT_INLINE_NODEBUG VecH h() const noexcept; + //! Cast this register to a 32-bit S register. + ASMJIT_INLINE_NODEBUG VecS s() const noexcept; + //! Cast this register to a 64-bit D register. + ASMJIT_INLINE_NODEBUG VecD d() const noexcept; + //! Cast this register to a 128-bit Q register. + ASMJIT_INLINE_NODEBUG VecV q() const noexcept; + //! Cast this register to a 128-bit V register. + ASMJIT_INLINE_NODEBUG VecV v() const noexcept; + + //! Casts this register to b (clone). + ASMJIT_INLINE_NODEBUG Vec v8() const noexcept; + //! Casts this register to h (clone). + ASMJIT_INLINE_NODEBUG Vec v16() const noexcept; + //! Casts this register to s (clone). + ASMJIT_INLINE_NODEBUG Vec v32() const noexcept; + //! Casts this register to d (clone). + ASMJIT_INLINE_NODEBUG Vec v64() const noexcept; + //! Casts this register to q (clone). + ASMJIT_INLINE_NODEBUG Vec v128() const noexcept; + + //! Cast this register to a 128-bit V.B[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV b(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.H[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV h(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.S[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV s(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.D[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV d(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.H2[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV h2(uint32_t elementIndex) const noexcept; + //! Cast this register to a 128-bit V.B4[elementIndex] register. + ASMJIT_INLINE_NODEBUG VecV b4(uint32_t elementIndex) const noexcept; + + //! Cast this register to V.8B. + ASMJIT_INLINE_NODEBUG VecD b8() const noexcept; + //! Cast this register to V.16B. + ASMJIT_INLINE_NODEBUG VecV b16() const noexcept; + //! Cast this register to V.2H. + ASMJIT_INLINE_NODEBUG VecS h2() const noexcept; + //! Cast this register to V.4H. + ASMJIT_INLINE_NODEBUG VecD h4() const noexcept; + //! Cast this register to V.8H. + ASMJIT_INLINE_NODEBUG VecV h8() const noexcept; + //! Cast this register to V.2S. + ASMJIT_INLINE_NODEBUG VecD s2() const noexcept; + //! Cast this register to V.4S. + ASMJIT_INLINE_NODEBUG VecV s4() const noexcept; + //! Cast this register to V.2D. + ASMJIT_INLINE_NODEBUG VecV d2() const noexcept; + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature _makeElementAccessSignature(VecElementType elementType, uint32_t elementIndex) noexcept { + return OperandSignature{ + uint32_t(RegTraits::kSignature) | + uint32_t(kSignatureRegElementFlagMask) | + (uint32_t(elementType) << kSignatureRegElementTypeShift) | + (uint32_t(elementIndex << kSignatureRegElementIndexShift))}; + } +}; + +//! 8-bit view (S) of VFP/SIMD register. +class VecB : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecB, Vec, RegTraits) +}; + +//! 16-bit view (S) of VFP/SIMD register. +class VecH : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecH, Vec, RegTraits) +}; + +//! 32-bit view (S) of VFP/SIMD register. +class VecS : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecS, Vec, RegTraits) +}; + +//! 64-bit view (D) of VFP/SIMD register. +class VecD : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecD, Vec, RegTraits) +}; + +//! 128-bit vector register (Q or V). +class VecV : public Vec { +public: + ASMJIT_DEFINE_FINAL_REG(VecV, Vec, RegTraits) +}; + +ASMJIT_INLINE_NODEBUG VecB Vec::b() const noexcept { return VecB(id()); } +ASMJIT_INLINE_NODEBUG VecH Vec::h() const noexcept { return VecH(id()); } +ASMJIT_INLINE_NODEBUG VecS Vec::s() const noexcept { return VecS(id()); } +ASMJIT_INLINE_NODEBUG VecD Vec::d() const noexcept { return VecD(id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::q() const noexcept { return VecV(id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::v() const noexcept { return VecV(id()); } + +ASMJIT_INLINE_NODEBUG Vec Vec::v8() const noexcept { return VecB(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v16() const noexcept { return VecH(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v32() const noexcept { return VecS(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v64() const noexcept { return VecD(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v128() const noexcept { return VecV(id()); } + +ASMJIT_INLINE_NODEBUG VecV Vec::b(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kB, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::h(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kH, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::s(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kS, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::d(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kD, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::h2(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kH2, elementIndex), id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::b4(uint32_t elementIndex) const noexcept { return VecV(_makeElementAccessSignature(VecElementType::kB4, elementIndex), id()); } + +ASMJIT_INLINE_NODEBUG VecD Vec::b8() const noexcept { return VecD(OperandSignature{VecD::kSignature | kSignatureElementB}, id()); } +ASMJIT_INLINE_NODEBUG VecS Vec::h2() const noexcept { return VecS(OperandSignature{VecS::kSignature | kSignatureElementH}, id()); } +ASMJIT_INLINE_NODEBUG VecD Vec::h4() const noexcept { return VecD(OperandSignature{VecD::kSignature | kSignatureElementH}, id()); } +ASMJIT_INLINE_NODEBUG VecD Vec::s2() const noexcept { return VecD(OperandSignature{VecD::kSignature | kSignatureElementS}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::b16() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementB}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::h8() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementH}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::s4() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementS}, id()); } +ASMJIT_INLINE_NODEBUG VecV Vec::d2() const noexcept { return VecV(OperandSignature{VecV::kSignature | kSignatureElementD}, id()); } + +#ifndef _DOXYGEN +namespace regs { +#endif + +//! Creates a 32-bit W register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpW w(uint32_t id) noexcept { return GpW(id); } +//! Creates a 64-bit X register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpX x(uint32_t id) noexcept { return GpX(id); } + +//! Creates a 32-bit S register operand. +static ASMJIT_INLINE_NODEBUG constexpr VecS s(uint32_t id) noexcept { return VecS(id); } +//! Creates a 64-bit D register operand. +static ASMJIT_INLINE_NODEBUG constexpr VecD d(uint32_t id) noexcept { return VecD(id); } +//! Creates a 1282-bit V register operand. +static ASMJIT_INLINE_NODEBUG constexpr VecV v(uint32_t id) noexcept { return VecV(id); } + +static constexpr GpW w0 = GpW(0); +static constexpr GpW w1 = GpW(1); +static constexpr GpW w2 = GpW(2); +static constexpr GpW w3 = GpW(3); +static constexpr GpW w4 = GpW(4); +static constexpr GpW w5 = GpW(5); +static constexpr GpW w6 = GpW(6); +static constexpr GpW w7 = GpW(7); +static constexpr GpW w8 = GpW(8); +static constexpr GpW w9 = GpW(9); +static constexpr GpW w10 = GpW(10); +static constexpr GpW w11 = GpW(11); +static constexpr GpW w12 = GpW(12); +static constexpr GpW w13 = GpW(13); +static constexpr GpW w14 = GpW(14); +static constexpr GpW w15 = GpW(15); +static constexpr GpW w16 = GpW(16); +static constexpr GpW w17 = GpW(17); +static constexpr GpW w18 = GpW(18); +static constexpr GpW w19 = GpW(19); +static constexpr GpW w20 = GpW(20); +static constexpr GpW w21 = GpW(21); +static constexpr GpW w22 = GpW(22); +static constexpr GpW w23 = GpW(23); +static constexpr GpW w24 = GpW(24); +static constexpr GpW w25 = GpW(25); +static constexpr GpW w26 = GpW(26); +static constexpr GpW w27 = GpW(27); +static constexpr GpW w28 = GpW(28); +static constexpr GpW w29 = GpW(29); +static constexpr GpW w30 = GpW(30); +static constexpr GpW wzr = GpW(Gp::kIdZr); +static constexpr GpW wsp = GpW(Gp::kIdSp); + +static constexpr GpX x0 = GpX(0); +static constexpr GpX x1 = GpX(1); +static constexpr GpX x2 = GpX(2); +static constexpr GpX x3 = GpX(3); +static constexpr GpX x4 = GpX(4); +static constexpr GpX x5 = GpX(5); +static constexpr GpX x6 = GpX(6); +static constexpr GpX x7 = GpX(7); +static constexpr GpX x8 = GpX(8); +static constexpr GpX x9 = GpX(9); +static constexpr GpX x10 = GpX(10); +static constexpr GpX x11 = GpX(11); +static constexpr GpX x12 = GpX(12); +static constexpr GpX x13 = GpX(13); +static constexpr GpX x14 = GpX(14); +static constexpr GpX x15 = GpX(15); +static constexpr GpX x16 = GpX(16); +static constexpr GpX x17 = GpX(17); +static constexpr GpX x18 = GpX(18); +static constexpr GpX x19 = GpX(19); +static constexpr GpX x20 = GpX(20); +static constexpr GpX x21 = GpX(21); +static constexpr GpX x22 = GpX(22); +static constexpr GpX x23 = GpX(23); +static constexpr GpX x24 = GpX(24); +static constexpr GpX x25 = GpX(25); +static constexpr GpX x26 = GpX(26); +static constexpr GpX x27 = GpX(27); +static constexpr GpX x28 = GpX(28); +static constexpr GpX x29 = GpX(29); +static constexpr GpX x30 = GpX(30); +static constexpr GpX xzr = GpX(Gp::kIdZr); +static constexpr GpX sp = GpX(Gp::kIdSp); + +static constexpr VecB b0 = VecB(0); +static constexpr VecB b1 = VecB(1); +static constexpr VecB b2 = VecB(2); +static constexpr VecB b3 = VecB(3); +static constexpr VecB b4 = VecB(4); +static constexpr VecB b5 = VecB(5); +static constexpr VecB b6 = VecB(6); +static constexpr VecB b7 = VecB(7); +static constexpr VecB b8 = VecB(8); +static constexpr VecB b9 = VecB(9); +static constexpr VecB b10 = VecB(10); +static constexpr VecB b11 = VecB(11); +static constexpr VecB b12 = VecB(12); +static constexpr VecB b13 = VecB(13); +static constexpr VecB b14 = VecB(14); +static constexpr VecB b15 = VecB(15); +static constexpr VecB b16 = VecB(16); +static constexpr VecB b17 = VecB(17); +static constexpr VecB b18 = VecB(18); +static constexpr VecB b19 = VecB(19); +static constexpr VecB b20 = VecB(20); +static constexpr VecB b21 = VecB(21); +static constexpr VecB b22 = VecB(22); +static constexpr VecB b23 = VecB(23); +static constexpr VecB b24 = VecB(24); +static constexpr VecB b25 = VecB(25); +static constexpr VecB b26 = VecB(26); +static constexpr VecB b27 = VecB(27); +static constexpr VecB b28 = VecB(28); +static constexpr VecB b29 = VecB(29); +static constexpr VecB b30 = VecB(30); +static constexpr VecB b31 = VecB(31); + +static constexpr VecH h0 = VecH(0); +static constexpr VecH h1 = VecH(1); +static constexpr VecH h2 = VecH(2); +static constexpr VecH h3 = VecH(3); +static constexpr VecH h4 = VecH(4); +static constexpr VecH h5 = VecH(5); +static constexpr VecH h6 = VecH(6); +static constexpr VecH h7 = VecH(7); +static constexpr VecH h8 = VecH(8); +static constexpr VecH h9 = VecH(9); +static constexpr VecH h10 = VecH(10); +static constexpr VecH h11 = VecH(11); +static constexpr VecH h12 = VecH(12); +static constexpr VecH h13 = VecH(13); +static constexpr VecH h14 = VecH(14); +static constexpr VecH h15 = VecH(15); +static constexpr VecH h16 = VecH(16); +static constexpr VecH h17 = VecH(17); +static constexpr VecH h18 = VecH(18); +static constexpr VecH h19 = VecH(19); +static constexpr VecH h20 = VecH(20); +static constexpr VecH h21 = VecH(21); +static constexpr VecH h22 = VecH(22); +static constexpr VecH h23 = VecH(23); +static constexpr VecH h24 = VecH(24); +static constexpr VecH h25 = VecH(25); +static constexpr VecH h26 = VecH(26); +static constexpr VecH h27 = VecH(27); +static constexpr VecH h28 = VecH(28); +static constexpr VecH h29 = VecH(29); +static constexpr VecH h30 = VecH(30); +static constexpr VecH h31 = VecH(31); + +static constexpr VecS s0 = VecS(0); +static constexpr VecS s1 = VecS(1); +static constexpr VecS s2 = VecS(2); +static constexpr VecS s3 = VecS(3); +static constexpr VecS s4 = VecS(4); +static constexpr VecS s5 = VecS(5); +static constexpr VecS s6 = VecS(6); +static constexpr VecS s7 = VecS(7); +static constexpr VecS s8 = VecS(8); +static constexpr VecS s9 = VecS(9); +static constexpr VecS s10 = VecS(10); +static constexpr VecS s11 = VecS(11); +static constexpr VecS s12 = VecS(12); +static constexpr VecS s13 = VecS(13); +static constexpr VecS s14 = VecS(14); +static constexpr VecS s15 = VecS(15); +static constexpr VecS s16 = VecS(16); +static constexpr VecS s17 = VecS(17); +static constexpr VecS s18 = VecS(18); +static constexpr VecS s19 = VecS(19); +static constexpr VecS s20 = VecS(20); +static constexpr VecS s21 = VecS(21); +static constexpr VecS s22 = VecS(22); +static constexpr VecS s23 = VecS(23); +static constexpr VecS s24 = VecS(24); +static constexpr VecS s25 = VecS(25); +static constexpr VecS s26 = VecS(26); +static constexpr VecS s27 = VecS(27); +static constexpr VecS s28 = VecS(28); +static constexpr VecS s29 = VecS(29); +static constexpr VecS s30 = VecS(30); +static constexpr VecS s31 = VecS(31); + +static constexpr VecD d0 = VecD(0); +static constexpr VecD d1 = VecD(1); +static constexpr VecD d2 = VecD(2); +static constexpr VecD d3 = VecD(3); +static constexpr VecD d4 = VecD(4); +static constexpr VecD d5 = VecD(5); +static constexpr VecD d6 = VecD(6); +static constexpr VecD d7 = VecD(7); +static constexpr VecD d8 = VecD(8); +static constexpr VecD d9 = VecD(9); +static constexpr VecD d10 = VecD(10); +static constexpr VecD d11 = VecD(11); +static constexpr VecD d12 = VecD(12); +static constexpr VecD d13 = VecD(13); +static constexpr VecD d14 = VecD(14); +static constexpr VecD d15 = VecD(15); +static constexpr VecD d16 = VecD(16); +static constexpr VecD d17 = VecD(17); +static constexpr VecD d18 = VecD(18); +static constexpr VecD d19 = VecD(19); +static constexpr VecD d20 = VecD(20); +static constexpr VecD d21 = VecD(21); +static constexpr VecD d22 = VecD(22); +static constexpr VecD d23 = VecD(23); +static constexpr VecD d24 = VecD(24); +static constexpr VecD d25 = VecD(25); +static constexpr VecD d26 = VecD(26); +static constexpr VecD d27 = VecD(27); +static constexpr VecD d28 = VecD(28); +static constexpr VecD d29 = VecD(29); +static constexpr VecD d30 = VecD(30); +static constexpr VecD d31 = VecD(31); + +static constexpr VecV q0 = VecV(0); +static constexpr VecV q1 = VecV(1); +static constexpr VecV q2 = VecV(2); +static constexpr VecV q3 = VecV(3); +static constexpr VecV q4 = VecV(4); +static constexpr VecV q5 = VecV(5); +static constexpr VecV q6 = VecV(6); +static constexpr VecV q7 = VecV(7); +static constexpr VecV q8 = VecV(8); +static constexpr VecV q9 = VecV(9); +static constexpr VecV q10 = VecV(10); +static constexpr VecV q11 = VecV(11); +static constexpr VecV q12 = VecV(12); +static constexpr VecV q13 = VecV(13); +static constexpr VecV q14 = VecV(14); +static constexpr VecV q15 = VecV(15); +static constexpr VecV q16 = VecV(16); +static constexpr VecV q17 = VecV(17); +static constexpr VecV q18 = VecV(18); +static constexpr VecV q19 = VecV(19); +static constexpr VecV q20 = VecV(20); +static constexpr VecV q21 = VecV(21); +static constexpr VecV q22 = VecV(22); +static constexpr VecV q23 = VecV(23); +static constexpr VecV q24 = VecV(24); +static constexpr VecV q25 = VecV(25); +static constexpr VecV q26 = VecV(26); +static constexpr VecV q27 = VecV(27); +static constexpr VecV q28 = VecV(28); +static constexpr VecV q29 = VecV(29); +static constexpr VecV q30 = VecV(30); +static constexpr VecV q31 = VecV(31); + +static constexpr VecV v0 = VecV(0); +static constexpr VecV v1 = VecV(1); +static constexpr VecV v2 = VecV(2); +static constexpr VecV v3 = VecV(3); +static constexpr VecV v4 = VecV(4); +static constexpr VecV v5 = VecV(5); +static constexpr VecV v6 = VecV(6); +static constexpr VecV v7 = VecV(7); +static constexpr VecV v8 = VecV(8); +static constexpr VecV v9 = VecV(9); +static constexpr VecV v10 = VecV(10); +static constexpr VecV v11 = VecV(11); +static constexpr VecV v12 = VecV(12); +static constexpr VecV v13 = VecV(13); +static constexpr VecV v14 = VecV(14); +static constexpr VecV v15 = VecV(15); +static constexpr VecV v16 = VecV(16); +static constexpr VecV v17 = VecV(17); +static constexpr VecV v18 = VecV(18); +static constexpr VecV v19 = VecV(19); +static constexpr VecV v20 = VecV(20); +static constexpr VecV v21 = VecV(21); +static constexpr VecV v22 = VecV(22); +static constexpr VecV v23 = VecV(23); +static constexpr VecV v24 = VecV(24); +static constexpr VecV v25 = VecV(25); +static constexpr VecV v26 = VecV(26); +static constexpr VecV v27 = VecV(27); +static constexpr VecV v28 = VecV(28); +static constexpr VecV v29 = VecV(29); +static constexpr VecV v30 = VecV(30); +static constexpr VecV v31 = VecV(31); + +#ifndef _DOXYGEN +} // {regs} + +// Make `a64::regs` accessible through `a64` namespace as well. +using namespace regs; +#endif + +//! \name Shift Operation Construction +//! \{ + +//! Constructs a `UXTB #value` extend and shift (unsigned byte extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxtb(uint32_t value) noexcept { return Shift(ShiftOp::kUXTB, value); } +//! Constructs a `UXTH #value` extend and shift (unsigned hword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxth(uint32_t value) noexcept { return Shift(ShiftOp::kUXTH, value); } +//! Constructs a `UXTW #value` extend and shift (unsigned word extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxtw(uint32_t value) noexcept { return Shift(ShiftOp::kUXTW, value); } +//! Constructs a `UXTX #value` extend and shift (unsigned dword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift uxtx(uint32_t value) noexcept { return Shift(ShiftOp::kUXTX, value); } + +//! Constructs a `SXTB #value` extend and shift (signed byte extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxtb(uint32_t value) noexcept { return Shift(ShiftOp::kSXTB, value); } +//! Constructs a `SXTH #value` extend and shift (signed hword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxth(uint32_t value) noexcept { return Shift(ShiftOp::kSXTH, value); } +//! Constructs a `SXTW #value` extend and shift (signed word extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxtw(uint32_t value) noexcept { return Shift(ShiftOp::kSXTW, value); } +//! Constructs a `SXTX #value` extend and shift (signed dword extend) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Shift sxtx(uint32_t value) noexcept { return Shift(ShiftOp::kSXTX, value); } + +//! \} + +//! \name Memory Operand Construction +//! \{ + +//! Creates `[base, offset]` memory operand (offset mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, int32_t offset = 0) noexcept { + return Mem(base, offset); +} + +//! Creates `[base, offset]!` memory operand (pre-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_pre(const Gp& base, int32_t offset = 0) noexcept { + return Mem(base, offset, OperandSignature::fromValue(OffsetMode::kPreIndex)); +} + +//! Creates `[base], offset` memory operand (post-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_post(const Gp& base, int32_t offset = 0) noexcept { + return Mem(base, offset, OperandSignature::fromValue(OffsetMode::kPostIndex)); +} + +//! Creates `[base, index]` memory operand (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Gp& index) noexcept { + return Mem(base, index); +} + +//! Creates `[base, index]!` memory operand (pre-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_pre(const Gp& base, const Gp& index) noexcept { + return Mem(base, index, OperandSignature::fromValue(OffsetMode::kPreIndex)); +} + +//! Creates `[base], index` memory operand (post-index mode) (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_post(const Gp& base, const Gp& index) noexcept { + return Mem(base, index, OperandSignature::fromValue(OffsetMode::kPostIndex)); +} + +//! Creates `[base, index, SHIFT_OP #shift]` memory operand (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Gp& index, const Shift& shift) noexcept { + return Mem(base, index, shift); +} + +//! Creates `[base, offset]` memory operand (AArch64). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, int32_t offset = 0) noexcept { + return Mem(base, offset); +} + +// TODO: [ARM] PC + offset address. +#if 0 +//! Creates `[PC + offset]` (relative) memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const PC& pc, int32_t offset = 0) noexcept { + return Mem(pc, offset); +} +#endif + +//! \} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +//! \cond INTERNAL +ASMJIT_BEGIN_NAMESPACE +ASMJIT_DEFINE_TYPE_ID(a64::GpW, TypeId::kInt32); +ASMJIT_DEFINE_TYPE_ID(a64::GpX, TypeId::kInt64); +ASMJIT_DEFINE_TYPE_ID(a64::VecS, TypeId::kFloat32x1); +ASMJIT_DEFINE_TYPE_ID(a64::VecD, TypeId::kFloat64x1); +ASMJIT_DEFINE_TYPE_ID(a64::VecV, TypeId::kInt32x4); +ASMJIT_END_NAMESPACE +//! \endcond + +#endif // ASMJIT_ARM_A64OPERAND_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/armglobals.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/armglobals.h new file mode 100644 index 0000000000000000000000000000000000000000..6d1948ea4e7d3271769ae33166496b19c8db5522 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/armglobals.h @@ -0,0 +1,17 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_ARMGLOBALS_H_INCLUDED +#define ASMJIT_ARM_ARMGLOBALS_H_INCLUDED + +#include "../core/archcommons.h" +#include "../core/inst.h" + +//! \namespace asmjit::arm +//! \ingroup asmjit_arm +//! +//! API shared between AArch32 & AArch64 backends. + +#endif // ASMJIT_ARM_ARMGLOBALS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/armoperand.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/armoperand.h new file mode 100644 index 0000000000000000000000000000000000000000..c4757ff9590031297a16b415a4c8a6ebedc75d86 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/armoperand.h @@ -0,0 +1,396 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_ARMOPERAND_H_INCLUDED +#define ASMJIT_ARM_ARMOPERAND_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../arm/armglobals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(arm) + +//! \addtogroup asmjit_arm +//! \{ + +class Reg; +class Mem; + +//! Register traits (AArch32/AArch64). +//! +//! Register traits contains information about a particular register type. It's used by asmjit to setup register +//! information on-the-fly and to populate tables that contain register information (this way it's possible to +//! change register types and groups without having to reorder these tables). +template +struct RegTraits : public BaseRegTraits {}; + +//! \cond +// <--------------------+------------------------+------------------------+---+------------------+ +// | Reg-Type | Reg-Group |Sz | TypeId | +// <--------------------+------------------------+------------------------+---+------------------+ +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_GpW , RegGroup::kGp , 4 , TypeId::kInt32 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_GpX , RegGroup::kGp , 8 , TypeId::kInt64 ); // AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecB , RegGroup::kVec , 1 , TypeId::kVoid ); // AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecH , RegGroup::kVec , 2 , TypeId::kVoid ); // AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecS , RegGroup::kVec , 4 , TypeId::kInt32x1 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecD , RegGroup::kVec , 8 , TypeId::kInt32x2 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_VecQ , RegGroup::kVec , 16, TypeId::kInt32x4 ); // AArch32 & AArch64 +ASMJIT_DEFINE_REG_TRAITS(RegType::kARM_PC , RegGroup::kPC , 8 , TypeId::kInt64 ); // AArch64 +//! \endcond + +//! Register operand that can represent AArch32 and AArch64 registers. +class Reg : public BaseReg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Reg, BaseReg) + + //! Gets whether the register is either `R` or `W` register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpR() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is either `R` or `W` register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpW() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is an `X` register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpX() const noexcept { return baseSignature() == RegTraits::kSignature; } + + //! Gets whether the register is a VEC-B register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecB() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-H register (16-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecH() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-S register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecS() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-D register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecD() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a VEC-Q register (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecQ() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is either VEC-D (64-bit) or VEC-Q (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecDOrQ() const noexcept { return uint32_t(type()) - uint32_t(RegType::kARM_VecD) <= 1u; } + //! Gets whether the register is a VEC-V register (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isVecV() const noexcept { return baseSignature() == RegTraits::kSignature; } + + //! Gets whether the register is an 8-bit vector register or view, alias if \ref isVecB(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec8() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 16-bit vector register or view, alias if \ref isVecH(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec16() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 32-bit vector register or view, alias if \ref isVecS(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec32() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 64-bit vector register or view, alias if \ref isVecD(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec64() const noexcept { return baseSignature() == RegTraits::kSignature; } + //! Gets whether the register is a 128-bit vector register or view, alias if \ref isVecQ(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec128() const noexcept { return baseSignature() == RegTraits::kSignature; } + + template + ASMJIT_INLINE_NODEBUG void setRegT(uint32_t id) noexcept { + setSignature(RegTraits::kSignature); + setId(id); + } + + ASMJIT_INLINE_NODEBUG void setTypeAndId(RegType type, uint32_t id) noexcept { + setSignature(signatureOf(type)); + setId(id); + } + + static ASMJIT_INLINE_NODEBUG RegGroup groupOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kAArch64).regTypeToGroup(type); } + static ASMJIT_INLINE_NODEBUG TypeId typeIdOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kAArch64).regTypeToTypeId(type); } + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kAArch64).regTypeToSignature(type); } + + template + static ASMJIT_INLINE_NODEBUG RegGroup groupOfT() noexcept { return RegTraits::kGroup; } + + template + static ASMJIT_INLINE_NODEBUG TypeId typeIdOfT() noexcept { return RegTraits::kTypeId; } + + template + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfT() noexcept { return OperandSignature{RegTraits::kSignature}; } + + static ASMJIT_INLINE_NODEBUG bool isGpW(const Operand_& op) noexcept { return op.as().isGpW(); } + static ASMJIT_INLINE_NODEBUG bool isGpX(const Operand_& op) noexcept { return op.as().isGpX(); } + static ASMJIT_INLINE_NODEBUG bool isVecB(const Operand_& op) noexcept { return op.as().isVecB(); } + static ASMJIT_INLINE_NODEBUG bool isVecH(const Operand_& op) noexcept { return op.as().isVecH(); } + static ASMJIT_INLINE_NODEBUG bool isVecS(const Operand_& op) noexcept { return op.as().isVecS(); } + static ASMJIT_INLINE_NODEBUG bool isVecD(const Operand_& op) noexcept { return op.as().isVecD(); } + static ASMJIT_INLINE_NODEBUG bool isVecQ(const Operand_& op) noexcept { return op.as().isVecQ(); } + static ASMJIT_INLINE_NODEBUG bool isVecV(const Operand_& op) noexcept { return op.as().isVecV(); } + + static ASMJIT_INLINE_NODEBUG bool isGpW(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isGpW(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isGpX(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isGpX(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecB(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecB(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecH(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecH(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecS(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecS(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecD(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecD(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecQ(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecQ(op)) & unsigned(op.id() == id)); } + static ASMJIT_INLINE_NODEBUG bool isVecV(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVecV(op)) & unsigned(op.id() == id)); } +}; + +//! Vector register base - a common base for both AArch32 & AArch64 vector register. +class BaseVec : public Reg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(BaseVec, Reg) + + //! Additional signature bits used by a vector register. + enum AdditionalBits : uint32_t { + // Register element type (3 bits). + // |........|........|.XXX....|........| + kSignatureRegElementTypeShift = 12, + kSignatureRegElementTypeMask = 0x07 << kSignatureRegElementTypeShift, + + // Register has element index (1 bit). + // |........|........|X.......|........| + kSignatureRegElementFlagShift = 15, + kSignatureRegElementFlagMask = 0x01 << kSignatureRegElementFlagShift, + + // Register element index (4 bits). + // |........|....XXXX|........|........| + kSignatureRegElementIndexShift = 16, + kSignatureRegElementIndexMask = 0x0F << kSignatureRegElementIndexShift + }; + + //! Returns whether the register has element index (it's an element index access). + ASMJIT_INLINE_NODEBUG constexpr bool hasElementIndex() const noexcept { return _signature.hasField(); } + //! Returns element index of the register. + ASMJIT_INLINE_NODEBUG constexpr uint32_t elementIndex() const noexcept { return _signature.getField(); } + //! Sets element index of the register to `elementType`. + ASMJIT_INLINE_NODEBUG void setElementIndex(uint32_t elementIndex) noexcept { + _signature |= kSignatureRegElementFlagMask; + _signature.setField(elementIndex); + } + //! Resets element index of the register. + ASMJIT_INLINE_NODEBUG void resetElementIndex() noexcept { + _signature &= ~(kSignatureRegElementFlagMask | kSignatureRegElementIndexMask); + } +}; + +//! Memory operand (ARM). +class Mem : public BaseMem { +public: + //! \cond INTERNAL + //! Additional bits of operand's signature used by `arm::Mem`. + enum AdditionalBits : uint32_t { + // Index shift value (5 bits). + // |........|.....XXX|XX......|........| + kSignatureMemShiftValueShift = 14, + kSignatureMemShiftValueMask = 0x1Fu << kSignatureMemShiftValueShift, + + // Index shift operation (4 bits). + // |........|XXXX....|........|........| + kSignatureMemShiftOpShift = 20, + kSignatureMemShiftOpMask = 0x0Fu << kSignatureMemShiftOpShift, + + // Offset mode type (2 bits). + // |......XX|........|........|........| + kSignatureMemOffsetModeShift = 24, + kSignatureMemOffsetModeMask = 0x03u << kSignatureMemOffsetModeShift + }; + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Construct a default `Mem` operand, that points to [0]. + ASMJIT_INLINE_NODEBUG constexpr Mem() noexcept + : BaseMem() {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Mem& other) noexcept + : BaseMem(other) {} + + ASMJIT_INLINE_NODEBUG explicit Mem(Globals::NoInit_) noexcept + : BaseMem(Globals::NoInit) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Signature& signature, uint32_t baseId, uint32_t indexId, int32_t offset) noexcept + : BaseMem(signature, baseId, indexId, offset) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(const Label& base, int32_t off = 0, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(RegType::kLabelTag) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(const BaseReg& base, int32_t off = 0, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, const BaseReg& index, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromMemIndexType(index.type()) | + signature, base.id(), index.id(), 0) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, const BaseReg& index, const Shift& shift, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(uint32_t(shift.op())) | + Signature::fromValue(shift.value()) | + signature, base.id(), index.id(), 0) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(uint64_t base, Signature signature = Signature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + signature, uint32_t(base >> 32), 0, int32_t(uint32_t(base & 0xFFFFFFFFu))) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Mem& operator=(const Mem& other) noexcept = default; + + //! \} + + //! \name Clone + //! \{ + + //! Clones the memory operand. + ASMJIT_INLINE_NODEBUG constexpr Mem clone() const noexcept { return Mem(*this); } + + //! Gets new memory operand adjusted by `off`. + ASMJIT_INLINE_NODEBUG Mem cloneAdjusted(int64_t off) const noexcept { + Mem result(*this); + result.addOffset(off); + return result; + } + + //! Clones the memory operand and makes it pre-index. + ASMJIT_INLINE_NODEBUG Mem pre() const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPreIndex); + return result; + } + + //! Clones the memory operand, applies a given offset `off` and makes it pre-index. + ASMJIT_INLINE_NODEBUG Mem pre(int64_t off) const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPreIndex); + result.addOffset(off); + return result; + } + + //! Clones the memory operand and makes it post-index. + ASMJIT_INLINE_NODEBUG Mem post() const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPostIndex); + return result; + } + + //! Clones the memory operand, applies a given offset `off` and makes it post-index. + ASMJIT_INLINE_NODEBUG Mem post(int64_t off) const noexcept { + Mem result(*this); + result.setOffsetMode(OffsetMode::kPostIndex); + result.addOffset(off); + return result; + } + + //! \} + + //! \name Base & Index + //! \{ + + //! Converts memory `baseType` and `baseId` to `arm::Reg` instance. + //! + //! The memory must have a valid base register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg baseReg() const noexcept { return Reg::fromTypeAndId(baseType(), baseId()); } + + //! Converts memory `indexType` and `indexId` to `arm::Reg` instance. + //! + //! The memory must have a valid index register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg indexReg() const noexcept { return Reg::fromTypeAndId(indexType(), indexId()); } + + using BaseMem::setIndex; + + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index, uint32_t shift) noexcept { + setIndex(index); + setShift(shift); + } + + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index, Shift shift) noexcept { + setIndex(index); + setShift(shift); + } + + //! \} + + //! \name ARM Specific Features + //! \{ + + //! Gets offset mode. + ASMJIT_INLINE_NODEBUG constexpr OffsetMode offsetMode() const noexcept { return OffsetMode(_signature.getField()); } + //! Sets offset mode to `mode`. + ASMJIT_INLINE_NODEBUG void setOffsetMode(OffsetMode mode) noexcept { _signature.setField(uint32_t(mode)); } + //! Resets offset mode to default (fixed offset, without write-back). + ASMJIT_INLINE_NODEBUG void resetOffsetMode() noexcept { _signature.setField(uint32_t(OffsetMode::kFixed)); } + + //! Tests whether the current memory offset mode is fixed (see \ref OffsetMode::kFixed). + ASMJIT_INLINE_NODEBUG constexpr bool isFixedOffset() const noexcept { return offsetMode() == OffsetMode::kFixed; } + //! Tests whether the current memory offset mode is either pre-index or post-index (write-back is used). + ASMJIT_INLINE_NODEBUG constexpr bool isPreOrPost() const noexcept { return offsetMode() != OffsetMode::kFixed; } + //! Tests whether the current memory offset mode is pre-index (write-back is used). + ASMJIT_INLINE_NODEBUG constexpr bool isPreIndex() const noexcept { return offsetMode() == OffsetMode::kPreIndex; } + //! Tests whether the current memory offset mode is post-index (write-back is used). + ASMJIT_INLINE_NODEBUG constexpr bool isPostIndex() const noexcept { return offsetMode() == OffsetMode::kPostIndex; } + + //! Sets offset mode of this memory operand to pre-index (write-back is used). + ASMJIT_INLINE_NODEBUG void makePreIndex() noexcept { setOffsetMode(OffsetMode::kPreIndex); } + //! Sets offset mode of this memory operand to post-index (write-back is used). + ASMJIT_INLINE_NODEBUG void makePostIndex() noexcept { setOffsetMode(OffsetMode::kPostIndex); } + + //! Gets shift operation that is used by index register. + ASMJIT_INLINE_NODEBUG constexpr ShiftOp shiftOp() const noexcept { return ShiftOp(_signature.getField()); } + //! Sets shift operation that is used by index register. + ASMJIT_INLINE_NODEBUG void setShiftOp(ShiftOp sop) noexcept { _signature.setField(uint32_t(sop)); } + //! Resets shift operation that is used by index register to LSL (default value). + ASMJIT_INLINE_NODEBUG void resetShiftOp() noexcept { _signature.setField(uint32_t(ShiftOp::kLSL)); } + + //! Gets whether the memory operand has shift (aka scale) constant. + ASMJIT_INLINE_NODEBUG constexpr bool hasShift() const noexcept { return _signature.hasField(); } + //! Gets the memory operand's shift (aka scale) constant. + ASMJIT_INLINE_NODEBUG constexpr uint32_t shift() const noexcept { return _signature.getField(); } + //! Sets the memory operand's shift (aka scale) constant. + ASMJIT_INLINE_NODEBUG void setShift(uint32_t shift) noexcept { _signature.setField(shift); } + + //! Sets the memory operand's shift and shift operation. + ASMJIT_INLINE_NODEBUG void setShift(Shift shift) noexcept { + _signature.setField(uint32_t(shift.op())); + _signature.setField(shift.value()); + } + + //! Resets the memory operand's shift (aka scale) constant to zero. + ASMJIT_INLINE_NODEBUG void resetShift() noexcept { _signature.setField(0); } + + //! \} +}; + +//! \name Shift Operation Construction +//! \{ + +//! Constructs a `LSL #value` shift (logical shift left). +static ASMJIT_INLINE_NODEBUG constexpr Shift lsl(uint32_t value) noexcept { return Shift(ShiftOp::kLSL, value); } +//! Constructs a `LSR #value` shift (logical shift right). +static ASMJIT_INLINE_NODEBUG constexpr Shift lsr(uint32_t value) noexcept { return Shift(ShiftOp::kLSR, value); } +//! Constructs a `ASR #value` shift (arithmetic shift right). +static ASMJIT_INLINE_NODEBUG constexpr Shift asr(uint32_t value) noexcept { return Shift(ShiftOp::kASR, value); } +//! Constructs a `ROR #value` shift (rotate right). +static ASMJIT_INLINE_NODEBUG constexpr Shift ror(uint32_t value) noexcept { return Shift(ShiftOp::kROR, value); } +//! Constructs a `RRX` shift (rotate with carry by 1). +static ASMJIT_INLINE_NODEBUG constexpr Shift rrx() noexcept { return Shift(ShiftOp::kRRX, 0); } +//! Constructs a `MSL #value` shift (logical shift left filling ones). +static ASMJIT_INLINE_NODEBUG constexpr Shift msl(uint32_t value) noexcept { return Shift(ShiftOp::kMSL, value); } + +//! \} + +//! \name Memory Operand Construction +//! \{ + +//! Creates `[base]` absolute memory operand (AArch32 or AArch64). +//! +//! \note The concept of absolute memory operands doesn't exist on ARM, the ISA only provides PC relative addressing. +//! Absolute memory operands can only be used if it's known that the PC relative offset is encodable and that it +//! would be within the limits. Absolute address is also often output from disassemblers, so AsmJit supports it to +//! make it possible to assemble such output back. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base) noexcept { return Mem(base); } + +//! \} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_ARMOPERAND_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/arm/armutils.h b/phivenv/Lib/site-packages/torch/include/asmjit/arm/armutils.h new file mode 100644 index 0000000000000000000000000000000000000000..f97b97527a5b3c7d6a5e50dd8c7f2fadcb683dc5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/arm/armutils.h @@ -0,0 +1,226 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_ARM_ARMUTILS_H_INCLUDED +#define ASMJIT_ARM_ARMUTILS_H_INCLUDED + +#include "../core/support.h" +#include "../arm/armglobals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(arm) + +//! \addtogroup asmjit_arm +//! \{ + +//! Public utilities and helpers for targeting AArch32 and AArch64 architectures. +namespace Utils { + +//! Encodes a 12-bit immediate part of opcode that ise used by a standard 32-bit ARM encoding. +ASMJIT_MAYBE_UNUSED +static inline bool encodeAArch32Imm(uint64_t imm, uint32_t* encodedImmOut) noexcept { + if (imm & 0xFFFFFFFF00000000u) + return false; + + uint32_t v = uint32_t(imm); + uint32_t r = 0; + + if (v <= 0xFFu) { + *encodedImmOut = v; + return true; + } + + // Rotate if there are bits on both ends (LSB and MSB) + // (otherwise we would not be able to calculate the rotation with ctz). + if (v & 0xFF0000FFu) { + v = Support::ror(v, 16); + r = 16u; + } + + uint32_t n = Support::ctz(v) & ~0x1u; + r = (r - n) & 0x1Eu; + v = Support::ror(v, n); + + if (v > 0xFFu) + return false; + + *encodedImmOut = v | (r << 7); + return true; +} + +//! Decomposed fields of a logical immediate value. +struct LogicalImm { + uint32_t n; + uint32_t s; + uint32_t r; +}; + +//! Encodes the given `imm` value of the given `width` to a logical immediate value represented as N, S, and R fields +//! and writes these fields to `out`. +//! +//! Encoding Table: +//! +//! ``` +//! +---+--------+--------+------+ +//! | N | ImmS | ImmR | Size | +//! +---+--------+--------+------+ +//! | 1 | ssssss | rrrrrr | 64 | +//! | 0 | 0sssss | .rrrrr | 32 | +//! | 0 | 10ssss | ..rrrr | 16 | +//! | 0 | 110sss | ...rrr | 8 | +//! | 0 | 1110ss | ....rr | 4 | +//! | 0 | 11110s | .....r | 2 | +//! +---+--------+--------+------+ +//! ``` +ASMJIT_MAYBE_UNUSED +static bool encodeLogicalImm(uint64_t imm, uint32_t width, LogicalImm* out) noexcept { + // Determine the element width, which must be 2, 4, 8, 16, 32, or 64 bits. + do { + width /= 2; + uint64_t mask = (uint64_t(1) << width) - 1u; + if ((imm & mask) != ((imm >> width) & mask)) { + width *= 2; + break; + } + } while (width > 2); + + // Patterns of all zeros and all ones are not encodable. + uint64_t lsbMask = Support::lsbMask(width); + imm &= lsbMask; + + if (imm == 0 || imm == lsbMask) + return false; + + // Inspect the pattern and get the most important bit indexes. + // + // oIndex <-+ +-> zIndex + // | | + // |..zeros..|oCount|zCount|..ones..| + // |000000000|111111|000000|11111111| + + uint32_t zIndex = Support::ctz(~imm); + uint64_t zImm = imm ^ ((uint64_t(1) << zIndex) - 1); + uint32_t zCount = (zImm ? Support::ctz(zImm) : width) - zIndex; + + uint32_t oIndex = zIndex + zCount; + uint64_t oImm = ~(zImm ^ Support::lsbMask(oIndex)); + uint32_t oCount = (oImm ? Support::ctz(oImm) : width) - (oIndex); + + // Verify whether the bit-pattern is encodable. + uint64_t mustBeZero = oImm ^ ~Support::lsbMask(oIndex + oCount); + if (mustBeZero != 0 || (zIndex > 0 && width - (oIndex + oCount) != 0)) + return false; + + out->n = width == 64; + out->s = (oCount + zIndex - 1) | (Support::neg(width * 2) & 0x3F); + out->r = width - oIndex; + return true; +} + +//! Returns true if the given `imm` value is encodable as a logical immediate. The `width` argument describes the +//! width of the operation, and must be either 32 or 64. This function can be used to test whether an immediate +//! value can be used with AND, ANDS, BIC, BICS, EON, EOR, ORN, and ORR instruction. +ASMJIT_MAYBE_UNUSED +static ASMJIT_INLINE_NODEBUG bool isLogicalImm(uint64_t imm, uint32_t width) noexcept { + LogicalImm dummy; + return encodeLogicalImm(imm, width, &dummy); +} + +//! Returns true if the given `imm` value is encodable as an immediate with `add` and `sub` instructions on AArch64. +//! These two instructions can encode 12-bit immediate value optionally shifted left by 12 bits. +ASMJIT_MAYBE_UNUSED +static ASMJIT_INLINE_NODEBUG bool isAddSubImm(uint64_t imm) noexcept { + return imm <= 0xFFFu || (imm & ~uint64_t(0xFFFu << 12)) == 0; +} + +//! Returns true if the given `imm` value is a byte mask. Byte mask has each byte part of the value set to either +//! 0x00 or 0xFF. Some ARM instructions accept immediates that form a byte-mask and this function can be used to +//! verify that the immediate is encodable before using the value. +template +static ASMJIT_INLINE_NODEBUG bool isByteMaskImm8(const T& imm) noexcept { + constexpr T kMask = T(0x0101010101010101 & Support::allOnes()); + return imm == (imm & kMask) * T(255); +} + +// [.......A|B.......|.......C|D.......|.......E|F.......|.......G|H.......] +static ASMJIT_INLINE_NODEBUG uint32_t encodeImm64ByteMaskToImm8(uint64_t imm) noexcept { + return uint32_t(((imm >> (7 - 0)) & 0b00000011) | // [.......G|H.......] + ((imm >> (23 - 2)) & 0b00001100) | // [.......E|F.......] + ((imm >> (39 - 4)) & 0b00110000) | // [.......C|D.......] + ((imm >> (55 - 6)) & 0b11000000)); // [.......A|B.......] +} +//! \cond +//! A generic implementation that checjs whether a floating point value can be converted to ARM Imm8. +template +static ASMJIT_FORCE_INLINE bool isFPImm8Generic(T val) noexcept { + constexpr uint32_t kAllBsMask = Support::lsbMask(kNumBBits); + constexpr uint32_t kB0Pattern = Support::bitMask(kNumBBits - 1); + constexpr uint32_t kB1Pattern = kAllBsMask ^ kB0Pattern; + + T immZ = val & Support::lsbMask(kNumZeroBits); + uint32_t immB = uint32_t(val >> (kNumZeroBits + kNumCDEFGHBits)) & kAllBsMask; + + // ImmZ must be all zeros and ImmB must either be B0 or B1 pattern. + return immZ == 0 && (immB == kB0Pattern || immB == kB1Pattern); +} +//! \endcond + +//! Returns true if the given half precision floating point `val` can be encoded as ARM IMM8 value, which represents +//! a limited set of floating point immediate values, which can be used with FMOV instruction. +//! +//! The floating point must have bits distributed in the following way: +//! +//! ``` +//! [aBbbcdef|gh000000] +//! ``` +static ASMJIT_INLINE_NODEBUG bool isFP16Imm8(uint32_t val) noexcept { return isFPImm8Generic(val); } + +//! Returns true if the given single precision floating point `val` can be encoded as ARM IMM8 value, which represents +//! a limited set of floating point immediate values, which can be used with FMOV instruction. +//! +//! The floating point must have bits distributed in the following way: +//! +//! ``` +//! [aBbbbbbc|defgh000|00000000|00000000] +//! ``` +static ASMJIT_INLINE_NODEBUG bool isFP32Imm8(uint32_t val) noexcept { return isFPImm8Generic(val); } +//! \overload +static ASMJIT_INLINE_NODEBUG bool isFP32Imm8(float val) noexcept { return isFP32Imm8(Support::bitCast(val)); } + +//! Returns true if the given double precision floating point `val` can be encoded as ARM IMM8 value, which represents +//! a limited set of floating point immediate values, which can be used with FMOV instruction. +//! +//! The floating point must have bits distributed in the following way: +//! +//! ``` +//! [aBbbbbbb|bbcdefgh|00000000|00000000|00000000|00000000|00000000|00000000] +//! ``` +static ASMJIT_INLINE_NODEBUG bool isFP64Imm8(uint64_t val) noexcept { return isFPImm8Generic(val); } +//! \overload +static ASMJIT_INLINE_NODEBUG bool isFP64Imm8(double val) noexcept { return isFP64Imm8(Support::bitCast(val)); } + +//! \cond +template +static ASMJIT_INLINE_NODEBUG uint32_t encodeFPToImm8Generic(T val) noexcept { + uint32_t bits = uint32_t(val >> kNumZeroBits); + return ((bits >> (kNumBBits + kNumCDEFGHBits - 7)) & 0x80u) | (bits & 0x7F); +} +//! \endcond + +//! Encodes a double precision floating point value into IMM8 format. +//! +//! \note This function expects that `isFP64Imm8(val) == true` so it doesn't perform any checks of the value and just +//! rearranges some bits into Imm8 order. +static ASMJIT_INLINE_NODEBUG uint32_t encodeFP64ToImm8(uint64_t val) noexcept { return encodeFPToImm8Generic(val); } +//! \overload +static ASMJIT_INLINE_NODEBUG uint32_t encodeFP64ToImm8(double val) noexcept { return encodeFP64ToImm8(Support::bitCast(val)); } + +} // {Utils} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_ARM_ARMUTILS_H_INCLUDED + diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/asmjit-scope-begin.h b/phivenv/Lib/site-packages/torch/include/asmjit/asmjit-scope-begin.h new file mode 100644 index 0000000000000000000000000000000000000000..6c292260b4a5a8d02b3333f06e9b404d2dfff2c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/asmjit-scope-begin.h @@ -0,0 +1,17 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifdef _WIN32 + #pragma push_macro("min") + #pragma push_macro("max") + + #ifdef min + #undef min + #endif + + #ifdef max + #undef max + #endif +#endif diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/asmjit-scope-end.h b/phivenv/Lib/site-packages/torch/include/asmjit/asmjit-scope-end.h new file mode 100644 index 0000000000000000000000000000000000000000..0baec9ce22882f35b9d6d791b606a725b2f81e4e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/asmjit-scope-end.h @@ -0,0 +1,9 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifdef _WIN32 + #pragma pop_macro("min") + #pragma pop_macro("max") +#endif diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/asmjit.h b/phivenv/Lib/site-packages/torch/include/asmjit/asmjit.h new file mode 100644 index 0000000000000000000000000000000000000000..2e82d507ac7069bc068c3c694f91a8ecf285b732 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/asmjit.h @@ -0,0 +1,33 @@ +// This file is part of AsmJit project +// +// SPDX-License-Identifier: Zlib +// Official GitHub Repository: https://github.com/asmjit/asmjit +// +// Copyright (c) 2008-2024 The AsmJit Authors +// +// This software is provided 'as-is', without any express or implied +// warranty. In no event will the authors be held liable for any damages +// arising from the use of this software. +// +// Permission is granted to anyone to use this software for any purpose, +// including commercial applications, and to alter it and redistribute it +// freely, subject to the following restrictions: +// +// 1. The origin of this software must not be misrepresented; you must not +// claim that you wrote the original software. If you use this software +// in a product, an acknowledgment in the product documentation would be +// appreciated but is not required. +// 2. Altered source versions must be plainly marked as such, and must not be +// misrepresented as being the original software. +// 3. This notice may not be removed or altered from any source distribution. + +#ifndef ASMJIT_ASMJIT_H_INCLUDED +#define ASMJIT_ASMJIT_H_INCLUDED + +#include "./core.h" + +#ifndef ASMJIT_NO_X86 + #include "./x86.h" +#endif + +#endif // ASMJIT_ASMJIT_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core.h b/phivenv/Lib/site-packages/torch/include/asmjit/core.h new file mode 100644 index 0000000000000000000000000000000000000000..ea3542eadd17676d28828c81d54d1e9a723b8ccf --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core.h @@ -0,0 +1,1991 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_H_INCLUDED +#define ASMJIT_CORE_H_INCLUDED + +//! Root namespace used by AsmJit. +namespace asmjit { + +//! \mainpage API Reference +//! +//! AsmJit C++ API reference documentation generated by Doxygen. +//! +//! AsmJit library uses one global namespace called \ref asmjit, which provides the whole functionality. Core +//! functionality is within \ref asmjit namespace and architecture specific functionality is always in its own +//! namespace. For example \ref asmjit::x86 provides both 32-bit and 64-bit X86 code generation. +//! +//! \section main_groups Documentation Groups +//! +//! AsmJit documentation is structured into groups. Groups can be followed in order to learn AsmJit, but knowledge +//! from multiple groups is required to use AsmJit properly: +//! +//! $$DOCS_GROUP_OVERVIEW$$ +//! +//! \note It's important to understand that in order to learn AsmJit all groups are important. Some groups can be +//! omitted if a particular tool is out of interest - for example \ref asmjit_assembler users don't need to know +//! about \ref asmjit_builder, but it's not the opposite. \ref asmjit_builder users should know about \ref +//! asmjit_assembler as it also uses operands, labels, and other concepts. Similarly \ref asmjit_compiler users +//! should know how both \ref asmjit_assembler and \ref asmjit_builder tools work. +//! +//! \section where_to_start Where To Start +//! +//! AsmJit \ref asmjit_core provides the following two classes that are essential from the code generation perspective: +//! +//! - \ref CodeHolder provides functionality to temporarily hold the generated code. It stores all the necessary +//! information about the code - code buffers, sections, labels, symbols, and information about relocations. +//! +//! - \ref BaseEmitter provides interface used by emitter implementations. The interface provides basic building +//! blocks that are then implemented by \ref BaseAssembler, \ref BaseBuilder, and \ref BaseCompiler. +//! +//! Code emitters: +//! +//! - \ref asmjit_assembler - provides direct machine code generation. +//! +//! - \ref asmjit_builder - provides intermediate code generation that can be processed before it's serialized to +//! \ref BaseAssembler. +//! +//! - \ref asmjit_compiler - provides high-level code generation with built-in register allocation. +//! +//! - \ref FuncNode - provides insight into how function looks from the Compiler perspective and how it's stored in +//! a node-list. +//! +//! \section main_recommendations Recommendations +//! +//! The following steps are recommended for all AsmJit users: +//! +//! - Make sure that you use \ref Logger, see \ref asmjit_logging. +//! +//! - Make sure that you use \ref ErrorHandler, see \ref asmjit_error_handling. +//! +//! - Instruction validation in your debug builds can reveal problems too. AsmJit provides validation at instruction +//! level that can be enabled via \ref BaseEmitter::addDiagnosticOptions(). See \ref DiagnosticOptions for more +//! details. +//! +//! - If you are a Compiler user, use diagnostic options and read carefully if anything suspicious pops out. +//! Diagnostic options can be enabled via \ref BaseEmitter::addDiagnosticOptions(). If unsure which ones to use, +//! enable annotations and all debug options: `DiagnosticOptions::kRAAnnotate | DiagnosticOptions::kRADebugAll`. +//! +//! - Make sure you put a breakpoint into \ref DebugUtils::errored() function if you have a problem with AsmJit +//! returning errors during instruction encoding or register allocation. Having an active breakpoint there can +//! help to reveal the origin of the error, to inspect variables and other conditions that caused it. +//! +//! The reason for using \ref Logger and \ref ErrorHandler is that they provide a very useful information about what's +//! happening inside emitters. In many cases the information provided by these two is crucial to quickly identify and +//! fix issues that happen during development (for example wrong instruction, address, or register used). In addition, +//! output from \ref Logger is always necessary when filling bug reports. In other words, using logging and proper error +//! handling can save a lot of time during the development and can also save users from submitting issues. +//! +//! \section main_other Other Pages +//! +//! - Class List - List of classes sorted alphabetically +//! - AsmJit Namespace - List of symbols provided by `asmjit` namespace + + +//! \defgroup asmjit_build Build Instructions +//! \brief Build instructions, supported environments, and feature selection. +//! +//! ### Overview +//! +//! AsmJit is designed to be easy embeddable in any project. However, it depends on some compile-time definitions that +//! can be used to enable or disable features to decrease the resulting binary size. A typical way of building AsmJit +//! is to use [cmake](https://www.cmake.org), but it's also possible to just include AsmJit source code in your project +//! and to just build it. The easiest way to include AsmJit in your project is to just include **src** directory in +//! your project and to define \ref ASMJIT_STATIC. AsmJit can be just updated from time to time without any changes to +//! this integration process. Do not embed AsmJit's `test` files in such case as these are used exclusively for testing. +//! +//! ### Supported C++ Compilers +//! +//! - Requirements: +//! +//! - AsmJit won't build without C++11 enabled. If you use older GCC or Clang you would have to enable at least +//! C++11 standard through compiler flags. +//! +//! - Tested: +//! +//! - **Clang** - Tested by GitHub Actions - Clang 10+ is officially supported and tested by CI, older Clang versions +//! having C++11 should work, but are not tested anymore due to upgraded CI images. +//! +//! - **GNU** - Tested by GitHub Actions - GCC 7+ is officially supported, older GCC versions from 4.8+ having C++11 +//! enabled should also work, but are not tested anymore due to upgraded CI images. +//! +//! - **MINGW** - Reported to work, but not tested in our CI environment (help welcome). +//! +//! - **MSVC** - Tested by GitHub Actions - VS2019+ is officially supported, VS2015 and VS2017 is reported to work, +//! but not tested by CI anymore. +//! +//! ### Supported Operating Systems and Platforms +//! +//! - Tested: +//! +//! - **BSD** - FreeBSD, NetBSD, and OpenBSD tested by GitHub Actions (only recent images are tested by CI). BSD +//! runners only test BSD images with clang compiler. +//! +//! - **Linux** - Tested by GitHub Actions (only recent Ubuntu images are tested by CI, in general any distribution +//! should be supported as AsmJit has no dependencies). +//! +//! - **Mac OS** - Tested by GitHub Actions. +//! +//! - **Windows** - Tested by GitHub Actions - (Windows 7+ is officially supported). +//! +//! - **Emscripten** - Works if compiled with \ref ASMJIT_NO_JIT. AsmJit cannot generate WASM code, but can be +//! used to generate X86/X64/AArch64 code within a browser, for example. +//! +//! - Untested: +//! +//! - **Haiku** - Reported to work, not tested by CI. +//! +//! - **Other** operating systems would require some testing and support in the following files: +//! - [core/api-config.h](https://github.com/asmjit/asmjit/tree/master/src/asmjit/core/api-config.h) +//! - [core/osutils.cpp](https://github.com/asmjit/asmjit/tree/master/src/asmjit/core/osutils.cpp) +//! - [core/virtmem.cpp](https://github.com/asmjit/asmjit/tree/master/src/asmjit/core/virtmem.cpp) +//! +//! ### Supported Backends / Architectures +//! +//! - **X86** and **X86_64** - Both 32-bit and 64-bit backends tested on CI. +//! - **AArch64** - AArch64 backend is currently only partially tested (there is no native AArch64 runner to test +//! AsmJit Builder/Compiler). +//! +//! ### Static Builds and Embedding +//! +//! These definitions can be used to enable static library build. Embed is used when AsmJit's source code is embedded +//! directly in another project, implies static build as well. +//! +//! - \ref ASMJIT_EMBED - Asmjit is embedded, implies \ref ASMJIT_STATIC. +//! - \ref ASMJIT_STATIC - Enable static-library build. +//! +//! \note Projects that use AsmJit statically must define \ref ASMJIT_STATIC in all compilation units that use AsmJit, +//! otherwise AsmJit would use dynamic library imports in \ref ASMJIT_API decorator. The recommendation is to define +//! this macro across the whole project that uses AsmJit this way. +//! +//! ### Build Configuration +//! +//! These definitions control whether asserts are active or not. By default AsmJit would autodetect build configuration +//! from existing pre-processor definitions, but this behavior can be overridden, for example to enable debug asserts +//! in release configuration. +//! +//! - \ref ASMJIT_BUILD_DEBUG - Overrides build configuration to debug, asserts will be enabled in this case. +//! - \ref ASMJIT_BUILD_RELEASE - Overrides build configuration to release, asserts will be disabled in this case. +//! +//! \note There is usually no need to override the build configuration. AsmJit detects the build configuration by +//! checking whether `NDEBUG` is defined and automatically defines \ref ASMJIT_BUILD_RELEASE if configuration overrides +//! were not used. We only recommend using build configuration overrides in special situations, like using AsmJit in +//! release configuration with asserts enabled for whatever reason. +//! +//! ### AsmJit Backends +//! +//! AsmJit currently supports only X86/X64 backend, but the plan is to add more backends in the future. By default +//! AsmJit builds only the host backend, which is auto-detected at compile-time, but this can be overridden. +//! +//! - \ref ASMJIT_NO_X86 - Disables both X86 and X86_64 backends. +//! - \ref ASMJIT_NO_AARCH64 - Disables AArch64 backend. +//! - \ref ASMJIT_NO_FOREIGN - Disables the support for foreign architecture backends, only keeps a native backend. +//! +//! ### AsmJit Compilation Options +//! +//! - \ref ASMJIT_NO_DEPRECATED - Disables deprecated API at compile time so it won't be available and the +//! compilation will fail if there is attempt to use such API. This includes deprecated classes, namespaces, +//! enumerations, and functions. +//! +//! - \ref ASMJIT_NO_SHM_OPEN - Disables functionality that uses `shm_open()`. +//! +//! - \ref ASMJIT_NO_ABI_NAMESPACE - Disables inline ABI namespace within `asmjit` namespace. This is only provided +//! for users that control all the dependencies (even transitive ones) and that make sure that no two AsmJit +//! versions are used at the same time. This option can be debugging a little simpler as there would not be ABI +//! tag after `asmjit::` namespace. Otherwise asmjit would look like `asmjit::_abi_1_13::`, for example. +//! +//! ### Features Selection +//! +//! AsmJit builds by defaults all supported features, which includes all emitters, logging, instruction validation and +//! introspection, and JIT memory allocation. Features can be disabled at compile time by using `ASMJIT_NO_...` +//! definitions. +//! - \ref ASMJIT_NO_JIT - Disables JIT memory management and \ref JitRuntime. +//! +//! - \ref ASMJIT_NO_TEXT - Disables everything that contains string representation of AsmJit constants, should +//! be used together with \ref ASMJIT_NO_LOGGING as logging doesn't make sense without the ability to query +//! instruction names, register names, etc... +//! +//! - \ref ASMJIT_NO_LOGGING - Disables \ref Logger and \ref Formatter. +//! +//! - \ref ASMJIT_NO_VALIDATION - Disables validation API. +//! +//! - \ref ASMJIT_NO_INTROSPECTION - Disables instruction introspection API, must be used together with \ref +//! ASMJIT_NO_COMPILER as \ref asmjit_compiler requires introspection for its liveness analysis and register +//! allocation. +//! +//! - \ref ASMJIT_NO_BUILDER - Disables \ref asmjit_builder functionality completely. This implies \ref +//! ASMJIT_NO_COMPILER as \ref asmjit_compiler cannot be used without \ref asmjit_builder. +//! +//! - \ref ASMJIT_NO_COMPILER - Disables \ref asmjit_compiler functionality completely. +//! +//! \note It's not recommended to disable features if you plan to build AsmJit as a shared library that will be +//! used by multiple projects that you don't control how AsmJit was built (for example AsmJit in a Linux distribution). +//! The possibility to disable certain features exists mainly for customized AsmJit builds. + + +//! \defgroup asmjit_breaking_changes Breaking Changes +//! \brief Documentation of breaking changes +//! +//! ### Overview +//! +//! AsmJit is a live project that is being actively developed. Deprecating the existing API in favor of a new +//! one is preferred, but it's not always possible if the changes are significant. AsmJit authors prefer to do +//! accumulated breaking changes at once instead of breaking the API often. This page documents deprecated and +//! removed APIs and should serve as a how-to guide for people that want to port existing code to work with the +//! newest AsmJit. +//! +//! ### Tips +//! +//! Useful tips before you start: +//! +//! - Visit our [Public Gitter Chat](https://app.gitter.im/#/room/#asmjit:gitter.im) if you need a quick help. +//! +//! - Build AsmJit with `ASMJIT_NO_DEPRECATED` macro defined to make sure that you are not using deprecated +//! functionality at all. Deprecated functions are decorated with `ASMJIT_DEPRECATED()` macro, but sometimes +//! it's not possible to decorate everything like classes, which are used by deprecated functions as well, +//! because some compilers would warn about that. If your project compiles fine with `ASMJIT_NO_DEPRECATED` +//! it's not using anything, which was deprecated. +//! +//! ### Changes committed at 2024-01-01 +//! +//! Core changes: +//! +//! - Renamed equality functions `eq()` to `equals()` - Only related to `String`, `ZoneVector`, and `CpuFeatures`. +//! Old function names were deprecated. +//! +//! - Removed `CallConvId::kNone` in favor of `CallConvId::kCDecl`, which is now the default calling convention. +//! +//! - Deprecated `CallConvId::kHost` in favor of `CallConvId::kCDecl` - host calling convention is now not part +//! of CallConvId, it can be calculated from CallConvId and Environment instead. +//! +//! ### Changes committed at 2023-12-27 +//! +//! Core changes: +//! +//! - Renamed `a64::Vec::ElementType` to `a64::VecElementType` and made it a typed enum. This enum was used mostly +//! internally, but there is a public API using it, so it's a breaking change. +//! +//! - Refactored `FuncSignature`, `FuncSignatureT`, and `FuncSignatureBuilder`. There is only `FuncSignature` now, +//! which acts as a function signature holder and builder. Replace `FuncSignatureBuilder` with `FuncSignature` +//! and use `FuncSignature::build` instead of `FuncSignatureT`. The old API has been deprecated. +//! +//! - The maximum number of function arguments was raised from 16 to 32. +//! +//! ### Changes committed at 2023-12-26 +//! +//! Core changes: +//! +//! - Reworked InstNode and InstExNode to be friendlier to static analysis and to not cause undefined behavior. +//! InstNode has no operands visually embedded within the struct so there is no _opArray (which was internal). +//! This means that sizeof(InstNode) changed, but since it's allocated by AsmJit this should be fine. Moreover, +//! there is no longer InstExNode as that was more a hack, instead there is now InstNodeWithOperands, which is +//! a template and specifies the number of operands embedded (InstNode accesses these). All nodes that inherited +//! InstExNode now just inherit InstNodeWithOperands, which would provide the same +//! number of nodes as InstNode. +//! +//! - Moved GP and Vec registers from asmjit::arm namespace to asmjit::a64 namespace. At this time there was +//! no prior deprecation as having arm::Vec would collide with a64::Vec as arm namespace is used within a64 +//! namespace. Just change `arm::Gp` to `a64::Gp` and `arm::Vec` to `a64::Vec`. +//! +//! ### Changes committed at 2023-09-10 +//! +//! Core changes: +//! +//! - Changed allocation API to work with spans (JitAllocator). +//! +//! - This change is required to support more hardened platforms in the future that make it very difficult +//! to write JIT compilers. +//! - `JitAllocator::Span` now represents a memory that the user can access. It abstracts both regular and +//! dual mappings. +//! - The `Span` is mostly designed to make it possible to write into it, so in general the read+execute +//! pointer is what user is intended to keep. Use `span.rx()` to access RX pointer. `Span` is not needed +//! after the memory it references has been modified, only remember `span.rx()` pointer, which is then +//! used to deallocate or change the memory the span references. +//! - Use a new `JitAllocator::alloc()` to allocate a `Span`, then pass the populated Span to `JitAllocator` +//! write API such as `JitAllocator::write()` - note that JitAllocator can also establish a scope, so you +//! can use a lambda function that would perform the write, but since it's going through JitAllocator it's +//! able to ensure that the memory is actually writable. +//! - If you need to repopulate a `Span` from rx pointer, use `JitAllocator::query(, rx)` to get it. +//! - Study what JitRuntime is doing to better understand how this new API works in detail. +//! - Users of JitRuntime do not have to do anything as JitRuntime properly abstracts the allocation. +//! +//! - Renamed some X86 CPU features to make them compatible with architecture manuals: +//! +//! - Changed `AVX512_CDI` to `AVX512_CD`. +//! - Changed `AVX512_ERI` to `AVX512_ER`. +//! - Changed `AVX512_PFI` to `AVX512_PF`. +//! +//! - Old names were deprecated. +//! +//! ### Changes committed at 2021-12-13 +//! +//! Core changes: +//! +//! - Removed old deprecated API. +//! +//! - Many enumerations were changed to enum class, and many public APIs were changed to use such enums instead +//! of uint32_t. This change makes some APIs backward incompatible - there are no deprecations this time. +//! +//! - Extracted operand signature manipulation to `OperandSignature`. +//! - Setting function arguments through `Compiler::setArg()` was deprecated, use FuncNode::setArg() instead. +//! - Moved `{arch}::Features::k` to `CpuFeatures::{arch}::k`. +//! - Moved `BaseEmitter::kEncodingOption` to `EncodingOptions::k`. +//! - Moved `BaseEmitter::kFlag` to `EmitterFlags::k`. +//! - Moved `BaseEmitter::kType` to `EmitterType::k`. +//! - Moved `BaseEmitter::kValidationOption` to `DiagnosticOptions::kValidate`. +//! - Moved `BaseFeatures` to `CpuFeatures`. +//! - Moved `BaseInst::kControl` to `InstControlFlow::k`. +//! - Moved `BaseInst::kOption` and `x86::Inst::kOption` to `InstOptions::k`. +//! - Moved `BaseNode::kNode` to `NodeType::k`. +//! - Moved `BaseReg::kGroup` and `x86::Reg::kGroup` to `RegGroup::k`. +//! - Moved `BaseReg::kType` and `x86::Reg::kType` to `RegType::k`. +//! - Moved `CallConv::kFlag` to `CallConvFlags::k`. +//! - Moved `CallConv::kId` to `CallConvId::k`. +//! - Moved `CallConv::kStrategy` to `CallConvStrategy::k`. +//! - Moved `CodeBuffer::kFlag` to `CodeBufferFlags`. +//! - Moved `ConstPool::kScope` to `ConstPoolScope::k`. +//! - Moved `Environment::kArch` to `Arch::k`. +//! - Moved `Environment::kSubArch` to `SubArch::k`. +//! - Moved `Environment::kFormat` to `OjectFormat::k`. +//! - Moved `Environment::kPlatform` to `Platform::k`. +//! - Moved `Environment::kAbi` to `PlatformABI::k`. +//! - Moved `Environment::kVendor` to `Vendor::k`. +//! - Moved `FormatOptions::kFlag` to `FormatFlags::k` and `DiagnosticOptions::k` (Compiler diagnostics flags). +//! - Moved `FormatOptions::kIndentation` to `FormatIndentationGroup::k`. +//! - Moved `FuncFrame::kAttr` to `FuncAttributes::k`. +//! - Moved `Globals::kReset` to `ResetPolicy::k`. +//! - Moved `InstDB::kAvx512Flag` to `InstDB::Avx512Flags::k`. +//! - Moved `InstDB::kFlag` to `InstDB::InstFlags::k`. +//! - Moved `InstDB::kMemFlag` to `InstDB::OpFlags::kMem`. +//! - Moved `InstDB::kMode` to `InstDB::Mode::k`. +//! - Moved `InstDB::kOpFlag` to `InstDB::OpFlags::k{OpType}...`. +//! - Moved `JitAllocator::kOption` to `JitAllocatorOptions::k`. +//! - Moved `Label::kType` to `LabelType::k`. +//! - Moved `Operand::kOpType` to `OperandType::k`. +//! - Moved `OpRWInfo::kFlag` to `OpRWFlags::k`. +//! - Moved `Type::kId` to `TypeId::k`. +//! - Moved `VirtMem::k` to `VirtMem::MemoryFlags::k`. +//! +//! ### Changes committed at 2020-05-30 +//! +//! AsmJit has been cleaned up significantly, many todo items have been fixed and many functions and classes have +//! been redesigned, some in an incompatible way. +//! +//! Core changes: +//! +//! - `Imm` operand has now only `Imm::value()` and `Imm::valueAs()` functions that return its value content, +//! and `Imm::setValue()` function that sets the content. Functions like `setI8()`, `setU8()` were deprecated. +//! +//! Old functions were deprecated, but code using them should still compile. +//! +//! - `ArchInfo` has been replaced with `Environment`. Environment provides more details about the architecture, +//! but drops some properties that were used by arch info - `gpSize(`) and `gpCount()`. `gpSize()` can be replaced +//! with `registerSize()` getter, which returns a native register size of the architecture the environment uses. +//! However, `gpCount()` was removed - at the moment `ArchTraits` can be used to access such properties. +//! +//! Some other functions were renamed, like `ArchInfo::isX86Family()` is now `Environment::isFamilyX86()`, etc. +//! The reason for changing the order was support for more properties and all the accessors now start with the +//! type of the property, like `Environment::isPlatformWindows()`. +//! +//! This function causes many other classes to provide `environment()` getter instead of `archInfo()` getter. +//! In addition, AsmJit now uses `arch()` to get an architecture instead of `archId()`. `ArchInfo::kIdXXX` was +//! renamed to `Environment::kArchXXX`. +//! +//! Some functions were deprecated, some removed... +//! +//! - `CodeInfo` has been removed in favor of `Environment`. If you used `CodeInfo` to set architecture and base +//! address, this is now possible with `Environment` and setting base address explicitly by `CodeHolder::init()` +//! - the first argument is `Environment`, and the second argument is base address, which defaults to +//! `Globals::kNoBaseAddress`. +//! +//! CodeInfo class was deprecated, but the code using it should still compile with warnings. +//! +//! - `CallConv` has been updated to offer a more unified way of representing calling conventions - many calling +//! conventions were abstracted to follow standard naming like `CallConvId::kCDecl` or `CallConvId::kStdCall`. +//! +//! This change means that other APIs like `FuncDetail::init()` now require both, calling convention and target +//! `Environment`. +//! +//! - `Logging` namespace has been renamed to `Formatter`, which now provides general functionality for formatting +//! in AsmJit. +//! +//! Logging namespace should still work, but its use is deprecated. Unfortunately this will be without deprecation +//! warnings, so make sure you don't use it. +//! +//! - `Data64`, `Data128`, and `Data256` structs were deprecated and should no longer be used. There is no replacement, +//! AsmJit users should simply create their own structures if they need them or use the new repeated embed API in +//! emitters, see `BaseEmitter::embedDataArray()`. +//! +//! Emitter changes: +//! +//! - `BaseEmitter::emit()` function signature has been changed to accept 3 operands by reference and the rest 3 +//! operands as a continuous array. This change is purely cosmetic and shouldn't affect users as emit() has many +//! overloads that dispatch to the right function. +//! +//! - `x86::Emitter` (Assembler, Builder, Compiler) deprecates embed utilities like `dint8()`, `duint8()`, `duint16()`, +//! `dxmm()`, etc... in favor of a new and more powerful `BaseEmitter::embedDataArray()`. This function also allows +//! emitting repeated values and/or patterns, which is used by helpers `BaseEmitter::embedUInt8()`, and others... +//! +//! - Validation is now available through `BaseEmitter::DiagnosticOptions`, which can be enabled/disabled through +//! `BaseEmitter::addDiagnosticOptions()` and `BaseEmitter::clearDiagnosticOptions()`, respectively. Validation +//! options now separate between encoding and Builder/Compiler so it's possible to choose the granularity required. +//! +//! Builder changes: +//! +//! - Internal functions for creating nodes were redesigned. They now accept a pointer to the node created as +//! a first parameter. These changes should not affect AsmJit users as these functions were used internally. +//! +//! Compiler changes: +//! +//! - `FuncCallNode` has been renamed to `InvokeNode`. Additionally, function calls should now use +//! `x86::Compiler::invoke()` instead of `call()`. The reason behind this is to remove the confusion between a +//! `call` instruction and AsmJit's `call()` intrinsic, which is now `invoke()`. +//! +//! - Creating new nodes also changed. Now the preferred way of invoking a function is to call +//! `x86::Compiler::invoke()` where the first argument is `InvokeNode**`. The function now returns an error and +//! would call `ErrorHandler` in case of a failure. Error handling was unspecified in the past - the function was +//! marked noexcept, but called error handler, which could throw. +//! +//! The reason behind this change is to make the API consistent with other changes and to also make it possible +//! to inspect the possible error. In the previous API it returned a new node or `nullptr` in case of error, +//! which the user couldn't inspect unless there was an attached `ErrorHandler`. +//! +//! Samples: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! // The basic setup of JitRuntime and CodeHolder changed, use environment() +//! // instead of codeInfo(). +//! void basicSetup() { +//! JitRuntime rt; +//! CodeHolder code(rt.environment()); +//! } +//! +//! // Calling a function (Compiler) changed - use invoke() instead of call(). +//! void functionInvocation(x86::Compiler& cc) { +//! InvokeNode* invokeNode; +//! cc.invoke(&invokeNode, targetOperand, FuncSignature::build<...>(...)); +//! } +//! ``` + + +//! \defgroup asmjit_core Core +//! \brief Globals, code storage, and emitter interface. +//! +//! ### Overview +//! +//! AsmJit library uses \ref CodeHolder to hold code during code generation and emitters inheriting from \ref +//! BaseEmitter to emit code. CodeHolder uses containers to manage its data: +//! +//! - \ref Section - stores information about a code or data section. +//! - \ref CodeBuffer - stores actual code or data, part of \ref Section. +//! - \ref LabelEntry - stores information about a label - its name, offset, section where it belongs to, and +//! other bits. +//! - \ref LabelLink - stores information about yet unbound label, which was already used by the assembler. +//! - \ref RelocEntry - stores information about a relocation. +//! - \ref AddressTableEntry - stores information about an address, which was used in a jump or call. Such +//! address may need relocation. +//! +//! To generate code you would need to instantiate at least the following classes: +//! +//! - \ref CodeHolder - to hold code during code generation. +//! - \ref BaseEmitter - to emit code into \ref CodeHolder. +//! - \ref Target (optional) - most likely \ref JitRuntime to keep the generated code in executable memory. \ref +//! Target can be customized by inheriting from it. +//! +//! There are also other core classes that are important: +//! +//! - \ref Environment - describes where the code will run. Environment brings the concept of target triples or +//! tuples into AsmJit, which means that users can specify target architecture, platform, and ABI. +//! - \ref TypeId - encapsulates lightweight type functionality that can be used to describe primitive and vector +//! types. Types are used by higher level utilities, for example by \ref asmjit_function and \ref asmjit_compiler. +//! - \ref CpuInfo - encapsulates CPU information - stores both CPU information and CPU features described by \ref +//! CpuFeatures. +//! +//! AsmJit also provides global constants: +//! +//! - \ref Globals - namespace that provides global constants. +//! - \ref ByteOrder - byte-order constants and functionality. +//! +//! \note CodeHolder examples use \ref x86::Assembler as abstract interfaces cannot be used to generate code. +//! +//! ### CodeHolder & Emitters +//! +//! The example below shows how the mentioned classes interact to generate X86 code: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*Func)(void); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! +//! CodeHolder code; // Holds code and relocation information. +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! a.mov(x86::eax, 1); // Move one to eax register. +//! a.ret(); // Return from function. +//! // ===== x86::Assembler is no longer needed from here and can be destroyed ===== +//! +//! Func fn; // Holds address to the generated function. +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ===== CodeHolder is no longer needed from here and can be destroyed ===== +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "1". +//! +//! // All classes use RAII, all resources will be released before `main()` returns, +//! // the generated function can be, however, released explicitly if you intend to +//! // reuse or keep the runtime alive, which you should in a production-ready code. +//! rt.release(fn); +//! +//! return 0; +//! } +//! ``` +//! +//! The example above used \ref x86::Assembler as an emitter. AsmJit provides the following emitters that offer various +//! levels of abstraction: +//! +//! - \ref asmjit_assembler - Low-level emitter that emits directly to \ref CodeBuffer. +//! - \ref asmjit_builder - Low-level emitter that emits to a \ref BaseNode list. +//! - \ref asmjit_compiler - High-level emitter that provides register allocation. +//! +//! ### Targets and JitRuntime +//! +//! AsmJit's \ref Target is an interface that provides basic target abstraction. At the moment AsmJit provides only +//! one implementation called \ref JitRuntime, which as the name suggests provides JIT code target and execution +//! runtime. \ref JitRuntime provides all the necessary stuff to implement a simple JIT compiler with basic memory +//! management. It only provides \ref JitRuntime::add() and \ref JitRuntime::release() functions that are used to +//! either add code to the runtime or release it. \ref JitRuntime doesn't do any decisions on when the code should be +//! released, the decision is up to the developer. +//! +//! See more at \ref asmjit_virtual_memory group. +//! +//! ### More About Environment +//! +//! In the previous example the \ref Environment is retrieved from \ref JitRuntime. It's logical as \ref JitRuntime +//! always returns an \ref Environment that is compatible with the host. For example if your application runs on X86_64 +//! CPU the \ref Environment returned will use \ref Arch::kX64 architecture in contrast to \ref Arch::kX86, which will +//! be used in 32-bit mode on an X86 target. +//! +//! AsmJit allows to setup the \ref Environment manually and to select a different architecture and ABI when necessary. +//! So let's do something else this time, let's always generate a 32-bit code and print its binary representation. To +//! do that, we can create our own \ref Environment and initialize it to \ref Arch::kX86. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! int main(int argc, char* argv[]) { +//! using namespace asmjit::x86; +//! +//! // Create a custom environment initialized to 32-bit X86 architecture. +//! Environment env; +//! env.setArch(Arch::kX86); +//! +//! CodeHolder code; // Create a CodeHolder. +//! code.init(env); // Initialize CodeHolder with custom environment. +//! +//! // Generate a 32-bit function that sums 4 floats and looks like: +//! // void func(float* dst, const float* a, const float* b) +//! x86::Assembler a(&code); // Create and attach x86::Assembler to `code`. +//! +//! a.mov(eax, dword_ptr(esp, 4)); // Load the destination pointer. +//! a.mov(ecx, dword_ptr(esp, 8)); // Load the first source pointer. +//! a.mov(edx, dword_ptr(esp, 12)); // Load the second source pointer. +//! +//! a.movups(xmm0, ptr(ecx)); // Load 4 floats from [ecx] to XMM0. +//! a.movups(xmm1, ptr(edx)); // Load 4 floats from [edx] to XMM1. +//! a.addps(xmm0, xmm1); // Add 4 floats in XMM1 to XMM0. +//! a.movups(ptr(eax), xmm0); // Store the result to [eax]. +//! a.ret(); // Return from function. +//! +//! // We have no Runtime this time, it's on us what we do with the code. +//! // CodeHolder stores code in Section, which provides some basic properties +//! // and CodeBuffer structure. We are interested in section's CodeBuffer. +//! // +//! // NOTE: The first section is always '.text', it can be retrieved by +//! // code.sectionById(0) or simply by code.textSection(). +//! CodeBuffer& buffer = code.textSection()->buffer(); +//! +//! // Print the machine-code generated or do something else with it... +//! // 8B4424048B4C24048B5424040F28010F58010F2900C3 +//! for (size_t i = 0; i < buffer.length; i++) +//! printf("%02X", buffer.data[i]); +//! +//! return 0; +//! } +//! ``` +//! +//! ### Explicit Code Relocation +//! +//! In addition to \ref Environment, \ref CodeHolder can be configured to specify a base-address (or a virtual base +//! address in a linker terminology), which could be static (useful when you know the location where the target's +//! machine code will be) or dynamic. AsmJit assumes dynamic base-address by default and relocates the code held by +//! \ref CodeHolder to a user provided address on-demand. To be able to relocate to a user provided address it needs +//! to store some information about relocations, which is represented by \ref RelocEntry. Relocation entries are only +//! required if you call external functions from the generated code that cannot be encoded by using a 32-bit +//! displacement (64-bit displacements are not provided by aby supported architecture). +//! +//! There is also a concept called \ref LabelLink - label link is a lightweight data structure that doesn't have any +//! identifier and is stored in \ref LabelEntry as a single-linked list. Label link represents either unbound yet used +//! label and cross-sections links (only relevant to code that uses multiple sections). Since crossing sections is +//! something that cannot be resolved immediately these links persist until offsets of these sections are assigned and +//! until \ref CodeHolder::resolveUnresolvedLinks() is called. It's an error if you end up with code that has +//! unresolved label links after flattening. You can verify it by calling \ref CodeHolder::hasUnresolvedLinks(), which +//! inspects the value returned by \ref CodeHolder::unresolvedLinkCount(). +//! +//! AsmJit can flatten code that uses multiple sections by assigning each section an incrementing offset that respects +//! its alignment. Use \ref CodeHolder::flatten() to do that. After the sections are flattened their offsets and +//! virtual sizes are adjusted to respect each section's buffer size and alignment. The \ref +//! CodeHolder::resolveUnresolvedLinks() function must be called before relocating the code held by \ref CodeHolder. +//! You can also flatten your code manually by iterating over all sections and calculating their offsets (relative to +//! base) by your own algorithm. In that case \ref CodeHolder::flatten() should not be called, however, +//! \ref CodeHolder::resolveUnresolvedLinks() should be. +//! +//! The example below shows how to use a built-in virtual memory allocator \ref JitAllocator instead of using \ref +//! JitRuntime (just in case you want to use your own memory management) and how to relocate the generated code +//! into your own memory block - you can use your own virtual memory allocator if you prefer that, but that's OS +//! specific and not covered by the documentation. +//! +//! The following code is similar to the previous one, but implements a function working in both 32-bit and 64-bit +//! environments: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef void (*SumIntsFunc)(int* dst, const int* a, const int* b); +//! +//! int main() { +//! // Create a custom environment that matches the current host environment. +//! Environment env = Environment::host(); +//! CpuFeatures cpuFeatures = CpuInfo::host().features(); +//! +//! CodeHolder code; // Create a CodeHolder. +//! code.init(env, cpuFeatures); // Initialize CodeHolder with environment. +//! +//! x86::Assembler a(&code); // Create and attach x86::Assembler to `code`. +//! +//! // Signature: 'void func(int* dst, const int* a, const int* b)'. +//! x86::Gp dst; +//! x86::Gp src_a; +//! x86::Gp src_b; +//! +//! // Handle the difference between 32-bit and 64-bit calling conventions +//! // (arguments passed through stack vs. arguments passed by registers). +//! if (env.is32Bit()) { +//! dst = x86::eax; +//! src_a = x86::ecx; +//! src_b = x86::edx; +//! a.mov(dst , x86::dword_ptr(x86::esp, 4)); +//! a.mov(src_a, x86::dword_ptr(x86::esp, 8)); +//! a.mov(src_b, x86::dword_ptr(x86::esp, 12)); +//! } +//! else { +//! if (env.isPlatformWindows()) { +//! dst = x86::rcx; // First argument (destination pointer). +//! src_a = x86::rdx; // Second argument (source 'a' pointer). +//! src_b = x86::r8; // Third argument (source 'b' pointer). +//! } +//! else { +//! dst = x86::rdi; // First argument (destination pointer). +//! src_a = x86::rsi; // Second argument (source 'a' pointer). +//! src_b = x86::rdx; // Third argument (source 'b' pointer). +//! } +//! } +//! +//! a.movdqu(x86::xmm0, x86::ptr(src_a)); // Load 4 ints from [src_a] to XMM0. +//! a.movdqu(x86::xmm1, x86::ptr(src_b)); // Load 4 ints from [src_b] to XMM1. +//! a.paddd(x86::xmm0, x86::xmm1); // Add 4 ints in XMM1 to XMM0. +//! a.movdqu(x86::ptr(dst), x86::xmm0); // Store the result to [dst]. +//! a.ret(); // Return from function. +//! +//! // Even when we didn't use multiple sections AsmJit could insert one section +//! // called '.addrtab' (address table section), which would be filled by data +//! // required by relocations (absolute jumps and calls). You can omit this code +//! // if you are 100% sure your code doesn't contain multiple sections and +//! // such relocations. You can use `CodeHolder::hasAddressTable()` to verify +//! // whether the address table section does exist. +//! code.flatten(); +//! code.resolveUnresolvedLinks(); +//! +//! // After the code was generated it can be relocated manually to any memory +//! // location, however, we need to know it's size before we perform memory +//! // allocation. `CodeHolder::codeSize()` returns the worst estimated code +//! // size in case that relocations are not possible without trampolines (in +//! // that case some extra code at the end of the current code buffer is +//! // generated during relocation). +//! size_t estimatedSize = code.codeSize(); +//! +//! // Instead of rolling up our own memory allocator we can use the one AsmJit +//! // provides. It's decoupled so you don't need to use `JitRuntime` for that. +//! JitAllocator allocator; +//! +//! // Allocate an executable virtual memory and handle a possible failure. +//! void* p = allocator.alloc(estimatedSize); +//! if (!p) +//! return 0; +//! +//! // Now relocate the code to the address provided by the memory allocator. +//! // Please note that this DOESN'T COPY anything to `p`. This function will +//! // store the address in CodeHolder and use relocation entries to patch the +//! // existing code in all sections to respect the base address provided. +//! code.relocateToBase((uint64_t)p); +//! +//! // This is purely optional. There are cases in which the relocation can omit +//! // unneeded data, which would shrink the size of address table. If that +//! // happened the codeSize returned after relocateToBase() would be smaller +//! // than the originally `estimatedSize`. +//! size_t codeSize = code.codeSize(); +//! +//! // This will copy code from all sections to `p`. Iterating over all sections +//! // and calling `memcpy()` would work as well, however, this function supports +//! // additional options that can be used to also zero pad sections' virtual +//! // size, etc. +//! // +//! // With some additional features, copyFlattenData() does roughly this: +//! // for (Section* section : code.sections()) +//! // memcpy((uint8_t*)p + section->offset(), +//! // section->data(), +//! // section->bufferSize()); +//! code.copyFlattenedData(p, codeSize, CopySectionFlags::kPadSectionBuffer); +//! +//! // Execute the generated function. +//! int inA[4] = { 4, 3, 2, 1 }; +//! int inB[4] = { 1, 5, 2, 8 }; +//! int out[4]; +//! +//! // This code uses AsmJit's ptr_as_func<> to cast between void* and SumIntsFunc. +//! ptr_as_func(p)(out, inA, inB); +//! +//! // Prints {5 8 4 9} +//! printf("{%d %d %d %d}\n", out[0], out[1], out[2], out[3]); +//! +//! // Release 'p' is it's no longer needed. It will be destroyed with 'vm' +//! // instance anyway, but it's a good practice to release it explicitly +//! // when you know that the function will not be needed anymore. +//! allocator.release(p); +//! +//! return 0; +//! } +//! ``` +//! +//! If you know the base-address in advance (before the code generation) it can be passed as a second argument to +//! \ref CodeHolder::init(). In that case the Assembler will know the absolute position of each instruction and +//! would be able to use it during instruction encoding to prevent relocations where possible. The following example +//! shows how to configure the base address: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! void initializeCodeHolder(CodeHolder& code) { +//! Environment env = Environment::host(); +//! CpuFeatures cpuFeatures = CpuInfo::host().features(); +//! uint64_t baseAddress = uint64_t(0x1234); +//! +//! // initialize CodeHolder with environment and custom base address. +//! code.init(env, cpuFeatures, baseAddress); +//! } +//! ``` +//! +//! ### Label Offsets and Links +//! +//! When a label that is not yet bound is used by the Assembler, it creates a \ref LabelLink, which is then added to +//! a \ref LabelEntry. These links are also created if a label is used in a different section than in which it was +//! bound. Let's examine some functions that can be used to check whether there are any unresolved links. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! void labelLinksExample(CodeHolder& code, const Label& label) { +//! // Tests whether the `label` is bound. +//! bool isBound = code.isLabelBound(label); +//! printf("Label %u is %s\n", label.id(), isBound ? "bound" : "not bound"); +//! +//! // Returns true if the code contains either referenced, but unbound +//! // labels, or cross-section label links that are not resolved yet. +//! bool hasUnresolved = code.hasUnresolvedLinks(); // Boolean answer. +//! size_t nUnresolved = code.unresolvedLinkCount(); // Count of unresolved links. +//! +//! printf("Number of unresolved links: %zu\n", nUnresolved); +//! } +//! ``` +//! +//! There is no function that would return the number of unbound labels as this is completely unimportant from +//! CodeHolder's perspective. If a label is not used then it doesn't matter whether it's bound or not, only actually +//! used labels matter. After a Label is bound it's possible to query its offset relative to the start of the +//! section where it was bound: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! void labelOffsetExample(CodeHolder& code, const Label& label) { +//! // Label offset is known after it's bound. The offset provided is relative +//! // to the start of the section, see below for alternative. If the given +//! // label is not bound the offset returned will be zero. It's recommended +//! // to always check whether the label is bound before using its offset. +//! uint64_t sectionOffset = code.labelOffset(label); +//! printf("Label offset relative to section: %llu\n", (unsigned long long)sectionOffset); +//! +//! // If you use multiple sections and want the offset relative to the base. +//! // NOTE: This function expects that the section has already an offset and +//! // the label-link was resolved (if this is not true you will still get an +//! // offset relative to the start of the section). +//! uint64_t baseOffset = code.labelOffsetFromBase(label); +//! printf("Label offset relative to base: %llu\n", (unsigned long long)baseOffset); +//! } +//! ``` +//! +//! ### Sections +//! +//! AsmJit allows to create multiple sections within the same \ref CodeHolder. A test-case +//! [asmjit_test_x86_sections.cpp](https://github.com/asmjit/asmjit/blob/master/test/asmjit_test_x86_sections.cpp) +//! can be used as a reference point although the following example should also provide a useful insight: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! void sectionsExample(CodeHolder& code) { +//! // Text section is always provided as the first section. +//! Section* text = code.textSection(); // or code.sectionById(0); +//! +//! // To create another section use CodeHolder::newSection(). +//! Section* data; +//! Error err = code.newSection(&data, +//! ".data", // Section name +//! SIZE_MAX, // Name length if the name is not null terminated (or SIZE_MAX). +//! SectionFlags::kNone, // Section flags, see SectionFlags. +//! 8, // Section alignment, must be power of 2. +//! 0); // Section order value (optional, default 0). +//! +//! // When you switch sections in Assembler, Builder, or Compiler the cursor +//! // will always move to the end of that section. When you create an Assembler +//! // the cursor would be placed at the end of the first (.text) section, which +//! // is initially empty. +//! x86::Assembler a(&code); +//! Label L_Data = a.newLabel(); +//! +//! a.mov(x86::eax, x86::ebx); // Emits in .text section. +//! +//! a.section(data); // Switches to the end of .data section. +//! a.bind(L_Data); // Binds label in this .data section +//! a.db(0x01); // Emits byte in .data section. +//! +//! a.section(text); // Switches to the end of .text section. +//! a.add(x86::ebx, x86::eax); // Emits in .text section. +//! +//! // References a label in .text section, which was bound in .data section. +//! // This would create a LabelLink even when the L_Data is already bound, +//! // because the reference crosses sections. See below... +//! a.lea(x86::rsi, x86::ptr(L_Data)); +//! } +//! ``` +//! +//! The last line in the example above shows that a LabelLink would be created even for bound labels that cross +//! sections. In this case a referenced label was bound in another section, which means that the link couldn't be +//! resolved at that moment. If your code uses sections, but you wish AsmJit to flatten these sections (you don't +//! plan to flatten them manually) then there is an API for that. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // ... (continuing the previous example) ... +//! void sectionsExampleContinued(CodeHolder& code) { +//! // Suppose we have some code that contains multiple sections and +//! // we would like to flatten it by using AsmJit's built-in API: +//! Error err = code.flatten(); +//! if (err) { +//! // There are many reasons it can fail, so always handle a possible error. +//! printf("Failed to flatten the code: %s\n", DebugUtils::errorAsString(err)); +//! exit(1); +//! } +//! +//! // After flattening all sections would contain assigned offsets +//! // relative to base. Offsets are 64-bit unsigned integers so we +//! // cast them to `size_t` for simplicity. On 32-bit targets it's +//! // guaranteed that the offset cannot be greater than `2^32 - 1`. +//! printf("Data section offset %zu", size_t(data->offset())); +//! +//! // The flattening doesn't resolve unresolved label links, this +//! // has to be done manually as flattening can be done separately. +//! err = code.resolveUnresolvedLinks(); +//! if (err) { +//! // This is the kind of error that should always be handled... +//! printf("Failed to resolve label links: %s\n", DebugUtils::errorAsString(err)); +//! exit(1); +//! } +//! +//! if (code.hasUnresolvedLinks()) { +//! // This would mean either unbound label or some other issue. +//! printf("The code has %zu unbound labels\n", code.unresolvedLinkCount()); +//! exit(1); +//! } +//! } +//! ``` + + +//! \defgroup asmjit_assembler Assembler +//! \brief Assembler interface and operands. +//! +//! ### Overview +//! +//! AsmJit's Assembler is used to emit machine code directly into a \ref CodeBuffer. In general, code generation +//! with assembler requires the knowledge of the following: +//! +//! - \ref BaseAssembler and architecture-specific assemblers: +//! - \ref x86::Assembler - Assembler implementation targeting X86 and X86_64 architectures. +//! - \ref a64::Assembler - Assembler implementation targeting AArch64 architecture. +//! - \ref Operand and its variations: +//! - \ref BaseReg - Base class for a register operand, inherited by: +//! - \ref x86::Reg - Register operand specific to X86 and X86_64 architectures. +//! - \ref arm::Reg - Register operand specific to AArch64 architecture. +//! - \ref BaseMem - Base class for a memory operand, inherited by: +//! - \ref x86::Mem - Memory operand specific to X86 architecture. +//! - \ref arm::Mem - Memory operand specific to AArch64 architecture. +//! - \ref Imm - Immediate (value) operand. +//! - \ref Label - Label operand. +//! +//! \note Assembler examples use \ref x86::Assembler as abstract interfaces cannot be used to generate code. +//! +//! ### Operand Basics +//! +//! Let's start with operands. \ref Operand is a data structure that defines a data layout of any operand. It can be +//! inherited, but any class inheriting it cannot add any members to it, only the existing layout can be reused. +//! AsmJit allows to construct operands dynamically, to store them, and to query a complete information about them +//! at run-time. Operands are small (always 16 bytes per \ref Operand) and can be copied and passed by value. Please +//! never allocate individual operands dynamically by using a `new` keyword - it would work, but then you would have +//! to be responsible for deleting such operands. In AsmJit operands are always part of some other data structures +//! like \ref InstNode, which is part of \ref asmjit_builder tool. +//! +//! Operands contain only identifiers, but not pointers to any code-generation data. For example \ref Label operand +//! only provides label identifier, but not a pointer to \ref LabelEntry structure. In AsmJit such IDs are used to +//! link stuff together without having to deal with pointers. +//! +//! AsmJit's operands all inherit from a base class called \ref Operand. Operands have the following properties that +//! are commonly accessible by getters and setters: +//! +//! - \ref Operand - Base operand, which only provides accessors that are common to all operand types. +//! - \ref BaseReg - Describes either physical or virtual register. Physical registers have id that matches the +//! target's machine id directly whereas virtual registers must be allocated into physical registers by a register +//! allocator pass. Register operand provides: +//! - Register Type (\ref RegType) - Unique id that describes each possible register provided by the target +//! architecture - for example X86 backend provides general purpose registers (GPB-LO, GPB-HI, GPW, GPD, and GPQ) +//! and all types of other registers like K, MM, BND, XMM, YMM, ZMM, and TMM. +//! - Register Group (\ref RegGroup) - Groups multiple register types under a single group - for example all +//! general-purpose registers (of all sizes) on X86 are part of \ref RegGroup::kGp and all SIMD registers +//! (XMM, YMM, ZMM) are part of \ref RegGroup::kVec. +//! - Register Size - Contains the size of the register in bytes. If the size depends on the mode (32-bit vs +//! 64-bit) then generally the higher size is used (for example RIP register has size 8 by default). +//! - Register Id - Contains physical or virtual id of the register. +//! - \ref BaseMem - Used to reference a memory location. Memory operand provides: +//! - Base Register - A base register type and id (physical or virtual). +//! - Index Register - An index register type and id (physical or virtual). +//! - Offset - Displacement or absolute address to be referenced (32-bit if base register is used and 64-bit if +//! base register is not used). +//! - Flags that can describe various architecture dependent information (like scale and segment-override on X86). +//! - \ref Imm - Immediate values are usually part of instructions (encoded within the instruction itself) or data. +//! - \ref Label - used to reference a location in code or data. Labels must be created by the \ref BaseEmitter or +//! by \ref CodeHolder. Each label has its unique id per \ref CodeHolder instance. +//! +//! ### Operand Manipulation +//! +//! AsmJit allows to construct operands dynamically, to store them, and to query a complete information about them at +//! run-time. Operands are small (always 16 bytes per `Operand`) and should be always copied (by value) if you intend +//! to store them (don't create operands by using `new` keyword, it's not recommended). Operands are safe to be passed +//! to `memcpy()` and `memset()`, which becomes handy when working with arrays of operands. If you set all members of +//! an \ref Operand to zero the operand would become NONE operand, which is the same as a default constructed Operand. +//! +//! The example below illustrates how operands can be used and modified even without using any other code generation +//! classes. The example uses X86 architecture-specific operands. +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! // Registers can be copied, it's a common practice. +//! x86::Gp dstRegByValue() { return x86::ecx; } +//! +//! void usingOperandsExample(x86::Assembler& a) { +//! // Gets `ecx` register returned by a function. +//! x86::Gp dst = dstRegByValue(); +//! // Gets `rax` register directly from the provided `x86` namespace. +//! x86::Gp src = x86::rax; +//! // Constructs `r10` dynamically. +//! x86::Gp idx = x86::gpq(10); +//! // Constructs [src + idx] memory address - referencing [rax + r10]. +//! x86::Mem m = x86::ptr(src, idx); +//! +//! // Examine `m`: Returns `RegType::kX86_Gpq`. +//! m.indexType(); +//! // Examine `m`: Returns 10 (`r10`). +//! m.indexId(); +//! +//! // Reconstruct `idx` stored in mem: +//! x86::Gp idx_2 = x86::Gp::fromTypeAndId(m.indexType(), m.indexId()); +//! +//! // True, `idx` and idx_2` are identical. +//! idx == idx_2; +//! +//! // Possible - op will still be the same as `m`. +//! Operand op = m; +//! // True (can be casted to BaseMem or architecture-specific Mem). +//! op.isMem(); +//! +//! // True, `op` is just a copy of `m`. +//! m == op; +//! +//! // Static cast is fine and valid here. +//! static_cast(op).addOffset(1); +//! // However, using `as()` to cast to a derived type is preferred. +//! op.as().addOffset(1); +//! // False, `op` now points to [rax + r10 + 2], which is not [rax + r10]. +//! m == op; +//! +//! // Emitting 'mov' - type safe way. +//! a.mov(dst, m); +//! // Not possible, `mov` doesn't provide mov(x86::Gp, Operand) overload. +//! a.mov(dst, op); +//! +//! // Type-unsafe, but possible. +//! a.emit(x86::Inst::kIdMov, dst, m); +//! // Also possible, `emit()` is type-less and can be used with raw Operand. +//! a.emit(x86::Inst::kIdMov, dst, op); +//! } +//! ``` +//! +//! Some operands have to be created explicitly by emitters. For example labels must be created by \ref +//! BaseEmitter::newLabel(), which creates a label entry and returns a \ref Label operand with the id that refers +//! to it. Such label then can be used by emitters. +//! +//! ### Memory Operands +//! +//! Some architectures like X86 provide a complex memory addressing model that allows to encode addresses having a +//! BASE register, INDEX register with a possible scale (left shift), and displacement (called offset in AsmJit). +//! Memory address on X86 can also specify memory segment (segment-override in X86 terminology) and some instructions +//! (gather / scatter) require INDEX to be a \ref x86::Vec register instead of a general-purpose register. +//! +//! AsmJit allows to encode and work with all forms of addresses mentioned and implemented by X86. In addition, it +//! also allows to construct absolute 64-bit memory address operands, which is only allowed in one form of 'mov' +//! instruction. +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void testX86Mem() { +//! // Makes it easier to access x86 stuff... +//! using namespace asmjit::x86; +//! +//! // BASE + OFFSET. +//! Mem a = ptr(rax); // a = [rax] +//! Mem b = ptr(rax, 15); // b = [rax + 15] +//! +//! // BASE + INDEX << SHIFT - Shift is in BITS as used by X86! +//! Mem c = ptr(rax, rbx); // c = [rax + rbx] +//! Mem d = ptr(rax, rbx, 2); // d = [rax + rbx << 2] +//! Mem e = ptr(rax, rbx, 2, 15); // e = [rax + rbx << 2 + 15] +//! +//! // BASE + VM (Vector Index) (encoded as MOD+VSIB). +//! Mem f = ptr(rax, xmm1); // f = [rax + xmm1] +//! Mem g = ptr(rax, xmm1, 2); // g = [rax + xmm1 << 2] +//! Mem h = ptr(rax, xmm1, 2, 15); // h = [rax + xmm1 << 2 + 15] +//! +//! // Absolute address: +//! uint64_t addr = (uint64_t)0x1234; +//! Mem i = ptr(addr); // i = [0x1234] +//! Mem j = ptr(addr, rbx); // j = [0x1234 + rbx] +//! Mem k = ptr(addr, rbx, 2); // k = [0x1234 + rbx << 2] +//! +//! // LABEL - Will be encoded as RIP (64-bit) or absolute address (32-bit). +//! Label L = ...; +//! Mem m = ptr(L); // m = [L] +//! Mem n = ptr(L, rbx); // n = [L + rbx] +//! Mem o = ptr(L, rbx, 2); // o = [L + rbx << 2] +//! Mem p = ptr(L, rbx, 2, 15); // p = [L + rbx << 2 + 15] +//! +//! // RIP - 64-bit only (RIP can't use INDEX). +//! Mem q = ptr(rip, 24); // q = [rip + 24] +//! } +//! ``` +//! +//! Memory operands can optionally contain memory size. This is required by instructions where the memory size cannot +//! be deduced from other operands, like `inc` and `dec` on X86: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void testX86Mem() { +//! // The same as: dword ptr [rax + rbx]. +//! x86::Mem a = x86::dword_ptr(x86::rax, x86::rbx); +//! +//! // The same as: qword ptr [rdx + rsi << 0 + 1]. +//! x86::Mem b = x86::qword_ptr(x86::rdx, x86::rsi, 0, 1); +//! } +//! ``` +//! +//! Memory operands provide API that can be used to access its properties: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void testX86Mem() { +//! // The same as: dword ptr [rax + 12]. +//! x86::Mem mem = x86::dword_ptr(x86::rax, 12); +//! +//! mem.hasBase(); // true. +//! mem.hasIndex(); // false. +//! mem.size(); // 4. +//! mem.offset(); // 12. +//! +//! mem.setSize(0); // Sets the size to 0 (makes it size-less). +//! mem.addOffset(-1); // Adds -1 to the offset and makes it 11. +//! mem.setOffset(0); // Sets the offset to 0. +//! mem.setBase(x86::rcx); // Changes BASE to RCX. +//! mem.setIndex(x86::rax); // Changes INDEX to RAX. +//! mem.hasIndex(); // true. +//! } +//! // ... +//! ``` +//! +//! Making changes to memory operand is very comfortable when emitting loads +//! and stores: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void testX86Mem(CodeHolder& code) { +//! x86::Assembler a(code); // Your initialized x86::Assembler. +//! x86::Mem mSrc = x86::ptr(eax); // Construct [eax] memory operand. +//! +//! // One way of emitting bunch of loads is to use `mem.adjusted()`, which +//! // returns a new memory operand and keeps the source operand unchanged. +//! a.movaps(x86::xmm0, mSrc); // No adjustment needed to load [eax]. +//! a.movaps(x86::xmm1, mSrc.adjusted(16)); // Loads from [eax + 16]. +//! a.movaps(x86::xmm2, mSrc.adjusted(32)); // Loads from [eax + 32]. +//! a.movaps(x86::xmm3, mSrc.adjusted(48)); // Loads from [eax + 48]. +//! +//! // ... do something with xmm0-3 ... +//! +//! // Another way of adjusting memory is to change the operand in-place. +//! // If you want to keep the original operand you can simply clone it. +//! x86::Mem mDst = mSrc.clone(); // Clone mSrc. +//! +//! a.movaps(mDst, x86::xmm0); // Stores xmm0 to [eax]. +//! mDst.addOffset(16); // Adds 16 to `mDst`. +//! +//! a.movaps(mDst, x86::xmm1); // Stores to [eax + 16] . +//! mDst.addOffset(16); // Adds 16 to `mDst`. +//! +//! a.movaps(mDst, x86::xmm2); // Stores to [eax + 32]. +//! mDst.addOffset(16); // Adds 16 to `mDst`. +//! +//! a.movaps(mDst, x86::xmm3); // Stores to [eax + 48]. +//! } +//! ``` +//! +//! ### Assembler Examples +//! +//! - \ref x86::Assembler provides many X86/X64 examples. + + +//! \defgroup asmjit_builder Builder +//! \brief Builder interface, nodes, and passes. +//! +//! ### Overview +//! +//! Both \ref BaseBuilder and \ref BaseCompiler interfaces describe emitters that emit into a representation that +//! allows further processing. The code stored in such representation is completely safe to be patched, simplified, +//! reordered, obfuscated, removed, injected, analyzed, or processed some other way. Each instruction, label, +//! directive, or other building block is stored as \ref BaseNode (or derived class like \ref InstNode or \ref +//! LabelNode) and contains all the information necessary to pass that node later to the assembler. +//! +//! \ref BaseBuilder is an emitter that inherits from \ref BaseEmitter interface. It was designed to provide a maximum +//! compatibility with the existing \ref BaseAssembler emitter so users can move from assembler to builder when needed, +//! for example to implement post-processing, which is not possible with Assembler. +//! +//! ### Builder Nodes +//! +//! \ref BaseBuilder doesn't generate machine code directly, it uses an intermediate representation based on nodes, +//! however, it allows to serialize to \ref BaseAssembler when the code is ready to be encoded. +//! +//! There are multiple node types used by both \ref BaseBuilder and \ref BaseCompiler : +//! +//! - Basic nodes: +//! - \ref BaseNode - Base class for all nodes. +//! - \ref InstNode - Represents an instruction node. +//! - \ref AlignNode - Represents an alignment directive (.align). +//! - \ref LabelNode - Represents a location where to bound a \ref Label. +//! +//! - Data nodes: +//! - \ref EmbedDataNode - Represents data. +//! - \ref EmbedLabelNode - Represents \ref Label address embedded as data. +//! - \ref EmbedLabelDeltaNode - Represents a difference of two labels embedded in data. +//! - \ref ConstPoolNode - Represents a constant pool data embedded as data. +//! +//! - Informative nodes: +//! - \ref CommentNode - Represents a comment string, doesn't affect code generation. +//! - \ref SentinelNode - A marker that can be used to remember certain position in code or data, doesn't affect +//! code generation. Used by \ref FuncNode to mark the end of a function. +//! +//! - Other nodes are provided by \ref asmjit_compiler infrastructure. +//! +//! ### Builder Examples +//! +//! - \ref x86::Builder - Builder implementation targeting X86 and X86_64 architectures. +//! - \ref a64::Builder - Builder implementation targeting AArch64 architecture. + + +//! \defgroup asmjit_compiler Compiler +//! \brief Compiler interface. +//! +//! ### Overview +//! +//! \ref BaseCompiler is a high-level interface, which provides register allocation and support for defining and +//! invoking functions, built on top of \ref BaseBuilder interface At the moment it's the easiest way of generating +//! code in AsmJit as most architecture and OS specifics is properly abstracted and handled by AsmJit automatically. +//! However, abstractions also mean restrictions, which means that \ref BaseCompiler has more limitations than \ref +//! BaseAssembler or \ref BaseBuilder. +//! +//! Since \ref BaseCompiler provides register allocation it also establishes the concept of functions - a function +//! in Compiler sense is a unit in which virtual registers are allocated into physical registers by the register +//! allocator. In addition, it enables to use such virtual registers in function invocations. +//! +//! \ref BaseCompiler automatically handles function calling conventions. It's still architecture dependent, but +//! makes the code generation much easies. Functions are essential; the first-step to generate some code is to define +//! a signature of the function to be generated (before generating the function body itself). Function arguments and +//! return value(s) are handled by assigning virtual registers to them. Similarly, function calls are handled the same +//! way. +//! +//! ### Compiler Nodes +//! +//! \ref BaseCompiler adds some nodes that are required for function generation and invocation: +//! +//! - \ref FuncNode - Represents a function definition. +//! - \ref FuncRetNode - Represents a function return. +//! - \ref InvokeNode - Represents a function invocation. +//! +//! \ref BaseCompiler also makes the use of passes (\ref Pass) and automatically adds an architecture-dependent +//! register allocator pass to the list of passes when attached to \ref CodeHolder. +//! +//! ### Compiler Examples +//! +//! - \ref x86::Compiler - Compiler implementation targeting X86 and X86_64 architectures. +//! - \ref a64::Compiler - Compiler implementation targeting AArch64 architecture. +//! +//! ### Compiler Tips +//! +//! Users of AsmJit have done mistakes in the past, this section should provide some useful tips for beginners: +//! +//! - Virtual registers in compiler are bound to a single function. At the moment the implementation doesn't +//! care whether a single virtual register is used in multiple functions, but it sees it as two independent +//! virtual registers in that case. This means that virtual registers cannot be used to implement global +//! variables. Global variables are basically memory addresses which functions can read from and write to, +//! and they have to be implemented in the same way. +//! +//! - Compiler provides a useful debugging functionality, which can be turned on through \ref FormatFlags. Use +//! \ref Logger::addFlags() to turn on additional logging features when using Compiler. + + +//! \defgroup asmjit_function Function +//! \brief Function definitions. +//! +//! ### Overview +//! +//! AsmJit provides functionality that can be used to define function signatures and to calculate automatically +//! optimal function frame that can be used directly by a prolog and epilog insertion. This feature was exclusive +//! to AsmJit's Compiler for a very long time, but was abstracted out and is now available for all users regardless +//! of the emitter they use. The following use cases are possible: +//! +//! - Calculate function frame before the function is generated - this is the only way available to \ref +//! BaseAssembler users and it will be described in this section. +//! +//! - Calculate function frame after the function is generated - this way is generally used by \ref BaseBuilder +//! and \ref BaseCompiler emitters and this way is generally described in \ref asmjit_compiler section. +//! +//! The following concepts are used to describe and create functions in AsmJit: +//! +//! - \ref TypeId - Type-id is an 8-bit value that describes a platform independent type as we know from C/C++. +//! It provides abstractions for most common types like `int8_t`, `uint32_t`, `uintptr_t`, `float`, `double`, +//! and all possible vector types to match ISAs up to AVX512. \ref TypeId was introduced originally for \ref +//! asmjit_compiler, but it's now used by \ref FuncSignature as well. +//! +//! - \ref CallConv - Describes a calling convention - this class contains instructions to assign registers and +//! stack addresses to function arguments and return value(s), but doesn't specify any function signature itself. +//! Calling conventions are architecture and OS dependent. +//! +//! - \ref FuncSignature - Describes a function signature, for example `int func(int, int)`. FuncSignature contains +//! a function calling convention id, return value type, and function arguments. The signature itself is platform +//! independent and uses \ref TypeId to describe types of function arguments and function return value(s). +//! +//! - \ref FuncDetail - Architecture and ABI dependent information that describes \ref CallConv and expanded \ref +//! FuncSignature. Each function argument and return value is represented as \ref FuncValue that contains the +//! original \ref TypeId enriched with additional information that specifies whether the value is passed or +//! returned by register (and which register) or by stack. Each value also contains some other metadata that +//! provide additional information required to handle it properly (for example whether a vector is passed +//! indirectly by a pointer as required by WIN64 calling convention). +//! +//! - \ref FuncFrame - Contains information about the function frame that can be used by prolog/epilog inserter +//! (PEI). Holds call stack size size and alignment, local stack size and alignment, and various attributes that +//! describe how prolog and epilog should be constructed. `FuncFrame` doesn't know anything about function's +//! arguments or return values, it hold only information necessary to create a valid and ABI conforming function +//! prologs and epilogs. +//! +//! - \ref FuncArgsAssignment - A helper class that can be used to reassign function arguments into user specified +//! registers. It's architecture and ABI dependent mapping from function arguments described by \ref CallConv +//! and \ref FuncDetail into registers specified by the user. +//! +//! It's a lot of concepts where each represents one step in a function frame calculation. It can be used to create +//! function prologs, epilogs, and also to calculate information necessary to perform function calls. + + +//! \defgroup asmjit_logging Logging +//! \brief Logging and formatting. +//! +//! ### Overview +//! +//! The initial phase of a project that generates machine code is not always smooth. Failure cases are common not just +//! at the beginning phase, but also during the development or refactoring. AsmJit provides logging functionality to +//! address this issue. AsmJit does already a good job with function overloading to prevent from emitting unencodable +//! instructions, but it can't prevent from emitting machine code that is correct at instruction level, but doesn't +//! work when it's executed asa whole. Logging has always been an important part of AsmJit's infrastructure and looking +//! at logs can sometimes reveal code generation issues quickly. +//! +//! AsmJit provides API for logging and formatting: +//! +//! - \ref Logger - A logger that you can pass to \ref CodeHolder and all emitters that inherit from \ref BaseEmitter. +//! +//! - \ref FormatOptions - Formatting options that can change how instructions and operands are formatted. +//! +//! - \ref Formatter - A namespace that provides functions that can format input data like \ref Operand, \ref BaseReg, +//! \ref Label, and \ref BaseNode into \ref String. +//! +//! AsmJit's \ref Logger serves the following purposes: +//! +//! - Provides a basic foundation for logging. +//! +//! - Abstract class leaving the implementation on users. The following built-in implementations are provided for +//! simplicity: +//! +//! - \ref FileLogger implements logging into a standard `FILE` stream. +//! - \ref StringLogger serializes all logs into a \ref String instance. +//! +//! AsmJit's \ref FormatOptions provides the following to customize the formatting of instructions and operands through: +//! +//! - \ref FormatFlags +//! - \ref FormatIndentationGroup +//! +//! ### Logging +//! +//! A \ref Logger is typically attached to a \ref CodeHolder, which propagates it to all attached emitters +//! automatically. The example below illustrates how to use \ref FileLogger that outputs to standard output: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! FileLogger logger(stdout); // Logger should always survive CodeHolder. +//! +//! CodeHolder code; // Holds code and relocation information. +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! code.setLogger(&logger); // Attach the `logger` to `code` holder. +//! +//! // ... code as usual, everything emitted will be logged to `stdout` ... +//! return 0; +//! } +//! ``` +//! +//! If output to FILE stream is not desired it's possible to use \ref StringLogger, which concatenates everything +//! into a multi-line string: +//! +//! ``` +//! #include +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! StringLogger logger; // Logger should always survive CodeHolder. +//! +//! CodeHolder code; // Holds code and relocation information. +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! code.setLogger(&logger); // Attach the `logger` to `code` holder. +//! +//! // ... code as usual, logging will be concatenated to logger string ... +//! +//! // You can either use the string from StringLogger directly or you can +//! // move it. Logger::data() returns its content as null terminated char[]. +//! printf("Logger content: %s\n", logger.data()); +//! +//! // It can be moved into your own string like this: +//! String content = std::move(logger.content()); +//! printf("The same content: %s\n", content.data()); +//! +//! return 0; +//! } +//! ``` +//! +//! ### Formatting +//! +//! AsmJit uses \ref Formatter to format inputs that are then passed to \ref Logger. Formatting is public and can be +//! used by AsmJit users as well. The most important thing to know regarding formatting is that \ref Formatter always +//! appends to the output string, so it can be used to build complex strings without having to concatenate +//! intermediate strings. +//! +//! The first example illustrates how to format operands: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! void logOperand(Arch arch, const Operand_& op) { +//! // The emitter is optional (named labels and virtual registers need it). +//! BaseEmitter* emitter = nullptr; +//! +//! // No flags by default. +//! FormatFlags formatFlags = FormatFlags::kNone; +//! +//! StringTmp<128> sb; +//! Formatter::formatOperand(sb, formatFlags, emitter, arch, op); +//! printf("%s\n", sb.data()); +//! } +//! +//! void formattingExample() { +//! using namespace x86; +//! +//! // Architecture is not part of operand, it must be passed explicitly. +//! // Format flags. We pass it explicitly also to 'logOperand' to make +//! // compatible with what AsmJit normally does. +//! Arch arch = Arch::kX64; +//! +//! logOperand(arch, rax); // Prints 'rax'. +//! logOperand(arch, ptr(rax, rbx, 2)); // Prints '[rax + rbx * 4]`. +//! logOperand(arch, dword_ptr(rax, rbx, 2)); // Prints 'dword [rax + rbx * 4]`. +//! logOperand(arch, imm(42)); // Prints '42'. +//! } +//! ``` +//! +//! Next example illustrates how to format whole instructions: +//! +//! ``` +//! #include +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! template +//! void logInstruction(Arch arch, const BaseInst& inst, Args&&... args) { +//! // The emitter is optional (named labels and virtual registers need it). +//! BaseEmitter* emitter = nullptr; +//! +//! // No flags by default. +//! FormatFlags formatFlags = FormatFlags::kNone; +//! +//! // The formatter expects operands in an array. +//! Operand_ operands[] { std::forward(args)... }; +//! +//! StringTmp<128> sb; +//! Formatter::formatInstruction( +//! sb, formatFlags, emitter, arch, inst, operands, sizeof...(args)); +//! printf("%s\n", sb.data()); +//! } +//! +//! void formattingExample() { +//! using namespace x86; +//! +//! // Architecture is not part of operand, it must be passed explicitly. +//! // Format flags. We pass it explicitly also to 'logOperand' to make +//! // compatible with what AsmJit normally does. +//! Arch arch = Arch::kX64; +//! +//! // Prints 'mov rax, rcx'. +//! logInstruction(arch, BaseInst(Inst::kIdMov), rax, rcx); +//! +//! // Prints 'vaddpd zmm0, zmm1, [rax] {1to8}'. +//! logInstruction(arch, +//! BaseInst(Inst::kIdVaddpd), +//! zmm0, zmm1, ptr(rax)._1to8()); +//! +//! // BaseInst abstracts instruction id, instruction options, and extraReg. +//! // Prints 'lock add [rax], rcx'. +//! logInstruction(arch, +//! BaseInst(Inst::kIdAdd, InstOptions::kX86_Lock), +//! ptr(rax), rcx); +//! +//! // Similarly an extra register (like AVX-512 selector) can be used. +//! // Prints 'vaddpd zmm0 {k2} {z}, zmm1, [rax]'. +//! logInstruction(arch, +//! BaseInst(Inst::kIdAdd, InstOptions::kX86_ZMask, k2), +//! zmm0, zmm1, ptr(rax)); +//! } +//! ``` +//! +//! And finally, the example below illustrates how to use a built-in function to format the content of +//! \ref BaseBuilder, which consists of nodes: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! void formattingExample(BaseBuilder* builder) { +//! FormatOptions formatOptions {}; +//! +//! // This also shows how temporary strings can be used. +//! StringTmp<512> sb; +//! +//! // FormatNodeList requires the String for output, formatting flags, which +//! // were zero (no extra flags), and the builder instance, which we have +//! // provided. An overloaded version also exists, which accepts begin and +//! // and end nodes, which can be used to only format a range of nodes. +//! Formatter::formatNodeList(sb, formatOptions, builder); +//! +//! // You can do whatever else with the string, it's always null terminated, +//! // so it can be passed to C functions like printf(). +//! printf("%s\n", sb.data()); +//! } +//! ``` + + +//! \defgroup asmjit_error_handling Error Handling +//! \brief Error handling. +//! +//! ### Overview +//! +//! AsmJit uses error codes to represent and return errors. Every function that can fail returns an \ref Error code. +//! Exceptions are never thrown by AsmJit itself even in extreme conditions like out-of-memory, but it's possible to +//! override \ref ErrorHandler::handleError() to throw, in that case no error will be returned and exception will be +//! thrown instead. All functions where this can happen are not marked `noexcept`. +//! +//! Errors should never be ignored, however, checking errors after each AsmJit API call would simply over-complicate +//! the whole code generation experience. \ref ErrorHandler exists to make the use of AsmJit API simpler as it allows +//! to customize how errors can be handled: +//! +//! - Record the error and continue (the way how the error is user-implemented). +//! - Throw an exception. AsmJit doesn't use exceptions and is completely exception-safe, but it's perfectly legal +//! to throw an exception from the error handler. +//! - Use plain old C's `setjmp()` and `longjmp()`. Asmjit always puts Assembler, Builder and Compiler to a +//! consistent state before calling \ref ErrorHandler::handleError(), so `longjmp()` can be used without issues +//! to cancel the code-generation if an error occurred. This method can be used if exception handling in your +//! project is turned off and you still want some comfort. In most cases it should be safe as AsmJit uses \ref +//! Zone memory and the ownership of memory it allocates always ends with the instance that allocated it. If +//! using this approach please never jump outside the life-time of \ref CodeHolder and \ref BaseEmitter. +//! +//! ### Using ErrorHandler +//! +//! An example of attaching \ref ErrorHandler to \ref CodeHolder. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // A simple error handler implementation, extend according to your needs. +//! class MyErrorHandler : public ErrorHandler { +//! public: +//! void handleError(Error err, const char* message, BaseEmitter* origin) override { +//! printf("AsmJit error: %s\n", message); +//! } +//! }; +//! +//! int main() { +//! JitRuntime rt; +//! +//! MyErrorHandler myErrorHandler; +//! CodeHolder code; +//! +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&myErrorHandler); +//! +//! x86::Assembler a(&code); +//! // ... code generation ... +//! +//! return 0; +//! } +//! ``` +//! +//! Useful classes in error handling group: +//! +//! - See \ref DebugUtils that provides utilities useful for debugging. +//! - See \ref Error that lists error codes that AsmJit uses. +//! - See \ref ErrorHandler for more details about error handling. + + +//! \defgroup asmjit_instruction_db Instruction DB +//! \brief Instruction database (introspection, read/write, validation, ...). +//! +//! ### Overview +//! +//! AsmJit provides a public instruction database that can be used to query information about a complete instruction. +//! The instruction database requires the knowledge of the following: +//! +//! - \ref BaseInst - Base instruction that contains instruction id, options, and a possible extra-register that +//! represents either REP prefix counter or AVX-512 selector (mask). +//! +//! - \ref Operand - Represents operands of an instruction. +//! +//! Each instruction can be then queried for the following information: +//! +//! - \ref InstRWInfo - Read/write information of instruction and its operands (includes \ref OpRWInfo). +//! +//! - \ref CpuFeatures - CPU features required to execute the instruction. +//! +//! In addition to query functionality AsmJit is also able to validate whether an instruction and its operands are +//! valid. This is useful for making sure that what user tries to emit is correct and it can be also used by other +//! projects that parse user input, like AsmTK project. +//! +//! ### Query API +//! +//! The instruction query API is provided by \ref InstAPI namespace. The following queries are possible: +//! +//! - \ref InstAPI::queryRWInfo() - queries read/write information of the given instruction and its operands. +//! Includes also CPU flags read/written. +//! +//! - \ref InstAPI::queryFeatures() - queries CPU features that are required to execute the given instruction. A full +//! instruction with operands must be given as some architectures like X86 may require different features for the +//! same instruction based on its operands. +//! +//! - asmjit_test_instinfo.cpp +//! can be also used as a reference about accessing instruction information. +//! +//! ### Validation API +//! +//! The instruction validation API is provided by \ref InstAPI namespace in the similar fashion like the Query API, +//! however, validation can also be turned on at \ref BaseEmitter level. The following is possible: +//! +//! - \ref InstAPI::validate() - low-level instruction validation function that is used internally by emitters +//! if strict validation is enabled. +//! +//! - \ref BaseEmitter::addDiagnosticOptions() - can be used to enable validation at emitter level, see \ref +//! DiagnosticOptions. + + +//! \defgroup asmjit_virtual_memory Virtual Memory +//! \brief Virtual memory management. +//! +//! ### Overview +//! +//! AsmJit's virtual memory management is divided into three main categories: +//! +//! - Low level interface that provides cross-platform abstractions for virtual memory allocation. Implemented in +//! \ref VirtMem namespace. This API is a thin wrapper around operating system specific calls such as +//! `VirtualAlloc()` and `mmap()` and it's intended to be used by AsmJit's higher level API. Low-level virtual +//! memory functions can be used to allocate virtual memory, change its permissions, and to release it. +//! Additionally, an API that allows to create dual mapping (to support hardened environments) is provided. +//! +//! - Middle level API that is provided by \ref JitAllocator, which uses \ref VirtMem internally and offers nicer +//! API that can be used by users to allocate executable memory conveniently. \ref JitAllocator tries to be smart, +//! for example automatically using dual mapping or `MAP_JIT` on hardened environments. +//! +//! - High level API that is provided by \ref JitRuntime, which implements \ref Target interface and uses \ref +//! JitAllocator under the hood. Since \ref JitRuntime inherits from \ref Target it makes it easy to use with +//! \ref CodeHolder. Many AsmJit examples use \ref JitRuntime for its simplicity and easy integration. +//! +//! The main difference between \ref VirtMem and \ref JitAllocator is that \ref VirtMem can only be used to allocate +//! whole pages, whereas \ref JitAllocator has `malloc()` like API that allows to allocate smaller quantities that +//! usually represent the size of an assembled function or a chunk of functions that can represent a module, for +//! example. \ref JitAllocator then tracks used space of each page it maintains. Internally, \ref JitAllocator uses +//! two bit arrays to track occupied regions in each allocated block of pages. +//! +//! ### Hardened Environments +//! +//! In the past, allocating virtual memory with Read+Write+Execute (RWX) access permissions was easy. However, modern +//! operating systems and runtime environments often use hardening, which typically prohibits mapping pages with both +//! Write and Execute permissions (known as the W^X policy). This presents a challenge for JIT compilers because +//! generated code for a single function is unlikely to fit in exactly N pages without leaving some space empty. To +//! accommodate this, the execution environment may need to temporarily change the permissions of existing pages to +//! read+write (RW) to insert new code into them, however, sometimes it's not possible to ensure that no thread is +//! executing code in such affected pages in a multithreaded environment, in which multiple threads may be executing +//! generated code. +//! +//! Such restrictions leave a lot of complexity on the application, so AsmJit implements a dual mapping technique to +//! make the life of AsmJit users easier. In this technique, a region of memory is mapped to two different virtual +//! addresses with different access permissions. One virtual address is mapped with read and write (RW) access, which +//! is used by the JIT compiler to write generated code. The other virtual address is mapped with read and execute (RX) +//! access, which is used by the application to execute the generated code. +//! +//! However, implementing dual mapping can be challenging because it typically requires obtaining an anonymous file +//! descriptor on most Unix-like operating systems. This file descriptor is then passed to mmap() twice to create +//! the two mappings. AsmJit handles this challenge by using system-specific techniques such as `memfd_create()` on +//! Linux, `shm_open(SHM_ANON)` on BSD, and `MAP_REMAPDUP` with `mremap()` on NetBSD. The latter approach does not +//! require a file descriptor. If none of these options are available, AsmJit uses a plain `open()` call followed by +//! `unlink()`. +//! +//! The most challenging part is actually obtaining a file descriptor that can be passed to `mmap()` with `PROT_EXEC`. +//! This is still something that may fail, for example the environment could be hardened in a way that this would +//! not be possible at all, and thus dual mapping would not work. +//! +//! Dual mapping is provided by both \ref VirtMem and \ref JitAllocator. + + +//! \defgroup asmjit_zone Zone Memory +//! \brief Zone memory allocator and containers. +//! +//! ### Overview +//! +//! AsmJit uses zone memory allocation (also known as Arena allocation) to allocate most of the data it uses. It's a +//! fast allocator that allows AsmJit to allocate a lot of small data structures fast and without `malloc()` overhead. +//! Since code generators and all related classes are usually short-lived this approach decreases memory usage and +//! fragmentation as arena-based allocators always allocate larger blocks of memory, which are then split into smaller +//! chunks. +//! +//! Another advantage of zone memory allocation is that since the whole library uses this strategy it's very easy to +//! deallocate everything that a particular instance is holding by simply releasing the memory the allocator holds. +//! This improves destruction time of such objects as there is no destruction at all. Long-lived objects just reset +//! its data in destructor or in their reset() member function for a future reuse. For this purpose all containers in +//! AsmJit are also zone allocated. +//! +//! ### Zone Allocation +//! +//! - \ref Zone - Incremental zone memory allocator with minimum features. It can only allocate memory without the +//! possibility to return it back to the allocator. +//! +//! - \ref ZoneTmp - A temporary \ref Zone with some initial static storage. If the allocation requests fit the +//! static storage allocated then there will be no dynamic memory allocation during the lifetime of \ref ZoneTmp, +//! otherwise it would act as \ref Zone with one preallocated block on the stack. +//! +//! - \ref ZoneAllocator - A wrapper of \ref Zone that provides the capability of returning memory to the allocator. +//! Such memory is stored in a pool for later reuse. +//! +//! ### Zone Allocated Containers +//! +//! - \ref ZoneString - Zone allocated string. +//! - \ref ZoneHash - Zone allocated hash table. +//! - \ref ZoneTree - Zone allocated red-black tree. +//! - \ref ZoneList - Zone allocated double-linked list. +//! - \ref ZoneStack - Zone allocated stack. +//! - \ref ZoneVector - Zone allocated vector. +//! - \ref ZoneBitVector - Zone allocated vector of bits. +//! +//! ### Using Zone Allocated Containers +//! +//! The most common data structure exposed by AsmJit is \ref ZoneVector. It's very similar to `std::vector`, but the +//! implementation doesn't use exceptions and uses the mentioned \ref ZoneAllocator for performance reasons. You don't +//! have to worry about allocations as you should not need to add items to AsmJit's data structures directly as there +//! should be API for all required operations. +//! +//! The following APIs in \ref CodeHolder returns \ref ZoneVector reference: +//! +//! ``` +//! using namespace asmjit; +//! +//! void example(CodeHolder& code) { +//! // Contains all emitters attached to CodeHolder. +//! const ZoneVector& emitters = code.emitters(); +//! +//! // Contains all section entries managed by CodeHolder. +//! const ZoneVector& sections = code.sections(); +//! +//! // Contains all label entries managed by CodeHolder. +//! const ZoneVector& labelEntries = code.labelEntries(); +//! +//! // Contains all relocation entries managed by CodeHolder. +//! const ZoneVector& relocEntries = code.relocEntries(); +//! } +//! ``` +//! +//! \ref ZoneVector has overloaded array access operator to make it possible to access its elements through operator[]. +//! Some standard functions like \ref ZoneVector::empty(), \ref ZoneVector::size(), and \ref ZoneVector::data() are +//! provided as well. Vectors are also iterable through a range-based for loop: +//! +//! ``` +//! using namespace asmjit; +//! +//! void example(CodeHolder& code) { +//! for (LabelEntry* le : code.labelEntries()) { +//! printf("Label #%u {Bound=%s Offset=%llu}", +//! le->id(), +//! le->isBound() ? "true" : "false", +//! (unsigned long long)le->offset()); +//! } +//! } +//! ``` +//! +//! ### Design Considerations +//! +//! Zone-allocated containers do not store the allocator within the container. This decision was made to reduce the +//! footprint of such containers as AsmJit tooling, especially Compiler's register allocation, may use many instances +//! of such containers to perform code analysis and register allocation. +//! +//! For example to append an item into a \ref ZoneVector it's required to pass the allocator as the first argument, +//! so it can be used in case that the vector needs a reallocation. Such function also returns an error, which must +//! be propagated to the caller. +//! +//! ``` +//! using namespace asmjit +//! +//! Error example(ZoneAllocator* allocator) { +//! ZoneVector vector; +//! +//! // Unfortunately, allocator must be provided to all functions that mutate +//! // the vector. However, AsmJit users should never need to do this as all +//! // manipulation should be done through public API, which takes care of +//! // that. +//! for (int i = 0; i < 100; i++) { +//! ASMJIT_PROPAGATE(vector.append(allocator, i)); +//! } +//! +//! // By default vector's destructor doesn't release anything as it knows +//! // that its content is zone allocated. However, \ref ZoneVector::release +//! // can be used to explicitly release the vector data to the allocator if +//! // necessary +//! vector.release(allocator); +//! } +//! ``` +//! +//! Containers like \ref ZoneVector also provide a functionality to reserve a certain number of items before any items +//! are added to it. This approach is used internally in most places as it allows to prepare space for data that will +//! be added to some container before the data itself was created. +//! +//! ``` +//! using namespace asmjit +//! +//! Error example(ZoneAllocator* allocator) { +//! ZoneVector vector; +//! +//! ASMJIT_PROPAGATE(vector.willGrow(100)); +//! for (int i = 0; i < 100; i++) { +//! // Cannot fail. +//! vector.appendUnsafe(allocator, i); +//! } +//! +//! vector.release(allocator); +//! } +//! ``` + + +//! \defgroup asmjit_utilities Utilities +//! \brief Utility classes and functions. +//! +//! ### Overview +//! +//! AsmJit uses and provides utility classes and functions, that can be used with AsmJit. The functionality can be +//! divided into the following topics: +//! +//! ### String Functionality +//! +//! - \ref String - AsmJit's string container, which is used internally and which doesn't use exceptions and has +//! a stable layout, which is not dependent on C++ standard library. +//! +//! - \ref StringTmp - String that can have base storage allocated on stack. The amount of storage on stack can +//! be specified as a template parameter. +//! +//! - \ref FixedString - Fixed string container limited up to N characters. +//! +//! ### Code Generation Utilities +//! +//! - \ref ConstPool - Constant pool used by \ref BaseCompiler, but also available to users that may find use of it. +//! +//! ### Support Functionality Used by AsmJit +//! +//! - \ref Support namespace provides many other utility functions and classes that are used by AsmJit, and made +//! public. + + +//! \defgroup asmjit_x86 X86 Backend +//! \brief X86/X64 backend. + + +//! \defgroup asmjit_arm ARM Commons +//! \brief ARM commons shared between AArch32 and AArch64. + + +//! \defgroup asmjit_a64 AArch64 Backend +//! \brief AArch64 backend. + + +//! \cond INTERNAL +//! \defgroup asmjit_ra RA +//! \brief Register allocator internals. +//! \endcond + +} // {asmjit} + +#include "asmjit-scope-begin.h" +#include "core/archtraits.h" +#include "core/assembler.h" +#include "core/builder.h" +#include "core/codeholder.h" +#include "core/compiler.h" +#include "core/constpool.h" +#include "core/cpuinfo.h" +#include "core/emitter.h" +#include "core/environment.h" +#include "core/errorhandler.h" +#include "core/formatter.h" +#include "core/func.h" +#include "core/globals.h" +#include "core/inst.h" +#include "core/jitallocator.h" +#include "core/jitruntime.h" +#include "core/logger.h" +#include "core/operand.h" +#include "core/osutils.h" +#include "core/string.h" +#include "core/support.h" +#include "core/target.h" +#include "core/type.h" +#include "core/virtmem.h" +#include "core/zone.h" +#include "core/zonehash.h" +#include "core/zonelist.h" +#include "core/zonetree.h" +#include "core/zonestack.h" +#include "core/zonestring.h" +#include "core/zonevector.h" +#include "asmjit-scope-end.h" + +#endif // ASMJIT_CORE_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/api-config.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/api-config.h new file mode 100644 index 0000000000000000000000000000000000000000..c9682e44adeb6888106682dfec6837a585f48833 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/api-config.h @@ -0,0 +1,664 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_API_CONFIG_H_INCLUDED +#define ASMJIT_CORE_API_CONFIG_H_INCLUDED + +// AsmJit Library & ABI Version +// ============================ + +//! \addtogroup asmjit_core +//! \{ + +//! Makes a 32-bit integer that represents AsmJit version in `(major << 16) | (minor << 8) | patch` form. +#define ASMJIT_LIBRARY_MAKE_VERSION(major, minor, patch) ((major << 16) | (minor << 8) | (patch)) + +//! AsmJit library version, see \ref ASMJIT_LIBRARY_MAKE_VERSION for a version format reference. +#define ASMJIT_LIBRARY_VERSION ASMJIT_LIBRARY_MAKE_VERSION(1, 13, 0) + +//! \def ASMJIT_ABI_NAMESPACE +//! +//! AsmJit ABI namespace is an inline namespace within \ref asmjit namespace. +//! +//! It's used to make sure that when user links to an incompatible version of AsmJit, it won't link. It has also +//! some additional properties as well. When `ASMJIT_ABI_NAMESPACE` is defined by the user it would override the +//! AsmJit default, which makes it possible to use multiple AsmJit libraries within a single project, totally +//! controlled by users. This is useful especially in cases in which some of such library comes from third party. +#if !defined(ASMJIT_ABI_NAMESPACE) + #define ASMJIT_ABI_NAMESPACE _abi_1_13 +#endif // !ASMJIT_ABI_NAMESPACE + +//! \} + +// Global Dependencies +// =================== + +#include +#include +#include // We really want std types as globals, not under 'std' namespace. +#include +#include +#include + +#include +#include +#include +#include + +#if !defined(_WIN32) && !defined(__EMSCRIPTEN__) + #include +#endif + +// Build Options +// ============= + +// NOTE: Doxygen cannot document macros that are not defined, that's why we have to define them and then undefine +// them immediately, so it won't use the macros with its own preprocessor. +#ifdef _DOXYGEN +namespace asmjit { + +//! \addtogroup asmjit_build +//! \{ + +//! Asmjit is embedded, implies \ref ASMJIT_STATIC. +#define ASMJIT_EMBED + +//! Enables static-library build. +#define ASMJIT_STATIC + +//! Defined when AsmJit's build configuration is 'Debug'. +//! +//! \note Can be defined explicitly to bypass auto-detection. +#define ASMJIT_BUILD_DEBUG + +//! Defined when AsmJit's build configuration is 'Release'. +//! +//! \note Can be defined explicitly to bypass auto-detection. +#define ASMJIT_BUILD_RELEASE + +//! Disables X86/X64 backends. +#define ASMJIT_NO_X86 + +//! Disables AArch64 backend. +#define ASMJIT_NO_AARCH64 + +//! Disables non-host backends entirely (useful for JIT compilers to minimize the library size). +#define ASMJIT_NO_FOREIGN + +//! Disables deprecated API at compile time (deprecated API won't be available). +#define ASMJIT_NO_DEPRECATED + +//! Disables \ref asmjit_builder functionality completely. +#define ASMJIT_NO_BUILDER + +//! Disables \ref asmjit_compiler functionality completely. +#define ASMJIT_NO_COMPILER + +//! Disables JIT memory management and \ref asmjit::JitRuntime. +#define ASMJIT_NO_JIT + +//! Disables \ref asmjit::Logger and \ref asmjit::Formatter. +#define ASMJIT_NO_LOGGING + +//! Disables everything that contains text. +#define ASMJIT_NO_TEXT + +//! Disables instruction validation API. +#define ASMJIT_NO_VALIDATION + +//! Disables instruction introspection API. +#define ASMJIT_NO_INTROSPECTION + +// Avoid doxygen preprocessor using feature-selection definitions. +#undef ASMJIT_BUILD_EMBED +#undef ASMJIT_BUILD_STATIC +#undef ASMJIT_BUILD_DEBUG +#undef ASMJIT_BUILD_RELEASE +#undef ASMJIT_NO_X86 +#undef ASMJIT_NO_FOREIGN +// (keep ASMJIT_NO_DEPRECATED defined, we don't document deprecated APIs). +#undef ASMJIT_NO_BUILDER +#undef ASMJIT_NO_COMPILER +#undef ASMJIT_NO_JIT +#undef ASMJIT_NO_LOGGING +#undef ASMJIT_NO_TEXT +#undef ASMJIT_NO_VALIDATION +#undef ASMJIT_NO_INTROSPECTION + +//! \} + +} // {asmjit} +#endif // _DOXYGEN + +// ASMJIT_NO_BUILDER implies ASMJIT_NO_COMPILER. +#if defined(ASMJIT_NO_BUILDER) && !defined(ASMJIT_NO_COMPILER) + #define ASMJIT_NO_COMPILER +#endif + +// Prevent compile-time errors caused by misconfiguration. +#if defined(ASMJIT_NO_TEXT) && !defined(ASMJIT_NO_LOGGING) + #pragma message("'ASMJIT_NO_TEXT' can only be defined when 'ASMJIT_NO_LOGGING' is defined.") + #undef ASMJIT_NO_TEXT +#endif + +#if defined(ASMJIT_NO_INTROSPECTION) && !defined(ASMJIT_NO_COMPILER) + #pragma message("'ASMJIT_NO_INTROSPECTION' can only be defined when 'ASMJIT_NO_COMPILER' is defined") + #undef ASMJIT_NO_INTROSPECTION +#endif + +// Build Mode +// ========== + +// Detect ASMJIT_BUILD_DEBUG and ASMJIT_BUILD_RELEASE if not defined. +#if !defined(ASMJIT_BUILD_DEBUG) && !defined(ASMJIT_BUILD_RELEASE) + #if !defined(NDEBUG) + #define ASMJIT_BUILD_DEBUG + #else + #define ASMJIT_BUILD_RELEASE + #endif +#endif + +// Target Architecture Detection +// ============================= + +//! \addtogroup asmjit_core +//! \{ + +//! \def ASMJIT_ARCH_X86 +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is X86 (32) or X86_64 (64). + +//! \def ASMJIT_ARCH_ARM +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is ARM (32) or AArch64 (64). + +//! \def ASMJIT_ARCH_MIPS +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is MIPS (32) or MISP64 (64). + +//! \def ASMJIT_ARCH_RISCV +//! +//! Defined to either 0, 32, or 64 depending on whether the target CPU is RV32 (32) or RV64 (64). + +//! \def ASMJIT_ARCH_BITS +//! +//! Defined to either 32 or 64 depending on the target. + +//! \def ASMJIT_ARCH_LE +//! +//! Defined to 1 if the target architecture is little endian. + +//! \def ASMJIT_ARCH_BE +//! +//! Defined to 1 if the target architecture is big endian. + +//! \} + +//! \cond NONE + +#if defined(_M_X64) || defined(__x86_64__) + #define ASMJIT_ARCH_X86 64 +#elif defined(_M_IX86) || defined(__X86__) || defined(__i386__) + #define ASMJIT_ARCH_X86 32 +#else + #define ASMJIT_ARCH_X86 0 +#endif + +#if defined(_M_ARM64) || defined(__arm64__) || defined(__aarch64__) +# define ASMJIT_ARCH_ARM 64 +#elif defined(_M_ARM) || defined(_M_ARMT) || defined(__arm__) || defined(__thumb__) || defined(__thumb2__) + #define ASMJIT_ARCH_ARM 32 +#else + #define ASMJIT_ARCH_ARM 0 +#endif + +#if defined(_MIPS_ARCH_MIPS64) || defined(__mips64) + #define ASMJIT_ARCH_MIPS 64 +#elif defined(_MIPS_ARCH_MIPS32) || defined(_M_MRX000) || defined(__mips__) + #define ASMJIT_ARCH_MIPS 32 +#else + #define ASMJIT_ARCH_MIPS 0 +#endif + +// NOTE `__riscv` is the correct macro in this case as specified by "RISC-V Toolchain Conventions". +#if (defined(__riscv) || defined(__riscv__)) && defined(__riscv_xlen) + #define ASMJIT_ARCH_RISCV __riscv_xlen +#else + #define ASMJIT_ARCH_RISCV 0 +#endif + +#define ASMJIT_ARCH_BITS (ASMJIT_ARCH_X86 | ASMJIT_ARCH_ARM | ASMJIT_ARCH_MIPS | ASMJIT_ARCH_RISCV) +#if ASMJIT_ARCH_BITS == 0 + #undef ASMJIT_ARCH_BITS + #if defined (__LP64__) || defined(_LP64) + #define ASMJIT_ARCH_BITS 64 + #else + #define ASMJIT_ARCH_BITS 32 + #endif +#endif + +#if (defined(__ARMEB__)) || \ + (defined(__MIPSEB__)) || \ + (defined(__BYTE_ORDER__) && (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)) + #define ASMJIT_ARCH_LE 0 + #define ASMJIT_ARCH_BE 1 +#else + #define ASMJIT_ARCH_LE 1 + #define ASMJIT_ARCH_BE 0 +#endif + +#if defined(ASMJIT_NO_FOREIGN) + #if !ASMJIT_ARCH_X86 && !defined(ASMJIT_NO_X86) + #define ASMJIT_NO_X86 + #endif + + #if ASMJIT_ARCH_ARM != 64 && !defined(ASMJIT_NO_AARCH64) + #define ASMJIT_NO_AARCH64 + #endif +#endif + +//! \endcond + +// C++ Compiler and Features Detection +// =================================== + +#if defined(__GNUC__) && defined(__has_attribute) + #define ASMJIT_CXX_HAS_ATTRIBUTE(NAME, CHECK) (__has_attribute(NAME)) +#else + #define ASMJIT_CXX_HAS_ATTRIBUTE(NAME, CHECK) (!(!(CHECK))) +#endif // !ASMJIT_CXX_HAS_ATTRIBUTE + +// API Decorators & C++ Extensions +// =============================== + +//! \addtogroup asmjit_core +//! \{ + +//! \def ASMJIT_API +//! +//! A decorator that is used to decorate API that AsmJit exports when built as a shared library. + +//! \def ASMJIT_VIRTAPI +//! +//! This is basically a workaround. When using MSVC and marking class as DLL export everything gets exported, which +//! is unwanted in most projects. MSVC automatically exports typeinfo and vtable if at least one symbol of the class +//! is exported. However, GCC has some strange behavior that even if one or more symbol is exported it doesn't export +//! typeinfo unless the class itself is decorated with "visibility(default)" (i.e. ASMJIT_API). + +//! \def ASMJIT_FORCE_INLINE +//! +//! Decorator to force inlining of functions, uses either `__attribute__((__always_inline__))` or __forceinline, +//! depending on C++ compiler. + +//! \def ASMJIT_INLINE_NODEBUG +//! +//! Like \ref ASMJIT_FORCE_INLINE, but uses additionally `__nodebug__` or `__artificial__` attribute to make the +//! debugging of some AsmJit functions easier, especially getters and one-line abstractions where usually you don't +//! want to step in. + +//! \def ASMJIT_NOINLINE +//! +//! Decorator to avoid inlining of functions, uses either `__attribute__((__noinline__))` or `__declspec(noinline)` +//! depending on C++ compiler. + +//! \def ASMJIT_NORETURN +//! +//! Decorator that marks functions that should never return. Typically used to implement assertion handlers that +//! terminate, so the function never returns. + +//! \def ASMJIT_CDECL +//! +//! CDECL function attribute - either `__attribute__((__cdecl__))` or `__cdecl`. + +//! \def ASMJIT_STDCALL +//! +//! STDCALL function attribute - either `__attribute__((__stdcall__))` or `__stdcall`. +//! +//! \note This expands to nothing on non-x86 targets as STDCALL is X86 specific. + +//! \def ASMJIT_FASTCALL +//! +//! FASTCALL function attribute - either `__attribute__((__fastcall__))` or `__fastcall`. +//! +//! \note Expands to nothing on non-x86 targets as FASTCALL is X86 specific. + +//! \def ASMJIT_REGPARM(N) +//! +//! Expands to `__attribute__((__regparm__(N)))` when compiled by GCC or clang, nothing otherwise. + +//! \def ASMJIT_VECTORCALL +//! +//! VECTORCALL function attribute - either `__attribute__((__vectorcall__))` or `__vectorcall`. +//! +//! \note Expands to nothing on non-x86 targets as VECTORCALL is X86 specific. + +//! \} + +// API (Export / Import). +#if !defined(ASMJIT_STATIC) + #if defined(_WIN32) && (defined(_MSC_VER) || defined(__MINGW32__)) + #ifdef ASMJIT_EXPORTS + #define ASMJIT_API __declspec(dllexport) + #else + #define ASMJIT_API __declspec(dllimport) + #endif + #elif defined(_WIN32) && defined(__GNUC__) + #ifdef ASMJIT_EXPORTS + #define ASMJIT_API __attribute__((__dllexport__)) + #else + #define ASMJIT_API __attribute__((__dllimport__)) + #endif + #elif defined(__GNUC__) + #define ASMJIT_API __attribute__((__visibility__("default"))) + #endif +#endif + +#if !defined(ASMJIT_API) + #define ASMJIT_API +#endif + +#if !defined(ASMJIT_VARAPI) + #define ASMJIT_VARAPI extern ASMJIT_API +#endif + +#if defined(__GNUC__) && !defined(_WIN32) + #define ASMJIT_VIRTAPI ASMJIT_API +#else + #define ASMJIT_VIRTAPI +#endif + +// Function attributes. +#if !defined(ASMJIT_BUILD_DEBUG) && defined(__GNUC__) + #define ASMJIT_FORCE_INLINE inline __attribute__((__always_inline__)) +#elif !defined(ASMJIT_BUILD_DEBUG) && defined(_MSC_VER) + #define ASMJIT_FORCE_INLINE __forceinline +#else + #define ASMJIT_FORCE_INLINE inline +#endif + + +#if defined(__clang__) + #define ASMJIT_INLINE_NODEBUG inline __attribute__((__always_inline__, __nodebug__)) +#elif defined(__GNUC__) + #define ASMJIT_INLINE_NODEBUG inline __attribute__((__always_inline__, __artificial__)) +#else + #define ASMJIT_INLINE_NODEBUG inline +#endif + +#if defined(__GNUC__) + #define ASMJIT_NOINLINE __attribute__((__noinline__)) + #define ASMJIT_NORETURN __attribute__((__noreturn__)) +#elif defined(_MSC_VER) + #define ASMJIT_NOINLINE __declspec(noinline) + #define ASMJIT_NORETURN __declspec(noreturn) +#else + #define ASMJIT_NOINLINE + #define ASMJIT_NORETURN +#endif + +// Calling conventions. +#if ASMJIT_ARCH_X86 == 32 && defined(__GNUC__) + #define ASMJIT_CDECL __attribute__((__cdecl__)) + #define ASMJIT_STDCALL __attribute__((__stdcall__)) + #define ASMJIT_FASTCALL __attribute__((__fastcall__)) + #define ASMJIT_REGPARM(N) __attribute__((__regparm__(N))) +#elif ASMJIT_ARCH_X86 == 32 && defined(_MSC_VER) + #define ASMJIT_CDECL __cdecl + #define ASMJIT_STDCALL __stdcall + #define ASMJIT_FASTCALL __fastcall + #define ASMJIT_REGPARM(N) +#else + #define ASMJIT_CDECL + #define ASMJIT_STDCALL + #define ASMJIT_FASTCALL + #define ASMJIT_REGPARM(N) +#endif + +#if ASMJIT_ARCH_X86 && defined(_WIN32) && defined(_MSC_VER) + #define ASMJIT_VECTORCALL __vectorcall +#elif ASMJIT_ARCH_X86 && defined(_WIN32) + #define ASMJIT_VECTORCALL __attribute__((__vectorcall__)) +#else + #define ASMJIT_VECTORCALL +#endif + +// Type alignment (not allowed by C++11 'alignas' keyword). +#if defined(__GNUC__) + #define ASMJIT_ALIGN_TYPE(TYPE, N) __attribute__((__aligned__(N))) TYPE +#elif defined(_MSC_VER) + #define ASMJIT_ALIGN_TYPE(TYPE, N) __declspec(align(N)) TYPE +#else + #define ASMJIT_ALIGN_TYPE(TYPE, N) TYPE +#endif + +//! \def ASMJIT_MAY_ALIAS +//! +//! Expands to `__attribute__((__may_alias__))` if supported. +#if defined(__GNUC__) + #define ASMJIT_MAY_ALIAS __attribute__((__may_alias__)) +#else + #define ASMJIT_MAY_ALIAS +#endif + +//! \def ASMJIT_MAYBE_UNUSED +//! +//! Expands to `[[maybe_unused]]` if supported or a compiler attribute instead. +#if __cplusplus >= 201703L + #define ASMJIT_MAYBE_UNUSED [[maybe_unused]] +#elif defined(__GNUC__) + #define ASMJIT_MAYBE_UNUSED __attribute__((unused)) +#else + #define ASMJIT_MAYBE_UNUSED +#endif + +#if defined(__clang_major__) && __clang_major__ >= 4 && !defined(_DOXYGEN) + // NOTE: Clang allows to apply this attribute to function arguments, which is what we want. Once GCC decides to + // support this use, we will enable it for GCC as well. However, until that, it will be clang only, which is + // what we need for static analysis. + #define ASMJIT_NONNULL(FUNCTION_ARGUMENT) FUNCTION_ARGUMENT __attribute__((__nonnull__)) +#else + #define ASMJIT_NONNULL(FUNCTION_ARGUMENT) FUNCTION_ARGUMENT +#endif + +//! \def ASMJIT_NOEXCEPT_TYPE +//! +//! Defined to `noexcept` in C++17 mode or nothing otherwise. Used by function typedefs. +#if __cplusplus >= 201703L + #define ASMJIT_NOEXCEPT_TYPE noexcept +#else + #define ASMJIT_NOEXCEPT_TYPE +#endif + +//! \def ASMJIT_ASSUME(...) +//! +//! Macro that tells the C/C++ compiler that the expression `...` evaluates to true. +//! +//! This macro has two purposes: +//! +//! 1. Enable optimizations that would not be possible without the assumption. +//! 2. Hint static analysis tools that a certain condition is true to prevent false positives. +#if defined(__clang__) + #define ASMJIT_ASSUME(...) __builtin_assume(__VA_ARGS__) +#elif defined(__GNUC__) + #define ASMJIT_ASSUME(...) do { if (!(__VA_ARGS__)) __builtin_unreachable(); } while (0) +#elif defined(_MSC_VER) + #define ASMJIT_ASSUME(...) __assume(__VA_ARGS__) +#else + #define ASMJIT_ASSUME(...) (void)0 +#endif + +//! \def ASMJIT_LIKELY(...) +//! +//! Condition is likely to be taken (mostly error handling and edge cases). + +//! \def ASMJIT_UNLIKELY(...) +//! +//! Condition is unlikely to be taken (mostly error handling and edge cases). +#if defined(__GNUC__) + #define ASMJIT_LIKELY(...) __builtin_expect(!!(__VA_ARGS__), 1) + #define ASMJIT_UNLIKELY(...) __builtin_expect(!!(__VA_ARGS__), 0) +#else + #define ASMJIT_LIKELY(...) (__VA_ARGS__) + #define ASMJIT_UNLIKELY(...) (__VA_ARGS__) +#endif + +//! \def ASMJIT_FALLTHROUGH +//! +//! Portable [[fallthrough]] attribute. +#if defined(__clang__) && __cplusplus >= 201103L + #define ASMJIT_FALLTHROUGH [[clang::fallthrough]] +#elif defined(__GNUC__) && __GNUC__ >= 7 + #define ASMJIT_FALLTHROUGH __attribute__((__fallthrough__)) +#else + #define ASMJIT_FALLTHROUGH ((void)0) /* fallthrough */ +#endif + +//! \def ASMJIT_DEPRECATED +//! +//! Marks function, class, struct, enum, or anything else as deprecated. +#if defined(__GNUC__) + #define ASMJIT_DEPRECATED(MESSAGE) __attribute__((__deprecated__(MESSAGE))) +#elif defined(_MSC_VER) + #define ASMJIT_DEPRECATED(MESSAGE) __declspec(deprecated(MESSAGE)) +#else + #define ASMJIT_DEPRECATED(MESSAGE) +#endif + +// Utilities. +#define ASMJIT_OFFSET_OF(STRUCT, MEMBER) ((int)(intptr_t)((const char*)&((const STRUCT*)0x100)->MEMBER) - 0x100) +#define ASMJIT_ARRAY_SIZE(X) uint32_t(sizeof(X) / sizeof(X[0])) + +#if ASMJIT_CXX_HAS_ATTRIBUTE(no_sanitize, 0) + #define ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF __attribute__((__no_sanitize__("undefined"))) +#elif defined(__GNUC__) && __GNUC__ >= 5 + #define ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF __attribute__((__no_sanitize_undefined__)) +#else + #define ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF +#endif + +// Diagnostic Macros +// ====================================== + +#if !defined(__clang__) && !defined(__INTEL_COMPILER) && !defined(_DOXYGEN) + #if defined(__GNUC__) && __GNUC__ == 4 + // There is a bug in GCC 4.X that has been fixed in GCC 5+, so just silence the warning. + #define ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + _Pragma("GCC diagnostic push") \ + _Pragma("GCC diagnostic ignored \"-Wmissing-field-initializers\"") + #define ASMJIT_END_DIAGNOSTIC_SCOPE \ + _Pragma("GCC diagnostic pop") + #elif defined(_MSC_VER) + #define ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + __pragma(warning(push)) \ + __pragma(warning(disable: 4127)) /* conditional expression is const */ \ + __pragma(warning(disable: 4201)) /* nameless struct/union */ + #define ASMJIT_END_DIAGNOSTIC_SCOPE \ + __pragma(warning(pop)) + #endif +#endif + +#if !defined(ASMJIT_BEGIN_DIAGNOSTIC_SCOPE) && !defined(ASMJIT_END_DIAGNOSTIC_SCOPE) + #define ASMJIT_BEGIN_DIAGNOSTIC_SCOPE + #define ASMJIT_END_DIAGNOSTIC_SCOPE +#endif + +// Begin-Namespace & End-Namespace Macros +// ====================================== + +#if !defined(ASMJIT_NO_ABI_NAMESPACE) && !defined(_DOXYGEN) + #define ASMJIT_BEGIN_NAMESPACE \ + ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + namespace asmjit { \ + inline namespace ASMJIT_ABI_NAMESPACE { + #define ASMJIT_END_NAMESPACE \ + }} \ + ASMJIT_END_DIAGNOSTIC_SCOPE +#else + #define ASMJIT_BEGIN_NAMESPACE \ + ASMJIT_BEGIN_DIAGNOSTIC_SCOPE \ + namespace asmjit { + #define ASMJIT_END_NAMESPACE \ + } \ + ASMJIT_END_DIAGNOSTIC_SCOPE +#endif + +#define ASMJIT_BEGIN_SUB_NAMESPACE(NAMESPACE) ASMJIT_BEGIN_NAMESPACE namespace NAMESPACE { +#define ASMJIT_END_SUB_NAMESPACE } ASMJIT_END_NAMESPACE + +// C++ Utilities +// ============= + +#define ASMJIT_NONCOPYABLE(Type) \ + Type(const Type& other) = delete; \ + Type& operator=(const Type& other) = delete; + +#define ASMJIT_NONCONSTRUCTIBLE(Type) \ + Type() = delete; \ + Type(const Type& other) = delete; \ + Type& operator=(const Type& other) = delete; + +//! \def ASMJIT_DEFINE_ENUM_FLAGS(T) +//! +//! Defines bit operations for enumeration flags. +#ifdef _DOXYGEN + #define ASMJIT_DEFINE_ENUM_FLAGS(T) +#else + #define ASMJIT_DEFINE_ENUM_FLAGS(T) \ + static ASMJIT_INLINE_NODEBUG constexpr T operator~(T a) noexcept { \ + return T(~(std::underlying_type::type)(a)); \ + } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr T operator|(T a, T b) noexcept { \ + return T((std::underlying_type::type)(a) | \ + (std::underlying_type::type)(b)); \ + } \ + static ASMJIT_INLINE_NODEBUG constexpr T operator&(T a, T b) noexcept { \ + return T((std::underlying_type::type)(a) & \ + (std::underlying_type::type)(b)); \ + } \ + static ASMJIT_INLINE_NODEBUG constexpr T operator^(T a, T b) noexcept { \ + return T((std::underlying_type::type)(a) ^ \ + (std::underlying_type::type)(b)); \ + } \ + \ + static ASMJIT_INLINE_NODEBUG T& operator|=(T& a, T b) noexcept { \ + a = T((std::underlying_type::type)(a) | \ + (std::underlying_type::type)(b)); \ + return a; \ + } \ + static ASMJIT_INLINE_NODEBUG T& operator&=(T& a, T b) noexcept { \ + a = T((std::underlying_type::type)(a) & \ + (std::underlying_type::type)(b)); \ + return a; \ + } \ + static ASMJIT_INLINE_NODEBUG T& operator^=(T& a, T b) noexcept { \ + a = T((std::underlying_type::type)(a) ^ \ + (std::underlying_type::type)(b)); \ + return a; \ + } +#endif + +//! \def ASMJIT_DEFINE_ENUM_COMPARE(T) +//! +//! Defines comparison operations for enumeration flags. +#if defined(_DOXYGEN) || (defined(_MSC_VER) && _MSC_VER <= 1900) + #define ASMJIT_DEFINE_ENUM_COMPARE(T) +#else + #define ASMJIT_DEFINE_ENUM_COMPARE(T) \ + static ASMJIT_INLINE_NODEBUG bool operator<(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) < (std::underlying_type::type)(b); \ + } \ + static ASMJIT_INLINE_NODEBUG bool operator<=(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) <= (std::underlying_type::type)(b); \ + } \ + static ASMJIT_INLINE_NODEBUG bool operator>(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) > (std::underlying_type::type)(b); \ + } \ + static ASMJIT_INLINE_NODEBUG bool operator>=(T a, T b) noexcept { \ + return (std::underlying_type::type)(a) >= (std::underlying_type::type)(b); \ + } +#endif + +#endif // ASMJIT_CORE_API_CONFIG_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/archcommons.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/archcommons.h new file mode 100644 index 0000000000000000000000000000000000000000..f14648bf004d31cf2e1d2422a5a615da04462332 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/archcommons.h @@ -0,0 +1,261 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ARCHCOMMONS_H_INCLUDED +#define ASMJIT_CORE_ARCHCOMMONS_H_INCLUDED + +// This file provides architecture-specific classes that are required in the core library. For example Imm operand +// allows to be created from arm::Shift in a const-expr way, so the arm::Shift must be provided. So this header file +// provides everything architecture-specific that is used by the Core API. + +#include "../core/globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(arm) + +//! \addtogroup asmjit_arm +//! \{ + +//! Condition code (both AArch32 & AArch64). +//! +//! \note This enumeration doesn't match condition code that is used in AArch32/AArch64 opcodes. In general this +//! condition code is encoded as `(cc - 2) & 0xF` so that `kAL` condition code is zero and encoded as 0xE in opcode. +//! This makes it easier to use a condition code as an instruction modifier that defaults to 'al'. +enum class CondCode : uint8_t { + kAL = 0x00u, //!< (no condition code) (always) + kNA = 0x01u, //!< (not available) (special) + kEQ = 0x02u, //!< Z==1 (any_sign ==) + kNE = 0x03u, //!< Z==0 (any_sign !=) + kCS = 0x04u, //!< C==1 (unsigned >=) + kHS = 0x04u, //!< C==1 (unsigned >=) + kLO = 0x05u, //!< C==0 (unsigned < ) + kCC = 0x05u, //!< C==0 (unsigned < ) + kMI = 0x06u, //!< N==1 (is negative) + kPL = 0x07u, //!< N==0 (is positive or zero) + kVS = 0x08u, //!< V==1 (is overflow) + kVC = 0x09u, //!< V==0 (no overflow) + kHI = 0x0Au, //!< C==1 & Z==0 (unsigned > ) + kLS = 0x0Bu, //!< C==0 | Z==1 (unsigned <=) + kGE = 0x0Cu, //!< N==V (signed >=) + kLT = 0x0Du, //!< N!=V (signed < ) + kGT = 0x0Eu, //!< Z==0 & N==V (signed > ) + kLE = 0x0Fu, //!< Z==1 | N!=V (signed <=) + + kZero = kEQ, //!< Zero flag (alias to equal). + kNotZero = kNE, //!< Not zero (alias to Not Equal). + + kEqual = kEQ, //!< Equal `a == b`. + kNotEqual = kNE, //!< Not Equal `a != b`. + + kCarry = kCS, //!< Carry flag. + kNotCarry = kCC, //!< Not carry. + + kSign = kMI, //!< Sign flag. + kNotSign = kPL, //!< Not sign. + + kNegative = kMI, //!< Negative. + kPositive = kPL, //!< Positive or zero. + + kOverflow = kVS, //!< Signed overflow. + kNotOverflow = kVC, //!< Not signed overflow. + + kSignedLT = kLT, //!< Signed `a < b`. + kSignedLE = kLE, //!< Signed `a <= b`. + kSignedGT = kGT, //!< Signed `a > b`. + kSignedGE = kGE, //!< Signed `a >= b`. + + kUnsignedLT = kLO, //!< Unsigned `a < b`. + kUnsignedLE = kLS, //!< Unsigned `a <= b`. + kUnsignedGT = kHI, //!< Unsigned `a > b`. + kUnsignedGE = kHS, //!< Unsigned `a >= b`. + + kBTZero = kZero, //!< Tested bit is zero. + kBTNotZero = kNotZero, //!< Tested bit is not zero. + + kAlways = kAL, //!< No condition code (always). + + kMaxValue = 0x0Fu //!< Maximum value of `CondCode`. +}; + + +//! \cond +static constexpr CondCode _reverseCondTable[] = { + CondCode::kAL, // AL <- AL + CondCode::kNA, // NA <- NA + CondCode::kEQ, // EQ <- EQ + CondCode::kNE, // NE <- NE + CondCode::kLS, // LS <- CS + CondCode::kHI, // HI <- LO + CondCode::kMI, // MI <- MI + CondCode::kPL, // PL <- PL + CondCode::kVS, // VS <- VS + CondCode::kVC, // VC <- VC + CondCode::kLO, // LO <- HI + CondCode::kCS, // CS <- LS + CondCode::kLE, // LE <- GE + CondCode::kGT, // GT <- LT + CondCode::kLT, // LT <- GT + CondCode::kGE // GE <- LE +}; +//! \endcond + +//! Reverses a condition code (reverses the corresponding operands of a comparison). +static ASMJIT_INLINE_NODEBUG constexpr CondCode reverseCond(CondCode cond) noexcept { return _reverseCondTable[uint8_t(cond)]; } +//! Negates a condition code. +static ASMJIT_INLINE_NODEBUG constexpr CondCode negateCond(CondCode cond) noexcept { return CondCode(uint8_t(cond) ^ uint8_t(1)); } + +//! Memory offset mode. +//! +//! Describes either fixed, pre-index, or post-index offset modes. +enum class OffsetMode : uint32_t { + //! Fixed offset mode (either no index at all or a regular index without a write-back). + kFixed = 0u, + //! Pre-index "[BASE, #Offset {, }]!" with write-back. + kPreIndex = 1u, + //! Post-index "[BASE], #Offset {, }" with write-back. + kPostIndex = 2u +}; + +//! Shift operation predicate (ARM) describes either SHIFT or EXTEND operation. +//! +//! \note The constants are AsmJit specific. The first 5 values describe real constants on ARM32 and AArch64 hardware, +//! however, the addition constants that describe extend modes are specific to AsmJit and would be translated to the +//! AArch64 specific constants by the assembler. +enum class ShiftOp : uint32_t { + //! Shift left logical operation (default). + //! + //! Available to all ARM architectures. + kLSL = 0x00u, + + //! Shift right logical operation. + //! + //! Available to all ARM architectures. + kLSR = 0x01u, + + //! Shift right arithmetic operation. + //! + //! Available to all ARM architectures. + kASR = 0x02u, + + //! Rotate right operation (AArch32 only). + kROR = 0x03u, + + //! Rotate right with carry operation (encoded as `ShiftOp::kROR` with zero) (AArch32 only). + kRRX = 0x04u, + + //! Shift left by filling low order bits with ones. + kMSL = 0x05u, + + //! UXTN extend register operation (AArch64 only). + kUXTB = 0x06u, + //! UXTH extend register operation (AArch64 only). + kUXTH = 0x07u, + //! UXTW extend register operation (AArch64 only). + kUXTW = 0x08u, + //! UXTX extend register operation (AArch64 only). + kUXTX = 0x09u, + + //! SXTB extend register operation (AArch64 only). + kSXTB = 0x0Au, + //! SXTH extend register operation (AArch64 only). + kSXTH = 0x0Bu, + //! SXTW extend register operation (AArch64 only). + kSXTW = 0x0Cu, + //! SXTX extend register operation (AArch64 only). + kSXTX = 0x0Du + + // NOTE: 0xE and 0xF are used by memory operand to specify POST|PRE offset mode. +}; + +//! Represents ARM immediate shift operation type and value. +class Shift { +public: + //! Shift operation. + ShiftOp _op; + //! Shift Value. + uint32_t _value; + + //! Default constructed Shift is not initialized. + ASMJIT_INLINE_NODEBUG Shift() noexcept = default; + + //! Copy constructor (default) + ASMJIT_INLINE_NODEBUG constexpr Shift(const Shift& other) noexcept = default; + + //! Constructs Shift from operation `op` and shift `value`. + ASMJIT_INLINE_NODEBUG constexpr Shift(ShiftOp op, uint32_t value) noexcept + : _op(op), + _value(value) {} + + //! Returns the shift operation. + ASMJIT_INLINE_NODEBUG constexpr ShiftOp op() const noexcept { return _op; } + //! Sets shift operation to `op`. + ASMJIT_INLINE_NODEBUG void setOp(ShiftOp op) noexcept { _op = op; } + + //! Returns the shift amount. + ASMJIT_INLINE_NODEBUG constexpr uint32_t value() const noexcept { return _value; } + //! Sets shift amount to `value`. + ASMJIT_INLINE_NODEBUG void setValue(uint32_t value) noexcept { _value = value; } +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +ASMJIT_BEGIN_SUB_NAMESPACE(a32) + +using namespace arm; + +//! Data type that can be encoded with AArch32 instruction identifier. +//! +//! \note Data types are frequently used with AArch32 SIMD instructions. For example `VMAX` instruction can +//! use almost all datatypes in a form `VMAX.F32`, `VMAX.S16`, `VMAX.U32`, etc... Emitter automatically adds +//! the required data type at emit level. +enum class DataType : uint32_t { + //! No data type specified (default for all general purpose instructions). + kNone = 0, + //! 8-bit signed integer, specified as `.s8` in assembly. + kS8 = 1, + //! 16-bit signed integer, specified as `.s16` in assembly. + kS16 = 2, + //! 32-bit signed integer, specified as `.s32` in assembly. + kS32 = 3, + //! 64-bit signed integer, specified as `.s64` in assembly. + kS64 = 4, + //! 8-bit unsigned integer, specified as `.u8` in assembly. + kU8 = 5, + //! 16-bit unsigned integer, specified as `.u16` in assembly. + kU16 = 6, + //! 32-bit unsigned integer, specified as `.u32` in assembly. + kU32 = 7, + //! 64-bit unsigned integer, specified as `.u64` in assembly. + kU64 = 8, + //! 16-bit floating point (half precision), specified as `.f16` in assembly. + kF16 = 10, + //! 32-bit floating point (single precision), specified as `.f32` in assembly. + kF32 = 11, + //! 64-bit floating point (double precision), specified as `.f64` in assembly. + kF64 = 12, + //! 8-bit polynomial. + kP8 = 13, + //! 16-bit BF16 floating point. + kBF16 = 14, + //! 64-bit polynomial. + kP64 = 15, + + //! Maximum value of `DataType`. + kMaxValue = 15 +}; + +static ASMJIT_INLINE_NODEBUG uint32_t dataTypeSize(DataType dt) noexcept { + static constexpr uint8_t table[] = { 0, 1, 2, 4, 8, 1, 2, 4, 8, 2, 4, 8, 1, 2, 8 }; + return table[size_t(dt)]; +} + +ASMJIT_END_SUB_NAMESPACE + +ASMJIT_BEGIN_SUB_NAMESPACE(a64) +using namespace arm; +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_CORE_ARCHCOMMONS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/archtraits.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/archtraits.h new file mode 100644 index 0000000000000000000000000000000000000000..507618b70b5aedc695158baba29ac3bdaa86a88d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/archtraits.h @@ -0,0 +1,293 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ARCHTRAITS_H_INCLUDED +#define ASMJIT_CORE_ARCHTRAITS_H_INCLUDED + +#include "../core/operand.h" +#include "../core/support.h" +#include "../core/type.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Instruction set architecture (ISA). +enum class Arch : uint8_t { + //! Unknown or uninitialized ISA. + kUnknown = 0, + + //! 32-bit X86 ISA. + kX86 = 1, + //! 64-bit X86 ISA also known as X64, X86_64, and AMD64. + kX64 = 2, + + //! 32-bit RISC-V ISA. + kRISCV32 = 3, + //! 64-bit RISC-V ISA. + kRISCV64 = 4, + + //! 32-bit ARM ISA (little endian). + kARM = 5, + //! 64-bit ARM ISA in (little endian). + kAArch64 = 6, + //! 32-bit ARM ISA in Thumb mode (little endian). + kThumb = 7, + + // 8 is not used at the moment, even numbers are 64-bit architectures. + + //! 32-bit MIPS ISA in (little endian). + kMIPS32_LE = 9, + //! 64-bit MIPS ISA in (little endian). + kMIPS64_LE = 10, + + //! 32-bit ARM ISA (big endian). + kARM_BE = 11, + //! 64-bit ARM ISA in (big endian). + kAArch64_BE = 12, + //! 32-bit ARM ISA in Thumb mode (big endian). + kThumb_BE = 13, + + // 14 is not used at the moment, even numbers are 64-bit architectures. + + //! 32-bit MIPS ISA in (big endian). + kMIPS32_BE = 15, + //! 64-bit MIPS ISA in (big endian). + kMIPS64_BE = 16, + + //! Maximum value of `Arch`. + kMaxValue = kMIPS64_BE, + + //! Mask used by 32-bit ISAs (odd are 32-bit, even are 64-bit). + k32BitMask = 0x01, + //! First big-endian architecture. + kBigEndian = kARM_BE, + + //! ISA detected at compile-time (ISA of the host). + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#else + ASMJIT_ARCH_X86 == 32 ? kX86 : + ASMJIT_ARCH_X86 == 64 ? kX64 : + + ASMJIT_ARCH_RISCV == 32 ? kRISCV32 : + ASMJIT_ARCH_RISCV == 64 ? kRISCV64 : + + ASMJIT_ARCH_ARM == 32 && ASMJIT_ARCH_LE ? kARM : + ASMJIT_ARCH_ARM == 32 && ASMJIT_ARCH_BE ? kARM_BE : + ASMJIT_ARCH_ARM == 64 && ASMJIT_ARCH_LE ? kAArch64 : + ASMJIT_ARCH_ARM == 64 && ASMJIT_ARCH_BE ? kAArch64_BE : + + ASMJIT_ARCH_MIPS == 32 && ASMJIT_ARCH_LE ? kMIPS32_LE : + ASMJIT_ARCH_MIPS == 32 && ASMJIT_ARCH_BE ? kMIPS32_BE : + ASMJIT_ARCH_MIPS == 64 && ASMJIT_ARCH_LE ? kMIPS64_LE : + ASMJIT_ARCH_MIPS == 64 && ASMJIT_ARCH_BE ? kMIPS64_BE : + + kUnknown +#endif +}; + +//! Sub-architecture. +enum class SubArch : uint8_t { + //! Unknown or uninitialized architecture sub-type. + kUnknown = 0, + + //! Maximum value of `SubArch`. + kMaxValue = kUnknown, + + //! Sub-architecture detected at compile-time (sub-architecture of the host). + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#else + kUnknown +#endif +}; + +//! Identifier used to represent names of different data types across architectures. +enum class ArchTypeNameId : uint8_t { + //! Describes 'db' (X86/X86_64 convention, always 8-bit quantity). + kDB = 0, + //! Describes 'dw' (X86/X86_64 convention, always 16-bit word). + kDW, + //! Describes 'dd' (X86/X86_64 convention, always 32-bit word). + kDD, + //! Describes 'dq' (X86/X86_64 convention, always 64-bit word). + kDQ, + //! Describes 'byte' (always 8-bit quantity). + kByte, + //! Describes 'half' (most likely 16-bit word). + kHalf, + //! Describes 'word' (either 16-bit or 32-bit word). + kWord, + //! Describes 'hword' (most likely 16-bit word). + kHWord, + //! Describes 'dword' (either 32-bit or 64-bit word). + kDWord, + //! Describes 'qword' (64-bit word). + kQWord, + //! Describes 'xword' (64-bit word). + kXWord, + //! Describes 'short' (always 16-bit word). + kShort, + //! Describes 'long' (most likely 32-bit word). + kLong, + //! Describes 'quad' (64-bit word). + kQuad, + + //! Maximum value of `ArchTypeNameId`. + kMaxValue = kQuad +}; + +//! Instruction feature hints for each register group provided by \ref ArchTraits. +//! +//! Instruction feature hints describe miscellaneous instructions provided by the architecture that can be used by +//! register allocator to make certain things simpler - like register swaps or emitting register push/pop sequences. +//! +//! \remarks Instruction feature hints are only defined for register groups that can be used with \ref +//! asmjit_compiler infrastructure. Register groups that are not managed by Compiler are not provided by +//! \ref ArchTraits and cannot be queried. +enum class InstHints : uint8_t { + //! No feature hints. + kNoHints = 0, + + //! Architecture supports a register swap by using a single instruction. + kRegSwap = 0x01u, + //! Architecture provides push/pop instructions. + kPushPop = 0x02u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstHints) + +//! Architecture traits used by Function API and Compiler's register allocator. +struct ArchTraits { + //! \name Members + //! \{ + + //! Stack pointer register id. + uint8_t _spRegId; + //! Frame pointer register id. + uint8_t _fpRegId; + //! Link register id. + uint8_t _linkRegId; + //! Instruction pointer (or program counter) register id, if accessible. + uint8_t _ipRegId; + + // Reserved. + uint8_t _reserved[3]; + //! Hardware stack alignment requirement. + uint8_t _hwStackAlignment; + + //! Minimum addressable offset on stack guaranteed for all instructions. + uint32_t _minStackOffset; + //! Maximum addressable offset on stack depending on specific instruction. + uint32_t _maxStackOffset; + + //! Flags for each virtual register group. + Support::Array _instHints; + + //! Maps register type into a signature, that provides group, size and can be used to construct register operands. + Support::Array _regSignature; + //! Maps a register to type-id, see \ref TypeId. + Support::Array _regTypeToTypeId; + //! Maps scalar TypeId values (from TypeId::_kIdBaseStart) to register types, see \ref TypeId. + Support::Array _typeIdToRegType; + + //! Word name identifiers of 8-bit, 16-bit, 32-biit, and 64-bit quantities that appear in formatted text. + ArchTypeNameId _typeNameIdTable[4]; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns stack pointer register id. + ASMJIT_INLINE_NODEBUG uint32_t spRegId() const noexcept { return _spRegId; } + //! Returns stack frame register id. + ASMJIT_INLINE_NODEBUG uint32_t fpRegId() const noexcept { return _fpRegId; } + //! Returns link register id, if the architecture provides it. + ASMJIT_INLINE_NODEBUG uint32_t linkRegId() const noexcept { return _linkRegId; } + //! Returns instruction pointer register id, if the architecture provides it. + ASMJIT_INLINE_NODEBUG uint32_t ipRegId() const noexcept { return _ipRegId; } + + //! Returns a hardware stack alignment requirement. + //! + //! \note This is a hardware constraint. Architectures that don't constrain it would return the lowest alignment + //! (1), however, some architectures may constrain the alignment, for example AArch64 requires 16-byte alignment. + ASMJIT_INLINE_NODEBUG uint32_t hwStackAlignment() const noexcept { return _hwStackAlignment; } + + //! Tests whether the architecture provides link register, which is used across function calls. If the link + //! register is not provided then a function call pushes the return address on stack (X86/X64). + ASMJIT_INLINE_NODEBUG bool hasLinkReg() const noexcept { return _linkRegId != BaseReg::kIdBad; } + + //! Returns minimum addressable offset on stack guaranteed for all instructions. + ASMJIT_INLINE_NODEBUG uint32_t minStackOffset() const noexcept { return _minStackOffset; } + //! Returns maximum addressable offset on stack depending on specific instruction. + ASMJIT_INLINE_NODEBUG uint32_t maxStackOffset() const noexcept { return _maxStackOffset; } + + //! Returns ISA flags of the given register `group`. + ASMJIT_INLINE_NODEBUG InstHints instFeatureHints(RegGroup group) const noexcept { return _instHints[group]; } + //! Tests whether the given register `group` has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasInstHint(RegGroup group, InstHints feature) const noexcept { return Support::test(_instHints[group], feature); } + //! Tests whether the ISA provides register swap instruction for the given register `group`. + ASMJIT_INLINE_NODEBUG bool hasInstRegSwap(RegGroup group) const noexcept { return hasInstHint(group, InstHints::kRegSwap); } + //! Tests whether the ISA provides push/pop instructions for the given register `group`. + ASMJIT_INLINE_NODEBUG bool hasInstPushPop(RegGroup group) const noexcept { return hasInstHint(group, InstHints::kPushPop); } + + ASMJIT_INLINE_NODEBUG bool hasRegType(RegType type) const noexcept { + return type <= RegType::kMaxValue && _regSignature[type].isValid(); + } + + //! Returns an operand signature from the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG OperandSignature regTypeToSignature(RegType type) const noexcept { return _regSignature[type]; } + //! Returns a register from the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG RegGroup regTypeToGroup(RegType type) const noexcept { return _regSignature[type].regGroup(); } + //! Returns a register size the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG uint32_t regTypeToSize(RegType type) const noexcept { return _regSignature[type].size(); } + //! Returns a corresponding `TypeId` from the given register `type` of this architecture. + ASMJIT_INLINE_NODEBUG TypeId regTypeToTypeId(RegType type) const noexcept { return _regTypeToTypeId[type]; } + + //! Returns a table of ISA word names that appear in formatted text. Word names are ISA dependent. + //! + //! The index of this table is log2 of the size: + //! - [0] 8-bits + //! - [1] 16-bits + //! - [2] 32-bits + //! - [3] 64-bits + ASMJIT_INLINE_NODEBUG const ArchTypeNameId* typeNameIdTable() const noexcept { return _typeNameIdTable; } + + //! Returns an ISA word name identifier of the given `index`, see \ref typeNameIdTable() for more details. + ASMJIT_INLINE_NODEBUG ArchTypeNameId typeNameIdByIndex(uint32_t index) const noexcept { return _typeNameIdTable[index]; } + + //! \} + + //! \name Statics + //! \{ + + //! Returns a const reference to `ArchTraits` for the given architecture `arch`. + static ASMJIT_INLINE_NODEBUG const ArchTraits& byArch(Arch arch) noexcept; + + //! \} +}; + +ASMJIT_VARAPI const ArchTraits _archTraits[uint32_t(Arch::kMaxValue) + 1]; + +//! \cond +ASMJIT_INLINE_NODEBUG const ArchTraits& ArchTraits::byArch(Arch arch) noexcept { return _archTraits[uint32_t(arch)]; } +//! \endcond + +//! Architecture utilities. +namespace ArchUtils { + +ASMJIT_API Error typeIdToRegSignature(Arch arch, TypeId typeId, TypeId* typeIdOut, OperandSignature* regSignatureOut) noexcept; + +} // {ArchUtils} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ARCHTRAITS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/assembler.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/assembler.h new file mode 100644 index 0000000000000000000000000000000000000000..7609ade579ff919515ea5ed47d8355987e897209 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/assembler.h @@ -0,0 +1,130 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ASSEMBLER_H_INCLUDED +#define ASMJIT_CORE_ASSEMBLER_H_INCLUDED + +#include "../core/codeholder.h" +#include "../core/emitter.h" +#include "../core/operand.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_assembler +//! \{ + +//! Base assembler. +//! +//! This is a base class that provides interface used by architecture specific +//! assembler implementations. Assembler doesn't hold any data, instead it's +//! attached to \ref CodeHolder, which provides all the data that Assembler +//! needs and which can be altered by it. +//! +//! Check out architecture specific assemblers for more details and examples: +//! +//! - \ref x86::Assembler - X86/X64 assembler implementation. +//! - \ref a64::Assembler - AArch64 assembler implementation. +class ASMJIT_VIRTAPI BaseAssembler : public BaseEmitter { +public: + ASMJIT_NONCOPYABLE(BaseAssembler) + typedef BaseEmitter Base; + + //! Current section where the assembling happens. + Section* _section = nullptr; + //! Start of the CodeBuffer of the current section. + uint8_t* _bufferData = nullptr; + //! End (first invalid byte) of the current section. + uint8_t* _bufferEnd = nullptr; + //! Pointer in the CodeBuffer of the current section. + uint8_t* _bufferPtr = nullptr; + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseAssembler` instance. + ASMJIT_API BaseAssembler() noexcept; + //! Destroys the `BaseAssembler` instance. + ASMJIT_API ~BaseAssembler() noexcept override; + + //! \} + + //! \name Code-Buffer Management + //! \{ + + //! Returns the capacity of the current CodeBuffer. + ASMJIT_INLINE_NODEBUG size_t bufferCapacity() const noexcept { return (size_t)(_bufferEnd - _bufferData); } + //! Returns the number of remaining bytes in the current CodeBuffer. + ASMJIT_INLINE_NODEBUG size_t remainingSpace() const noexcept { return (size_t)(_bufferEnd - _bufferPtr); } + + //! Returns the current position in the CodeBuffer. + ASMJIT_INLINE_NODEBUG size_t offset() const noexcept { return (size_t)(_bufferPtr - _bufferData); } + + //! Sets the current position in the CodeBuffer to `offset`. + //! + //! \note The `offset` cannot be greater than buffer size even if it's + //! within the buffer's capacity. + ASMJIT_API Error setOffset(size_t offset); + + //! Returns the start of the CodeBuffer in the current section. + ASMJIT_INLINE_NODEBUG uint8_t* bufferData() const noexcept { return _bufferData; } + //! Returns the end (first invalid byte) in the current section. + ASMJIT_INLINE_NODEBUG uint8_t* bufferEnd() const noexcept { return _bufferEnd; } + //! Returns the current pointer in the CodeBuffer in the current section. + ASMJIT_INLINE_NODEBUG uint8_t* bufferPtr() const noexcept { return _bufferPtr; } + + //! \} + + //! \name Section Management + //! \{ + + //! Returns the current section. + ASMJIT_INLINE_NODEBUG Section* currentSection() const noexcept { return _section; } + + ASMJIT_API Error section(Section* section) override; + + //! \} + + //! \name Label Management + //! \{ + + ASMJIT_API Label newLabel() override; + ASMJIT_API Label newNamedLabel(const char* name, size_t nameSize = SIZE_MAX, LabelType type = LabelType::kGlobal, uint32_t parentId = Globals::kInvalidId) override; + ASMJIT_API Error bind(const Label& label) override; + + //! \} + + //! \name Embed + //! \{ + + ASMJIT_API Error embed(const void* data, size_t dataSize) override; + ASMJIT_API Error embedDataArray(TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1) override; + ASMJIT_API Error embedConstPool(const Label& label, const ConstPool& pool) override; + + ASMJIT_API Error embedLabel(const Label& label, size_t dataSize = 0) override; + ASMJIT_API Error embedLabelDelta(const Label& label, const Label& base, size_t dataSize = 0) override; + + //! \} + + //! \name Comment + //! \{ + + ASMJIT_API Error comment(const char* data, size_t size = SIZE_MAX) override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ASSEMBLER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/builder.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/builder.h new file mode 100644 index 0000000000000000000000000000000000000000..ab7cbee5a7b51febb153a64e5996b1a72883bd2a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/builder.h @@ -0,0 +1,1499 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_BUILDER_H_INCLUDED +#define ASMJIT_CORE_BUILDER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_BUILDER + +#include "../core/assembler.h" +#include "../core/codeholder.h" +#include "../core/constpool.h" +#include "../core/formatter.h" +#include "../core/inst.h" +#include "../core/operand.h" +#include "../core/string.h" +#include "../core/support.h" +#include "../core/type.h" +#include "../core/zone.h" +#include "../core/zonevector.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_builder +//! \{ + +class BaseBuilder; +class Pass; + +class BaseNode; +class InstNode; +class SectionNode; +class LabelNode; +class AlignNode; +class EmbedDataNode; +class EmbedLabelNode; +class ConstPoolNode; +class CommentNode; +class SentinelNode; +class LabelDeltaNode; + +//! Type of node used by \ref BaseBuilder and \ref BaseCompiler. +enum class NodeType : uint8_t { + //! Invalid node (internal, don't use). + kNone = 0, + + // [BaseBuilder] + + //! Node is \ref InstNode. + kInst = 1, + //! Node is \ref SectionNode. + kSection = 2, + //! Node is \ref LabelNode. + kLabel = 3, + //! Node is \ref AlignNode. + kAlign = 4, + //! Node is \ref EmbedDataNode. + kEmbedData = 5, + //! Node is \ref EmbedLabelNode. + kEmbedLabel = 6, + //! Node is \ref EmbedLabelDeltaNode. + kEmbedLabelDelta = 7, + //! Node is \ref ConstPoolNode. + kConstPool = 8, + //! Node is \ref CommentNode. + kComment = 9, + //! Node is \ref SentinelNode. + kSentinel = 10, + + // [BaseCompiler] + + //! Node is \ref JumpNode (acts as InstNode). + kJump = 15, + //! Node is \ref FuncNode (acts as LabelNode). + kFunc = 16, + //! Node is \ref FuncRetNode (acts as InstNode). + kFuncRet = 17, + //! Node is \ref InvokeNode (acts as InstNode). + kInvoke = 18, + + // [UserDefined] + + //! First id of a user-defined node. + kUser = 32 +}; + +//! Node flags, specify what the node is and/or does. +enum class NodeFlags : uint8_t { + //! No flags. + kNone = 0, + //! Node is code that can be executed (instruction, label, align, etc...). + kIsCode = 0x01u, + //! Node is data that cannot be executed (data, const-pool, etc...). + kIsData = 0x02u, + //! Node is informative, can be removed and ignored. + kIsInformative = 0x04u, + //! Node can be safely removed if unreachable. + kIsRemovable = 0x08u, + //! Node does nothing when executed (label, align, explicit nop). + kHasNoEffect = 0x10u, + //! Node is an instruction or acts as it. + kActsAsInst = 0x20u, + //! Node is a label or acts as it. + kActsAsLabel = 0x40u, + //! Node is active (part of the code). + kIsActive = 0x80u +}; +ASMJIT_DEFINE_ENUM_FLAGS(NodeFlags) + +//! Type of the sentinel (purely informative purpose). +enum class SentinelType : uint8_t { + //! Type of the sentinel is not known. + kUnknown = 0u, + //! This is a sentinel used at the end of \ref FuncNode. + kFuncEnd = 1u +}; + +//! Node list. +//! +//! A double-linked list of pointers to \ref BaseNode, managed by \ref BaseBuilder or \ref BaseCompiler. +//! +//! \note At the moment NodeList is just a view, but it's planned that it will get more functionality in the future. +class NodeList { +public: + //! \name Members + //! \{ + + //! First node in the list or nullptr if there are no nodes in the list. + BaseNode* _first = nullptr; + //! Last node in the list or nullptr if there are no nodes in the list. + BaseNode* _last = nullptr; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG NodeList() noexcept {} + + ASMJIT_INLINE_NODEBUG NodeList(BaseNode* first, BaseNode* last) noexcept + : _first(first), + _last(last) {} + + //! \} + + //! \name Reset + //! \{ + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _first = nullptr; + _last = nullptr; + } + + ASMJIT_INLINE_NODEBUG void reset(BaseNode* first, BaseNode* last) noexcept { + _first = first; + _last = last; + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _first == nullptr; } + + ASMJIT_INLINE_NODEBUG BaseNode* first() const noexcept { return _first; } + ASMJIT_INLINE_NODEBUG BaseNode* last() const noexcept { return _last; } + + //! \} +}; + +//! Builder interface. +//! +//! `BaseBuilder` interface was designed to be used as a \ref BaseAssembler replacement in case pre-processing or +//! post-processing of the generated code is required. The code can be modified during or after code generation. +//! Pre processing or post processing can be done manually or through a \ref Pass object. \ref BaseBuilder stores +//! the emitted code as a double-linked list of nodes, which allows O(1) insertion and removal during processing. +//! +//! Check out architecture specific builders for more details and examples: +//! +//! - \ref x86::Builder - X86/X64 builder implementation. +//! - \ref a64::Builder - AArch64 builder implementation. +class ASMJIT_VIRTAPI BaseBuilder : public BaseEmitter { +public: + ASMJIT_NONCOPYABLE(BaseBuilder) + typedef BaseEmitter Base; + + //! \name Members + //! \{ + + //! Base zone used to allocate nodes and passes. + Zone _codeZone; + //! Data zone used to allocate data and names. + Zone _dataZone; + //! Pass zone, passed to `Pass::run()`. + Zone _passZone; + //! Allocator that uses `_codeZone`. + ZoneAllocator _allocator; + + //! Array of `Pass` objects. + ZoneVector _passes {}; + //! Maps section indexes to `LabelNode` nodes. + ZoneVector _sectionNodes {}; + //! Maps label indexes to `LabelNode` nodes. + ZoneVector _labelNodes {}; + + //! Current node (cursor). + BaseNode* _cursor = nullptr; + //! First and last nodes. + NodeList _nodeList; + + //! Flags assigned to each new node. + NodeFlags _nodeFlags = NodeFlags::kNone; + //! The sections links are dirty (used internally). + bool _dirtySectionLinks = false; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseBuilder` instance. + ASMJIT_API BaseBuilder() noexcept; + //! Destroys the `BaseBuilder` instance. + ASMJIT_API ~BaseBuilder() noexcept override; + + //! \} + + //! \name Node Management + //! \{ + + ASMJIT_INLINE_NODEBUG NodeList nodeList() const noexcept { return _nodeList; } + + //! Returns the first node. + ASMJIT_INLINE_NODEBUG BaseNode* firstNode() const noexcept { return _nodeList.first(); } + //! Returns the last node. + ASMJIT_INLINE_NODEBUG BaseNode* lastNode() const noexcept { return _nodeList.last(); } + + //! Allocates and instantiates a new node of type `T` and returns its instance. If the allocation fails `nullptr` + //! is returned. + //! + //! The template argument `T` must be a type that is extends \ref BaseNode. + //! + //! \remarks The pointer returned (if non-null) is owned by the Builder or Compiler. When the Builder/Compiler + //! is destroyed it destroys all nodes it created so no manual memory management is required. + template + inline Error _newNodeT(T** ASMJIT_NONNULL(out), Args&&... args) { + *out = _allocator.newT(this, std::forward(args)...); + if (ASMJIT_UNLIKELY(!*out)) + return reportError(DebugUtils::errored(kErrorOutOfMemory)); + return kErrorOk; + } + + //! Creates a new \ref InstNode. + ASMJIT_API Error newInstNode(InstNode** ASMJIT_NONNULL(out), InstId instId, InstOptions instOptions, uint32_t opCount); + //! Creates a new \ref LabelNode. + ASMJIT_API Error newLabelNode(LabelNode** ASMJIT_NONNULL(out)); + //! Creates a new \ref AlignNode. + ASMJIT_API Error newAlignNode(AlignNode** ASMJIT_NONNULL(out), AlignMode alignMode, uint32_t alignment); + //! Creates a new \ref EmbedDataNode. + ASMJIT_API Error newEmbedDataNode(EmbedDataNode** ASMJIT_NONNULL(out), TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1); + //! Creates a new \ref ConstPoolNode. + ASMJIT_API Error newConstPoolNode(ConstPoolNode** ASMJIT_NONNULL(out)); + //! Creates a new \ref CommentNode. + ASMJIT_API Error newCommentNode(CommentNode** ASMJIT_NONNULL(out), const char* data, size_t size); + + //! Adds `node` after the current and sets the current node to the given `node`. + ASMJIT_API BaseNode* addNode(BaseNode* ASMJIT_NONNULL(node)) noexcept; + //! Inserts the given `node` after `ref`. + ASMJIT_API BaseNode* addAfter(BaseNode* ASMJIT_NONNULL(node), BaseNode* ASMJIT_NONNULL(ref)) noexcept; + //! Inserts the given `node` before `ref`. + ASMJIT_API BaseNode* addBefore(BaseNode* ASMJIT_NONNULL(node), BaseNode* ASMJIT_NONNULL(ref)) noexcept; + //! Removes the given `node`. + ASMJIT_API BaseNode* removeNode(BaseNode* ASMJIT_NONNULL(node)) noexcept; + //! Removes multiple nodes. + ASMJIT_API void removeNodes(BaseNode* first, BaseNode* last) noexcept; + + //! Returns the cursor. + //! + //! When the Builder/Compiler is created it automatically creates a '.text' \ref SectionNode, which will be the + //! initial one. When instructions are added they are always added after the cursor and the cursor is changed + //! to be that newly added node. Use `setCursor()` to change where new nodes are inserted. + ASMJIT_INLINE_NODEBUG BaseNode* cursor() const noexcept { return _cursor; } + + //! Sets the current node to `node` and return the previous one. + ASMJIT_API BaseNode* setCursor(BaseNode* node) noexcept; + + //! Sets the current node without returning the previous node. + //! + //! Only use this function if you are concerned about performance and want this inlined (for example if you set + //! the cursor in a loop, etc...). + ASMJIT_INLINE_NODEBUG void _setCursor(BaseNode* node) noexcept { _cursor = node; } + + //! \} + + //! \name Section Management + //! \{ + + //! Returns a vector of SectionNode objects. + //! + //! \note If a section of some id is not associated with the Builder/Compiler it would be null, so always check + //! for nulls if you iterate over the vector. + ASMJIT_INLINE_NODEBUG const ZoneVector& sectionNodes() const noexcept { + return _sectionNodes; + } + + //! Tests whether the `SectionNode` of the given `sectionId` was registered. + ASMJIT_INLINE_NODEBUG bool hasRegisteredSectionNode(uint32_t sectionId) const noexcept { + return sectionId < _sectionNodes.size() && _sectionNodes[sectionId] != nullptr; + } + + //! Returns or creates a `SectionNode` that matches the given `sectionId`. + //! + //! \remarks This function will either get the existing `SectionNode` or create it in case it wasn't created before. + //! You can check whether a section has a registered `SectionNode` by using `BaseBuilder::hasRegisteredSectionNode()`. + ASMJIT_API Error sectionNodeOf(SectionNode** ASMJIT_NONNULL(out), uint32_t sectionId); + + ASMJIT_API Error section(Section* ASMJIT_NONNULL(section)) override; + + //! Returns whether the section links of active section nodes are dirty. You can update these links by calling + //! `updateSectionLinks()` in such case. + ASMJIT_INLINE_NODEBUG bool hasDirtySectionLinks() const noexcept { return _dirtySectionLinks; } + + //! Updates links of all active section nodes. + ASMJIT_API void updateSectionLinks() noexcept; + + //! \} + + //! \name Label Management + //! \{ + + //! Returns a vector of \ref LabelNode nodes. + //! + //! \note If a label of some id is not associated with the Builder/Compiler it would be null, so always check for + //! nulls if you iterate over the vector. + ASMJIT_INLINE_NODEBUG const ZoneVector& labelNodes() const noexcept { return _labelNodes; } + + //! Tests whether the `LabelNode` of the given `labelId` was registered. + ASMJIT_INLINE_NODEBUG bool hasRegisteredLabelNode(uint32_t labelId) const noexcept { + return labelId < _labelNodes.size() && _labelNodes[labelId] != nullptr; + } + + //! \overload + ASMJIT_INLINE_NODEBUG bool hasRegisteredLabelNode(const Label& label) const noexcept { + return hasRegisteredLabelNode(label.id()); + } + + //! Gets or creates a \ref LabelNode that matches the given `labelId`. + //! + //! \remarks This function will either get the existing `LabelNode` or create it in case it wasn't created before. + //! You can check whether a label has a registered `LabelNode` by calling \ref BaseBuilder::hasRegisteredLabelNode(). + ASMJIT_API Error labelNodeOf(LabelNode** ASMJIT_NONNULL(out), uint32_t labelId); + + //! \overload + ASMJIT_INLINE_NODEBUG Error labelNodeOf(LabelNode** ASMJIT_NONNULL(out), const Label& label) { + return labelNodeOf(out, label.id()); + } + + //! Registers this \ref LabelNode (internal). + //! + //! This function is used internally to register a newly created `LabelNode` with this instance of Builder/Compiler. + //! Use \ref labelNodeOf() functions to get back \ref LabelNode from a label or its identifier. + ASMJIT_API Error registerLabelNode(LabelNode* ASMJIT_NONNULL(node)); + + ASMJIT_API Label newLabel() override; + ASMJIT_API Label newNamedLabel(const char* name, size_t nameSize = SIZE_MAX, LabelType type = LabelType::kGlobal, uint32_t parentId = Globals::kInvalidId) override; + ASMJIT_API Error bind(const Label& label) override; + + //! \} + + //! \name Passes + //! \{ + + //! Returns a vector of `Pass` instances that will be executed by `runPasses()`. + ASMJIT_INLINE_NODEBUG const ZoneVector& passes() const noexcept { return _passes; } + + //! Allocates and instantiates a new pass of type `T` and returns its instance. If the allocation fails `nullptr` is + //! returned. + //! + //! The template argument `T` must be a type that is extends \ref Pass. + //! + //! \remarks The pointer returned (if non-null) is owned by the Builder or Compiler. When the Builder/Compiler is + //! destroyed it destroys all passes it created so no manual memory management is required. + template + inline T* newPassT() noexcept { return _codeZone.newT(); } + + //! \overload + template + inline T* newPassT(Args&&... args) noexcept { return _codeZone.newT(std::forward(args)...); } + + template + inline Error addPassT() { return addPass(newPassT()); } + + template + inline Error addPassT(Args&&... args) { return addPass(newPassT(std::forward(args)...)); } + + //! Returns `Pass` by name. + //! + //! If the pass having the given `name` doesn't exist `nullptr` is returned. + ASMJIT_API Pass* passByName(const char* name) const noexcept; + //! Adds `pass` to the list of passes. + ASMJIT_API Error addPass(Pass* pass) noexcept; + //! Removes `pass` from the list of passes and delete it. + ASMJIT_API Error deletePass(Pass* pass) noexcept; + + //! Runs all passes in order. + ASMJIT_API Error runPasses(); + + //! \} + + //! \name Emit + //! \{ + + ASMJIT_API Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) override; + + //! \} + + //! \name Align + //! \{ + + ASMJIT_API Error align(AlignMode alignMode, uint32_t alignment) override; + + //! \} + + //! \name Embed + //! \{ + + ASMJIT_API Error embed(const void* data, size_t dataSize) override; + ASMJIT_API Error embedDataArray(TypeId typeId, const void* data, size_t count, size_t repeat = 1) override; + ASMJIT_API Error embedConstPool(const Label& label, const ConstPool& pool) override; + + ASMJIT_API Error embedLabel(const Label& label, size_t dataSize = 0) override; + ASMJIT_API Error embedLabelDelta(const Label& label, const Label& base, size_t dataSize = 0) override; + + //! \} + + //! \name Comment + //! \{ + + ASMJIT_API Error comment(const char* data, size_t size = SIZE_MAX) override; + + //! \} + + //! \name Serialization + //! \{ + + //! Serializes everything the given emitter `dst`. + //! + //! Although not explicitly required the emitter will most probably be of Assembler type. The reason is that + //! there is no known use of serializing nodes held by Builder/Compiler into another Builder-like emitter. + ASMJIT_API Error serializeTo(BaseEmitter* dst); + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! Base node. +//! +//! Every node represents a building-block used by \ref BaseBuilder. It can be instruction, data, label, comment, +//! directive, or any other high-level representation that can be transformed to the building blocks mentioned. +//! Every class that inherits \ref BaseBuilder can define its own high-level nodes that can be later lowered to +//! basic nodes like instructions. +class BaseNode { +public: + ASMJIT_NONCOPYABLE(BaseNode) + + //! \name Members + //! \{ + + union { + struct { + //! Previous node. + BaseNode* _prev; + //! Next node. + BaseNode* _next; + }; + //! Links (an alternative view to previous and next nodes). + BaseNode* _links[2]; + }; + + //! Data shared between all types of nodes. + struct AnyData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Not used by BaseNode. + uint8_t _reserved0; + //! Not used by BaseNode. + uint8_t _reserved1; + }; + + //! Data used by \ref AlignNode. + struct AlignData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Align mode. + AlignMode _alignMode; + //! Not used by AlignNode. + uint8_t _reserved; + }; + + //! Data used by \ref InstNode. + struct InstData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Instruction operands count (used). + uint8_t _opCount; + //! Instruction operands capacity (allocated). + uint8_t _opCapacity; + }; + + //! Data used by \ref EmbedDataNode. + struct EmbedData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Type id. + TypeId _typeId; + //! Size of `_typeId`. + uint8_t _typeSize; + }; + + //! Data used by \ref SentinelNode. + struct SentinelData { + //! Node type. + NodeType _nodeType; + //! Node flags. + NodeFlags _nodeFlags; + //! Sentinel type. + SentinelType _sentinelType; + //! Not used by BaseNode. + uint8_t _reserved1; + }; + + //! Data that can have different meaning depending on \ref NodeType. + union { + //! Data useful by any node type. + AnyData _any; + //! Data specific to \ref AlignNode. + AlignData _alignData; + //! Data specific to \ref InstNode. + InstData _inst; + //! Data specific to \ref EmbedDataNode. + EmbedData _embed; + //! Data specific to \ref SentinelNode. + SentinelData _sentinel; + }; + + //! Node position in code (should be unique). + uint32_t _position; + + //! Value reserved for AsmJit users never touched by AsmJit itself. + union { + //! User data as 64-bit integer. + uint64_t _userDataU64; + //! User data as pointer. + void* _userDataPtr; + }; + + //! Data used exclusively by the current `Pass`. + void* _passData; + + //! Inline comment/annotation or nullptr if not used. + const char* _inlineComment; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseNode` - always use `BaseBuilder` to allocate nodes. + ASMJIT_INLINE_NODEBUG BaseNode(BaseBuilder* cb, NodeType nodeType, NodeFlags nodeFlags = NodeFlags::kNone) noexcept { + _prev = nullptr; + _next = nullptr; + _any._nodeType = nodeType; + _any._nodeFlags = nodeFlags | cb->_nodeFlags; + _any._reserved0 = 0; + _any._reserved1 = 0; + _position = 0; + _userDataU64 = 0; + _passData = nullptr; + _inlineComment = nullptr; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Casts this node to `T*`. + template + ASMJIT_INLINE_NODEBUG T* as() noexcept { return static_cast(this); } + //! Casts this node to `const T*`. + template + ASMJIT_INLINE_NODEBUG const T* as() const noexcept { return static_cast(this); } + + //! Returns previous node or `nullptr` if this node is either first or not + //! part of Builder/Compiler node-list. + ASMJIT_INLINE_NODEBUG BaseNode* prev() const noexcept { return _prev; } + //! Returns next node or `nullptr` if this node is either last or not part + //! of Builder/Compiler node-list. + ASMJIT_INLINE_NODEBUG BaseNode* next() const noexcept { return _next; } + + //! Returns the type of the node, see `NodeType`. + ASMJIT_INLINE_NODEBUG NodeType type() const noexcept { return _any._nodeType; } + + //! Sets the type of the node, see `NodeType` (internal). + //! + //! \remarks You should never set a type of a node to anything else than the initial value. This function is only + //! provided for users that use custom nodes and need to change the type either during construction or later. + ASMJIT_INLINE_NODEBUG void setType(NodeType type) noexcept { _any._nodeType = type; } + + //! Tests whether this node is either `InstNode` or extends it. + ASMJIT_INLINE_NODEBUG bool isInst() const noexcept { return hasFlag(NodeFlags::kActsAsInst); } + //! Tests whether this node is `SectionNode`. + ASMJIT_INLINE_NODEBUG bool isSection() const noexcept { return type() == NodeType::kSection; } + //! Tests whether this node is either `LabelNode` or extends it. + ASMJIT_INLINE_NODEBUG bool isLabel() const noexcept { return hasFlag(NodeFlags::kActsAsLabel); } + //! Tests whether this node is `AlignNode`. + ASMJIT_INLINE_NODEBUG bool isAlign() const noexcept { return type() == NodeType::kAlign; } + //! Tests whether this node is `EmbedDataNode`. + ASMJIT_INLINE_NODEBUG bool isEmbedData() const noexcept { return type() == NodeType::kEmbedData; } + //! Tests whether this node is `EmbedLabelNode`. + ASMJIT_INLINE_NODEBUG bool isEmbedLabel() const noexcept { return type() == NodeType::kEmbedLabel; } + //! Tests whether this node is `EmbedLabelDeltaNode`. + ASMJIT_INLINE_NODEBUG bool isEmbedLabelDelta() const noexcept { return type() == NodeType::kEmbedLabelDelta; } + //! Tests whether this node is `ConstPoolNode`. + ASMJIT_INLINE_NODEBUG bool isConstPool() const noexcept { return type() == NodeType::kConstPool; } + //! Tests whether this node is `CommentNode`. + ASMJIT_INLINE_NODEBUG bool isComment() const noexcept { return type() == NodeType::kComment; } + //! Tests whether this node is `SentinelNode`. + ASMJIT_INLINE_NODEBUG bool isSentinel() const noexcept { return type() == NodeType::kSentinel; } + + //! Tests whether this node is `FuncNode`. + ASMJIT_INLINE_NODEBUG bool isFunc() const noexcept { return type() == NodeType::kFunc; } + //! Tests whether this node is `FuncRetNode`. + ASMJIT_INLINE_NODEBUG bool isFuncRet() const noexcept { return type() == NodeType::kFuncRet; } + //! Tests whether this node is `InvokeNode`. + ASMJIT_INLINE_NODEBUG bool isInvoke() const noexcept { return type() == NodeType::kInvoke; } + + //! Returns the node flags. + ASMJIT_INLINE_NODEBUG NodeFlags flags() const noexcept { return _any._nodeFlags; } + //! Tests whether the node has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(NodeFlags flag) const noexcept { return Support::test(_any._nodeFlags, flag); } + //! Replaces node flags with `flags`. + ASMJIT_INLINE_NODEBUG void setFlags(NodeFlags flags) noexcept { _any._nodeFlags = flags; } + //! Adds the given `flags` to node flags. + ASMJIT_INLINE_NODEBUG void addFlags(NodeFlags flags) noexcept { _any._nodeFlags |= flags; } + //! Clears the given `flags` from node flags. + ASMJIT_INLINE_NODEBUG void clearFlags(NodeFlags flags) noexcept { _any._nodeFlags &= ~flags; } + + //! Tests whether the node is code that can be executed. + ASMJIT_INLINE_NODEBUG bool isCode() const noexcept { return hasFlag(NodeFlags::kIsCode); } + //! Tests whether the node is data that cannot be executed. + ASMJIT_INLINE_NODEBUG bool isData() const noexcept { return hasFlag(NodeFlags::kIsData); } + //! Tests whether the node is informative only (is never encoded like comment, etc...). + ASMJIT_INLINE_NODEBUG bool isInformative() const noexcept { return hasFlag(NodeFlags::kIsInformative); } + //! Tests whether the node is removable if it's in an unreachable code block. + ASMJIT_INLINE_NODEBUG bool isRemovable() const noexcept { return hasFlag(NodeFlags::kIsRemovable); } + //! Tests whether the node has no effect when executed (label, .align, nop, ...). + ASMJIT_INLINE_NODEBUG bool hasNoEffect() const noexcept { return hasFlag(NodeFlags::kHasNoEffect); } + //! Tests whether the node is part of the code. + ASMJIT_INLINE_NODEBUG bool isActive() const noexcept { return hasFlag(NodeFlags::kIsActive); } + + //! Tests whether the node has a position assigned. + //! + //! \remarks Returns `true` if node position is non-zero. + ASMJIT_INLINE_NODEBUG bool hasPosition() const noexcept { return _position != 0; } + //! Returns node position. + ASMJIT_INLINE_NODEBUG uint32_t position() const noexcept { return _position; } + //! Sets node position. + //! + //! Node position is a 32-bit unsigned integer that is used by Compiler to track where the node is relatively to + //! the start of the function. It doesn't describe a byte position in a binary, instead it's just a pseudo position + //! used by liveness analysis and other tools around Compiler. + //! + //! If you don't use Compiler then you may use `position()` and `setPosition()` freely for your own purposes if + //! the 32-bit value limit is okay for you. + ASMJIT_INLINE_NODEBUG void setPosition(uint32_t position) noexcept { _position = position; } + + //! Returns user data casted to `T*`. + //! + //! User data is dedicated to be used only by AsmJit users and not touched by the library. The data is of a pointer + //! size so you can either store a pointer or `int64_t` value through `setUserDataAsPtr()`, `setUserDataAsInt64()` + //! and `setUserDataAsUInt64()`. + template + ASMJIT_INLINE_NODEBUG T* userDataAsPtr() const noexcept { return static_cast(_userDataPtr); } + //! Returns user data casted to `int64_t`. + ASMJIT_INLINE_NODEBUG int64_t userDataAsInt64() const noexcept { return int64_t(_userDataU64); } + //! Returns user data casted to `uint64_t`. + ASMJIT_INLINE_NODEBUG uint64_t userDataAsUInt64() const noexcept { return _userDataU64; } + + //! Sets user data to `data`. + template + ASMJIT_INLINE_NODEBUG void setUserDataAsPtr(T* data) noexcept { _userDataPtr = static_cast(data); } + //! Sets used data to the given 64-bit signed `value`. + ASMJIT_INLINE_NODEBUG void setUserDataAsInt64(int64_t value) noexcept { _userDataU64 = uint64_t(value); } + //! Sets used data to the given 64-bit unsigned `value`. + ASMJIT_INLINE_NODEBUG void setUserDataAsUInt64(uint64_t value) noexcept { _userDataU64 = value; } + + //! Resets user data to zero / nullptr. + ASMJIT_INLINE_NODEBUG void resetUserData() noexcept { _userDataU64 = 0; } + + //! Tests whether the node has an associated pass data. + ASMJIT_INLINE_NODEBUG bool hasPassData() const noexcept { return _passData != nullptr; } + //! Returns the node pass data - data used during processing & transformations. + template + ASMJIT_INLINE_NODEBUG T* passData() const noexcept { return (T*)_passData; } + //! Sets the node pass data to `data`. + template + ASMJIT_INLINE_NODEBUG void setPassData(T* data) noexcept { _passData = (void*)data; } + //! Resets the node pass data to nullptr. + ASMJIT_INLINE_NODEBUG void resetPassData() noexcept { _passData = nullptr; } + + //! Tests whether the node has an inline comment/annotation. + ASMJIT_INLINE_NODEBUG bool hasInlineComment() const noexcept { return _inlineComment != nullptr; } + //! Returns an inline comment/annotation string. + ASMJIT_INLINE_NODEBUG const char* inlineComment() const noexcept { return _inlineComment; } + //! Sets an inline comment/annotation string to `s`. + ASMJIT_INLINE_NODEBUG void setInlineComment(const char* s) noexcept { _inlineComment = s; } + //! Resets an inline comment/annotation string to nullptr. + ASMJIT_INLINE_NODEBUG void resetInlineComment() noexcept { _inlineComment = nullptr; } + + //! \} +}; + +//! Instruction node. +//! +//! Wraps an instruction with its options and operands. +class InstNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(InstNode) + + //! \name Constants + //! \{ + + //! The number of embedded operands for a default \ref InstNode instance that are always allocated as a part of + //! the instruction itself. Minimum embedded operands is 4, but in 32-bit more pointers are smaller and we can + //! embed 5. The rest (up to 6 operands) is considered extended. + //! + //! The number of operands InstNode holds is decided when \ref InstNode is created. + static constexpr uint32_t kBaseOpCapacity = uint32_t((128 - sizeof(BaseNode) - sizeof(BaseInst)) / sizeof(Operand_)); + + //! Count of maximum number of operands \ref InstNode can hold. + static constexpr uint32_t kFullOpCapacity = Globals::kMaxOpCount; + + //! \} + + //! \name Members + //! \{ + + //! Base instruction data. + BaseInst _baseInst; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `InstNode` instance. + ASMJIT_INLINE_NODEBUG InstNode(BaseBuilder* cb, InstId instId, InstOptions options, uint32_t opCount, uint32_t opCapacity = kBaseOpCapacity) noexcept + : BaseNode(cb, NodeType::kInst, NodeFlags::kIsCode | NodeFlags::kIsRemovable | NodeFlags::kActsAsInst), + _baseInst(instId, options) { + _inst._opCapacity = uint8_t(opCapacity); + _inst._opCount = uint8_t(opCount); + } + + //! \cond INTERNAL + //! Reset all built-in operands, including `extraReg`. + ASMJIT_INLINE_NODEBUG void _resetOps() noexcept { + _baseInst.resetExtraReg(); + resetOpRange(0, opCapacity()); + } + //! \endcond + + //! \} + + //! \name Instruction Object + //! \{ + + ASMJIT_INLINE_NODEBUG BaseInst& baseInst() noexcept { return _baseInst; } + ASMJIT_INLINE_NODEBUG const BaseInst& baseInst() const noexcept { return _baseInst; } + + //! \} + + //! \name Instruction Id + //! \{ + + //! Returns the instruction id, see `BaseInst::Id`. + ASMJIT_INLINE_NODEBUG InstId id() const noexcept { return _baseInst.id(); } + //! Returns the instruction real id, see `BaseInst::Id`. + ASMJIT_INLINE_NODEBUG InstId realId() const noexcept { return _baseInst.realId(); } + + //! Sets the instruction id to `id`, see `BaseInst::Id`. + ASMJIT_INLINE_NODEBUG void setId(InstId id) noexcept { _baseInst.setId(id); } + + //! \} + + //! \name Instruction Options + //! \{ + + //! Returns instruction options, see \ref InstOptions for more details. + ASMJIT_INLINE_NODEBUG InstOptions options() const noexcept { return _baseInst.options(); } + //! Tests whether instruction has the given \option` set/enabled. + ASMJIT_INLINE_NODEBUG bool hasOption(InstOptions option) const noexcept { return _baseInst.hasOption(option); } + //! Sets instruction `options` to the provided value, resetting all others. + ASMJIT_INLINE_NODEBUG void setOptions(InstOptions options) noexcept { _baseInst.setOptions(options); } + //! Adds instruction `options` to the instruction. + ASMJIT_INLINE_NODEBUG void addOptions(InstOptions options) noexcept { _baseInst.addOptions(options); } + //! Clears instruction `options` of the instruction (disables the given options). + ASMJIT_INLINE_NODEBUG void clearOptions(InstOptions options) noexcept { _baseInst.clearOptions(options); } + //! Resets instruction options to none - disabling all instruction options. + ASMJIT_INLINE_NODEBUG void resetOptions() noexcept { _baseInst.resetOptions(); } + + //! \} + + //! \name Extra Register + //! \{ + + //! Tests whether the node has an extra register operand. + ASMJIT_INLINE_NODEBUG bool hasExtraReg() const noexcept { return _baseInst.hasExtraReg(); } + //! Returns extra register operand. + ASMJIT_INLINE_NODEBUG RegOnly& extraReg() noexcept { return _baseInst.extraReg(); } + //! \overload + ASMJIT_INLINE_NODEBUG const RegOnly& extraReg() const noexcept { return _baseInst.extraReg(); } + //! Sets extra register operand to `reg`. + ASMJIT_INLINE_NODEBUG void setExtraReg(const BaseReg& reg) noexcept { _baseInst.setExtraReg(reg); } + //! Sets extra register operand to `reg`. + ASMJIT_INLINE_NODEBUG void setExtraReg(const RegOnly& reg) noexcept { _baseInst.setExtraReg(reg); } + //! Resets extra register operand. + ASMJIT_INLINE_NODEBUG void resetExtraReg() noexcept { _baseInst.resetExtraReg(); } + + //! \} + + //! \name Instruction Operands + //! \{ + + //! Returns operand count. + ASMJIT_INLINE_NODEBUG uint32_t opCount() const noexcept { return _inst._opCount; } + //! Returns operand capacity. + ASMJIT_INLINE_NODEBUG uint32_t opCapacity() const noexcept { return _inst._opCapacity; } + + //! Sets operand count. + ASMJIT_INLINE_NODEBUG void setOpCount(uint32_t opCount) noexcept { _inst._opCount = uint8_t(opCount); } + + //! Returns operands array. + ASMJIT_INLINE_NODEBUG Operand* operands() noexcept { + return reinterpret_cast(reinterpret_cast(this) + sizeof(InstNode)); + } + + //! Returns operands array (const). + ASMJIT_INLINE_NODEBUG const Operand* operands() const noexcept { + return reinterpret_cast(reinterpret_cast(this) + sizeof(InstNode)); + } + + //! Returns operand at the given `index`. + inline Operand& op(uint32_t index) noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + Operand* ops = operands(); + return ops[index].as(); + } + + //! Returns operand at the given `index` (const). + inline const Operand& op(uint32_t index) const noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + const Operand* ops = operands(); + return ops[index].as(); + } + + //! Sets operand at the given `index` to `op`. + inline void setOp(uint32_t index, const Operand_& op) noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + Operand* ops = operands(); + ops[index].copyFrom(op); + } + + //! Resets operand at the given `index` to none. + inline void resetOp(uint32_t index) noexcept { + ASMJIT_ASSERT(index < opCapacity()); + + Operand* ops = operands(); + ops[index].reset(); + } + + //! Resets operands at `[start, end)` range. + inline void resetOpRange(uint32_t start, uint32_t end) noexcept { + Operand* ops = operands(); + for (uint32_t i = start; i < end; i++) + ops[i].reset(); + } + + //! \} + + //! \name Utilities + //! \{ + + //! Tests whether the given operand type `opType` is used by the instruction. + inline bool hasOpType(OperandType opType) const noexcept { + const Operand* ops = operands(); + for (uint32_t i = 0, count = opCount(); i < count; i++) + if (ops[i].opType() == opType) + return true; + return false; + } + + //! Tests whether the instruction uses at least one register operand. + inline bool hasRegOp() const noexcept { return hasOpType(OperandType::kReg); } + //! Tests whether the instruction uses at least one memory operand. + inline bool hasMemOp() const noexcept { return hasOpType(OperandType::kMem); } + //! Tests whether the instruction uses at least one immediate operand. + inline bool hasImmOp() const noexcept { return hasOpType(OperandType::kImm); } + //! Tests whether the instruction uses at least one label operand. + inline bool hasLabelOp() const noexcept { return hasOpType(OperandType::kLabel); } + + //! Returns the index of the given operand type `opType`. + //! + //! \note If the operand type wa found, the value returned represents its index in \ref operands() + //! array, otherwise \ref Globals::kNotFound is returned to signalize that the operand was not found. + inline uint32_t indexOfOpType(OperandType opType) const noexcept { + uint32_t i = 0; + uint32_t count = opCount(); + const Operand* ops = operands(); + + while (i < count) { + if (ops[i].opType() == opType) + return i; + i++; + } + + return Globals::kNotFound; + } + + //! A shortcut that calls `indexOfOpType(OperandType::kMem)`. + inline uint32_t indexOfMemOp() const noexcept { return indexOfOpType(OperandType::kMem); } + //! A shortcut that calls `indexOfOpType(OperandType::kImm)`. + inline uint32_t indexOfImmOp() const noexcept { return indexOfOpType(OperandType::kImm); } + //! A shortcut that calls `indexOfOpType(OperandType::kLabel)`. + inline uint32_t indexOfLabelOp() const noexcept { return indexOfOpType(OperandType::kLabel); } + + //! \} + + //! \name Rewriting + //! \{ + + //! \cond INTERNAL + + //! Returns uint32_t[] view that represents BaseInst::RegOnly and instruction operands. + ASMJIT_INLINE_NODEBUG uint32_t* _getRewriteArray() noexcept { return &_baseInst._extraReg._id; } + //! \overload + ASMJIT_INLINE_NODEBUG const uint32_t* _getRewriteArray() const noexcept { return &_baseInst._extraReg._id; } + + //! Maximum value of rewrite id - 6 operands each having 4 slots is 24, one RegOnly having 2 slots => 26. + static constexpr uint32_t kMaxRewriteId = 26 - 1; + + //! Returns a rewrite index of the given pointer to `id`. + //! + //! This function returns a value that can be then passed to `\ref rewriteIdAtIndex() function. It can address + //! any id from any operand that is used by the instruction in addition to \ref BaseInst::regOnly field, which + //! can also be used by the register allocator. + inline uint32_t getRewriteIndex(const uint32_t* id) const noexcept { + const uint32_t* array = _getRewriteArray(); + ASMJIT_ASSERT(array <= id); + + size_t index = (size_t)(id - array); + ASMJIT_ASSERT(index <= kMaxRewriteId); + + return uint32_t(index); + } + + //! Rewrites the given `index` to the provided identifier `id`. + //! + //! \note This is an internal function that is used by a \ref BaseCompiler implementation to rewrite virtual + //! registers to physical registers. The rewriter in this case sees all operands as array of uint32 values + //! and the given `index` describes a position in this array. For example a single \ref Operand would be + //! decomposed to 4 uint32_t values, where the first at index 0 would be operand signature, next would be + //! base id, etc... This is a comfortable way of patching operands without having to check for their types. + inline void rewriteIdAtIndex(uint32_t index, uint32_t id) noexcept { + ASMJIT_ASSERT(index <= kMaxRewriteId); + + uint32_t* array = _getRewriteArray(); + array[index] = id; + } + //! \endcond + + //! \} + + //! \name Static Functions + //! \{ + + //! \cond INTERNAL + + //! Returns the capacity required for the given operands count `opCount`. + //! + //! There are only two capacities used - \ref kBaseOpCapacity and \ref kFullOpCapacity, so this function + //! is used to decide between these two. The general rule is that instructions that can be represented with + //! \ref kBaseOpCapacity would use this value, and all others would take \ref kFullOpCapacity. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t capacityOfOpCount(uint32_t opCount) noexcept { + return opCount <= kBaseOpCapacity ? kBaseOpCapacity : kFullOpCapacity; + } + + //! Calculates the size of \ref InstNode required to hold at most `opCapacity` operands. + //! + //! This function is used internally to allocate \ref InstNode. + static ASMJIT_INLINE_NODEBUG constexpr size_t nodeSizeOfOpCapacity(uint32_t opCapacity) noexcept { + return sizeof(InstNode) + opCapacity * sizeof(Operand); + } + //! \endcond + + //! \} +}; + +//! Instruction node with embedded operands following \ref InstNode layout. +//! +//! \note This is used to make tools such as static analysis and compilers happy about the layout. There were two +//! instruction nodes in the past, having the second extend the operand array of the first, but that has caused +//! undefined behavior and made recent tools unhappy about that. +template +class InstNodeWithOperands : public InstNode { +public: + Operand_ _operands[kN]; + + //! Creates a new `InstNodeWithOperands` instance. + ASMJIT_INLINE_NODEBUG InstNodeWithOperands(BaseBuilder* cb, InstId instId, InstOptions options, uint32_t opCount) noexcept + : InstNode(cb, instId, options, opCount, kN) {} +}; + +//! Section node. +class SectionNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(SectionNode) + + //! \name Members + //! \{ + + //! Section id. + uint32_t _id; + + //! Next section node that follows this section. + //! + //! This link is only valid when the section is active (is part of the code) and when `Builder::hasDirtySectionLinks()` + //! returns `false`. If you intend to use this field you should always call `Builder::updateSectionLinks()` before you + //! do so. + SectionNode* _nextSection; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `SectionNode` instance. + ASMJIT_INLINE_NODEBUG SectionNode(BaseBuilder* cb, uint32_t sectionId = 0) noexcept + : BaseNode(cb, NodeType::kSection, NodeFlags::kHasNoEffect), + _id(sectionId), + _nextSection(nullptr) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the section id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + + //! \} +}; + +//! Label node. +class LabelNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(LabelNode) + + //! \name Members + //! \{ + + //! Label identifier. + uint32_t _labelId; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `LabelNode` instance. + ASMJIT_INLINE_NODEBUG LabelNode(BaseBuilder* cb, uint32_t labelId = 0) noexcept + : BaseNode(cb, NodeType::kLabel, NodeFlags::kHasNoEffect | NodeFlags::kActsAsLabel), + _labelId(labelId) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns \ref Label representation of the \ref LabelNode. + ASMJIT_INLINE_NODEBUG Label label() const noexcept { return Label(_labelId); } + //! Returns the id of the label. + ASMJIT_INLINE_NODEBUG uint32_t labelId() const noexcept { return _labelId; } + + //! \} +}; + +//! Align directive (BaseBuilder). +//! +//! Wraps `.align` directive. +class AlignNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(AlignNode) + + //! \name Members + //! \{ + + //! Alignment (in bytes). + uint32_t _alignment; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `AlignNode` instance. + ASMJIT_INLINE_NODEBUG AlignNode(BaseBuilder* cb, AlignMode alignMode, uint32_t alignment) noexcept + : BaseNode(cb, NodeType::kAlign, NodeFlags::kIsCode | NodeFlags::kHasNoEffect) { + + _alignData._alignMode = alignMode; + _alignment = alignment; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns align mode. + ASMJIT_INLINE_NODEBUG AlignMode alignMode() const noexcept { return _alignData._alignMode; } + //! Sets align mode to `alignMode`. + ASMJIT_INLINE_NODEBUG void setAlignMode(AlignMode alignMode) noexcept { _alignData._alignMode = alignMode; } + + //! Returns align offset in bytes. + ASMJIT_INLINE_NODEBUG uint32_t alignment() const noexcept { return _alignment; } + //! Sets align offset in bytes to `offset`. + ASMJIT_INLINE_NODEBUG void setAlignment(uint32_t alignment) noexcept { _alignment = alignment; } + + //! \} +}; + +//! Embed data node. +//! +//! Wraps `.data` directive. The node contains data that will be placed at the node's position in the assembler +//! stream. The data is considered to be RAW; no analysis nor byte-order conversion is performed on RAW data. +class EmbedDataNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(EmbedDataNode) + + //! \cond INTERNAL + enum : uint32_t { + kInlineBufferSize = 128 - (sizeof(BaseNode) + sizeof(size_t) * 2) + }; + //! \endcond + + //! \name Members + //! \{ + + size_t _itemCount; + size_t _repeatCount; + + union { + uint8_t* _externalData; + uint8_t _inlineData[kInlineBufferSize]; + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `EmbedDataNode` instance. + ASMJIT_INLINE_NODEBUG EmbedDataNode(BaseBuilder* cb) noexcept + : BaseNode(cb, NodeType::kEmbedData, NodeFlags::kIsData), + _itemCount(0), + _repeatCount(0) { + _embed._typeId = TypeId::kUInt8; + _embed._typeSize = uint8_t(1); + memset(_inlineData, 0, kInlineBufferSize); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns data type as \ref TypeId. + ASMJIT_INLINE_NODEBUG TypeId typeId() const noexcept { return _embed._typeId; } + //! Returns the size of a single data element. + ASMJIT_INLINE_NODEBUG uint32_t typeSize() const noexcept { return _embed._typeSize; } + + //! Returns a pointer to the data casted to `uint8_t`. + ASMJIT_INLINE_NODEBUG uint8_t* data() const noexcept { + return dataSize() <= kInlineBufferSize ? const_cast(_inlineData) : _externalData; + } + + //! Returns a pointer to the data casted to `T`. + template + ASMJIT_INLINE_NODEBUG T* dataAs() const noexcept { return reinterpret_cast(data()); } + + //! Returns the number of (typed) items in the array. + ASMJIT_INLINE_NODEBUG size_t itemCount() const noexcept { return _itemCount; } + + //! Returns how many times the data is repeated (default 1). + //! + //! Repeated data is useful when defining constants for SIMD, for example. + ASMJIT_INLINE_NODEBUG size_t repeatCount() const noexcept { return _repeatCount; } + + //! Returns the size of the data, not considering the number of times it repeats. + //! + //! \note The returned value is the same as `typeSize() * itemCount()`. + ASMJIT_INLINE_NODEBUG size_t dataSize() const noexcept { return typeSize() * _itemCount; } + + //! \} +}; + +//! Label data node. +class EmbedLabelNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(EmbedLabelNode) + + //! \name Members + //! \{ + + uint32_t _labelId; + uint32_t _dataSize; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `EmbedLabelNode` instance. + ASMJIT_INLINE_NODEBUG EmbedLabelNode(BaseBuilder* cb, uint32_t labelId = 0, uint32_t dataSize = 0) noexcept + : BaseNode(cb, NodeType::kEmbedLabel, NodeFlags::kIsData), + _labelId(labelId), + _dataSize(dataSize) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the label to embed as \ref Label operand. + ASMJIT_INLINE_NODEBUG Label label() const noexcept { return Label(_labelId); } + //! Returns the id of the label. + ASMJIT_INLINE_NODEBUG uint32_t labelId() const noexcept { return _labelId; } + + //! Sets the label id from `label` operand. + ASMJIT_INLINE_NODEBUG void setLabel(const Label& label) noexcept { setLabelId(label.id()); } + //! Sets the label id (use with caution, improper use can break a lot of things). + ASMJIT_INLINE_NODEBUG void setLabelId(uint32_t labelId) noexcept { _labelId = labelId; } + + //! Returns the data size. + ASMJIT_INLINE_NODEBUG uint32_t dataSize() const noexcept { return _dataSize; } + //! Sets the data size. + ASMJIT_INLINE_NODEBUG void setDataSize(uint32_t dataSize) noexcept { _dataSize = dataSize; } + + //! \} +}; + +//! Label data node. +class EmbedLabelDeltaNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(EmbedLabelDeltaNode) + + //! \name Members + //! \{ + + uint32_t _labelId; + uint32_t _baseLabelId; + uint32_t _dataSize; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `EmbedLabelDeltaNode` instance. + ASMJIT_INLINE_NODEBUG EmbedLabelDeltaNode(BaseBuilder* cb, uint32_t labelId = 0, uint32_t baseLabelId = 0, uint32_t dataSize = 0) noexcept + : BaseNode(cb, NodeType::kEmbedLabelDelta, NodeFlags::kIsData), + _labelId(labelId), + _baseLabelId(baseLabelId), + _dataSize(dataSize) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the label as `Label` operand. + ASMJIT_INLINE_NODEBUG Label label() const noexcept { return Label(_labelId); } + //! Returns the id of the label. + ASMJIT_INLINE_NODEBUG uint32_t labelId() const noexcept { return _labelId; } + + //! Sets the label id from `label` operand. + ASMJIT_INLINE_NODEBUG void setLabel(const Label& label) noexcept { setLabelId(label.id()); } + //! Sets the label id. + ASMJIT_INLINE_NODEBUG void setLabelId(uint32_t labelId) noexcept { _labelId = labelId; } + + //! Returns the base label as `Label` operand. + ASMJIT_INLINE_NODEBUG Label baseLabel() const noexcept { return Label(_baseLabelId); } + //! Returns the id of the base label. + ASMJIT_INLINE_NODEBUG uint32_t baseLabelId() const noexcept { return _baseLabelId; } + + //! Sets the base label id from `label` operand. + ASMJIT_INLINE_NODEBUG void setBaseLabel(const Label& baseLabel) noexcept { setBaseLabelId(baseLabel.id()); } + //! Sets the base label id. + ASMJIT_INLINE_NODEBUG void setBaseLabelId(uint32_t baseLabelId) noexcept { _baseLabelId = baseLabelId; } + + //! Returns the size of the embedded label address. + ASMJIT_INLINE_NODEBUG uint32_t dataSize() const noexcept { return _dataSize; } + //! Sets the size of the embedded label address. + ASMJIT_INLINE_NODEBUG void setDataSize(uint32_t dataSize) noexcept { _dataSize = dataSize; } + + //! \} +}; + +//! A node that wraps `ConstPool`. +class ConstPoolNode : public LabelNode { +public: + ASMJIT_NONCOPYABLE(ConstPoolNode) + + //! \name Members + //! \{ + + ConstPool _constPool; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `ConstPoolNode` instance. + ASMJIT_INLINE_NODEBUG ConstPoolNode(BaseBuilder* cb, uint32_t id = 0) noexcept + : LabelNode(cb, id), + _constPool(&cb->_codeZone) { + + setType(NodeType::kConstPool); + addFlags(NodeFlags::kIsData); + clearFlags(NodeFlags::kIsCode | NodeFlags::kHasNoEffect); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the constant-pool is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _constPool.empty(); } + //! Returns the size of the constant-pool in bytes. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _constPool.size(); } + //! Returns minimum alignment. + ASMJIT_INLINE_NODEBUG size_t alignment() const noexcept { return _constPool.alignment(); } + + //! Returns the wrapped `ConstPool` instance. + ASMJIT_INLINE_NODEBUG ConstPool& constPool() noexcept { return _constPool; } + //! Returns the wrapped `ConstPool` instance (const). + ASMJIT_INLINE_NODEBUG const ConstPool& constPool() const noexcept { return _constPool; } + + //! \} + + //! \name Utilities + //! \{ + + //! See `ConstPool::add()`. + ASMJIT_INLINE_NODEBUG Error add(const void* data, size_t size, size_t& dstOffset) noexcept { + return _constPool.add(data, size, dstOffset); + } + + //! \} +}; + +//! Comment node. +class CommentNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(CommentNode) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `CommentNode` instance. + ASMJIT_INLINE_NODEBUG CommentNode(BaseBuilder* cb, const char* comment) noexcept + : BaseNode(cb, NodeType::kComment, NodeFlags::kIsInformative | NodeFlags::kHasNoEffect | NodeFlags::kIsRemovable) { + _inlineComment = comment; + } + + //! \} +}; + +//! Sentinel node. +//! +//! Sentinel is a marker that is completely ignored by the code builder. It's used to remember a position in a code +//! as it never gets removed by any pass. +class SentinelNode : public BaseNode { +public: + ASMJIT_NONCOPYABLE(SentinelNode) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `SentinelNode` instance. + ASMJIT_INLINE_NODEBUG SentinelNode(BaseBuilder* cb, SentinelType sentinelType = SentinelType::kUnknown) noexcept + : BaseNode(cb, NodeType::kSentinel, NodeFlags::kIsInformative | NodeFlags::kHasNoEffect) { + + _sentinel._sentinelType = sentinelType; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the type of the sentinel. + ASMJIT_INLINE_NODEBUG SentinelType sentinelType() const noexcept { + return _sentinel._sentinelType; + } + + //! Sets the type of the sentinel. + ASMJIT_INLINE_NODEBUG void setSentinelType(SentinelType type) noexcept { + _sentinel._sentinelType = type; + } + + //! \} +}; + +//! Pass can be used to implement code transformations, analysis, and lowering. +class ASMJIT_VIRTAPI Pass { +public: + ASMJIT_BASE_CLASS(Pass) + ASMJIT_NONCOPYABLE(Pass) + + //! \name Members + //! \{ + + //! BaseBuilder this pass is assigned to. + BaseBuilder* _cb = nullptr; + //! Name of the pass. + const char* _name = nullptr; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API Pass(const char* name) noexcept; + ASMJIT_API virtual ~Pass() noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns \ref BaseBuilder associated with the pass. + ASMJIT_INLINE_NODEBUG const BaseBuilder* cb() const noexcept { return _cb; } + //! Returns the name of the pass. + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name; } + + //! \} + + //! \name Pass Interface + //! \{ + + //! Processes the code stored in Builder or Compiler. + //! + //! This is the only function that is called by the `BaseBuilder` to process the code. It passes `zone`, + //! which will be reset after the `run()` finishes. + ASMJIT_API virtual Error run(Zone* zone, Logger* logger); + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // !ASMJIT_NO_BUILDER +#endif // ASMJIT_CORE_BUILDER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/codebuffer.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/codebuffer.h new file mode 100644 index 0000000000000000000000000000000000000000..58d55c62937fd6de1c1433cd8da7217d80a2025a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/codebuffer.h @@ -0,0 +1,113 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CODEBUFFER_H_INCLUDED +#define ASMJIT_CORE_CODEBUFFER_H_INCLUDED + +#include "../core/globals.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Flags used by \ref CodeBuffer. +enum class CodeBufferFlags : uint32_t { + //! No flags. + kNone = 0, + //! Buffer is external (not allocated by asmjit). + kIsExternal = 0x00000001u, + //! Buffer is fixed (cannot be reallocated). + kIsFixed = 0x00000002u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CodeBufferFlags) + +//! Code or data buffer. +struct CodeBuffer { + //! \name Members + //! \{ + + //! The content of the buffer (data). + uint8_t* _data; + //! Number of bytes of `data` used. + size_t _size; + //! Buffer capacity (in bytes). + size_t _capacity; + //! Buffer flags. + CodeBufferFlags _flags; + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Returns a reference to the byte at the given `index`. + inline uint8_t& operator[](size_t index) noexcept { + ASMJIT_ASSERT(index < _size); + return _data[index]; + } + //! \overload + inline const uint8_t& operator[](size_t index) const noexcept { + ASMJIT_ASSERT(index < _size); + return _data[index]; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns code buffer flags. + ASMJIT_INLINE_NODEBUG CodeBufferFlags flags() const noexcept { return _flags; } + //! Tests whether the code buffer has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(CodeBufferFlags flag) const noexcept { return Support::test(_flags, flag); } + + //! Tests whether this code buffer has a fixed size. + //! + //! Fixed size means that the code buffer is fixed and cannot grow. + ASMJIT_INLINE_NODEBUG bool isFixed() const noexcept { return hasFlag(CodeBufferFlags::kIsFixed); } + + //! Tests whether the data in this code buffer is external. + //! + //! External data can only be provided by users, it's never used by AsmJit. + ASMJIT_INLINE_NODEBUG bool isExternal() const noexcept { return hasFlag(CodeBufferFlags::kIsExternal); } + + //! Tests whether the data in this code buffer is allocated (non-null). + ASMJIT_INLINE_NODEBUG bool isAllocated() const noexcept { return _data != nullptr; } + + //! Tests whether the code buffer is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return !_size; } + + //! Returns the size of the data. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + //! Returns the capacity of the data. + ASMJIT_INLINE_NODEBUG size_t capacity() const noexcept { return _capacity; } + + //! Returns the pointer to the data the buffer references. + ASMJIT_INLINE_NODEBUG uint8_t* data() noexcept { return _data; } + //! \overload + ASMJIT_INLINE_NODEBUG const uint8_t* data() const noexcept { return _data; } + + //! \} + + //! \name Iterators + //! \{ + + ASMJIT_INLINE_NODEBUG uint8_t* begin() noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const uint8_t* begin() const noexcept { return _data; } + + ASMJIT_INLINE_NODEBUG uint8_t* end() noexcept { return _data + _size; } + ASMJIT_INLINE_NODEBUG const uint8_t* end() const noexcept { return _data + _size; } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CODEBUFFER_H_INCLUDED + diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/codeholder.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/codeholder.h new file mode 100644 index 0000000000000000000000000000000000000000..7c52b677ffa8cadef80c362d74ff73876cbac590 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/codeholder.h @@ -0,0 +1,1123 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CODEHOLDER_H_INCLUDED +#define ASMJIT_CORE_CODEHOLDER_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/codebuffer.h" +#include "../core/errorhandler.h" +#include "../core/operand.h" +#include "../core/string.h" +#include "../core/support.h" +#include "../core/target.h" +#include "../core/zone.h" +#include "../core/zonehash.h" +#include "../core/zonestring.h" +#include "../core/zonetree.h" +#include "../core/zonevector.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +class BaseEmitter; +class CodeHolder; +class LabelEntry; +class Logger; + +//! Operator type that can be used within an \ref Expression. +enum class ExpressionOpType : uint8_t { + //! Addition. + kAdd = 0, + //! Subtraction. + kSub = 1, + //! Multiplication + kMul = 2, + //! Logical left shift. + kSll = 3, + //! Logical right shift. + kSrl = 4, + //! Arithmetic right shift. + kSra = 5 +}; + +//! Value type that can be used within an \ref Expression. +enum class ExpressionValueType : uint8_t { + //! No value or invalid. + kNone = 0, + //! Value is 64-bit unsigned integer (constant). + kConstant = 1, + //! Value is \ref LabelEntry, which references a \ref Label. + kLabel = 2, + //! Value is \ref Expression + kExpression = 3 +}; + +//! Expression node that can reference constants, labels, and another expressions. +struct Expression { + //! Expression value. + union Value { + //! Constant. + uint64_t constant; + //! Pointer to another expression. + Expression* expression; + //! Pointer to \ref LabelEntry. + LabelEntry* label; + }; + + //! \name Members + //! \{ + + //! Operation type. + ExpressionOpType opType; + //! Value types of \ref value. + ExpressionValueType valueType[2]; + //! Reserved for future use, should be initialized to zero. + uint8_t reserved[5]; + //! Expression left and right values. + Value value[2]; + + //! \} + + //! \name Accessors + //! \{ + + //! Resets the whole expression. + //! + //! Changes both values to \ref ExpressionValueType::kNone. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = Expression{}; } + + //! Sets the value type at `index` to \ref ExpressionValueType::kConstant and its content to `constant`. + ASMJIT_INLINE_NODEBUG void setValueAsConstant(size_t index, uint64_t constant) noexcept { + valueType[index] = ExpressionValueType::kConstant; + value[index].constant = constant; + } + + //! Sets the value type at `index` to \ref ExpressionValueType::kLabel and its content to `labelEntry`. + ASMJIT_INLINE_NODEBUG void setValueAsLabel(size_t index, LabelEntry* labelEntry) noexcept { + valueType[index] = ExpressionValueType::kLabel; + value[index].label = labelEntry; + } + + //! Sets the value type at `index` to \ref ExpressionValueType::kExpression and its content to `expression`. + ASMJIT_INLINE_NODEBUG void setValueAsExpression(size_t index, Expression* expression) noexcept { + valueType[index] = ExpressionValueType::kExpression; + value[index].expression = expression; + } + + //! \} +}; + +//! Section flags, used by \ref Section. +enum class SectionFlags : uint32_t { + //! No flags. + kNone = 0, + //! Executable (.text sections). + kExecutable = 0x00000001u, + //! Read-only (.text and .data sections). + kReadOnly = 0x00000002u, + //! Zero initialized by the loader (BSS). + kZeroInitialized = 0x00000004u, + //! Info / comment flag. + kComment = 0x00000008u, + //! Section created implicitly, can be deleted by \ref Target. + kImplicit = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(SectionFlags) + +//! Flags that can be used with \ref CodeHolder::copySectionData() and \ref CodeHolder::copyFlattenedData(). +enum class CopySectionFlags : uint32_t { + //! No flags. + kNone = 0, + + //! If virtual size of a section is greater than the size of its \ref CodeBuffer then all bytes between the buffer + //! size and virtual size will be zeroed. If this option is not set then those bytes would be left as is, which + //! means that if the user didn't initialize them they would have a previous content, which may be unwanted. + kPadSectionBuffer = 0x00000001u, + + //! Clears the target buffer if the flattened data is less than the destination size. This option works + //! only with \ref CodeHolder::copyFlattenedData() as it processes multiple sections. It is ignored by + //! \ref CodeHolder::copySectionData(). + kPadTargetBuffer = 0x00000002u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CopySectionFlags) + +//! Section entry. +class Section { +public: + //! \name Members + //! \{ + + //! Section id. + uint32_t _id; + //! Section flags. + SectionFlags _flags; + //! Section alignment requirements (0 if no requirements). + uint32_t _alignment; + //! Order (lower value means higher priority). + int32_t _order; + //! Offset of this section from base-address. + uint64_t _offset; + //! Virtual size of the section (zero initialized sections). + uint64_t _virtualSize; + //! Section name (max 35 characters, PE allows max 8). + FixedString _name; + //! Code or data buffer. + CodeBuffer _buffer; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the section id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + //! Returns the section name, as a null terminated string. + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name.str; } + + //! Returns the section data. + ASMJIT_INLINE_NODEBUG uint8_t* data() noexcept { return _buffer.data(); } + //! \overload + ASMJIT_INLINE_NODEBUG const uint8_t* data() const noexcept { return _buffer.data(); } + + //! Returns the section flags. + ASMJIT_INLINE_NODEBUG SectionFlags flags() const noexcept { return _flags; } + //! Tests whether the section has the given `flag`. + ASMJIT_INLINE_NODEBUG bool hasFlag(SectionFlags flag) const noexcept { return Support::test(_flags, flag); } + //! Adds `flags` to the section flags. + ASMJIT_INLINE_NODEBUG void addFlags(SectionFlags flags) noexcept { _flags |= flags; } + //! Removes `flags` from the section flags. + ASMJIT_INLINE_NODEBUG void clearFlags(SectionFlags flags) noexcept { _flags &= ~flags; } + + //! Returns the minimum section alignment + ASMJIT_INLINE_NODEBUG uint32_t alignment() const noexcept { return _alignment; } + //! Sets the minimum section alignment + ASMJIT_INLINE_NODEBUG void setAlignment(uint32_t alignment) noexcept { _alignment = alignment; } + + //! Returns the section order, which has a higher priority than section id. + ASMJIT_INLINE_NODEBUG int32_t order() const noexcept { return _order; } + + //! Returns the section offset, relative to base. + ASMJIT_INLINE_NODEBUG uint64_t offset() const noexcept { return _offset; } + //! Set the section offset. + ASMJIT_INLINE_NODEBUG void setOffset(uint64_t offset) noexcept { _offset = offset; } + + //! Returns the virtual size of the section. + //! + //! Virtual size is initially zero and is never changed by AsmJit. It's normal if virtual size is smaller than + //! size returned by `bufferSize()` as the buffer stores real data emitted by assemblers or appended by users. + //! + //! Use `realSize()` to get the real and final size of this section. + ASMJIT_INLINE_NODEBUG uint64_t virtualSize() const noexcept { return _virtualSize; } + //! Sets the virtual size of the section. + ASMJIT_INLINE_NODEBUG void setVirtualSize(uint64_t virtualSize) noexcept { _virtualSize = virtualSize; } + + //! Returns the buffer size of the section. + ASMJIT_INLINE_NODEBUG size_t bufferSize() const noexcept { return _buffer.size(); } + //! Returns the real size of the section calculated from virtual and buffer sizes. + ASMJIT_INLINE_NODEBUG uint64_t realSize() const noexcept { return Support::max(virtualSize(), bufferSize()); } + + //! Returns the `CodeBuffer` used by this section. + ASMJIT_INLINE_NODEBUG CodeBuffer& buffer() noexcept { return _buffer; } + //! Returns the `CodeBuffer` used by this section (const). + ASMJIT_INLINE_NODEBUG const CodeBuffer& buffer() const noexcept { return _buffer; } + + //! \} +}; + +//! Entry in an address table. +class AddressTableEntry : public ZoneTreeNodeT { +public: + ASMJIT_NONCOPYABLE(AddressTableEntry) + + //! \name Members + //! \{ + + //! Address. + uint64_t _address; + //! Slot. + uint32_t _slot; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG explicit AddressTableEntry(uint64_t address) noexcept + : _address(address), + _slot(0xFFFFFFFFu) {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG uint64_t address() const noexcept { return _address; } + ASMJIT_INLINE_NODEBUG uint32_t slot() const noexcept { return _slot; } + + ASMJIT_INLINE_NODEBUG bool hasAssignedSlot() const noexcept { return _slot != 0xFFFFFFFFu; } + + ASMJIT_INLINE_NODEBUG bool operator<(const AddressTableEntry& other) const noexcept { return _address < other._address; } + ASMJIT_INLINE_NODEBUG bool operator>(const AddressTableEntry& other) const noexcept { return _address > other._address; } + + ASMJIT_INLINE_NODEBUG bool operator<(uint64_t queryAddress) const noexcept { return _address < queryAddress; } + ASMJIT_INLINE_NODEBUG bool operator>(uint64_t queryAddress) const noexcept { return _address > queryAddress; } + + //! \} +}; + +//! Offset format type, used by \ref OffsetFormat. +enum class OffsetType : uint8_t { + // Common Offset Formats + // --------------------- + + //! A value having `_immBitCount` bits and shifted by `_immBitShift`. + //! + //! This offset type is sufficient for many targets that store offset as a continuous set bits within an + //! instruction word / sequence of bytes. + kSignedOffset, + + //! An unsigned value having `_immBitCount` bits and shifted by `_immBitShift`. + kUnsignedOffset, + + // AArch64 Specific Offset Formats + // ------------------------------- + + //! AArch64 ADR format of `[.|immlo:2|.....|immhi:19|.....]`. + kAArch64_ADR, + + //! AArch64 ADRP format of `[.|immlo:2|.....|immhi:19|.....]` (4kB pages). + kAArch64_ADRP, + + // AArch32 Specific Offset Formats (T16 & T32) + // ------------------------------------------- + + //! AArch32 THUMBv2 immediate encoding of 'ADR' instruction (12-bit payload and sign bit): + //! + //! `|.....|imm:1|..N.N|......|imm:3|....|imm:8|` + //! + //! Where `N` is one if the offset is negative. The immediate is encoded as absolute value of the offset if negative. + kThumb32_ADR, + + //! AArch32 THUMBv2 immediate encoding of 'BLX' instruction (23-bit immediate payload, multiplied by 4): + //! + //! `|.....|imm[22]|imm[19:10]|..|ja|1|jb|imm[9:0]|0` + //! + //! Where: + //! + //! - `ja` is calculated as imm[22] ^ imm[21] ^ 1. + //! - `jb` is calculated as imm[22] ^ imm[20] ^ 1. + kThumb32_BLX, + + //! AArch32 THUMBv2 immediate encoding of 'B' instruction without `` (24-bit immediate payload, multiplied by 2): + //! + //! `|.....|imm[23]|imm[20:11]|..|ja|1|jb|imm[10:0]` + //! + //! Where: + //! + //! - `ja` is calculated as imm[23] ^ imm[22] ^ 1. + //! - `jb` is calculated as imm[23] ^ imm[21] ^ 1. + kThumb32_B, + + //! AArch32 THUMBv2 immediate encoding of 'B' instruction with `` (20-bit immediate payload, multiplied by 2). + //! + //! `|.....|imm[19]|....|imm[16:11]|..|ja|1|jb|imm[10:0]` + //! + //! Where: + //! + //! - `ja` is calculated as imm[19] ^ imm[18] ^ 1. + //! - `jb` is calculated as imm[19] ^ imm[17] ^ 1. + kThumb32_BCond, + + // AArch32 Specific Offset Formats (A32) + // ------------------------------------- + + //! AArch32 ADR instruction, which uses a standard 12-bit immediate encoding that is used by other ARM instructions. + kAArch32_ADR, + + //! AArch32 signed offset that is similar to `kSignedOffset`, however it uses absolute value of the offset and its + //! sign is encoded in 23rd bit of the opcode. + //! + //! `|........|U.......|........|........|` + //! + kAArch32_U23_SignedOffset, + + //! AArch32 offset format that encodes 8-bit offset as: + //! + //! `|........|U.......|....|imm[7:4]|....|imm[3:0]|` + //! + //! in a 32-bit word, where U is a sign of the displacement and the displacement itself is encoded as its absolute + //! value. + kAArch32_U23_0To3At0_4To7At8, + + //! AArch32 offset format that encodes a signed 25-bit offset as: + //! + //! `|.......|imm[0]|imm[24:1]|` + //! + //! in a 32-bit word. + kAArch32_1To24At0_0At24, + + //! Maximum value of `OffsetFormatType`. + kMaxValue = kAArch32_1To24At0_0At24 +}; + +//! Provides information about formatting offsets, absolute addresses, or their parts. Offset format is used by both +//! \ref RelocEntry and \ref LabelLink. The illustration below describes the relation of region size and offset size. +//! Region size is the size of the whole unit whereas offset size is the size of the unit that will be patched. +//! +//! ``` +//! +-> Code buffer | The subject of the relocation (region) | +//! | | (Word-Offset) (Word-Size) | +//! |xxxxxxxxxxxxxxx|................|*PATCHED*|................|xxxxxxxxxxxx-> +//! | | +//! [Word Offset points here]----+ +--- [WordOffset + WordSize] +//! ``` +//! +//! Once the offset word has been located it can be patched like this: +//! +//! ``` +//! |ImmDiscardLSB (discard LSB bits). +//! |.. +//! [0000000000000iiiiiiiiiiiiiiiiiDD] - Offset value (32-bit) +//! [000000000000000iiiiiiiiiiiiiiiii] - Offset value after discard LSB. +//! [00000000000iiiiiiiiiiiiiiiii0000] - Offset value shifted by ImmBitShift. +//! [xxxxxxxxxxxiiiiiiiiiiiiiiiiixxxx] - Patched word (32-bit) +//! |...............| +//! (ImmBitCount) +- ImmBitShift +//! ``` +struct OffsetFormat { + //! \name Members + //! \{ + + //! Type of the offset. + OffsetType _type; + //! Encoding flags. + uint8_t _flags; + //! Size of the region (in bytes) containing the offset value, if the offset value is part of an instruction, + //! otherwise it would be the same as `_valueSize`. + uint8_t _regionSize; + //! Size of the offset value, in bytes (1, 2, 4, or 8). + uint8_t _valueSize; + //! Offset of the offset value, in bytes, relative to the start of the region or data. Value offset would be + //! zero if both region size and value size are equal. + uint8_t _valueOffset; + //! Size of the offset immediate value in bits. + uint8_t _immBitCount; + //! Shift of the offset immediate value in bits in the target word. + uint8_t _immBitShift; + //! Number of least significant bits to discard before writing the immediate to the destination. All discarded + //! bits must be zero otherwise the value is invalid. + uint8_t _immDiscardLsb; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the type of the offset. + ASMJIT_INLINE_NODEBUG OffsetType type() const noexcept { return _type; } + + //! Returns whether the offset is encoded as an absolute value of the offset with additional field(s) that represent + //! the sign (AArch32 U/N fields in the opcode). + //! + //! If true, the offset itself is always positive and a separate U/N field is used to indicate the sign of the offset + //! (usually `U==1` means ADD, but sometimes `N==1` means negative offset, which implies SUB). + ASMJIT_INLINE_NODEBUG bool hasSignBit() const noexcept { + return _type == OffsetType::kThumb32_ADR || + _type == OffsetType::kAArch32_ADR || + _type == OffsetType::kAArch32_U23_SignedOffset || + _type == OffsetType::kAArch32_U23_0To3At0_4To7At8; + } + + //! Returns flags. + ASMJIT_INLINE_NODEBUG uint32_t flags() const noexcept { return _flags; } + //! Returns the size of the region/instruction where the offset is encoded. + ASMJIT_INLINE_NODEBUG uint32_t regionSize() const noexcept { return _regionSize; } + //! Returns the offset of the word relative to the start of the region where the offset is. + ASMJIT_INLINE_NODEBUG uint32_t valueOffset() const noexcept { return _valueOffset; } + //! Returns the size of the data-type (word) that contains the offset, in bytes. + ASMJIT_INLINE_NODEBUG uint32_t valueSize() const noexcept { return _valueSize; } + //! Returns the count of bits of the offset value in the data it's stored in. + ASMJIT_INLINE_NODEBUG uint32_t immBitCount() const noexcept { return _immBitCount; } + //! Returns the bit-shift of the offset value in the data it's stored in. + ASMJIT_INLINE_NODEBUG uint32_t immBitShift() const noexcept { return _immBitShift; } + //! Returns the number of least significant bits of the offset value, that must be zero and that are not part of + //! the encoded data. + ASMJIT_INLINE_NODEBUG uint32_t immDiscardLsb() const noexcept { return _immDiscardLsb; } + + //! Resets this offset format to a simple data value of `dataSize` bytes. + //! + //! The region will be the same size as data and immediate bits would correspond to `dataSize * 8`. There will be + //! no immediate bit shift or discarded bits. + inline void resetToSimpleValue(OffsetType type, size_t valueSize) noexcept { + ASMJIT_ASSERT(valueSize <= 8u); + + _type = type; + _flags = uint8_t(0); + _regionSize = uint8_t(valueSize); + _valueSize = uint8_t(valueSize); + _valueOffset = uint8_t(0); + _immBitCount = uint8_t(valueSize * 8u); + _immBitShift = uint8_t(0); + _immDiscardLsb = uint8_t(0); + } + + inline void resetToImmValue(OffsetType type, size_t valueSize, uint32_t immBitShift, uint32_t immBitCount, uint32_t immDiscardLsb) noexcept { + ASMJIT_ASSERT(valueSize <= 8u); + ASMJIT_ASSERT(immBitShift < valueSize * 8u); + ASMJIT_ASSERT(immBitCount <= 64u); + ASMJIT_ASSERT(immDiscardLsb <= 64u); + + _type = type; + _flags = uint8_t(0); + _regionSize = uint8_t(valueSize); + _valueSize = uint8_t(valueSize); + _valueOffset = uint8_t(0); + _immBitCount = uint8_t(immBitCount); + _immBitShift = uint8_t(immBitShift); + _immDiscardLsb = uint8_t(immDiscardLsb); + } + + inline void setRegion(size_t regionSize, size_t valueOffset) noexcept { + _regionSize = uint8_t(regionSize); + _valueOffset = uint8_t(valueOffset); + } + + inline void setLeadingAndTrailingSize(size_t leadingSize, size_t trailingSize) noexcept { + _regionSize = uint8_t(leadingSize + trailingSize + _valueSize); + _valueOffset = uint8_t(leadingSize); + } + + //! \} +}; + +//! Relocation type. +enum class RelocType : uint32_t { + //! None/deleted (no relocation). + kNone = 0, + //! Expression evaluation, `_payload` is pointer to `Expression`. + kExpression = 1, + //! Relocate absolute to absolute. + kAbsToAbs = 2, + //! Relocate relative to absolute. + kRelToAbs = 3, + //! Relocate absolute to relative. + kAbsToRel = 4, + //! Relocate absolute to relative or use trampoline. + kX64AddressEntry = 5 +}; + +//! Relocation entry. +struct RelocEntry { + //! \name Members + //! \{ + + //! Relocation id. + uint32_t _id; + //! Type of the relocation. + RelocType _relocType; + //! Format of the relocated value. + OffsetFormat _format; + //! Source section id. + uint32_t _sourceSectionId; + //! Target section id. + uint32_t _targetSectionId; + //! Source offset (relative to start of the section). + uint64_t _sourceOffset; + //! Payload (target offset, target address, expression, etc). + uint64_t _payload; + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + + ASMJIT_INLINE_NODEBUG RelocType relocType() const noexcept { return _relocType; } + ASMJIT_INLINE_NODEBUG const OffsetFormat& format() const noexcept { return _format; } + + ASMJIT_INLINE_NODEBUG uint32_t sourceSectionId() const noexcept { return _sourceSectionId; } + ASMJIT_INLINE_NODEBUG uint32_t targetSectionId() const noexcept { return _targetSectionId; } + + ASMJIT_INLINE_NODEBUG uint64_t sourceOffset() const noexcept { return _sourceOffset; } + ASMJIT_INLINE_NODEBUG uint64_t payload() const noexcept { return _payload; } + + ASMJIT_INLINE_NODEBUG Expression* payloadAsExpression() const noexcept { + return reinterpret_cast(uintptr_t(_payload)); + } + + //! \} +}; + +//! Type of the \ref Label. +enum class LabelType : uint8_t { + //! Anonymous label that can optionally have a name, which is only used for debugging purposes. + kAnonymous = 0, + //! Local label (always has parentId). + kLocal = 1, + //! Global label (never has parentId). + kGlobal = 2, + //! External label (references an external symbol). + kExternal = 3, + + //! Maximum value of `LabelType`. + kMaxValue = kExternal +}; + +//! Data structure used to link either unbound labels or cross-section links. +struct LabelLink { + //! Next link (single-linked list). + LabelLink* next; + //! Section id where the label is bound. + uint32_t sectionId; + //! Relocation id or Globals::kInvalidId. + uint32_t relocId; + //! Label offset relative to the start of the section. + size_t offset; + //! Inlined rel8/rel32. + intptr_t rel; + //! Offset format information. + OffsetFormat format; +}; + +//! Label entry. +//! +//! Contains the following properties: +//! - Label id - This is the only thing that is set to the `Label` operand. +//! - Label name - Optional, used mostly to create executables and libraries. +//! - Label type - Type of the label, default `LabelType::kAnonymous`. +//! - Label parent id - Derived from many assemblers that allow to define a local label that falls under a global +//! label. This allows to define many labels of the same name that have different parent (global) label. +//! - Offset - offset of the label bound by `Assembler`. +//! - Links - single-linked list that contains locations of code that has to be patched when the label gets bound. +//! Every use of unbound label adds one link to `_links` list. +//! - HVal - Hash value of label's name and optionally parentId. +//! - HashNext - Hash-table implementation detail. +class LabelEntry : public ZoneHashNode { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + //! SSO size of \ref _name. + //! + //! \cond INTERNAL + //! Let's round the size of `LabelEntry` to 64 bytes (as `ZoneAllocator` has granularity of 32 bytes anyway). This + //! gives `_name` the remaining space, which is should be 16 bytes on 64-bit and 28 bytes on 32-bit architectures. + //! \endcond + kStaticNameSize = 64 - (sizeof(ZoneHashNode) + 8 + sizeof(Section*) + sizeof(size_t) + sizeof(LabelLink*)) + }; + + //! \} + + //! \name Members + //! \{ + + //! Type of the label. + LabelType _type; + //! Must be zero. + uint8_t _reserved[3]; + //! Label parent id or zero. + uint32_t _parentId; + //! Label offset relative to the start of the `_section`. + uint64_t _offset; + //! Section where the label was bound. + Section* _section; + //! Label links. + LabelLink* _links; + //! Label name. + ZoneString _name; + + //! \} + + //! \name Accessors + //! \{ + + // NOTE: Label id is stored in `_customData`, which is provided by ZoneHashNode to fill a padding that a C++ + // compiler targeting 64-bit CPU will add to align the structure to 64-bits. + + //! Returns label id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _customData; } + //! Sets label id (internal, used only by `CodeHolder`). + ASMJIT_INLINE_NODEBUG void _setId(uint32_t id) noexcept { _customData = id; } + + //! Returns label type. + ASMJIT_INLINE_NODEBUG LabelType type() const noexcept { return _type; } + + //! Tests whether the label has a parent label. + ASMJIT_INLINE_NODEBUG bool hasParent() const noexcept { return _parentId != Globals::kInvalidId; } + //! Returns label's parent id. + ASMJIT_INLINE_NODEBUG uint32_t parentId() const noexcept { return _parentId; } + + //! Returns the section where the label was bound. + //! + //! If the label was not yet bound the return value is `nullptr`. + ASMJIT_INLINE_NODEBUG Section* section() const noexcept { return _section; } + + //! Tests whether the label has name. + ASMJIT_INLINE_NODEBUG bool hasName() const noexcept { return !_name.empty(); } + + //! Returns the label's name. + //! + //! \note Local labels will return their local name without their parent part, for example ".L1". + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name.data(); } + + //! Returns size of label's name. + //! + //! \note Label name is always null terminated, so you can use `strlen()` to get it, however, it's also cached in + //! `LabelEntry` itself, so if you want to know the size the fastest way is to call `LabelEntry::nameSize()`. + ASMJIT_INLINE_NODEBUG uint32_t nameSize() const noexcept { return _name.size(); } + + //! Returns links associated with this label. + ASMJIT_INLINE_NODEBUG LabelLink* links() const noexcept { return _links; } + + //! Tests whether the label is bound. + ASMJIT_INLINE_NODEBUG bool isBound() const noexcept { return _section != nullptr; } + //! Tests whether the label is bound to a the given `sectionId`. + ASMJIT_INLINE_NODEBUG bool isBoundTo(Section* section) const noexcept { return _section == section; } + + //! Returns the label offset (only useful if the label is bound). + ASMJIT_INLINE_NODEBUG uint64_t offset() const noexcept { return _offset; } + + //! Returns the hash-value of label's name and its parent label (if any). + //! + //! Label hash is calculated as `HASH(Name) ^ ParentId`. The hash function is implemented in `Support::hashString()` + //! and `Support::hashRound()`. + ASMJIT_INLINE_NODEBUG uint32_t hashCode() const noexcept { return _hashCode; } + + //! \} +}; + +//! Holds assembled code and data (including sections, labels, and relocation information). +//! +//! CodeHolder connects emitters with their targets. It provides them interface that can be used to query information +//! about the target environment (architecture, etc...) and API to create labels, sections, relocations, and to write +//! data to a \ref CodeBuffer, which is always part of \ref Section. More than one emitter can be attached to a single +//! CodeHolder instance at a time, which is used in practice +//! +//! CodeHolder provides interface for all emitter types. Assemblers use CodeHolder to write into \ref CodeBuffer, and +//! higher level emitters like Builder and Compiler use CodeHolder to manage labels and sections so higher level code +//! can be serialized to Assembler by \ref BaseEmitter::finalize() and \ref BaseBuilder::serializeTo(). +//! +//! In order to use CodeHolder, it must be first initialized by \ref init(). After the CodeHolder has been successfully +//! initialized it can be used to hold assembled code, sections, labels, relocations, and to attach / detach code +//! emitters. After the end of code generation it can be used to query physical locations of labels and to relocate +//! the assembled code into the right address. +//! +//! \note \ref CodeHolder has an ability to attach an \ref ErrorHandler, however, the error handler is not triggered +//! by \ref CodeHolder itself, it's instead propagated to all emitters that attach to it. +class CodeHolder { +public: + ASMJIT_NONCOPYABLE(CodeHolder) + + //! \name Members + //! \{ + + //! Environment information. + Environment _environment; + //! CPU features of the target architecture. + CpuFeatures _cpuFeatures; + //! Base address or \ref Globals::kNoBaseAddress. + uint64_t _baseAddress; + + //! Attached `Logger`, used by all consumers. + Logger* _logger; + //! Attached `ErrorHandler`. + ErrorHandler* _errorHandler; + + //! Code zone (used to allocate core structures). + Zone _zone; + //! Zone allocator, used to manage internal containers. + ZoneAllocator _allocator; + + //! Attached emitters. + ZoneVector _emitters; + //! Section entries. + ZoneVector _sections; + //! Section entries sorted by section order and then section id. + ZoneVector _sectionsByOrder; + //! Label entries. + ZoneVector _labelEntries; + //! Relocation entries. + ZoneVector _relocations; + //! Label name -> LabelEntry (only named labels). + ZoneHash _namedLabels; + + //! Count of label links, which are not resolved. + size_t _unresolvedLinkCount; + //! Pointer to an address table section (or null if this section doesn't exist). + Section* _addressTableSection; + //! Address table entries. + ZoneTree _addressTableEntries; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates an uninitialized CodeHolder (you must init() it before it can be used). + //! + //! An optional `temporary` argument can be used to initialize the first block of \ref Zone that the CodeHolder + //! uses into a temporary memory provided by the user. + ASMJIT_API explicit CodeHolder(const Support::Temporary* temporary = nullptr) noexcept; + + //! \overload + ASMJIT_INLINE_NODEBUG explicit CodeHolder(const Support::Temporary& temporary) noexcept + : CodeHolder(&temporary) {} + + //! Destroys the CodeHolder and frees all resources it has allocated. + ASMJIT_API ~CodeHolder() noexcept; + + //! Tests whether the `CodeHolder` has been initialized. + //! + //! Emitters can be only attached to initialized `CodeHolder` instances. + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _environment.isInitialized(); } + + //! Initializes CodeHolder to hold code described by the given `environment` and `baseAddress`. + ASMJIT_API Error init(const Environment& environment, uint64_t baseAddress = Globals::kNoBaseAddress) noexcept; + //! Initializes CodeHolder to hold code described by the given `environment`, `cpuFeatures`, and `baseAddress`. + ASMJIT_API Error init(const Environment& environment, const CpuFeatures& cpuFeatures, uint64_t baseAddress = Globals::kNoBaseAddress) noexcept; + //! Detaches all code-generators attached and resets the `CodeHolder`. + ASMJIT_API void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept; + + //! \} + + //! \name Attach & Detach + //! \{ + + //! Attaches an emitter to this `CodeHolder`. + ASMJIT_API Error attach(BaseEmitter* emitter) noexcept; + //! Detaches an emitter from this `CodeHolder`. + ASMJIT_API Error detach(BaseEmitter* emitter) noexcept; + + //! \} + + //! \name Allocators + //! \{ + + //! Returns the allocator that the `CodeHolder` uses. + //! + //! \note This should be only used for AsmJit's purposes. Code holder uses arena allocator to allocate everything, + //! so anything allocated through this allocator will be invalidated by \ref CodeHolder::reset() or by CodeHolder's + //! destructor. + ASMJIT_INLINE_NODEBUG ZoneAllocator* allocator() const noexcept { return const_cast(&_allocator); } + + //! \} + + //! \name Code & Architecture + //! \{ + + //! Returns the target environment information. + ASMJIT_INLINE_NODEBUG const Environment& environment() const noexcept { return _environment; } + + //! Returns the target architecture. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return environment().arch(); } + //! Returns the target sub-architecture. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return environment().subArch(); } + + //! Returns the minimum CPU features of the target architecture. + ASMJIT_INLINE_NODEBUG const CpuFeatures& cpuFeatures() const noexcept { return _cpuFeatures; } + + //! Tests whether a static base-address is set. + ASMJIT_INLINE_NODEBUG bool hasBaseAddress() const noexcept { return _baseAddress != Globals::kNoBaseAddress; } + //! Returns a static base-address or \ref Globals::kNoBaseAddress, if not set. + ASMJIT_INLINE_NODEBUG uint64_t baseAddress() const noexcept { return _baseAddress; } + + //! \} + + //! \name Emitters + //! \{ + + //! Returns a vector of attached emitters. + ASMJIT_INLINE_NODEBUG const ZoneVector& emitters() const noexcept { return _emitters; } + + //! \} + + //! \name Logging + //! \{ + + //! Returns the attached logger. + ASMJIT_INLINE_NODEBUG Logger* logger() const noexcept { return _logger; } + //! Attaches a `logger` to CodeHolder and propagates it to all attached emitters. + ASMJIT_API void setLogger(Logger* logger) noexcept; + //! Resets the logger to none. + ASMJIT_INLINE_NODEBUG void resetLogger() noexcept { setLogger(nullptr); } + + //! \name Error Handling + //! \{ + + //! Tests whether the CodeHolder has an attached error handler, see \ref ErrorHandler. + ASMJIT_INLINE_NODEBUG bool hasErrorHandler() const noexcept { return _errorHandler != nullptr; } + //! Returns the attached error handler. + ASMJIT_INLINE_NODEBUG ErrorHandler* errorHandler() const noexcept { return _errorHandler; } + //! Attach an error handler to this `CodeHolder`. + ASMJIT_API void setErrorHandler(ErrorHandler* errorHandler) noexcept; + //! Resets the error handler to none. + ASMJIT_INLINE_NODEBUG void resetErrorHandler() noexcept { setErrorHandler(nullptr); } + + //! \} + + //! \name Code Buffer + //! \{ + + //! Makes sure that at least `n` bytes can be added to CodeHolder's buffer `cb`. + //! + //! \note The buffer `cb` must be managed by `CodeHolder` - otherwise the behavior of the function is undefined. + ASMJIT_API Error growBuffer(CodeBuffer* cb, size_t n) noexcept; + + //! Reserves the size of `cb` to at least `n` bytes. + //! + //! \note The buffer `cb` must be managed by `CodeHolder` - otherwise the behavior of the function is undefined. + ASMJIT_API Error reserveBuffer(CodeBuffer* cb, size_t n) noexcept; + + //! \} + + //! \name Sections + //! \{ + + //! Returns an array of `Section*` records. + ASMJIT_INLINE_NODEBUG const ZoneVector& sections() const noexcept { return _sections; } + //! Returns an array of `Section*` records sorted according to section order first, then section id. + ASMJIT_INLINE_NODEBUG const ZoneVector& sectionsByOrder() const noexcept { return _sectionsByOrder; } + //! Returns the number of sections. + ASMJIT_INLINE_NODEBUG uint32_t sectionCount() const noexcept { return _sections.size(); } + + //! Tests whether the given `sectionId` is valid. + ASMJIT_INLINE_NODEBUG bool isSectionValid(uint32_t sectionId) const noexcept { return sectionId < _sections.size(); } + + //! Creates a new section and return its pointer in `sectionOut`. + //! + //! Returns `Error`, does not report a possible error to `ErrorHandler`. + ASMJIT_API Error newSection(Section** sectionOut, const char* name, size_t nameSize = SIZE_MAX, SectionFlags flags = SectionFlags::kNone, uint32_t alignment = 1, int32_t order = 0) noexcept; + + //! Returns a section entry of the given index. + ASMJIT_INLINE_NODEBUG Section* sectionById(uint32_t sectionId) const noexcept { return _sections[sectionId]; } + + //! Returns section-id that matches the given `name`. + //! + //! If there is no such section `Section::kInvalidId` is returned. + ASMJIT_API Section* sectionByName(const char* name, size_t nameSize = SIZE_MAX) const noexcept; + + //! Returns '.text' section (section that commonly represents code). + //! + //! \note Text section is always the first section in \ref CodeHolder::sections() array. + ASMJIT_INLINE_NODEBUG Section* textSection() const noexcept { return _sections[0]; } + + //! Tests whether '.addrtab' section exists. + ASMJIT_INLINE_NODEBUG bool hasAddressTable() const noexcept { return _addressTableSection != nullptr; } + + //! Returns '.addrtab' section. + //! + //! This section is used exclusively by AsmJit to store absolute 64-bit + //! addresses that cannot be encoded in instructions like 'jmp' or 'call'. + //! + //! \note This section is created on demand, the returned pointer can be null. + ASMJIT_INLINE_NODEBUG Section* addressTableSection() const noexcept { return _addressTableSection; } + + //! Ensures that '.addrtab' section exists (creates it if it doesn't) and + //! returns it. Can return `nullptr` on out of memory condition. + ASMJIT_API Section* ensureAddressTableSection() noexcept; + + //! Used to add an address to an address table. + //! + //! This implicitly calls `ensureAddressTableSection()` and then creates `AddressTableEntry` that is inserted + //! to `_addressTableEntries`. If the address already exists this operation does nothing as the same addresses + //! use the same slot. + //! + //! This function should be considered internal as it's used by assemblers to insert an absolute address into the + //! address table. Inserting address into address table without creating a particular relocation entry makes no sense. + ASMJIT_API Error addAddressToAddressTable(uint64_t address) noexcept; + + //! \} + + //! \name Labels & Symbols + //! \{ + + //! Returns array of `LabelEntry*` records. + ASMJIT_INLINE_NODEBUG const ZoneVector& labelEntries() const noexcept { return _labelEntries; } + + //! Returns number of labels created. + ASMJIT_INLINE_NODEBUG uint32_t labelCount() const noexcept { return _labelEntries.size(); } + + //! Tests whether the label having `id` is valid (i.e. created by `newLabelEntry()`). + ASMJIT_INLINE_NODEBUG bool isLabelValid(uint32_t labelId) const noexcept { + return labelId < _labelEntries.size(); + } + + //! Tests whether the `label` is valid (i.e. created by `newLabelEntry()`). + ASMJIT_INLINE_NODEBUG bool isLabelValid(const Label& label) const noexcept { + return label.id() < _labelEntries.size(); + } + + //! \overload + ASMJIT_INLINE_NODEBUG bool isLabelBound(uint32_t labelId) const noexcept { + return isLabelValid(labelId) && _labelEntries[labelId]->isBound(); + } + + //! Tests whether the `label` is already bound. + //! + //! Returns `false` if the `label` is not valid. + ASMJIT_INLINE_NODEBUG bool isLabelBound(const Label& label) const noexcept { + return isLabelBound(label.id()); + } + + //! Returns LabelEntry of the given label `id`. + ASMJIT_INLINE_NODEBUG LabelEntry* labelEntry(uint32_t labelId) const noexcept { + return isLabelValid(labelId) ? _labelEntries[labelId] : static_cast(nullptr); + } + + //! Returns LabelEntry of the given `label`. + ASMJIT_INLINE_NODEBUG LabelEntry* labelEntry(const Label& label) const noexcept { + return labelEntry(label.id()); + } + + //! Returns offset of a `Label` by its `labelId`. + //! + //! The offset returned is relative to the start of the section. Zero offset is returned for unbound labels, + //! which is their initial offset value. + ASMJIT_INLINE_NODEBUG uint64_t labelOffset(uint32_t labelId) const noexcept { + ASMJIT_ASSERT(isLabelValid(labelId)); + return _labelEntries[labelId]->offset(); + } + + //! \overload + ASMJIT_INLINE_NODEBUG uint64_t labelOffset(const Label& label) const noexcept { + return labelOffset(label.id()); + } + + //! Returns offset of a label by it's `labelId` relative to the base offset. + //! + //! \remarks The offset of the section where the label is bound must be valid in order to use this function, + //! otherwise the value returned will not be reliable. + inline uint64_t labelOffsetFromBase(uint32_t labelId) const noexcept { + ASMJIT_ASSERT(isLabelValid(labelId)); + const LabelEntry* le = _labelEntries[labelId]; + return (le->isBound() ? le->section()->offset() : uint64_t(0)) + le->offset(); + } + + //! \overload + inline uint64_t labelOffsetFromBase(const Label& label) const noexcept { + return labelOffsetFromBase(label.id()); + } + + //! Creates a new anonymous label and return its id in `idOut`. + //! + //! Returns `Error`, does not report error to `ErrorHandler`. + ASMJIT_API Error newLabelEntry(LabelEntry** entryOut) noexcept; + + //! Creates a new named \ref LabelEntry of the given label `type`. + //! + //! \param entryOut Where to store the created \ref LabelEntry. + //! \param name The name of the label. + //! \param nameSize The length of `name` argument, or `SIZE_MAX` if `name` is a null terminated string, which + //! means that the `CodeHolder` will use `strlen()` to determine the length. + //! \param type The type of the label to create, see \ref LabelType. + //! \param parentId Parent id of a local label, otherwise it must be \ref Globals::kInvalidId. + //! \retval Always returns \ref Error, does not report a possible error to the attached \ref ErrorHandler. + //! + //! AsmJit has a support for local labels (\ref LabelType::kLocal) which require a parent label id (parentId). + //! The names of local labels can conflict with names of other local labels that have a different parent. In + //! addition, AsmJit supports named anonymous labels, which are useful only for debugging purposes as the + //! anonymous name will have a name, which will be formatted, but the label itself cannot be queried by such + //! name. + ASMJIT_API Error newNamedLabelEntry(LabelEntry** entryOut, const char* name, size_t nameSize, LabelType type, uint32_t parentId = Globals::kInvalidId) noexcept; + + //! Returns a label by name. + //! + //! If the named label doesn't a default constructed \ref Label is returned, + //! which has its id set to \ref Globals::kInvalidId. + ASMJIT_INLINE_NODEBUG Label labelByName(const char* name, size_t nameSize = SIZE_MAX, uint32_t parentId = Globals::kInvalidId) noexcept { + return Label(labelIdByName(name, nameSize, parentId)); + } + + //! Returns a label id by name. + //! + //! If the named label doesn't exist \ref Globals::kInvalidId is returned. + ASMJIT_API uint32_t labelIdByName(const char* name, size_t nameSize = SIZE_MAX, uint32_t parentId = Globals::kInvalidId) noexcept; + + //! Tests whether there are any unresolved label links. + ASMJIT_INLINE_NODEBUG bool hasUnresolvedLinks() const noexcept { return _unresolvedLinkCount != 0; } + //! Returns the number of label links, which are unresolved. + ASMJIT_INLINE_NODEBUG size_t unresolvedLinkCount() const noexcept { return _unresolvedLinkCount; } + + //! Creates a new label-link used to store information about yet unbound labels. + //! + //! Returns `null` if the allocation failed. + ASMJIT_API LabelLink* newLabelLink(LabelEntry* le, uint32_t sectionId, size_t offset, intptr_t rel, const OffsetFormat& format) noexcept; + + //! Resolves cross-section links (`LabelLink`) associated with each label that was used as a destination in code + //! of a different section. It's only useful to people that use multiple sections as it will do nothing if the code + //! only contains a single section in which cross-section links are not possible. + ASMJIT_API Error resolveUnresolvedLinks() noexcept; + + //! Binds a label to a given `sectionId` and `offset` (relative to start of the section). + //! + //! This function is generally used by `BaseAssembler::bind()` to do the heavy lifting. + ASMJIT_API Error bindLabel(const Label& label, uint32_t sectionId, uint64_t offset) noexcept; + + //! \} + + //! \name Relocations + //! \{ + + //! Tests whether the code contains relocation entries. + ASMJIT_INLINE_NODEBUG bool hasRelocEntries() const noexcept { return !_relocations.empty(); } + //! Returns array of `RelocEntry*` records. + ASMJIT_INLINE_NODEBUG const ZoneVector& relocEntries() const noexcept { return _relocations; } + + //! Returns a RelocEntry of the given `id`. + ASMJIT_INLINE_NODEBUG RelocEntry* relocEntry(uint32_t id) const noexcept { return _relocations[id]; } + + //! Creates a new relocation entry of type `relocType`. + //! + //! Additional fields can be set after the relocation entry was created. + ASMJIT_API Error newRelocEntry(RelocEntry** dst, RelocType relocType) noexcept; + + //! \} + + //! \name Utilities + //! \{ + + //! Flattens all sections by recalculating their offsets, starting at 0. + //! + //! \note This should never be called more than once. + ASMJIT_API Error flatten() noexcept; + + //! Returns computed the size of code & data of all sections. + //! + //! \note All sections will be iterated over and the code size returned would represent the minimum code size of + //! all combined sections after applying minimum alignment. Code size may decrease after calling `flatten()` and + //! `relocateToBase()`. + ASMJIT_API size_t codeSize() const noexcept; + + //! Relocates the code to the given `baseAddress`. + //! + //! \param baseAddress Absolute base address where the code will be relocated to. Please note that nothing is + //! copied to such base address, it's just an absolute value used by the relocation code to resolve all stored + //! relocations. + //! + //! \note This should never be called more than once. + ASMJIT_API Error relocateToBase(uint64_t baseAddress) noexcept; + + //! Copies a single section into `dst`. + ASMJIT_API Error copySectionData(void* dst, size_t dstSize, uint32_t sectionId, CopySectionFlags copyFlags = CopySectionFlags::kNone) noexcept; + + //! Copies all sections into `dst`. + //! + //! This should only be used if the data was flattened and there are no gaps between the sections. The `dstSize` + //! is always checked and the copy will never write anything outside the provided buffer. + ASMJIT_API Error copyFlattenedData(void* dst, size_t dstSize, CopySectionFlags copyFlags = CopySectionFlags::kNone) noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CODEHOLDER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/compiler.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..24b133c5407fae00a9453f0183a9e3d053fe1e6d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/compiler.h @@ -0,0 +1,741 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_COMPILER_H_INCLUDED +#define ASMJIT_CORE_COMPILER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_COMPILER + +#include "../core/assembler.h" +#include "../core/builder.h" +#include "../core/constpool.h" +#include "../core/compilerdefs.h" +#include "../core/func.h" +#include "../core/inst.h" +#include "../core/operand.h" +#include "../core/support.h" +#include "../core/zone.h" +#include "../core/zonevector.h" + +ASMJIT_BEGIN_NAMESPACE + +class JumpAnnotation; +class JumpNode; +class FuncNode; +class FuncRetNode; +class InvokeNode; + +//! \addtogroup asmjit_compiler +//! \{ + +//! Code emitter that uses virtual registers and performs register allocation. +//! +//! Compiler is a high-level code-generation tool that provides register allocation and automatic handling of function +//! calling conventions. It was primarily designed for merging multiple parts of code into a function without worrying +//! about registers and function calling conventions. +//! +//! BaseCompiler can be used, with a minimum effort, to handle 32-bit and 64-bit code generation within a single code +//! base. +//! +//! BaseCompiler is based on BaseBuilder and contains all the features it provides. It means that the code it stores +//! can be modified (removed, added, injected) and analyzed. When the code is finalized the compiler can emit the code +//! into an Assembler to translate the abstract representation into a machine code. +//! +//! Check out architecture specific compilers for more details and examples: +//! +//! - \ref x86::Compiler - X86/X64 compiler implementation. +//! - \ref a64::Compiler - AArch64 compiler implementation. +class ASMJIT_VIRTAPI BaseCompiler : public BaseBuilder { +public: + ASMJIT_NONCOPYABLE(BaseCompiler) + typedef BaseBuilder Base; + + //! \name Members + //! \{ + + //! Current function. + FuncNode* _func; + //! Allocates `VirtReg` objects. + Zone _vRegZone; + //! Stores array of `VirtReg` pointers. + ZoneVector _vRegArray; + //! Stores jump annotations. + ZoneVector _jumpAnnotations; + + //! Local and global constant pools. + //! + //! Local constant pool is flushed with each function, global constant pool is flushed only by \ref finalize(). + ConstPoolNode* _constPools[2]; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `BaseCompiler` instance. + ASMJIT_API BaseCompiler() noexcept; + //! Destroys the `BaseCompiler` instance. + ASMJIT_API ~BaseCompiler() noexcept override; + + //! \} + + //! \name Function Management + //! \{ + + //! Creates a new \ref FuncNode. + ASMJIT_API Error newFuncNode(FuncNode** ASMJIT_NONNULL(out), const FuncSignature& signature); + //! Creates a new \ref FuncNode adds it to the instruction stream. + ASMJIT_API Error addFuncNode(FuncNode** ASMJIT_NONNULL(out), const FuncSignature& signature); + + //! Creates a new \ref FuncRetNode. + ASMJIT_API Error newFuncRetNode(FuncRetNode** ASMJIT_NONNULL(out), const Operand_& o0, const Operand_& o1); + //! Creates a new \ref FuncRetNode and adds it to the instruction stream. + ASMJIT_API Error addFuncRetNode(FuncRetNode** ASMJIT_NONNULL(out), const Operand_& o0, const Operand_& o1); + + //! Returns the current function. + ASMJIT_INLINE_NODEBUG FuncNode* func() const noexcept { return _func; } + + //! Creates a new \ref FuncNode with the given `signature` and returns it. + inline FuncNode* newFunc(const FuncSignature& signature) { + FuncNode* node; + newFuncNode(&node, signature); + return node; + } + + //! Creates a new \ref FuncNode with the given `signature`, adds it to the instruction stream by using + //! the \ref addFunc(FuncNode*) overload, and returns it. + inline FuncNode* addFunc(const FuncSignature& signature) { + FuncNode* node; + addFuncNode(&node, signature); + return node; + } + + //! Adds a function `node` to the instruction stream. + ASMJIT_API FuncNode* addFunc(FuncNode* ASMJIT_NONNULL(func)); + //! Emits a sentinel that marks the end of the current function. + ASMJIT_API Error endFunc(); + +#if !defined(ASMJIT_NO_DEPRECATED) + inline Error _setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg); + + //! Sets a function argument at `argIndex` to `reg`. + ASMJIT_DEPRECATED("Setting arguments through Compiler is deprecated, use FuncNode->setArg() instead") + inline Error setArg(size_t argIndex, const BaseReg& reg) { return _setArg(argIndex, 0, reg); } + + //! Sets a function argument at `argIndex` at `valueIndex` to `reg`. + ASMJIT_DEPRECATED("Setting arguments through Compiler is deprecated, use FuncNode->setArg() instead") + inline Error setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg) { return _setArg(argIndex, valueIndex, reg); } +#endif + + inline Error addRet(const Operand_& o0, const Operand_& o1) { + FuncRetNode* node; + return addFuncRetNode(&node, o0, o1); + } + + //! \} + + //! \name Function Invocation + //! \{ + + //! Creates a new \ref InvokeNode. + ASMJIT_API Error newInvokeNode(InvokeNode** ASMJIT_NONNULL(out), InstId instId, const Operand_& o0, const FuncSignature& signature); + //! Creates a new \ref InvokeNode and adds it to the instruction stream. + ASMJIT_API Error addInvokeNode(InvokeNode** ASMJIT_NONNULL(out), InstId instId, const Operand_& o0, const FuncSignature& signature); + + //! \} + + //! \name Virtual Registers + //! \{ + + //! Creates a new virtual register representing the given `typeId` and `signature`. + //! + //! \note This function is public, but it's not generally recommended to be used by AsmJit users, use architecture + //! specific `newReg()` functionality instead or functions like \ref _newReg() and \ref _newRegFmt(). + ASMJIT_API Error newVirtReg(VirtReg** ASMJIT_NONNULL(out), TypeId typeId, OperandSignature signature, const char* name); + + //! Creates a new virtual register of the given `typeId` and stores it to `out` operand. + ASMJIT_API Error _newReg(BaseReg* ASMJIT_NONNULL(out), TypeId typeId, const char* name = nullptr); + + //! Creates a new virtual register of the given `typeId` and stores it to `out` operand. + //! + //! \note This version accepts a snprintf() format `fmt` followed by a variadic arguments. + ASMJIT_API Error _newRegFmt(BaseReg* ASMJIT_NONNULL(out), TypeId typeId, const char* fmt, ...); + //! \overload + inline Error _newRegFmt(BaseReg* ASMJIT_NONNULL(out), TypeId typeId) { return _newRegFmt(out, typeId, nullptr); } + + //! Creates a new virtual register compatible with the provided reference register `ref`. + ASMJIT_API Error _newReg(BaseReg* ASMJIT_NONNULL(out), const BaseReg& ref, const char* name = nullptr); + + //! Creates a new virtual register compatible with the provided reference register `ref`. + //! + //! \note This version accepts a snprintf() format `fmt` followed by a variadic arguments. + ASMJIT_API Error _newRegFmt(BaseReg* ASMJIT_NONNULL(out), const BaseReg& ref, const char* fmt, ...); + + //! Tests whether the given `id` is a valid virtual register id. + ASMJIT_INLINE_NODEBUG bool isVirtIdValid(uint32_t id) const noexcept { + uint32_t index = Operand::virtIdToIndex(id); + return index < _vRegArray.size(); + } + //! Tests whether the given `reg` is a virtual register having a valid id. + ASMJIT_INLINE_NODEBUG bool isVirtRegValid(const BaseReg& reg) const noexcept { + return isVirtIdValid(reg.id()); + } + + //! Returns \ref VirtReg associated with the given `id`. + inline VirtReg* virtRegById(uint32_t id) const noexcept { + ASMJIT_ASSERT(isVirtIdValid(id)); + return _vRegArray[Operand::virtIdToIndex(id)]; + } + + //! Returns \ref VirtReg associated with the given `reg`. + ASMJIT_INLINE_NODEBUG VirtReg* virtRegByReg(const BaseReg& reg) const noexcept { return virtRegById(reg.id()); } + + //! Returns \ref VirtReg associated with the given virtual register `index`. + //! + //! \note This is not the same as virtual register id. The conversion between id and its index is implemented + //! by \ref Operand_::virtIdToIndex() and \ref Operand_::indexToVirtId() functions. + ASMJIT_INLINE_NODEBUG VirtReg* virtRegByIndex(uint32_t index) const noexcept { return _vRegArray[index]; } + + //! Returns an array of all virtual registers managed by the Compiler. + ASMJIT_INLINE_NODEBUG const ZoneVector& virtRegs() const noexcept { return _vRegArray; } + + //! \name Stack + //! \{ + + //! Creates a new stack of the given `size` and `alignment` and stores it to `out`. + //! + //! \note `name` can be used to give the stack a name, for debugging purposes. + ASMJIT_API Error _newStack(BaseMem* ASMJIT_NONNULL(out), uint32_t size, uint32_t alignment, const char* name = nullptr); + + //! Updates the stack size of a stack created by `_newStack()` by its `virtId`. + ASMJIT_API Error setStackSize(uint32_t virtId, uint32_t newSize, uint32_t newAlignment = 0); + + //! Updates the stack size of a stack created by `_newStack()`. + ASMJIT_INLINE_NODEBUG Error setStackSize(const BaseMem& mem, uint32_t newSize, uint32_t newAlignment = 0) { + return setStackSize(mem.id(), newSize, newAlignment); + } + + //! \} + + //! \name Constants + //! \{ + + //! Creates a new constant of the given `scope` (see \ref ConstPoolScope). + //! + //! This function adds a constant of the given `size` to the built-in \ref ConstPool and stores the reference to that + //! constant to the `out` operand. + ASMJIT_API Error _newConst(BaseMem* ASMJIT_NONNULL(out), ConstPoolScope scope, const void* data, size_t size); + + //! \} + + //! \name Miscellaneous + //! \{ + + //! Rename the given virtual register `reg` to a formatted string `fmt`. + ASMJIT_API void rename(const BaseReg& reg, const char* fmt, ...); + + //! \} + + //! \name Jump Annotations + //! \{ + + ASMJIT_INLINE_NODEBUG const ZoneVector& jumpAnnotations() const noexcept { + return _jumpAnnotations; + } + + ASMJIT_API Error newJumpNode(JumpNode** ASMJIT_NONNULL(out), InstId instId, InstOptions instOptions, const Operand_& o0, JumpAnnotation* annotation); + ASMJIT_API Error emitAnnotatedJump(InstId instId, const Operand_& o0, JumpAnnotation* annotation); + + //! Returns a new `JumpAnnotation` instance, which can be used to aggregate possible targets of a jump where the + //! target is not a label, for example to implement jump tables. + ASMJIT_API JumpAnnotation* newJumpAnnotation(); + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! Jump annotation used to annotate jumps. +//! +//! \ref BaseCompiler allows to emit jumps where the target is either register or memory operand. Such jumps cannot be +//! trivially inspected, so instead of doing heuristics AsmJit allows to annotate such jumps with possible targets. +//! Register allocator then uses the annotation to construct control-flow, which is then used by liveness analysis and +//! other tools to prepare ground for register allocation. +class JumpAnnotation { +public: + ASMJIT_NONCOPYABLE(JumpAnnotation) + + //! \name Members + //! \{ + + //! Compiler that owns this JumpAnnotation. + BaseCompiler* _compiler; + //! Annotation identifier. + uint32_t _annotationId; + //! Vector of label identifiers, see \ref labelIds(). + ZoneVector _labelIds; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG JumpAnnotation(BaseCompiler* ASMJIT_NONNULL(compiler), uint32_t annotationId) noexcept + : _compiler(compiler), + _annotationId(annotationId) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the compiler that owns this JumpAnnotation. + ASMJIT_INLINE_NODEBUG BaseCompiler* compiler() const noexcept { return _compiler; } + //! Returns the annotation id. + ASMJIT_INLINE_NODEBUG uint32_t annotationId() const noexcept { return _annotationId; } + //! Returns a vector of label identifiers that lists all targets of the jump. + ASMJIT_INLINE_NODEBUG const ZoneVector& labelIds() const noexcept { return _labelIds; } + + //! Tests whether the given `label` is a target of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG bool hasLabel(const Label& label) const noexcept { return hasLabelId(label.id()); } + //! Tests whether the given `labelId` is a target of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG bool hasLabelId(uint32_t labelId) const noexcept { return _labelIds.contains(labelId); } + + //! \} + + //! \name Annotation Building API + //! \{ + + //! Adds the `label` to the list of targets of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG Error addLabel(const Label& label) noexcept { return addLabelId(label.id()); } + //! Adds the `labelId` to the list of targets of this JumpAnnotation. + ASMJIT_INLINE_NODEBUG Error addLabelId(uint32_t labelId) noexcept { return _labelIds.append(&_compiler->_allocator, labelId); } + + //! \} +}; + +//! Jump instruction with \ref JumpAnnotation. +//! +//! \note This node should be only used to represent jump where the jump target cannot be deduced by examining +//! instruction operands. For example if the jump target is register or memory location. This pattern is often +//! used to perform indirect jumps that use jump table, e.g. to implement `switch{}` statement. +class JumpNode : public InstNodeWithOperands { +public: + ASMJIT_NONCOPYABLE(JumpNode) + + //! \name Members + //! \{ + + JumpAnnotation* _annotation; + + //! \} + + //! \name Construction & Destruction + //! \{ + + inline JumpNode(BaseCompiler* ASMJIT_NONNULL(cc), InstId instId, InstOptions options, uint32_t opCount, JumpAnnotation* annotation) noexcept + : InstNodeWithOperands(cc, instId, options, opCount), + _annotation(annotation) { + setType(NodeType::kJump); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether this JumpNode has associated a \ref JumpAnnotation. + ASMJIT_INLINE_NODEBUG bool hasAnnotation() const noexcept { return _annotation != nullptr; } + //! Returns the \ref JumpAnnotation associated with this jump, or `nullptr`. + ASMJIT_INLINE_NODEBUG JumpAnnotation* annotation() const noexcept { return _annotation; } + //! Sets the \ref JumpAnnotation associated with this jump to `annotation`. + ASMJIT_INLINE_NODEBUG void setAnnotation(JumpAnnotation* annotation) noexcept { _annotation = annotation; } + + //! \} +}; + +//! Function node represents a function used by \ref BaseCompiler. +//! +//! A function is composed of the following: +//! +//! - Function entry, \ref FuncNode acts as a label, so the entry is implicit. To get the entry, simply use +//! \ref FuncNode::label(), which is the same as \ref LabelNode::label(). +//! +//! - Function exit, which is represented by \ref FuncNode::exitNode(). A helper function +//! \ref FuncNode::exitLabel() exists and returns an exit label instead of node. +//! +//! - Function \ref FuncNode::endNode() sentinel. This node marks the end of a function - there should be no +//! code that belongs to the function after this node, but the Compiler doesn't enforce that at the moment. +//! +//! - Function detail, see \ref FuncNode::detail(). +//! +//! - Function frame, see \ref FuncNode::frame(). +//! +//! - Function arguments mapped to virtual registers, see \ref FuncNode::argPacks(). +//! +//! In a node list, the function and its body looks like the following: +//! +//! \code{.unparsed} +//! [...] - Anything before the function. +//! +//! [FuncNode] - Entry point of the function, acts as a label as well. +//! - Prolog inserted by the register allocator. +//! {...} - Function body - user code basically. +//! [ExitLabel] - Exit label +//! - Epilog inserted by the register allocator. +//! - Return inserted by the register allocator. +//! {...} - Can contain data or user code (error handling, special cases, ...). +//! [FuncEnd] - End sentinel +//! +//! [...] - Anything after the function. +//! \endcode +//! +//! When a function is added to the instruction stream by \ref BaseCompiler::addFunc() it actually inserts 3 nodes +//! (FuncNode, ExitLabel, and FuncEnd) and sets the current cursor to be FuncNode. When \ref BaseCompiler::endFunc() +//! is called the cursor is set to FuncEnd. This guarantees that user can use ExitLabel as a marker after additional +//! code or data can be placed, which is a common practice. +class FuncNode : public LabelNode { +public: + ASMJIT_NONCOPYABLE(FuncNode) + + //! Arguments pack. + struct ArgPack { + RegOnly _data[Globals::kMaxValuePack]; + + inline void reset() noexcept { + for (size_t valueIndex = 0; valueIndex < Globals::kMaxValuePack; valueIndex++) + _data[valueIndex].reset(); + } + + inline RegOnly& operator[](size_t valueIndex) noexcept { return _data[valueIndex]; } + inline const RegOnly& operator[](size_t valueIndex) const noexcept { return _data[valueIndex]; } + }; + + //! \name Members + //! \{ + + //! Function detail. + FuncDetail _funcDetail; + //! Function frame. + FuncFrame _frame; + //! Function exit label. + LabelNode* _exitNode; + //! Function end (sentinel). + SentinelNode* _end; + //! Argument packs. + ArgPack* _args; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `FuncNode` instance. + //! + //! Always use `BaseCompiler::addFunc()` to create a new `FuncNode`. + inline FuncNode(BaseBuilder* ASMJIT_NONNULL(cb)) noexcept + : LabelNode(cb), + _funcDetail(), + _frame(), + _exitNode(nullptr), + _end(nullptr), + _args(nullptr) { + setType(NodeType::kFunc); + } + + //! \} + + //! \{ + //! \name Accessors + + //! Returns function exit `LabelNode`. + ASMJIT_INLINE_NODEBUG LabelNode* exitNode() const noexcept { return _exitNode; } + //! Returns function exit label. + ASMJIT_INLINE_NODEBUG Label exitLabel() const noexcept { return _exitNode->label(); } + + //! Returns "End of Func" sentinel node. + ASMJIT_INLINE_NODEBUG SentinelNode* endNode() const noexcept { return _end; } + + //! Returns function detail. + ASMJIT_INLINE_NODEBUG FuncDetail& detail() noexcept { return _funcDetail; } + //! Returns function detail. + ASMJIT_INLINE_NODEBUG const FuncDetail& detail() const noexcept { return _funcDetail; } + + //! Returns function frame. + ASMJIT_INLINE_NODEBUG FuncFrame& frame() noexcept { return _frame; } + //! Returns function frame. + ASMJIT_INLINE_NODEBUG const FuncFrame& frame() const noexcept { return _frame; } + + //! Returns function attributes. + ASMJIT_INLINE_NODEBUG FuncAttributes attributes() const noexcept { return _frame.attributes(); } + //! Adds `attrs` to the function attributes. + ASMJIT_INLINE_NODEBUG void addAttributes(FuncAttributes attrs) noexcept { _frame.addAttributes(attrs); } + + //! Returns arguments count. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _funcDetail.argCount(); } + //! Returns argument packs. + ASMJIT_INLINE_NODEBUG ArgPack* argPacks() const noexcept { return _args; } + + //! Tests whether the function has a return value. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return _funcDetail.hasRet(); } + + //! Returns argument pack at `argIndex`. + inline ArgPack& argPack(size_t argIndex) const noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex]; + } + + //! Sets argument at `argIndex`. + inline void setArg(size_t argIndex, const BaseReg& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][0].init(vReg); + } + + //! \overload + inline void setArg(size_t argIndex, const RegOnly& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][0].init(vReg); + } + + //! Sets argument at `argIndex` and `valueIndex`. + inline void setArg(size_t argIndex, size_t valueIndex, const BaseReg& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex].init(vReg); + } + + //! \overload + inline void setArg(size_t argIndex, size_t valueIndex, const RegOnly& vReg) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex].init(vReg); + } + + //! Resets argument pack at `argIndex`. + inline void resetArg(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex].reset(); + } + + //! Resets argument pack at `argIndex`. + inline void resetArg(size_t argIndex, size_t valueIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex].reset(); + } + + //! \} +}; + +//! Function return, used by \ref BaseCompiler. +class FuncRetNode : public InstNodeWithOperands { +public: + ASMJIT_NONCOPYABLE(FuncRetNode) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `FuncRetNode` instance. + inline FuncRetNode(BaseBuilder* ASMJIT_NONNULL(cb)) noexcept + : InstNodeWithOperands(cb, BaseInst::kIdAbstract, InstOptions::kNone, 0) { + _any._nodeType = NodeType::kFuncRet; + } + + //! \} +}; + +//! Function invocation, used by \ref BaseCompiler. +class InvokeNode : public InstNodeWithOperands { +public: + ASMJIT_NONCOPYABLE(InvokeNode) + + //! Operand pack provides multiple operands that can be associated with a single return value of function + //! argument. Sometimes this is necessary to express an argument or return value that requires multiple + //! registers, for example 64-bit value in 32-bit mode or passing / returning homogeneous data structures. + struct OperandPack { + //! Operands. + Operand_ _data[Globals::kMaxValuePack]; + + //! Reset the pack by resetting all operands in the pack. + inline void reset() noexcept { + for (size_t valueIndex = 0; valueIndex < Globals::kMaxValuePack; valueIndex++) + _data[valueIndex].reset(); + } + + //! Returns an operand at the given `valueIndex`. + inline Operand& operator[](size_t valueIndex) noexcept { + ASMJIT_ASSERT(valueIndex < Globals::kMaxValuePack); + return _data[valueIndex].as(); + } + + //! Returns an operand at the given `valueIndex` (const). + const inline Operand& operator[](size_t valueIndex) const noexcept { + ASMJIT_ASSERT(valueIndex < Globals::kMaxValuePack); + return _data[valueIndex].as(); + } + }; + + //! \name Members + //! \{ + + //! Function detail. + FuncDetail _funcDetail; + //! Function return value(s). + OperandPack _rets; + //! Function arguments. + OperandPack* _args; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `InvokeNode` instance. + inline InvokeNode(BaseBuilder* ASMJIT_NONNULL(cb), InstId instId, InstOptions options) noexcept + : InstNodeWithOperands(cb, instId, options, 0), + _funcDetail(), + _args(nullptr) { + setType(NodeType::kInvoke); + _resetOps(); + _rets.reset(); + addFlags(NodeFlags::kIsRemovable); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Sets the function signature. + inline Error init(const FuncSignature& signature, const Environment& environment) noexcept { + return _funcDetail.init(signature, environment); + } + + //! Returns the function detail. + ASMJIT_INLINE_NODEBUG FuncDetail& detail() noexcept { return _funcDetail; } + //! Returns the function detail. + ASMJIT_INLINE_NODEBUG const FuncDetail& detail() const noexcept { return _funcDetail; } + + //! Returns the target operand. + ASMJIT_INLINE_NODEBUG Operand& target() noexcept { return op(0); } + //! \overload + ASMJIT_INLINE_NODEBUG const Operand& target() const noexcept { return op(0); } + + //! Returns the number of function return values. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return _funcDetail.hasRet(); } + //! Returns the number of function arguments. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _funcDetail.argCount(); } + + //! Returns operand pack representing function return value(s). + ASMJIT_INLINE_NODEBUG OperandPack& retPack() noexcept { return _rets; } + //! Returns operand pack representing function return value(s). + ASMJIT_INLINE_NODEBUG const OperandPack& retPack() const noexcept { return _rets; } + + //! Returns the return value at the given `valueIndex`. + ASMJIT_INLINE_NODEBUG Operand& ret(size_t valueIndex = 0) noexcept { return _rets[valueIndex]; } + //! \overload + ASMJIT_INLINE_NODEBUG const Operand& ret(size_t valueIndex = 0) const noexcept { return _rets[valueIndex]; } + + //! Returns operand pack representing function return value(s). + inline OperandPack& argPack(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex]; + } + //! \overload + inline const OperandPack& argPack(size_t argIndex) const noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex]; + } + + //! Returns a function argument at the given `argIndex`. + inline Operand& arg(size_t argIndex, size_t valueIndex) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex][valueIndex]; + } + //! \overload + inline const Operand& arg(size_t argIndex, size_t valueIndex) const noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + return _args[argIndex][valueIndex]; + } + + //! Sets the function return value at `i` to `op`. + inline void _setRet(size_t valueIndex, const Operand_& op) noexcept { _rets[valueIndex] = op; } + //! Sets the function argument at `i` to `op`. + inline void _setArg(size_t argIndex, size_t valueIndex, const Operand_& op) noexcept { + ASMJIT_ASSERT(argIndex < argCount()); + _args[argIndex][valueIndex] = op; + } + + //! Sets the function return value at `valueIndex` to `reg`. + ASMJIT_INLINE_NODEBUG void setRet(size_t valueIndex, const BaseReg& reg) noexcept { _setRet(valueIndex, reg); } + + //! Sets the first function argument in a value-pack at `argIndex` to `reg`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, const BaseReg& reg) noexcept { _setArg(argIndex, 0, reg); } + //! Sets the first function argument in a value-pack at `argIndex` to `imm`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, const Imm& imm) noexcept { _setArg(argIndex, 0, imm); } + + //! Sets the function argument at `argIndex` and `valueIndex` to `reg`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg) noexcept { _setArg(argIndex, valueIndex, reg); } + //! Sets the function argument at `argIndex` and `valueIndex` to `imm`. + ASMJIT_INLINE_NODEBUG void setArg(size_t argIndex, size_t valueIndex, const Imm& imm) noexcept { _setArg(argIndex, valueIndex, imm); } + + //! \} +}; + +//! Function pass extends \ref Pass with \ref FuncPass::runOnFunction(). +class ASMJIT_VIRTAPI FuncPass : public Pass { +public: + ASMJIT_NONCOPYABLE(FuncPass) + typedef Pass Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API FuncPass(const char* name) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the associated `BaseCompiler`. + ASMJIT_INLINE_NODEBUG BaseCompiler* cc() const noexcept { return static_cast(_cb); } + + //! \} + + //! \name Pass Interface + //! \{ + + //! Calls `runOnFunction()` on each `FuncNode` node found. + ASMJIT_API Error run(Zone* zone, Logger* logger) override; + + //! Called once per `FuncNode`. + ASMJIT_API virtual Error runOnFunction(Zone* zone, Logger* logger, FuncNode* func); + + //! \} +}; + +#if !defined(ASMJIT_NO_DEPRECATED) +inline Error BaseCompiler::_setArg(size_t argIndex, size_t valueIndex, const BaseReg& reg) { + FuncNode* func = _func; + + if (ASMJIT_UNLIKELY(!func)) + return reportError(DebugUtils::errored(kErrorInvalidState)); + + func->setArg(argIndex, valueIndex, reg); + return kErrorOk; +} +#endif + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // !ASMJIT_NO_COMPILER +#endif // ASMJIT_CORE_COMPILER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/compilerdefs.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/compilerdefs.h new file mode 100644 index 0000000000000000000000000000000000000000..4c74eecdfe02b2466a714e5d9b251c359a3fed71 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/compilerdefs.h @@ -0,0 +1,171 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_COMPILERDEFS_H_INCLUDED +#define ASMJIT_CORE_COMPILERDEFS_H_INCLUDED + +#include "../core/api-config.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../core/zonestring.h" + +ASMJIT_BEGIN_NAMESPACE + +class RAWorkReg; + +//! \addtogroup asmjit_compiler +//! \{ + +//! Virtual register data, managed by \ref BaseCompiler. +class VirtReg { +public: + ASMJIT_NONCOPYABLE(VirtReg) + + //! \name Members + //! \{ + + //! Virtual register signature. + OperandSignature _signature {}; + //! Virtual register id. + uint32_t _id = 0; + //! Virtual register size (can be smaller than `_signature._size`). + uint32_t _virtSize = 0; + //! Virtual register alignment (for spilling). + uint8_t _alignment = 0; + //! Type-id. + TypeId _typeId = TypeId::kVoid; + //! Virtual register weight for alloc/spill decisions. + uint8_t _weight = 1; + //! True if this is a fixed register, never reallocated. + uint8_t _isFixed : 1; + //! True if the virtual register is only used as a stack (never accessed as register). + uint8_t _isStack : 1; + //! True if this virtual register has assigned stack offset (can be only valid after register allocation pass). + uint8_t _hasStackSlot : 1; + uint8_t _reservedBits : 5; + + //! Stack offset assigned by the register allocator relative to stack pointer (can be negative as well). + int32_t _stackOffset = 0; + + //! Reserved for future use (padding). + uint32_t _reservedU32 = 0; + + //! Virtual register name (user provided or automatically generated). + ZoneString<16> _name {}; + + // The following members are used exclusively by RAPass. They are initialized when the VirtReg is created to + // null pointers and then changed during RAPass execution. RAPass sets them back to NULL before it returns. + + //! Reference to `RAWorkReg`, used during register allocation. + RAWorkReg* _workReg = nullptr; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG VirtReg(OperandSignature signature, uint32_t id, uint32_t virtSize, uint32_t alignment, TypeId typeId) noexcept + : _signature(signature), + _id(id), + _virtSize(virtSize), + _alignment(uint8_t(alignment)), + _typeId(typeId), + _isFixed(0), + _isStack(0), + _hasStackSlot(0), + _reservedBits(0) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the virtual register id. + ASMJIT_INLINE_NODEBUG uint32_t id() const noexcept { return _id; } + + //! Returns the virtual register name. + ASMJIT_INLINE_NODEBUG const char* name() const noexcept { return _name.data(); } + //! Returns the size of the virtual register name. + ASMJIT_INLINE_NODEBUG uint32_t nameSize() const noexcept { return _name.size(); } + + //! Returns a register signature of this virtual register. + ASMJIT_INLINE_NODEBUG OperandSignature signature() const noexcept { return _signature; } + //! Returns a virtual register type (maps to the physical register type as well). + ASMJIT_INLINE_NODEBUG RegType type() const noexcept { return _signature.regType(); } + //! Returns a virtual register group (maps to the physical register group as well). + ASMJIT_INLINE_NODEBUG RegGroup group() const noexcept { return _signature.regGroup(); } + + //! Returns a real size of the register this virtual register maps to. + //! + //! For example if this is a 128-bit SIMD register used for a scalar single precision floating point value then + //! its virtSize would be 4, however, the `regSize` would still say 16 (128-bits), because it's the smallest size + //! of that register type. + ASMJIT_INLINE_NODEBUG uint32_t regSize() const noexcept { return _signature.size(); } + + //! Returns the virtual register size. + //! + //! The virtual register size describes how many bytes the virtual register needs to store its content. It can be + //! smaller than the physical register size, see `regSize()`. + ASMJIT_INLINE_NODEBUG uint32_t virtSize() const noexcept { return _virtSize; } + + //! Returns the virtual register alignment. + ASMJIT_INLINE_NODEBUG uint32_t alignment() const noexcept { return _alignment; } + + //! Returns the virtual register type id. + ASMJIT_INLINE_NODEBUG TypeId typeId() const noexcept { return _typeId; } + + //! Returns the virtual register weight - the register allocator can use it as explicit hint for alloc/spill + //! decisions. + ASMJIT_INLINE_NODEBUG uint32_t weight() const noexcept { return _weight; } + //! Sets the virtual register weight (0 to 255) - the register allocator can use it as explicit hint for + //! alloc/spill decisions and initial bin-packing. + ASMJIT_INLINE_NODEBUG void setWeight(uint32_t weight) noexcept { _weight = uint8_t(weight); } + + //! Returns whether the virtual register is always allocated to a fixed physical register (and never reallocated). + //! + //! \note This is only used for special purposes and it's mostly internal. + ASMJIT_INLINE_NODEBUG bool isFixed() const noexcept { return bool(_isFixed); } + + //! Tests whether the virtual register is in fact a stack that only uses the virtual register id. + //! + //! \note It's an error if a stack is accessed as a register. + ASMJIT_INLINE_NODEBUG bool isStack() const noexcept { return bool(_isStack); } + + //! Tests whether this virtual register (or stack) has assigned a stack offset. + //! + //! If this is a virtual register that was never allocated on stack, it would return false, otherwise if + //! it's a virtual register that was spilled or explicitly allocated stack, the return value would be true. + ASMJIT_INLINE_NODEBUG bool hasStackSlot() const noexcept { return bool(_hasStackSlot); } + + //! Assigns a stack offset of this virtual register to `stackOffset` and sets `_hasStackSlot` to true. + ASMJIT_INLINE_NODEBUG void assignStackSlot(int32_t stackOffset) noexcept { + _hasStackSlot = 1; + _stackOffset = stackOffset; + } + + //! Returns a stack offset associated with a virtual register or explicit stack allocation. + //! + //! \note Always verify that the stack offset has been assigned by calling \ref hasStackSlot(). The return + //! value will be zero when the stack offset was not assigned. + ASMJIT_INLINE_NODEBUG int32_t stackOffset() const noexcept { return _stackOffset; } + + //! Tests whether the virtual register has an associated `RAWorkReg` at the moment. + ASMJIT_INLINE_NODEBUG bool hasWorkReg() const noexcept { return _workReg != nullptr; } + //! Returns an associated RAWorkReg with this virtual register (only valid during register allocation). + ASMJIT_INLINE_NODEBUG RAWorkReg* workReg() const noexcept { return _workReg; } + //! Associates a RAWorkReg with this virtual register (used by register allocator). + ASMJIT_INLINE_NODEBUG void setWorkReg(RAWorkReg* workReg) noexcept { _workReg = workReg; } + //! Reset the RAWorkReg association (used by register allocator). + ASMJIT_INLINE_NODEBUG void resetWorkReg() noexcept { _workReg = nullptr; } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_COMPILERDEFS_H_INCLUDED + diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/constpool.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/constpool.h new file mode 100644 index 0000000000000000000000000000000000000000..94c80313c6532e76e9bfc2005d355660ede64b65 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/constpool.h @@ -0,0 +1,261 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CONSTPOOL_H_INCLUDED +#define ASMJIT_CORE_CONSTPOOL_H_INCLUDED + +#include "../core/support.h" +#include "../core/zone.h" +#include "../core/zonetree.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! Constant pool scope. +enum class ConstPoolScope : uint32_t { + //! Local constant, always embedded right after the current function. + kLocal = 0, + //! Global constant, embedded at the end of the currently compiled code. + kGlobal = 1, + + //! Maximum value of `ConstPoolScope`. + kMaxValue = kGlobal +}; + +//! Constant pool. +//! +//! Constant pool is designed to hold 1, 2, 4, 8, 16, 32, and 64 byte constants. It's not designed to hold constants +//! having arbitrary length like strings and arrays. +class ConstPool { +public: + ASMJIT_NONCOPYABLE(ConstPool) + + //! \cond INTERNAL + + //! Index of a given size in const-pool table. + enum Index : uint32_t { + kIndex1 = 0, + kIndex2 = 1, + kIndex4 = 2, + kIndex8 = 3, + kIndex16 = 4, + kIndex32 = 5, + kIndex64 = 6, + kIndexCount = 7 + }; + + //! Zone-allocated const-pool gap created by two differently aligned constants. + struct Gap { + //! Pointer to the next gap + Gap* _next; + //! Offset of the gap. + size_t _offset; + //! Remaining bytes of the gap (basically a gap size). + size_t _size; + }; + + //! Zone-allocated const-pool node. + class Node : public ZoneTreeNodeT { + public: + ASMJIT_NONCOPYABLE(Node) + + //! If this constant is shared with another. + uint32_t _shared : 1; + //! Data offset from the beginning of the pool. + uint32_t _offset; + + ASMJIT_INLINE_NODEBUG Node(size_t offset, bool shared) noexcept + : ZoneTreeNodeT(), + _shared(shared), + _offset(uint32_t(offset)) {} + + ASMJIT_INLINE_NODEBUG void* data() const noexcept { + return static_cast(const_cast(this) + 1); + } + }; + + //! Data comparer used internally. + class Compare { + public: + size_t _dataSize; + + ASMJIT_INLINE_NODEBUG Compare(size_t dataSize) noexcept + : _dataSize(dataSize) {} + + ASMJIT_INLINE_NODEBUG int operator()(const Node& a, const Node& b) const noexcept { + return ::memcmp(a.data(), b.data(), _dataSize); + } + + ASMJIT_INLINE_NODEBUG int operator()(const Node& a, const void* data) const noexcept { + return ::memcmp(a.data(), data, _dataSize); + } + }; + + //! Zone-allocated const-pool tree. + struct Tree { + //! RB tree. + ZoneTree _tree; + //! Size of the tree (number of nodes). + size_t _size; + //! Size of the data. + size_t _dataSize; + + ASMJIT_INLINE_NODEBUG explicit Tree(size_t dataSize = 0) noexcept + : _tree(), + _size(0), + _dataSize(dataSize) {} + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _tree.reset(); + _size = 0; + } + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + + inline void setDataSize(size_t dataSize) noexcept { + ASMJIT_ASSERT(empty()); + _dataSize = dataSize; + } + + ASMJIT_INLINE_NODEBUG Node* get(const void* data) noexcept { + Compare cmp(_dataSize); + return _tree.get(data, cmp); + } + + ASMJIT_INLINE_NODEBUG void insert(Node* node) noexcept { + Compare cmp(_dataSize); + _tree.insert(node, cmp); + _size++; + } + + template + inline void forEach(Visitor& visitor) const noexcept { + Node* node = _tree.root(); + if (!node) return; + + Node* stack[Globals::kMaxTreeHeight]; + size_t top = 0; + + for (;;) { + Node* left = node->left(); + if (left != nullptr) { + ASMJIT_ASSERT(top != Globals::kMaxTreeHeight); + stack[top++] = node; + + node = left; + continue; + } + + for (;;) { + visitor(node); + node = node->right(); + + if (node != nullptr) + break; + + if (top == 0) + return; + + node = stack[--top]; + } + } + } + + static inline Node* _newNode(Zone* zone, const void* data, size_t size, size_t offset, bool shared) noexcept { + Node* node = zone->allocT(sizeof(Node) + size); + if (ASMJIT_UNLIKELY(!node)) return nullptr; + + node = new(Support::PlacementNew{node}) Node(offset, shared); + memcpy(node->data(), data, size); + return node; + } + }; + + //! \endcond + + //! \name Members + //! \{ + + //! Zone allocator. + Zone* _zone; + //! Tree per size. + Tree _tree[kIndexCount]; + //! Gaps per size. + Gap* _gaps[kIndexCount]; + //! Gaps pool + Gap* _gapPool; + + //! Size of the pool (in bytes). + size_t _size; + //! Required pool alignment. + size_t _alignment; + //! Minimum item size in the pool. + size_t _minItemSize; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new constant pool that would use `zone` as a memory allocator. + ASMJIT_API explicit ConstPool(Zone* zone) noexcept; + //! Destroys this constant pool. + ASMJIT_API ~ConstPool() noexcept; + + //! \} + + //! \name Reset + //! \{ + + //! Resets this constant pool and its allocator to `zone`. + ASMJIT_API void reset(Zone* zone) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the constant-pool is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + //! Returns the size of the constant-pool in bytes. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + //! Returns minimum alignment. + ASMJIT_INLINE_NODEBUG size_t alignment() const noexcept { return _alignment; } + //! Returns the minimum size of all items added to the constant pool. + ASMJIT_INLINE_NODEBUG size_t minItemSize() const noexcept { return _minItemSize; } + + //! \} + + //! \name Utilities + //! \{ + + //! Adds a constant to the constant pool. + //! + //! The constant must have known size, which is 1, 2, 4, 8, 16 or 32 bytes. The constant is added to the pool only + //! if it doesn't not exist, otherwise cached value is returned. + //! + //! AsmJit is able to subdivide added constants, so for example if you add 8-byte constant 0x1122334455667788 it + //! will create the following slots: + //! + //! 8-byte: 0x1122334455667788 + //! 4-byte: 0x11223344, 0x55667788 + //! + //! The reason is that when combining MMX/SSE/AVX code some patterns are used frequently. However, AsmJit is not + //! able to reallocate a constant that has been already added. For example if you try to add 4-byte constant and + //! then 8-byte constant having the same 4-byte pattern as the previous one, two independent slots will be used. + ASMJIT_API Error add(const void* data, size_t size, size_t& dstOffset) noexcept; + + //! Fills the destination with the content of this constant pool. + ASMJIT_API void fill(void* dst) const noexcept; +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CONSTPOOL_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/cpuinfo.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/cpuinfo.h new file mode 100644 index 0000000000000000000000000000000000000000..d4903cb9e8a28e546a4580f6c0e37c0b9174362f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/cpuinfo.h @@ -0,0 +1,1224 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_CPUINFO_H_INCLUDED +#define ASMJIT_CORE_CPUINFO_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/environment.h" +#include "../core/globals.h" +#include "../core/string.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! CPU features information. +//! +//! Each feature is represented by a single bit in an embedded bit array. +class CpuFeatures { +public: + //! \name Constants + //! \{ + + //! \cond INTERNAL + enum : uint32_t { + kMaxFeatures = 256, + kNumBitWords = kMaxFeatures / Support::kBitWordSizeInBits + }; + //! \endcond + + //! A word that is used to represents feature bits. + typedef Support::BitWord BitWord; + //! Iterator that can iterate all CPU features set. + typedef Support::BitVectorIterator Iterator; + + typedef Support::Array Bits; + + //! \} + + //! \name Data + //! \{ + + //! CPU features data. + struct Data { + //! \name Members + //! \{ + + //! Data bits. + Bits _bits; + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG bool operator==(const Data& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const Data& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns true if there are no features set. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _bits.aggregate(0) == 0; } + + //! Returns all features as array of bitwords (see \ref Support::BitWord). + ASMJIT_INLINE_NODEBUG BitWord* bits() noexcept { return _bits.data(); } + //! Returns all features as array of bitwords (const). + ASMJIT_INLINE_NODEBUG const BitWord* bits() const noexcept { return _bits.data(); } + + //! Returns the number of BitWords returned by \ref bits(). + ASMJIT_INLINE_NODEBUG size_t bitWordCount() const noexcept { return kNumBitWords; } + + //! Returns \ref Support::BitVectorIterator, that can be used to iterate over all features efficiently. + ASMJIT_INLINE_NODEBUG Iterator iterator() const noexcept { return Iterator(_bits.data(), kNumBitWords); } + + //! Tests whether the feature `featureId` is present. + template + ASMJIT_INLINE_NODEBUG bool has(const FeatureId& featureId) const noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + return bool((_bits[idx] >> bit) & 0x1); + } + + //! \cond NONE + template + ASMJIT_INLINE_NODEBUG bool hasAny(const FeatureId& featureId) const noexcept { + return has(featureId); + } + //! \endcond + + //! Tests whether any feature given is present. + //! + //! \note This is a variadic function template that can be used with multiple features. + template + ASMJIT_INLINE_NODEBUG bool hasAny(const FeatureId& featureId, Args&&... otherFeatureIds) const noexcept { + return bool(unsigned(has(featureId)) | unsigned(hasAny(std::forward(otherFeatureIds)...))); + } + + //! Tests whether all features as defined by `other` are present. + ASMJIT_INLINE_NODEBUG bool hasAll(const Data& other) const noexcept { + uint32_t result = 1; + for (uint32_t i = 0; i < kNumBitWords; i++) + result &= uint32_t((_bits[i] & other._bits[i]) == other._bits[i]); + return bool(result); + } + + //! \} + + //! \name Manipulation + //! \{ + + //! Clears all features set. + ASMJIT_INLINE_NODEBUG void reset() noexcept { _bits.fill(0); } + + //! Adds the given CPU `featureId` to the list of features. + template + ASMJIT_INLINE_NODEBUG void add(const FeatureId& featureId) noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + _bits[idx] |= BitWord(1) << bit; + } + + template + ASMJIT_INLINE_NODEBUG void add(const FeatureId& featureId, Args&&... otherFeatureIds) noexcept { + add(featureId); + add(std::forward(otherFeatureIds)...); + } + + template + ASMJIT_INLINE_NODEBUG void addIf(bool condition, const FeatureId& featureId) noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + _bits[idx] |= BitWord(condition) << bit; + } + + template + ASMJIT_INLINE_NODEBUG void addIf(bool condition, const FeatureId& featureId, Args&&... otherFeatureIds) noexcept { + addIf(condition, featureId); + addIf(condition, std::forward(otherFeatureIds)...); + } + + //! Removes the given CPU `featureId` from the list of features. + template + ASMJIT_INLINE_NODEBUG void remove(const FeatureId& featureId) noexcept { + ASMJIT_ASSERT(uint32_t(featureId) < kMaxFeatures); + + uint32_t idx = uint32_t(featureId) / Support::kBitWordSizeInBits; + uint32_t bit = uint32_t(featureId) % Support::kBitWordSizeInBits; + + _bits[idx] &= ~(BitWord(1) << bit); + } + + template + ASMJIT_INLINE_NODEBUG void remove(const FeatureId& featureId, Args&&... otherFeatureIds) noexcept { + remove(featureId); + remove(std::forward(otherFeatureIds)...); + } + + //! Tests whether this CPU features data matches `other`. + ASMJIT_INLINE_NODEBUG bool equals(const Data& other) const noexcept { return _bits == other._bits; } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use CpuFeatures::Data::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const Data& other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} + }; + + //! X86 specific features data. + struct X86 : public Data { + //! X86 CPU feature identifiers. + enum Id : uint8_t { + // @EnumValuesBegin{"enum": "CpuFeatures::X86"}@ + kNone, //!< No feature (never set, used internally). + + kMT, //!< CPU has multi-threading capabilities. + kNX, //!< CPU has Not-Execute-Bit aka DEP (data-execution prevention). + k3DNOW, //!< CPU has 3DNOW (3DNOW base instructions) {AMD} (deprecated). + k3DNOW2, //!< CPU has 3DNOW2 (enhanced 3DNOW) {AMD} (deprecated). + kADX, //!< CPU has ADX (multi-precision add-carry instruction extensions). + kAESNI, //!< CPU has AESNI (AES encode/decode instructions). + kALTMOVCR8, //!< CPU has LOCK MOV R<->CR0 (supports `MOV R<->CR8` via `LOCK MOV R<->CR0` in 32-bit mode) {AMD}. + kAMX_BF16, //!< CPU has AMX_BF16 (AMX-BF16 instructions). + kAMX_COMPLEX, //!< CPU has AMX_COMPLEX (AMX-COMPLEX instructions). + kAMX_FP16, //!< CPU has AMX_FP16 (AMX-FP16 instructions). + kAMX_INT8, //!< CPU has AMX_INT8 (AMX-INT8 instructions). + kAMX_TILE, //!< CPU has AMX_TILE (advanced matrix extensions). + kAPX_F, //!< CPU has APX_F (advanced performance extensions - 32 GP registers, REX2 prefix, ...) {X86_64}. + kAVX, //!< CPU has AVX (advanced vector extensions). + kAVX2, //!< CPU has AVX2 (advanced vector extensions 2). + kAVX512_4FMAPS, //!< CPU has AVX512_FMAPS (FMA packed single). + kAVX512_4VNNIW, //!< CPU has AVX512_VNNIW (vector NN instructions word variable precision). + kAVX512_BF16, //!< CPU has AVX512_BF16 (AVX512 BFLOAT16 support instructions). + kAVX512_BITALG, //!< CPU has AVX512_BITALG (AVX512 VPOPCNT[B|W] and VPSHUFBITQMB instructions). + kAVX512_BW, //!< CPU has AVX512_BW (AVX512 integer BYTE|WORD instructions). + kAVX512_CD, //!< CPU has AVX512_CD (AVX512 conflict detection DWORD|QWORD instructions). + kAVX512_DQ, //!< CPU has AVX512_DQ (AVX512 integer DWORD|QWORD instructions). + kAVX512_ER, //!< CPU has AVX512_ER (AVX512 exponential and reciprocal instructions). + kAVX512_F, //!< CPU has AVX512_F (AVX512 foundation). + kAVX512_FP16, //!< CPU has AVX512_FP16 (AVX512 FP16 instructions). + kAVX512_IFMA, //!< CPU has AVX512_IFMA (AVX512 integer fused-multiply-add using 52-bit precision). + kAVX512_PF, //!< CPU has AVX512_PF (AVX512 prefetch instructions). + kAVX512_VBMI, //!< CPU has AVX512_VBMI (AVX152 vector byte manipulation instructions). + kAVX512_VBMI2, //!< CPU has AVX512_VBMI2 (AVX512 vector byte manipulation instructions v2). + kAVX512_VL, //!< CPU has AVX512_VL (AVX512 vector length extensions). + kAVX512_VNNI, //!< CPU has AVX512_VNNI (AVX512 vector neural network instructions). + kAVX512_VP2INTERSECT, //!< CPU has AVX512_VP2INTERSECT + kAVX512_VPOPCNTDQ, //!< CPU has AVX512_VPOPCNTDQ (AVX512 VPOPCNT[D|Q] instructions). + kAVX_IFMA, //!< CPU has AVX_IFMA (AVX/VEX encoding of vpmadd52huq/vpmadd52luq). + kAVX_NE_CONVERT, //!< CPU has AVX_NE_CONVERT. + kAVX_VNNI, //!< CPU has AVX_VNNI (AVX/VEX encoding of vpdpbusd/vpdpbusds/vpdpwssd/vpdpwssds). + kAVX_VNNI_INT16, //!< CPU has AVX_VNNI_INT16. + kAVX_VNNI_INT8, //!< CPU has AVX_VNNI_INT8. + kBMI, //!< CPU has BMI (bit manipulation instructions #1). + kBMI2, //!< CPU has BMI2 (bit manipulation instructions #2). + kCET_IBT, //!< CPU has CET-IBT (indirect branch tracking). + kCET_SS, //!< CPU has CET-SS. + kCET_SSS, //!< CPU has CET-SSS. + kCLDEMOTE, //!< CPU has CLDEMOTE (cache line demote). + kCLFLUSH, //!< CPU has CLFUSH (cache Line flush). + kCLFLUSHOPT, //!< CPU has CLFUSHOPT (cache Line flush - optimized). + kCLWB, //!< CPU has CLWB. + kCLZERO, //!< CPU has CLZERO. + kCMOV, //!< CPU has CMOV (CMOV and FCMOV instructions). + kCMPCCXADD, //!< CPU has CMPCCXADD. + kCMPXCHG16B, //!< CPU has CMPXCHG16B (compare-exchange 16 bytes) {X86_64}. + kCMPXCHG8B, //!< CPU has CMPXCHG8B (compare-exchange 8 bytes). + kENCLV, //!< CPU has ENCLV. + kENQCMD, //!< CPU has ENQCMD (enqueue stores). + kERMS, //!< CPU has ERMS (enhanced REP MOVSB/STOSB). + kF16C, //!< CPU has F16C (AVX FP16 conversion instructions). + kFMA, //!< CPU has FMA (AVX fused-multiply-add - 3 operand form). + kFMA4, //!< CPU has FMA4 (AVX fused-multiply-add - 4 operand form) (deprecated). + kFPU, //!< CPU has FPU (FPU support). + kFSGSBASE, //!< CPU has FSGSBASE. + kFSRM, //!< CPU has FSRM (fast short REP MOVSB). + kFSRC, //!< CPU has FSRC (fast short REP CMPSB|SCASB). + kFSRS, //!< CPU has FSRS (fast short REP STOSB) + kFXSR, //!< CPU has FXSR (FXSAVE/FXRSTOR instructions). + kFXSROPT, //!< CPU has FXSROTP (FXSAVE/FXRSTOR is optimized). + kFZRM, //!< CPU has FZRM (fast zero-length REP MOVSB). + kGEODE, //!< CPU has GEODE extensions (GEODE 3DNOW additions) (deprecated). + kGFNI, //!< CPU has GFNI (galois field instructions). + kHLE, //!< CPU has HLE. + kHRESET, //!< CPU has HRESET. + kI486, //!< CPU has I486 features (I486+ support). + kINVLPGB, //!< CPU has INVLPGB. + kLAHFSAHF, //!< CPU has LAHF/SAHF (LAHF/SAHF in 64-bit mode) {X86_64}. + kLAM, //!< CPU has LAM (linear address masking) {X86_64}. + kLWP, //!< CPU has LWP (lightweight profiling) {AMD}. + kLZCNT, //!< CPU has LZCNT (LZCNT instruction). + kMCOMMIT, //!< CPU has MCOMMIT (MCOMMIT instruction). + kMMX, //!< CPU has MMX (MMX base instructions) (deprecated). + kMMX2, //!< CPU has MMX2 (MMX2 extensions or initial SSE extensions) (deprecated). + kMONITOR, //!< CPU has MONITOR (MONITOR/MWAIT instructions). + kMONITORX, //!< CPU has MONITORX (MONITORX/MWAITX instructions). + kMOVBE, //!< CPU has MOVBE (move with byte-order swap). + kMOVDIR64B, //!< CPU has MOVDIR64B (move 64 bytes as direct store). + kMOVDIRI, //!< CPU has MOVDIRI (move dword/qword as direct store). + kMPX, //!< CPU has MPX (memory protection extensions). + kMSR, //!< CPU has MSR (RDMSR/WRMSR instructions). + kMSRLIST, //!< CPU has MSRLIST. + kMSSE, //!< CPU has MSSE (misaligned SSE support). + kOSXSAVE, //!< CPU has OSXSAVE (XSAVE enabled by OS). + kOSPKE, //!< CPU has OSPKE (PKE enabled by OS). + kPCLMULQDQ, //!< CPU has PCLMULQDQ (packed carry-less multiplication). + kPCONFIG, //!< CPU has PCONFIG (PCONFIG instruction). + kPOPCNT, //!< CPU has POPCNT (POPCNT instruction). + kPREFETCHI, //!< CPU has PREFETCHI. + kPREFETCHW, //!< CPU has PREFETCHW. + kPREFETCHWT1, //!< CPU has PREFETCHWT1. + kPTWRITE, //!< CPU has PTWRITE. + kRAO_INT, //!< CPU has RAO_INT (AADD, AAND, AOR, AXOR instructions). + kRMPQUERY, //!< CPU has RMPQUERY (RMPQUERY instruction). + kRDPID, //!< CPU has RDPID (RDPID instruction). + kRDPRU, //!< CPU has RDPRU (RDPRU instruction). + kRDRAND, //!< CPU has RDRAND (RDRAND instruction). + kRDSEED, //!< CPU has RDSEED (RDSEED instruction). + kRDTSC, //!< CPU has RDTSC. + kRDTSCP, //!< CPU has RDTSCP. + kRTM, //!< CPU has RTM. + kSEAM, //!< CPU has SEAM. + kSERIALIZE, //!< CPU has SERIALIZE. + kSEV, //!< CPU has SEV (secure encrypted virtualization). + kSEV_ES, //!< CPU has SEV_ES (SEV encrypted state). + kSEV_SNP, //!< CPU has SEV_SNP (SEV secure nested paging). + kSHA, //!< CPU has SHA (SHA-1 and SHA-256 instructions). + kSHA512, //!< CPU has SHA512 (SHA-512 instructions). + kSKINIT, //!< CPU has SKINIT (SKINIT/STGI instructions) {AMD}. + kSM3, //!< CPU has SM3 (SM3 hash extensions). + kSM4, //!< CPU has SM4 (SM4 cipher extensions). + kSMAP, //!< CPU has SMAP (supervisor-mode access prevention). + kSME , //!< CPU has SME (secure memory encryption). + kSMEP, //!< CPU has SMEP (supervisor-mode execution prevention). + kSMX, //!< CPU has SMX (safer mode extensions). + kSSE, //!< CPU has SSE (SSE instructions). + kSSE2, //!< CPU has SSE2 (SSE2 instructions). + kSSE3, //!< CPU has SSE3 (SSE3 instructions). + kSSE4_1, //!< CPU has SSE4.1 (SSE4.1 instructions). + kSSE4_2, //!< CPU has SSE4.2 (SSE4.2 instructions). + kSSE4A, //!< CPU has SSE4A (SSE4.A instructions) {AMD} (deprecated). + kSSSE3, //!< CPU has SSSE3 (SSSE3 instructions). + kSVM, //!< CPU has SVM (virtualization) {AMD}. + kTBM, //!< CPU has TBM (trailing bit manipulation) {AMD}. + kTSE, //!< CPU has TSE. + kTSX, //!< CPU has TSX. + kTSXLDTRK, //!< CPU has TSXLDTRK. + kUINTR, //!< CPU has UINTR (user interrupts). + kVAES, //!< CPU has VAES (vector AES 256|512 bit support). + kVMX, //!< CPU has VMX (virtualization) {INTEL}. + kVPCLMULQDQ, //!< CPU has VPCLMULQDQ (vector PCLMULQDQ 256|512-bit support). + kWAITPKG, //!< CPU has WAITPKG (UMONITOR, UMWAIT, TPAUSE). + kWBNOINVD, //!< CPU has WBNOINVD. + kWRMSRNS, //!< CPU has WRMSRNS. + kXOP, //!< CPU has XOP (XOP instructions) {AMD} (deprecated). + kXSAVE, //!< CPU has XSAVE. + kXSAVEC, //!< CPU has XSAVEC. + kXSAVEOPT, //!< CPU has XSAVEOPT. + kXSAVES, //!< CPU has XSAVES. + // @EnumValuesEnd@ + +#ifndef ASMJIT_NO_DEPRECATED + kAVX512_CDI = kAVX512_CD, + kAVX512_ERI = kAVX512_ER, + kAVX512_PFI = kAVX512_PF, +#endif + + kMaxValue = kXSAVES + }; + + #define ASMJIT_X86_FEATURE(FEATURE) \ + /*! Tests whether FEATURE is present. */ \ + ASMJIT_INLINE_NODEBUG bool has##FEATURE() const noexcept { return has(X86::k##FEATURE); } + + ASMJIT_X86_FEATURE(MT) + ASMJIT_X86_FEATURE(NX) + ASMJIT_X86_FEATURE(3DNOW) + ASMJIT_X86_FEATURE(3DNOW2) + ASMJIT_X86_FEATURE(ADX) + ASMJIT_X86_FEATURE(AESNI) + ASMJIT_X86_FEATURE(ALTMOVCR8) + ASMJIT_X86_FEATURE(AMX_BF16) + ASMJIT_X86_FEATURE(AMX_COMPLEX) + ASMJIT_X86_FEATURE(AMX_FP16) + ASMJIT_X86_FEATURE(AMX_INT8) + ASMJIT_X86_FEATURE(AMX_TILE) + ASMJIT_X86_FEATURE(APX_F) + ASMJIT_X86_FEATURE(AVX) + ASMJIT_X86_FEATURE(AVX2) + ASMJIT_X86_FEATURE(AVX512_4FMAPS) + ASMJIT_X86_FEATURE(AVX512_4VNNIW) + ASMJIT_X86_FEATURE(AVX512_BF16) + ASMJIT_X86_FEATURE(AVX512_BITALG) + ASMJIT_X86_FEATURE(AVX512_BW) + ASMJIT_X86_FEATURE(AVX512_CD) + ASMJIT_X86_FEATURE(AVX512_DQ) + ASMJIT_X86_FEATURE(AVX512_ER) + ASMJIT_X86_FEATURE(AVX512_F) + ASMJIT_X86_FEATURE(AVX512_FP16) + ASMJIT_X86_FEATURE(AVX512_IFMA) + ASMJIT_X86_FEATURE(AVX512_PF) + ASMJIT_X86_FEATURE(AVX512_VBMI) + ASMJIT_X86_FEATURE(AVX512_VBMI2) + ASMJIT_X86_FEATURE(AVX512_VL) + ASMJIT_X86_FEATURE(AVX512_VNNI) + ASMJIT_X86_FEATURE(AVX512_VP2INTERSECT) + ASMJIT_X86_FEATURE(AVX512_VPOPCNTDQ) + ASMJIT_X86_FEATURE(AVX_IFMA) + ASMJIT_X86_FEATURE(AVX_NE_CONVERT) + ASMJIT_X86_FEATURE(AVX_VNNI) + ASMJIT_X86_FEATURE(AVX_VNNI_INT16) + ASMJIT_X86_FEATURE(AVX_VNNI_INT8) + ASMJIT_X86_FEATURE(BMI) + ASMJIT_X86_FEATURE(BMI2) + ASMJIT_X86_FEATURE(CET_IBT) + ASMJIT_X86_FEATURE(CET_SS) + ASMJIT_X86_FEATURE(CET_SSS) + ASMJIT_X86_FEATURE(CLDEMOTE) + ASMJIT_X86_FEATURE(CLFLUSH) + ASMJIT_X86_FEATURE(CLFLUSHOPT) + ASMJIT_X86_FEATURE(CLWB) + ASMJIT_X86_FEATURE(CLZERO) + ASMJIT_X86_FEATURE(CMOV) + ASMJIT_X86_FEATURE(CMPXCHG16B) + ASMJIT_X86_FEATURE(CMPXCHG8B) + ASMJIT_X86_FEATURE(ENCLV) + ASMJIT_X86_FEATURE(ENQCMD) + ASMJIT_X86_FEATURE(ERMS) + ASMJIT_X86_FEATURE(F16C) + ASMJIT_X86_FEATURE(FMA) + ASMJIT_X86_FEATURE(FMA4) + ASMJIT_X86_FEATURE(FPU) + ASMJIT_X86_FEATURE(FSGSBASE) + ASMJIT_X86_FEATURE(FSRM) + ASMJIT_X86_FEATURE(FSRC) + ASMJIT_X86_FEATURE(FSRS) + ASMJIT_X86_FEATURE(FXSR) + ASMJIT_X86_FEATURE(FXSROPT) + ASMJIT_X86_FEATURE(FZRM) + ASMJIT_X86_FEATURE(GEODE) + ASMJIT_X86_FEATURE(GFNI) + ASMJIT_X86_FEATURE(HLE) + ASMJIT_X86_FEATURE(HRESET) + ASMJIT_X86_FEATURE(I486) + ASMJIT_X86_FEATURE(INVLPGB) + ASMJIT_X86_FEATURE(LAHFSAHF) + ASMJIT_X86_FEATURE(LAM) + ASMJIT_X86_FEATURE(LWP) + ASMJIT_X86_FEATURE(LZCNT) + ASMJIT_X86_FEATURE(MCOMMIT) + ASMJIT_X86_FEATURE(MMX) + ASMJIT_X86_FEATURE(MMX2) + ASMJIT_X86_FEATURE(MONITOR) + ASMJIT_X86_FEATURE(MONITORX) + ASMJIT_X86_FEATURE(MOVBE) + ASMJIT_X86_FEATURE(MOVDIR64B) + ASMJIT_X86_FEATURE(MOVDIRI) + ASMJIT_X86_FEATURE(MPX) + ASMJIT_X86_FEATURE(MSR) + ASMJIT_X86_FEATURE(MSRLIST) + ASMJIT_X86_FEATURE(MSSE) + ASMJIT_X86_FEATURE(OSXSAVE) + ASMJIT_X86_FEATURE(OSPKE) + ASMJIT_X86_FEATURE(PCLMULQDQ) + ASMJIT_X86_FEATURE(PCONFIG) + ASMJIT_X86_FEATURE(POPCNT) + ASMJIT_X86_FEATURE(PREFETCHI) + ASMJIT_X86_FEATURE(PREFETCHW) + ASMJIT_X86_FEATURE(PREFETCHWT1) + ASMJIT_X86_FEATURE(PTWRITE) + ASMJIT_X86_FEATURE(RAO_INT) + ASMJIT_X86_FEATURE(RMPQUERY) + ASMJIT_X86_FEATURE(RDPID) + ASMJIT_X86_FEATURE(RDPRU) + ASMJIT_X86_FEATURE(RDRAND) + ASMJIT_X86_FEATURE(RDSEED) + ASMJIT_X86_FEATURE(RDTSC) + ASMJIT_X86_FEATURE(RDTSCP) + ASMJIT_X86_FEATURE(RTM) + ASMJIT_X86_FEATURE(SEAM) + ASMJIT_X86_FEATURE(SERIALIZE) + ASMJIT_X86_FEATURE(SEV) + ASMJIT_X86_FEATURE(SEV_ES) + ASMJIT_X86_FEATURE(SEV_SNP) + ASMJIT_X86_FEATURE(SHA) + ASMJIT_X86_FEATURE(SKINIT) + ASMJIT_X86_FEATURE(SMAP) + ASMJIT_X86_FEATURE(SMEP) + ASMJIT_X86_FEATURE(SMX) + ASMJIT_X86_FEATURE(SSE) + ASMJIT_X86_FEATURE(SSE2) + ASMJIT_X86_FEATURE(SSE3) + ASMJIT_X86_FEATURE(SSE4_1) + ASMJIT_X86_FEATURE(SSE4_2) + ASMJIT_X86_FEATURE(SSE4A) + ASMJIT_X86_FEATURE(SSSE3) + ASMJIT_X86_FEATURE(SVM) + ASMJIT_X86_FEATURE(TBM) + ASMJIT_X86_FEATURE(TSE) + ASMJIT_X86_FEATURE(TSX) + ASMJIT_X86_FEATURE(TSXLDTRK) + ASMJIT_X86_FEATURE(UINTR) + ASMJIT_X86_FEATURE(VAES) + ASMJIT_X86_FEATURE(VMX) + ASMJIT_X86_FEATURE(VPCLMULQDQ) + ASMJIT_X86_FEATURE(WAITPKG) + ASMJIT_X86_FEATURE(WBNOINVD) + ASMJIT_X86_FEATURE(WRMSRNS) + ASMJIT_X86_FEATURE(XOP) + ASMJIT_X86_FEATURE(XSAVE) + ASMJIT_X86_FEATURE(XSAVEC) + ASMJIT_X86_FEATURE(XSAVEOPT) + ASMJIT_X86_FEATURE(XSAVES) + +#ifndef ASMJIT_NO_DEPRECATED + ASMJIT_DEPRECATED("Use hasAVX512_CD() instead") + ASMJIT_X86_FEATURE(AVX512_CDI) + + ASMJIT_DEPRECATED("Use hasAVX512_ER() instead") + ASMJIT_X86_FEATURE(AVX512_ERI) + + ASMJIT_DEPRECATED("Use hasAVX512_PF() instead") + ASMJIT_X86_FEATURE(AVX512_PFI) +#endif + + #undef ASMJIT_X86_FEATURE + }; + + //! ARM specific features data. + //! + //! Naming reference: + //! - https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile + struct ARM : public Data { + //! ARM CPU feature identifiers. + enum Id : uint8_t { + // @EnumValuesBegin{"enum": "CpuFeatures::ARM"}@ + kNone = 0, //!< No feature (never set, used internally). + + kARMv6, //!< CPU is at least ARMv6 {A32}. + kARMv7, //!< CPU is at least ARMv7 {A32}. + kARMv8a, //!< CPU is at least ARMv8A. + kTHUMB, //!< CPU has THUMB (16-bit THUMB encoding) {A32}. + kTHUMBv2, //!< CPU has THUMBv2 (32-bit THUMB encoding) {A32}. + + kABLE, //!< CPU has ABLE (address breakpoint linking extension) {A64}. + kADERR, //!< CPU has ADERR (asynchronous device error exceptions) {A64}. + kAES, //!< CPU has AES (ASIMD AES instructions). + kAFP, //!< CPU has AFP (alternate floating-point behavior) {A64}. + kAIE, //!< CPU has AIE (memory attribute index enhancement) {A64}. + kAMU1, //!< CPU has AMUv1 (activity monitors extension version 1) {A64}. + kAMU1_1, //!< CPU has AMUv1p1 (activity monitors extension version 1.1) {A64}. + kANERR, //!< CPU has ANERR (asynchronous normal error exception) {A64}. + kASIMD, //!< CPU has ASIMD (NEON on ARM/THUMB). + kBF16, //!< CPU has BF16 (BFloat16 instructions) {A64}. + kBRBE, //!< CPU has BRBE (branch record buffer extension) {A64}. + kBTI, //!< CPU has BTI (branch target identification). + kBWE, //!< CPU has BWE (breakpoint mismatch and range extension) {A64}. + kCCIDX, //!< CPU has CCIDX (extend of the CCSIDR number of sets). + kCHK, //!< CPU has CHK (check feature status - CHKFEAT instruction) {A64}. + kCLRBHB, //!< CPU has CLRBHB (clear BHB instruction). + kCMOW, //!< CPU has CMOW (control for cache maintenance permission) {A64}. + kCONSTPACFIELD, //!< CPU has CONSTPACFIELD (PAC algorithm enhancement) {A64}. + kCPA, //!< CPU has CPA (instruction-only Checked Pointer Arithmetic) {A64}. + kCPA2, //!< CPU has CPA2 (checked Pointer Arithmetic) {A64}. + kCPUID, //!< CPU has CPUID (CPUID registers accessible in user-space). + kCRC32, //!< CPU has CRC32 (CRC32 instructions). + kCSSC, //!< CPU has CSSC (common short sequence compression) {A64}. + kCSV2, //!< CPU has CSV2 (cache speculation variant 2 version 2.1) {A64}. + kCSV2_3, //!< CPU has CSV2_3 (cache speculation variant 2 version 3) {A64}. + kCSV3, //!< CPU has CSV3 (cache speculation Variant 3) {A64}. + kD128, //!< CPU has D128 (128-bit translation tables, 56 bit PA) {A64}. + kDGH, //!< CPU has DGH (data gathering hint) {A64}. + kDIT, //!< CPU has DIT (data independent timing of instructions). + kDOTPROD, //!< CPU has DOTPROD (ASIMD Int8 dot product instructions). + kDPB, //!< CPU has DPB (DC CVAP instruction) {A64}. + kDPB2, //!< CPU has DPB2 (DC CVADP instruction) {A64}. + kEBEP, //!< CPU has EBEP (exception-based event profiling) {A64}. + kEBF16, //!< CPU has EBF16 (extended BFloat16 mode) {A64}. + kECBHB, //!< CPU has ECBHB (exploitative control using branch history information) {A64}. + kECV, //!< CPU has ECV (enhanced counter virtualization). + kEDHSR, //!< CPU has EDHSR (support for EDHSR) {A64}. + kEDSP, //!< CPU has EDSP (ARM/THUMB only). + kFAMINMAX, //!< CPU has FAMINMAX (floating-point maximum and minimum absolute value instructions) {A64}. + kFCMA, //!< CPU has FCMA (FCADD/FCMLA). + kFGT, //!< CPU has FGT (fine-grained traps). + kFGT2, //!< CPU has FGT2 (fine-grained traps 2). + kFHM, //!< CPU has FHM (half-precision floating-point FMLAL instructions). + kFLAGM, //!< CPU has FLAGM (condition flag manipulation) {A64}. + kFLAGM2, //!< CPU has FLAGM2 (condition flag manipulation version v2) {A64}. + kFMAC, //!< CPU has FMAC (ARM/THUMB only). + kFP, //!< CPU has FP (floating-point) (on 32-bit ARM this means VFPv3). + kFP16, //!< CPU has FP16 (half-precision floating-point data processing). + kFP16CONV, //!< CPU has FP16CONV (half-precision float conversion). + kFP8, //!< CPU has FP8 (FP8 convert instructions) {A64}. + kFP8DOT2, //!< CPU has FP8DOT2 (FP8 2-way dot product to half-precision instructions) {A64}. + kFP8DOT4, //!< CPU has FP8DOT4 (FP8 4-way dot product to single-precision instructions) {A64}. + kFP8FMA, //!< CPU has FP8FMA (FP8 multiply-accumulate to half-precision and single-precision instructions) {A64}. + kFPMR, //!< CPU has FPMR (floating-point Mode Register) {A64}. + kFRINTTS, //!< CPU has FRINTTS (FRINT[32|64][X|Z] instructions) {A64}. + kGCS, //!< CPU has GCS (guarded control stack extension) {A64}. + kHACDBS, //!< CPU has HACDBS (hardware accelerator for cleaning Dirty state) {A64}. + kHAFDBS, //!< CPU has HAFDBS (hardware management of the access flag and dirty state) {A64}. + kHAFT, //!< CPU has HAFT (hardware managed access flag for table descriptors) {A64}. + kHDBSS, //!< CPU has HDBSS (hardware Dirty state tracking Structure) {A64}. + kHBC, //!< CPU has HBC (hinted conditional branches) {A64}. + kHCX, //!< CPU has HCX (support for the HCRX_EL2 register) {A64}. + kHPDS, //!< CPU has HPDS (hierarchical permission disables in translation tables ) {A64}. + kHPDS2, //!< CPU has HPDS2 (hierarchical permission disables) {A64}. + kI8MM, //!< CPU has I8MM (int8 matrix multiplication) {A64}. + kIDIVA, //!< CPU has IDIV (hardware SDIV and UDIV in ARM mode). + kIDIVT, //!< CPU has IDIV (hardware SDIV and UDIV in THUMB mode). + kITE, //!< CPU has ITE (instrumentation extension) {A64}. + kJSCVT, //!< CPU has JSCVT (JavaScript FJCVTS conversion instruction) {A64}. + kLOR, //!< CPU has LOR (limited ordering regions extension). + kLRCPC, //!< CPU has LRCPC (load-acquire RCpc instructions) {A64}. + kLRCPC2, //!< CPU has LRCPC2 (load-acquire RCpc instructions v2) {A64}. + kLRCPC3, //!< CPU has LRCPC3 (load-Acquire RCpc instructions v3) {A64}. + kLS64, //!< CPU has LS64 (64 byte loads/stores without return) {A64}. + kLS64_ACCDATA, //!< CPU has LS64_ACCDATA (64-byte EL0 stores with return) {A64}. + kLS64_V, //!< CPU has LS64_V (64-byte stores with return) {A64}. + kLSE, //!< CPU has LSE (large system extensions) {A64}. + kLSE128, //!< CPU has LSE128 (128-bit atomics) {A64}. + kLSE2, //!< CPU has LSE2 (large system extensions v2) {A64}. + kLUT, //!< CPU has LUT (lookup table instructions with 2-bit and 4-bit indices) {A64}. + kLVA, //!< CPU has LVA (large VA support) {A64}. + kLVA3, //!< CPU has LVA3 (56-bit VA) {A64}. + kMEC, //!< CPU has MEC (memory encryption contexts) {A64}. + kMOPS, //!< CPU has MOPS (memcpy and memset acceleration instructions) {A64}. + kMPAM, //!< CPU has MPAM (memory system partitioning and monitoring extension) {A64}. + kMTE, //!< CPU has MTE (instruction-only memory tagging extension) {A64}. + kMTE2, //!< CPU has MTE2 (full memory tagging extension) {A64}. + kMTE3, //!< CPU has MTE3 (MTE asymmetric fault handling) {A64}. + kMTE4, //!< CPU has MTE4 (MTE v4) {A64}. + kMTE_ASYM_FAULT, //!< CPU has MTE_ASYM_FAULT (memory tagging asymmetric faults) {A64}. + kMTE_ASYNC, //!< CPU has MTE_ASYNC (memory tagging asynchronous faulting) {A64}. + kMTE_CANONICAL_TAGS, //!< CPU has MTE_CANONICAL_TAGS (canonical tag checking for untagged memory) {A64}. + kMTE_NO_ADDRESS_TAGS, //!< CPU has MTE_NO_ADDRESS_TAGS (memory tagging with address tagging disabled) {A64}. + kMTE_PERM_S1, //!< CPU has MTE_PERM_S1 (allocation tag access permission) {A64}. + kMTE_STORE_ONLY, //!< CPU has MTE_STORE_ONLY (store-only tag checking) {A64}. + kMTE_TAGGED_FAR, //!< CPU has MTE_TAGGED_FAR (FAR_ELx on a tag check fault) {A64}. + kMTPMU, //!< CPU has MTPMU (multi-threaded PMU extensions) {A64}. + kNMI, //!< CPU has NMI (non-maskable Interrupt) {A64}. + kNV, //!< CPU has NV (nested virtualization enchancement) {A64}. + kNV2, //!< CPU has NV2 (enhanced support for nested virtualization) {A64}. + kPAN, //!< CPU has PAN (privileged access-never extension) {A64}. + kPAN2, //!< CPU has PAN2 (PAN s1e1R and s1e1W variants) {A64}. + kPAN3, //!< CPU has PAN3 (support for SCTLR_ELx.EPAN) {A64}. + kPAUTH, //!< CPU has PAUTH (pointer authentication extension) {A64}. + kPFAR, //!< CPU has PFAR (physical fault address registers) {A64}. + kPMU, //!< CPU has PMU {A64}. + kPMULL, //!< CPU has PMULL (ASIMD PMULL instructions) {A64}. + kPRFMSLC, //!< CPU has PRFMSLC (PRFM instructions support the SLC target) {A64}. + kRAS, //!< CPU has RAS (reliability, availability and serviceability extensions). + kRAS1_1, //!< CPU has RASv1p1 (RAS v1.1). + kRAS2, //!< CPU has RASv2 (RAS v2). + kRASSA2, //!< CPU has RASSAv2 (RAS v2 system architecture). + kRDM, //!< CPU has RDM (rounding double multiply accumulate) {A64}. + kRME, //!< CPU has RME (memory encryption contexts extension) {A64}. + kRNG, //!< CPU has RNG (random number generation). + kRNG_TRAP, //!< CPU has RNG_TRAP (random number trap to EL3 field) {A64}. + kRPRES, //!< CPU has RPRES (increased precision of reciprocal estimate and RSQRT estimate) {A64}. + kRPRFM, //!< CPU has RPRFM (range prefetch hint instruction). + kS1PIE, //!< CPU has S1PIE (permission model enhancements) {A64}. + kS1POE, //!< CPU has S1POE (permission model enhancements) {A64}. + kS2PIE, //!< CPU has S2PIE (permission model enhancements) {A64}. + kS2POE, //!< CPU has S2POE (permission model enhancements) {A64}. + kSB, //!< CPU has SB (speculative barrier). + kSCTLR2, //!< CPU has SCTLR2 (extension to SCTLR_ELx) {A64}. + kSEBEP, //!< CPU has SEBEP (synchronous exception-based event profiling) {A64}. + kSEL2, //!< CPU has SEL2 (secure EL2) {A64}. + kSHA1, //!< CPU has SHA1 (ASIMD SHA1 instructions). + kSHA256, //!< CPU has SHA256 (ASIMD SHA256 instructions). + kSHA3, //!< CPU has SHA3 (ASIMD EOR3, RAX1, XAR, and BCAX instructions). + kSHA512, //!< CPU has SHA512 (ASIMD SHA512 instructions). + kSM3, //!< CPU has SM3 (ASIMD SM3 instructions). + kSM4, //!< CPU has SM4 (ASIMD SM4 instructions). + kSME, //!< CPU has SME (SME v1 - scalable matrix extension) {A64}. + kSME2, //!< CPU has SME2 (SME v2) {A64}. + kSME2_1, //!< CPU has SME2p1 (SME v2.1) {A64}. + kSME_B16B16, //!< CPU has SME_B16B16 (SME non-widening BFloat16 to BFloat16 arithmetic) {A64}. + kSME_B16F32, //!< CPU has SME_B16F32 (BFMOPA and BFMOPS instructions that accumulate BFloat16 outer products into single-precision tiles) {A64}. + kSME_BI32I32, //!< CPU has SME_BI32I32 (BMOPA and BMOPS instructions that accumulate 1-bit binary outer products into 32-bit integer tiles) {A64}. + kSME_F16F16, //!< CPU has SME_F16F16 (SME2.1 non-widening half-precision FP16 to FP16 arithmetic) {A64}. + kSME_F16F32, //!< CPU has SME_F16F32 {A64}. + kSME_F32F32, //!< CPU has SME_F32F32 {A64}. + kSME_F64F64, //!< CPU has SME_F64F64 {A64}. + kSME_F8F16, //!< CPU has SME_F8F16 (SME2 ZA-targeting FP8 multiply-accumulate, dot product, and outer product to half-precision instructions) {A64}. + kSME_F8F32, //!< CPU has SME_F8F32 (SME2 ZA-targeting FP8 multiply-accumulate, dot product, and outer product to single-precision instructions) {A64}. + kSME_FA64, //!< CPU has SME_FA64 {A64}. + kSME_I16I32, //!< CPU has SME_I16I32 {A64}. + kSME_I16I64, //!< CPU has SME_I16I64 {A64}. + kSME_I8I32, //!< CPU has SME_I8I32 {A64}. + kSME_LUTv2, //!< CPU has SME_LUTv2 (lookup table instructions with 4-bit indices and 8-bit elements) {A64}. + kSPE, //!< CPU has SPE (statistical profiling extension) {A64}. + kSPE1_1, //!< CPU has SPEv1p1 (statistical profiling extensions version 1.1) {A64}. + kSPE1_2, //!< CPU has SPEv1p2 (statistical profiling extensions version 1.2) {A64}. + kSPE1_3, //!< CPU has SPEv1p3 (statistical profiling extensions version 1.3) {A64}. + kSPE1_4, //!< CPU has SPEv1p4 (statistical profiling extensions version 1.4) {A64}. + kSPE_ALTCLK, //!< CPU has SPE_ALTCLK (statistical profiling alternate clock domain extension) {A64}. + kSPE_CRR, //!< CPU has SPE_CRR (statistical profiling call return branch records) {A64}. + kSPE_EFT, //!< CPU has SPE_EFT (statistical profiling extended filtering by type) {A64}. + kSPE_FDS, //!< CPU has SPE_FDS (statistical profiling data source filtering) {A64}. + kSPE_FPF, //!< CPU has SPE_FPF (statistical profiling floating-point flag extension) {A64}. + kSPE_SME, //!< CPU has SPE_SME (statistical profiling extensions for SME) {A64}. + kSPECRES, //!< CPU has SPECRES (speculation restriction instructions). + kSPECRES2, //!< CPU has SPECRES2 (clear other speculative predictions). + kSPMU, //!< CPU has SPMU (system performance monitors extension) {A64}. + kSSBS, //!< CPU has SSBS (speculative store bypass safe instruction). + kSSBS2, //!< CPU has SSBS2 (MRS and MSR instructions for SSBS). + kSSVE_FP8DOT2, //!< CPU has SSVE_FP8DOT2 (SVE2 FP8 2-way dot product to half-precision instructions in Streaming SVE mode) {A64}. + kSSVE_FP8DOT4, //!< CPU has SSVE_FP8DOT4 (SVE2 FP8 4-way dot product to single-precision instructions in Streaming SVE mode) {A64}. + kSSVE_FP8FMA, //!< CPU has SSVE_FP8FMA (SVE2 FP8 multiply-accumulate to half-precision and single-precision instructions in Streaming SVE mode) {A64}. + kSVE, //!< CPU has SVE (SVE v1 - scalable vector extension) {A64}. + kSVE2, //!< CPU has SVE2 (SVE v2) {A64}. + kSVE2_1, //!< CPU has SVE2p1 (SVE v2.1) {A64}. + kSVE_AES, //!< CPU has SVE_AES (SVE AES instructions) {A64}. + kSVE_B16B16, //!< CPU has SVE_B16B16 (SVE non-widening BFloat16 to BFloat16 arithmetic) {A64}. + kSVE_BF16, //!< CPU has SVE_BF16 (SVE BF16 instructions) {A64}. + kSVE_BITPERM, //!< CPU has SVE_BITPERM (SVE bit permute) {A64}. + kSVE_EBF16, //!< CPU has SVE_EBF16 (SVE extended BFloat16 mode) {A64}. + kSVE_F32MM, //!< CPU has SVE_F32MM (SVE single-precision floating-point matrix multiply instruction) {A64}. + kSVE_F64MM, //!< CPU has SVE_F64MM (SVE double-precision floating-point matrix multiply instruction) {A64}. + kSVE_I8MM, //!< CPU has SVE_I8MM (SVE int8 matrix multiplication) {A64}. + kSVE_PMULL128, //!< CPU has SVE_PMULL128 (SVE PMULL instructions) {A64}. + kSVE_SHA3, //!< CPU has SVE_SHA3 (SVE SHA-3 instructions) {A64}. + kSVE_SM4, //!< CPU has SVE_SM4 (SVE SM4 instructions {A64}. + kSYSINSTR128, //!< CPU has SYSINSTR128 (128-bit system instructions) {A64}. + kSYSREG128, //!< CPU has SYSREG128 (128-bit system registers) {A64}. + kTHE, //!< CPU has THE (translation hardening extension). + kTLBIOS, //!< CPU has TLBIOS (TLBI instructions in Outer Shareable domain) {A64}. + kTLBIRANGE, //!< CPU has TLBIRANGE (TLBI range instructions) {A64}. + kTLBIW, //!< CPU has TLBIW (TLBI VMALL for dirty state) {A64}. + kTME, //!< CPU has TME (transactional memory extensions). + kTRF, //!< CPU has TRF (self-hosted trace extensions). + kUAO, //!< CPU has UAO (AArch64 v8.2 UAO PState) {A64}. + kVFP_D32, //!< CPU has VFP_D32 (32 VFP-D registers) (ARM/THUMB only). + kVHE, //!< CPU has VHE (virtual host extension). + kVMID16, //!< CPU has VMID16 (16-bit VMID) {A64}. + kWFXT, //!< CPU has WFxT (WFE and WFI instructions with timeout) {A64}. + kXNX, //!< CPU has XNX (translation table stage 2 unprivileged execute-never) {A64}. + kXS, //!< CPU has XS (XS attribute in TLBI and DSB instructions) {A64}. + // @EnumValuesEnd@ + + kMaxValue = kXS + }; + + #define ASMJIT_ARM_FEATURE(FEATURE) \ + /*! Tests whether FEATURE is present. */ \ + ASMJIT_INLINE_NODEBUG bool has##FEATURE() const noexcept { return has(ARM::k##FEATURE); } + + ASMJIT_ARM_FEATURE(THUMB) + ASMJIT_ARM_FEATURE(THUMBv2) + + ASMJIT_ARM_FEATURE(ARMv6) + ASMJIT_ARM_FEATURE(ARMv7) + ASMJIT_ARM_FEATURE(ARMv8a) + + ASMJIT_ARM_FEATURE(ABLE) + ASMJIT_ARM_FEATURE(ADERR) + ASMJIT_ARM_FEATURE(AES) + ASMJIT_ARM_FEATURE(AFP) + ASMJIT_ARM_FEATURE(AIE) + ASMJIT_ARM_FEATURE(AMU1) + ASMJIT_ARM_FEATURE(AMU1_1) + ASMJIT_ARM_FEATURE(ANERR) + ASMJIT_ARM_FEATURE(ASIMD) + ASMJIT_ARM_FEATURE(BF16) + ASMJIT_ARM_FEATURE(BRBE) + ASMJIT_ARM_FEATURE(BTI) + ASMJIT_ARM_FEATURE(BWE) + ASMJIT_ARM_FEATURE(CCIDX) + ASMJIT_ARM_FEATURE(CHK) + ASMJIT_ARM_FEATURE(CLRBHB) + ASMJIT_ARM_FEATURE(CMOW) + ASMJIT_ARM_FEATURE(CONSTPACFIELD) + ASMJIT_ARM_FEATURE(CPA) + ASMJIT_ARM_FEATURE(CPA2) + ASMJIT_ARM_FEATURE(CPUID) + ASMJIT_ARM_FEATURE(CRC32) + ASMJIT_ARM_FEATURE(CSSC) + ASMJIT_ARM_FEATURE(CSV2) + ASMJIT_ARM_FEATURE(CSV2_3) + ASMJIT_ARM_FEATURE(CSV3) + ASMJIT_ARM_FEATURE(D128) + ASMJIT_ARM_FEATURE(DGH) + ASMJIT_ARM_FEATURE(DIT) + ASMJIT_ARM_FEATURE(DOTPROD) + ASMJIT_ARM_FEATURE(DPB) + ASMJIT_ARM_FEATURE(DPB2) + ASMJIT_ARM_FEATURE(EBEP) + ASMJIT_ARM_FEATURE(EBF16) + ASMJIT_ARM_FEATURE(ECBHB) + ASMJIT_ARM_FEATURE(ECV) + ASMJIT_ARM_FEATURE(EDHSR) + ASMJIT_ARM_FEATURE(EDSP) + ASMJIT_ARM_FEATURE(FAMINMAX) + ASMJIT_ARM_FEATURE(FCMA) + ASMJIT_ARM_FEATURE(FGT) + ASMJIT_ARM_FEATURE(FGT2) + ASMJIT_ARM_FEATURE(FHM) + ASMJIT_ARM_FEATURE(FLAGM) + ASMJIT_ARM_FEATURE(FLAGM2) + ASMJIT_ARM_FEATURE(FMAC) + ASMJIT_ARM_FEATURE(FP) + ASMJIT_ARM_FEATURE(FP16) + ASMJIT_ARM_FEATURE(FP16CONV) + ASMJIT_ARM_FEATURE(FP8) + ASMJIT_ARM_FEATURE(FP8DOT2) + ASMJIT_ARM_FEATURE(FP8DOT4) + ASMJIT_ARM_FEATURE(FP8FMA) + ASMJIT_ARM_FEATURE(FPMR) + ASMJIT_ARM_FEATURE(FRINTTS) + ASMJIT_ARM_FEATURE(GCS) + ASMJIT_ARM_FEATURE(HACDBS) + ASMJIT_ARM_FEATURE(HAFDBS) + ASMJIT_ARM_FEATURE(HAFT) + ASMJIT_ARM_FEATURE(HDBSS) + ASMJIT_ARM_FEATURE(HBC) + ASMJIT_ARM_FEATURE(HCX) + ASMJIT_ARM_FEATURE(HPDS) + ASMJIT_ARM_FEATURE(HPDS2) + ASMJIT_ARM_FEATURE(I8MM) + ASMJIT_ARM_FEATURE(IDIVA) + ASMJIT_ARM_FEATURE(IDIVT) + ASMJIT_ARM_FEATURE(ITE) + ASMJIT_ARM_FEATURE(JSCVT) + ASMJIT_ARM_FEATURE(LOR) + ASMJIT_ARM_FEATURE(LRCPC) + ASMJIT_ARM_FEATURE(LRCPC2) + ASMJIT_ARM_FEATURE(LRCPC3) + ASMJIT_ARM_FEATURE(LS64) + ASMJIT_ARM_FEATURE(LS64_ACCDATA) + ASMJIT_ARM_FEATURE(LS64_V) + ASMJIT_ARM_FEATURE(LSE) + ASMJIT_ARM_FEATURE(LSE128) + ASMJIT_ARM_FEATURE(LSE2) + ASMJIT_ARM_FEATURE(LUT) + ASMJIT_ARM_FEATURE(LVA) + ASMJIT_ARM_FEATURE(LVA3) + ASMJIT_ARM_FEATURE(MEC) + ASMJIT_ARM_FEATURE(MOPS) + ASMJIT_ARM_FEATURE(MPAM) + ASMJIT_ARM_FEATURE(MTE) + ASMJIT_ARM_FEATURE(MTE2) + ASMJIT_ARM_FEATURE(MTE3) + ASMJIT_ARM_FEATURE(MTE4) + ASMJIT_ARM_FEATURE(MTE_ASYM_FAULT) + ASMJIT_ARM_FEATURE(MTE_ASYNC) + ASMJIT_ARM_FEATURE(MTE_CANONICAL_TAGS) + ASMJIT_ARM_FEATURE(MTE_NO_ADDRESS_TAGS) + ASMJIT_ARM_FEATURE(MTE_PERM_S1) + ASMJIT_ARM_FEATURE(MTE_STORE_ONLY) + ASMJIT_ARM_FEATURE(MTE_TAGGED_FAR) + ASMJIT_ARM_FEATURE(MTPMU) + ASMJIT_ARM_FEATURE(NMI) + ASMJIT_ARM_FEATURE(NV) + ASMJIT_ARM_FEATURE(NV2) + ASMJIT_ARM_FEATURE(PAN) + ASMJIT_ARM_FEATURE(PAN2) + ASMJIT_ARM_FEATURE(PAN3) + ASMJIT_ARM_FEATURE(PAUTH) + ASMJIT_ARM_FEATURE(PFAR) + ASMJIT_ARM_FEATURE(PMU) + ASMJIT_ARM_FEATURE(PMULL) + ASMJIT_ARM_FEATURE(PRFMSLC) + ASMJIT_ARM_FEATURE(RAS) + ASMJIT_ARM_FEATURE(RAS1_1) + ASMJIT_ARM_FEATURE(RAS2) + ASMJIT_ARM_FEATURE(RASSA2) + ASMJIT_ARM_FEATURE(RDM) + ASMJIT_ARM_FEATURE(RME) + ASMJIT_ARM_FEATURE(RNG) + ASMJIT_ARM_FEATURE(RNG_TRAP) + ASMJIT_ARM_FEATURE(RPRES) + ASMJIT_ARM_FEATURE(RPRFM) + ASMJIT_ARM_FEATURE(S1PIE) + ASMJIT_ARM_FEATURE(S1POE) + ASMJIT_ARM_FEATURE(S2PIE) + ASMJIT_ARM_FEATURE(S2POE) + ASMJIT_ARM_FEATURE(SB) + ASMJIT_ARM_FEATURE(SCTLR2) + ASMJIT_ARM_FEATURE(SEBEP) + ASMJIT_ARM_FEATURE(SEL2) + ASMJIT_ARM_FEATURE(SHA1) + ASMJIT_ARM_FEATURE(SHA256) + ASMJIT_ARM_FEATURE(SHA3) + ASMJIT_ARM_FEATURE(SHA512) + ASMJIT_ARM_FEATURE(SM3) + ASMJIT_ARM_FEATURE(SM4) + ASMJIT_ARM_FEATURE(SME) + ASMJIT_ARM_FEATURE(SME2) + ASMJIT_ARM_FEATURE(SME2_1) + ASMJIT_ARM_FEATURE(SME_B16B16) + ASMJIT_ARM_FEATURE(SME_B16F32) + ASMJIT_ARM_FEATURE(SME_BI32I32) + ASMJIT_ARM_FEATURE(SME_F16F16) + ASMJIT_ARM_FEATURE(SME_F16F32) + ASMJIT_ARM_FEATURE(SME_F32F32) + ASMJIT_ARM_FEATURE(SME_F64F64) + ASMJIT_ARM_FEATURE(SME_F8F16) + ASMJIT_ARM_FEATURE(SME_F8F32) + ASMJIT_ARM_FEATURE(SME_FA64) + ASMJIT_ARM_FEATURE(SME_I16I32) + ASMJIT_ARM_FEATURE(SME_I16I64) + ASMJIT_ARM_FEATURE(SME_I8I32) + ASMJIT_ARM_FEATURE(SME_LUTv2) + ASMJIT_ARM_FEATURE(SPE) + ASMJIT_ARM_FEATURE(SPE1_1) + ASMJIT_ARM_FEATURE(SPE1_2) + ASMJIT_ARM_FEATURE(SPE1_3) + ASMJIT_ARM_FEATURE(SPE1_4) + ASMJIT_ARM_FEATURE(SPE_ALTCLK) + ASMJIT_ARM_FEATURE(SPE_CRR) + ASMJIT_ARM_FEATURE(SPE_EFT) + ASMJIT_ARM_FEATURE(SPE_FDS) + ASMJIT_ARM_FEATURE(SPE_FPF) + ASMJIT_ARM_FEATURE(SPE_SME) + ASMJIT_ARM_FEATURE(SPECRES) + ASMJIT_ARM_FEATURE(SPECRES2) + ASMJIT_ARM_FEATURE(SPMU) + ASMJIT_ARM_FEATURE(SSBS) + ASMJIT_ARM_FEATURE(SSBS2) + ASMJIT_ARM_FEATURE(SSVE_FP8DOT2) + ASMJIT_ARM_FEATURE(SSVE_FP8DOT4) + ASMJIT_ARM_FEATURE(SSVE_FP8FMA) + ASMJIT_ARM_FEATURE(SVE) + ASMJIT_ARM_FEATURE(SVE2) + ASMJIT_ARM_FEATURE(SVE2_1) + ASMJIT_ARM_FEATURE(SVE_AES) + ASMJIT_ARM_FEATURE(SVE_B16B16) + ASMJIT_ARM_FEATURE(SVE_BF16) + ASMJIT_ARM_FEATURE(SVE_BITPERM) + ASMJIT_ARM_FEATURE(SVE_EBF16) + ASMJIT_ARM_FEATURE(SVE_F32MM) + ASMJIT_ARM_FEATURE(SVE_F64MM) + ASMJIT_ARM_FEATURE(SVE_I8MM) + ASMJIT_ARM_FEATURE(SVE_PMULL128) + ASMJIT_ARM_FEATURE(SVE_SHA3) + ASMJIT_ARM_FEATURE(SVE_SM4) + ASMJIT_ARM_FEATURE(SYSINSTR128) + ASMJIT_ARM_FEATURE(SYSREG128) + ASMJIT_ARM_FEATURE(THE) + ASMJIT_ARM_FEATURE(TLBIOS) + ASMJIT_ARM_FEATURE(TLBIRANGE) + ASMJIT_ARM_FEATURE(TLBIW) + ASMJIT_ARM_FEATURE(TME) + ASMJIT_ARM_FEATURE(TRF) + ASMJIT_ARM_FEATURE(UAO) + ASMJIT_ARM_FEATURE(VFP_D32) + ASMJIT_ARM_FEATURE(VHE) + ASMJIT_ARM_FEATURE(VMID16) + ASMJIT_ARM_FEATURE(WFXT) + ASMJIT_ARM_FEATURE(XNX) + ASMJIT_ARM_FEATURE(XS) + + #undef ASMJIT_ARM_FEATURE + }; + + static_assert(uint32_t(X86::kMaxValue) < kMaxFeatures, "The number of X86 CPU features cannot exceed CpuFeatures::kMaxFeatures"); + static_assert(uint32_t(ARM::kMaxValue) < kMaxFeatures, "The number of ARM CPU features cannot exceed CpuFeatures::kMaxFeatures"); + + //! \} + + //! \name Members + //! \{ + + Data _data {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG CpuFeatures() noexcept {} + ASMJIT_INLINE_NODEBUG CpuFeatures(const CpuFeatures& other) noexcept = default; + ASMJIT_INLINE_NODEBUG explicit CpuFeatures(const Data& other) noexcept : _data{other._bits} {} + ASMJIT_INLINE_NODEBUG explicit CpuFeatures(Globals::NoInit_) noexcept {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG CpuFeatures& operator=(const CpuFeatures& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG bool operator==(const CpuFeatures& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const CpuFeatures& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns true if there are no features set. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _data.empty(); } + + //! Casts this base class into a derived type `T`. + template + ASMJIT_INLINE_NODEBUG T& data() noexcept { return static_cast(_data); } + + //! Casts this base class into a derived type `T` (const). + template + ASMJIT_INLINE_NODEBUG const T& data() const noexcept { return static_cast(_data); } + + //! Returns CpuFeatures::Data as \ref CpuFeatures::X86. + ASMJIT_INLINE_NODEBUG X86& x86() noexcept { return data(); } + //! Returns CpuFeatures::Data as \ref CpuFeatures::X86 (const). + ASMJIT_INLINE_NODEBUG const X86& x86() const noexcept { return data(); } + + //! Returns CpuFeatures::Data as \ref CpuFeatures::ARM. + ASMJIT_INLINE_NODEBUG ARM& arm() noexcept { return data(); } + //! Returns CpuFeatures::Data as \ref CpuFeatures::ARM (const). + ASMJIT_INLINE_NODEBUG const ARM& arm() const noexcept { return data(); } + + //! Returns all features as array of bitwords (see \ref Support::BitWord). + ASMJIT_INLINE_NODEBUG BitWord* bits() noexcept { return _data.bits(); } + //! Returns all features as array of bitwords (const). + ASMJIT_INLINE_NODEBUG const BitWord* bits() const noexcept { return _data.bits(); } + //! Returns the number of BitWords returned by \ref bits(). + ASMJIT_INLINE_NODEBUG size_t bitWordCount() const noexcept { return _data.bitWordCount(); } + + //! Returns \ref Support::BitVectorIterator, that can be used to iterate over all features efficiently. + ASMJIT_INLINE_NODEBUG Iterator iterator() const noexcept { return _data.iterator(); } + + //! Tests whether the feature `featureId` is present. + template + ASMJIT_INLINE_NODEBUG bool has(const FeatureId& featureId) const noexcept { return _data.has(featureId); } + + //! Tests whether any of the features is present. + template + ASMJIT_INLINE_NODEBUG bool hasAny(Args&&... args) const noexcept { return _data.hasAny(std::forward(args)...); } + + //! Tests whether all features as defined by `other` are present. + ASMJIT_INLINE_NODEBUG bool hasAll(const CpuFeatures& other) const noexcept { return _data.hasAll(other._data); } + + //! \} + + //! \name Manipulation + //! \{ + + //! Clears all features set. + ASMJIT_INLINE_NODEBUG void reset() noexcept { _data.reset(); } + + //! Adds the given CPU `featureId` to the list of features. + template + ASMJIT_INLINE_NODEBUG void add(Args&&... args) noexcept { return _data.add(std::forward(args)...); } + + //! Adds the given CPU `featureId` to the list of features if `condition` is true. + template + ASMJIT_INLINE_NODEBUG void addIf(bool condition, Args&&... args) noexcept { return _data.addIf(condition, std::forward(args)...); } + + //! Removes the given CPU `featureId` from the list of features. + template + ASMJIT_INLINE_NODEBUG void remove(Args&&... args) noexcept { return _data.remove(std::forward(args)...); } + + //! Tests whether this CPU features matches `other`. + ASMJIT_INLINE_NODEBUG bool equals(const CpuFeatures& other) const noexcept { return _data.equals(other._data); } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use CpuFeatures::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const CpuFeatures& other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} +}; + +//! CPU information. +class CpuInfo { +public: + //! \name Members + //! \{ + + //! Architecture. + Arch _arch {}; + //! Sub-architecture. + SubArch _subArch {}; + //! True if the CPU was detected, false if the detection failed or it's not available. + bool _wasDetected {}; + //! Reserved for future use. + uint8_t _reserved {}; + //! CPU family ID. + uint32_t _familyId {}; + //! CPU model ID. + uint32_t _modelId {}; + //! CPU brand ID. + uint32_t _brandId {}; + //! CPU stepping. + uint32_t _stepping {}; + //! Processor type. + uint32_t _processorType {}; + //! Maximum number of addressable IDs for logical processors. + uint32_t _maxLogicalProcessors {}; + //! Cache line size (in bytes). + uint32_t _cacheLineSize {}; + //! Number of hardware threads. + uint32_t _hwThreadCount {}; + + //! CPU vendor string. + FixedString<16> _vendor {}; + //! CPU brand string. + FixedString<64> _brand {}; + //! CPU features. + CpuFeatures _features {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new CpuInfo instance. + ASMJIT_INLINE_NODEBUG CpuInfo() noexcept {} + //! Creates a copy of `other` instance. + ASMJIT_INLINE_NODEBUG CpuInfo(const CpuInfo& other) noexcept = default; + + //! Creates an unitialized `CpuInfo` instance. + ASMJIT_INLINE_NODEBUG explicit CpuInfo(Globals::NoInit_) noexcept + : _features(Globals::NoInit) {}; + + //! \} + + //! \name CPU Information Detection + //! \{ + + //! Returns the host CPU information. + //! + //! \note The returned reference is global - it's setup only once and then shared. + ASMJIT_API static const CpuInfo& host() noexcept; + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment. + ASMJIT_INLINE_NODEBUG CpuInfo& operator=(const CpuInfo& other) noexcept = default; + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Initializes CpuInfo architecture and sub-architecture members to `arch` and `subArch`, respectively. + ASMJIT_INLINE_NODEBUG void initArch(Arch arch, SubArch subArch = SubArch::kUnknown) noexcept { + _arch = arch; + _subArch = subArch; + } + + //! Resets this \ref CpuInfo to a default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = CpuInfo{}; } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the CPU architecture this information relates to. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + + //! Returns the CPU sub-architecture this information relates to. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return _subArch; } + + //! Returns whether the CPU was detected successfully. + //! + //! If the returned value is false it means that AsmJit either failed to detect the CPU or it doesn't have + //! implementation targeting the host architecture and operating system. + ASMJIT_INLINE_NODEBUG bool wasDetected() const noexcept { return _wasDetected; } + + //! Returns the CPU family ID. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Family identifier matches the FamilyId read by using CPUID. + //! - ARM: + //! - Apple - returns Apple Family identifier returned by sysctlbyname("hw.cpufamily"). + ASMJIT_INLINE_NODEBUG uint32_t familyId() const noexcept { return _familyId; } + + //! Returns the CPU model ID. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Model identifier matches the ModelId read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t modelId() const noexcept { return _modelId; } + + //! Returns the CPU brand id. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Brand identifier matches the BrandId read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t brandId() const noexcept { return _brandId; } + + //! Returns the CPU stepping. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Stepping identifier matches the Stepping information read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t stepping() const noexcept { return _stepping; } + + //! Returns the processor type. + //! + //! The information provided depends on architecture and OS: + //! - X86: + //! - Processor type identifier matches the ProcessorType read by using CPUID. + ASMJIT_INLINE_NODEBUG uint32_t processorType() const noexcept { return _processorType; } + + //! Returns the maximum number of logical processors. + ASMJIT_INLINE_NODEBUG uint32_t maxLogicalProcessors() const noexcept { return _maxLogicalProcessors; } + + //! Returns the size of a CPU cache line. + //! + //! On a multi-architecture system this should return the smallest cache line of all CPUs. + ASMJIT_INLINE_NODEBUG uint32_t cacheLineSize() const noexcept { return _cacheLineSize; } + + //! Returns number of hardware threads available. + ASMJIT_INLINE_NODEBUG uint32_t hwThreadCount() const noexcept { return _hwThreadCount; } + + //! Returns a CPU vendor string. + ASMJIT_INLINE_NODEBUG const char* vendor() const noexcept { return _vendor.str; } + //! Tests whether the CPU vendor string is equal to `s`. + ASMJIT_INLINE_NODEBUG bool isVendor(const char* s) const noexcept { return _vendor.equals(s); } + + //! Returns a CPU brand string. + ASMJIT_INLINE_NODEBUG const char* brand() const noexcept { return _brand.str; } + + //! Returns CPU features. + ASMJIT_INLINE_NODEBUG CpuFeatures& features() noexcept { return _features; } + //! Returns CPU features (const). + ASMJIT_INLINE_NODEBUG const CpuFeatures& features() const noexcept { return _features; } + + //! Tests whether the CPU has the given `feature`. + template + ASMJIT_INLINE_NODEBUG bool hasFeature(const FeatureId& featureId) const noexcept { return _features.has(featureId); } + + //! Adds the given CPU `featureId` to the list of features. + template + ASMJIT_INLINE_NODEBUG void addFeature(Args&&... args) noexcept { return _features.add(std::forward(args)...); } + + //! Removes the given CPU `featureId` from the list of features. + template + ASMJIT_INLINE_NODEBUG void removeFeature(Args&&... args) noexcept { return _features.remove(std::forward(args)...); } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_CPUINFO_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/emitter.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..272cb26c266b65a520dc8cf7a45057ef8f7e3ba9 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/emitter.h @@ -0,0 +1,817 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_EMITTER_H_INCLUDED +#define ASMJIT_CORE_EMITTER_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/codeholder.h" +#include "../core/formatter.h" +#include "../core/inst.h" +#include "../core/operand.h" +#include "../core/type.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +class ConstPool; +class FuncFrame; +class FuncArgsAssignment; + +//! Align mode, used by \ref BaseEmitter::align(). +enum class AlignMode : uint8_t { + //! Align executable code. + kCode = 0, + //! Align non-executable code. + kData = 1, + //! Align by a sequence of zeros. + kZero = 2, + + //! Maximum value of `AlignMode`. + kMaxValue = kZero +}; + +//! Emitter type used by \ref BaseEmitter. +enum class EmitterType : uint8_t { + //! Unknown or uninitialized. + kNone = 0, + //! Emitter inherits from \ref BaseAssembler. + kAssembler = 1, + //! Emitter inherits from \ref BaseBuilder. + kBuilder = 2, + //! Emitter inherits from \ref BaseCompiler. + kCompiler = 3, + + //! Maximum value of `EmitterType`. + kMaxValue = kCompiler +}; + +//! Emitter flags, used by \ref BaseEmitter. +enum class EmitterFlags : uint8_t { + //! No flags. + kNone = 0u, + //! Emitter is attached to CodeHolder. + kAttached = 0x01u, + //! The emitter must emit comments. + kLogComments = 0x08u, + //! The emitter has its own \ref Logger (not propagated from \ref CodeHolder). + kOwnLogger = 0x10u, + //! The emitter has its own \ref ErrorHandler (not propagated from \ref CodeHolder). + kOwnErrorHandler = 0x20u, + //! The emitter was finalized. + kFinalized = 0x40u, + //! The emitter was destroyed. + //! + //! This flag is used for a very short time when an emitter is being destroyed by + //! CodeHolder. + kDestroyed = 0x80u +}; +ASMJIT_DEFINE_ENUM_FLAGS(EmitterFlags) + +//! Encoding options. +enum class EncodingOptions : uint32_t { + //! No encoding options. + kNone = 0, + + //! Emit instructions that are optimized for size, if possible. + //! + //! Default: false. + //! + //! X86 Specific + //! ------------ + //! + //! When this option is set it the assembler will try to fix instructions if possible into operation equivalent + //! instructions that take less bytes by taking advantage of implicit zero extension. For example instruction + //! like `mov r64, imm` and `and r64, imm` can be translated to `mov r32, imm` and `and r32, imm` when the + //! immediate constant is lesser than `2^31`. + kOptimizeForSize = 0x00000001u, + + //! Emit optimized code-alignment sequences. + //! + //! Default: false. + //! + //! X86 Specific + //! ------------ + //! + //! Default align sequence used by X86 architecture is one-byte (0x90) opcode that is often shown by disassemblers + //! as NOP. However there are more optimized align sequences for 2-11 bytes that may execute faster on certain CPUs. + //! If this feature is enabled AsmJit will generate specialized sequences for alignment between 2 to 11 bytes. + kOptimizedAlign = 0x00000002u, + + //! Emit jump-prediction hints. + //! + //! Default: false. + //! + //! X86 Specific + //! ------------ + //! + //! Jump prediction is usually based on the direction of the jump. If the jump is backward it is usually predicted as + //! taken; and if the jump is forward it is usually predicted as not-taken. The reason is that loops generally use + //! backward jumps and conditions usually use forward jumps. However this behavior can be overridden by using + //! instruction prefixes. If this option is enabled these hints will be emitted. + //! + //! This feature is disabled by default, because the only processor that used to take into consideration prediction + //! hints was P4. Newer processors implement heuristics for branch prediction and ignore static hints. This means + //! that this feature can be only used for annotation purposes. + kPredictedJumps = 0x00000010u +}; +ASMJIT_DEFINE_ENUM_FLAGS(EncodingOptions) + +//! Diagnostic options are used to tell emitters and their passes to perform diagnostics when emitting or processing +//! user code. These options control validation and extra diagnostics that can be performed by higher level emitters. +//! +//! Instruction Validation +//! ---------------------- +//! +//! \ref BaseAssembler implementation perform by default only basic checks that are necessary to identify all +//! variations of an instruction so the correct encoding can be selected. This is fine for production-ready code +//! as the assembler doesn't have to perform checks that would slow it down. However, sometimes these checks are +//! beneficial especially when the project that uses AsmJit is in a development phase, in which mistakes happen +//! often. To make the experience of using AsmJit seamless it offers validation features that can be controlled +//! by \ref DiagnosticOptions. +//! +//! Compiler Diagnostics +//! -------------------- +//! +//! Diagnostic options work with \ref BaseCompiler passes (precisely with its register allocation pass). These options +//! can be used to enable logging of all operations that the Compiler does. +enum class DiagnosticOptions : uint32_t { + //! No validation options. + kNone = 0, + + //! Perform strict validation in \ref BaseAssembler::emit() implementations. + //! + //! This flag ensures that each instruction is checked before it's encoded into a binary representation. This flag + //! is only relevant for \ref BaseAssembler implementations, but can be set in any other emitter type, in that case + //! if that emitter needs to create an assembler on its own, for the purpose of \ref BaseEmitter::finalize() it + //! would propagate this flag to such assembler so all instructions passed to it are explicitly validated. + //! + //! Default: false. + kValidateAssembler = 0x00000001u, + + //! Perform strict validation in \ref BaseBuilder::emit() and \ref BaseCompiler::emit() implementations. + //! + //! This flag ensures that each instruction is checked before an \ref InstNode representing the instruction is + //! created by \ref BaseBuilder or \ref BaseCompiler. This option could be more useful than \ref kValidateAssembler + //! in cases in which there is an invalid instruction passed to an assembler, which was invalid much earlier, most + //! likely when such instruction was passed to Builder/Compiler. + //! + //! This is a separate option that was introduced, because it's possible to manipulate the instruction stream + //! emitted by \ref BaseBuilder and \ref BaseCompiler - this means that it's allowed to emit invalid instructions + //! (for example with missing operands) that will be fixed later before finalizing it. + //! + //! Default: false. + kValidateIntermediate = 0x00000002u, + + //! Annotate all nodes processed by register allocator (Compiler/RA). + //! + //! \note Annotations don't need debug options, however, some debug options like `kRADebugLiveness` may influence + //! their output (for example the mentioned option would add liveness information to per-instruction annotation). + kRAAnnotate = 0x00000080u, + + //! Debug CFG generation and other related algorithms / operations (Compiler/RA). + kRADebugCFG = 0x00000100u, + + //! Debug liveness analysis (Compiler/RA). + kRADebugLiveness = 0x00000200u, + + //! Debug register allocation assignment (Compiler/RA). + kRADebugAssignment = 0x00000400u, + + //! Debug the removal of code part of unreachable blocks. + kRADebugUnreachable = 0x00000800u, + + //! Enable all debug options (Compiler/RA). + kRADebugAll = 0x0000FF00u, +}; +ASMJIT_DEFINE_ENUM_FLAGS(DiagnosticOptions) + +//! Provides a base foundation to emitting code - specialized by \ref BaseAssembler and \ref BaseBuilder. +class ASMJIT_VIRTAPI BaseEmitter { +public: + ASMJIT_BASE_CLASS(BaseEmitter) + ASMJIT_NONCOPYABLE(BaseEmitter) + + //! \name Members + //! \{ + + //! See \ref EmitterType. + EmitterType _emitterType = EmitterType::kNone; + //! See \ref EmitterFlags. + EmitterFlags _emitterFlags = EmitterFlags::kNone; + //! Instruction alignment. + uint8_t _instructionAlignment = 0u; + //! \cond + uint8_t _reservedBaseEmitter = 0u; + //! \endcond + //! Validation flags in case validation is used. + //! + //! \note Validation flags are specific to the emitter and they are setup at construction time and then never + //! changed. + ValidationFlags _validationFlags = ValidationFlags::kNone; + //! Validation options. + DiagnosticOptions _diagnosticOptions = DiagnosticOptions::kNone; + + //! All supported architectures in a bit-mask, where LSB is the bit with a zero index. + uint64_t _archMask = 0; + + //! Encoding options. + EncodingOptions _encodingOptions = EncodingOptions::kNone; + + //! Forced instruction options, combined with \ref _instOptions by \ref emit(). + InstOptions _forcedInstOptions = InstOptions::kReserved; + //! Internal private data used freely by any emitter. + uint32_t _privateData = 0; + + //! CodeHolder the emitter is attached to. + CodeHolder* _code = nullptr; + //! Attached \ref Logger. + Logger* _logger = nullptr; + //! Attached \ref ErrorHandler. + ErrorHandler* _errorHandler = nullptr; + + //! Describes the target environment, matches \ref CodeHolder::environment(). + Environment _environment {}; + //! Native GP register signature and signature related information. + OperandSignature _gpSignature {}; + + //! Emitter state that can be used to specify options and inline comment of a next node or instruction. + struct State { + InstOptions options; + RegOnly extraReg; + const char* comment; + }; + + //! Next instruction options (affects the next instruction). + InstOptions _instOptions = InstOptions::kNone; + //! Extra register (op-mask {k} on AVX-512) (affects the next instruction). + RegOnly _extraReg {}; + //! Inline comment of the next instruction (affects the next instruction). + const char* _inlineComment = nullptr; + + //! Function callbacks used by emitter implementation. + //! + //! These are typically shared between Assembler/Builder/Compiler of a single backend. + struct Funcs { + typedef Error (ASMJIT_CDECL* EmitProlog)(BaseEmitter* emitter, const FuncFrame& frame); + typedef Error (ASMJIT_CDECL* EmitEpilog)(BaseEmitter* emitter, const FuncFrame& frame); + typedef Error (ASMJIT_CDECL* EmitArgsAssignment)(BaseEmitter* emitter, const FuncFrame& frame, const FuncArgsAssignment& args); + + typedef Error (ASMJIT_CDECL* FormatInstruction)( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + const BaseInst& inst, const Operand_* operands, size_t opCount) ASMJIT_NOEXCEPT_TYPE; + + typedef Error (ASMJIT_CDECL* ValidateFunc)(const BaseInst& inst, const Operand_* operands, size_t opCount, ValidationFlags validationFlags) ASMJIT_NOEXCEPT_TYPE; + + //! Emit prolog implementation. + EmitProlog emitProlog; + //! Emit epilog implementation. + EmitEpilog emitEpilog; + //! Emit arguments assignment implementation. + EmitArgsAssignment emitArgsAssignment; + //! Instruction formatter implementation. + FormatInstruction formatInstruction; + //! Instruction validation implementation. + ValidateFunc validate; + + //! Resets all functions to nullptr. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + emitProlog = nullptr; + emitEpilog = nullptr; + emitArgsAssignment = nullptr; + validate = nullptr; + } + }; + + Funcs _funcs {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit BaseEmitter(EmitterType emitterType) noexcept; + ASMJIT_API virtual ~BaseEmitter() noexcept; + + //! \} + + //! \name Cast + //! \{ + + template + ASMJIT_INLINE_NODEBUG T* as() noexcept { return reinterpret_cast(this); } + + template + ASMJIT_INLINE_NODEBUG const T* as() const noexcept { return reinterpret_cast(this); } + + //! \} + + //! \name Emitter Type & Flags + //! \{ + + //! Returns the type of this emitter, see `EmitterType`. + ASMJIT_INLINE_NODEBUG EmitterType emitterType() const noexcept { return _emitterType; } + //! Returns emitter flags , see `Flags`. + ASMJIT_INLINE_NODEBUG EmitterFlags emitterFlags() const noexcept { return _emitterFlags; } + + //! Tests whether the emitter inherits from `BaseAssembler`. + ASMJIT_INLINE_NODEBUG bool isAssembler() const noexcept { return _emitterType == EmitterType::kAssembler; } + //! Tests whether the emitter inherits from `BaseBuilder`. + //! + //! \note Both Builder and Compiler emitters would return `true`. + ASMJIT_INLINE_NODEBUG bool isBuilder() const noexcept { return uint32_t(_emitterType) >= uint32_t(EmitterType::kBuilder); } + //! Tests whether the emitter inherits from `BaseCompiler`. + ASMJIT_INLINE_NODEBUG bool isCompiler() const noexcept { return _emitterType == EmitterType::kCompiler; } + + //! Tests whether the emitter has the given `flag` enabled. + ASMJIT_INLINE_NODEBUG bool hasEmitterFlag(EmitterFlags flag) const noexcept { return Support::test(_emitterFlags, flag); } + //! Tests whether the emitter is finalized. + ASMJIT_INLINE_NODEBUG bool isFinalized() const noexcept { return hasEmitterFlag(EmitterFlags::kFinalized); } + //! Tests whether the emitter is destroyed (only used during destruction). + ASMJIT_INLINE_NODEBUG bool isDestroyed() const noexcept { return hasEmitterFlag(EmitterFlags::kDestroyed); } + + //! \} + + //! \cond INTERNAL + //! \name Internal Functions + //! \{ + + ASMJIT_INLINE_NODEBUG void _addEmitterFlags(EmitterFlags flags) noexcept { _emitterFlags |= flags; } + ASMJIT_INLINE_NODEBUG void _clearEmitterFlags(EmitterFlags flags) noexcept { _emitterFlags &= _emitterFlags & ~flags; } + + //! \} + //! \endcond + + //! \name Target Information + //! \{ + + //! Returns the CodeHolder this emitter is attached to. + ASMJIT_INLINE_NODEBUG CodeHolder* code() const noexcept { return _code; } + + //! Returns the target environment. + //! + //! The returned \ref Environment reference matches \ref CodeHolder::environment(). + ASMJIT_INLINE_NODEBUG const Environment& environment() const noexcept { return _environment; } + + //! Tests whether the target architecture is 32-bit. + ASMJIT_INLINE_NODEBUG bool is32Bit() const noexcept { return environment().is32Bit(); } + //! Tests whether the target architecture is 64-bit. + ASMJIT_INLINE_NODEBUG bool is64Bit() const noexcept { return environment().is64Bit(); } + + //! Returns the target architecture type. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return environment().arch(); } + //! Returns the target architecture sub-type. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return environment().subArch(); } + + //! Returns the target architecture's GP register size (4 or 8 bytes). + ASMJIT_INLINE_NODEBUG uint32_t registerSize() const noexcept { return environment().registerSize(); } + + //! Returns instruction alignment. + //! + //! The following values are returned based on the target architecture: + //! - X86 and X86_64 - instruction alignment is 1 + //! - AArch32 - instruction alignment is 4 in A32 mode and 2 in THUMB mode. + //! - AArch64 - instruction alignment is 4 + ASMJIT_INLINE_NODEBUG uint32_t instructionAlignment() const noexcept { return _instructionAlignment; } + + //! \} + + //! \name Initialization & Finalization + //! \{ + + //! Tests whether the emitter is initialized (i.e. attached to \ref CodeHolder). + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _code != nullptr; } + + //! Finalizes this emitter. + //! + //! Materializes the content of the emitter by serializing it to the attached \ref CodeHolder through an architecture + //! specific \ref BaseAssembler. This function won't do anything if the emitter inherits from \ref BaseAssembler as + //! assemblers emit directly to a \ref CodeBuffer held by \ref CodeHolder. However, if this is an emitter that + //! inherits from \ref BaseBuilder or \ref BaseCompiler then these emitters need the materialization phase as they + //! store their content in a representation not visible to \ref CodeHolder. + ASMJIT_API virtual Error finalize(); + + //! \} + + //! \name Logging + //! \{ + + //! Tests whether the emitter has a logger. + ASMJIT_INLINE_NODEBUG bool hasLogger() const noexcept { return _logger != nullptr; } + + //! Tests whether the emitter has its own logger. + //! + //! Own logger means that it overrides the possible logger that may be used by \ref CodeHolder this emitter is + //! attached to. + ASMJIT_INLINE_NODEBUG bool hasOwnLogger() const noexcept { return hasEmitterFlag(EmitterFlags::kOwnLogger); } + + //! Returns the logger this emitter uses. + //! + //! The returned logger is either the emitter's own logger or it's logger used by \ref CodeHolder this emitter + //! is attached to. + ASMJIT_INLINE_NODEBUG Logger* logger() const noexcept { return _logger; } + + //! Sets or resets the logger of the emitter. + //! + //! If the `logger` argument is non-null then the logger will be considered emitter's own logger, see \ref + //! hasOwnLogger() for more details. If the given `logger` is null then the emitter will automatically use logger + //! that is attached to the \ref CodeHolder this emitter is attached to. + ASMJIT_API void setLogger(Logger* logger) noexcept; + + //! Resets the logger of this emitter. + //! + //! The emitter will bail to using a logger attached to \ref CodeHolder this emitter is attached to, or no logger + //! at all if \ref CodeHolder doesn't have one. + ASMJIT_INLINE_NODEBUG void resetLogger() noexcept { return setLogger(nullptr); } + + //! \} + + //! \name Error Handling + //! \{ + + //! Tests whether the emitter has an error handler attached. + ASMJIT_INLINE_NODEBUG bool hasErrorHandler() const noexcept { return _errorHandler != nullptr; } + + //! Tests whether the emitter has its own error handler. + //! + //! Own error handler means that it overrides the possible error handler that may be used by \ref CodeHolder this + //! emitter is attached to. + ASMJIT_INLINE_NODEBUG bool hasOwnErrorHandler() const noexcept { return hasEmitterFlag(EmitterFlags::kOwnErrorHandler); } + + //! Returns the error handler this emitter uses. + //! + //! The returned error handler is either the emitter's own error handler or it's error handler used by + //! \ref CodeHolder this emitter is attached to. + ASMJIT_INLINE_NODEBUG ErrorHandler* errorHandler() const noexcept { return _errorHandler; } + + //! Sets or resets the error handler of the emitter. + ASMJIT_API void setErrorHandler(ErrorHandler* errorHandler) noexcept; + + //! Resets the error handler. + ASMJIT_INLINE_NODEBUG void resetErrorHandler() noexcept { setErrorHandler(nullptr); } + + //! Handles the given error in the following way: + //! 1. If the emitter has \ref ErrorHandler attached, it calls its \ref ErrorHandler::handleError() member function + //! first, and then returns the error. The `handleError()` function may throw. + //! 2. if the emitter doesn't have \ref ErrorHandler, the error is simply returned. + ASMJIT_API Error reportError(Error err, const char* message = nullptr); + + //! \} + + //! \name Encoding Options + //! \{ + + //! Returns encoding options. + ASMJIT_INLINE_NODEBUG EncodingOptions encodingOptions() const noexcept { return _encodingOptions; } + //! Tests whether the encoding `option` is set. + ASMJIT_INLINE_NODEBUG bool hasEncodingOption(EncodingOptions option) const noexcept { return Support::test(_encodingOptions, option); } + + //! Enables the given encoding `options`. + ASMJIT_INLINE_NODEBUG void addEncodingOptions(EncodingOptions options) noexcept { _encodingOptions |= options; } + //! Disables the given encoding `options`. + ASMJIT_INLINE_NODEBUG void clearEncodingOptions(EncodingOptions options) noexcept { _encodingOptions &= ~options; } + + //! \} + + //! \name Diagnostic Options + //! \{ + + //! Returns the emitter's diagnostic options. + ASMJIT_INLINE_NODEBUG DiagnosticOptions diagnosticOptions() const noexcept { return _diagnosticOptions; } + + //! Tests whether the given `option` is present in the emitter's diagnostic options. + ASMJIT_INLINE_NODEBUG bool hasDiagnosticOption(DiagnosticOptions option) const noexcept { return Support::test(_diagnosticOptions, option); } + + //! Activates the given diagnostic `options`. + //! + //! This function is used to activate explicit validation options that will be then used by all emitter + //! implementations. There are in general two possibilities: + //! + //! - Architecture specific assembler is used. In this case a \ref DiagnosticOptions::kValidateAssembler can be + //! used to turn on explicit validation that will be used before an instruction is emitted. This means that + //! internally an extra step will be performed to make sure that the instruction is correct. This is needed, + //! because by default assemblers prefer speed over strictness. + //! + //! This option should be used in debug builds as it's pretty expensive. + //! + //! - Architecture specific builder or compiler is used. In this case the user can turn on + //! \ref DiagnosticOptions::kValidateIntermediate option that adds explicit validation step before the Builder + //! or Compiler creates an \ref InstNode to represent an emitted instruction. Error will be returned if the + //! instruction is ill-formed. In addition, also \ref DiagnosticOptions::kValidateAssembler can be used, which + //! would not be consumed by Builder / Compiler directly, but it would be propagated to an architecture specific + //! \ref BaseAssembler implementation it creates during \ref BaseEmitter::finalize(). + ASMJIT_API void addDiagnosticOptions(DiagnosticOptions options) noexcept; + + //! Deactivates the given validation `options`. + //! + //! See \ref addDiagnosticOptions() and \ref DiagnosticOptions for more details. + ASMJIT_API void clearDiagnosticOptions(DiagnosticOptions options) noexcept; + + //! \} + + //! \name Instruction Options + //! \{ + + //! Returns forced instruction options. + //! + //! Forced instruction options are merged with next instruction options before the instruction is encoded. These + //! options have some bits reserved that are used by error handling, logging, and instruction validation purposes. + //! Other options are globals that affect each instruction. + ASMJIT_INLINE_NODEBUG InstOptions forcedInstOptions() const noexcept { return _forcedInstOptions; } + + //! Returns options of the next instruction. + ASMJIT_INLINE_NODEBUG InstOptions instOptions() const noexcept { return _instOptions; } + //! Returns options of the next instruction. + ASMJIT_INLINE_NODEBUG void setInstOptions(InstOptions options) noexcept { _instOptions = options; } + //! Adds options of the next instruction. + ASMJIT_INLINE_NODEBUG void addInstOptions(InstOptions options) noexcept { _instOptions |= options; } + //! Resets options of the next instruction. + ASMJIT_INLINE_NODEBUG void resetInstOptions() noexcept { _instOptions = InstOptions::kNone; } + + //! Tests whether the extra register operand is valid. + ASMJIT_INLINE_NODEBUG bool hasExtraReg() const noexcept { return _extraReg.isReg(); } + //! Returns an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG const RegOnly& extraReg() const noexcept { return _extraReg; } + //! Sets an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG void setExtraReg(const BaseReg& reg) noexcept { _extraReg.init(reg); } + //! Sets an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG void setExtraReg(const RegOnly& reg) noexcept { _extraReg.init(reg); } + //! Resets an extra operand that will be used by the next instruction (architecture specific). + ASMJIT_INLINE_NODEBUG void resetExtraReg() noexcept { _extraReg.reset(); } + + //! Returns comment/annotation of the next instruction. + ASMJIT_INLINE_NODEBUG const char* inlineComment() const noexcept { return _inlineComment; } + //! Sets comment/annotation of the next instruction. + //! + //! \note This string is set back to null by `_emit()`, but until that it has to remain valid as the Emitter is not + //! required to make a copy of it (and it would be slow to do that for each instruction). + ASMJIT_INLINE_NODEBUG void setInlineComment(const char* s) noexcept { _inlineComment = s; } + //! Resets the comment/annotation to nullptr. + ASMJIT_INLINE_NODEBUG void resetInlineComment() noexcept { _inlineComment = nullptr; } + + //! \} + + //! \name Emitter State + //! \{ + + //! Resets the emitter state, which contains instruction options, extra register, and inline comment. + //! + //! Emitter can have a state that describes instruction options and extra register used by the instruction. Most + //! instructions don't need nor use the state, however, if an instruction uses a prefix such as REX or REP prefix, + //! which is set explicitly, then the state would contain it. This allows to mimic the syntax of assemblers such + //! as X86. For example `rep().movs(...)` would map to a `REP MOVS` instuction on X86. The same applies to various + //! hints and the use of a mask register in AVX-512 mode. + ASMJIT_INLINE_NODEBUG void resetState() noexcept { + resetInstOptions(); + resetExtraReg(); + resetInlineComment(); + } + + //! \cond INTERNAL + + //! Grabs the current emitter state and resets the emitter state at the same time, returning the state the emitter + //! had before the state was reset. + ASMJIT_INLINE_NODEBUG State _grabState() noexcept { + State s{_instOptions | _forcedInstOptions, _extraReg, _inlineComment}; + resetState(); + return s; + } + //! \endcond + + //! \} + + //! \name Sections + //! \{ + + //! Switches the given `section`. + //! + //! Once switched, everything is added to the given `section`. + ASMJIT_API virtual Error section(Section* section); + + //! \} + + //! \name Labels + //! \{ + + //! Creates a new label. + ASMJIT_API virtual Label newLabel(); + //! Creates a new named label. + ASMJIT_API virtual Label newNamedLabel(const char* name, size_t nameSize = SIZE_MAX, LabelType type = LabelType::kGlobal, uint32_t parentId = Globals::kInvalidId); + + //! Creates a new anonymous label with a name, which can only be used for debugging purposes. + ASMJIT_INLINE_NODEBUG Label newAnonymousLabel(const char* name, size_t nameSize = SIZE_MAX) { return newNamedLabel(name, nameSize, LabelType::kAnonymous); } + //! Creates a new external label. + ASMJIT_INLINE_NODEBUG Label newExternalLabel(const char* name, size_t nameSize = SIZE_MAX) { return newNamedLabel(name, nameSize, LabelType::kExternal); } + + //! Returns `Label` by `name`. + //! + //! Returns invalid Label in case that the name is invalid or label was not found. + //! + //! \note This function doesn't trigger ErrorHandler in case the name is invalid or no such label exist. You must + //! always check the validity of the `Label` returned. + ASMJIT_API Label labelByName(const char* name, size_t nameSize = SIZE_MAX, uint32_t parentId = Globals::kInvalidId) noexcept; + + //! Binds the `label` to the current position of the current section. + //! + //! \note Attempt to bind the same label multiple times will return an error. + ASMJIT_API virtual Error bind(const Label& label); + + //! Tests whether the label `id` is valid (i.e. registered). + ASMJIT_API bool isLabelValid(uint32_t labelId) const noexcept; + //! Tests whether the `label` is valid (i.e. registered). + ASMJIT_INLINE_NODEBUG bool isLabelValid(const Label& label) const noexcept { return isLabelValid(label.id()); } + + //! \} + + //! \name Emit + //! \{ + + // NOTE: These `emit()` helpers are designed to address a code-bloat generated by C++ compilers to call a function + // having many arguments. Each parameter to `_emit()` requires some code to pass it, which means that if we default + // to 5 arguments in `_emit()` and instId the C++ compiler would have to generate a virtual function call having 5 + // parameters and additional `this` argument, which is quite a lot. Since by default most instructions have 2 to 3 + // operands it's better to introduce helpers that pass from 0 to 6 operands that help to reduce the size of emit(...) + // function call. + + //! Emits an instruction (internal). + ASMJIT_API Error _emitI(InstId instId); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3, const Operand_& o4); + //! \overload + ASMJIT_API Error _emitI(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_& o3, const Operand_& o4, const Operand_& o5); + + //! Emits an instruction `instId` with the given `operands`. + //! + //! This is the most universal way of emitting code, which accepts an instruction identifier and instruction + //! operands. This is called an "unchecked" API as emit doesn't provide any type checks at compile-time. This + //! allows to emit instruction with just \ref Operand instances, which could be handy in some cases - for + //! example emitting generic code where you don't know whether some operand is register, memory, or immediate. + template + ASMJIT_INLINE_NODEBUG Error emit(InstId instId, Args&&... operands) { + return _emitI(instId, Support::ForwardOp::forward(operands)...); + } + + //! Similar to \ref emit(), but uses array of `operands` instead. + ASMJIT_INLINE_NODEBUG Error emitOpArray(InstId instId, const Operand_* operands, size_t opCount) { + return _emitOpArray(instId, operands, opCount); + } + + //! Similar to \ref emit(), but emits instruction with both instruction options and extra register, followed + //! by an array of `operands`. + ASMJIT_FORCE_INLINE Error emitInst(const BaseInst& inst, const Operand_* operands, size_t opCount) { + setInstOptions(inst.options()); + setExtraReg(inst.extraReg()); + return _emitOpArray(inst.id(), operands, opCount); + } + + //! \} + + //! \cond INTERNAL + //! \name Emit Internals + //! \{ + + //! Emits an instruction - all 6 operands must be defined. + ASMJIT_API virtual Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* oExt); + //! Emits instruction having operands stored in array. + ASMJIT_API virtual Error _emitOpArray(InstId instId, const Operand_* operands, size_t opCount); + + //! \} + //! \endcond + + //! \name Emit Utilities + //! \{ + + //! Emits a function prolog described by the given function `frame`. + ASMJIT_API Error emitProlog(const FuncFrame& frame); + //! Emits a function epilog described by the given function `frame`. + ASMJIT_API Error emitEpilog(const FuncFrame& frame); + //! Emits code that reassigns function `frame` arguments to the given `args`. + ASMJIT_API Error emitArgsAssignment(const FuncFrame& frame, const FuncArgsAssignment& args); + + //! \} + + //! \name Align + //! \{ + + //! Aligns the current CodeBuffer position to the `alignment` specified. + //! + //! The sequence that is used to fill the gap between the aligned location and the current location depends on the + //! align `mode`, see \ref AlignMode. The `alignment` argument specifies alignment in bytes, so for example when + //! it's `32` it means that the code buffer will be aligned to `32` bytes. + ASMJIT_API virtual Error align(AlignMode alignMode, uint32_t alignment); + + //! \} + + //! \name Embed + //! \{ + + //! Embeds raw data into the \ref CodeBuffer. + ASMJIT_API virtual Error embed(const void* data, size_t dataSize); + + //! Embeds a typed data array. + //! + //! This is the most flexible function for embedding data as it allows to: + //! + //! - Assign a `typeId` to the data, so the emitter knows the type of items stored in `data`. Binary data should + //! use \ref TypeId::kUInt8. + //! + //! - Repeat the given data `repeatCount` times, so the data can be used as a fill pattern for example, or as a + //! pattern used by SIMD instructions. + ASMJIT_API virtual Error embedDataArray(TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1); + + //! Embeds int8_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt8(int8_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt8, &value, 1, repeatCount); } + //! Embeds uint8_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt8(uint8_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt8, &value, 1, repeatCount); } + //! Embeds int16_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt16(int16_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt16, &value, 1, repeatCount); } + //! Embeds uint16_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt16(uint16_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt16, &value, 1, repeatCount); } + //! Embeds int32_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt32(int32_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt32, &value, 1, repeatCount); } + //! Embeds uint32_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt32(uint32_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt32, &value, 1, repeatCount); } + //! Embeds int64_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedInt64(int64_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kInt64, &value, 1, repeatCount); } + //! Embeds uint64_t `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedUInt64(uint64_t value, size_t repeatCount = 1) { return embedDataArray(TypeId::kUInt64, &value, 1, repeatCount); } + //! Embeds a floating point `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedFloat(float value, size_t repeatCount = 1) { return embedDataArray(TypeId(TypeUtils::TypeIdOfT::kTypeId), &value, 1, repeatCount); } + //! Embeds a floating point `value` repeated by `repeatCount`. + ASMJIT_INLINE_NODEBUG Error embedDouble(double value, size_t repeatCount = 1) { return embedDataArray(TypeId(TypeUtils::TypeIdOfT::kTypeId), &value, 1, repeatCount); } + + //! Embeds a constant pool at the current offset by performing the following: + //! 1. Aligns by using AlignMode::kData to the minimum `pool` alignment. + //! 2. Binds the ConstPool label so it's bound to an aligned location. + //! 3. Emits ConstPool content. + ASMJIT_API virtual Error embedConstPool(const Label& label, const ConstPool& pool); + + //! Embeds an absolute `label` address as data. + //! + //! The `dataSize` is an optional argument that can be used to specify the size of the address data. If it's zero + //! (default) the address size is deduced from the target architecture (either 4 or 8 bytes). + ASMJIT_API virtual Error embedLabel(const Label& label, size_t dataSize = 0); + + //! Embeds a delta (distance) between the `label` and `base` calculating it as `label - base`. This function was + //! designed to make it easier to embed lookup tables where each index is a relative distance of two labels. + ASMJIT_API virtual Error embedLabelDelta(const Label& label, const Label& base, size_t dataSize = 0); + + //! \} + + //! \name Comment + //! \{ + + //! Emits a comment stored in `data` with an optional `size` parameter. + ASMJIT_API virtual Error comment(const char* data, size_t size = SIZE_MAX); + + //! Emits a formatted comment specified by `fmt` and variable number of arguments. + ASMJIT_API Error commentf(const char* fmt, ...); + //! Emits a formatted comment specified by `fmt` and `ap`. + ASMJIT_API Error commentv(const char* fmt, va_list ap); + + //! \} + + //! \name Events + //! \{ + + //! Called after the emitter was attached to `CodeHolder`. + ASMJIT_API virtual Error onAttach(CodeHolder* ASMJIT_NONNULL(code)) noexcept; + //! Called after the emitter was detached from `CodeHolder`. + ASMJIT_API virtual Error onDetach(CodeHolder* ASMJIT_NONNULL(code)) noexcept; + + //! Called when \ref CodeHolder has updated an important setting, which involves the following: + //! + //! - \ref Logger has been changed (\ref CodeHolder::setLogger() has been called). + //! + //! - \ref ErrorHandler has been changed (\ref CodeHolder::setErrorHandler() has been called). + //! + //! This function ensures that the settings are properly propagated from \ref CodeHolder to the emitter. + //! + //! \note This function is virtual and can be overridden, however, if you do so, always call \ref + //! BaseEmitter::onSettingsUpdated() within your own implementation to ensure that the emitter is + //! in a consistent state. + ASMJIT_API virtual void onSettingsUpdated() noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_EMITTER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/environment.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/environment.h new file mode 100644 index 0000000000000000000000000000000000000000..a343eef3411873edac87ffaef9ba6c39e08b5891 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/environment.h @@ -0,0 +1,534 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ENVIRONMENT_H_INCLUDED +#define ASMJIT_CORE_ENVIRONMENT_H_INCLUDED + +#include "../core/archtraits.h" + +#if defined(__APPLE__) + #include +#endif + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Vendor. +//! +//! \note AsmJit doesn't use vendor information at the moment. It's provided for future use, if required. +enum class Vendor : uint8_t { + //! Unknown or uninitialized platform vendor. + kUnknown = 0, + + //! Maximum value of `Vendor`. + kMaxValue = kUnknown, + + //! Platform vendor detected at compile-time. + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#else + kUnknown +#endif +}; + +//! Platform - runtime environment or operating system. +enum class Platform : uint8_t { + //! Unknown or uninitialized platform. + kUnknown = 0, + + //! Windows OS. + kWindows, + + //! Other platform that is not Windows, most likely POSIX based. + kOther, + + //! Linux OS. + kLinux, + //! GNU/Hurd OS. + kHurd, + + //! FreeBSD OS. + kFreeBSD, + //! OpenBSD OS. + kOpenBSD, + //! NetBSD OS. + kNetBSD, + //! DragonFly BSD OS. + kDragonFlyBSD, + + //! Haiku OS. + kHaiku, + + //! Apple OSX. + kOSX, + //! Apple iOS. + kIOS, + //! Apple TVOS. + kTVOS, + //! Apple WatchOS. + kWatchOS, + + //! Emscripten platform. + kEmscripten, + + //! Maximum value of `Platform`. + kMaxValue = kEmscripten, + + //! Platform detected at compile-time (platform of the host). + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#elif defined(__EMSCRIPTEN__) + kEmscripten +#elif defined(_WIN32) + kWindows +#elif defined(__linux__) + kLinux +#elif defined(__gnu_hurd__) + kHurd +#elif defined(__FreeBSD__) + kFreeBSD +#elif defined(__OpenBSD__) + kOpenBSD +#elif defined(__NetBSD__) + kNetBSD +#elif defined(__DragonFly__) + kDragonFlyBSD +#elif defined(__HAIKU__) + kHaiku +#elif defined(__APPLE__) && TARGET_OS_OSX + kOSX +#elif defined(__APPLE__) && TARGET_OS_TV + kTVOS +#elif defined(__APPLE__) && TARGET_OS_WATCH + kWatchOS +#elif defined(__APPLE__) && TARGET_OS_IPHONE + kIOS +#else + kOther +#endif +}; + +//! Platform ABI (application binary interface). +enum class PlatformABI : uint8_t { + //! Unknown or uninitialized environment. + kUnknown = 0, + //! Microsoft ABI. + kMSVC, + //! GNU ABI. + kGNU, + //! Android Environment / ABI. + kAndroid, + //! Cygwin ABI. + kCygwin, + //! Darwin ABI. + kDarwin, + + //! Maximum value of `PlatformABI`. + kMaxValue, + + //! Host ABI detected at compile-time. + kHost = +#if defined(_DOXYGEN) + DETECTED_AT_COMPILE_TIME +#elif defined(_MSC_VER) + kMSVC +#elif defined(__CYGWIN__) + kCygwin +#elif defined(__MINGW32__) || defined(__GLIBC__) + kGNU +#elif defined(__ANDROID__) + kAndroid +#elif defined(__APPLE__) + kDarwin +#else + kUnknown +#endif +}; + +//! Floating point ABI (ARM). +enum class FloatABI : uint8_t { + kHardFloat = 0, + kSoftFloat, + + kHost = +#if ASMJIT_ARCH_ARM == 32 && defined(__SOFTFP__) + kSoftFloat +#else + kHardFloat +#endif +}; + +//! Object format. +//! +//! \note AsmJit doesn't really use anything except \ref ObjectFormat::kUnknown and \ref ObjectFormat::kJIT at +//! the moment. Object file formats are provided for future extensibility and a possibility to generate object +//! files at some point. +enum class ObjectFormat : uint8_t { + //! Unknown or uninitialized object format. + kUnknown = 0, + + //! JIT code generation object, most likely \ref JitRuntime or a custom + //! \ref Target implementation. + kJIT, + + //! Executable and linkable format (ELF). + kELF, + //! Common object file format. + kCOFF, + //! Extended COFF object format. + kXCOFF, + //! Mach object file format. + kMachO, + + //! Maximum value of `ObjectFormat`. + kMaxValue +}; + +//! Represents an environment, which is usually related to a \ref Target. +//! +//! Environment has usually an 'arch-subarch-vendor-os-abi' format, which is sometimes called "Triple" (historically +//! it used to be 3 only parts) or "Tuple", which is a convention used by Debian Linux. +//! +//! AsmJit doesn't support all possible combinations or architectures and ABIs, however, it models the environment +//! similarly to other compilers for future extensibility. +class Environment { +public: + //! \name Members + //! \{ + + //! Architecture. + Arch _arch = Arch::kUnknown; + //! Sub-architecture type. + SubArch _subArch = SubArch::kUnknown; + //! Vendor type. + Vendor _vendor = Vendor::kUnknown; + //! Platform. + Platform _platform = Platform::kUnknown; + //! Platform ABI. + PlatformABI _platformABI = PlatformABI::kUnknown; + //! Object format. + ObjectFormat _objectFormat = ObjectFormat::kUnknown; + //! Floating point ABI. + FloatABI _floatABI = FloatABI::kHardFloat; + //! Reserved for future use, must be zero. + uint8_t _reserved = 0; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default initialized environment (all values either unknown or set to safe defaults). + ASMJIT_INLINE_NODEBUG constexpr Environment() noexcept = default; + //! Creates a copy of `other` instance. + ASMJIT_INLINE_NODEBUG constexpr Environment(const Environment& other) noexcept = default; + + //! Creates \ref Environment initialized to `arch`, `subArch`, `vendor`, `platform`, `platformABI`, `objectFormat`, + //! and `floatABI`. + ASMJIT_INLINE_NODEBUG constexpr explicit Environment( + Arch arch, + SubArch subArch = SubArch::kUnknown, + Vendor vendor = Vendor::kUnknown, + Platform platform = Platform::kUnknown, + PlatformABI platformABI = PlatformABI::kUnknown, + ObjectFormat objectFormat = ObjectFormat::kUnknown, + FloatABI floatABI = FloatABI::kHardFloat) noexcept + : _arch(arch), + _subArch(subArch), + _vendor(vendor), + _platform(platform), + _platformABI(platformABI), + _objectFormat(objectFormat), + _floatABI(floatABI) {} + + //! Returns the host environment constructed from preprocessor macros defined by the compiler. + //! + //! The returned environment should precisely match the target host architecture, sub-architecture, platform, + //! and ABI. + static ASMJIT_INLINE_NODEBUG Environment host() noexcept { + return Environment(Arch::kHost, SubArch::kHost, Vendor::kHost, Platform::kHost, PlatformABI::kHost, ObjectFormat::kUnknown, FloatABI::kHost); + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Environment& operator=(const Environment& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG bool operator==(const Environment& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const Environment& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the environment is not set up. + //! + //! Returns true if all members are zero, and thus unknown. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { + // Unfortunately compilers won't optimize fields are checked one by one... + return _packed() == 0; + } + + //! Tests whether the environment is initialized, which means it must have + //! a valid architecture. + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { + return _arch != Arch::kUnknown; + } + + ASMJIT_INLINE_NODEBUG uint64_t _packed() const noexcept { + uint64_t x; + memcpy(&x, this, 8); + return x; + } + + //! Resets all members of the environment to zero / unknown. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = Environment{}; } + + //! Tests whether this environment is equal to `other`. + ASMJIT_INLINE_NODEBUG bool equals(const Environment& other) const noexcept { return _packed() == other._packed(); } + + //! Returns the architecture. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + //! Returns the sub-architecture. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return _subArch; } + //! Returns vendor. + ASMJIT_INLINE_NODEBUG Vendor vendor() const noexcept { return _vendor; } + //! Returns target's platform or operating system. + ASMJIT_INLINE_NODEBUG Platform platform() const noexcept { return _platform; } + //! Returns target's ABI. + ASMJIT_INLINE_NODEBUG PlatformABI platformABI() const noexcept { return _platformABI; } + //! Returns target's object format. + ASMJIT_INLINE_NODEBUG ObjectFormat objectFormat() const noexcept { return _objectFormat; } + //! Returns floating point ABI. + ASMJIT_INLINE_NODEBUG FloatABI floatABI() const noexcept { return _floatABI; } + + //! Initializes \ref Environment to `arch`, `subArch`, `vendor`, `platform`, `platformABI`, `objectFormat`, + //! and `floatABI`. + inline void init( + Arch arch, + SubArch subArch = SubArch::kUnknown, + Vendor vendor = Vendor::kUnknown, + Platform platform = Platform::kUnknown, + PlatformABI platformABI = PlatformABI::kUnknown, + ObjectFormat objectFormat = ObjectFormat::kUnknown, + FloatABI floatABI = FloatABI::kHardFloat) noexcept { + + _arch = arch; + _subArch = subArch; + _vendor = vendor; + _platform = platform; + _platformABI = platformABI; + _objectFormat = objectFormat; + _floatABI = floatABI; + _reserved = 0; + } + + //! Tests whether this environment describes a 32-bit X86. + ASMJIT_INLINE_NODEBUG bool isArchX86() const noexcept { return _arch == Arch::kX86; } + //! Tests whether this environment describes a 64-bit X86. + ASMJIT_INLINE_NODEBUG bool isArchX64() const noexcept { return _arch == Arch::kX64; } + //! Tests whether this environment describes a 32-bit ARM. + ASMJIT_INLINE_NODEBUG bool isArchARM() const noexcept { return isArchARM(_arch); } + //! Tests whether this environment describes a 32-bit ARM in THUMB mode. + ASMJIT_INLINE_NODEBUG bool isArchThumb() const noexcept { return isArchThumb(_arch); } + //! Tests whether this environment describes a 64-bit X86. + ASMJIT_INLINE_NODEBUG bool isArchAArch64() const noexcept { return isArchAArch64(_arch); } + //! Tests whether this environment describes a 32-bit MIPS. + ASMJIT_INLINE_NODEBUG bool isArchMIPS32() const noexcept { return isArchMIPS32(_arch); } + //! Tests whether this environment describes a 64-bit MIPS. + ASMJIT_INLINE_NODEBUG bool isArchMIPS64() const noexcept { return isArchMIPS64(_arch); } + //! Tests whether this environment describes a 32-bit RISC-V. + ASMJIT_INLINE_NODEBUG bool isArchRISCV32() const noexcept { return _arch == Arch::kRISCV32; } + //! Tests whether this environment describes a 64-bit RISC-V. + ASMJIT_INLINE_NODEBUG bool isArchRISCV64() const noexcept { return _arch == Arch::kRISCV64; } + + //! Tests whether the architecture is 32-bit. + ASMJIT_INLINE_NODEBUG bool is32Bit() const noexcept { return is32Bit(_arch); } + //! Tests whether the architecture is 64-bit. + ASMJIT_INLINE_NODEBUG bool is64Bit() const noexcept { return is64Bit(_arch); } + + //! Tests whether the architecture is little endian. + ASMJIT_INLINE_NODEBUG bool isLittleEndian() const noexcept { return isLittleEndian(_arch); } + //! Tests whether the architecture is big endian. + ASMJIT_INLINE_NODEBUG bool isBigEndian() const noexcept { return isBigEndian(_arch); } + + //! Tests whether this architecture is of X86 family. + ASMJIT_INLINE_NODEBUG bool isFamilyX86() const noexcept { return isFamilyX86(_arch); } + //! Tests whether this architecture family is ARM, THUMB, or AArch64. + ASMJIT_INLINE_NODEBUG bool isFamilyARM() const noexcept { return isFamilyARM(_arch); } + //! Tests whether this architecture family is AArch32 (ARM or THUMB). + ASMJIT_INLINE_NODEBUG bool isFamilyAArch32() const noexcept { return isFamilyAArch32(_arch); } + //! Tests whether this architecture family is AArch64. + ASMJIT_INLINE_NODEBUG bool isFamilyAArch64() const noexcept { return isFamilyAArch64(_arch); } + //! Tests whether this architecture family is MISP or MIPS64. + ASMJIT_INLINE_NODEBUG bool isFamilyMIPS() const noexcept { return isFamilyMIPS(_arch); } + //! Tests whether this architecture family is RISC-V (both 32-bit and 64-bit). + ASMJIT_INLINE_NODEBUG bool isFamilyRISCV() const noexcept { return isFamilyRISCV(_arch); } + + //! Tests whether the environment platform is Windows. + ASMJIT_INLINE_NODEBUG bool isPlatformWindows() const noexcept { return _platform == Platform::kWindows; } + //! Tests whether the environment platform is Linux. + ASMJIT_INLINE_NODEBUG bool isPlatformLinux() const noexcept { return _platform == Platform::kLinux; } + //! Tests whether the environment platform is Hurd. + ASMJIT_INLINE_NODEBUG bool isPlatformHurd() const noexcept { return _platform == Platform::kHurd; } + //! Tests whether the environment platform is Haiku. + ASMJIT_INLINE_NODEBUG bool isPlatformHaiku() const noexcept { return _platform == Platform::kHaiku; } + + //! Tests whether the environment platform is any BSD. + ASMJIT_INLINE_NODEBUG bool isPlatformBSD() const noexcept { + return _platform == Platform::kFreeBSD || + _platform == Platform::kOpenBSD || + _platform == Platform::kNetBSD || + _platform == Platform::kDragonFlyBSD; + } + + //! Tests whether the environment platform is any Apple platform (OSX, iOS, TVOS, WatchOS). + ASMJIT_INLINE_NODEBUG bool isPlatformApple() const noexcept { + return _platform == Platform::kOSX || + _platform == Platform::kIOS || + _platform == Platform::kTVOS || + _platform == Platform::kWatchOS; + } + + //! Tests whether the ABI is MSVC. + ASMJIT_INLINE_NODEBUG bool isMSVC() const noexcept { return _platformABI == PlatformABI::kMSVC; } + //! Tests whether the ABI is GNU. + ASMJIT_INLINE_NODEBUG bool isGNU() const noexcept { return _platformABI == PlatformABI::kGNU; } + //! Tests whether the ABI is GNU. + ASMJIT_INLINE_NODEBUG bool isDarwin() const noexcept { return _platformABI == PlatformABI::kDarwin; } + + //! Returns a calculated stack alignment for this environment. + ASMJIT_API uint32_t stackAlignment() const noexcept; + + //! Returns a native register size of this architecture. + ASMJIT_INLINE_NODEBUG uint32_t registerSize() const noexcept { return registerSizeFromArch(_arch); } + + //! Sets the architecture to `arch`. + ASMJIT_INLINE_NODEBUG void setArch(Arch arch) noexcept { _arch = arch; } + //! Sets the sub-architecture to `subArch`. + ASMJIT_INLINE_NODEBUG void setSubArch(SubArch subArch) noexcept { _subArch = subArch; } + //! Sets the vendor to `vendor`. + ASMJIT_INLINE_NODEBUG void setVendor(Vendor vendor) noexcept { _vendor = vendor; } + //! Sets the platform to `platform`. + ASMJIT_INLINE_NODEBUG void setPlatform(Platform platform) noexcept { _platform = platform; } + //! Sets the ABI to `platformABI`. + ASMJIT_INLINE_NODEBUG void setPlatformABI(PlatformABI platformABI) noexcept { _platformABI = platformABI; } + //! Sets the object format to `objectFormat`. + ASMJIT_INLINE_NODEBUG void setObjectFormat(ObjectFormat objectFormat) noexcept { _objectFormat = objectFormat; } + + //! Sets floating point ABI to `floatABI`. + ASMJIT_INLINE_NODEBUG void setFloatABI(FloatABI floatABI) noexcept { _floatABI = floatABI; } + + //! \} + + //! \name Static Utilities + //! \{ + + static ASMJIT_INLINE_NODEBUG bool isDefinedArch(Arch arch) noexcept { + return uint32_t(arch) <= uint32_t(Arch::kMaxValue); + } + + static ASMJIT_INLINE_NODEBUG bool isValidArch(Arch arch) noexcept { + return arch != Arch::kUnknown && uint32_t(arch) <= uint32_t(Arch::kMaxValue); + } + + //! Tests whether the given architecture `arch` is 32-bit. + static ASMJIT_INLINE_NODEBUG bool is32Bit(Arch arch) noexcept { + return (uint32_t(arch) & uint32_t(Arch::k32BitMask)) == uint32_t(Arch::k32BitMask); + } + + //! Tests whether the given architecture `arch` is 64-bit. + static ASMJIT_INLINE_NODEBUG bool is64Bit(Arch arch) noexcept { + return (uint32_t(arch) & uint32_t(Arch::k32BitMask)) == 0; + } + + //! Tests whether the given architecture `arch` is little endian. + static ASMJIT_INLINE_NODEBUG bool isLittleEndian(Arch arch) noexcept { + return uint32_t(arch) < uint32_t(Arch::kBigEndian); + } + + //! Tests whether the given architecture `arch` is big endian. + static ASMJIT_INLINE_NODEBUG bool isBigEndian(Arch arch) noexcept { + return uint32_t(arch) >= uint32_t(Arch::kBigEndian); + } + + //! Tests whether the given architecture is Thumb or Thumb_BE. + static ASMJIT_INLINE_NODEBUG bool isArchThumb(Arch arch) noexcept { + return arch == Arch::kThumb || arch == Arch::kThumb_BE; + } + + //! Tests whether the given architecture is ARM or ARM_BE. + static ASMJIT_INLINE_NODEBUG bool isArchARM(Arch arch) noexcept { + return arch == Arch::kARM || arch == Arch::kARM_BE; + } + + //! Tests whether the given architecture is AArch64 or AArch64_BE. + static ASMJIT_INLINE_NODEBUG bool isArchAArch64(Arch arch) noexcept { + return arch == Arch::kAArch64 || arch == Arch::kAArch64_BE; + } + + //! Tests whether the given architecture is MIPS32_LE or MIPS32_BE. + static ASMJIT_INLINE_NODEBUG bool isArchMIPS32(Arch arch) noexcept { + return arch == Arch::kMIPS32_LE || arch == Arch::kMIPS32_BE; + } + + //! Tests whether the given architecture is MIPS64_LE or MIPS64_BE. + static ASMJIT_INLINE_NODEBUG bool isArchMIPS64(Arch arch) noexcept { + return arch == Arch::kMIPS64_LE || arch == Arch::kMIPS64_BE; + } + + //! Tests whether the given architecture family is X86 or X64. + static ASMJIT_INLINE_NODEBUG bool isFamilyX86(Arch arch) noexcept { + return arch == Arch::kX86 || arch == Arch::kX64; + } + + //! Tests whether the given architecture family is AArch32 (ARM or THUMB). + static ASMJIT_INLINE_NODEBUG bool isFamilyAArch32(Arch arch) noexcept { + return isArchARM(arch) || isArchThumb(arch); + } + + //! Tests whether the given architecture family is AArch64. + static ASMJIT_INLINE_NODEBUG bool isFamilyAArch64(Arch arch) noexcept { + return isArchAArch64(arch); + } + + //! Tests whether the given architecture family is ARM, THUMB, or AArch64. + static ASMJIT_INLINE_NODEBUG bool isFamilyARM(Arch arch) noexcept { + return isFamilyAArch32(arch) || isFamilyAArch64(arch); + } + + //! Tests whether the given architecture family is MIPS or MIPS64. + static ASMJIT_INLINE_NODEBUG bool isFamilyMIPS(Arch arch) noexcept { + return isArchMIPS32(arch) || isArchMIPS64(arch); + } + + //! Tests whether the given architecture family is RISC-V (both 32-bit and 64-bit). + static ASMJIT_INLINE_NODEBUG bool isFamilyRISCV(Arch arch) noexcept { + return arch == Arch::kRISCV32 || arch == Arch::kRISCV64; + } + + //! Returns a native general purpose register size from the given architecture. + static ASMJIT_INLINE_NODEBUG uint32_t registerSizeFromArch(Arch arch) noexcept { + return is32Bit(arch) ? 4u : 8u; + } + + //! \} +}; + +static_assert(sizeof(Environment) == 8, + "Environment must occupy exactly 8 bytes."); + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ENVIRONMENT_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/errorhandler.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/errorhandler.h new file mode 100644 index 0000000000000000000000000000000000000000..581b5e2d3ce2e00d7ae5e460d36ef127bed1b625 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/errorhandler.h @@ -0,0 +1,228 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ERRORHANDLER_H_INCLUDED +#define ASMJIT_CORE_ERRORHANDLER_H_INCLUDED + +#include "../core/globals.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_error_handling +//! \{ + +class BaseEmitter; + +//! Error handler can be used to override the default behavior of error handling. +//! +//! It's available to all classes that inherit `BaseEmitter`. Override \ref ErrorHandler::handleError() to implement +//! your own error handler. +//! +//! The following use-cases are supported: +//! +//! - Record the error and continue code generation. This is the simplest approach that can be used to at least log +//! possible errors. +//! - Throw an exception. AsmJit doesn't use exceptions and is completely exception-safe, but it's perfectly legal +//! to throw an exception from the error handler. +//! - Use plain old C's `setjmp()` and `longjmp()`. Asmjit always puts Assembler, Builder and Compiler to +//! a consistent state before calling \ref handleError(), so `longjmp()` can be used without issues to cancel the +//! code generation if an error occurred. This method can be used if exception handling in your project is turned +//! off and you still want some comfort. In most cases it should be safe as AsmJit uses \ref Zone memory and the +//! ownership of memory it allocates always ends with the instance that allocated it. If using this approach please +//! never jump outside the life-time of \ref CodeHolder and \ref BaseEmitter. +//! +//! \ref ErrorHandler can be attached to \ref CodeHolder or \ref BaseEmitter, which has a priority. The example below +//! uses error handler that just prints the error, but lets AsmJit continue: +//! +//! ``` +//! // Error Handling #1 - Logging and returning Error. +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Error handler that just prints the error and lets AsmJit ignore it. +//! class SimpleErrorHandler : public ErrorHandler { +//! public: +//! Error err; +//! +//! inline SimpleErrorHandler() : err(kErrorOk) {} +//! +//! void handleError(Error err, const char* message, BaseEmitter* origin) override { +//! this->err = err; +//! fprintf(stderr, "ERROR: %s\n", message); +//! } +//! }; +//! +//! int main() { +//! JitRuntime rt; +//! SimpleErrorHandler eh; +//! +//! CodeHolder code; +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&eh); +//! +//! // Try to emit instruction that doesn't exist. +//! x86::Assembler a(&code); +//! a.emit(x86::Inst::kIdMov, x86::xmm0, x86::xmm1); +//! +//! if (eh.err) { +//! // Assembler failed! +//! return 1; +//! } +//! +//! return 0; +//! } +//! ``` +//! +//! If error happens during instruction emitting / encoding the assembler behaves transactionally - the output buffer +//! won't advance if encoding failed, thus either a fully encoded instruction or nothing is emitted. The error handling +//! shown above is useful, but it's still not the best way of dealing with errors in AsmJit. The following example +//! shows how to use exception handling to handle errors in a more C++ way: +//! +//! ``` +//! // Error Handling #2 - Throwing an exception. +//! #include +//! #include +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Error handler that throws a user-defined `AsmJitException`. +//! class AsmJitException : public std::exception { +//! public: +//! Error err; +//! std::string message; +//! +//! AsmJitException(Error err, const char* message) noexcept +//! : err(err), +//! message(message) {} +//! +//! const char* what() const noexcept override { return message.c_str(); } +//! }; +//! +//! class ThrowableErrorHandler : public ErrorHandler { +//! public: +//! // Throw is possible, functions that use ErrorHandler are never 'noexcept'. +//! void handleError(Error err, const char* message, BaseEmitter* origin) override { +//! throw AsmJitException(err, message); +//! } +//! }; +//! +//! int main() { +//! JitRuntime rt; +//! ThrowableErrorHandler eh; +//! +//! CodeHolder code; +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&eh); +//! +//! x86::Assembler a(&code); +//! +//! // Try to emit instruction that doesn't exist. +//! try { +//! a.emit(x86::Inst::kIdMov, x86::xmm0, x86::xmm1); +//! } +//! catch (const AsmJitException& ex) { +//! printf("EXCEPTION THROWN: %s\n", ex.what()); +//! return 1; +//! } +//! +//! return 0; +//! } +//! ``` +//! +//! If C++ exceptions are not what you like or your project turns off them completely there is still a way of reducing +//! the error handling to a minimum by using a standard setjmp/longjmp approach. AsmJit is exception-safe and cleans +//! up everything before calling the ErrorHandler, so any approach is safe. You can simply jump from the error handler +//! without causing any side-effects or memory leaks. The following example demonstrates how it could be done: +//! +//! ``` +//! // Error Handling #3 - Using setjmp/longjmp if exceptions are not allowed. +//! #include +//! #include +//! #include +//! +//! class LongJmpErrorHandler : public asmjit::ErrorHandler { +//! public: +//! inline LongJmpErrorHandler() : err(asmjit::kErrorOk) {} +//! +//! void handleError(asmjit::Error err, const char* message, asmjit::BaseEmitter* origin) override { +//! this->err = err; +//! longjmp(state, 1); +//! } +//! +//! jmp_buf state; +//! asmjit::Error err; +//! }; +//! +//! int main(int argc, char* argv[]) { +//! using namespace asmjit; +//! +//! JitRuntime rt; +//! LongJmpErrorHandler eh; +//! +//! CodeHolder code; +//! code.init(rt.environment(), rt.cpuFeatures()); +//! code.setErrorHandler(&eh); +//! +//! x86::Assembler a(&code); +//! +//! if (!setjmp(eh.state)) { +//! // Try to emit instruction that doesn't exist. +//! a.emit(x86::Inst::kIdMov, x86::xmm0, x86::xmm1); +//! } +//! else { +//! Error err = eh.err; +//! printf("ASMJIT ERROR: 0x%08X [%s]\n", err, DebugUtils::errorAsString(err)); +//! } +//! +//! return 0; +//! } +//! ``` +class ASMJIT_VIRTAPI ErrorHandler { +public: + ASMJIT_BASE_CLASS(ErrorHandler) + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `ErrorHandler` instance. + ASMJIT_API ErrorHandler() noexcept; + //! Destroys the `ErrorHandler` instance. + ASMJIT_API virtual ~ErrorHandler() noexcept; + + //! \} + + //! \name Interface + //! \{ + + //! Error handler (must be reimplemented). + //! + //! Error handler is called after an error happened and before it's propagated to the caller. There are multiple + //! ways how the error handler can be used: + //! + //! 1. User-based error handling without throwing exception or using C's`longjmp()`. This is for users that don't + //! use exceptions and want customized error handling. + //! + //! 2. Throwing an exception. AsmJit doesn't use exceptions and is completely exception-safe, but you can throw + //! exception from your error handler if this way is the preferred way of handling errors in your project. + //! + //! 3. Using plain old C's `setjmp()` and `longjmp()`. Asmjit always puts `BaseEmitter` to a consistent state before + //! calling `handleError()` so `longjmp()` can be used without any issues to cancel the code generation if an + //! error occurred. There is no difference between exceptions and `longjmp()` from AsmJit's perspective, however, + //! never jump outside of `CodeHolder` and `BaseEmitter` scope as you would leak memory. + ASMJIT_API virtual void handleError(Error err, const char* message, BaseEmitter* origin); + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ERRORHANDLER_H_INCLUDED + diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/formatter.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/formatter.h new file mode 100644 index 0000000000000000000000000000000000000000..624fd691cd68be5c36cc2e5063a8c94af2ab9485 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/formatter.h @@ -0,0 +1,249 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_FORMATTER_H_INCLUDED +#define ASMJIT_CORE_FORMATTER_H_INCLUDED + +#include "../core/globals.h" +#include "../core/inst.h" +#include "../core/string.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_logging +//! \{ + +class BaseBuilder; +class BaseEmitter; +class BaseNode; +struct Operand_; + +//! Format flags used by \ref Logger and \ref FormatOptions. +enum class FormatFlags : uint32_t { + //! No formatting flags. + kNone = 0u, + + //! Show also binary form of each logged instruction (Assembler). + kMachineCode = 0x00000001u, + //! Show a text explanation of some immediate values. + kExplainImms = 0x00000002u, + //! Use hexadecimal notation of immediate values. + kHexImms = 0x00000004u, + //! Use hexadecimal notation of addresses and offsets in addresses. + kHexOffsets = 0x00000008u, + //! Show casts between virtual register types (Compiler output). + kRegCasts = 0x00000010u, + //! Show positions associated with nodes (Compiler output). + kPositions = 0x00000020u, + //! Always format a register type (Compiler output). + kRegType = 0x00000040u +}; +ASMJIT_DEFINE_ENUM_FLAGS(FormatFlags) + +//! Format indentation group, used by \ref FormatOptions. +enum class FormatIndentationGroup : uint32_t { + //! Indentation used for instructions and directives. + kCode = 0u, + //! Indentation used for labels and function nodes. + kLabel = 1u, + //! Indentation used for comments (not inline comments). + kComment = 2u, + + //! \cond INTERNAL + //! Reserved for future use. + kReserved = 3u, + //! \endcond + + //! Maximum value of `FormatIndentationGroup`. + kMaxValue = kReserved +}; + +//! Format padding group, used by \ref FormatOptions. +enum class FormatPaddingGroup : uint32_t { + //! Describes padding of a regular line, which can represent instruction, data, or assembler directives. + kRegularLine = 0, + //! Describes padding of machine code dump that is visible next to the instruction, if enabled. + kMachineCode = 1, + + //! Maximum value of `FormatPaddingGroup`. + kMaxValue = kMachineCode +}; + +//! Formatting options used by \ref Logger and \ref Formatter. +class FormatOptions { +public: + //! \name Members + //! \{ + + //! Format flags. + FormatFlags _flags = FormatFlags::kNone; + //! Indentations for each indentation group. + Support::Array _indentation {}; + //! Paddings for each padding group. + Support::Array _padding {}; + + //! \} + + //! \name Reset + //! \{ + + //! Resets FormatOptions to its default initialized state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _flags = FormatFlags::kNone; + _indentation.fill(uint8_t(0)); + _padding.fill(uint16_t(0)); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns format flags. + ASMJIT_INLINE_NODEBUG FormatFlags flags() const noexcept { return _flags; } + //! Tests whether the given `flag` is set in format flags. + ASMJIT_INLINE_NODEBUG bool hasFlag(FormatFlags flag) const noexcept { return Support::test(_flags, flag); } + + //! Resets all format flags to `flags`. + ASMJIT_INLINE_NODEBUG void setFlags(FormatFlags flags) noexcept { _flags = flags; } + //! Adds `flags` to format flags. + ASMJIT_INLINE_NODEBUG void addFlags(FormatFlags flags) noexcept { _flags |= flags; } + //! Removes `flags` from format flags. + ASMJIT_INLINE_NODEBUG void clearFlags(FormatFlags flags) noexcept { _flags &= ~flags; } + + //! Returns indentation for the given indentation `group`. + ASMJIT_INLINE_NODEBUG uint8_t indentation(FormatIndentationGroup group) const noexcept { return _indentation[group]; } + //! Sets indentation for the given indentation `group`. + ASMJIT_INLINE_NODEBUG void setIndentation(FormatIndentationGroup group, uint32_t n) noexcept { _indentation[group] = uint8_t(n); } + //! Resets indentation for the given indentation `group` to zero. + ASMJIT_INLINE_NODEBUG void resetIndentation(FormatIndentationGroup group) noexcept { _indentation[group] = uint8_t(0); } + + //! Returns padding for the given padding `group`. + ASMJIT_INLINE_NODEBUG size_t padding(FormatPaddingGroup group) const noexcept { return _padding[group]; } + //! Sets padding for the given padding `group`. + ASMJIT_INLINE_NODEBUG void setPadding(FormatPaddingGroup group, size_t n) noexcept { _padding[group] = uint16_t(n); } + //! Resets padding for the given padding `group` to zero, which means that a default padding will be used + //! based on the target architecture properties. + ASMJIT_INLINE_NODEBUG void resetPadding(FormatPaddingGroup group) noexcept { _padding[group] = uint16_t(0); } + + //! \} +}; + +//! Provides formatting functionality to format operands, instructions, and nodes. +namespace Formatter { + +#ifndef ASMJIT_NO_LOGGING + +//! Appends a formatted `typeId` to the output string `sb`. +ASMJIT_API Error formatTypeId( + String& sb, + TypeId typeId) noexcept; + +//! Appends a formatted `featureId` to the output string `sb`. +//! +//! See \ref CpuFeatures. +ASMJIT_API Error formatFeature( + String& sb, + Arch arch, + uint32_t featureId) noexcept; + +//! Appends a formatted register to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format virtual registers, which won't be formatted properly +//! if the `emitter` is not provided. +ASMJIT_API Error formatRegister( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + RegType regType, + uint32_t regId) noexcept; + +//! Appends a formatted label to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format named labels properly, otherwise the formatted as +//! it is an anonymous label. +ASMJIT_API Error formatLabel( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + uint32_t labelId) noexcept; + +//! Appends a formatted operand to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format named labels and virtual registers. See +//! \ref formatRegister() and \ref formatLabel() for more details. +ASMJIT_API Error formatOperand( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + const Operand_& op) noexcept; + +//! Appends a formatted data-type to the output string `sb`. +ASMJIT_API Error formatDataType( + String& sb, + FormatFlags formatFlags, + Arch arch, + TypeId typeId) noexcept; + +//! Appends a formatted data to the output string `sb`. +ASMJIT_API Error formatData( + String& sb, + FormatFlags formatFlags, + Arch arch, + TypeId typeId, const void* data, size_t itemCount, size_t repeatCount = 1) noexcept; + +//! Appends a formatted instruction to the output string `sb`. +//! +//! \note Emitter is optional, but it's required to format named labels and virtual registers. See +//! \ref formatRegister() and \ref formatLabel() for more details. +ASMJIT_API Error formatInstruction( + String& sb, + FormatFlags formatFlags, + const BaseEmitter* emitter, + Arch arch, + const BaseInst& inst, const Operand_* operands, size_t opCount) noexcept; + +#ifndef ASMJIT_NO_BUILDER +//! Appends a formatted node to the output string `sb`. +//! +//! The `node` must belong to the provided `builder`. +ASMJIT_API Error formatNode( + String& sb, + const FormatOptions& formatOptions, + const BaseBuilder* builder, + const BaseNode* node) noexcept; + +//! Appends formatted nodes to the output string `sb`. +//! +//! All nodes that are part of the given `builder` will be appended. +ASMJIT_API Error formatNodeList( + String& sb, + const FormatOptions& formatOptions, + const BaseBuilder* builder) noexcept; + +//! Appends formatted nodes to the output string `sb`. +//! +//! This function works the same as \ref formatNode(), but appends more nodes to the output string, +//! separating each node with a newline '\n' character. +ASMJIT_API Error formatNodeList( + String& sb, + const FormatOptions& formatOptions, + const BaseBuilder* builder, + const BaseNode* begin, + const BaseNode* end) noexcept; +#endif + +#endif + +} // {Formatter} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_FORMATTER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/func.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/func.h new file mode 100644 index 0000000000000000000000000000000000000000..68159b7bf7cc09b90724be44fa82da797b7fe3ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/func.h @@ -0,0 +1,1595 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_FUNC_H_INCLUDED +#define ASMJIT_CORE_FUNC_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/environment.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_function +//! \{ + +//! Calling convention id. +//! +//! Calling conventions can be divided into the following groups: +//! +//! - Universal - calling conventions are applicable to any target. They will be converted to a target dependent +//! calling convention at runtime by \ref CallConv::init() with some help from \ref Environment. The purpose of +//! these calling conventions is to make using functions less target dependent and closer to C and C++. +//! +//! - Target specific - calling conventions that are used by a particular architecture and ABI. For example +//! Windows 64-bit calling convention and AMD64 SystemV calling convention. +enum class CallConvId : uint8_t { + // Universal Calling Conventions + // ----------------------------- + + //! Standard function call or explicit `__cdecl` where it can be specified. + //! + //! This is a universal calling convention, which is used to initialize specific calling conventions based on + //! architecture, platform, and its ABI. + kCDecl = 0, + + //! `__stdcall` on targets that support this calling convention (X86). + //! + //! \note This calling convention is only supported on 32-bit X86. If used on environment that doesn't support + //! this calling convention it will be replaced by \ref CallConvId::kCDecl. + kStdCall = 1, + + //! `__fastcall` on targets that support this calling convention (X86). + //! + //! \note This calling convention is only supported on 32-bit X86. If used on environment that doesn't support + //! this calling convention it will be replaced by \ref CallConvId::kCDecl. + kFastCall = 2, + + //! `__vectorcall` on targets that support this calling convention (X86/X64). + //! + //! \note This calling convention is only supported on 32-bit and 64-bit X86 architecture on Windows platform. + //! If used on environment that doesn't support this calling it will be replaced by \ref CallConvId::kCDecl. + kVectorCall = 3, + + //! `__thiscall` on targets that support this calling convention (X86). + //! + //! \note This calling convention is only supported on 32-bit X86 Windows platform. If used on environment that + //! doesn't support this calling convention it will be replaced by \ref CallConvId::kCDecl. + kThisCall = 4, + + //! `__attribute__((regparm(1)))` convention (GCC and Clang). + kRegParm1 = 5, + //! `__attribute__((regparm(2)))` convention (GCC and Clang). + kRegParm2 = 6, + //! `__attribute__((regparm(3)))` convention (GCC and Clang). + kRegParm3 = 7, + + //! AsmJit specific calling convention designed for calling functions inside a multimedia code that don't use many + //! registers internally, but are long enough to be called and not inlined. These functions are usually used to + //! calculate trigonometric functions, logarithms, etc... + kLightCall2 = 16, + kLightCall3 = 17, + kLightCall4 = 18, + + // ABI-Specific Calling Conventions + // -------------------------------- + + //! Soft-float calling convention (AArch32). + //! + //! Floating point arguments are passed via general purpose registers. + kSoftFloat = 30, + + //! Hard-float calling convention (AArch32). + //! + //! Floating point arguments are passed via SIMD registers. + kHardFloat = 31, + + //! X64 System-V calling convention. + kX64SystemV = 32, + //! X64 Windows calling convention. + kX64Windows = 33, + + //! Maximum value of `CallConvId`. + kMaxValue = kX64Windows + + // Deprecated Aliases + // ------------------ + +#if !defined(ASMJIT_NO_DEPRECATED) + , + kNone = kCDecl, + kHost = kCDecl +#endif // !ASMJIT_NO_DEPRECATED +}; + +//! Strategy used by calling conventions to assign registers to function arguments. +//! +//! Calling convention strategy describes how AsmJit should convert function arguments used by \ref FuncSignature +//! into register identifiers and stack offsets. The \ref CallConvStrategy::kDefault strategy assigns registers +//! and then stack whereas \ref CallConvStrategy::kX64Windows strategy does register shadowing as defined by WIN64 +//! calling convention, which is only used by 64-bit Windows. +enum class CallConvStrategy : uint8_t { + //! Default register assignment strategy. + kDefault = 0, + //! Windows 64-bit ABI register assignment strategy. + kX64Windows = 1, + //! Windows 64-bit __vectorcall register assignment strategy. + kX64VectorCall = 2, + //! Apple's AArch64 calling convention (differs compared to AArch64 calling convention used by Linux). + kAArch64Apple = 3, + + //! Maximum value of `CallConvStrategy`. + kMaxValue = kX64VectorCall +}; + +//! Calling convention flags. +enum class CallConvFlags : uint32_t { + //! No flags. + kNone = 0, + //! Callee is responsible for cleaning up the stack. + kCalleePopsStack = 0x0001u, + //! Pass vector arguments indirectly (as a pointer). + kIndirectVecArgs = 0x0002u, + //! Pass F32 and F64 arguments via VEC128 register. + kPassFloatsByVec = 0x0004u, + //! Pass MMX and vector arguments via stack if the function has variable arguments. + kPassVecByStackIfVA = 0x0008u, + //! MMX registers are passed and returned via GP registers. + kPassMmxByGp = 0x0010u, + //! MMX registers are passed and returned via XMM registers. + kPassMmxByXmm = 0x0020u, + //! Calling convention can be used with variable arguments. + kVarArgCompatible = 0x0080u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CallConvFlags) + +//! Function calling convention. +//! +//! Function calling convention is a scheme that defines how function parameters are passed and how function +//! returns its result. AsmJit defines a variety of architecture and OS specific calling conventions and also +//! provides a compile time detection to make the code-generation easier. +struct CallConv { + //! \name Constants + //! \{ + + //! Maximum number of register arguments per register group. + //! + //! \note This is not really AsmJit's limitation, it's just the number that makes sense considering all common + //! calling conventions. Usually even conventions that use registers to pass function arguments are limited to 8 + //! and less arguments passed via registers per group. + static constexpr uint32_t kMaxRegArgsPerGroup = 16; + + //! \} + + //! \name Members + //! \{ + + //! Target architecture. + Arch _arch; + //! Calling convention id. + CallConvId _id; + //! Register assignment strategy. + CallConvStrategy _strategy; + + //! Red zone size (AMD64 == 128 bytes). + uint8_t _redZoneSize; + //! Spill zone size (WIN-X64 == 32 bytes). + uint8_t _spillZoneSize; + //! Natural stack alignment as defined by OS/ABI. + uint8_t _naturalStackAlignment; + + //! \cond INTERNAL + //! Reserved for future use. + uint8_t _reserved[2]; + //! \endcond + + //! Calling convention flags. + CallConvFlags _flags; + + //! Size to save/restore per register group. + Support::Array _saveRestoreRegSize; + //! Alignment of save/restore groups. + Support::Array _saveRestoreAlignment; + + //! Mask of all passed registers, per group. + Support::Array _passedRegs; + //! Mask of all preserved registers, per group. + Support::Array _preservedRegs; + + //! Passed registers' order. + union RegOrder { + //! Passed registers, ordered. + uint8_t id[kMaxRegArgsPerGroup]; + //! Packed IDs in `uint32_t` array. + uint32_t packed[(kMaxRegArgsPerGroup + 3) / 4]; + }; + + //! Passed registers' order, per register group. + Support::Array _passedOrder; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Initializes this calling convention to the given `ccId` based on the `environment`. + //! + //! See \ref CallConvId and \ref Environment for more details. + ASMJIT_API Error init(CallConvId ccId, const Environment& environment) noexcept; + + //! Resets this CallConv struct into a defined state. + //! + //! It's recommended to reset the \ref CallConv struct in case you would like create a custom calling convention + //! as it prevents from using an uninitialized data (CallConv doesn't have a constructor that would initialize it, + //! it's just a struct). + ASMJIT_INLINE_NODEBUG void reset() noexcept { + *this = CallConv{}; + memset(_passedOrder.data(), 0xFF, sizeof(_passedOrder)); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the target architecture of this calling convention. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + //! Sets the target architecture of this calling convention. + ASMJIT_INLINE_NODEBUG void setArch(Arch arch) noexcept { _arch = arch; } + + //! Returns the calling convention id. + ASMJIT_INLINE_NODEBUG CallConvId id() const noexcept { return _id; } + //! Sets the calling convention id. + ASMJIT_INLINE_NODEBUG void setId(CallConvId ccId) noexcept { _id = ccId; } + + //! Returns the strategy used to assign registers to arguments. + ASMJIT_INLINE_NODEBUG CallConvStrategy strategy() const noexcept { return _strategy; } + //! Sets the strategy used to assign registers to arguments. + ASMJIT_INLINE_NODEBUG void setStrategy(CallConvStrategy ccStrategy) noexcept { _strategy = ccStrategy; } + + //! Tests whether the calling convention has the given `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(CallConvFlags flag) const noexcept { return Support::test(_flags, flag); } + //! Returns the calling convention flags, see `Flags`. + ASMJIT_INLINE_NODEBUG CallConvFlags flags() const noexcept { return _flags; } + //! Adds the calling convention flags, see `Flags`. + ASMJIT_INLINE_NODEBUG void setFlags(CallConvFlags flag) noexcept { _flags = flag; }; + //! Adds the calling convention flags, see `Flags`. + ASMJIT_INLINE_NODEBUG void addFlags(CallConvFlags flags) noexcept { _flags |= flags; }; + + //! Tests whether this calling convention specifies 'RedZone'. + ASMJIT_INLINE_NODEBUG bool hasRedZone() const noexcept { return _redZoneSize != 0; } + //! Tests whether this calling convention specifies 'SpillZone'. + ASMJIT_INLINE_NODEBUG bool hasSpillZone() const noexcept { return _spillZoneSize != 0; } + + //! Returns size of 'RedZone'. + ASMJIT_INLINE_NODEBUG uint32_t redZoneSize() const noexcept { return _redZoneSize; } + //! Returns size of 'SpillZone'. + ASMJIT_INLINE_NODEBUG uint32_t spillZoneSize() const noexcept { return _spillZoneSize; } + + //! Sets size of 'RedZone'. + ASMJIT_INLINE_NODEBUG void setRedZoneSize(uint32_t size) noexcept { _redZoneSize = uint8_t(size); } + //! Sets size of 'SpillZone'. + ASMJIT_INLINE_NODEBUG void setSpillZoneSize(uint32_t size) noexcept { _spillZoneSize = uint8_t(size); } + + //! Returns a natural stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t naturalStackAlignment() const noexcept { return _naturalStackAlignment; } + //! Sets a natural stack alignment. + //! + //! This function can be used to override the default stack alignment in case that you know that it's alignment is + //! different. For example it allows to implement custom calling conventions that guarantee higher stack alignment. + ASMJIT_INLINE_NODEBUG void setNaturalStackAlignment(uint32_t value) noexcept { _naturalStackAlignment = uint8_t(value); } + + //! Returns the size of a register (or its part) to be saved and restored of the given `group`. + ASMJIT_INLINE_NODEBUG uint32_t saveRestoreRegSize(RegGroup group) const noexcept { return _saveRestoreRegSize[group]; } + //! Sets the size of a vector register (or its part) to be saved and restored. + ASMJIT_INLINE_NODEBUG void setSaveRestoreRegSize(RegGroup group, uint32_t size) noexcept { _saveRestoreRegSize[group] = uint8_t(size); } + + //! Returns the alignment of a save-restore area of the given `group`. + ASMJIT_INLINE_NODEBUG uint32_t saveRestoreAlignment(RegGroup group) const noexcept { return _saveRestoreAlignment[group]; } + //! Sets the alignment of a save-restore area of the given `group`. + ASMJIT_INLINE_NODEBUG void setSaveRestoreAlignment(RegGroup group, uint32_t alignment) noexcept { _saveRestoreAlignment[group] = uint8_t(alignment); } + + //! Returns the order of passed registers of the given `group`. + inline const uint8_t* passedOrder(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _passedOrder[size_t(group)].id; + } + + //! Returns the mask of passed registers of the given `group`. + inline RegMask passedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _passedRegs[size_t(group)]; + } + + inline void _setPassedPacked(RegGroup group, uint32_t p0, uint32_t p1, uint32_t p2, uint32_t p3) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + + _passedOrder[group].packed[0] = p0; + _passedOrder[group].packed[1] = p1; + _passedOrder[group].packed[2] = p2; + _passedOrder[group].packed[3] = p3; + } + + //! Resets the order and mask of passed registers. + inline void setPassedToNone(RegGroup group) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + + _setPassedPacked(group, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu, 0xFFFFFFFFu); + _passedRegs[size_t(group)] = 0u; + } + + //! Sets the order and mask of passed registers. + inline void setPassedOrder(RegGroup group, uint32_t a0, uint32_t a1 = 0xFF, uint32_t a2 = 0xFF, uint32_t a3 = 0xFF, uint32_t a4 = 0xFF, uint32_t a5 = 0xFF, uint32_t a6 = 0xFF, uint32_t a7 = 0xFF) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + + // NOTE: This should always be called with all arguments known at compile time, so even if it looks scary it + // should be translated into few instructions. + _setPassedPacked(group, Support::bytepack32_4x8(a0, a1, a2, a3), + Support::bytepack32_4x8(a4, a5, a6, a7), + 0xFFFFFFFFu, + 0xFFFFFFFFu); + + _passedRegs[group] = (a0 != 0xFF ? 1u << a0 : 0u) | + (a1 != 0xFF ? 1u << a1 : 0u) | + (a2 != 0xFF ? 1u << a2 : 0u) | + (a3 != 0xFF ? 1u << a3 : 0u) | + (a4 != 0xFF ? 1u << a4 : 0u) | + (a5 != 0xFF ? 1u << a5 : 0u) | + (a6 != 0xFF ? 1u << a6 : 0u) | + (a7 != 0xFF ? 1u << a7 : 0u) ; + } + + //! Returns preserved register mask of the given `group`. + inline RegMask preservedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _preservedRegs[group]; + } + + //! Sets preserved register mask of the given `group`. + inline void setPreservedRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _preservedRegs[group] = regs; + } + + //! \} +}; + +//! Function signature. +//! +//! Contains information about a function return type, count of arguments, and their TypeIds. Function signature +//! is a low level structure which doesn't contain platform specific or calling convention specific information. +//! It's typically used to describe function arguments in a C-API like form, which is then used to calculate a +//! \ref FuncDetail instance, which then maps function signature into a platform and calling convention specific +//! format. +//! +//! Function signature can be built either dynamically by using \ref addArg() and \ref addArgT() functionality, +//! or dynamically by using a template-based \ref FuncSignature::build() function, which maps template types +//! into a function signature. +struct FuncSignature { + //! \name Constants + //! \{ + + //! Doesn't have variable number of arguments (`...`). + static constexpr uint8_t kNoVarArgs = 0xFFu; + + //! \} + + //! \name Members + //! \{ + + //! Calling convention id. + CallConvId _ccId = CallConvId::kCDecl; + //! Count of arguments. + uint8_t _argCount = 0; + //! Index of a first VA or `kNoVarArgs`. + uint8_t _vaIndex = kNoVarArgs; + //! Return value TypeId. + TypeId _ret = TypeId::kVoid; + //! Reserved for future use. + uint8_t _reserved[4] {}; + //! Function argument TypeIds. + TypeId _args[Globals::kMaxFuncArgs] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Default constructed function signature, initialized to \ref CallConvId::kCDecl, having no return value and no arguments. + ASMJIT_FORCE_INLINE constexpr FuncSignature() = default; + + //! Copy constructor, which is initialized to the same function signature as `other`. + ASMJIT_FORCE_INLINE constexpr FuncSignature(const FuncSignature& other) = default; + + //! Initializes the function signature with calling convention id `ccId` and variable argument's index `vaIndex`. + ASMJIT_FORCE_INLINE constexpr FuncSignature(CallConvId ccId, uint32_t vaIndex = kNoVarArgs) noexcept + : _ccId(ccId), + _vaIndex(uint8_t(vaIndex)) {} + + //! Initializes the function signature with calling convention id `ccId`, `vaIndex`, return value, and function arguments. + template + ASMJIT_FORCE_INLINE constexpr FuncSignature(CallConvId ccId, uint32_t vaIndex, TypeId ret, Args&&...args) noexcept + : _ccId(ccId), + _argCount(uint8_t(sizeof...(args))), + _vaIndex(uint8_t(vaIndex)), + _ret(ret), + _args{std::forward(args)...} {} + + //! Builds a function signature based on `RetValueAndArgs`. The first template argument is a function return type, + //! and function arguments follow. + //! + //! \note This function returns a new function signature, which can be passed to functions where it's required. It's + //! a convenience function that allows to build function signature statically based on types known at compile time, + //! which is common in JIT code generation. + template + static ASMJIT_INLINE_NODEBUG constexpr FuncSignature build(CallConvId ccId = CallConvId::kCDecl, uint32_t vaIndex = kNoVarArgs) noexcept { + return FuncSignature(ccId, vaIndex, (TypeId(TypeUtils::TypeIdOfT::kTypeId))... ); + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment - function signature can be copied by value. + ASMJIT_FORCE_INLINE FuncSignature& operator=(const FuncSignature& other) noexcept = default; + + //! Compares this function signature with `other` for equality.. + ASMJIT_FORCE_INLINE bool operator==(const FuncSignature& other) const noexcept { return equals(other); } + //! Compares this function signature with `other` for inequality.. + ASMJIT_FORCE_INLINE bool operator!=(const FuncSignature& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Resets this function signature to a default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = FuncSignature{}; } + + //! \} + + //! \name Equality & Comparison + //! \{ + + //! Compares this function signature with `other` for equality.. + ASMJIT_INLINE_NODEBUG bool equals(const FuncSignature& other) const noexcept { + return _ccId == other._ccId && + _argCount == other._argCount && + _vaIndex == other._vaIndex && + _ret == other._ret && + memcmp(_args, other._args, sizeof(_args)) == 0; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the calling convention. + ASMJIT_INLINE_NODEBUG CallConvId callConvId() const noexcept { return _ccId; } + //! Sets the calling convention to `ccId`; + ASMJIT_INLINE_NODEBUG void setCallConvId(CallConvId ccId) noexcept { _ccId = ccId; } + + //! Tests whether the function signature has a return value. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return _ret != TypeId::kVoid; } + //! Returns the type of the return value. + ASMJIT_INLINE_NODEBUG TypeId ret() const noexcept { return _ret; } + //! Sets the return type to `retType`. + ASMJIT_INLINE_NODEBUG void setRet(TypeId retType) noexcept { _ret = retType; } + //! Sets the return type based on `T`. + template + ASMJIT_INLINE_NODEBUG void setRetT() noexcept { setRet(TypeId(TypeUtils::TypeIdOfT::kTypeId)); } + + + //! Returns the array of function arguments' types. + ASMJIT_INLINE_NODEBUG const TypeId* args() const noexcept { return _args; } + //! Returns the number of function arguments. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _argCount; } + + //! Returns the type of the argument at index `i`. + inline TypeId arg(uint32_t i) const noexcept { + ASMJIT_ASSERT(i < _argCount); + return _args[i]; + } + + //! Sets the argument at index `index` to `argType`. + inline void setArg(uint32_t index, TypeId argType) noexcept { + ASMJIT_ASSERT(index < _argCount); + _args[index] = argType; + } + //! Sets the argument at index `i` to the type based on `T`. + template + inline void setArgT(uint32_t index) noexcept { setArg(index, TypeId(TypeUtils::TypeIdOfT::kTypeId)); } + + //! Tests whether an argument can be added to the signature, use before calling \ref addArg() and \ref addArgT(). + //! + //! \note If you know that you are not adding more arguments than \ref Globals::kMaxFuncArgs then it's not necessary + //! to use this function. However, if you are adding arguments based on user input, for example, then either check + //! the number of arguments before using function signature or use \ref canAddArg() before actually adding them to + //! the function signature. + inline bool canAddArg() const noexcept { return _argCount < Globals::kMaxFuncArgs; } + + //! Appends an argument of `type` to the function prototype. + inline void addArg(TypeId type) noexcept { + ASMJIT_ASSERT(_argCount < Globals::kMaxFuncArgs); + _args[_argCount++] = type; + } + + //! Appends an argument of type based on `T` to the function prototype. + template + inline void addArgT() noexcept { addArg(TypeId(TypeUtils::TypeIdOfT::kTypeId)); } + + //! Tests whether the function has variable number of arguments (...). + ASMJIT_INLINE_NODEBUG bool hasVarArgs() const noexcept { return _vaIndex != kNoVarArgs; } + //! Returns the variable arguments (...) index, `kNoVarArgs` if none. + ASMJIT_INLINE_NODEBUG uint32_t vaIndex() const noexcept { return _vaIndex; } + //! Sets the variable arguments (...) index to `index`. + ASMJIT_INLINE_NODEBUG void setVaIndex(uint32_t index) noexcept { _vaIndex = uint8_t(index); } + //! Resets the variable arguments index (making it a non-va function). + ASMJIT_INLINE_NODEBUG void resetVaIndex() noexcept { _vaIndex = kNoVarArgs; } + + //! \} +}; + +#if !defined(ASMJIT_NO_DEPRECATED) +template +class FuncSignatureT : public FuncSignature { +public: + ASMJIT_DEPRECATED("Use FuncSignature::build() instead") + ASMJIT_INLINE_NODEBUG constexpr FuncSignatureT(CallConvId ccId = CallConvId::kCDecl, uint32_t vaIndex = kNoVarArgs) noexcept + : FuncSignature(ccId, vaIndex, (TypeId(TypeUtils::TypeIdOfT::kTypeId))... ) {} +}; + +ASMJIT_DEPRECATED("Use FuncSignature instead of FuncSignatureBuilder") +typedef FuncSignature FuncSignatureBuilder; +#endif // !ASMJIT_NO_DEPRECATED + +//! Argument or return value (or its part) as defined by `FuncSignature`, but with register or stack address +//! (and other metadata) assigned. +struct FuncValue { + //! \name Constants + //! \{ + + enum Bits : uint32_t { + kTypeIdShift = 0, //!< TypeId shift. + kTypeIdMask = 0x000000FFu, //!< TypeId mask. + + kFlagIsReg = 0x00000100u, //!< Passed by register. + kFlagIsStack = 0x00000200u, //!< Passed by stack. + kFlagIsIndirect = 0x00000400u, //!< Passed indirectly by reference (internally a pointer). + kFlagIsDone = 0x00000800u, //!< Used internally by arguments allocator. + + kStackOffsetShift = 12, //!< Stack offset shift. + kStackOffsetMask = 0xFFFFF000u, //!< Stack offset mask (must occupy MSB bits). + + kRegIdShift = 16, //!< RegId shift. + kRegIdMask = 0x00FF0000u, //!< RegId mask. + + kRegTypeShift = 24, //!< RegType shift. + kRegTypeMask = 0xFF000000u //!< RegType mask. + }; + + //! \} + + //! \name Members + //! \{ + + uint32_t _data; + + //! \} + + //! \name Initialization & Reset + //! + //! These initialize the whole `FuncValue` to either register or stack. Useful when you know all of these + //! properties and wanna just set it up. + //! + //! \{ + + //! Initializes this `FuncValue` only to the `typeId` provided - the rest of the values will be cleared. + ASMJIT_INLINE_NODEBUG void initTypeId(TypeId typeId) noexcept { + _data = uint32_t(typeId) << kTypeIdShift; + } + + //! Initializes this `FuncValue` to a register of `regType`, `regId`, and assigns its `typeId` and `flags`. + ASMJIT_INLINE_NODEBUG void initReg(RegType regType, uint32_t regId, TypeId typeId, uint32_t flags = 0) noexcept { + _data = (uint32_t(regType) << kRegTypeShift) | (regId << kRegIdShift) | (uint32_t(typeId) << kTypeIdShift) | kFlagIsReg | flags; + } + + //! Initializes this `FuncValue` to a stack at the given `offset` and assigns its `typeId`. + ASMJIT_INLINE_NODEBUG void initStack(int32_t offset, TypeId typeId) noexcept { + _data = (uint32_t(offset) << kStackOffsetShift) | (uint32_t(typeId) << kTypeIdShift) | kFlagIsStack; + } + + //! Resets the value to its unassigned state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { _data = 0; } + + //! \} + + //! \name Assign + //! + //! These initialize only part of `FuncValue`, useful when building `FuncValue` incrementally. The caller + //! should first init the type-id by calling `initTypeId` and then continue building either register or stack. + //! + //! \{ + + //! Assigns a register of `regType` and `regId`. + inline void assignRegData(RegType regType, uint32_t regId) noexcept { + ASMJIT_ASSERT((_data & (kRegTypeMask | kRegIdMask)) == 0); + _data |= (uint32_t(regType) << kRegTypeShift) | (regId << kRegIdShift) | kFlagIsReg; + } + + //! Assigns a stack location at `offset`. + inline void assignStackOffset(int32_t offset) noexcept { + ASMJIT_ASSERT((_data & kStackOffsetMask) == 0); + _data |= (uint32_t(offset) << kStackOffsetShift) | kFlagIsStack; + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns true if the value is initialized (explicit bool cast). + ASMJIT_INLINE_NODEBUG explicit operator bool() const noexcept { return _data != 0; } + + //! \cond INTERNAL + ASMJIT_INLINE_NODEBUG void _replaceValue(uint32_t mask, uint32_t value) noexcept { _data = (_data & ~mask) | value; } + //! \endcond + + //! Tests whether the `FuncValue` has a flag `flag` set. + ASMJIT_INLINE_NODEBUG bool hasFlag(uint32_t flag) const noexcept { return Support::test(_data, flag); } + //! Adds `flags` to `FuncValue`. + ASMJIT_INLINE_NODEBUG void addFlags(uint32_t flags) noexcept { _data |= flags; } + //! Clears `flags` of `FuncValue`. + ASMJIT_INLINE_NODEBUG void clearFlags(uint32_t flags) noexcept { _data &= ~flags; } + + //! Tests whether the value is initialized (i.e. contains a valid data). + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _data != 0; } + //! Tests whether the argument is passed by register. + ASMJIT_INLINE_NODEBUG bool isReg() const noexcept { return hasFlag(kFlagIsReg); } + //! Tests whether the argument is passed by stack. + ASMJIT_INLINE_NODEBUG bool isStack() const noexcept { return hasFlag(kFlagIsStack); } + //! Tests whether the argument is passed by register. + ASMJIT_INLINE_NODEBUG bool isAssigned() const noexcept { return hasFlag(kFlagIsReg | kFlagIsStack); } + //! Tests whether the argument is passed through a pointer (used by WIN64 to pass XMM|YMM|ZMM). + ASMJIT_INLINE_NODEBUG bool isIndirect() const noexcept { return hasFlag(kFlagIsIndirect); } + + //! Tests whether the argument was already processed (used internally). + ASMJIT_INLINE_NODEBUG bool isDone() const noexcept { return hasFlag(kFlagIsDone); } + + //! Returns a register type of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG RegType regType() const noexcept { return RegType((_data & kRegTypeMask) >> kRegTypeShift); } + //! Sets a register type of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG void setRegType(RegType regType) noexcept { _replaceValue(kRegTypeMask, uint32_t(regType) << kRegTypeShift); } + + //! Returns a physical id of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG uint32_t regId() const noexcept { return (_data & kRegIdMask) >> kRegIdShift; } + //! Sets a physical id of the register used to pass function argument or return value. + ASMJIT_INLINE_NODEBUG void setRegId(uint32_t regId) noexcept { _replaceValue(kRegIdMask, regId << kRegIdShift); } + + //! Returns a stack offset of this argument. + ASMJIT_INLINE_NODEBUG int32_t stackOffset() const noexcept { return int32_t(_data & kStackOffsetMask) >> kStackOffsetShift; } + //! Sets a stack offset of this argument. + ASMJIT_INLINE_NODEBUG void setStackOffset(int32_t offset) noexcept { _replaceValue(kStackOffsetMask, uint32_t(offset) << kStackOffsetShift); } + + //! Tests whether the argument or return value has associated `TypeId`. + ASMJIT_INLINE_NODEBUG bool hasTypeId() const noexcept { return Support::test(_data, kTypeIdMask); } + //! Returns a TypeId of this argument or return value. + ASMJIT_INLINE_NODEBUG TypeId typeId() const noexcept { return TypeId((_data & kTypeIdMask) >> kTypeIdShift); } + //! Sets a TypeId of this argument or return value. + ASMJIT_INLINE_NODEBUG void setTypeId(TypeId typeId) noexcept { _replaceValue(kTypeIdMask, uint32_t(typeId) << kTypeIdShift); } + + //! \} +}; + +//! Contains multiple `FuncValue` instances in an array so functions that use multiple registers for arguments or +//! return values can represent all inputs and outputs. +struct FuncValuePack { +public: + //! \name Members + //! \{ + + //! Values of the pack. + FuncValue _values[Globals::kMaxValuePack]; + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Resets all values in the pack. + inline void reset() noexcept { + for (size_t i = 0; i < Globals::kMaxValuePack; i++) + _values[i].reset(); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Calculates how many values are in the pack, checking for non-values from the end. + inline uint32_t count() const noexcept { + uint32_t n = Globals::kMaxValuePack; + while (n && !_values[n - 1]) + n--; + return n; + } + + //! Returns values in this value in the pack. + //! + //! \note The returned array has exactly \ref Globals::kMaxValuePack elements. + ASMJIT_INLINE_NODEBUG FuncValue* values() noexcept { return _values; } + //! \overload + ASMJIT_INLINE_NODEBUG const FuncValue* values() const noexcept { return _values; } + + //! Resets a value at the given `index` in the pack, which makes it unassigned. + inline void resetValue(size_t index) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + _values[index].reset(); + } + + //! Tests whether the value at the given `index` in the pack is assigned. + inline bool hasValue(size_t index) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + return _values[index].isInitialized(); + } + + //! Assigns a register at the given `index` to `reg` and an optional `typeId`. + inline void assignReg(size_t index, const BaseReg& reg, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + ASMJIT_ASSERT(reg.isPhysReg()); + _values[index].initReg(reg.type(), reg.id(), typeId); + } + + //! Assigns a register at the given `index` to `regType`, `regId`, and an optional `typeId`. + inline void assignReg(size_t index, RegType regType, uint32_t regId, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + _values[index].initReg(regType, regId, typeId); + } + + //! Assigns a stack location at the given `index` to `offset` and an optional `typeId`. + inline void assignStack(size_t index, int32_t offset, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + _values[index].initStack(offset, typeId); + } + + //! Accesses the value in the pack at the given `index`. + //! + //! \note The maximum index value is `Globals::kMaxValuePack - 1`. + inline FuncValue& operator[](size_t index) { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + return _values[index]; + } + //! \overload + inline const FuncValue& operator[](size_t index) const { + ASMJIT_ASSERT(index < Globals::kMaxValuePack); + return _values[index]; + } + + //! \} +}; + +//! Attributes are designed in a way that all are initially false, and user or \ref FuncFrame finalizer adds +//! them when necessary. +enum class FuncAttributes : uint32_t { + //! No attributes. + kNoAttributes = 0, + + //! Function has variable number of arguments. + kHasVarArgs = 0x00000001u, + //! Preserve frame pointer (don't omit FP). + kHasPreservedFP = 0x00000010u, + //! Function calls other functions (is not leaf). + kHasFuncCalls = 0x00000020u, + //! Function has aligned save/restore of vector registers. + kAlignedVecSR = 0x00000040u, + //! Function must begin with an instruction that marks a start of a branch or function. + //! + //! * `ENDBR32/ENDBR64` instruction is inserted at the beginning of the function (X86, X86_64). + //! * `BTI` instruction is inserted at the beginning of the function (AArch64) + kIndirectBranchProtection = 0x00000080u, + //! FuncFrame is finalized and can be used by prolog/epilog inserter (PEI). + kIsFinalized = 0x00000800u, + + // X86 Specific Attributes + // ----------------------- + + //! Enables the use of AVX within the function's body, prolog, and epilog (X86). + //! + //! This flag instructs prolog and epilog emitter to use AVX instead of SSE for manipulating XMM registers. + kX86_AVXEnabled = 0x00010000u, + + //! Enables the use of AVX-512 within the function's body, prolog, and epilog (X86). + //! + //! This flag instructs Compiler register allocator to use additional 16 registers introduced by AVX-512. + //! Additionally, if the functions saves full width of ZMM registers (custom calling conventions only) then + //! the prolog/epilog inserter would use AVX-512 move instructions to emit the save and restore sequence. + kX86_AVX512Enabled = 0x00020000u, + + //! This flag instructs the epilog writer to emit EMMS instruction before RET (X86). + kX86_MMXCleanup = 0x00040000u, + + //! This flag instructs the epilog writer to emit VZEROUPPER instruction before RET (X86). + kX86_AVXCleanup = 0x00080000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(FuncAttributes) + +//! Function detail - \ref CallConv and expanded \ref FuncSignature. +//! +//! Function detail is architecture and OS dependent representation of a function. It contains a materialized +//! calling convention and expanded function signature so all arguments have assigned either register type/id +//! or stack address. +class FuncDetail { +public: + //! \name Constants + //! \{ + + //! Function doesn't have a variable number of arguments (`...`). + static constexpr uint8_t kNoVarArgs = 0xFFu; + + //! \} + + //! \name Members + //! \{ + + //! Calling convention. + CallConv _callConv {}; + //! Number of function arguments. + uint8_t _argCount = 0; + //! Variable arguments index of `kNoVarArgs`. + uint8_t _vaIndex = 0; + //! Reserved for future use. + uint16_t _reserved = 0; + //! Registers that contain arguments. + Support::Array _usedRegs {}; + //! Size of arguments passed by stack. + uint32_t _argStackSize = 0; + //! Function return value(s). + FuncValuePack _rets {}; + //! Function arguments. + FuncValuePack _args[Globals::kMaxFuncArgs] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default constructed \ref FuncDetail. + ASMJIT_INLINE_NODEBUG FuncDetail() noexcept {} + + //! Copy constructor. + //! + //! Function details are copyable. + ASMJIT_INLINE_NODEBUG FuncDetail(const FuncDetail& other) noexcept = default; + + //! Initializes this `FuncDetail` to the given signature. + ASMJIT_API Error init(const FuncSignature& signature, const Environment& environment) noexcept; + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Assignment operator, copies `other` to this \ref FuncDetail. + ASMJIT_INLINE_NODEBUG FuncDetail& operator=(const FuncDetail& other) noexcept = default; + + //! \} + + //! \name Reset + //! \{ + + //! Resets the function detail to its default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = FuncDetail{}; } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the function's calling convention, see `CallConv`. + ASMJIT_INLINE_NODEBUG const CallConv& callConv() const noexcept { return _callConv; } + + //! Returns the associated calling convention flags, see `CallConv::Flags`. + ASMJIT_INLINE_NODEBUG CallConvFlags flags() const noexcept { return _callConv.flags(); } + //! Checks whether a CallConv `flag` is set, see `CallConv::Flags`. + ASMJIT_INLINE_NODEBUG bool hasFlag(CallConvFlags ccFlag) const noexcept { return _callConv.hasFlag(ccFlag); } + + //! Tests whether the function has a return value. + ASMJIT_INLINE_NODEBUG bool hasRet() const noexcept { return bool(_rets[0]); } + //! Returns the number of function arguments. + ASMJIT_INLINE_NODEBUG uint32_t argCount() const noexcept { return _argCount; } + + //! Returns function return values. + ASMJIT_INLINE_NODEBUG FuncValuePack& retPack() noexcept { return _rets; } + //! Returns function return values. + ASMJIT_INLINE_NODEBUG const FuncValuePack& retPack() const noexcept { return _rets; } + + //! Returns a function return value associated with the given `valueIndex`. + ASMJIT_INLINE_NODEBUG FuncValue& ret(size_t valueIndex = 0) noexcept { return _rets[valueIndex]; } + //! Returns a function return value associated with the given `valueIndex` (const). + ASMJIT_INLINE_NODEBUG const FuncValue& ret(size_t valueIndex = 0) const noexcept { return _rets[valueIndex]; } + + //! Returns function argument packs array. + ASMJIT_INLINE_NODEBUG FuncValuePack* argPacks() noexcept { return _args; } + //! Returns function argument packs array (const). + ASMJIT_INLINE_NODEBUG const FuncValuePack* argPacks() const noexcept { return _args; } + + //! Returns function argument pack at the given `argIndex`. + inline FuncValuePack& argPack(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex]; + } + + //! Returns function argument pack at the given `argIndex` (const). + inline const FuncValuePack& argPack(size_t argIndex) const noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex]; + } + + //! Returns an argument at `valueIndex` from the argument pack at the given `argIndex`. + inline FuncValue& arg(size_t argIndex, size_t valueIndex = 0) noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex][valueIndex]; + } + + //! Returns an argument at `valueIndex` from the argument pack at the given `argIndex` (const). + inline const FuncValue& arg(size_t argIndex, size_t valueIndex = 0) const noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + return _args[argIndex][valueIndex]; + } + + //! Resets an argument at the given `argIndex`. + //! + //! If the argument is a parameter pack (has multiple values) all values are reset. + inline void resetArg(size_t argIndex) noexcept { + ASMJIT_ASSERT(argIndex < Globals::kMaxFuncArgs); + _args[argIndex].reset(); + } + + //! Tests whether the function has variable arguments. + ASMJIT_INLINE_NODEBUG bool hasVarArgs() const noexcept { return _vaIndex != kNoVarArgs; } + //! Returns an index of a first variable argument. + ASMJIT_INLINE_NODEBUG uint32_t vaIndex() const noexcept { return _vaIndex; } + + //! Tests whether the function passes one or more argument by stack. + ASMJIT_INLINE_NODEBUG bool hasStackArgs() const noexcept { return _argStackSize != 0; } + //! Returns stack size needed for function arguments passed on the stack. + ASMJIT_INLINE_NODEBUG uint32_t argStackSize() const noexcept { return _argStackSize; } + + //! Returns red zone size. + ASMJIT_INLINE_NODEBUG uint32_t redZoneSize() const noexcept { return _callConv.redZoneSize(); } + //! Returns spill zone size. + ASMJIT_INLINE_NODEBUG uint32_t spillZoneSize() const noexcept { return _callConv.spillZoneSize(); } + //! Returns natural stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t naturalStackAlignment() const noexcept { return _callConv.naturalStackAlignment(); } + + //! Returns a mask of all passed registers of the given register `group`. + ASMJIT_INLINE_NODEBUG RegMask passedRegs(RegGroup group) const noexcept { return _callConv.passedRegs(group); } + //! Returns a mask of all preserved registers of the given register `group`. + ASMJIT_INLINE_NODEBUG RegMask preservedRegs(RegGroup group) const noexcept { return _callConv.preservedRegs(group); } + + //! Returns a mask of all used registers of the given register `group`. + inline RegMask usedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _usedRegs[size_t(group)]; + } + + //! Adds `regs` to the mask of used registers of the given register `group`. + inline void addUsedRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _usedRegs[size_t(group)] |= regs; + } + + //! \} +}; + +//! Function frame. +//! +//! Function frame is used directly by prolog and epilog insertion (PEI) utils. It provides information necessary to +//! insert a proper and ABI conforming prolog and epilog. Function frame calculation is based on `CallConv` and +//! other function attributes. +//! +//! SSE vs AVX vs AVX-512 +//! --------------------- +//! +//! Function frame provides a way to tell prolog/epilog inserter to use AVX instructions instead of SSE. Use +//! `setAvxEnabled()` and `setAvx512Enabled()` to enable AVX and/or AVX-512, respectively. Enabling AVX-512 +//! is mostly for Compiler as it would use 32 SIMD registers instead of 16 when enabled. +//! +//! \note If your code uses AVX instructions and AVX is not enabled there would be a performance hit in case that +//! some registers had to be saved/restored in function's prolog/epilog, respectively. Thus, it's recommended to +//! always let the function frame know about the use of AVX. +//! +//! Function Frame Structure +//! ------------------------ +//! +//! Various properties can contribute to the size and structure of the function frame. The function frame in most +//! cases won't use all of the properties illustrated (for example Spill Zone and Red Zone are never used together). +//! +//! ``` +//! +-----------------------------+ +//! | Arguments Passed by Stack | +//! +-----------------------------+ +//! | Spill Zone | +//! +-----------------------------+ <- Stack offset (args) starts from here. +//! | Return Address, if Pushed | +//! +-----------------------------+ <- Stack pointer (SP) upon entry. +//! | Save/Restore Stack. | +//! +-----------------------------+-----------------------------+ +//! | Local Stack | | +//! +-----------------------------+ Final Stack | +//! | Call Stack | | +//! +-----------------------------+-----------------------------+ <- SP after prolog. +//! | Red Zone | +//! +-----------------------------+ +//! ``` +class FuncFrame { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + //! Tag used to inform that some offset is invalid. + kTagInvalidOffset = 0xFFFFFFFFu + }; + + //! \} + + //! \name Members + //! \{ + + //! Function attributes. + FuncAttributes _attributes {}; + + //! Target architecture. + Arch _arch {}; + //! SP register ID (to access call stack and local stack). + uint8_t _spRegId = uint8_t(BaseReg::kIdBad); + //! SA register ID (to access stack arguments). + uint8_t _saRegId = uint8_t(BaseReg::kIdBad); + + //! Red zone size (copied from CallConv). + uint8_t _redZoneSize = 0; + //! Spill zone size (copied from CallConv). + uint8_t _spillZoneSize = 0; + //! Natural stack alignment (copied from CallConv). + uint8_t _naturalStackAlignment = 0; + //! Minimum stack alignment to turn on dynamic alignment. + uint8_t _minDynamicAlignment = 0; + + //! Call stack alignment. + uint8_t _callStackAlignment = 0; + //! Local stack alignment. + uint8_t _localStackAlignment = 0; + //! Final stack alignment. + uint8_t _finalStackAlignment = 0; + + //! Adjustment of the stack before returning (X86-STDCALL). + uint16_t _calleeStackCleanup = 0; + + //! Call stack size. + uint32_t _callStackSize = 0; + //! Local stack size. + uint32_t _localStackSize = 0; + //! Final stack size (sum of call stack and local stack). + uint32_t _finalStackSize = 0; + + //! Local stack offset (non-zero only if call stack is used). + uint32_t _localStackOffset = 0; + //! Offset relative to SP that contains previous SP (before alignment). + uint32_t _daOffset = 0; + //! Offset of the first stack argument relative to SP. + uint32_t _saOffsetFromSP = 0; + //! Offset of the first stack argument relative to SA (_saRegId or FP). + uint32_t _saOffsetFromSA = 0; + + //! Local stack adjustment in prolog/epilog. + uint32_t _stackAdjustment = 0; + + //! Registers that are dirty. + Support::Array _dirtyRegs {}; + //! Registers that must be preserved (copied from CallConv). + Support::Array _preservedRegs {}; + //! Size to save/restore per register group. + Support::Array _saveRestoreRegSize {}; + //! Alignment of save/restore area per register group. + Support::Array _saveRestoreAlignment {}; + + //! Stack size required to save registers with push/pop. + uint16_t _pushPopSaveSize = 0; + //! Stack size required to save extra registers that cannot use push/pop. + uint16_t _extraRegSaveSize = 0; + //! Offset where registers saved/restored via push/pop are stored + uint32_t _pushPopSaveOffset = 0; + //! Offset where extra registers that cannot use push/pop are stored. + uint32_t _extraRegSaveOffset = 0; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default constructed function frame, which has initialized all members to their default values. + ASMJIT_INLINE_NODEBUG FuncFrame() noexcept = default; + //! Creates a copy of `other` function frame. + ASMJIT_INLINE_NODEBUG FuncFrame(const FuncFrame& other) noexcept = default; + + //! \} + + //! \name Initialization & Reset + //! \{ + + //! Initializes the function frame based on `func` detail. + ASMJIT_API Error init(const FuncDetail& func) noexcept; + //! Resets the function frame into its default constructed state. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = FuncFrame{}; } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment - function frame is copy assignable. + ASMJIT_INLINE_NODEBUG FuncFrame& operator=(const FuncFrame& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the target architecture of the function frame. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _arch; } + + //! Returns function frame attributes, see `Attributes`. + ASMJIT_INLINE_NODEBUG FuncAttributes attributes() const noexcept { return _attributes; } + //! Checks whether the FuncFame contains an attribute `attr`. + ASMJIT_INLINE_NODEBUG bool hasAttribute(FuncAttributes attr) const noexcept { return Support::test(_attributes, attr); } + //! Adds attributes `attrs` to the FuncFrame. + ASMJIT_INLINE_NODEBUG void addAttributes(FuncAttributes attrs) noexcept { _attributes |= attrs; } + //! Clears attributes `attrs` from the FrameFrame. + ASMJIT_INLINE_NODEBUG void clearAttributes(FuncAttributes attrs) noexcept { _attributes &= ~attrs; } + + //! Tests whether the function has variable number of arguments. + ASMJIT_INLINE_NODEBUG bool hasVarArgs() const noexcept { return hasAttribute(FuncAttributes::kHasVarArgs); } + //! Sets the variable arguments flag. + ASMJIT_INLINE_NODEBUG void setVarArgs() noexcept { addAttributes(FuncAttributes::kHasVarArgs); } + //! Resets variable arguments flag. + ASMJIT_INLINE_NODEBUG void resetVarArgs() noexcept { clearAttributes(FuncAttributes::kHasVarArgs); } + + //! Tests whether the function preserves frame pointer (EBP|ESP on X86). + ASMJIT_INLINE_NODEBUG bool hasPreservedFP() const noexcept { return hasAttribute(FuncAttributes::kHasPreservedFP); } + //! Enables preserved frame pointer. + ASMJIT_INLINE_NODEBUG void setPreservedFP() noexcept { addAttributes(FuncAttributes::kHasPreservedFP); } + //! Disables preserved frame pointer. + ASMJIT_INLINE_NODEBUG void resetPreservedFP() noexcept { clearAttributes(FuncAttributes::kHasPreservedFP); } + + //! Tests whether the function calls other functions. + ASMJIT_INLINE_NODEBUG bool hasFuncCalls() const noexcept { return hasAttribute(FuncAttributes::kHasFuncCalls); } + //! Sets `FuncAttributes::kHasFuncCalls` to true. + ASMJIT_INLINE_NODEBUG void setFuncCalls() noexcept { addAttributes(FuncAttributes::kHasFuncCalls); } + //! Sets `FuncAttributes::kHasFuncCalls` to false. + ASMJIT_INLINE_NODEBUG void resetFuncCalls() noexcept { clearAttributes(FuncAttributes::kHasFuncCalls); } + + //! Tests whether the function uses indirect branch protection, see \ref FuncAttributes::kIndirectBranchProtection. + ASMJIT_INLINE_NODEBUG bool hasIndirectBranchProtection() const noexcept { return hasAttribute(FuncAttributes::kIndirectBranchProtection); } + //! Enabled indirect branch protection (sets `FuncAttributes::kIndirectBranchProtection` attribute to true). + ASMJIT_INLINE_NODEBUG void setIndirectBranchProtection() noexcept { addAttributes(FuncAttributes::kIndirectBranchProtection); } + //! Disables indirect branch protection (sets `FuncAttributes::kIndirectBranchProtection` attribute to false). + ASMJIT_INLINE_NODEBUG void resetIndirectBranchProtection() noexcept { clearAttributes(FuncAttributes::kIndirectBranchProtection); } + + //! Tests whether the function has AVX enabled. + ASMJIT_INLINE_NODEBUG bool isAvxEnabled() const noexcept { return hasAttribute(FuncAttributes::kX86_AVXEnabled); } + //! Enables AVX use. + ASMJIT_INLINE_NODEBUG void setAvxEnabled() noexcept { addAttributes(FuncAttributes::kX86_AVXEnabled); } + //! Disables AVX use. + ASMJIT_INLINE_NODEBUG void resetAvxEnabled() noexcept { clearAttributes(FuncAttributes::kX86_AVXEnabled); } + + //! Tests whether the function has AVX-512 enabled. + ASMJIT_INLINE_NODEBUG bool isAvx512Enabled() const noexcept { return hasAttribute(FuncAttributes::kX86_AVX512Enabled); } + //! Enables AVX-512 use. + ASMJIT_INLINE_NODEBUG void setAvx512Enabled() noexcept { addAttributes(FuncAttributes::kX86_AVX512Enabled); } + //! Disables AVX-512 use. + ASMJIT_INLINE_NODEBUG void resetAvx512Enabled() noexcept { clearAttributes(FuncAttributes::kX86_AVX512Enabled); } + + //! Tests whether the function has MMX cleanup - 'emms' instruction in epilog. + ASMJIT_INLINE_NODEBUG bool hasMmxCleanup() const noexcept { return hasAttribute(FuncAttributes::kX86_MMXCleanup); } + //! Enables MMX cleanup. + ASMJIT_INLINE_NODEBUG void setMmxCleanup() noexcept { addAttributes(FuncAttributes::kX86_MMXCleanup); } + //! Disables MMX cleanup. + ASMJIT_INLINE_NODEBUG void resetMmxCleanup() noexcept { clearAttributes(FuncAttributes::kX86_MMXCleanup); } + + //! Tests whether the function has AVX cleanup - 'vzeroupper' instruction in epilog. + ASMJIT_INLINE_NODEBUG bool hasAvxCleanup() const noexcept { return hasAttribute(FuncAttributes::kX86_AVXCleanup); } + //! Enables AVX cleanup. + ASMJIT_INLINE_NODEBUG void setAvxCleanup() noexcept { addAttributes(FuncAttributes::kX86_AVXCleanup); } + //! Disables AVX cleanup. + ASMJIT_INLINE_NODEBUG void resetAvxCleanup() noexcept { clearAttributes(FuncAttributes::kX86_AVXCleanup); } + + //! Tests whether the function uses call stack. + ASMJIT_INLINE_NODEBUG bool hasCallStack() const noexcept { return _callStackSize != 0; } + //! Tests whether the function uses local stack. + ASMJIT_INLINE_NODEBUG bool hasLocalStack() const noexcept { return _localStackSize != 0; } + //! Tests whether vector registers can be saved and restored by using aligned reads and writes. + ASMJIT_INLINE_NODEBUG bool hasAlignedVecSR() const noexcept { return hasAttribute(FuncAttributes::kAlignedVecSR); } + //! Tests whether the function has to align stack dynamically. + ASMJIT_INLINE_NODEBUG bool hasDynamicAlignment() const noexcept { return _finalStackAlignment >= _minDynamicAlignment; } + + //! Tests whether the calling convention specifies 'RedZone'. + ASMJIT_INLINE_NODEBUG bool hasRedZone() const noexcept { return _redZoneSize != 0; } + //! Tests whether the calling convention specifies 'SpillZone'. + ASMJIT_INLINE_NODEBUG bool hasSpillZone() const noexcept { return _spillZoneSize != 0; } + + //! Returns the size of 'RedZone'. + ASMJIT_INLINE_NODEBUG uint32_t redZoneSize() const noexcept { return _redZoneSize; } + //! Returns the size of 'SpillZone'. + ASMJIT_INLINE_NODEBUG uint32_t spillZoneSize() const noexcept { return _spillZoneSize; } + + //! Resets the size of red zone, which would disable it entirely. + //! + //! \note Red zone is currently only used by an AMD64 SystemV calling convention, which expects 128 + //! bytes of stack to be accessible below stack pointer. These bytes are then accessible within the + //! function and Compiler can use this space as a spill area. However, sometimes it's better to + //! disallow the use of red zone in case that a user wants to use this stack for a custom purpose. + ASMJIT_INLINE_NODEBUG void resetRedZone() noexcept { _redZoneSize = 0; } + + //! Returns natural stack alignment (guaranteed stack alignment upon entry). + ASMJIT_INLINE_NODEBUG uint32_t naturalStackAlignment() const noexcept { return _naturalStackAlignment; } + //! Returns natural stack alignment (guaranteed stack alignment upon entry). + ASMJIT_INLINE_NODEBUG uint32_t minDynamicAlignment() const noexcept { return _minDynamicAlignment; } + + //! Tests whether the callee must adjust SP before returning (X86-STDCALL only) + ASMJIT_INLINE_NODEBUG bool hasCalleeStackCleanup() const noexcept { return _calleeStackCleanup != 0; } + //! Returns home many bytes of the stack the callee must adjust before returning (X86-STDCALL only) + ASMJIT_INLINE_NODEBUG uint32_t calleeStackCleanup() const noexcept { return _calleeStackCleanup; } + + //! Returns call stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t callStackAlignment() const noexcept { return _callStackAlignment; } + //! Returns local stack alignment. + ASMJIT_INLINE_NODEBUG uint32_t localStackAlignment() const noexcept { return _localStackAlignment; } + //! Returns final stack alignment (the maximum value of call, local, and natural stack alignments). + ASMJIT_INLINE_NODEBUG uint32_t finalStackAlignment() const noexcept { return _finalStackAlignment; } + + //! Sets call stack alignment. + //! + //! \note This also updates the final stack alignment. + inline void setCallStackAlignment(uint32_t alignment) noexcept { + _callStackAlignment = uint8_t(alignment); + _finalStackAlignment = Support::max(_naturalStackAlignment, _callStackAlignment, _localStackAlignment); + } + + //! Sets local stack alignment. + //! + //! \note This also updates the final stack alignment. + inline void setLocalStackAlignment(uint32_t value) noexcept { + _localStackAlignment = uint8_t(value); + _finalStackAlignment = Support::max(_naturalStackAlignment, _callStackAlignment, _localStackAlignment); + } + + //! Combines call stack alignment with `alignment`, updating it to the greater value. + //! + //! \note This also updates the final stack alignment. + inline void updateCallStackAlignment(uint32_t alignment) noexcept { + _callStackAlignment = uint8_t(Support::max(_callStackAlignment, alignment)); + _finalStackAlignment = Support::max(_finalStackAlignment, _callStackAlignment); + } + + //! Combines local stack alignment with `alignment`, updating it to the greater value. + //! + //! \note This also updates the final stack alignment. + inline void updateLocalStackAlignment(uint32_t alignment) noexcept { + _localStackAlignment = uint8_t(Support::max(_localStackAlignment, alignment)); + _finalStackAlignment = Support::max(_finalStackAlignment, _localStackAlignment); + } + + //! Returns call stack size. + ASMJIT_INLINE_NODEBUG uint32_t callStackSize() const noexcept { return _callStackSize; } + //! Returns local stack size. + ASMJIT_INLINE_NODEBUG uint32_t localStackSize() const noexcept { return _localStackSize; } + + //! Sets call stack size. + ASMJIT_INLINE_NODEBUG void setCallStackSize(uint32_t size) noexcept { _callStackSize = size; } + //! Sets local stack size. + ASMJIT_INLINE_NODEBUG void setLocalStackSize(uint32_t size) noexcept { _localStackSize = size; } + + //! Combines call stack size with `size`, updating it to the greater value. + ASMJIT_INLINE_NODEBUG void updateCallStackSize(uint32_t size) noexcept { _callStackSize = Support::max(_callStackSize, size); } + //! Combines local stack size with `size`, updating it to the greater value. + ASMJIT_INLINE_NODEBUG void updateLocalStackSize(uint32_t size) noexcept { _localStackSize = Support::max(_localStackSize, size); } + + //! Returns final stack size (only valid after the FuncFrame is finalized). + ASMJIT_INLINE_NODEBUG uint32_t finalStackSize() const noexcept { return _finalStackSize; } + + //! Returns an offset to access the local stack (non-zero only if call stack is used). + ASMJIT_INLINE_NODEBUG uint32_t localStackOffset() const noexcept { return _localStackOffset; } + + //! Tests whether the function prolog/epilog requires a memory slot for storing unaligned SP. + ASMJIT_INLINE_NODEBUG bool hasDAOffset() const noexcept { return _daOffset != kTagInvalidOffset; } + //! Returns a memory offset used to store DA (dynamic alignment) slot (relative to SP). + ASMJIT_INLINE_NODEBUG uint32_t daOffset() const noexcept { return _daOffset; } + + ASMJIT_INLINE_NODEBUG uint32_t saOffset(uint32_t regId) const noexcept { + return regId == _spRegId ? saOffsetFromSP() + : saOffsetFromSA(); + } + + ASMJIT_INLINE_NODEBUG uint32_t saOffsetFromSP() const noexcept { return _saOffsetFromSP; } + ASMJIT_INLINE_NODEBUG uint32_t saOffsetFromSA() const noexcept { return _saOffsetFromSA; } + + //! Returns mask of registers of the given register `group` that are modified by the function. The engine would + //! then calculate which registers must be saved & restored by the function by using the data provided by the + //! calling convention. + inline RegMask dirtyRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _dirtyRegs[group]; + } + + //! Sets which registers (as a mask) are modified by the function. + //! + //! \remarks Please note that this will completely overwrite the existing register mask, use `addDirtyRegs()` + //! to modify the existing register mask. + inline void setDirtyRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _dirtyRegs[group] = regs; + } + + //! Adds which registers (as a mask) are modified by the function. + inline void addDirtyRegs(RegGroup group, RegMask regs) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _dirtyRegs[group] |= regs; + } + + //! \overload + inline void addDirtyRegs(const BaseReg& reg) noexcept { + ASMJIT_ASSERT(reg.id() < Globals::kMaxPhysRegs); + addDirtyRegs(reg.group(), Support::bitMask(reg.id())); + } + + //! \overload + template + inline void addDirtyRegs(const BaseReg& reg, Args&&... args) noexcept { + addDirtyRegs(reg); + addDirtyRegs(std::forward(args)...); + } + + //! A helper function to set all registers from all register groups dirty. + //! + //! \note This should not be used in general as it's the most pessimistic case. However, it can be used for testing + //! or in cases in which all registers are considered clobbered. + ASMJIT_INLINE_NODEBUG void setAllDirty() noexcept { + for (size_t i = 0; i < ASMJIT_ARRAY_SIZE(_dirtyRegs); i++) + _dirtyRegs[i] = 0xFFFFFFFFu; + } + + //! A helper function to set all registers from the given register `group` dirty. + inline void setAllDirty(RegGroup group) noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + _dirtyRegs[group] = 0xFFFFFFFFu; + } + + //! Returns a calculated mask of registers of the given `group` that will be saved and restored in the function's + //! prolog and epilog, respectively. The register mask is calculated from both `dirtyRegs` (provided by user) and + //! `preservedMask` (provided by the calling convention). + inline RegMask savedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _dirtyRegs[group] & _preservedRegs[group]; + } + + //! Returns the mask of preserved registers of the given register `group`. + //! + //! Preserved registers are those that must survive the function call unmodified. The function can only modify + //! preserved registers it they are saved and restored in function's prolog and epilog, respectively. + inline RegMask preservedRegs(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _preservedRegs[group]; + } + + //! Returns the size of a save-restore are for the required register `group`. + inline uint32_t saveRestoreRegSize(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _saveRestoreRegSize[group]; + } + + inline uint32_t saveRestoreAlignment(RegGroup group) const noexcept { + ASMJIT_ASSERT(group <= RegGroup::kMaxVirt); + return _saveRestoreAlignment[group]; + } + + ASMJIT_INLINE_NODEBUG bool hasSARegId() const noexcept { return _saRegId != BaseReg::kIdBad; } + ASMJIT_INLINE_NODEBUG uint32_t saRegId() const noexcept { return _saRegId; } + ASMJIT_INLINE_NODEBUG void setSARegId(uint32_t regId) { _saRegId = uint8_t(regId); } + ASMJIT_INLINE_NODEBUG void resetSARegId() { setSARegId(BaseReg::kIdBad); } + + //! Returns stack size required to save/restore registers via push/pop. + ASMJIT_INLINE_NODEBUG uint32_t pushPopSaveSize() const noexcept { return _pushPopSaveSize; } + //! Returns an offset to the stack where registers are saved via push/pop. + ASMJIT_INLINE_NODEBUG uint32_t pushPopSaveOffset() const noexcept { return _pushPopSaveOffset; } + + //! Returns stack size required to save/restore extra registers that don't use push/pop/ + //! + //! \note On X86 this covers all registers except GP registers, on other architectures it can be always + //! zero (for example AArch64 saves all registers via push/pop like instructions, so this would be zero). + ASMJIT_INLINE_NODEBUG uint32_t extraRegSaveSize() const noexcept { return _extraRegSaveSize; } + //! Returns an offset to the stack where extra registers are saved. + ASMJIT_INLINE_NODEBUG uint32_t extraRegSaveOffset() const noexcept { return _extraRegSaveOffset; } + + //! Tests whether the functions contains stack adjustment. + ASMJIT_INLINE_NODEBUG bool hasStackAdjustment() const noexcept { return _stackAdjustment != 0; } + //! Returns function's stack adjustment used in function's prolog and epilog. + //! + //! If the returned value is zero it means that the stack is not adjusted. This can mean both that the stack + //! is not used and/or the stack is only adjusted by instructions that pust/pop registers into/from stack. + ASMJIT_INLINE_NODEBUG uint32_t stackAdjustment() const noexcept { return _stackAdjustment; } + + //! \} + + //! \name Finalization + //! \{ + + ASMJIT_API Error finalize() noexcept; + + //! \} +}; + +//! A helper class that can be used to assign a physical register for each function argument. Use with +//! `BaseEmitter::emitArgsAssignment()`. +class FuncArgsAssignment { +public: + //! \name Members + //! \{ + + //! Function detail. + const FuncDetail* _funcDetail {}; + //! Register that can be used to access arguments passed by stack. + uint8_t _saRegId = uint8_t(BaseReg::kIdBad); + //! Reserved for future use. + uint8_t _reserved[3] {}; + //! Mapping of each function argument. + FuncValuePack _argPacks[Globals::kMaxFuncArgs] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates either a default initialized `FuncArgsAssignment` or to assignment that links to `fd`, if non-null. + ASMJIT_INLINE_NODEBUG explicit FuncArgsAssignment(const FuncDetail* fd = nullptr) noexcept { reset(fd); } + + //! Copy constructor. + ASMJIT_INLINE_NODEBUG FuncArgsAssignment(const FuncArgsAssignment& other) noexcept = default; + + //! Resets this `FuncArgsAssignment` to either default constructed state or to assignment that links to `fd`, + //! if non-null. + inline void reset(const FuncDetail* fd = nullptr) noexcept { + _funcDetail = fd; + _saRegId = uint8_t(BaseReg::kIdBad); + memset(_reserved, 0, sizeof(_reserved)); + memset(_argPacks, 0, sizeof(_argPacks)); + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Copy assignment. + ASMJIT_INLINE_NODEBUG FuncArgsAssignment& operator=(const FuncArgsAssignment& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the associated \ref FuncDetail of this `FuncArgsAssignment`. + ASMJIT_INLINE_NODEBUG const FuncDetail* funcDetail() const noexcept { return _funcDetail; } + //! Associates \ref FuncDetails with this `FuncArgsAssignment`. + ASMJIT_INLINE_NODEBUG void setFuncDetail(const FuncDetail* fd) noexcept { _funcDetail = fd; } + + ASMJIT_INLINE_NODEBUG bool hasSARegId() const noexcept { return _saRegId != BaseReg::kIdBad; } + ASMJIT_INLINE_NODEBUG uint32_t saRegId() const noexcept { return _saRegId; } + ASMJIT_INLINE_NODEBUG void setSARegId(uint32_t regId) { _saRegId = uint8_t(regId); } + ASMJIT_INLINE_NODEBUG void resetSARegId() { _saRegId = uint8_t(BaseReg::kIdBad); } + + //! Returns assigned argument at `argIndex` and `valueIndex`. + //! + //! \note `argIndex` refers to he function argument and `valueIndex` refers to a value pack (in case multiple + //! values are passed as a single argument). + inline FuncValue& arg(size_t argIndex, size_t valueIndex) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + return _argPacks[argIndex][valueIndex]; + } + //! \overload + inline const FuncValue& arg(size_t argIndex, size_t valueIndex) const noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + return _argPacks[argIndex][valueIndex]; + } + + //! Tests whether argument at `argIndex` and `valueIndex` has been assigned. + inline bool isAssigned(size_t argIndex, size_t valueIndex) const noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + return _argPacks[argIndex][valueIndex].isAssigned(); + } + + //! Assigns register at `argIndex` and value index of 0 to `reg` and an optional `typeId`. + inline void assignReg(size_t argIndex, const BaseReg& reg, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + ASMJIT_ASSERT(reg.isPhysReg()); + _argPacks[argIndex][0].initReg(reg.type(), reg.id(), typeId); + } + + //! Assigns register at `argIndex` and value index of 0 to `regType`, `regId`, and an optional `typeId`. + inline void assignReg(size_t argIndex, RegType regType, uint32_t regId, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][0].initReg(regType, regId, typeId); + } + + //! Assigns stack at `argIndex` and value index of 0 to `offset` and an optional `typeId`. + inline void assignStack(size_t argIndex, int32_t offset, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][0].initStack(offset, typeId); + } + + //! Assigns register at `argIndex` and `valueIndex` to `reg` and an optional `typeId`. + inline void assignRegInPack(size_t argIndex, size_t valueIndex, const BaseReg& reg, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + ASMJIT_ASSERT(reg.isPhysReg()); + _argPacks[argIndex][valueIndex].initReg(reg.type(), reg.id(), typeId); + } + + //! Assigns register at `argIndex` and `valueIndex` to `regType`, `regId`, and an optional `typeId`. + inline void assignRegInPack(size_t argIndex, size_t valueIndex, RegType regType, uint32_t regId, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][valueIndex].initReg(regType, regId, typeId); + } + + //! Assigns stack at `argIndex` and `valueIndex` to `offset` and an optional `typeId`. + inline void assignStackInPack(size_t argIndex, size_t valueIndex, int32_t offset, TypeId typeId = TypeId::kVoid) noexcept { + ASMJIT_ASSERT(argIndex < ASMJIT_ARRAY_SIZE(_argPacks)); + _argPacks[argIndex][valueIndex].initStack(offset, typeId); + } + + // NOTE: All `assignAll()` methods are shortcuts to assign all arguments at once, however, since registers are + // passed all at once these initializers don't provide any way to pass TypeId and/or to keep any argument between + // the arguments passed unassigned. + inline void _assignAllInternal(size_t argIndex, const BaseReg& reg) noexcept { + assignReg(argIndex, reg); + } + + template + inline void _assignAllInternal(size_t argIndex, const BaseReg& reg, Args&&... args) noexcept { + assignReg(argIndex, reg); + _assignAllInternal(argIndex + 1, std::forward(args)...); + } + + //! Assigns all argument at once. + //! + //! \note This function can be only used if the arguments don't contain value packs (multiple values per argument). + template + inline void assignAll(Args&&... args) noexcept { + _assignAllInternal(0, std::forward(args)...); + } + + //! \} + + //! \name Utilities + //! \{ + + //! Update `FuncFrame` based on function's arguments assignment. + //! + //! \note This function must be called in order to use `BaseEmitter::emitArgsAssignment()`, otherwise the \ref FuncFrame + //! would not contain the information necessary to assign all arguments into the registers and/or stack specified. + ASMJIT_API Error updateFuncFrame(FuncFrame& frame) const noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_FUNC_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/globals.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/globals.h new file mode 100644 index 0000000000000000000000000000000000000000..1066bb830a0ccf1100032953cac82bf219e40e51 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/globals.h @@ -0,0 +1,421 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_GLOBALS_H_INCLUDED +#define ASMJIT_CORE_GLOBALS_H_INCLUDED + +#include "../core/api-config.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \cond INTERNAL +//! \addtogroup asmjit_utilities +//! \{ +namespace Support { + +//! Cast designed to cast between function and void* pointers. +template +static inline Dst ptr_cast_impl(Src p) noexcept { return (Dst)p; } + +//! Helper to implement placement new/delete without relying on `` header. +struct PlacementNew { void* ptr; }; + +} // {Support} + +#if defined(ASMJIT_NO_STDCXX) +namespace Support { + ASMJIT_FORCE_INLINE void* operatorNew(size_t n) noexcept { return malloc(n); } + ASMJIT_FORCE_INLINE void operatorDelete(void* p) noexcept { if (p) free(p); } +} // {Support} + +#define ASMJIT_BASE_CLASS(TYPE) \ + ASMJIT_FORCE_INLINE void* operator new(size_t n) noexcept { return Support::operatorNew(n); } \ + ASMJIT_FORCE_INLINE void operator delete(void* ptr) noexcept { Support::operatorDelete(ptr); } \ + \ + ASMJIT_FORCE_INLINE void* operator new(size_t, void* ptr) noexcept { return ptr; } \ + ASMJIT_FORCE_INLINE void operator delete(void*, void*) noexcept {} \ + \ + ASMJIT_FORCE_INLINE void* operator new(size_t, Support::PlacementNew ptr) noexcept { return ptr.ptr; } \ + ASMJIT_FORCE_INLINE void operator delete(void*, Support::PlacementNew) noexcept {} +#else +#define ASMJIT_BASE_CLASS(TYPE) +#endif + +//! \} +//! \endcond + +//! \addtogroup asmjit_core +//! \{ + +//! Byte order. +enum class ByteOrder { + //! Little endian. + kLE = 0, + //! Big endian. + kBE = 1, + //! Native byte order of the target architecture. + kNative = ASMJIT_ARCH_LE ? kLE : kBE, + //! Swapped byte order of the target architecture. + kSwapped = ASMJIT_ARCH_LE ? kBE : kLE +}; + +//! A policy that can be used with some `reset()` member functions. +enum class ResetPolicy : uint32_t { + //! Soft reset, doesn't deallocate memory (default). + kSoft = 0, + //! Hard reset, releases all memory used, if any. + kHard = 1 +}; + +//! Contains typedefs, constants, and variables used globally by AsmJit. +namespace Globals { + +//! Host memory allocator overhead. +static constexpr uint32_t kAllocOverhead = uint32_t(sizeof(intptr_t) * 4); + +//! Host memory allocator alignment. +static constexpr uint32_t kAllocAlignment = 8; + +//! Aggressive growing strategy threshold. +static constexpr uint32_t kGrowThreshold = 1024 * 1024 * 16; + +//! Maximum depth of RB-Tree is: +//! +//! `2 * log2(n + 1)` +//! +//! Size of RB node is at least two pointers (without data), so a theoretical architecture limit would be: +//! +//! `2 * log2(addressableMemorySize / sizeof(Node) + 1)` +//! +//! Which yields 30 on 32-bit arch and 61 on 64-bit arch. The final value was adjusted by +1 for safety reasons. +static constexpr uint32_t kMaxTreeHeight = (ASMJIT_ARCH_BITS == 32 ? 30 : 61) + 1; + +//! Maximum number of operands per a single instruction. +static constexpr uint32_t kMaxOpCount = 6; + +//! Maximum arguments of a function supported by the Compiler / Function API. +static constexpr uint32_t kMaxFuncArgs = 32; + +//! The number of values that can be assigned to a single function argument or return value. +static constexpr uint32_t kMaxValuePack = 4; + +//! Maximum number of physical registers AsmJit can use per register group. +static constexpr uint32_t kMaxPhysRegs = 32; + +//! Maximum alignment. +static constexpr uint32_t kMaxAlignment = 64; + +//! Maximum label or symbol size in bytes. +static constexpr uint32_t kMaxLabelNameSize = 2048; + +//! Maximum section name size. +static constexpr uint32_t kMaxSectionNameSize = 35; + +//! Maximum size of comment. +static constexpr uint32_t kMaxCommentSize = 1024; + +//! Invalid identifier. +static constexpr uint32_t kInvalidId = 0xFFFFFFFFu; + +//! Returned by `indexOf()` and similar when working with containers that use 32-bit index/size. +static constexpr uint32_t kNotFound = 0xFFFFFFFFu; + +//! Invalid base address. +static constexpr uint64_t kNoBaseAddress = ~uint64_t(0); + +//! Number of virtual register groups. +static constexpr uint32_t kNumVirtGroups = 4; + +struct Init_ {}; +struct NoInit_ {}; + +//! A decorator used to initialize. +static const constexpr Init_ Init {}; +//! A decorator used to not initialize. +static const constexpr NoInit_ NoInit {}; + +} // {Globals} + +//! Casts a `void*` pointer `func` to a function pointer `Func`. +template +static ASMJIT_INLINE_NODEBUG Func ptr_as_func(void* func) noexcept { return Support::ptr_cast_impl(func); } + +//! Casts a function pointer `func` to a void pointer `void*`. +template +static ASMJIT_INLINE_NODEBUG void* func_as_ptr(Func func) noexcept { return Support::ptr_cast_impl(func); } + +//! \} + +//! \addtogroup asmjit_error_handling +//! \{ + +//! AsmJit error type (uint32_t). +typedef uint32_t Error; + +//! AsmJit error codes. +enum ErrorCode : uint32_t { + // @EnumValuesBegin{"enum": "ErrorCode"}@ + + //! No error (success). + kErrorOk = 0, + + //! Out of memory. + kErrorOutOfMemory, + + //! Invalid argument. + kErrorInvalidArgument, + + //! Invalid state. + //! + //! If this error is returned it means that either you are doing something wrong or AsmJit caught itself by + //! doing something wrong. This error should never be ignored. + kErrorInvalidState, + + //! Invalid or incompatible architecture. + kErrorInvalidArch, + + //! The object is not initialized. + kErrorNotInitialized, + //! The object is already initialized. + kErrorAlreadyInitialized, + + //! Either a built-in feature was disabled at compile time and it's not available or the feature is not + //! available on the target platform. + //! + //! For example trying to allocate large pages on unsupported platform would return this error. + kErrorFeatureNotEnabled, + + //! Too many handles (Windows) or file descriptors (Unix/Posix). + kErrorTooManyHandles, + //! Code generated is larger than allowed. + kErrorTooLarge, + + //! No code generated. + //! + //! Returned by runtime if the \ref CodeHolder contains no code. + kErrorNoCodeGenerated, + + //! Invalid directive. + kErrorInvalidDirective, + //! Attempt to use uninitialized label. + kErrorInvalidLabel, + //! Label index overflow - a single \ref BaseAssembler instance can hold almost 2^32 (4 billion) labels. If + //! there is an attempt to create more labels then this error is returned. + kErrorTooManyLabels, + //! Label is already bound. + kErrorLabelAlreadyBound, + //! Label is already defined (named labels). + kErrorLabelAlreadyDefined, + //! Label name is too long. + kErrorLabelNameTooLong, + //! Label must always be local if it's anonymous (without a name). + kErrorInvalidLabelName, + //! Parent id passed to \ref CodeHolder::newNamedLabelEntry() was either invalid or parent is not supported + //! by the requested `LabelType`. + kErrorInvalidParentLabel, + + //! Invalid section. + kErrorInvalidSection, + //! Too many sections (section index overflow). + kErrorTooManySections, + //! Invalid section name (most probably too long). + kErrorInvalidSectionName, + + //! Relocation index overflow (too many relocations). + kErrorTooManyRelocations, + //! Invalid relocation entry. + kErrorInvalidRelocEntry, + //! Reloc entry contains address that is out of range (unencodable). + kErrorRelocOffsetOutOfRange, + + //! Invalid assignment to a register, function argument, or function return value. + kErrorInvalidAssignment, + //! Invalid instruction. + kErrorInvalidInstruction, + //! Invalid register type. + kErrorInvalidRegType, + //! Invalid register group. + kErrorInvalidRegGroup, + //! Invalid physical register id. + kErrorInvalidPhysId, + //! Invalid virtual register id. + kErrorInvalidVirtId, + //! Invalid element index (ARM). + kErrorInvalidElementIndex, + //! Invalid prefix combination (X86|X64). + kErrorInvalidPrefixCombination, + //! Invalid LOCK prefix (X86|X64). + kErrorInvalidLockPrefix, + //! Invalid XACQUIRE prefix (X86|X64). + kErrorInvalidXAcquirePrefix, + //! Invalid XRELEASE prefix (X86|X64). + kErrorInvalidXReleasePrefix, + //! Invalid REP prefix (X86|X64). + kErrorInvalidRepPrefix, + //! Invalid REX prefix (X86|X64). + kErrorInvalidRexPrefix, + //! Invalid {...} register (X86|X64). + kErrorInvalidExtraReg, + //! Invalid {k} use (not supported by the instruction) (X86|X64). + kErrorInvalidKMaskUse, + //! Invalid {k}{z} use (not supported by the instruction) (X86|X64). + kErrorInvalidKZeroUse, + //! Invalid broadcast - Currently only related to invalid use of AVX-512 {1tox} (X86|X64). + kErrorInvalidBroadcast, + //! Invalid 'embedded-rounding' {er} or 'suppress-all-exceptions' {sae} (AVX-512) (X86|X64). + kErrorInvalidEROrSAE, + //! Invalid address used (not encodable). + kErrorInvalidAddress, + //! Invalid index register used in memory address (not encodable). + kErrorInvalidAddressIndex, + //! Invalid address scale (not encodable). + kErrorInvalidAddressScale, + //! Invalid use of 64-bit address. + kErrorInvalidAddress64Bit, + //! Invalid use of 64-bit address that require 32-bit zero-extension (X64). + kErrorInvalidAddress64BitZeroExtension, + //! Invalid displacement (not encodable). + kErrorInvalidDisplacement, + //! Invalid segment (X86). + kErrorInvalidSegment, + + //! Invalid immediate (out of bounds on X86 and invalid pattern on ARM). + kErrorInvalidImmediate, + + //! Invalid operand size. + kErrorInvalidOperandSize, + //! Ambiguous operand size (memory has zero size while it's required to determine the operation type. + kErrorAmbiguousOperandSize, + //! Mismatching operand size (size of multiple operands doesn't match the operation size). + kErrorOperandSizeMismatch, + + //! Invalid option. + kErrorInvalidOption, + //! Option already defined. + kErrorOptionAlreadyDefined, + + //! Invalid TypeId. + kErrorInvalidTypeId, + //! Invalid use of a 8-bit GPB-HIGH register. + kErrorInvalidUseOfGpbHi, + //! Invalid use of a 64-bit GPQ register in 32-bit mode. + kErrorInvalidUseOfGpq, + //! Invalid use of an 80-bit float (\ref TypeId::kFloat80). + kErrorInvalidUseOfF80, + //! Instruction requires the use of consecutive registers, but registers in operands weren't (AVX512, ASIMD load/store, etc...). + kErrorNotConsecutiveRegs, + //! Failed to allocate consecutive registers - allocable registers either too restricted or a bug in RW info. + kErrorConsecutiveRegsAllocation, + + //! Illegal virtual register - reported by instruction validation. + kErrorIllegalVirtReg, + //! AsmJit cannot create more virtual registers. + kErrorTooManyVirtRegs, + + //! AsmJit requires a physical register, but no one is available. + kErrorNoMorePhysRegs, + //! A variable has been assigned more than once to a function argument (BaseCompiler). + kErrorOverlappedRegs, + //! Invalid register to hold stack arguments offset. + kErrorOverlappingStackRegWithRegArg, + + //! Unbound label cannot be evaluated by expression. + kErrorExpressionLabelNotBound, + //! Arithmetic overflow during expression evaluation. + kErrorExpressionOverflow, + + //! Failed to open anonymous memory handle or file descriptor. + kErrorFailedToOpenAnonymousMemory, + + //! Failed to open a file. + //! + //! \note This is a generic error that is used by internal filesystem API. + kErrorFailedToOpenFile, + + //! Protection failure can be returned from a virtual memory allocator or when trying to change memory access + //! permissions. + kErrorProtectionFailure, + + // @EnumValuesEnd@ + + //! Count of AsmJit error codes. + kErrorCount +}; + +//! Debugging utilities. +namespace DebugUtils { + +//! \cond INTERNAL +//! Used to silence warnings about unused arguments or variables. +template +static ASMJIT_INLINE_NODEBUG void unused(Args&&...) noexcept {} +//! \endcond + +//! Returns the error `err` passed. +//! +//! Provided for debugging purposes. Putting a breakpoint inside `errored` can help with tracing the origin of any +//! error reported / returned by AsmJit. +static constexpr Error errored(Error err) noexcept { return err; } + +//! Returns a printable version of `asmjit::Error` code. +ASMJIT_API const char* errorAsString(Error err) noexcept; + +//! Called to output debugging message(s). +ASMJIT_API void debugOutput(const char* str) noexcept; + +//! Called on assertion failure. +//! +//! \param file Source file name where it happened. +//! \param line Line in the source file. +//! \param msg Message to display. +//! +//! If you have problems with assertion failures a breakpoint can be put at \ref assertionFailed() function +//! (asmjit/core/globals.cpp). A call stack will be available when such assertion failure is triggered. AsmJit +//! always returns errors on failures, assertions are a last resort and usually mean unrecoverable state due to out +//! of range array access or totally invalid arguments like nullptr where a valid pointer should be provided, etc... +ASMJIT_API void ASMJIT_NORETURN assertionFailed(const char* file, int line, const char* msg) noexcept; + +} // {DebugUtils} + +//! \def ASMJIT_ASSERT(...) +//! +//! AsmJit's own assert macro used in AsmJit code-base. +#if defined(ASMJIT_BUILD_DEBUG) +#define ASMJIT_ASSERT(...) \ + do { \ + if (ASMJIT_LIKELY(__VA_ARGS__)) \ + break; \ + ::asmjit::DebugUtils::assertionFailed(__FILE__, __LINE__, #__VA_ARGS__); \ + } while (0) +#else +#define ASMJIT_ASSERT(...) ((void)0) +#endif + +//! \def ASMJIT_PROPAGATE(...) +//! +//! Propagates a possible `Error` produced by `...` to the caller by returning the error immediately. Used by AsmJit +//! internally, but kept public for users that want to use the same technique to propagate errors to the caller. +#define ASMJIT_PROPAGATE(...) \ + do { \ + ::asmjit::Error _err = __VA_ARGS__; \ + if (ASMJIT_UNLIKELY(_err)) \ + return _err; \ + } while (0) + +//! \} + +ASMJIT_END_NAMESPACE + +//! Implementation of a placement new so we don't have to depend on ``. +ASMJIT_INLINE_NODEBUG void* operator new(size_t, const asmjit::Support::PlacementNew& p) noexcept { +#if defined(_MSC_VER) && !defined(__clang__) + __assume(p.ptr != nullptr); // Otherwise MSVC would emit a nullptr check. +#endif + return p.ptr; +} + +ASMJIT_INLINE_NODEBUG void operator delete(void*, const asmjit::Support::PlacementNew&) noexcept {} + +#endif // ASMJIT_CORE_GLOBALS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/inst.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/inst.h new file mode 100644 index 0000000000000000000000000000000000000000..e3da1af313ff5a6fb3daf26f154745ee3d82f06c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/inst.h @@ -0,0 +1,804 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_INST_H_INCLUDED +#define ASMJIT_CORE_INST_H_INCLUDED + +#include "../core/cpuinfo.h" +#include "../core/operand.h" +#include "../core/string.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_instruction_db +//! \{ + +//! Describes an instruction id and modifiers used together with the id. +//! +//! Each architecture has a set of valid instructions indexed from 0. Instruction with 0 id is, however, a special +//! instruction that describes a "no instruction" or "invalid instruction". Different architectures can assign a. +//! different instruction to the same id, each architecture typically has its own instructions indexed from 1. +//! +//! Instruction identifiers listed by architecture: +//! +//! - \ref x86::Inst (X86 and X86_64) +//! - \ref a64::Inst (AArch64) +typedef uint32_t InstId; + +//! Instruction id parts. +//! +//! A mask that specifies a bit-layout of \ref InstId. +enum class InstIdParts : uint32_t { + // Common Masks + // ------------ + + //! Real id without any modifiers (always 16 least significant bits). + kRealId = 0x0000FFFFu, + //! Instruction is abstract (or virtual, IR, etc...). + kAbstract = 0x80000000u, + + // ARM Specific + // ------------ + + //! AArch32 first data type, used by ASIMD instructions (`inst.dt.dt2`). + kA32_DT = 0x000F0000u, + //! AArch32 second data type, used by ASIMD instructions (`inst.dt.dt2`). + kA32_DT2 = 0x00F00000u, + //! AArch32/AArch64 condition code. + kARM_Cond = 0x78000000u +}; + +//! Instruction options. +//! +//! Instruction options complement instruction identifier and attributes. +enum class InstOptions : uint32_t { + //! No options. + kNone = 0, + + //! Used internally by emitters for handling errors and rare cases. + kReserved = 0x00000001u, + + //! Prevents following a jump during compilation (Compiler). + kUnfollow = 0x00000002u, + + //! Overwrite the destination operand(s) (Compiler). + //! + //! Hint that is important for register liveness analysis. It tells the compiler that the destination operand will + //! be overwritten now or by adjacent instructions. Compiler knows when a register is completely overwritten by a + //! single instruction, for example you don't have to mark "movaps" or "pxor x, x", however, if a pair of + //! instructions is used and the first of them doesn't completely overwrite the content of the destination, + //! Compiler fails to mark that register as dead. + //! + //! X86 Specific + //! ------------ + //! + //! - All instructions that always overwrite at least the size of the register the virtual-register uses, for + //! example "mov", "movq", "movaps" don't need the overwrite option to be used - conversion, shuffle, and + //! other miscellaneous instructions included. + //! + //! - All instructions that clear the destination register if all operands are the same, for example "xor x, x", + //! "pcmpeqb x x", etc... + //! + //! - Consecutive instructions that partially overwrite the variable until there is no old content require + //! `BaseCompiler::overwrite()` to be used. Some examples (not always the best use cases thought): + //! + //! - `movlps xmm0, ?` followed by `movhps xmm0, ?` and vice versa + //! - `movlpd xmm0, ?` followed by `movhpd xmm0, ?` and vice versa + //! - `mov al, ?` followed by `and ax, 0xFF` + //! - `mov al, ?` followed by `mov ah, al` + //! - `pinsrq xmm0, ?, 0` followed by `pinsrq xmm0, ?, 1` + //! + //! - If the allocated virtual register is used temporarily for scalar operations. For example if you allocate a + //! full vector like `x86::Compiler::newXmm()` and then use that vector for scalar operations you should use + //! `overwrite()` directive: + //! + //! - `sqrtss x, y` - only LO element of `x` is changed, if you don't + //! use HI elements, use `compiler.overwrite().sqrtss(x, y)`. + kOverwrite = 0x00000004u, + + //! Emit short-form of the instruction. + kShortForm = 0x00000010u, + //! Emit long-form of the instruction. + kLongForm = 0x00000020u, + + //! Conditional jump is likely to be taken. + kTaken = 0x00000040u, + //! Conditional jump is unlikely to be taken. + kNotTaken = 0x00000080u, + + // X86 & X64 Options + // ----------------- + + //! Use ModMR instead of ModRM if applicable. + kX86_ModMR = 0x00000100u, + //! Use ModRM instead of ModMR if applicable. + kX86_ModRM = 0x00000200u, + //! Use 3-byte VEX prefix if possible (AVX) (must be 0x00000400). + kX86_Vex3 = 0x00000400u, + //! Use VEX prefix when both VEX|EVEX prefixes are available (HINT: AVX_VNNI). + kX86_Vex = 0x00000800u, + //! Use 4-byte EVEX prefix if possible (AVX-512) (must be 0x00001000). + kX86_Evex = 0x00001000u, + + //! LOCK prefix (lock-enabled instructions only). + kX86_Lock = 0x00002000u, + //! REP prefix (string instructions only). + kX86_Rep = 0x00004000u, + //! REPNE prefix (string instructions only). + kX86_Repne = 0x00008000u, + + //! XACQUIRE prefix (only allowed instructions). + kX86_XAcquire = 0x00010000u, + //! XRELEASE prefix (only allowed instructions). + kX86_XRelease = 0x00020000u, + + //! AVX-512: embedded-rounding {er} and implicit {sae}. + kX86_ER = 0x00040000u, + //! AVX-512: suppress-all-exceptions {sae}. + kX86_SAE = 0x00080000u, + //! AVX-512: round-to-nearest (even) {rn-sae} (bits 00). + kX86_RN_SAE = 0x00000000u, + //! AVX-512: round-down (toward -inf) {rd-sae} (bits 01). + kX86_RD_SAE = 0x00200000u, + //! AVX-512: round-up (toward +inf) {ru-sae} (bits 10). + kX86_RU_SAE = 0x00400000u, + //! AVX-512: round-toward-zero (truncate) {rz-sae} (bits 11). + kX86_RZ_SAE = 0x00600000u, + //! AVX-512: Use zeroing {k}{z} instead of merging {k}. + kX86_ZMask = 0x00800000u, + + //! AVX-512: Mask to get embedded rounding bits (2 bits). + kX86_ERMask = kX86_RZ_SAE, + //! AVX-512: Mask of all possible AVX-512 options except EVEX prefix flag. + kX86_AVX512Mask = 0x00FC0000u, + + //! Force REX.B and/or VEX.B field (X64 only). + kX86_OpCodeB = 0x01000000u, + //! Force REX.X and/or VEX.X field (X64 only). + kX86_OpCodeX = 0x02000000u, + //! Force REX.R and/or VEX.R field (X64 only). + kX86_OpCodeR = 0x04000000u, + //! Force REX.W and/or VEX.W field (X64 only). + kX86_OpCodeW = 0x08000000u, + //! Force REX prefix (X64 only). + kX86_Rex = 0x40000000u, + //! Invalid REX prefix (set by X86 or when AH|BH|CH|DH regs are used on X64). + kX86_InvalidRex = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstOptions) + +//! Instruction control flow. +enum class InstControlFlow : uint32_t { + //! Regular instruction. + kRegular = 0u, + //! Unconditional jump. + kJump = 1u, + //! Conditional jump (branch). + kBranch = 2u, + //! Function call. + kCall = 3u, + //! Function return. + kReturn = 4u, + + //! Maximum value of `InstType`. + kMaxValue = kReturn +}; + +//! Hint that is used when both input operands to the instruction are the same. +//! +//! Provides hints to the instruction RW query regarding special cases in which two or more operands are the same +//! registers. This is required by instructions such as XOR, AND, OR, SUB, etc... These hints will influence the +//! RW operations query. +enum class InstSameRegHint : uint8_t { + //! No special handling. + kNone = 0, + //! Operands become read-only, the operation doesn't change the content - `X & X` and similar. + kRO = 1, + //! Operands become write-only, the content of the input(s) don't matter - `X ^ X`, `X - X`, and similar. + kWO = 2 +}; + +//! Instruction id, options, and extraReg in a single structure. This structure exists mainly to simplify analysis +//! and validation API that requires `BaseInst` and `Operand[]` array. +class BaseInst { +public: + //! \name Members + //! \{ + + //! Instruction id with modifiers. + InstId _id; + //! Instruction options. + InstOptions _options; + //! Extra register used by the instruction (either REP register or AVX-512 selector). + RegOnly _extraReg; + + enum Id : uint32_t { + //! Invalid or uninitialized instruction id. + kIdNone = 0x00000000u, + //! Abstract instruction (BaseBuilder and BaseCompiler). + kIdAbstract = 0x80000000u + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new BaseInst instance with `id` and `options` set. + //! + //! Default values of `id` and `options` are zero, which means 'none' instruction. Such instruction is guaranteed + //! to never exist for any architecture supported by AsmJit. + ASMJIT_INLINE_NODEBUG explicit BaseInst(InstId instId = 0, InstOptions options = InstOptions::kNone) noexcept + : _id(instId), + _options(options), + _extraReg() {} + + ASMJIT_INLINE_NODEBUG BaseInst(InstId instId, InstOptions options, const RegOnly& extraReg) noexcept + : _id(instId), + _options(options), + _extraReg(extraReg) {} + + ASMJIT_INLINE_NODEBUG BaseInst(InstId instId, InstOptions options, const BaseReg& extraReg) noexcept + : _id(instId), + _options(options), + _extraReg { extraReg.signature(), extraReg.id() } {} + + //! \} + + //! \name Instruction id and modifiers + //! \{ + + //! Returns the instruction id with modifiers. + ASMJIT_INLINE_NODEBUG InstId id() const noexcept { return _id; } + //! Sets the instruction id and modiiers from `id`. + ASMJIT_INLINE_NODEBUG void setId(InstId id) noexcept { _id = id; } + //! Resets the instruction id and modifiers to zero, see \ref kIdNone. + ASMJIT_INLINE_NODEBUG void resetId() noexcept { _id = 0; } + + //! Returns a real instruction id that doesn't contain any modifiers. + ASMJIT_INLINE_NODEBUG InstId realId() const noexcept { return _id & uint32_t(InstIdParts::kRealId); } + + template + ASMJIT_INLINE_NODEBUG uint32_t getInstIdPart() const noexcept { + return (uint32_t(_id) & uint32_t(kPart)) >> Support::ConstCTZ::value; + } + + template + ASMJIT_INLINE_NODEBUG void setInstIdPart(uint32_t value) noexcept { + _id = (_id & ~uint32_t(kPart)) | (value << Support::ConstCTZ::value); + } + + //! \} + + //! \name Instruction Options + //! \{ + + ASMJIT_INLINE_NODEBUG InstOptions options() const noexcept { return _options; } + ASMJIT_INLINE_NODEBUG bool hasOption(InstOptions option) const noexcept { return Support::test(_options, option); } + ASMJIT_INLINE_NODEBUG void setOptions(InstOptions options) noexcept { _options = options; } + ASMJIT_INLINE_NODEBUG void addOptions(InstOptions options) noexcept { _options |= options; } + ASMJIT_INLINE_NODEBUG void clearOptions(InstOptions options) noexcept { _options &= ~options; } + ASMJIT_INLINE_NODEBUG void resetOptions() noexcept { _options = InstOptions::kNone; } + + //! \} + + //! \name Extra Register + //! \{ + + ASMJIT_INLINE_NODEBUG bool hasExtraReg() const noexcept { return _extraReg.isReg(); } + ASMJIT_INLINE_NODEBUG RegOnly& extraReg() noexcept { return _extraReg; } + ASMJIT_INLINE_NODEBUG const RegOnly& extraReg() const noexcept { return _extraReg; } + ASMJIT_INLINE_NODEBUG void setExtraReg(const BaseReg& reg) noexcept { _extraReg.init(reg); } + ASMJIT_INLINE_NODEBUG void setExtraReg(const RegOnly& reg) noexcept { _extraReg.init(reg); } + ASMJIT_INLINE_NODEBUG void resetExtraReg() noexcept { _extraReg.reset(); } + + //! \} + + //! \name ARM Specific + //! \{ + + ASMJIT_INLINE_NODEBUG arm::CondCode armCondCode() const noexcept { return (arm::CondCode)getInstIdPart(); } + ASMJIT_INLINE_NODEBUG void setArmCondCode(arm::CondCode cc) noexcept { setInstIdPart(uint32_t(cc)); } + + ASMJIT_INLINE_NODEBUG a32::DataType armDt() const noexcept { return (a32::DataType)getInstIdPart(); } + ASMJIT_INLINE_NODEBUG a32::DataType armDt2() const noexcept { return (a32::DataType)getInstIdPart(); } + + //! \} + + //! \name Statics + //! \{ + + static ASMJIT_INLINE_NODEBUG constexpr InstId composeARMInstId(uint32_t id, arm::CondCode cc) noexcept { + return id | (uint32_t(cc) << Support::ConstCTZ::value); + } + + static ASMJIT_INLINE_NODEBUG constexpr InstId composeARMInstId(uint32_t id, a32::DataType dt, arm::CondCode cc = arm::CondCode::kAL) noexcept { + return id | (uint32_t(dt) << Support::ConstCTZ::value) + | (uint32_t(cc) << Support::ConstCTZ::value); + } + + static ASMJIT_INLINE_NODEBUG constexpr InstId composeARMInstId(uint32_t id, a32::DataType dt, a32::DataType dt2, arm::CondCode cc = arm::CondCode::kAL) noexcept { + return id | (uint32_t(dt) << Support::ConstCTZ::value) + | (uint32_t(dt2) << Support::ConstCTZ::value) + | (uint32_t(cc) << Support::ConstCTZ::value); + } + + static ASMJIT_INLINE_NODEBUG constexpr InstId extractRealId(uint32_t id) noexcept { + return id & uint32_t(InstIdParts::kRealId); + } + + static ASMJIT_INLINE_NODEBUG constexpr arm::CondCode extractARMCondCode(uint32_t id) noexcept { + return (arm::CondCode)((uint32_t(id) & uint32_t(InstIdParts::kARM_Cond)) >> Support::ConstCTZ::value); + } + + //! \} +}; + +//! CPU read/write flags used by \ref InstRWInfo. +//! +//! These flags can be used to get a basic overview about CPU specifics flags used by instructions. +enum class CpuRWFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + + // Common RW Flags (0x000000FF) + // ---------------------------- + + //! Signed overflow flag. + kOF = 0x00000001u, + //! Carry flag. + kCF = 0x00000002u, + //! Zero and/or equality flag (1 if zero/equal). + kZF = 0x00000004u, + //! Sign flag (negative/sign, if set). + kSF = 0x00000008u, + + // X86 Specific RW Flags + // ---------------------------------- + + //! Carry flag (X86, X86_64). + kX86_CF = kCF, + //! Overflow flag (X86, X86_64). + kX86_OF = kOF, + //! Sign flag (X86, X86_64). + kX86_SF = kSF, + //! Zero flag (X86, X86_64). + kX86_ZF = kZF, + + //! Adjust flag (X86, X86_64). + kX86_AF = 0x00000100u, + //! Parity flag (X86, X86_64). + kX86_PF = 0x00000200u, + //! Direction flag (X86, X86_64). + kX86_DF = 0x00000400u, + //! Interrupt enable flag (X86, X86_64). + kX86_IF = 0x00000800u, + + //! Alignment check flag (X86, X86_64). + kX86_AC = 0x00001000u, + + //! FPU C0 status flag (X86, X86_64). + kX86_C0 = 0x00010000u, + //! FPU C1 status flag (X86, X86_64). + kX86_C1 = 0x00020000u, + //! FPU C2 status flag (X86, X86_64). + kX86_C2 = 0x00040000u, + //! FPU C3 status flag (X86, X86_64). + kX86_C3 = 0x00080000u, + + // ARM Specific RW Flags + // ---------------------------------- + + kARM_V = kOF, + kARM_C = kCF, + kARM_Z = kZF, + kARM_N = kSF, + kARM_Q = 0x00000100u, + kARM_GE = 0x00000200u +}; +ASMJIT_DEFINE_ENUM_FLAGS(CpuRWFlags) + +//! Operand read/write flags describe how the operand is accessed and some additional features. +enum class OpRWFlags : uint32_t { + //! No flags. + kNone = 0, + + //! Operand is read. + kRead = 0x00000001u, + + //! Operand is written. + kWrite = 0x00000002u, + + //! Operand is both read and written. + kRW = 0x00000003u, + + //! Register operand can be replaced by a memory operand. + kRegMem = 0x00000004u, + + //! The register must be allocated to the index of the previous register + 1. + //! + //! This flag is used by all architectures to describe instructions that use consecutive registers, where only the + //! first one is encoded in the instruction, and the others are just a sequence that starts with the first one. On + //! X86/X86_64 architecture this is used by instructions such as V4FMADDPS, V4FMADDSS, V4FNMADDPS, V4FNMADDSS, + //! VP4DPWSSD, VP4DPWSSDS, VP2INTERSECTD, and VP2INTERSECTQ. On ARM/AArch64 this is used by vector load and store + //! instructions that can load or store multiple registers at once. + kConsecutive = 0x00000008u, + + //! The `extendByteMask()` represents a zero extension. + kZExt = 0x00000010u, + + //! The register must have assigned a unique physical ID, which cannot be assigned to any other register. + kUnique = 0x00000080u, + + //! Register operand must use \ref OpRWInfo::physId(). + kRegPhysId = 0x00000100u, + //! Base register of a memory operand must use \ref OpRWInfo::physId(). + kMemPhysId = 0x00000200u, + + //! This memory operand is only used to encode registers and doesn't access memory. + //! + //! X86 Specific + //! ------------ + //! + //! Instructions that use such feature include BNDLDX, BNDSTX, and LEA. + kMemFake = 0x000000400u, + + //! Base register of the memory operand will be read. + kMemBaseRead = 0x00001000u, + //! Base register of the memory operand will be written. + kMemBaseWrite = 0x00002000u, + //! Base register of the memory operand will be read & written. + kMemBaseRW = 0x00003000u, + + //! Index register of the memory operand will be read. + kMemIndexRead = 0x00004000u, + //! Index register of the memory operand will be written. + kMemIndexWrite = 0x00008000u, + //! Index register of the memory operand will be read & written. + kMemIndexRW = 0x0000C000u, + + //! Base register of the memory operand will be modified before the operation. + kMemBasePreModify = 0x00010000u, + //! Base register of the memory operand will be modified after the operation. + kMemBasePostModify = 0x00020000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(OpRWFlags) + +// Don't remove these asserts. Read/Write flags are used extensively +// by Compiler and they must always be compatible with constants below. +static_assert(uint32_t(OpRWFlags::kRead) == 0x1, "OpRWFlags::kRead flag must be 0x1"); +static_assert(uint32_t(OpRWFlags::kWrite) == 0x2, "OpRWFlags::kWrite flag must be 0x2"); +static_assert(uint32_t(OpRWFlags::kRegMem) == 0x4, "OpRWFlags::kRegMem flag must be 0x4"); + +//! Read/Write information related to a single operand, used by \ref InstRWInfo. +struct OpRWInfo { + //! \name Members + //! \{ + + //! Read/Write flags. + OpRWFlags _opFlags; + //! Physical register index, if required. + uint8_t _physId; + //! Size of a possible memory operand that can replace a register operand. + uint8_t _rmSize; + //! If non-zero, then this is a consecutive lead register, and the value describes how many registers follow. + uint8_t _consecutiveLeadCount; + //! Reserved for future use. + uint8_t _reserved[1]; + //! Read bit-mask where each bit represents one byte read from Reg/Mem. + uint64_t _readByteMask; + //! Write bit-mask where each bit represents one byte written to Reg/Mem. + uint64_t _writeByteMask; + //! Zero/Sign extend bit-mask where each bit represents one byte written to Reg/Mem. + uint64_t _extendByteMask; + + //! \} + + //! \name Reset + //! \{ + + //! Resets this operand information to all zeros. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = OpRWInfo{}; } + + //! Resets this operand info (resets all members) and set common information + //! to the given `opFlags`, `regSize`, and possibly `physId`. + inline void reset(OpRWFlags opFlags, uint32_t regSize, uint32_t physId = BaseReg::kIdBad) noexcept { + _opFlags = opFlags; + _physId = uint8_t(physId); + _rmSize = Support::test(opFlags, OpRWFlags::kRegMem) ? uint8_t(regSize) : uint8_t(0); + _consecutiveLeadCount = 0; + _resetReserved(); + + uint64_t mask = Support::lsbMask(Support::min(regSize, 64)); + + _readByteMask = Support::test(opFlags, OpRWFlags::kRead) ? mask : uint64_t(0); + _writeByteMask = Support::test(opFlags, OpRWFlags::kWrite) ? mask : uint64_t(0); + _extendByteMask = 0; + } + + ASMJIT_INLINE_NODEBUG void _resetReserved() noexcept { + _reserved[0] = 0; + } + + //! \} + + //! \name Operand Flags + //! \{ + + //! Returns operand flags. + ASMJIT_INLINE_NODEBUG OpRWFlags opFlags() const noexcept { return _opFlags; } + //! Tests whether operand flags contain the given `flag`. + ASMJIT_INLINE_NODEBUG bool hasOpFlag(OpRWFlags flag) const noexcept { return Support::test(_opFlags, flag); } + + //! Adds the given `flags` to operand flags. + ASMJIT_INLINE_NODEBUG void addOpFlags(OpRWFlags flags) noexcept { _opFlags |= flags; } + //! Removes the given `flags` from operand flags. + ASMJIT_INLINE_NODEBUG void clearOpFlags(OpRWFlags flags) noexcept { _opFlags &= ~flags; } + + //! Tests whether this operand is read from. + ASMJIT_INLINE_NODEBUG bool isRead() const noexcept { return hasOpFlag(OpRWFlags::kRead); } + //! Tests whether this operand is written to. + ASMJIT_INLINE_NODEBUG bool isWrite() const noexcept { return hasOpFlag(OpRWFlags::kWrite); } + //! Tests whether this operand is both read and write. + ASMJIT_INLINE_NODEBUG bool isReadWrite() const noexcept { return (_opFlags & OpRWFlags::kRW) == OpRWFlags::kRW; } + //! Tests whether this operand is read only. + ASMJIT_INLINE_NODEBUG bool isReadOnly() const noexcept { return (_opFlags & OpRWFlags::kRW) == OpRWFlags::kRead; } + //! Tests whether this operand is write only. + ASMJIT_INLINE_NODEBUG bool isWriteOnly() const noexcept { return (_opFlags & OpRWFlags::kRW) == OpRWFlags::kWrite; } + + //! Returns the type of a lead register, which is followed by consecutive registers. + ASMJIT_INLINE_NODEBUG uint32_t consecutiveLeadCount() const noexcept { return _consecutiveLeadCount; } + + //! Tests whether this operand is Reg/Mem + //! + //! Reg/Mem operands can use either register or memory. + ASMJIT_INLINE_NODEBUG bool isRm() const noexcept { return hasOpFlag(OpRWFlags::kRegMem); } + + //! Tests whether the operand will be zero extended. + ASMJIT_INLINE_NODEBUG bool isZExt() const noexcept { return hasOpFlag(OpRWFlags::kZExt); } + + //! Tests whether the operand must have allocated a unique physical id that cannot be shared with other register + //! operands. + ASMJIT_INLINE_NODEBUG bool isUnique() const noexcept { return hasOpFlag(OpRWFlags::kUnique); } + + //! \} + + //! \name Memory Flags + //! \{ + + //! Tests whether this is a fake memory operand, which is only used, because of encoding. Fake memory operands do + //! not access any memory, they are only used to encode registers. + ASMJIT_INLINE_NODEBUG bool isMemFake() const noexcept { return hasOpFlag(OpRWFlags::kMemFake); } + + //! Tests whether the instruction's memory BASE register is used. + ASMJIT_INLINE_NODEBUG bool isMemBaseUsed() const noexcept { return hasOpFlag(OpRWFlags::kMemBaseRW); } + //! Tests whether the instruction reads from its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseRead() const noexcept { return hasOpFlag(OpRWFlags::kMemBaseRead); } + //! Tests whether the instruction writes to its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseWrite() const noexcept { return hasOpFlag(OpRWFlags::kMemBaseWrite); } + //! Tests whether the instruction reads and writes from/to its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseReadWrite() const noexcept { return (_opFlags & OpRWFlags::kMemBaseRW) == OpRWFlags::kMemBaseRW; } + //! Tests whether the instruction only reads from its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseReadOnly() const noexcept { return (_opFlags & OpRWFlags::kMemBaseRW) == OpRWFlags::kMemBaseRead; } + //! Tests whether the instruction only writes to its BASE registers. + ASMJIT_INLINE_NODEBUG bool isMemBaseWriteOnly() const noexcept { return (_opFlags & OpRWFlags::kMemBaseRW) == OpRWFlags::kMemBaseWrite; } + + //! Tests whether the instruction modifies the BASE register before it uses it to calculate the target address. + ASMJIT_INLINE_NODEBUG bool isMemBasePreModify() const noexcept { return hasOpFlag(OpRWFlags::kMemBasePreModify); } + //! Tests whether the instruction modifies the BASE register after it uses it to calculate the target address. + ASMJIT_INLINE_NODEBUG bool isMemBasePostModify() const noexcept { return hasOpFlag(OpRWFlags::kMemBasePostModify); } + + //! Tests whether the instruction's memory INDEX register is used. + ASMJIT_INLINE_NODEBUG bool isMemIndexUsed() const noexcept { return hasOpFlag(OpRWFlags::kMemIndexRW); } + //! Tests whether the instruction reads the INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexRead() const noexcept { return hasOpFlag(OpRWFlags::kMemIndexRead); } + //! Tests whether the instruction writes to its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexWrite() const noexcept { return hasOpFlag(OpRWFlags::kMemIndexWrite); } + //! Tests whether the instruction reads and writes from/to its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexReadWrite() const noexcept { return (_opFlags & OpRWFlags::kMemIndexRW) == OpRWFlags::kMemIndexRW; } + //! Tests whether the instruction only reads from its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexReadOnly() const noexcept { return (_opFlags & OpRWFlags::kMemIndexRW) == OpRWFlags::kMemIndexRead; } + //! Tests whether the instruction only writes to its INDEX registers. + ASMJIT_INLINE_NODEBUG bool isMemIndexWriteOnly() const noexcept { return (_opFlags & OpRWFlags::kMemIndexRW) == OpRWFlags::kMemIndexWrite; } + + //! \} + + //! \name Physical Register ID + //! \{ + + //! Returns a physical id of the register that is fixed for this operand. + //! + //! Returns \ref BaseReg::kIdBad if any register can be used. + ASMJIT_INLINE_NODEBUG uint32_t physId() const noexcept { return _physId; } + //! Tests whether \ref physId() would return a valid physical register id. + ASMJIT_INLINE_NODEBUG bool hasPhysId() const noexcept { return _physId != BaseReg::kIdBad; } + //! Sets physical register id, which would be fixed for this operand. + ASMJIT_INLINE_NODEBUG void setPhysId(uint32_t physId) noexcept { _physId = uint8_t(physId); } + + //! \} + + //! \name Reg/Mem Information + //! \{ + + //! Returns Reg/Mem size of the operand. + ASMJIT_INLINE_NODEBUG uint32_t rmSize() const noexcept { return _rmSize; } + //! Sets Reg/Mem size of the operand. + ASMJIT_INLINE_NODEBUG void setRmSize(uint32_t rmSize) noexcept { _rmSize = uint8_t(rmSize); } + + //! \} + + //! \name Read & Write Masks + //! \{ + + //! Returns read mask. + ASMJIT_INLINE_NODEBUG uint64_t readByteMask() const noexcept { return _readByteMask; } + //! Returns write mask. + ASMJIT_INLINE_NODEBUG uint64_t writeByteMask() const noexcept { return _writeByteMask; } + //! Returns extend mask. + ASMJIT_INLINE_NODEBUG uint64_t extendByteMask() const noexcept { return _extendByteMask; } + + //! Sets read mask. + ASMJIT_INLINE_NODEBUG void setReadByteMask(uint64_t mask) noexcept { _readByteMask = mask; } + //! Sets write mask. + ASMJIT_INLINE_NODEBUG void setWriteByteMask(uint64_t mask) noexcept { _writeByteMask = mask; } + //! Sets extend mask. + ASMJIT_INLINE_NODEBUG void setExtendByteMask(uint64_t mask) noexcept { _extendByteMask = mask; } + + //! \} +}; + +//! Flags used by \ref InstRWInfo. +enum class InstRWFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + + //! Describes a move operation. + //! + //! This flag is used by RA to eliminate moves that are guaranteed to be moves only. + kMovOp = 0x00000001u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstRWFlags) + +//! Read/Write information of an instruction. +struct InstRWInfo { + //! \name Members + //! \{ + + //! Instruction flags (there are no flags at the moment, this field is reserved). + InstRWFlags _instFlags; + //! CPU flags read. + CpuRWFlags _readFlags; + //! CPU flags written. + CpuRWFlags _writeFlags; + //! Count of operands. + uint8_t _opCount; + //! CPU feature required for replacing register operand with memory operand. + uint8_t _rmFeature; + //! Reserved for future use. + uint8_t _reserved[18]; + //! Read/Write info of extra register (rep{} or kz{}). + OpRWInfo _extraReg; + //! Read/Write info of instruction operands. + OpRWInfo _operands[Globals::kMaxOpCount]; + + //! \} + + //! \name Commons + //! \{ + + //! Resets this RW information to all zeros. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = InstRWInfo{}; } + + //! \} + + //! \name Instruction Flags + //! \{ + + //! Returns flags associated with the instruction, see \ref InstRWFlags. + ASMJIT_INLINE_NODEBUG InstRWFlags instFlags() const noexcept { return _instFlags; } + + //! Tests whether the instruction flags contain `flag`. + ASMJIT_INLINE_NODEBUG bool hasInstFlag(InstRWFlags flag) const noexcept { return Support::test(_instFlags, flag); } + + //! Tests whether the instruction flags contain \ref InstRWFlags::kMovOp. + ASMJIT_INLINE_NODEBUG bool isMovOp() const noexcept { return hasInstFlag(InstRWFlags::kMovOp); } + + //! \} + + //! \name CPU Flags Information + //! \{ + + //! Returns a mask of CPU flags read. + ASMJIT_INLINE_NODEBUG CpuRWFlags readFlags() const noexcept { return _readFlags; } + //! Returns a mask of CPU flags written. + ASMJIT_INLINE_NODEBUG CpuRWFlags writeFlags() const noexcept { return _writeFlags; } + + //! \} + + //! \name Reg/Mem Information + //! \{ + + //! Returns the CPU feature required to replace a register operand with memory operand. If the returned feature is + //! zero (none) then this instruction either doesn't provide memory operand combination or there is no extra CPU + //! feature required. + //! + //! X86 Specific + //! ------------ + //! + //! Some AVX+ instructions may require extra features for replacing registers with memory operands, for example + //! VPSLLDQ instruction only supports `vpslldq reg, reg, imm` combination on AVX/AVX2 capable CPUs and requires + //! AVX-512 for `vpslldq reg, mem, imm` combination. + ASMJIT_INLINE_NODEBUG uint32_t rmFeature() const noexcept { return _rmFeature; } + + //! \} + + //! \name Operand Read/Write Information + //! \{ + + //! Returns RW information of extra register operand (extraReg). + ASMJIT_INLINE_NODEBUG const OpRWInfo& extraReg() const noexcept { return _extraReg; } + + //! Returns RW information of all instruction's operands. + ASMJIT_INLINE_NODEBUG const OpRWInfo* operands() const noexcept { return _operands; } + + //! Returns RW information of the operand at the given `index`. + inline const OpRWInfo& operand(size_t index) const noexcept { + ASMJIT_ASSERT(index < Globals::kMaxOpCount); + return _operands[index]; + } + + //! Returns the number of operands this instruction has. + ASMJIT_INLINE_NODEBUG uint32_t opCount() const noexcept { return _opCount; } + + //! \} +}; + +//! Validation flags that can be used with \ref InstAPI::validate(). +enum class ValidationFlags : uint32_t { + //! No flags. + kNone = 0, + //! Allow virtual registers in the instruction. + kEnableVirtRegs = 0x01u +}; +ASMJIT_DEFINE_ENUM_FLAGS(ValidationFlags) + +//! Instruction API. +namespace InstAPI { + +#ifndef ASMJIT_NO_TEXT +//! Appends the name of the instruction specified by `instId` and `instOptions` into the `output` string. +//! +//! \note Instruction options would only affect instruction prefix & suffix, other options would be ignored. +//! If `instOptions` is zero then only raw instruction name (without any additional text) will be appended. +ASMJIT_API Error instIdToString(Arch arch, InstId instId, String& output) noexcept; + +//! Parses an instruction name in the given string `s`. Length is specified by `len` argument, which can be +//! `SIZE_MAX` if `s` is known to be null terminated. +//! +//! Returns the parsed instruction id or \ref BaseInst::kIdNone if no such instruction exists. +ASMJIT_API InstId stringToInstId(Arch arch, const char* s, size_t len) noexcept; +#endif // !ASMJIT_NO_TEXT + +#ifndef ASMJIT_NO_VALIDATION +//! Validates the given instruction considering the given `validationFlags`. +ASMJIT_API Error validate(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, ValidationFlags validationFlags = ValidationFlags::kNone) noexcept; +#endif // !ASMJIT_NO_VALIDATION + +#ifndef ASMJIT_NO_INTROSPECTION +//! Gets Read/Write information of the given instruction. +ASMJIT_API Error queryRWInfo(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, InstRWInfo* out) noexcept; + +//! Gets CPU features required by the given instruction. +ASMJIT_API Error queryFeatures(Arch arch, const BaseInst& inst, const Operand_* operands, size_t opCount, CpuFeatures* out) noexcept; +#endif // !ASMJIT_NO_INTROSPECTION + +} // {InstAPI} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_INST_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/jitallocator.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/jitallocator.h new file mode 100644 index 0000000000000000000000000000000000000000..5a2fbdf2d1fffc7d914a992d55437f8de973a41e --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/jitallocator.h @@ -0,0 +1,570 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_JITALLOCATOR_H_INCLUDED +#define ASMJIT_CORE_JITALLOCATOR_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_JIT + +#include "../core/globals.h" +#include "../core/support.h" +#include "../core/virtmem.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_virtual_memory +//! \{ + +//! Options used by \ref JitAllocator. +enum class JitAllocatorOptions : uint32_t { + //! No options. + kNone = 0, + + //! Enables the use of an anonymous memory-mapped memory that is mapped into two buffers having a different pointer. + //! The first buffer has read and execute permissions and the second buffer has read+write permissions. + //! + //! See \ref VirtMem::allocDualMapping() for more details about this feature. + //! + //! \remarks Dual mapping would be automatically turned on by \ref JitAllocator in case of hardened runtime that + //! enforces `W^X` policy, so specifying this flag is essentially forcing to use dual mapped pages even when RWX + //! pages can be allocated and dual mapping is not necessary. + kUseDualMapping = 0x00000001u, + + //! Enables the use of multiple pools with increasing granularity instead of a single pool. This flag would enable + //! 3 internal pools in total having 64, 128, and 256 bytes granularity. + //! + //! This feature is only recommended for users that generate a lot of code and would like to minimize the overhead + //! of `JitAllocator` itself by having blocks of different allocation granularities. Using this feature only for + //! few allocations won't pay off as the allocator may need to create more blocks initially before it can take the + //! advantage of variable block granularity. + kUseMultiplePools = 0x00000002u, + + //! Always fill reserved memory by a fill-pattern. + //! + //! Causes a new block to be cleared by the fill pattern and freshly released memory to be cleared before making + //! it ready for another use. + kFillUnusedMemory = 0x00000004u, + + //! When this flag is set the allocator would immediately release unused blocks during `release()` or `reset()`. + //! When this flag is not set the allocator would keep one empty block in each pool to prevent excessive virtual + //! memory allocations and deallocations in border cases, which involve constantly allocating and deallocating a + //! single block caused by repetitive calling `alloc()` and `release()` when the allocator has either no blocks + //! or have all blocks fully occupied. + kImmediateRelease = 0x00000008u, + + //! This flag enables placing functions (or allocating memory) at the very beginning of each memory mapped region. + //! + //! Initially, this was the default behavior. However, LLVM developers working on undefined behavior sanitizer + //! (UBSAN) decided that they want to store metadata before each function and to access such metadata before an + //! indirect function call. This means that the instrumented code always reads from `[fnPtr - 8]` to decode whether + //! the function has his metadata present. However, reading 8 bytes below a function means that if a function is + //! placed at the very beginning of a memory mapped region, it could try to read bytes that are inaccessible. And + //! since AsmJit can be compiled as a shared library and used by applications instrumented by UBSAN, it's not + //! possible to conditionally compile the support only when necessary. + //! + //! \remarks This flag controls a workaround to make it possible to use LLVM UBSAN with AsmJit's \ref JitAllocator. + //! There is no undefined behavior even when `kDisableInitialPadding` is used, however, that doesn't really matter + //! as LLVM's UBSAN introduces one, and according to LLVM developers it's a "trade-off". This flag is safe to use + //! when the code is not instrumented with LLVM's UBSAN. + kDisableInitialPadding = 0x00000010u, + + //! Enables the use of large pages, if they are supported and the process can actually allocate them. + //! + //! \remarks This flag is a hint - if large pages can be allocated, JitAllocator would try to allocate them. + //! However, if the allocation fails, it will still try to fallback to use regular pages as \ref JitAllocator + //! is designed to minimize allocation failures, so a regular page is better than no page at all. Also, if a + //! block \ref JitAllocator wants to allocate is too small to consume a whole large page, regular page(s) will + //! be allocated as well. + kUseLargePages = 0x00000020u, + + //! Forces \ref JitAllocator to always align block size to be at least as big as a large page, if large pages are + //! enabled. This option does nothing if large pages are disabled. + //! + //! \remarks If \ref kUseLargePages option is used, the allocator would prefer large pages only when allocating a + //! block that has a sufficient size. Usually the allocator first allocates smaller block and when more requests + //! come it will start increasing the block size of next allocations. This option makes it sure that even the first + //! allocation would be the same as a minimum large page when large pages are enabled and can be allocated. + kAlignBlockSizeToLargePage = 0x00000040u, + + //! Use a custom fill pattern, must be combined with `kFlagFillUnusedMemory`. + kCustomFillPattern = 0x10000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(JitAllocatorOptions) + +//! A simple implementation of memory manager that uses `asmjit::VirtMem` +//! functions to manage virtual memory for JIT compiled code. +//! +//! Implementation notes: +//! +//! - Granularity of allocated blocks is different than granularity for a typical C malloc. In addition, the allocator +//! can use several memory pools having a different granularity to minimize the maintenance overhead. Multiple pools +//! feature requires `kFlagUseMultiplePools` flag to be set. +//! +//! - The allocator doesn't store any information in executable memory, instead, the implementation uses two +//! bit-vectors to manage allocated memory of each allocator-block. The first bit-vector called 'used' is used to +//! track used memory (where each bit represents memory size defined by granularity) and the second bit vector called +//! 'stop' is used as a sentinel to mark where the allocated area ends. +//! +//! - Internally, the allocator also uses RB tree to keep track of all blocks across all pools. Each inserted block is +//! added to the tree so it can be matched fast during `release()` and `shrink()`. +class JitAllocator { +public: + ASMJIT_NONCOPYABLE(JitAllocator) + + //! Visible \ref JitAllocator implementation data. + struct Impl { + //! Allocator options. + JitAllocatorOptions options; + //! Base block size (0 if the allocator is not initialized). + uint32_t blockSize; + //! Base granularity (0 if the allocator is not initialized). + uint32_t granularity; + //! A pattern that is used to fill unused memory if secure mode is enabled. + uint32_t fillPattern; + }; + + //! \name Members + //! \{ + + //! Allocator implementation (private). + Impl* _impl; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Parameters that can be passed to `JitAllocator` constructor. + //! + //! Use it like this: + //! + //! ``` + //! // Zero initialize (zero means the default value) and change what you need. + //! JitAllocator::CreateParams params {}; + //! params.blockSize = 1024 * 1024; + //! + //! // Create the allocator. + //! JitAllocator allocator(¶ms); + //! ``` + struct CreateParams { + //! Allocator options. + //! + //! No options are used by default. + JitAllocatorOptions options = JitAllocatorOptions::kNone; + + //! Base size of a single block in bytes (default 64kB). + //! + //! \remarks Block size must be equal to or greater than page size and must be power of 2. If the input is not + //! valid then the default block size will be used instead. + uint32_t blockSize = 0; + + //! Base granularity (and also natural alignment) of allocations in bytes (default 64). + //! + //! Since the `JitAllocator` uses bit-arrays to mark used memory the granularity also specifies how many bytes + //! correspond to a single bit in such bit-array. Higher granularity means more waste of virtual memory (as it + //! increases the natural alignment), but smaller bit-arrays as less bits would be required per a single block. + uint32_t granularity = 0; + + //! Patter to use to fill unused memory. + //! + //! Only used if \ref JitAllocatorOptions::kCustomFillPattern is set. + uint32_t fillPattern = 0; + + // Reset the content of `CreateParams`. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = CreateParams{}; } + }; + + //! Creates a `JitAllocator` instance. + ASMJIT_API explicit JitAllocator(const CreateParams* params = nullptr) noexcept; + //! Destroys the `JitAllocator` instance and release all blocks held. + ASMJIT_API ~JitAllocator() noexcept; + + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _impl->blockSize == 0; } + + //! Free all allocated memory - makes all pointers returned by `alloc()` invalid. + //! + //! \remarks This function is not thread-safe as it's designed to be used when nobody else is using allocator. + //! The reason is that there is no point of calling `reset()` when the allocator is still in use. + ASMJIT_API void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns allocator options, see `Flags`. + ASMJIT_INLINE_NODEBUG JitAllocatorOptions options() const noexcept { return _impl->options; } + //! Tests whether the allocator has the given `option` set. + ASMJIT_INLINE_NODEBUG bool hasOption(JitAllocatorOptions option) const noexcept { return uint32_t(_impl->options & option) != 0; } + + //! Returns a base block size (a minimum size of block that the allocator would allocate). + ASMJIT_INLINE_NODEBUG uint32_t blockSize() const noexcept { return _impl->blockSize; } + //! Returns granularity of the allocator. + ASMJIT_INLINE_NODEBUG uint32_t granularity() const noexcept { return _impl->granularity; } + //! Returns pattern that is used to fill unused memory if `kFlagUseFillPattern` is set. + ASMJIT_INLINE_NODEBUG uint32_t fillPattern() const noexcept { return _impl->fillPattern; } + + //! \} + + //! \name Alloc & Release + //! \{ + + //! A memory reference returned by \ref JitAllocator::alloc(). + //! + //! Span contains everything needed to actually write new code to the memory chunk it references. + class Span { + public: + //! \name Constants + //! \{ + + //! Span flags + enum class Flags : uint32_t { + //! No flags. + kNone = 0u, + + //! The process has never executed the region of the span. + //! + //! If this flag is set on a \ref Span it would mean that the allocator can avoid flushing + //! instruction cache after a code has been written to it. + kInstructionCacheClean = 0x00000001u + }; + + //! \} + + //! \name Members + //! \{ + + //! Address of memory that has Read and Execute permissions. + void* _rx = nullptr; + + //! Address of memory that has Read and Write permissions. + void* _rw = nullptr; + + //! Size of the span in bytes (rounded up to the allocation granularity). + size_t _size = 0; + + //! Pointer that references a memory block maintained by \ref JitAllocator. + //! + //! This pointer is considered private and should never be used nor inspected outside of AsmJit. + void* _block = nullptr; + + //! Span flags. + Flags _flags = Flags::kNone; + + //! Reserved for future use. + uint32_t _reserved = 0; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns a pointer having Read & Execute permissions (references executable memory). + //! + //! This pointer is never NULL if the allocation succeeded, it points to an executable memory. + ASMJIT_INLINE_NODEBUG void* rx() const noexcept { return _rx; } + + //! Returns a pointer having Read & Write permissions (references writable memory). + //! + //! Depending on the type of the allocation strategy this could either be: + //! + //! - the same address as returned by `rx()` if the allocator uses RWX mapping (pages have all of Read, Write, + //! and Execute permissions) or MAP_JIT, which requires either \ref VirtMem::ProtectJitReadWriteScope or to + //! call \ref VirtMem::protectJitMemory() manually. + //! - a valid pointer, but not the same as `rx` - this would be valid if dual mapping is used. + //! - NULL pointer, in case that the allocation strategy doesn't use RWX, MAP_JIT, or dual mapping. In this + //! case only \ref JitAllocator can copy new code into the executable memory referenced by \ref Span. + //! + //! \note If `rw()` returns a non-null pointer it's important to use either VirtMem::protectJitMemory() or + //! \ref VirtMem::ProtectJitReadWriteScope to guard the write, because in case of `MAP_JIT` it would temporarily + //! switch the permissions of the pointer to RW (that's per thread permissions). + //! + //! If \ref VirtMem::ProtectJitReadWriteScope is not used it's important to clear the instruction cache via + //! \ref VirtMem::flushInstructionCache() after the write is done. + ASMJIT_INLINE_NODEBUG void* rw() const noexcept { return _rw; } + + //! Returns size of this span, aligned to the allocator granularity. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + + //! Returns span flags. + ASMJIT_INLINE_NODEBUG Flags flags() const noexcept { return _flags; } + + //! Shrinks this span to `newSize`. + //! + //! \note This is the only function that is able to change the size of a span, and it's only use case is to + //! shrink the span size during \ref JitAllocator::write(). When the writer detects that the span size shrunk, + //! it will automatically shrink the memory used by the span, and propagate the new aligned size to the caller. + ASMJIT_INLINE_NODEBUG void shrink(size_t newSize) noexcept { _size = Support::min(_size, newSize); } + + //! Returns whether \ref rw() returns a non-null pointer. + ASMJIT_INLINE_NODEBUG bool isDirectlyWritable() const noexcept { return _rw != nullptr; } + + //! \} + }; + + //! Allocates a new memory span of the requested `size`. + ASMJIT_API Error alloc(Span& out, size_t size) noexcept; + + //! Releases a memory block returned by `alloc()`. + //! + //! \remarks This function is thread-safe. + ASMJIT_API Error release(void* rx) noexcept; + + //! Frees extra memory allocated with `rx` by shrinking it to the given `newSize`. + //! + //! \remarks This function is thread-safe. + ASMJIT_API Error shrink(Span& span, size_t newSize) noexcept; + + //! Queries information about an allocated memory block that contains the given `rx`, and writes it to `out`. + //! + //! If the pointer is matched, the function returns `kErrorOk` and fills `out` with the corresponding span. + ASMJIT_API Error query(Span& out, void* rx) const noexcept; + +#if !defined(ASMJIT_NO_DEPRECATED) + //! Allocates a new memory block of the requested `size`. + ASMJIT_DEPRECATED("Use alloc(Span& out, size_t size) instead") + ASMJIT_FORCE_INLINE Error alloc(void** rxPtrOut, void** rwPtrOut, size_t size) noexcept { + Span span; + Error err = alloc(span, size); + *rwPtrOut = span.rw(); + *rxPtrOut = span.rx(); + return err; + } + + ASMJIT_DEPRECATED("Use shrink(Span& span, size_t newSize) instead") + ASMJIT_FORCE_INLINE Error shrink(void* rxPtr, size_t newSize) noexcept { + Span span; + ASMJIT_PROPAGATE(query(span, rxPtr)); + return (span.size() > newSize) ? shrink(span, newSize) : Error(kErrorOk); + } + + ASMJIT_DEPRECATED("Use query(Span& out, void* rx) instead") + ASMJIT_FORCE_INLINE Error query(void* rxPtr, void** rxPtrOut, void** rwPtrOut, size_t* sizeOut) const noexcept { + Span span; + Error err = query(span, rxPtr); + *rxPtrOut = span.rx(); + *rwPtrOut = span.rw(); + *sizeOut = span.size(); + return err; + } +#endif + + //! \} + + //! \name Write Operations + //! \{ + + typedef Error (ASMJIT_CDECL* WriteFunc)(Span& span, void* userData) ASMJIT_NOEXCEPT_TYPE; + + ASMJIT_API Error write( + Span& span, + size_t offset, + const void* src, + size_t size, + VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept; + + ASMJIT_API Error write( + Span& span, + WriteFunc writeFunc, + void* userData, + VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept; + + template + ASMJIT_FORCE_INLINE Error write( + Span& span, + Lambda&& lambdaFunc, + VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept { + + WriteFunc wrapperFunc = [](Span& span, void* userData) noexcept -> Error { + Lambda& lambdaFunc = *static_cast(userData); + return lambdaFunc(span); + }; + return write(span, wrapperFunc, (void*)(&lambdaFunc), policy); + } + + //! \} + + //! \name Write Operations with Scope + //! \{ + + //! \cond INTERNAL + + //! Write scope data. + //! + //! This is mostly for internal purposes, please use \ref WriteScope instead. + struct WriteScopeData { + //! \name Members + //! \{ + + //! Link to the allocator. + JitAllocator* _allocator; + //! Cache policy passed to \ref JitAllocator::beginWriteScope(). + VirtMem::CachePolicy _policy; + //! Internal flags used by the implementation. + uint32_t _flags; + //! Internal data used by the implementation. + size_t _data[64]; + + //! \} + }; + + //! Begins a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope constructor instead. + ASMJIT_API Error beginWriteScope(WriteScopeData& scope, VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept; + + //! Ends a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope destructor instead. + ASMJIT_API Error endWriteScope(WriteScopeData& scope) noexcept; + + //! Flushes accumulated changes in a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope destructor or \ref WriteScope::flush() instead. + ASMJIT_API Error flushWriteScope(WriteScopeData& scope) noexcept; + + //! Alternative to `JitAllocator::write(span, offset, src, size)`, but under a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope::write() instead. + ASMJIT_API Error scopedWrite(WriteScopeData& scope, Span& span, size_t offset, const void* src, size_t size) noexcept; + + //! Alternative to `JitAllocator::write(span, writeFunc, userData)`, but under a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope::write() instead. + ASMJIT_API Error scopedWrite(WriteScopeData& scope, Span& span, WriteFunc writeFunc, void* userData) noexcept; + + //! Alternative to `JitAllocator::write(span, [lambda])`, but under a write `scope`. + //! + //! This is mostly for internal purposes, please use \ref WriteScope::write() instead. + template + inline Error scopedWrite(WriteScopeData& scope, Span& span, Lambda&& lambdaFunc) noexcept { + WriteFunc wrapperFunc = [](Span& span, void* userData) noexcept -> Error { + Lambda& lambdaFunc = *static_cast(userData); + return lambdaFunc(span); + }; + return scopedWrite(scope, span, wrapperFunc, (void*)(&lambdaFunc)); + } + + //! \endcond + + //! Write scope can be used to create a single scope that is optimized for writing multiple spans. + class WriteScope : public WriteScopeData { + public: + ASMJIT_NONCOPYABLE(WriteScope) + + //! \name Construction & Destruction + //! \{ + + // Begins a write scope. + inline explicit WriteScope(JitAllocator* allocator, VirtMem::CachePolicy policy = VirtMem::CachePolicy::kDefault) noexcept { + allocator->beginWriteScope(*this, policy); + } + + // Ends a write scope. + inline ~WriteScope() noexcept { + if (_allocator) + _allocator->endWriteScope(*this); + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG JitAllocator* allocator() const noexcept { return _allocator; } + ASMJIT_INLINE_NODEBUG VirtMem::CachePolicy policy() const noexcept { return _policy; } + + //! \} + + //! \name Operations + //! \{ + + //! Similar to `JitAllocator::write(span, offset, src, size)`, but under a write scope. + ASMJIT_INLINE_NODEBUG Error write(Span& span, size_t offset, const void* src, size_t size) noexcept { + return _allocator->scopedWrite(*this, span, offset, src, size); + } + + //! Similar to `JitAllocator::write(span, writeFunc, userData)`, but under a write scope. + ASMJIT_INLINE_NODEBUG Error write(Span& span, WriteFunc writeFunc, void* userData) noexcept { + return _allocator->scopedWrite(*this, span, writeFunc, userData); + } + + //! Similar to `JitAllocator::write(span, )`, but under a write scope. + template + ASMJIT_INLINE_NODEBUG Error write(Span& span, Lambda&& lambdaFunc) noexcept { + return _allocator->scopedWrite(*this, span, lambdaFunc); + } + + //! Flushes accumulated changes in this write scope. + ASMJIT_INLINE_NODEBUG Error flush() noexcept { + return _allocator->flushWriteScope(*this); + } + + //! \} + }; + + //! \} + + //! \name Statistics + //! \{ + + //! Statistics about `JitAllocator`. + struct Statistics { + //! Number of blocks `JitAllocator` maintains. + size_t _blockCount; + //! Number of active allocations. + size_t _allocationCount; + //! How many bytes are currently used / allocated. + size_t _usedSize; + //! How many bytes are currently reserved by the allocator. + size_t _reservedSize; + //! Allocation overhead (in bytes) required to maintain all blocks. + size_t _overheadSize; + + //! Resets the statistics to all zeros. + ASMJIT_INLINE_NODEBUG void reset() noexcept { *this = Statistics{}; } + + //! Returns count of blocks managed by `JitAllocator` at the moment. + ASMJIT_INLINE_NODEBUG size_t blockCount() const noexcept { return _blockCount; } + //! Returns the number of active allocations. + ASMJIT_INLINE_NODEBUG size_t allocationCount() const noexcept { return _allocationCount; } + + //! Returns how many bytes are currently used. + ASMJIT_INLINE_NODEBUG size_t usedSize() const noexcept { return _usedSize; } + //! Returns the number of bytes unused by the allocator at the moment. + ASMJIT_INLINE_NODEBUG size_t unusedSize() const noexcept { return _reservedSize - _usedSize; } + //! Returns the total number of bytes reserved by the allocator (sum of sizes of all blocks). + ASMJIT_INLINE_NODEBUG size_t reservedSize() const noexcept { return _reservedSize; } + //! Returns the number of bytes the allocator needs to manage the allocated memory. + ASMJIT_INLINE_NODEBUG size_t overheadSize() const noexcept { return _overheadSize; } + + ASMJIT_INLINE_NODEBUG double usedSizeAsPercent() const noexcept { + return (double(usedSize()) / (double(reservedSize()) + 1e-16)) * 100.0; + } + + ASMJIT_INLINE_NODEBUG double unusedSizeAsPercent() const noexcept { + return (double(unusedSize()) / (double(reservedSize()) + 1e-16)) * 100.0; + } + + ASMJIT_INLINE_NODEBUG double overheadSizeAsPercent() const noexcept { + return (double(overheadSize()) / (double(reservedSize()) + 1e-16)) * 100.0; + } + }; + + //! Returns JIT allocator statistics. + //! + //! \remarks This function is thread-safe. + ASMJIT_API Statistics statistics() const noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif +#endif diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/jitruntime.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/jitruntime.h new file mode 100644 index 0000000000000000000000000000000000000000..5150eb79863c697311d26dca89a5f72b2d927736 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/jitruntime.h @@ -0,0 +1,101 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_JITRUNTIME_H_INCLUDED +#define ASMJIT_CORE_JITRUNTIME_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_JIT + +#include "../core/codeholder.h" +#include "../core/jitallocator.h" +#include "../core/target.h" + +ASMJIT_BEGIN_NAMESPACE + +class CodeHolder; + +//! \addtogroup asmjit_virtual_memory +//! \{ + +//! JIT execution runtime is a special `Target` that is designed to store and execute a generated code. +//! +//! JIT runtime is the easiest way of using AsmJit as it abstracts allocation and deallocation of virtual memory +//! where executable code can be placed and from which it can be executed as well. +class ASMJIT_VIRTAPI JitRuntime : public Target { +public: + ASMJIT_NONCOPYABLE(JitRuntime) + + //! Virtual memory allocator. + JitAllocator _allocator; + + //! \name Construction & Destruction + //! \{ + + //! Creates a `JitRuntime` instance. + ASMJIT_API explicit JitRuntime(const JitAllocator::CreateParams* params = nullptr) noexcept; + //! Destroys the `JitRuntime` instance. + ASMJIT_API ~JitRuntime() noexcept override; + + //! \} + + //! \name Accessors + //! \{ + + //! Resets the \ref JitRuntime, freeing everything that was allocated by it. + //! + //! Depending on `resetPolicy` the currently held memory can be either freed entirely when ResetPolicy::kHard is used, + //! or the allocator can keep some of it for next allocations when ResetPolicy::kSoft is used, which is the default + //! behavior. + ASMJIT_INLINE_NODEBUG void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept { + _allocator.reset(resetPolicy); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the associated `JitAllocator`. + ASMJIT_INLINE_NODEBUG JitAllocator* allocator() const noexcept { return const_cast(&_allocator); } + + //! \} + + //! \name Utilities + //! \{ + + // NOTE: To allow passing function pointers to `add()` and `release()` the + // virtual methods are prefixed with `_` and called from templates instead. + + //! Allocates memory needed for a code stored in the `CodeHolder` and relocates the code to the pointer allocated. + //! + //! The beginning of the memory allocated for the function is returned in `dst`. If failed `Error` code is returned + //! and `dst` is explicitly set to `nullptr` (this means that you don't have to set it to null before calling `add()`). + template + ASMJIT_INLINE_NODEBUG Error add(Func* dst, CodeHolder* code) noexcept { + return _add(Support::ptr_cast_impl(dst), code); + } + + //! Releases `p` which was obtained by calling `add()`. + template + ASMJIT_INLINE_NODEBUG Error release(Func p) noexcept { + return _release(Support::ptr_cast_impl(p)); + } + + //! Type-unsafe version of `add()`. + ASMJIT_API virtual Error _add(void** dst, CodeHolder* code) noexcept; + + //! Type-unsafe version of `release()`. + ASMJIT_API virtual Error _release(void* p) noexcept; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif +#endif diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/logger.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/logger.h new file mode 100644 index 0000000000000000000000000000000000000000..f37c72c671e000bb4188cbe3794a942b701a2489 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/logger.h @@ -0,0 +1,198 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_LOGGING_H_INCLUDED +#define ASMJIT_CORE_LOGGING_H_INCLUDED + +#include "../core/inst.h" +#include "../core/string.h" +#include "../core/formatter.h" + +#ifndef ASMJIT_NO_LOGGING + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_logging +//! \{ + +//! Logging interface. +//! +//! This class can be inherited and reimplemented to fit into your own logging needs. When reimplementing a logger +//! use \ref Logger::_log() method to log customize the output. +//! +//! There are two `Logger` implementations offered by AsmJit: +//! - \ref FileLogger - logs into a `FILE*`. +//! - \ref StringLogger - concatenates all logs into a \ref String. +class ASMJIT_VIRTAPI Logger { +public: + ASMJIT_BASE_CLASS(Logger) + ASMJIT_NONCOPYABLE(Logger) + + //! Format options. + FormatOptions _options; + + //! \name Construction & Destruction + //! \{ + + //! Creates a `Logger` instance. + ASMJIT_API Logger() noexcept; + //! Destroys the `Logger` instance. + ASMJIT_API virtual ~Logger() noexcept; + + //! \} + + //! \name Format Options + //! \{ + + //! Returns \ref FormatOptions of this logger. + ASMJIT_INLINE_NODEBUG FormatOptions& options() noexcept { return _options; } + //! \overload + ASMJIT_INLINE_NODEBUG const FormatOptions& options() const noexcept { return _options; } + //! Sets formatting options of this Logger to `options`. + ASMJIT_INLINE_NODEBUG void setOptions(const FormatOptions& options) noexcept { _options = options; } + //! Resets formatting options of this Logger to defaults. + ASMJIT_INLINE_NODEBUG void resetOptions() noexcept { _options.reset(); } + + //! Returns formatting flags. + ASMJIT_INLINE_NODEBUG FormatFlags flags() const noexcept { return _options.flags(); } + //! Tests whether the logger has the given `flag` enabled. + ASMJIT_INLINE_NODEBUG bool hasFlag(FormatFlags flag) const noexcept { return _options.hasFlag(flag); } + //! Sets formatting flags to `flags`. + ASMJIT_INLINE_NODEBUG void setFlags(FormatFlags flags) noexcept { _options.setFlags(flags); } + //! Enables the given formatting `flags`. + ASMJIT_INLINE_NODEBUG void addFlags(FormatFlags flags) noexcept { _options.addFlags(flags); } + //! Disables the given formatting `flags`. + ASMJIT_INLINE_NODEBUG void clearFlags(FormatFlags flags) noexcept { _options.clearFlags(flags); } + + //! Returns indentation of a given indentation `group`. + ASMJIT_INLINE_NODEBUG uint32_t indentation(FormatIndentationGroup type) const noexcept { return _options.indentation(type); } + //! Sets indentation of the given indentation `group` to `n` spaces. + ASMJIT_INLINE_NODEBUG void setIndentation(FormatIndentationGroup type, uint32_t n) noexcept { _options.setIndentation(type, n); } + //! Resets indentation of the given indentation `group` to 0 spaces. + ASMJIT_INLINE_NODEBUG void resetIndentation(FormatIndentationGroup type) noexcept { _options.resetIndentation(type); } + + //! Returns padding of a given padding `group`. + ASMJIT_INLINE_NODEBUG size_t padding(FormatPaddingGroup type) const noexcept { return _options.padding(type); } + //! Sets padding of a given padding `group` to `n`. + ASMJIT_INLINE_NODEBUG void setPadding(FormatPaddingGroup type, uint32_t n) noexcept { _options.setPadding(type, n); } + //! Resets padding of a given padding `group` to 0, which means that a default will be used. + ASMJIT_INLINE_NODEBUG void resetPadding(FormatPaddingGroup type) noexcept { _options.resetPadding(type); } + + //! \} + + //! \name Logging Interface + //! \{ + + //! Logs `str` - must be reimplemented. + //! + //! The function can accept either a null terminated string if `size` is `SIZE_MAX` or a non-null terminated + //! string of the given `size`. The function cannot assume that the data is null terminated and must handle + //! non-null terminated inputs. + ASMJIT_API virtual Error _log(const char* data, size_t size) noexcept; + + //! Logs string `str`, which is either null terminated or having size `size`. + ASMJIT_INLINE_NODEBUG Error log(const char* data, size_t size = SIZE_MAX) noexcept { return _log(data, size); } + //! Logs content of a string `str`. + ASMJIT_INLINE_NODEBUG Error log(const String& str) noexcept { return _log(str.data(), str.size()); } + + //! Formats the message by using `snprintf()` and then passes the formatted string to \ref _log(). + ASMJIT_API Error logf(const char* fmt, ...) noexcept; + + //! Formats the message by using `vsnprintf()` and then passes the formatted string to \ref _log(). + ASMJIT_API Error logv(const char* fmt, va_list ap) noexcept; + + //! \} +}; + +//! Logger that can log to a `FILE*`. +class ASMJIT_VIRTAPI FileLogger : public Logger { +public: + ASMJIT_NONCOPYABLE(FileLogger) + + FILE* _file; + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `FileLogger` that logs to `FILE*`. + ASMJIT_API FileLogger(FILE* file = nullptr) noexcept; + //! Destroys the `FileLogger`. + ASMJIT_API ~FileLogger() noexcept override; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the logging output stream or null if the logger has no output stream. + ASMJIT_INLINE_NODEBUG FILE* file() const noexcept { return _file; } + + //! Sets the logging output stream to `stream` or null. + //! + //! \note If the `file` is null the logging will be disabled. When a logger is attached to `CodeHolder` or any + //! emitter the logging API will always be called regardless of the output file. This means that if you really + //! want to disable logging at emitter level you must not attach a logger to it. + ASMJIT_INLINE_NODEBUG void setFile(FILE* file) noexcept { _file = file; } + + //! \} + + ASMJIT_API Error _log(const char* data, size_t size = SIZE_MAX) noexcept override; +}; + +//! Logger that stores everything in an internal string buffer. +class ASMJIT_VIRTAPI StringLogger : public Logger { +public: + ASMJIT_NONCOPYABLE(StringLogger) + + //! Logger data as string. + String _content; + + //! \name Construction & Destruction + //! \{ + + //! Create new `StringLogger`. + ASMJIT_API StringLogger() noexcept; + //! Destroys the `StringLogger`. + ASMJIT_API ~StringLogger() noexcept override; + + //! \} + + //! \name Logger Data Accessors + //! \{ + + //! Returns the content of the logger as \ref String. + //! + //! It can be moved, if desired. + ASMJIT_INLINE_NODEBUG String& content() noexcept { return _content; } + //! \overload + ASMJIT_INLINE_NODEBUG const String& content() const noexcept { return _content; } + + //! Returns aggregated logger data as `char*` pointer. + //! + //! The pointer is owned by `StringLogger`, it can't be modified or freed. + ASMJIT_INLINE_NODEBUG const char* data() const noexcept { return _content.data(); } + //! Returns size of the data returned by `data()`. + ASMJIT_INLINE_NODEBUG size_t dataSize() const noexcept { return _content.size(); } + + //! \} + + //! \name Logger Data Manipulation + //! \{ + + //! Clears the accumulated logger data. + ASMJIT_INLINE_NODEBUG void clear() noexcept { _content.clear(); } + + //! \} + + ASMJIT_API Error _log(const char* data, size_t size = SIZE_MAX) noexcept override; +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif + +#endif // ASMJIT_CORE_LOGGER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/operand.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/operand.h new file mode 100644 index 0000000000000000000000000000000000000000..8b025e4f4c78c4503f64b67eb45133e296237bab --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/operand.h @@ -0,0 +1,1889 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_OPERAND_H_INCLUDED +#define ASMJIT_CORE_OPERAND_H_INCLUDED + +#include "../core/archcommons.h" +#include "../core/support.h" +#include "../core/type.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_assembler +//! \{ + +//! Operand type used by \ref Operand_. +enum class OperandType : uint32_t { + //! Not an operand or not initialized. + kNone = 0, + //! Operand is a register. + kReg = 1, + //! Operand is a memory. + kMem = 2, + //! Operand is a register-list. + kRegList = 3, + //! Operand is an immediate value. + kImm = 4, + //! Operand is a label. + kLabel = 5, + + //! Maximum value of `OperandType`. + kMaxValue = kRegList +}; + +static_assert(uint32_t(OperandType::kMem) == uint32_t(OperandType::kReg) + 1, + "AsmJit requires that `OperandType::kMem` equals to `OperandType::kReg + 1`"); + +//! Register mask is a convenience typedef that describes a mask where each bit describes a physical register id +//! in the same \ref RegGroup. At the moment 32 bits are enough as AsmJit doesn't support any architecture that +//! would provide more than 32 registers for a register group. +typedef uint32_t RegMask; + +//! Register type. +//! +//! Provides a unique type that can be used to identify a register or its view. +enum class RegType : uint8_t { + //! No register - unused, invalid, multiple meanings. + kNone = 0, + + //! This is not a register type. This value is reserved for a \ref Label that's used in \ref BaseMem as a base. + //! + //! Label tag is used as a sub-type, forming a unique signature across all operand types as 0x1 is never associated + //! with any register type. This means that a memory operand's BASE register can be constructed from virtually any + //! operand (register vs. label) by just assigning its type (register type or label-tag) and operand id. + kLabelTag = 1, + + //! Universal type describing program counter (PC) or instruction pointer (IP) register, if the target architecture + //! actually exposes it as a separate register type, which most modern architectures do. + kPC = 2, + + //! 8-bit low general purpose register (X86). + kGp8Lo = 3, + //! 8-bit high general purpose register (X86). + kGp8Hi = 4, + //! 16-bit general purpose register (X86). + kGp16 = 5, + //! 32-bit general purpose register (X86|AArch32|AArch64). + kGp32 = 6, + //! 64-bit general purpose register (X86|AArch64). + kGp64 = 7, + //! 8-bit view of a vector register (AArch64). + kVec8 = 8, + //! 16-bit view of a vector register (AArch64). + kVec16 = 9, + //! 32-bit view of a vector register (AArch32|AArch64). + kVec32 = 10, + //! 64-bit view of a vector register (AArch32|AArch64). + //! + //! \note This is never used for MMX registers on X86, MMX registers have its own category. + kVec64 = 11, + //! 128-bit view of a vector register (X86|AArch32|AArch64). + kVec128 = 12, + //! 256-bit view of a vector register (X86). + kVec256 = 13, + //! 512-bit view of a vector register (X86). + kVec512 = 14, + //! 1024-bit view of a vector register (future). + kVec1024 = 15, + //! View of a vector register, which width is implementation specific (AArch64). + kVecNLen = 16, + + //! Mask register (X86). + kMask = 17, + + //! Start of architecture dependent register types. + kExtra = 18, + + // X86 Specific Register Types + // --------------------------- + + //! Instruction pointer (RIP), only addressable in \ref x86::Mem in 64-bit targets. + kX86_Rip = kPC, + //! Low GPB register (AL, BL, CL, DL, ...). + kX86_GpbLo = kGp8Lo, + //! High GPB register (AH, BH, CH, DH only). + kX86_GpbHi = kGp8Hi, + //! GPW register. + kX86_Gpw = kGp16, + //! GPD register. + kX86_Gpd = kGp32, + //! GPQ register (64-bit). + kX86_Gpq = kGp64, + //! XMM register (SSE+). + kX86_Xmm = kVec128, + //! YMM register (AVX+). + kX86_Ymm = kVec256, + //! ZMM register (AVX512+). + kX86_Zmm = kVec512, + //! K register (AVX512+). + kX86_KReg = kMask, + //! MMX register. + kX86_Mm = kExtra + 0, + //! Segment register (None, ES, CS, SS, DS, FS, GS). + kX86_SReg = kExtra + 1, + //! Control register (CR). + kX86_CReg = kExtra + 2, + //! Debug register (DR). + kX86_DReg = kExtra + 3, + //! FPU (x87) register. + kX86_St = kExtra + 4, + //! Bound register (BND). + kX86_Bnd = kExtra + 5, + //! TMM register (AMX_TILE) + kX86_Tmm = kExtra + 6, + + // ARM Specific Register Types + // --------------------------- + + //! Program pointer (PC) register (AArch64). + kARM_PC = kPC, + //! 32-bit general purpose register (R or W). + kARM_GpW = kGp32, + //! 64-bit general purpose register (X). + kARM_GpX = kGp64, + //! 8-bit view of VFP/ASIMD register (B). + kARM_VecB = kVec8, + //! 16-bit view of VFP/ASIMD register (H). + kARM_VecH = kVec16, + //! 32-bit view of VFP/ASIMD register (S). + kARM_VecS = kVec32, + //! 64-bit view of VFP/ASIMD register (D). + kARM_VecD = kVec64, + //! 128-bit view of VFP/ASIMD register (Q). + kARM_VecQ = kVec128, + //! 128-bit view of VFP/ASIMD register (V). + kARM_VecV = kVec128, + + //! Maximum value of `RegType`. + kMaxValue = 31 +}; +ASMJIT_DEFINE_ENUM_COMPARE(RegType) + +//! Register group. +//! +//! Provides a unique value that identifies groups of registers and their views. +enum class RegGroup : uint8_t { + //! General purpose register group compatible with all backends. + kGp = 0, + //! Vector register group compatible with all backends. + //! + //! Describes X86 XMM|YMM|ZMM registers ARM/AArch64 V registers. + kVec = 1, + + //! Mask register group compatible with all backends that can use masking. + kMask = 2, + //! Extra virtual group #3 that can be used by Compiler for register allocation. + kExtraVirt3 = 3, + + //! Program counter group. + kPC = 4, + + //! Extra non-virtual group that can be used by registers not managed by Compiler. + kExtraNonVirt = 5, + + // X86 Specific Register Groups + // ---------------------------- + + //! K register group (KReg) - maps to \ref RegGroup::kMask (X86, X86_64). + kX86_K = kMask, + //! MMX register group (MM) - maps to \ref RegGroup::kExtraVirt3 (X86, X86_64). + kX86_MM = kExtraVirt3, + + //! Instruction pointer (X86, X86_64). + kX86_Rip = kPC, + //! Segment register group (X86, X86_64). + kX86_SReg = kExtraNonVirt + 0, + //! CR register group (X86, X86_64). + kX86_CReg = kExtraNonVirt + 1, + //! DR register group (X86, X86_64). + kX86_DReg = kExtraNonVirt + 2, + //! FPU register group (X86, X86_64). + kX86_St = kExtraNonVirt + 3, + //! BND register group (X86, X86_64). + kX86_Bnd = kExtraNonVirt + 4, + //! TMM register group (X86, X86_64). + kX86_Tmm = kExtraNonVirt + 5, + + //! First group - only used in loops. + k0 = 0, + //! Last value of a virtual register that is managed by \ref BaseCompiler. + kMaxVirt = Globals::kNumVirtGroups - 1, + //! Maximum value of `RegGroup`. + kMaxValue = 15 +}; +ASMJIT_DEFINE_ENUM_COMPARE(RegGroup) + +typedef Support::EnumValues RegGroupVirtValues; + +//! Operand signature is a 32-bit number describing \ref Operand and some of its payload. +//! +//! In AsmJit operand signature is used to store additional payload of register, memory, and immediate operands. +//! In practice the biggest pressure on OperandSignature is from \ref BaseMem and architecture specific memory +//! operands that need to store additional payload that cannot be stored elsewhere as values of all other members +//! are fully specified by \ref BaseMem. +struct OperandSignature { + //! \name Constants + //! \{ + + enum : uint32_t { + // Operand type (3 least significant bits). + // |........|........|........|.....XXX| + kOpTypeShift = 0, + kOpTypeMask = 0x07u << kOpTypeShift, + + // Register type (5 bits). + // |........|........|........|XXXXX...| + kRegTypeShift = 3, + kRegTypeMask = 0x1Fu << kRegTypeShift, + + // Register group (4 bits). + // |........|........|....XXXX|........| + kRegGroupShift = 8, + kRegGroupMask = 0x0Fu << kRegGroupShift, + + // Memory base type (5 bits). + // |........|........|........|XXXXX...| + kMemBaseTypeShift = 3, + kMemBaseTypeMask = 0x1Fu << kMemBaseTypeShift, + + // Memory index type (5 bits). + // |........|........|...XXXXX|........| + kMemIndexTypeShift = 8, + kMemIndexTypeMask = 0x1Fu << kMemIndexTypeShift, + + // Memory base+index combined (10 bits). + // |........|........|...XXXXX|XXXXX...| + kMemBaseIndexShift = 3, + kMemBaseIndexMask = 0x3FFu << kMemBaseIndexShift, + + // This memory operand represents a home-slot or stack (Compiler) (1 bit). + // |........|........|..X.....|........| + kMemRegHomeShift = 13, + kMemRegHomeFlag = 0x01u << kMemRegHomeShift, + + // Immediate type (1 bit). + // |........|........|........|....X...| + kImmTypeShift = 3, + kImmTypeMask = 0x01u << kImmTypeShift, + + // Predicate used by either registers or immediate values (4 bits). + // |........|XXXX....|........|........| + kPredicateShift = 20, + kPredicateMask = 0x0Fu << kPredicateShift, + + // Operand size (8 most significant bits). + // |XXXXXXXX|........|........|........| + kSizeShift = 24, + kSizeMask = 0xFFu << kSizeShift + }; + + //! \} + + //! \name Members + //! \{ + + uint32_t _bits; + + //! \} + + //! \name Overloaded Operators + //! + //! Overloaded operators make `OperandSignature` behave like regular integer. + //! + //! \{ + + ASMJIT_INLINE_NODEBUG constexpr bool operator!() const noexcept { return _bits == 0; } + ASMJIT_INLINE_NODEBUG constexpr explicit operator bool() const noexcept { return _bits != 0; } + + ASMJIT_INLINE_NODEBUG OperandSignature& operator|=(uint32_t x) noexcept { _bits |= x; return *this; } + ASMJIT_INLINE_NODEBUG OperandSignature& operator&=(uint32_t x) noexcept { _bits &= x; return *this; } + ASMJIT_INLINE_NODEBUG OperandSignature& operator^=(uint32_t x) noexcept { _bits ^= x; return *this; } + + ASMJIT_INLINE_NODEBUG OperandSignature& operator|=(const OperandSignature& other) noexcept { return operator|=(other._bits); } + ASMJIT_INLINE_NODEBUG OperandSignature& operator&=(const OperandSignature& other) noexcept { return operator&=(other._bits); } + ASMJIT_INLINE_NODEBUG OperandSignature& operator^=(const OperandSignature& other) noexcept { return operator^=(other._bits); } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator~() const noexcept { return OperandSignature{~_bits}; } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator|(uint32_t x) const noexcept { return OperandSignature{_bits | x}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator&(uint32_t x) const noexcept { return OperandSignature{_bits & x}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator^(uint32_t x) const noexcept { return OperandSignature{_bits ^ x}; } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator|(const OperandSignature& other) const noexcept { return OperandSignature{_bits | other._bits}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator&(const OperandSignature& other) const noexcept { return OperandSignature{_bits & other._bits}; } + ASMJIT_INLINE_NODEBUG constexpr OperandSignature operator^(const OperandSignature& other) const noexcept { return OperandSignature{_bits ^ other._bits}; } + + ASMJIT_INLINE_NODEBUG constexpr bool operator==(uint32_t x) const noexcept { return _bits == x; } + ASMJIT_INLINE_NODEBUG constexpr bool operator!=(uint32_t x) const noexcept { return _bits != x; } + + ASMJIT_INLINE_NODEBUG constexpr bool operator==(const OperandSignature& other) const noexcept { return _bits == other._bits; } + ASMJIT_INLINE_NODEBUG constexpr bool operator!=(const OperandSignature& other) const noexcept { return _bits != other._bits; } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG void reset() noexcept { _bits = 0; } + + ASMJIT_INLINE_NODEBUG constexpr uint32_t bits() const noexcept { return _bits; } + ASMJIT_INLINE_NODEBUG void setBits(uint32_t bits) noexcept { _bits = bits; } + + template + ASMJIT_INLINE_NODEBUG constexpr bool hasField() const noexcept { + return (_bits & kFieldMask) != 0; + } + + template + ASMJIT_INLINE_NODEBUG constexpr bool hasField(uint32_t value) const noexcept { + return (_bits & kFieldMask) != value << Support::ConstCTZ::value; + } + + template + ASMJIT_INLINE_NODEBUG constexpr uint32_t getField() const noexcept { + return (_bits >> Support::ConstCTZ::value) & (kFieldMask >> Support::ConstCTZ::value); + } + + template + ASMJIT_INLINE_NODEBUG void setField(uint32_t value) noexcept { + ASMJIT_ASSERT(((value << Support::ConstCTZ::value) & ~kFieldMask) == 0); + _bits = (_bits & ~kFieldMask) | (value << Support::ConstCTZ::value); + } + + ASMJIT_INLINE_NODEBUG constexpr OperandSignature subset(uint32_t mask) const noexcept { return OperandSignature{_bits & mask}; } + + template::value> + ASMJIT_INLINE_NODEBUG constexpr OperandSignature replacedValue(uint32_t value) const noexcept { return OperandSignature{(_bits & ~kFieldMask) | (value << kFieldShift)}; } + + template + ASMJIT_INLINE_NODEBUG constexpr bool matchesSignature(const OperandSignature& signature) const noexcept { + return (_bits & kFieldMask) == signature._bits; + } + + template + ASMJIT_INLINE_NODEBUG constexpr bool matchesFields(uint32_t bits) const noexcept { + return (_bits & kFieldMask) == bits; + } + + template + ASMJIT_INLINE_NODEBUG constexpr bool matchesFields(const OperandSignature& fields) const noexcept { + return (_bits & kFieldMask) == fields._bits; + } + + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return _bits != 0; } + + ASMJIT_INLINE_NODEBUG constexpr OperandType opType() const noexcept { return (OperandType)getField(); } + + ASMJIT_INLINE_NODEBUG constexpr RegType regType() const noexcept { return (RegType)getField(); } + ASMJIT_INLINE_NODEBUG constexpr RegGroup regGroup() const noexcept { return (RegGroup)getField(); } + + ASMJIT_INLINE_NODEBUG constexpr RegType memBaseType() const noexcept { return (RegType)getField(); } + ASMJIT_INLINE_NODEBUG constexpr RegType memIndexType() const noexcept { return (RegType)getField(); } + + ASMJIT_INLINE_NODEBUG constexpr uint32_t predicate() const noexcept { return getField(); } + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return getField(); } + + ASMJIT_INLINE_NODEBUG void setOpType(OperandType opType) noexcept { setField(uint32_t(opType)); } + ASMJIT_INLINE_NODEBUG void setRegType(RegType regType) noexcept { setField(uint32_t(regType)); } + ASMJIT_INLINE_NODEBUG void setRegGroup(RegGroup regGroup) noexcept { setField(uint32_t(regGroup)); } + + ASMJIT_INLINE_NODEBUG void setMemBaseType(RegType baseType) noexcept { setField(uint32_t(baseType)); } + ASMJIT_INLINE_NODEBUG void setMemIndexType(RegType indexType) noexcept { setField(uint32_t(indexType)); } + + ASMJIT_INLINE_NODEBUG void setPredicate(uint32_t predicate) noexcept { setField(predicate); } + ASMJIT_INLINE_NODEBUG void setSize(uint32_t size) noexcept { setField(size); } + + //! \} + + //! \name Static Constructors + //! \{ + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromBits(uint32_t bits) noexcept { + return OperandSignature{bits}; + } + + template + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromValue(const T& value) noexcept { + return OperandSignature{uint32_t(value) << Support::ConstCTZ::value}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromOpType(OperandType opType) noexcept { + return OperandSignature{uint32_t(opType) << kOpTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromRegType(RegType regType) noexcept { + return OperandSignature{uint32_t(regType) << kRegTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromRegGroup(RegGroup regGroup) noexcept { + return OperandSignature{uint32_t(regGroup) << kRegGroupShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromRegTypeAndGroup(RegType regType, RegGroup regGroup) noexcept { + return fromRegType(regType) | fromRegGroup(regGroup); + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromMemBaseType(RegType baseType) noexcept { + return OperandSignature{uint32_t(baseType) << kMemBaseTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromMemIndexType(RegType indexType) noexcept { + return OperandSignature{uint32_t(indexType) << kMemIndexTypeShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromPredicate(uint32_t predicate) noexcept { + return OperandSignature{predicate << kPredicateShift}; + } + + static ASMJIT_INLINE_NODEBUG constexpr OperandSignature fromSize(uint32_t size) noexcept { + return OperandSignature{size << kSizeShift}; + } + + //! \} +}; + +//! Base class representing an operand in AsmJit (non-default constructed version). +//! +//! Contains no initialization code and can be used safely to define an array of operands that won't be initialized. +//! This is a \ref Operand base structure designed to be statically initialized, static const, or to be used by user +//! code to define an array of operands without having them default initialized at construction time. +//! +//! The key difference between \ref Operand and \ref Operand_ is: +//! +//! ``` +//! Operand_ xArray[10]; // Not initialized, contains garbage. +//! Operand_ yArray[10] {}; // All operands initialized to none explicitly (zero initialized). +//! Operand yArray[10]; // All operands initialized to none implicitly (zero initialized). +//! ``` +struct Operand_ { + //! \name Types + //! \{ + + typedef OperandSignature Signature; + + //! \} + + //! \name Constants + //! \{ + + // Indexes to `_data` array. + enum DataIndex : uint32_t { + kDataMemIndexId = 0, + kDataMemOffsetLo = 1, + + kDataImmValueLo = ASMJIT_ARCH_LE ? 0 : 1, + kDataImmValueHi = ASMJIT_ARCH_LE ? 1 : 0 + }; + + //! Constants useful for VirtId <-> Index translation. + enum VirtIdConstants : uint32_t { + //! Minimum valid packed-id. + kVirtIdMin = 256, + //! Maximum valid packed-id, excludes Globals::kInvalidId. + kVirtIdMax = Globals::kInvalidId - 1, + //! Count of valid packed-ids. + kVirtIdCount = uint32_t(kVirtIdMax - kVirtIdMin + 1) + }; + + //! \} + + //! \name Members + //! \{ + + //! Provides operand type and additional payload. + Signature _signature; + //! Either base id as used by memory operand or any id as used by others. + uint32_t _baseId; + + //! Data specific to the operand type. + //! + //! The reason we don't use union is that we have `constexpr` constructors that construct operands and other + //!`constexpr` functions that return whether another Operand or something else. These cannot generally work with + //! unions so we also cannot use `union` if we want to be standard compliant. + uint32_t _data[2]; + + //! \} + + //! Tests whether the given `id` is a valid virtual register id. Since AsmJit supports both physical and virtual + //! registers it must be able to distinguish between these two. The idea is that physical registers are always + //! limited in size, so virtual identifiers start from `kVirtIdMin` and end at `kVirtIdMax`. + static ASMJIT_INLINE_NODEBUG bool isVirtId(uint32_t id) noexcept { return id - kVirtIdMin < uint32_t(kVirtIdCount); } + //! Converts a real-id into a packed-id that can be stored in Operand. + static ASMJIT_INLINE_NODEBUG uint32_t indexToVirtId(uint32_t id) noexcept { return id + kVirtIdMin; } + //! Converts a packed-id back to real-id. + static ASMJIT_INLINE_NODEBUG uint32_t virtIdToIndex(uint32_t id) noexcept { return id - kVirtIdMin; } + + //! \name Construction & Destruction + //! \{ + + //! \cond INTERNAL + //! Initializes a `BaseReg` operand from `signature` and register `id`. + ASMJIT_INLINE_NODEBUG void _initReg(const Signature& signature, uint32_t id) noexcept { + _signature = signature; + _baseId = id; + _data[0] = 0; + _data[1] = 0; + } + //! \endcond + + //! Initializes the operand from `other` operand (used by operator overloads). + ASMJIT_INLINE_NODEBUG void copyFrom(const Operand_& other) noexcept { + _signature._bits = other._signature._bits; + _baseId = other._baseId; + _data[0] = other._data[0]; + _data[1] = other._data[1]; + } + + //! Resets the `Operand` to none. + //! + //! None operand is defined the following way: + //! - Its signature is zero (OperandType::kNone, and the rest zero as well). + //! - Its id is `0`. + //! - The reserved8_4 field is set to `0`. + //! - The reserved12_4 field is set to zero. + //! + //! In other words, reset operands have all members set to zero. Reset operand must match the Operand state + //! right after its construction. Alternatively, if you have an array of operands, you can simply use `memset()`. + //! + //! ``` + //! using namespace asmjit; + //! + //! Operand a; + //! Operand b; + //! assert(a == b); + //! + //! b = x86::eax; + //! assert(a != b); + //! + //! b.reset(); + //! assert(a == b); + //! + //! memset(&b, 0, sizeof(Operand)); + //! assert(a == b); + //! ``` + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _signature.reset(); + _baseId = 0; + _data[0] = 0; + _data[1] = 0; + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Tests whether this operand is the same as `other`. + ASMJIT_INLINE_NODEBUG constexpr bool operator==(const Operand_& other) const noexcept { return equals(other); } + //! Tests whether this operand is not the same as `other`. + ASMJIT_INLINE_NODEBUG constexpr bool operator!=(const Operand_& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Cast + //! \{ + + //! Casts this operand to `T` type. + template + ASMJIT_INLINE_NODEBUG T& as() noexcept { return static_cast(*this); } + + //! Casts this operand to `T` type (const). + template + ASMJIT_INLINE_NODEBUG const T& as() const noexcept { return static_cast(*this); } + + //! \} + + //! \name Equality + //! \{ + + //! Tests whether the operand is 100% equal to `other` operand. + //! + //! \note This basically performs a binary comparison, if aby bit is + //! different the operands are not equal. + ASMJIT_INLINE_NODEBUG constexpr bool equals(const Operand_& other) const noexcept { + return bool(unsigned(_signature == other._signature) & + unsigned(_baseId == other._baseId ) & + unsigned(_data[0] == other._data[0] ) & + unsigned(_data[1] == other._data[1] )); + } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the operand's signature matches the signature of the `other` operand. + ASMJIT_INLINE_NODEBUG constexpr bool hasSignature(const Operand_& other) const noexcept { return _signature == other._signature; } + //! Tests whether the operand's signature matches the given signature `sign`. + ASMJIT_INLINE_NODEBUG constexpr bool hasSignature(const Signature& other) const noexcept { return _signature == other; } + + //! Returns operand signature as unsigned 32-bit integer. + //! + //! Signature is first 4 bytes of the operand data. It's used mostly for operand checking as it's + //! much faster to check packed 4 bytes at once than having to check these bytes individually. + ASMJIT_INLINE_NODEBUG constexpr Signature signature() const noexcept { return _signature; } + + //! Sets the operand signature, see `signature()`. + //! + //! \note Improper use of `setSignature()` can lead to hard-to-debug errors. + ASMJIT_INLINE_NODEBUG void setSignature(const Signature& signature) noexcept { _signature = signature; } + //! \overload + ASMJIT_INLINE_NODEBUG void setSignature(uint32_t signature) noexcept { _signature._bits = signature; } + + //! Returns the type of the operand, see `OpType`. + ASMJIT_INLINE_NODEBUG constexpr OperandType opType() const noexcept { return _signature.opType(); } + //! Tests whether the operand is none (`OperandType::kNone`). + ASMJIT_INLINE_NODEBUG constexpr bool isNone() const noexcept { return _signature == Signature::fromBits(0); } + //! Tests whether the operand is a register (`OperandType::kReg`). + ASMJIT_INLINE_NODEBUG constexpr bool isReg() const noexcept { return opType() == OperandType::kReg; } + //! Tests whether the operand is a register-list. + //! + //! \note Register-list is currently only used by 32-bit ARM architecture. + ASMJIT_INLINE_NODEBUG constexpr bool isRegList() const noexcept { return opType() == OperandType::kRegList; } + //! Tests whether the operand is a memory location (`OperandType::kMem`). + ASMJIT_INLINE_NODEBUG constexpr bool isMem() const noexcept { return opType() == OperandType::kMem; } + //! Tests whether the operand is an immediate (`OperandType::kImm`). + ASMJIT_INLINE_NODEBUG constexpr bool isImm() const noexcept { return opType() == OperandType::kImm; } + //! Tests whether the operand is a label (`OperandType::kLabel`). + ASMJIT_INLINE_NODEBUG constexpr bool isLabel() const noexcept { return opType() == OperandType::kLabel; } + + //! Tests whether the operand is a physical register. + ASMJIT_INLINE_NODEBUG constexpr bool isPhysReg() const noexcept { return isReg() && _baseId < 0xFFu; } + //! Tests whether the operand is a virtual register. + ASMJIT_INLINE_NODEBUG constexpr bool isVirtReg() const noexcept { return isReg() && _baseId > 0xFFu; } + + //! Returns the operand id. + //! + //! The value returned should be interpreted accordingly to the operand type: + //! * None - Should be `0`. + //! * Reg - Physical or virtual register id. + //! * Mem - Multiple meanings - BASE address (register or label id), or high value of a 64-bit absolute address. + //! * Imm - Should be `0`. + //! * Label - Label id if it was created by using `newLabel()` or `Globals::kInvalidId` if the label is invalid or + //! not initialized. + ASMJIT_INLINE_NODEBUG constexpr uint32_t id() const noexcept { return _baseId; } + + //! Tests whether the operand is a register matching the given register `type`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType type) const noexcept { + return _signature.subset(Signature::kOpTypeMask | Signature::kRegTypeMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegType(type)); + } + + //! Tests whether the operand is a register of the provided register group `regGroup`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegGroup regGroup) const noexcept { + return _signature.subset(Signature::kOpTypeMask | Signature::kRegGroupMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(regGroup)); + } + + //! Tests whether the operand is register and of register type `regType` and `regId`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType regType, uint32_t regId) const noexcept { return isReg(regType) && _baseId == regId; } + //! Tests whether the operand is register and of register group `regGroup` and `regId`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegGroup regGroup, uint32_t regId) const noexcept { return isReg(regGroup) && _baseId == regId; } + + //! Tests whether the register is a general purpose register (any size). + ASMJIT_INLINE_NODEBUG constexpr bool isGp() const noexcept { return isReg(RegGroup::kGp); } + //! Tests whether the register is a 32-bit general purpose register. + ASMJIT_INLINE_NODEBUG constexpr bool isGp32() const noexcept { return isReg(RegType::kGp32); } + //! Tests whether the register is a 64-bit general purpose register. + ASMJIT_INLINE_NODEBUG constexpr bool isGp64() const noexcept { return isReg(RegType::kGp64); } + + //! Tests whether the register is a vector register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isVec() const noexcept { return isReg(RegGroup::kVec); } + //! Tests whether the register is an 8-bit vector register or view (AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec8() const noexcept { return isReg(RegType::kVec8); } + //! Tests whether the register is a 16-bit vector register or view (AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec16() const noexcept { return isReg(RegType::kVec16); } + //! Tests whether the register is a 32-bit vector register or view (AArch32, AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec32() const noexcept { return isReg(RegType::kVec32); } + //! Tests whether the register is a 64-bit vector register or view (AArch32, AArch64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec64() const noexcept { return isReg(RegType::kVec64); } + //! Tests whether the register is a 128-bit vector register or view (AArch32, AArch64, X86, X86_64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec128() const noexcept { return isReg(RegType::kVec128); } + //! Tests whether the register is a 256-bit vector register or view (X86, X86_64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec256() const noexcept { return isReg(RegType::kVec256); } + //! Tests whether the register is a 512-bit vector register or view (X86, X86_64). + ASMJIT_INLINE_NODEBUG constexpr bool isVec512() const noexcept { return isReg(RegType::kVec512); } + + //! Tests whether the register is a mask register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isMask() const noexcept { return isReg(RegGroup::kMask); } + + //! Tests whether the operand is a register matching the given register `type`. + ASMJIT_INLINE_NODEBUG constexpr bool isRegList(RegType type) const noexcept { + return _signature.subset(Signature::kOpTypeMask | Signature::kRegTypeMask) == (Signature::fromOpType(OperandType::kRegList) | Signature::fromRegType(type)); + } + + //! Tests whether the operand is a register or memory. + //! + //! \note This is useful on X86 and X86_64 architectures as many instructions support Reg/Mem operand combination. + //! So if the user code works with just \ref Operand, it's possible to check whether the operand is either a register + //! or memory location with a single check. + ASMJIT_INLINE_NODEBUG constexpr bool isRegOrMem() const noexcept { + return Support::isBetween(uint32_t(opType()), uint32_t(OperandType::kReg), uint32_t(OperandType::kMem)); + } + + //! Tests whether the operand is a register, register-list, or memory. + //! + //! \note This is useful on 32-bit ARM architecture to check whether an operand references a register. It can be + //! used in other architectures too, but it would work identically to \ref isRegOrMem() as other architectures + //! don't provide register lists. + ASMJIT_INLINE_NODEBUG constexpr bool isRegOrRegListOrMem() const noexcept { + return Support::isBetween(uint32_t(opType()), uint32_t(OperandType::kReg), uint32_t(OperandType::kRegList)); + } + + //! \} + + //! \name Accessors (X86 Specific) + //! \{ + + //! Returns a size of a register or an X86 memory operand. + //! + //! At the moment only X86 and X86_64 memory operands have a size - other memory operands can use bits that represent + //! size as an additional payload. This means that memory size is architecture specific and should be accessed via + //! \ref x86::Mem::size(). Sometimes when the user knows that the operand is either a register or memory operand this + //! function can be helpful as it avoids casting. + ASMJIT_INLINE_NODEBUG constexpr uint32_t x86RmSize() const noexcept { + return _signature.size(); + } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("hasSize() is no longer portable - use x86RmSize() instead, if your target is X86/X86_64") + ASMJIT_INLINE_NODEBUG constexpr bool hasSize() const noexcept { return x86RmSize() != 0u; } + + ASMJIT_DEPRECATED("hasSize() is no longer portable - use x86RmSize() instead, if your target is X86/X86_64") + ASMJIT_INLINE_NODEBUG constexpr bool hasSize(uint32_t s) const noexcept { return x86RmSize() == s; } + + ASMJIT_DEPRECATED("size() is no longer portable - use x86RmSize() instead, if your target is X86/X86_64") + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } +#endif + + //! \} +}; + +//! Base class representing an operand in AsmJit (default constructed version). +class Operand : public Operand_ { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates `kOpNone` operand having all members initialized to zero. + ASMJIT_INLINE_NODEBUG constexpr Operand() noexcept + : Operand_{ Signature::fromOpType(OperandType::kNone), 0u, { 0u, 0u }} {} + + //! Creates a cloned `other` operand. + ASMJIT_INLINE_NODEBUG constexpr Operand(const Operand& other) noexcept = default; + + //! Creates a cloned `other` operand. + ASMJIT_INLINE_NODEBUG constexpr explicit Operand(const Operand_& other) + : Operand_(other) {} + + //! Creates an operand initialized to raw `[u0, u1, u2, u3]` values. + ASMJIT_INLINE_NODEBUG constexpr Operand(Globals::Init_, const Signature& u0, uint32_t u1, uint32_t u2, uint32_t u3) noexcept + : Operand_{{u0._bits}, u1, {u2, u3}} {} + + //! Creates an uninitialized operand (dangerous). + ASMJIT_INLINE_NODEBUG explicit Operand(Globals::NoInit_) noexcept {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Operand& operator=(const Operand& other) noexcept = default; + ASMJIT_INLINE_NODEBUG Operand& operator=(const Operand_& other) noexcept { return operator=(static_cast(other)); } + + //! \} + + //! \name Clone + //! \{ + + //! Clones this operand and returns its copy. + ASMJIT_INLINE_NODEBUG constexpr Operand clone() const noexcept { return Operand(*this); } + + //! \} +}; + +static_assert(sizeof(Operand) == 16, "asmjit::Operand must be exactly 16 bytes long"); + +//! Label (jump target or data location). +//! +//! Label represents a location in code typically used as a jump target, but may be also a reference to some data or +//! a static variable. Label has to be explicitly created by BaseEmitter. +//! +//! Example of using labels: +//! +//! ``` +//! // Create some emitter (for example x86::Assembler). +//! x86::Assembler a; +//! +//! // Create Label instance. +//! Label L1 = a.newLabel(); +//! +//! // ... your code ... +//! +//! // Using label. +//! a.jump(L1); +//! +//! // ... your code ... +//! +//! // Bind label to the current position, see `BaseEmitter::bind()`. +//! a.bind(L1); +//! ``` +class Label : public Operand { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates a label operand without ID (you must set the ID to make it valid). + ASMJIT_INLINE_NODEBUG constexpr Label() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kLabel), Globals::kInvalidId, 0, 0) {} + + //! Creates a cloned label operand of `other`. + ASMJIT_INLINE_NODEBUG constexpr Label(const Label& other) noexcept + : Operand(other) {} + + //! Creates a label operand of the given `id`. + ASMJIT_INLINE_NODEBUG constexpr explicit Label(uint32_t id) noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kLabel), id, 0, 0) {} + + ASMJIT_INLINE_NODEBUG explicit Label(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! Resets the label, will reset all properties and set its ID to `Globals::kInvalidId`. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _signature = Signature::fromOpType(OperandType::kLabel); + _baseId = Globals::kInvalidId; + _data[0] = 0; + _data[1] = 0; + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Label& operator=(const Label& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the label was created by CodeHolder and/or an attached emitter. + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return _baseId != Globals::kInvalidId; } + //! Sets the label `id`. + ASMJIT_INLINE_NODEBUG void setId(uint32_t id) noexcept { _baseId = id; } + + //! \} +}; + +//! \cond INTERNAL +//! Default register traits. +struct BaseRegTraits { + enum : uint32_t { + //! \ref TypeId representing this register type, could be \ref TypeId::kVoid if such type doesn't exist. + kTypeId = uint32_t(TypeId::kVoid), + //! RegType is not valid by default. + kValid = 0, + + //! Zero type by default (defaults to None). + kType = uint32_t(RegType::kNone), + //! Zero group by default (defaults to GP). + kGroup = uint32_t(RegGroup::kGp), + //! No size by default. + kSize = 0, + + //! Empty signature by default (not even having operand type set to register). + kSignature = 0 + }; +}; +//! \endcond + +//! Physical or virtual register operand (base). +class BaseReg : public Operand { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + //! None or any register (mostly internal). + kIdBad = 0xFFu, + + kBaseSignatureMask = + Signature::kOpTypeMask | + Signature::kRegTypeMask | + Signature::kRegGroupMask | + Signature::kSizeMask, + + kTypeNone = uint32_t(RegType::kNone), + kSignature = Signature::fromOpType(OperandType::kReg).bits() + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a dummy register operand. + ASMJIT_INLINE_NODEBUG constexpr BaseReg() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kReg), kIdBad, 0, 0) {} + + //! Creates a new register operand which is the same as `other` . + ASMJIT_INLINE_NODEBUG constexpr BaseReg(const BaseReg& other) noexcept + : Operand(other) {} + + //! Creates a new register operand compatible with `other`, but with a different `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseReg(const BaseReg& other, uint32_t id) noexcept + : Operand(Globals::Init, other._signature, id, 0, 0) {} + + //! Creates a register initialized to the given `signature` and `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseReg(const Signature& signature, uint32_t id) noexcept + : Operand(Globals::Init, signature, id, 0, 0) {} + + ASMJIT_INLINE_NODEBUG explicit BaseReg(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG BaseReg& operator=(const BaseReg& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns base signature of the register associated with each register type. + //! + //! Base signature only contains the operand type, register type, register group, and register size. It doesn't + //! contain element type, predicate, or other architecture-specific data. Base signature is a signature that is + //! provided by architecture-specific `RegTraits`, like \ref x86::RegTraits. + ASMJIT_INLINE_NODEBUG constexpr OperandSignature baseSignature() const noexcept { return _signature & kBaseSignatureMask; } + + //! Tests whether the operand's base signature matches the given signature `sign`. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseSignature(uint32_t signature) const noexcept { return baseSignature() == signature; } + //! Tests whether the operand's base signature matches the given signature `sign`. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseSignature(const OperandSignature& signature) const noexcept { return baseSignature() == signature; } + //! Tests whether the operand's base signature matches the base signature of the `other` operand. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseSignature(const BaseReg& other) const noexcept { return baseSignature() == other.baseSignature(); } + + //! Tests whether this register is the same as `other`. + //! + //! This is just an optimization. Registers by default only use the first 8 bytes of Operand data, so this method + //! takes advantage of this knowledge and only compares these 8 bytes. If both operands were created correctly + //! both \ref equals() and \ref isSame() should give the same answer, however, if any of these two contains garbage + //! or other metadata in the upper 8 bytes then \ref isSame() may return `true` in cases in which \ref equals() + //! returns false. + ASMJIT_INLINE_NODEBUG constexpr bool isSame(const BaseReg& other) const noexcept { + return (_signature == other._signature) & (_baseId == other._baseId); + } + + //! Tests whether the register is valid (either virtual or physical). + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return bool(unsigned(_signature != 0) & unsigned(_baseId != kIdBad)); } + + //! Tests whether this is a physical register. + ASMJIT_INLINE_NODEBUG constexpr bool isPhysReg() const noexcept { return _baseId < kIdBad; } + //! Tests whether this is a virtual register. + ASMJIT_INLINE_NODEBUG constexpr bool isVirtReg() const noexcept { return _baseId > kIdBad; } + + //! Tests whether the register type matches `type` - same as `isReg(type)`, provided for convenience. + ASMJIT_INLINE_NODEBUG constexpr bool isType(RegType type) const noexcept { return _signature.subset(Signature::kRegTypeMask) == Signature::fromRegType(type); } + //! Tests whether the register group matches `group`. + ASMJIT_INLINE_NODEBUG constexpr bool isGroup(RegGroup group) const noexcept { return _signature.subset(Signature::kRegGroupMask) == Signature::fromRegGroup(group); } + + //! Tests whether the register is a general purpose register (any size). + ASMJIT_INLINE_NODEBUG constexpr bool isGp() const noexcept { return isGroup(RegGroup::kGp); } + //! Tests whether the register is a vector register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isVec() const noexcept { return isGroup(RegGroup::kVec); } + //! Tests whether the register is a mask register of any size. + ASMJIT_INLINE_NODEBUG constexpr bool isMask() const noexcept { return isGroup(RegGroup::kMask); } + + using Operand_::isReg; + + //! Same as `isType()`, provided for convenience. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType rType) const noexcept { return isType(rType); } + //! Tests whether the register type matches `type` and register id matches `id`. + ASMJIT_INLINE_NODEBUG constexpr bool isReg(RegType rType, uint32_t id) const noexcept { return isType(rType) && this->id() == id; } + + //! Returns the register type. + ASMJIT_INLINE_NODEBUG constexpr RegType type() const noexcept { return _signature.regType(); } + //! Returns the register group. + ASMJIT_INLINE_NODEBUG constexpr RegGroup group() const noexcept { return _signature.regGroup(); } + + //! Tests whether the register specifies a size (i.e. the size is not zero). + ASMJIT_INLINE_NODEBUG constexpr bool hasSize() const noexcept { return _signature.hasField(); } + //! Tests whether the register size matches size `s`. + ASMJIT_INLINE_NODEBUG constexpr bool hasSize(uint32_t s) const noexcept { return size() == s; } + + //! Returns the size of the register in bytes. If the register size depends on architecture (like `x86::CReg` and + //! `x86::DReg`) the size returned should be the greatest possible (so it should return 64-bit size in such case). + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } + + //! Returns operation predicate of the register (ARM/AArch64). + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the register. + ASMJIT_INLINE_NODEBUG constexpr uint32_t predicate() const noexcept { return _signature.getField(); } + + //! Sets operation predicate of the register to `predicate` (ARM/AArch64). + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the register. + ASMJIT_INLINE_NODEBUG void setPredicate(uint32_t predicate) noexcept { _signature.setField(predicate); } + + //! Resets shift operation type of the register to the default value (ARM/AArch64). + ASMJIT_INLINE_NODEBUG void resetPredicate() noexcept { _signature.setField(0); } + + //! Clones the register operand. + ASMJIT_INLINE_NODEBUG constexpr BaseReg clone() const noexcept { return BaseReg(*this); } + + //! Casts this register to `RegT` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegT cloneAs() const noexcept { return RegT(Signature(RegT::kSignature), id()); } + + //! Casts this register to `other` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegT cloneAs(const RegT& other) const noexcept { return RegT(other.signature(), id()); } + + //! Sets the register id to `id`. + ASMJIT_INLINE_NODEBUG void setId(uint32_t id) noexcept { _baseId = id; } + + //! Sets a 32-bit operand signature based on traits of `RegT`. + template + ASMJIT_INLINE_NODEBUG void setSignatureT() noexcept { _signature = RegT::kSignature; } + + //! Sets the register `signature` and `id`. + ASMJIT_INLINE_NODEBUG void setSignatureAndId(const OperandSignature& signature, uint32_t id) noexcept { + _signature = signature; + _baseId = id; + } + + //! \} + + //! \name Static Functions + //! \{ + + //! Tests whether the `op` operand is a general purpose register. + static ASMJIT_INLINE_NODEBUG bool isGp(const Operand_& op) noexcept { + // Check operand type and register group. Not interested in register type and size. + return op.signature().subset(Signature::kOpTypeMask | Signature::kRegGroupMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(RegGroup::kGp)); + } + + //! Tests whether the `op` operand is a vector register. + static ASMJIT_INLINE_NODEBUG bool isVec(const Operand_& op) noexcept { + // Check operand type and register group. Not interested in register type and size. + return op.signature().subset(Signature::kOpTypeMask | Signature::kRegGroupMask) == (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(RegGroup::kVec)); + } + + //! Tests whether the `op` is a general purpose register of the given `id`. + static ASMJIT_INLINE_NODEBUG bool isGp(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isGp(op)) & unsigned(op.id() == id)); } + //! Tests whether the `op` is a vector register of the given `id`. + static ASMJIT_INLINE_NODEBUG bool isVec(const Operand_& op, uint32_t id) noexcept { return bool(unsigned(isVec(op)) & unsigned(op.id() == id)); } + + //! \} +}; + +//! RegOnly is 8-byte version of `BaseReg` that allows to store either register or nothing. +//! +//! It's designed to decrease the space consumed by an extra "operand" in \ref BaseEmitter and \ref InstNode. +struct RegOnly { + //! \name Types + //! \{ + + typedef OperandSignature Signature; + + //! \} + + //! Operand signature - only \ref OperandType::kNone and \ref OperandType::kReg are supported. + Signature _signature; + //! Physical or virtual register id. + uint32_t _id; + + //! \name Construction & Destruction + //! \{ + + //! Initializes the `RegOnly` instance to hold register `signature` and `id`. + ASMJIT_INLINE_NODEBUG void init(const OperandSignature& signature, uint32_t id) noexcept { + _signature = signature; + _id = id; + } + + ASMJIT_INLINE_NODEBUG void init(const BaseReg& reg) noexcept { init(reg.signature(), reg.id()); } + ASMJIT_INLINE_NODEBUG void init(const RegOnly& reg) noexcept { init(reg.signature(), reg.id()); } + + //! Resets the `RegOnly` members to zeros (none). + ASMJIT_INLINE_NODEBUG void reset() noexcept { init(Signature::fromBits(0), 0); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether this ExtraReg is none (same as calling `Operand_::isNone()`). + ASMJIT_INLINE_NODEBUG constexpr bool isNone() const noexcept { return _signature == 0; } + //! Tests whether the register is valid (either virtual or physical). + ASMJIT_INLINE_NODEBUG constexpr bool isReg() const noexcept { return _signature != 0; } + + //! Tests whether this is a physical register. + ASMJIT_INLINE_NODEBUG constexpr bool isPhysReg() const noexcept { return _id < BaseReg::kIdBad; } + //! Tests whether this is a virtual register (used by `BaseCompiler`). + ASMJIT_INLINE_NODEBUG constexpr bool isVirtReg() const noexcept { return _id > BaseReg::kIdBad; } + + //! Returns the register signature or 0 if no register is assigned. + ASMJIT_INLINE_NODEBUG constexpr OperandSignature signature() const noexcept { return _signature; } + //! Returns the register id. + //! + //! \note Always check whether the register is assigned before using the returned identifier as + //! non-assigned `RegOnly` instance would return zero id, which is still a valid register id. + ASMJIT_INLINE_NODEBUG constexpr uint32_t id() const noexcept { return _id; } + + //! Sets the register id. + ASMJIT_INLINE_NODEBUG void setId(uint32_t id) noexcept { _id = id; } + + //! Returns the register type. + ASMJIT_INLINE_NODEBUG constexpr RegType type() const noexcept { return _signature.regType(); } + //! Returns the register group. + ASMJIT_INLINE_NODEBUG constexpr RegGroup group() const noexcept { return _signature.regGroup(); } + + //! \} + + //! \name Utilities + //! \{ + + //! Converts this ExtraReg to a real `RegT` operand. + template + ASMJIT_INLINE_NODEBUG constexpr RegT toReg() const noexcept { return RegT(_signature, _id); } + + //! \} +}; + +//! \cond INTERNAL +//! Adds a template specialization for `REG_TYPE` into the local `RegTraits`. +#define ASMJIT_DEFINE_REG_TRAITS(REG_TYPE, GROUP, SIZE, TYPE_ID) \ +template<> \ +struct RegTraits { \ + static constexpr uint32_t kValid = 1; \ + static constexpr RegType kType = REG_TYPE; \ + static constexpr RegGroup kGroup = GROUP; \ + static constexpr uint32_t kSize = SIZE; \ + static constexpr TypeId kTypeId = TYPE_ID; \ + \ + static constexpr uint32_t kSignature = \ + (OperandSignature::fromOpType(OperandType::kReg) | \ + OperandSignature::fromRegType(kType) | \ + OperandSignature::fromRegGroup(kGroup) | \ + OperandSignature::fromSize(kSize)).bits(); \ + \ +} + +//! Adds constructors and member functions to a class that implements abstract register. Abstract register is register +//! that doesn't have type or signature yet, it's a base class like `x86::Reg` or `arm::Reg`. +#define ASMJIT_DEFINE_ABSTRACT_REG(REG, BASE) \ +public: \ + /*! Default constructor that only setups basics. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG() noexcept \ + : BASE(Signature{kSignature}, kIdBad) {} \ + \ + /*! Makes a copy of the `other` register operand. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG(const REG& other) noexcept \ + : BASE(other) {} \ + \ + /*! Makes a copy of the `other` register having id set to `id` */ \ + ASMJIT_INLINE_NODEBUG constexpr REG(const BaseReg& other, uint32_t id) noexcept \ + : BASE(other, id) {} \ + \ + /*! Creates a register based on `signature` and `id`. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG(const OperandSignature& sgn, uint32_t id) noexcept \ + : BASE(sgn, id) {} \ + \ + /*! Creates a completely uninitialized REG register operand (garbage). */ \ + ASMJIT_INLINE_NODEBUG explicit REG(Globals::NoInit_) noexcept \ + : BASE(Globals::NoInit) {} \ + \ + /*! Creates a new register from register type and id. */ \ + static ASMJIT_INLINE_NODEBUG REG fromTypeAndId(RegType type, uint32_t id) noexcept { \ + return REG(signatureOf(type), id); \ + } \ + \ + /*! Clones the register operand. */ \ + ASMJIT_INLINE_NODEBUG constexpr REG clone() const noexcept { return REG(*this); } \ + \ + ASMJIT_INLINE_NODEBUG REG& operator=(const REG& other) noexcept = default; + +//! Adds constructors and member functions to a class that implements final register. Final registers MUST HAVE a valid +//! signature. +#define ASMJIT_DEFINE_FINAL_REG(REG, BASE, TRAITS) \ +public: \ + static constexpr RegType kThisType = TRAITS::kType; \ + static constexpr RegGroup kThisGroup = TRAITS::kGroup; \ + static constexpr uint32_t kThisSize = TRAITS::kSize; \ + static constexpr uint32_t kSignature = TRAITS::kSignature; \ + \ + ASMJIT_DEFINE_ABSTRACT_REG(REG, BASE) \ + \ + /*! Creates a register operand having its id set to `id`. */ \ + ASMJIT_INLINE_NODEBUG constexpr explicit REG(uint32_t id) noexcept \ + : BASE(Signature{kSignature}, id) {} +//! \endcond + +//! List of physical registers (base). +//! +//! \note List of registers is only used by some ARM instructions at the moment. +class BaseRegList : public Operand { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + kSignature = Signature::fromOpType(OperandType::kRegList).bits() + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a dummy register operand. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kRegList), 0, 0, 0) {} + + //! Creates a new register operand which is the same as `other` . + ASMJIT_INLINE_NODEBUG constexpr BaseRegList(const BaseRegList& other) noexcept + : Operand(other) {} + + //! Creates a new register operand compatible with `other`, but with a different `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList(const BaseRegList& other, RegMask regMask) noexcept + : Operand(Globals::Init, other._signature, regMask, 0, 0) {} + + //! Creates a register initialized to the given `signature` and `id`. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList(const Signature& signature, RegMask regMask) noexcept + : Operand(Globals::Init, signature, regMask, 0, 0) {} + + ASMJIT_INLINE_NODEBUG explicit BaseRegList(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG BaseRegList& operator=(const BaseRegList& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the register-list is valid, which means it has a type and at least a single register in the list. + ASMJIT_INLINE_NODEBUG constexpr bool isValid() const noexcept { return bool(unsigned(_signature != 0u) & unsigned(_baseId != 0u)); } + + //! Tests whether the register type matches `type` - same as `isReg(type)`, provided for convenience. + ASMJIT_INLINE_NODEBUG constexpr bool isType(RegType type) const noexcept { return _signature.subset(Signature::kRegTypeMask) == Signature::fromRegType(type); } + //! Tests whether the register group matches `group`. + ASMJIT_INLINE_NODEBUG constexpr bool isGroup(RegGroup group) const noexcept { return _signature.subset(Signature::kRegGroupMask) == Signature::fromRegGroup(group); } + + //! Tests whether the register is a general purpose register (any size). + ASMJIT_INLINE_NODEBUG constexpr bool isGp() const noexcept { return isGroup(RegGroup::kGp); } + //! Tests whether the register is a vector register. + ASMJIT_INLINE_NODEBUG constexpr bool isVec() const noexcept { return isGroup(RegGroup::kVec); } + + //! Returns the register type. + ASMJIT_INLINE_NODEBUG constexpr RegType type() const noexcept { return _signature.regType(); } + //! Returns the register group. + ASMJIT_INLINE_NODEBUG constexpr RegGroup group() const noexcept { return _signature.regGroup(); } + //! Returns the size of a single register in this register-list or 0 if unspecified. + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } + + //! Returns the register list as a mask, where each bit represents one physical register. + ASMJIT_INLINE_NODEBUG constexpr RegMask list() const noexcept { return _baseId; } + //! Sets the register list to `mask`. + ASMJIT_INLINE_NODEBUG void setList(RegMask mask) noexcept { _baseId = mask; } + //! Remoes all registers from the register-list by making the underlying register-mask zero. + ASMJIT_INLINE_NODEBUG void resetList() noexcept { _baseId = 0; } + + //! Adds registers passed by a register `mask` to the register-list. + ASMJIT_INLINE_NODEBUG void addList(RegMask mask) noexcept { _baseId |= mask; } + //! Removes registers passed by a register `mask` to the register-list. + ASMJIT_INLINE_NODEBUG void clearList(RegMask mask) noexcept { _baseId &= ~mask; } + //! Uses AND operator to combine the current register-list with other register `mask`. + ASMJIT_INLINE_NODEBUG void andList(RegMask mask) noexcept { _baseId &= mask; } + //! Uses XOR operator to combine the current register-list with other register `mask`. + ASMJIT_INLINE_NODEBUG void xorList(RegMask mask) noexcept { _baseId ^= mask; } + + //! Checks whether a physical register `physId` is in the register-list. + ASMJIT_INLINE_NODEBUG bool hasReg(uint32_t physId) const noexcept { return physId < 32u ? (_baseId & (1u << physId)) != 0 : false; } + //! Adds a physical register `physId` to the register-list. + ASMJIT_INLINE_NODEBUG void addReg(uint32_t physId) noexcept { addList(1u << physId); } + //! Removes a physical register `physId` from the register-list. + ASMJIT_INLINE_NODEBUG void clearReg(uint32_t physId) noexcept { clearList(1u << physId); } + + //! Clones the register-list operand. + ASMJIT_INLINE_NODEBUG constexpr BaseRegList clone() const noexcept { return BaseRegList(*this); } + + //! Casts this register to `RegT` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegListT cloneAs() const noexcept { return RegListT(Signature(RegListT::kSignature), list()); } + + //! Casts this register to `other` by also changing its signature. + //! + //! \note Improper use of `cloneAs()` can lead to hard-to-debug errors. + template + ASMJIT_INLINE_NODEBUG constexpr RegListT cloneAs(const RegListT& other) const noexcept { return RegListT(other.signature(), list()); } + + //! \} +}; + +template +class RegListT : public BaseRegList { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates a dummy register operand. + ASMJIT_INLINE_NODEBUG constexpr RegListT() noexcept + : BaseRegList() {} + + //! Creates a new register operand which is the same as `other` . + ASMJIT_INLINE_NODEBUG constexpr RegListT(const RegListT& other) noexcept + : BaseRegList(other) {} + + //! Creates a new register operand compatible with `other`, but with a different `id`. + ASMJIT_INLINE_NODEBUG constexpr RegListT(const RegListT& other, RegMask regMask) noexcept + : BaseRegList(other, regMask) {} + + //! Creates a register initialized to the given `signature` and `id`. + ASMJIT_INLINE_NODEBUG constexpr RegListT(const Signature& signature, RegMask regMask) noexcept + : BaseRegList(signature, regMask) {} + + //! Creates a register initialized to the given `signature` and `regs`. + ASMJIT_INLINE_NODEBUG RegListT(const Signature& signature, std::initializer_list regs) noexcept + : BaseRegList(signature, RegMask(0)) { addRegs(regs); } + + ASMJIT_INLINE_NODEBUG explicit RegListT(Globals::NoInit_) noexcept + : BaseRegList(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG RegListT& operator=(const RegListT& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + using BaseRegList::addList; + using BaseRegList::clearList; + using BaseRegList::andList; + using BaseRegList::xorList; + + //! Adds registers to this register-list as provided by `other` register-list. + ASMJIT_INLINE_NODEBUG void addList(const RegListT& other) noexcept { addList(other.list()); } + //! Removes registers contained in `other` register-list. + ASMJIT_INLINE_NODEBUG void clearList(const RegListT& other) noexcept { clearList(other.list()); } + //! Uses AND operator to combine the current register-list with `other` register-list. + ASMJIT_INLINE_NODEBUG void andList(const RegListT& other) noexcept { andList(other.list()); } + //! Uses XOR operator to combine the current register-list with `other` register-list. + ASMJIT_INLINE_NODEBUG void xorList(const RegListT& other) noexcept { xorList(other.list()); } + + using BaseRegList::addReg; + using BaseRegList::clearReg; + + ASMJIT_INLINE_NODEBUG void addReg(const RegT& reg) noexcept { + if (reg.id() < 32u) + addReg(reg.id()); + } + + ASMJIT_INLINE_NODEBUG void addRegs(std::initializer_list regs) noexcept { + for (const RegT& reg : regs) + addReg(reg); + } + + ASMJIT_INLINE_NODEBUG void clearReg(const RegT& reg) noexcept { + if (reg.id() < 32u) + clearReg(reg.id()); + } + + ASMJIT_INLINE_NODEBUG void clearRegs(std::initializer_list regs) noexcept { + for (const RegT& reg : regs) + clearReg(reg); + } + + //! \} +}; + +//! Base class for all memory operands. +//! +//! The data is split into the following parts: +//! +//! - BASE - Base register or label - requires 36 bits total. 4 bits are used to encode the type of the BASE operand +//! (label vs. register type) and the remaining 32 bits define the BASE id, which can be a physical or virtual +//! register index. If BASE type is zero, which is never used as a register type and label doesn't use it as well +//! then BASE field contains a high DWORD of a possible 64-bit absolute address, which is possible on X64. +//! +//! - INDEX - Index register (or theoretically Label, which doesn't make sense). Encoding is similar to BASE - it +//! also requires 36 bits and splits the encoding to INDEX type (4 bits defining the register type) and 32-bit id. +//! +//! - OFFSET - A relative offset of the address. Basically if BASE is specified the relative displacement adjusts +//! BASE and an optional INDEX. if BASE is not specified then the OFFSET should be considered as ABSOLUTE address +//! (at least on X86). In that case its low 32 bits are stored in DISPLACEMENT field and the remaining high 32 +//! bits are stored in BASE. +//! +//! - OTHER - There is rest 8 bits that can be used for whatever purpose. For example \ref x86::Mem operand uses +//! these bits to store segment override prefix and index shift (or scale). +class BaseMem : public Operand { +public: + //! \name Construction & Destruction + //! \{ + + //! Creates a default `BaseMem` operand, that points to [0]. + ASMJIT_INLINE_NODEBUG constexpr BaseMem() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kMem), 0, 0, 0) {} + + //! Creates a `BaseMem` operand that is a clone of `other`. + ASMJIT_INLINE_NODEBUG constexpr BaseMem(const BaseMem& other) noexcept + : Operand(other) {} + + //! Creates a `BaseMem` operand from `baseReg` and `offset`. + //! + //! \note This is an architecture independent constructor that can be used to create an architecture + //! independent memory operand to be used in portable code that can handle multiple architectures. + ASMJIT_INLINE_NODEBUG constexpr explicit BaseMem(const BaseReg& baseReg, int32_t offset = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kMem) | Signature::fromMemBaseType(baseReg.type()), + baseReg.id(), + 0, + uint32_t(offset)) {} + + //! \cond INTERNAL + //! Creates a `BaseMem` operand from 4 integers as used by `Operand_` struct. + ASMJIT_INLINE_NODEBUG constexpr BaseMem(const OperandSignature& u0, uint32_t baseId, uint32_t indexId, int32_t offset) noexcept + : Operand(Globals::Init, u0, baseId, indexId, uint32_t(offset)) {} + //! \endcond + + //! Creates a completely uninitialized `BaseMem` operand. + ASMJIT_INLINE_NODEBUG explicit BaseMem(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! Resets the memory operand - after the reset the memory points to [0]. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _signature = Signature::fromOpType(OperandType::kMem); + _baseId = 0; + _data[0] = 0; + _data[1] = 0; + } + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG BaseMem& operator=(const BaseMem& other) noexcept { copyFrom(other); return *this; } + + //! \} + + //! \name Accessors + //! \{ + + //! Clones the memory operand. + ASMJIT_INLINE_NODEBUG constexpr BaseMem clone() const noexcept { return BaseMem(*this); } + + //! Creates a new copy of this memory operand adjusted by `off`. + ASMJIT_INLINE_NODEBUG BaseMem cloneAdjusted(int64_t off) const noexcept { + BaseMem result(*this); + result.addOffset(off); + return result; + } + + //! Tests whether this memory operand is a register home (only used by \ref asmjit_compiler) + ASMJIT_INLINE_NODEBUG constexpr bool isRegHome() const noexcept { return _signature.hasField(); } + //! Mark this memory operand as register home (only used by \ref asmjit_compiler). + ASMJIT_INLINE_NODEBUG void setRegHome() noexcept { _signature |= Signature::kMemRegHomeFlag; } + //! Marks this operand to not be a register home (only used by \ref asmjit_compiler). + ASMJIT_INLINE_NODEBUG void clearRegHome() noexcept { _signature &= ~Signature::kMemRegHomeFlag; } + + //! Tests whether the memory operand has a BASE register or label specified. + ASMJIT_INLINE_NODEBUG constexpr bool hasBase() const noexcept { + return (_signature & Signature::kMemBaseTypeMask) != 0; + } + + //! Tests whether the memory operand has an INDEX register specified. + ASMJIT_INLINE_NODEBUG constexpr bool hasIndex() const noexcept { + return (_signature & Signature::kMemIndexTypeMask) != 0; + } + + //! Tests whether the memory operand has BASE or INDEX register. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseOrIndex() const noexcept { + return (_signature & Signature::kMemBaseIndexMask) != 0; + } + + //! Tests whether the memory operand has BASE and INDEX register. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseAndIndex() const noexcept { + return (_signature & Signature::kMemBaseTypeMask) != 0 && (_signature & Signature::kMemIndexTypeMask) != 0; + } + + //! Tests whether the BASE operand is a label. + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseLabel() const noexcept { + return _signature.subset(Signature::kMemBaseTypeMask) == Signature::fromMemBaseType(RegType::kLabelTag); + } + + //! Tests whether the BASE operand is a register (registers start after `RegType::kLabelTag`). + ASMJIT_INLINE_NODEBUG constexpr bool hasBaseReg() const noexcept { + return _signature.subset(Signature::kMemBaseTypeMask).bits() > Signature::fromMemBaseType(RegType::kLabelTag).bits(); + } + + //! Tests whether the INDEX operand is a register (registers start after `RegType::kLabelTag`). + ASMJIT_INLINE_NODEBUG constexpr bool hasIndexReg() const noexcept { + return _signature.subset(Signature::kMemIndexTypeMask).bits() > Signature::fromMemIndexType(RegType::kLabelTag).bits(); + } + + //! Returns the type of the BASE register (0 if this memory operand doesn't use the BASE register). + //! + //! \note If the returned type is one (a value never associated to a register type) the BASE is not register, but it + //! is a label. One equals to `kLabelTag`. You should always check `hasBaseLabel()` before using `baseId()` result. + ASMJIT_INLINE_NODEBUG constexpr RegType baseType() const noexcept { return _signature.memBaseType(); } + + //! Returns the type of an INDEX register (0 if this memory operand doesn't + //! use the INDEX register). + ASMJIT_INLINE_NODEBUG constexpr RegType indexType() const noexcept { return _signature.memIndexType(); } + + //! This is used internally for BASE+INDEX validation. + ASMJIT_INLINE_NODEBUG constexpr uint32_t baseAndIndexTypes() const noexcept { return _signature.getField(); } + + //! Returns both BASE (4:0 bits) and INDEX (9:5 bits) types combined into a single value. + //! + //! \remarks Returns id of the BASE register or label (if the BASE was specified as label). + ASMJIT_INLINE_NODEBUG constexpr uint32_t baseId() const noexcept { return _baseId; } + + //! Returns the id of the INDEX register. + ASMJIT_INLINE_NODEBUG constexpr uint32_t indexId() const noexcept { return _data[kDataMemIndexId]; } + + //! Sets the id of the BASE register (without modifying its type). + ASMJIT_INLINE_NODEBUG void setBaseId(uint32_t id) noexcept { _baseId = id; } + //! Sets the register type of the BASE register (without modifying its id). + ASMJIT_INLINE_NODEBUG void setBaseType(RegType regType) noexcept { _signature.setMemBaseType(regType); } + + //! Sets the id of the INDEX register (without modifying its type). + ASMJIT_INLINE_NODEBUG void setIndexId(uint32_t id) noexcept { _data[kDataMemIndexId] = id; } + //! Sets the register type of the INDEX register (without modifying its id). + ASMJIT_INLINE_NODEBUG void setIndexType(RegType regType) noexcept { _signature.setMemIndexType(regType); } + + //! Sets the base register to type and id of the given `base` operand. + ASMJIT_INLINE_NODEBUG void setBase(const BaseReg& base) noexcept { return _setBase(base.type(), base.id()); } + //! Sets the index register to type and id of the given `index` operand. + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index) noexcept { return _setIndex(index.type(), index.id()); } + + //! \cond INTERNAL + ASMJIT_INLINE_NODEBUG void _setBase(RegType type, uint32_t id) noexcept { + _signature.setField(uint32_t(type)); + _baseId = id; + } + + ASMJIT_INLINE_NODEBUG void _setIndex(RegType type, uint32_t id) noexcept { + _signature.setField(uint32_t(type)); + _data[kDataMemIndexId] = id; + } + //! \endcond + + //! Resets the memory operand's BASE register or label. + ASMJIT_INLINE_NODEBUG void resetBase() noexcept { _setBase(RegType::kNone, 0); } + //! Resets the memory operand's INDEX register. + ASMJIT_INLINE_NODEBUG void resetIndex() noexcept { _setIndex(RegType::kNone, 0); } + + //! Sets the memory operand size (in bytes). + ASMJIT_INLINE_NODEBUG void setSize(uint32_t size) noexcept { _signature.setField(size); } + + //! Tests whether the memory operand has a 64-bit offset or absolute address. + //! + //! If this is true then `hasBase()` must always report false. + ASMJIT_INLINE_NODEBUG constexpr bool isOffset64Bit() const noexcept { return baseType() == RegType::kNone; } + + //! Tests whether the memory operand has a non-zero offset or absolute address. + ASMJIT_INLINE_NODEBUG constexpr bool hasOffset() const noexcept { + return (_data[kDataMemOffsetLo] | uint32_t(_baseId & Support::bitMaskFromBool(isOffset64Bit()))) != 0; + } + + //! Returns either relative offset or absolute address as 64-bit integer. + ASMJIT_INLINE_NODEBUG constexpr int64_t offset() const noexcept { + return isOffset64Bit() ? int64_t(uint64_t(_data[kDataMemOffsetLo]) | (uint64_t(_baseId) << 32)) + : int64_t(int32_t(_data[kDataMemOffsetLo])); // Sign extend 32-bit offset. + } + + //! Returns a 32-bit low part of a 64-bit offset or absolute address. + ASMJIT_INLINE_NODEBUG constexpr int32_t offsetLo32() const noexcept { return int32_t(_data[kDataMemOffsetLo]); } + //! Returns a 32-but high part of a 64-bit offset or absolute address. + //! + //! \note This function is UNSAFE and returns garbage if `isOffset64Bit()` + //! returns false. Never use it blindly without checking it first. + ASMJIT_INLINE_NODEBUG constexpr int32_t offsetHi32() const noexcept { return int32_t(_baseId); } + + //! Sets a 64-bit offset or an absolute address to `offset`. + //! + //! \note This functions attempts to set both high and low parts of a 64-bit offset, however, if the operand has + //! a BASE register it will store only the low 32 bits of the offset / address as there is no way to store both + //! BASE and 64-bit offset, and there is currently no architecture that has such capability targeted by AsmJit. + inline void setOffset(int64_t offset) noexcept { + uint32_t lo = uint32_t(uint64_t(offset) & 0xFFFFFFFFu); + uint32_t hi = uint32_t(uint64_t(offset) >> 32); + uint32_t hiMsk = Support::bitMaskFromBool(isOffset64Bit()); + + _data[kDataMemOffsetLo] = lo; + _baseId = (hi & hiMsk) | (_baseId & ~hiMsk); + } + //! Sets a low 32-bit offset to `offset` (don't use without knowing how BaseMem works). + inline void setOffsetLo32(int32_t offset) noexcept { _data[kDataMemOffsetLo] = uint32_t(offset); } + + //! Adjusts the offset by `offset`. + //! + //! \note This is a fast function that doesn't use the HI 32-bits of a 64-bit offset. Use it only if you know that + //! there is a BASE register and the offset is only 32 bits anyway. + + //! Adjusts the memory operand offset by a `offset`. + inline void addOffset(int64_t offset) noexcept { + if (isOffset64Bit()) { + int64_t result = offset + int64_t(uint64_t(_data[kDataMemOffsetLo]) | (uint64_t(_baseId) << 32)); + _data[kDataMemOffsetLo] = uint32_t(uint64_t(result) & 0xFFFFFFFFu); + _baseId = uint32_t(uint64_t(result) >> 32); + } + else { + _data[kDataMemOffsetLo] += uint32_t(uint64_t(offset) & 0xFFFFFFFFu); + } + } + + //! Adds `offset` to a low 32-bit offset part (don't use without knowing how BaseMem works). + ASMJIT_INLINE_NODEBUG void addOffsetLo32(int32_t offset) noexcept { _data[kDataMemOffsetLo] += uint32_t(offset); } + + //! Resets the memory offset to zero. + ASMJIT_INLINE_NODEBUG void resetOffset() noexcept { setOffset(0); } + + //! Resets the lo part of the memory offset to zero (don't use without knowing how BaseMem works). + ASMJIT_INLINE_NODEBUG void resetOffsetLo32() noexcept { setOffsetLo32(0); } + + //! \} +}; + +//! Type of the an immediate value. +enum class ImmType : uint32_t { + //! Immediate is integer. + kInt = 0, + //! Immediate is a floating point stored as double-precision. + kDouble = 1 +}; + +//! Immediate operands are encoded with instruction data. +class Imm : public Operand { +public: + //! \cond INTERNAL + template + struct IsConstexprConstructibleAsImmType + : public std::integral_constant::value || + std::is_pointer::value || + std::is_integral::value || + std::is_function::value> {}; + + template + struct IsConvertibleToImmType + : public std::integral_constant::value || + std::is_floating_point::value> {}; + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Creates a new immediate value (initial value is 0). + ASMJIT_INLINE_NODEBUG constexpr Imm() noexcept + : Operand(Globals::Init, Signature::fromOpType(OperandType::kImm), 0, 0, 0) {} + + //! Creates a new immediate value from `other`. + ASMJIT_INLINE_NODEBUG constexpr Imm(const Imm& other) noexcept + : Operand(other) {} + + //! Creates a new immediate value from ARM/AArch64 specific `shift`. + ASMJIT_INLINE_NODEBUG constexpr Imm(const arm::Shift& shift) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(uint32_t(shift.op())), + 0, + Support::unpackU32At0(shift.value()), + Support::unpackU32At1(shift.value())) {} + + //! Creates a new signed immediate value, assigning the value to `val` and an architecture-specific predicate + //! to `predicate`. + //! + //! \note Predicate is currently only used by ARM architectures. + template::type>::value>::type> + ASMJIT_INLINE_NODEBUG constexpr Imm(const T& val, const uint32_t predicate = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(predicate), + 0, + Support::unpackU32At0(int64_t(val)), + Support::unpackU32At1(int64_t(val))) {} + + ASMJIT_INLINE_NODEBUG Imm(const float& val, const uint32_t predicate = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(predicate), + 0, + 0, + 0) { setValue(val); } + + ASMJIT_INLINE_NODEBUG Imm(const double& val, const uint32_t predicate = 0) noexcept + : Operand(Globals::Init, + Signature::fromOpType(OperandType::kImm) | Signature::fromPredicate(predicate), + 0, + 0, + 0) { setValue(val); } + + ASMJIT_INLINE_NODEBUG explicit Imm(Globals::NoInit_) noexcept + : Operand(Globals::NoInit) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + //! Assigns the value of the `other` operand to this immediate. + ASMJIT_INLINE_NODEBUG Imm& operator=(const Imm& other) noexcept { copyFrom(other); return *this; } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns immediate type. + ASMJIT_INLINE_NODEBUG constexpr ImmType type() const noexcept { return (ImmType)_signature.getField(); } + //! Sets the immediate type to `type`. + ASMJIT_INLINE_NODEBUG void setType(ImmType type) noexcept { _signature.setField(uint32_t(type)); } + //! Resets immediate type to \ref ImmType::kInt. + ASMJIT_INLINE_NODEBUG void resetType() noexcept { setType(ImmType::kInt); } + + //! Returns operation predicate of the immediate. + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the immediate. + ASMJIT_INLINE_NODEBUG constexpr uint32_t predicate() const noexcept { return _signature.getField(); } + + //! Sets operation predicate of the immediate to `predicate`. + //! + //! The meaning depends on architecture, for example on ARM hardware this describes \ref arm::ShiftOp + //! of the immediate. + ASMJIT_INLINE_NODEBUG void setPredicate(uint32_t predicate) noexcept { _signature.setField(predicate); } + + //! Resets the shift operation type of the immediate to the default value (no operation). + ASMJIT_INLINE_NODEBUG void resetPredicate() noexcept { _signature.setField(0); } + + //! Returns the immediate value as `int64_t`, which is the internal format Imm uses. + ASMJIT_INLINE_NODEBUG constexpr int64_t value() const noexcept { + return int64_t((uint64_t(_data[kDataImmValueHi]) << 32) | _data[kDataImmValueLo]); + } + + //! Tests whether this immediate value is integer of any size. + ASMJIT_INLINE_NODEBUG constexpr uint32_t isInt() const noexcept { return type() == ImmType::kInt; } + //! Tests whether this immediate value is a double precision floating point value. + ASMJIT_INLINE_NODEBUG constexpr uint32_t isDouble() const noexcept { return type() == ImmType::kDouble; } + + //! Tests whether the immediate can be casted to 8-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr bool isInt8() const noexcept { return type() == ImmType::kInt && Support::isInt8(value()); } + //! Tests whether the immediate can be casted to 8-bit unsigned integer. + ASMJIT_INLINE_NODEBUG constexpr bool isUInt8() const noexcept { return type() == ImmType::kInt && Support::isUInt8(value()); } + //! Tests whether the immediate can be casted to 16-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr bool isInt16() const noexcept { return type() == ImmType::kInt && Support::isInt16(value()); } + //! Tests whether the immediate can be casted to 16-bit unsigned integer. + ASMJIT_INLINE_NODEBUG constexpr bool isUInt16() const noexcept { return type() == ImmType::kInt && Support::isUInt16(value()); } + //! Tests whether the immediate can be casted to 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr bool isInt32() const noexcept { return type() == ImmType::kInt && Support::isInt32(value()); } + //! Tests whether the immediate can be casted to 32-bit unsigned integer. + ASMJIT_INLINE_NODEBUG constexpr bool isUInt32() const noexcept { return type() == ImmType::kInt && _data[kDataImmValueHi] == 0; } + + //! Returns the immediate value casted to `T`. + //! + //! The value is masked before it's casted to `T` so the returned value is simply the representation of `T` + //! considering the original value's lowest bits. + template + ASMJIT_INLINE_NODEBUG T valueAs() const noexcept { return Support::immediateToT(value()); } + + //! Returns low 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr int32_t int32Lo() const noexcept { return int32_t(_data[kDataImmValueLo]); } + //! Returns high 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr int32_t int32Hi() const noexcept { return int32_t(_data[kDataImmValueHi]); } + //! Returns low 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr uint32_t uint32Lo() const noexcept { return _data[kDataImmValueLo]; } + //! Returns high 32-bit signed integer. + ASMJIT_INLINE_NODEBUG constexpr uint32_t uint32Hi() const noexcept { return _data[kDataImmValueHi]; } + + //! Sets immediate value to `val`, the value is casted to a signed 64-bit integer. + template + ASMJIT_INLINE_NODEBUG void setValue(const T& val) noexcept { + _setValueInternal(Support::immediateFromT(val), std::is_floating_point::value ? ImmType::kDouble : ImmType::kInt); + } + + ASMJIT_INLINE_NODEBUG void _setValueInternal(int64_t val, ImmType type) noexcept { + setType(type); + _data[kDataImmValueHi] = uint32_t(uint64_t(val) >> 32); + _data[kDataImmValueLo] = uint32_t(uint64_t(val) & 0xFFFFFFFFu); + } + + //! \} + + //! \name Utilities + //! \{ + + //! Clones the immediate operand. + ASMJIT_INLINE_NODEBUG constexpr Imm clone() const noexcept { return Imm(*this); } + + ASMJIT_INLINE_NODEBUG void signExtend8Bits() noexcept { setValue(int64_t(valueAs())); } + ASMJIT_INLINE_NODEBUG void signExtend16Bits() noexcept { setValue(int64_t(valueAs())); } + ASMJIT_INLINE_NODEBUG void signExtend32Bits() noexcept { setValue(int64_t(valueAs())); } + + ASMJIT_INLINE_NODEBUG void zeroExtend8Bits() noexcept { setValue(valueAs()); } + ASMJIT_INLINE_NODEBUG void zeroExtend16Bits() noexcept { setValue(valueAs()); } + ASMJIT_INLINE_NODEBUG void zeroExtend32Bits() noexcept { _data[kDataImmValueHi] = 0u; } + + //! \} +}; + +//! Creates a new immediate operand. +template +static ASMJIT_INLINE_NODEBUG constexpr Imm imm(const T& val) noexcept { return Imm(val); } + +//! \} + +namespace Globals { + //! \ingroup asmjit_assembler + //! + //! A default-constructed operand of `Operand_::kOpNone` type. + static constexpr const Operand none; +} + +//! \cond INTERNAL +namespace Support { + +template +struct ForwardOpImpl { + static ASMJIT_INLINE_NODEBUG const T& forward(const T& value) noexcept { return value; } +}; + +template +struct ForwardOpImpl { + static ASMJIT_INLINE_NODEBUG Imm forward(const T& value) noexcept { return Imm(value); } +}; + +//! Either forwards operand T or returns a new operand that wraps it if T is a type convertible to operand. +//! At the moment this is only used to convert integers, floats, and enumarations to \ref Imm operands. +template +struct ForwardOp : public ForwardOpImpl::type>::value> {}; + +} // {Support} +//! \endcond + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_OPERAND_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/osutils.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/osutils.h new file mode 100644 index 0000000000000000000000000000000000000000..44bdfb652afa6fc4dcd20f8d2fc527b9dda759ba --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/osutils.h @@ -0,0 +1,54 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_OSUTILS_H_INCLUDED +#define ASMJIT_CORE_OSUTILS_H_INCLUDED + +#include "../core/globals.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! \cond INTERNAL +//! Lock. +//! +//! Lock is internal, it cannot be used outside of AsmJit, however, its internal +//! layout is exposed as it's used by some other classes, which are public. +class Lock { +public: + ASMJIT_NONCOPYABLE(Lock) + +#if defined(_WIN32) +#pragma pack(push, 8) + struct ASMJIT_MAY_ALIAS Handle { + void* DebugInfo; + long LockCount; + long RecursionCount; + void* OwningThread; + void* LockSemaphore; + unsigned long* SpinCount; + }; + Handle _handle; +#pragma pack(pop) +#elif !defined(__EMSCRIPTEN__) + typedef pthread_mutex_t Handle; + Handle _handle; +#endif + + ASMJIT_INLINE_NODEBUG Lock() noexcept; + ASMJIT_INLINE_NODEBUG ~Lock() noexcept; + + ASMJIT_INLINE_NODEBUG void lock() noexcept; + ASMJIT_INLINE_NODEBUG void unlock() noexcept; +}; +//! \endcond + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_OSUTILS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/string.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/string.h new file mode 100644 index 0000000000000000000000000000000000000000..1f090e05719f8c3533ee633910a2a3e07daf081c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/string.h @@ -0,0 +1,383 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_STRING_H_INCLUDED +#define ASMJIT_CORE_STRING_H_INCLUDED + +#include "../core/support.h" +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! Format flags used by \ref String API. +enum class StringFormatFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + //! Show sign. + kShowSign = 0x00000001u, + //! Show space. + kShowSpace = 0x00000002u, + //! Alternate form (use 0x when formatting HEX number). + kAlternate = 0x00000004u, + //! The input is signed. + kSigned = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(StringFormatFlags) + +//! Fixed string - only useful for strings that would never exceed `N - 1` characters; always null-terminated. +template +union FixedString { + //! \name Constants + //! \{ + + // This cannot be constexpr as GCC 4.8 refuses constexpr members of unions. + enum : uint32_t { + kNumUInt32Words = uint32_t((N + sizeof(uint32_t) - 1) / sizeof(uint32_t)) + }; + + //! \} + + //! \name Members + //! \{ + + char str[kNumUInt32Words * sizeof(uint32_t)]; + uint32_t u32[kNumUInt32Words]; + + //! \} + + //! \name Utilities + //! \{ + + inline bool equals(const char* other) const noexcept { return strcmp(str, other) == 0; } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use FixedString::equals() instead") + inline bool eq(const char* other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} +}; + +//! A simple non-reference counted string that uses small string optimization (SSO). +//! +//! This string has 3 allocation possibilities: +//! +//! 1. Small - embedded buffer is used for up to `kSSOCapacity` characters. This should handle most small +//! strings and thus avoid dynamic memory allocation for most use-cases. +//! +//! 2. Large - string that doesn't fit into an embedded buffer (or string that was truncated from a larger +//! buffer) and is owned by AsmJit. When you destroy the string AsmJit would automatically +//! release the large buffer. +//! +//! 3. External - like Large (2), however, the large buffer is not owned by AsmJit and won't be released when +//! the string is destroyed or reallocated. This is mostly useful for working with larger temporary +//! strings allocated on stack or with immutable strings. +class String { +public: + ASMJIT_NONCOPYABLE(String) + + //! String operation. + enum class ModifyOp : uint32_t { + //! Assignment - a new content replaces the current one. + kAssign = 0, + //! Append - a new content is appended to the string. + kAppend = 1 + }; + + //! \cond INTERNAL + enum : uint32_t { + kLayoutSize = 32, + kSSOCapacity = kLayoutSize - 2 + }; + + //! String type. + enum Type : uint8_t { + //! Large string (owned by String). + kTypeLarge = 0x1Fu, + //! External string (zone allocated or not owned by String). + kTypeExternal = 0x20u + }; + + union Raw { + uint8_t u8[kLayoutSize]; + uint64_t u64[kLayoutSize / sizeof(uint64_t)]; + uintptr_t uptr[kLayoutSize / sizeof(uintptr_t)]; + }; + + struct Small { + uint8_t type; + char data[kSSOCapacity + 1u]; + }; + + struct Large { + uint8_t type; + uint8_t reserved[sizeof(uintptr_t) - 1]; + size_t size; + size_t capacity; + char* data; + }; + + union { + uint8_t _type; + Raw _raw; + Small _small; + Large _large; + }; + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Creates a default-initialized string if zero length. + ASMJIT_INLINE_NODEBUG String() noexcept + : _small {} {} + + //! Creates a string that takes ownership of the content of the `other` string. + ASMJIT_INLINE_NODEBUG String(String&& other) noexcept { + _raw = other._raw; + other._resetInternal(); + } + + ASMJIT_INLINE_NODEBUG ~String() noexcept { + reset(); + } + + //! Reset the string into a construction state. + ASMJIT_API Error reset() noexcept; + + //! \} + + //! \name Overloaded Operators + //! \{ + + inline String& operator=(String&& other) noexcept { + swap(other); + other.reset(); + return *this; + } + + ASMJIT_INLINE_NODEBUG bool operator==(const char* other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const char* other) const noexcept { return !equals(other); } + + ASMJIT_INLINE_NODEBUG bool operator==(const String& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const String& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool isExternal() const noexcept { return _type == kTypeExternal; } + ASMJIT_INLINE_NODEBUG bool isLargeOrExternal() const noexcept { return _type >= kTypeLarge; } + + //! Tests whether the string is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return size() == 0; } + //! Returns the size of the string. + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return isLargeOrExternal() ? size_t(_large.size) : size_t(_type); } + //! Returns the capacity of the string. + ASMJIT_INLINE_NODEBUG size_t capacity() const noexcept { return isLargeOrExternal() ? _large.capacity : size_t(kSSOCapacity); } + + //! Returns the data of the string. + ASMJIT_INLINE_NODEBUG char* data() noexcept { return isLargeOrExternal() ? _large.data : _small.data; } + //! \overload + ASMJIT_INLINE_NODEBUG const char* data() const noexcept { return isLargeOrExternal() ? _large.data : _small.data; } + + ASMJIT_INLINE_NODEBUG char* start() noexcept { return data(); } + ASMJIT_INLINE_NODEBUG const char* start() const noexcept { return data(); } + + ASMJIT_INLINE_NODEBUG char* end() noexcept { return data() + size(); } + ASMJIT_INLINE_NODEBUG const char* end() const noexcept { return data() + size(); } + + //! \} + + //! \name String Operations + //! \{ + + //! Swaps the content of this string with `other`. + ASMJIT_INLINE_NODEBUG void swap(String& other) noexcept { + std::swap(_raw, other._raw); + } + + //! Clears the content of the string. + ASMJIT_API Error clear() noexcept; + + ASMJIT_API char* prepare(ModifyOp op, size_t size) noexcept; + + ASMJIT_API Error _opString(ModifyOp op, const char* str, size_t size = SIZE_MAX) noexcept; + ASMJIT_API Error _opChar(ModifyOp op, char c) noexcept; + ASMJIT_API Error _opChars(ModifyOp op, char c, size_t n) noexcept; + ASMJIT_API Error _opNumber(ModifyOp op, uint64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept; + ASMJIT_API Error _opHex(ModifyOp op, const void* data, size_t size, char separator = '\0') noexcept; + ASMJIT_API Error _opFormat(ModifyOp op, const char* fmt, ...) noexcept; + ASMJIT_API Error _opVFormat(ModifyOp op, const char* fmt, va_list ap) noexcept; + + //! Replaces the current of the string with `data` of the given `size`. + //! + //! Null terminated strings can set `size` to `SIZE_MAX`. + ASMJIT_API Error assign(const char* data, size_t size = SIZE_MAX) noexcept; + + //! Replaces the current of the string with `other` string. + ASMJIT_INLINE_NODEBUG Error assign(const String& other) noexcept { + return assign(other.data(), other.size()); + } + + //! Replaces the current of the string by a single `c` character. + ASMJIT_INLINE_NODEBUG Error assign(char c) noexcept { + return _opChar(ModifyOp::kAssign, c); + } + + //! Replaces the current of the string by a `c` character, repeated `n` times. + ASMJIT_INLINE_NODEBUG Error assignChars(char c, size_t n) noexcept { + return _opChars(ModifyOp::kAssign, c, n); + } + + //! Replaces the current of the string by a formatted integer `i` (signed). + ASMJIT_INLINE_NODEBUG Error assignInt(int64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAssign, uint64_t(i), base, width, flags | StringFormatFlags::kSigned); + } + + //! Replaces the current of the string by a formatted integer `i` (unsigned). + ASMJIT_INLINE_NODEBUG Error assignUInt(uint64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAssign, i, base, width, flags); + } + + //! Replaces the current of the string by the given `data` converted to a HEX string. + ASMJIT_INLINE_NODEBUG Error assignHex(const void* data, size_t size, char separator = '\0') noexcept { + return _opHex(ModifyOp::kAssign, data, size, separator); + } + + //! Replaces the current of the string by a formatted string `fmt`. + template + ASMJIT_INLINE_NODEBUG Error assignFormat(const char* fmt, Args&&... args) noexcept { + return _opFormat(ModifyOp::kAssign, fmt, std::forward(args)...); + } + + //! Replaces the current of the string by a formatted string `fmt` (va_list version). + ASMJIT_INLINE_NODEBUG Error assignVFormat(const char* fmt, va_list ap) noexcept { + return _opVFormat(ModifyOp::kAssign, fmt, ap); + } + + //! Appends `str` having the given size `size` to the string. + //! + //! Null terminated strings can set `size` to `SIZE_MAX`. + ASMJIT_INLINE_NODEBUG Error append(const char* str, size_t size = SIZE_MAX) noexcept { + return _opString(ModifyOp::kAppend, str, size); + } + + //! Appends `other` string to this string. + ASMJIT_INLINE_NODEBUG Error append(const String& other) noexcept { + return append(other.data(), other.size()); + } + + //! Appends a single `c` character. + ASMJIT_INLINE_NODEBUG Error append(char c) noexcept { + return _opChar(ModifyOp::kAppend, c); + } + + //! Appends `c` character repeated `n` times. + ASMJIT_INLINE_NODEBUG Error appendChars(char c, size_t n) noexcept { + return _opChars(ModifyOp::kAppend, c, n); + } + + //! Appends a formatted integer `i` (signed). + ASMJIT_INLINE_NODEBUG Error appendInt(int64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAppend, uint64_t(i), base, width, flags | StringFormatFlags::kSigned); + } + + //! Appends a formatted integer `i` (unsigned). + ASMJIT_INLINE_NODEBUG Error appendUInt(uint64_t i, uint32_t base = 0, size_t width = 0, StringFormatFlags flags = StringFormatFlags::kNone) noexcept { + return _opNumber(ModifyOp::kAppend, i, base, width, flags); + } + + //! Appends the given `data` converted to a HEX string. + ASMJIT_INLINE_NODEBUG Error appendHex(const void* data, size_t size, char separator = '\0') noexcept { + return _opHex(ModifyOp::kAppend, data, size, separator); + } + + //! Appends a formatted string `fmt` with `args`. + template + ASMJIT_INLINE_NODEBUG Error appendFormat(const char* fmt, Args&&... args) noexcept { + return _opFormat(ModifyOp::kAppend, fmt, std::forward(args)...); + } + + //! Appends a formatted string `fmt` (va_list version). + ASMJIT_INLINE_NODEBUG Error appendVFormat(const char* fmt, va_list ap) noexcept { + return _opVFormat(ModifyOp::kAppend, fmt, ap); + } + + ASMJIT_API Error padEnd(size_t n, char c = ' ') noexcept; + + //! Truncate the string length into `newSize`. + ASMJIT_API Error truncate(size_t newSize) noexcept; + + ASMJIT_API bool equals(const char* other, size_t size = SIZE_MAX) const noexcept; + ASMJIT_INLINE_NODEBUG bool equals(const String& other) const noexcept { return equals(other.data(), other.size()); } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use String::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const char* other, size_t size = SIZE_MAX) const noexcept { return equals(other, size); } + + ASMJIT_DEPRECATED("Use String::equals() instead") + ASMJIT_INLINE_NODEBUG bool eq(const String& other) const noexcept { return equals(other.data(), other.size()); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} + + //! \name Internal Functions + //! \{ + + //! Resets string to embedded and makes it empty (zero length, zero first char) + //! + //! \note This is always called internally after an external buffer was released as it zeroes all bytes + //! used by String's embedded storage. + inline void _resetInternal() noexcept { + for (size_t i = 0; i < ASMJIT_ARRAY_SIZE(_raw.uptr); i++) + _raw.uptr[i] = 0; + } + + inline void _setSize(size_t newSize) noexcept { + if (isLargeOrExternal()) + _large.size = newSize; + else + _small.type = uint8_t(newSize); + } + + //! \} +}; + +//! Temporary string builder, has statically allocated `N` bytes. +template +class StringTmp : public String { +public: + ASMJIT_NONCOPYABLE(StringTmp) + + //! Embedded data. + char _embeddedData[Support::alignUp(N + 1, sizeof(size_t))]; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG StringTmp() noexcept { + _resetToTemporary(); + } + + inline void _resetToTemporary() noexcept { + _large.type = kTypeExternal; + _large.capacity = ASMJIT_ARRAY_SIZE(_embeddedData) - 1; + _large.data = _embeddedData; + _embeddedData[0] = '\0'; + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_STRING_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/support.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/support.h new file mode 100644 index 0000000000000000000000000000000000000000..c6e70fb09a28baeb99c1807742182e79006494cd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/support.h @@ -0,0 +1,1818 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_SUPPORT_H_INCLUDED +#define ASMJIT_CORE_SUPPORT_H_INCLUDED + +#include "../core/globals.h" + +#if defined(_MSC_VER) + #include +#endif + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_utilities +//! \{ + +//! Contains support classes and functions that may be used by AsmJit source and header files. Anything defined +//! here is considered internal and should not be used outside of AsmJit and related projects like AsmTK. +namespace Support { + +// Support - Basic Traits +// ====================== + +#if ASMJIT_ARCH_X86 +typedef uint8_t FastUInt8; +#else +typedef uint32_t FastUInt8; +#endif + +//! \cond INTERNAL +namespace Internal { + template + struct AliasedUInt {}; + + template<> struct AliasedUInt { typedef uint16_t ASMJIT_MAY_ALIAS T; }; + template<> struct AliasedUInt { typedef uint32_t ASMJIT_MAY_ALIAS T; }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS T; }; + + template<> struct AliasedUInt { typedef uint16_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 1); }; + template<> struct AliasedUInt { typedef uint32_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 1); }; + template<> struct AliasedUInt { typedef uint32_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 2); }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 1); }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 2); }; + template<> struct AliasedUInt { typedef uint64_t ASMJIT_MAY_ALIAS ASMJIT_ALIGN_TYPE(T, 4); }; + + // StdInt - Make an int-type by size (signed or unsigned) that is the + // same as types defined by . + // Int32Or64 - Make an int-type that has at least 32 bits: [u]int[32|64]_t. + + template + struct StdInt {}; // Fail if not specialized. + + template<> struct StdInt<1, 0> { typedef int8_t Type; }; + template<> struct StdInt<1, 1> { typedef uint8_t Type; }; + template<> struct StdInt<2, 0> { typedef int16_t Type; }; + template<> struct StdInt<2, 1> { typedef uint16_t Type; }; + template<> struct StdInt<4, 0> { typedef int32_t Type; }; + template<> struct StdInt<4, 1> { typedef uint32_t Type; }; + template<> struct StdInt<8, 0> { typedef int64_t Type; }; + template<> struct StdInt<8, 1> { typedef uint64_t Type; }; + + template::value> + struct Int32Or64 : public StdInt {}; +} +//! \endcond + +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUnsigned() noexcept { return std::is_unsigned::value; } + +//! Casts an integer `x` to either `int32_t` or `int64_t` depending on `T`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::Int32Or64::Type asInt(const T& x) noexcept { + return (typename Internal::Int32Or64::Type)x; +} + +//! Casts an integer `x` to either `uint32_t` or `uint64_t` depending on `T`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::Int32Or64::Type asUInt(const T& x) noexcept { + return (typename Internal::Int32Or64::Type)x; +} + +//! Casts an integer `x` to either `int32_t`, uint32_t`, `int64_t`, or `uint64_t` depending on `T`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::Int32Or64::Type asNormalized(const T& x) noexcept { + return (typename Internal::Int32Or64::Type)x; +} + +//! Casts an integer `x` to the same type as defined by ``. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::StdInt()>::Type asStdInt(const T& x) noexcept { + return (typename Internal::StdInt()>::Type)x; +} + +//! A helper class that can be used to iterate over enum values. +template +struct EnumValues { + typedef typename std::underlying_type::type ValueType; + + struct Iterator { + ValueType value; + + ASMJIT_INLINE_NODEBUG T operator*() const { return (T)value; } + ASMJIT_INLINE_NODEBUG void operator++() { ++value; } + + ASMJIT_INLINE_NODEBUG bool operator==(const Iterator& other) const noexcept { return value == other.value; } + ASMJIT_INLINE_NODEBUG bool operator!=(const Iterator& other) const noexcept { return value != other.value; } + }; + + ASMJIT_INLINE_NODEBUG Iterator begin() const noexcept { return Iterator{ValueType(from)}; } + ASMJIT_INLINE_NODEBUG Iterator end() const noexcept { return Iterator{ValueType(to) + 1}; } +}; + +// Support - BitCast +// ================= + +//! \cond +namespace Internal { + template + union BitCastUnion { + ASMJIT_INLINE_NODEBUG BitCastUnion(SrcT src) noexcept : src(src) {} + SrcT src; + DstT dst; + }; +} +//! \endcond + +//! Bit-casts from `Src` type to `Dst` type. +//! +//! Useful to bit-cast between integers and floating points. +template +static ASMJIT_INLINE_NODEBUG Dst bitCast(const Src& x) noexcept { return Internal::BitCastUnion(x).dst; } + +// Support - BitOps +// ================ + +//! Storage used to store a pack of bits (should by compatible with a machine word). +typedef Internal::StdInt::Type BitWord; + +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bitSizeOf() noexcept { return uint32_t(sizeof(T) * 8u); } + +//! Number of bits stored in a single `BitWord`. +static constexpr uint32_t kBitWordSizeInBits = bitSizeOf(); + +//! Returns `0 - x` in a safe way (no undefined behavior), works for unsigned numbers as well. +template +static ASMJIT_INLINE_NODEBUG constexpr T neg(const T& x) noexcept { + typedef typename std::make_unsigned::type U; + return T(U(0) - U(x)); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr T allOnes() noexcept { return neg(T(1)); } + +//! Returns `x << y` (shift left logical) by explicitly casting `x` to an unsigned type and back. +template +static ASMJIT_INLINE_NODEBUG constexpr X shl(const X& x, const Y& y) noexcept { + typedef typename std::make_unsigned::type U; + return X(U(x) << y); +} + +//! Returns `x >> y` (shift right logical) by explicitly casting `x` to an unsigned type and back. +template +static ASMJIT_INLINE_NODEBUG constexpr X shr(const X& x, const Y& y) noexcept { + typedef typename std::make_unsigned::type U; + return X(U(x) >> y); +} + +//! Returns `x >> y` (shift right arithmetic) by explicitly casting `x` to a signed type and back. +template +static ASMJIT_INLINE_NODEBUG constexpr X sar(const X& x, const Y& y) noexcept { + typedef typename std::make_signed::type S; + return X(S(x) >> y); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr X ror(const X& x, const Y& y) noexcept { + typedef typename std::make_unsigned::type U; + return X((U(x) >> y) | (U(x) << (bitSizeOf() - U(y)))); +} + +//! Returns `x | (x >> y)` - helper used by some bit manipulation helpers. +template +static ASMJIT_INLINE_NODEBUG constexpr X or_shr(const X& x, const Y& y) noexcept { return X(x | shr(x, y)); } + +//! Returns `x & -x` - extracts lowest set isolated bit (like BLSI instruction). +template +static ASMJIT_INLINE_NODEBUG constexpr T blsi(T x) noexcept { + typedef typename std::make_unsigned::type U; + return T(U(x) & neg(U(x))); +} + +//! Tests whether the given value `x` has `n`th bit set. +template +static ASMJIT_INLINE_NODEBUG constexpr bool bitTest(T x, IndexT n) noexcept { + typedef typename std::make_unsigned::type U; + return (U(x) & (U(1) << asStdInt(n))) != 0; +} + +// Tests whether the given `value` is a consecutive mask of bits that starts at +// the least significant bit. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isLsbMask(const T& value) { + typedef typename std::make_unsigned::type U; + return value && ((U(value) + 1u) & U(value)) == 0; +} + +// Tests whether the given value contains at least one bit or whether it's a +// bit-mask of consecutive bits. +// +// This function is similar to \ref isLsbMask(), but the mask doesn't have to +// start at a least significant bit. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isConsecutiveMask(const T& value) { + typedef typename std::make_unsigned::type U; + return value && isLsbMask((U(value) - 1u) | U(value)); +} + +//! Generates a trailing bit-mask that has `n` least significant (trailing) bits set. +template +static ASMJIT_INLINE_NODEBUG constexpr T lsbMask(const CountT& n) noexcept { + typedef typename std::make_unsigned::type U; + return (sizeof(U) < sizeof(uintptr_t)) + // Prevent undefined behavior by using a larger type than T. + ? T(U((uintptr_t(1) << n) - uintptr_t(1))) + // Prevent undefined behavior by checking `n` before shift. + : n ? T(shr(allOnes(), bitSizeOf() - size_t(n))) : T(0); +} + +//! Generates a leading bit-mask that has `n` most significant (leading) bits set. +template +static ASMJIT_INLINE_NODEBUG constexpr T msbMask(const CountT& n) noexcept { + typedef typename std::make_unsigned::type U; + return (sizeof(U) < sizeof(uintptr_t)) + // Prevent undefined behavior by using a larger type than T. + ? T(allOnes() >> (bitSizeOf() - n)) + // Prevent undefined behavior by performing `n & (nBits - 1)` so it's always within the range. + : T(sar(U(n != 0) << (bitSizeOf() - 1), n ? uint32_t(n - 1) : uint32_t(0))); +} + +//! Returns a bit-mask that has `x` bit set. +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bitMask(const Index& x) noexcept { return (1u << asUInt(x)); } + +//! Returns a bit-mask that has `x` bit set (multiple arguments). +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bitMask(const Index& x, Args... args) noexcept { return bitMask(x) | bitMask(args...); } + +//! Converts a boolean value `b` to zero or full mask (all bits set). +template +static ASMJIT_INLINE_NODEBUG constexpr DstT bitMaskFromBool(SrcT b) noexcept { + typedef typename std::make_unsigned::type U; + return DstT(U(0) - U(b)); +} + +//! Tests whether `a & b` is non-zero. +template +static inline constexpr bool test(A a, B b) noexcept { return (asUInt(a) & asUInt(b)) != 0; } + +//! \cond +namespace Internal { + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint8_t fillTrailingBitsImpl(uint8_t x) noexcept { return or_shr(or_shr(or_shr(x, 1), 2), 4); } + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint16_t fillTrailingBitsImpl(uint16_t x) noexcept { return or_shr(or_shr(or_shr(or_shr(x, 1), 2), 4), 8); } + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint32_t fillTrailingBitsImpl(uint32_t x) noexcept { return or_shr(or_shr(or_shr(or_shr(or_shr(x, 1), 2), 4), 8), 16); } + // Fills all trailing bits right from the first most significant bit set. + static ASMJIT_INLINE_NODEBUG constexpr uint64_t fillTrailingBitsImpl(uint64_t x) noexcept { return or_shr(or_shr(or_shr(or_shr(or_shr(or_shr(x, 1), 2), 4), 8), 16), 32); } +} +//! \endcond + +// Fills all trailing bits right from the first most significant bit set. +template +static ASMJIT_INLINE_NODEBUG constexpr T fillTrailingBits(const T& x) noexcept { + typedef typename std::make_unsigned::type U; + return T(Internal::fillTrailingBitsImpl(U(x))); +} + +// Support - Count Leading/Trailing Zeros +// ====================================== + +//! \cond +namespace Internal { +namespace { + +template +struct BitScanData { T x; uint32_t n; }; + +template +struct BitScanCalc { + static ASMJIT_INLINE_NODEBUG constexpr BitScanData advanceLeft(const BitScanData& data, uint32_t n) noexcept { + return BitScanData { data.x << n, data.n + n }; + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData advanceRight(const BitScanData& data, uint32_t n) noexcept { + return BitScanData { data.x >> n, data.n + n }; + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData clz(const BitScanData& data) noexcept { + return BitScanCalc::clz(advanceLeft(data, data.x & (allOnes() << (bitSizeOf() - N)) ? uint32_t(0) : N)); + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData ctz(const BitScanData& data) noexcept { + return BitScanCalc::ctz(advanceRight(data, data.x & (allOnes() >> (bitSizeOf() - N)) ? uint32_t(0) : N)); + } +}; + +template +struct BitScanCalc { + static ASMJIT_INLINE_NODEBUG constexpr BitScanData clz(const BitScanData& ctx) noexcept { + return BitScanData { 0, ctx.n - uint32_t(ctx.x >> (bitSizeOf() - 1)) }; + } + + static ASMJIT_INLINE_NODEBUG constexpr BitScanData ctz(const BitScanData& ctx) noexcept { + return BitScanData { 0, ctx.n - uint32_t(ctx.x & 0x1) }; + } +}; + +template +ASMJIT_INLINE_NODEBUG constexpr uint32_t clzFallback(const T& x) noexcept { + return BitScanCalc() / 2u>::clz(BitScanData{x, 1}).n; +} + +template +ASMJIT_INLINE_NODEBUG constexpr uint32_t ctzFallback(const T& x) noexcept { + return BitScanCalc() / 2u>::ctz(BitScanData{x, 1}).n; +} + +template ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const T& x) noexcept { return clzFallback(asUInt(x)); } +template ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const T& x) noexcept { return ctzFallback(asUInt(x)); } + +#if !defined(ASMJIT_NO_INTRINSICS) +# if defined(__GNUC__) +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint32_t& x) noexcept { return uint32_t(__builtin_clz(x)); } +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint64_t& x) noexcept { return uint32_t(__builtin_clzll(x)); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint32_t& x) noexcept { return uint32_t(__builtin_ctz(x)); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint64_t& x) noexcept { return uint32_t(__builtin_ctzll(x)); } +# elif defined(_MSC_VER) +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint32_t& x) noexcept { unsigned long i; _BitScanReverse(&i, x); return uint32_t(i ^ 31); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint32_t& x) noexcept { unsigned long i; _BitScanForward(&i, x); return uint32_t(i); } +# if ASMJIT_ARCH_X86 == 64 || ASMJIT_ARCH_ARM == 64 +template<> ASMJIT_INLINE_NODEBUG uint32_t clzImpl(const uint64_t& x) noexcept { unsigned long i; _BitScanReverse64(&i, x); return uint32_t(i ^ 63); } +template<> ASMJIT_INLINE_NODEBUG uint32_t ctzImpl(const uint64_t& x) noexcept { unsigned long i; _BitScanForward64(&i, x); return uint32_t(i); } +# endif +# endif +#endif + +} // {anonymous} +} // {Internal} +//! \endcond + +//! Count leading zeros in `x` (returns a position of a first bit set in `x`). +//! +//! \note The input MUST NOT be zero, otherwise the result is undefined. +template +static ASMJIT_INLINE_NODEBUG uint32_t clz(T x) noexcept { return Internal::clzImpl(asUInt(x)); } + +//! Count trailing zeros in `x` (returns a position of a first bit set in `x`). +//! +//! \note The input MUST NOT be zero, otherwise the result is undefined. +template +static ASMJIT_INLINE_NODEBUG uint32_t ctz(T x) noexcept { return Internal::ctzImpl(asUInt(x)); } + +template +struct ConstCTZ { + static constexpr uint32_t value = + (kInput & (uint64_t(1) << 0)) ? 0 : + (kInput & (uint64_t(1) << 1)) ? 1 : + (kInput & (uint64_t(1) << 2)) ? 2 : + (kInput & (uint64_t(1) << 3)) ? 3 : + (kInput & (uint64_t(1) << 4)) ? 4 : + (kInput & (uint64_t(1) << 5)) ? 5 : + (kInput & (uint64_t(1) << 6)) ? 6 : + (kInput & (uint64_t(1) << 7)) ? 7 : + (kInput & (uint64_t(1) << 8)) ? 8 : + (kInput & (uint64_t(1) << 9)) ? 9 : + (kInput & (uint64_t(1) << 10)) ? 10 : + (kInput & (uint64_t(1) << 11)) ? 11 : + (kInput & (uint64_t(1) << 12)) ? 12 : + (kInput & (uint64_t(1) << 13)) ? 13 : + (kInput & (uint64_t(1) << 14)) ? 14 : + (kInput & (uint64_t(1) << 15)) ? 15 : + (kInput & (uint64_t(1) << 16)) ? 16 : + (kInput & (uint64_t(1) << 17)) ? 17 : + (kInput & (uint64_t(1) << 18)) ? 18 : + (kInput & (uint64_t(1) << 19)) ? 19 : + (kInput & (uint64_t(1) << 20)) ? 20 : + (kInput & (uint64_t(1) << 21)) ? 21 : + (kInput & (uint64_t(1) << 22)) ? 22 : + (kInput & (uint64_t(1) << 23)) ? 23 : + (kInput & (uint64_t(1) << 24)) ? 24 : + (kInput & (uint64_t(1) << 25)) ? 25 : + (kInput & (uint64_t(1) << 26)) ? 26 : + (kInput & (uint64_t(1) << 27)) ? 27 : + (kInput & (uint64_t(1) << 28)) ? 28 : + (kInput & (uint64_t(1) << 29)) ? 29 : + (kInput & (uint64_t(1) << 30)) ? 30 : + (kInput & (uint64_t(1) << 31)) ? 31 : + (kInput & (uint64_t(1) << 32)) ? 32 : + (kInput & (uint64_t(1) << 33)) ? 33 : + (kInput & (uint64_t(1) << 34)) ? 34 : + (kInput & (uint64_t(1) << 35)) ? 35 : + (kInput & (uint64_t(1) << 36)) ? 36 : + (kInput & (uint64_t(1) << 37)) ? 37 : + (kInput & (uint64_t(1) << 38)) ? 38 : + (kInput & (uint64_t(1) << 39)) ? 39 : + (kInput & (uint64_t(1) << 40)) ? 40 : + (kInput & (uint64_t(1) << 41)) ? 41 : + (kInput & (uint64_t(1) << 42)) ? 42 : + (kInput & (uint64_t(1) << 43)) ? 43 : + (kInput & (uint64_t(1) << 44)) ? 44 : + (kInput & (uint64_t(1) << 45)) ? 45 : + (kInput & (uint64_t(1) << 46)) ? 46 : + (kInput & (uint64_t(1) << 47)) ? 47 : + (kInput & (uint64_t(1) << 48)) ? 48 : + (kInput & (uint64_t(1) << 49)) ? 49 : + (kInput & (uint64_t(1) << 50)) ? 50 : + (kInput & (uint64_t(1) << 51)) ? 51 : + (kInput & (uint64_t(1) << 52)) ? 52 : + (kInput & (uint64_t(1) << 53)) ? 53 : + (kInput & (uint64_t(1) << 54)) ? 54 : + (kInput & (uint64_t(1) << 55)) ? 55 : + (kInput & (uint64_t(1) << 56)) ? 56 : + (kInput & (uint64_t(1) << 57)) ? 57 : + (kInput & (uint64_t(1) << 58)) ? 58 : + (kInput & (uint64_t(1) << 59)) ? 59 : + (kInput & (uint64_t(1) << 60)) ? 60 : + (kInput & (uint64_t(1) << 61)) ? 61 : + (kInput & (uint64_t(1) << 62)) ? 62 : + (kInput & (uint64_t(1) << 63)) ? 63 : 64; +}; + +// Support - PopCnt +// ================ + +// Based on the following resource: +// http://graphics.stanford.edu/~seander/bithacks.html +// +// Alternatively, for a very small number of bits in `x`: +// uint32_t n = 0; +// while (x) { +// x &= x - 1; +// n++; +// } +// return n; + +//! \cond +namespace Internal { + static ASMJIT_INLINE_NODEBUG uint32_t constPopcntImpl(uint32_t x) noexcept { + x = x - ((x >> 1) & 0x55555555u); + x = (x & 0x33333333u) + ((x >> 2) & 0x33333333u); + return (((x + (x >> 4)) & 0x0F0F0F0Fu) * 0x01010101u) >> 24; + } + + static ASMJIT_INLINE_NODEBUG uint32_t constPopcntImpl(uint64_t x) noexcept { +#if ASMJIT_ARCH_BITS >= 64 + x = x - ((x >> 1) & 0x5555555555555555u); + x = (x & 0x3333333333333333u) + ((x >> 2) & 0x3333333333333333u); + return uint32_t((((x + (x >> 4)) & 0x0F0F0F0F0F0F0F0Fu) * 0x0101010101010101u) >> 56); +#else + return constPopcntImpl(uint32_t(x >> 32)) + + constPopcntImpl(uint32_t(x & 0xFFFFFFFFu)); +#endif + } + + static ASMJIT_INLINE_NODEBUG uint32_t popcntImpl(uint32_t x) noexcept { +#if defined(__GNUC__) + return uint32_t(__builtin_popcount(x)); +#else + return constPopcntImpl(asUInt(x)); +#endif + } + + static ASMJIT_INLINE_NODEBUG uint32_t popcntImpl(uint64_t x) noexcept { +#if defined(__GNUC__) + return uint32_t(__builtin_popcountll(x)); +#else + return constPopcntImpl(asUInt(x)); +#endif + } +} +//! \endcond + +//! Calculates count of bits in `x`. +template +static ASMJIT_INLINE_NODEBUG uint32_t popcnt(T x) noexcept { return Internal::popcntImpl(asUInt(x)); } + +//! Calculates count of bits in `x` (useful in constant expressions). +template +static ASMJIT_INLINE_NODEBUG uint32_t constPopcnt(T x) noexcept { return Internal::constPopcntImpl(asUInt(x)); } + +// Support - Min/Max +// ================= + +// NOTE: These are constexpr `min()` and `max()` implementations that are not +// exactly the same as `std::min()` and `std::max()`. The return value is not +// a reference to `a` or `b` but it's a new value instead. + +template +static ASMJIT_INLINE_NODEBUG constexpr T min(const T& a, const T& b) noexcept { return b < a ? b : a; } + +template +static ASMJIT_INLINE_NODEBUG constexpr T min(const T& a, const T& b, Args&&... args) noexcept { return min(min(a, b), std::forward(args)...); } + +template +static ASMJIT_INLINE_NODEBUG constexpr T max(const T& a, const T& b) noexcept { return a < b ? b : a; } + +template +static ASMJIT_INLINE_NODEBUG constexpr T max(const T& a, const T& b, Args&&... args) noexcept { return max(max(a, b), std::forward(args)...); } + +// Support - Immediate Helpers +// =========================== + +namespace Internal { + template + struct ImmConv { + static ASMJIT_INLINE_NODEBUG int64_t fromT(const T& x) noexcept { return int64_t(x); } + static ASMJIT_INLINE_NODEBUG T toT(int64_t x) noexcept { return T(uint64_t(x) & Support::allOnes::type>()); } + }; + + template + struct ImmConv { + static ASMJIT_INLINE_NODEBUG int64_t fromT(const T& x) noexcept { return int64_t(bitCast(double(x))); } + static ASMJIT_INLINE_NODEBUG T toT(int64_t x) noexcept { return T(bitCast(x)); } + }; +} + +template +static ASMJIT_INLINE_NODEBUG int64_t immediateFromT(const T& x) noexcept { return Internal::ImmConv::value>::fromT(x); } + +template +static ASMJIT_INLINE_NODEBUG T immediateToT(int64_t x) noexcept { return Internal::ImmConv::value>::toT(x); } + +// Support - Overflow Arithmetic +// ============================= + +//! \cond +namespace Internal { + template + inline T addOverflowFallback(T x, T y, FastUInt8* of) noexcept { + typedef typename std::make_unsigned::type U; + + U result = U(x) + U(y); + *of = FastUInt8(*of | FastUInt8(isUnsigned() ? result < U(x) : T((U(x) ^ ~U(y)) & (U(x) ^ result)) < 0)); + return T(result); + } + + template + inline T subOverflowFallback(T x, T y, FastUInt8* of) noexcept { + typedef typename std::make_unsigned::type U; + + U result = U(x) - U(y); + *of = FastUInt8(*of | FastUInt8(isUnsigned() ? result > U(x) : T((U(x) ^ U(y)) & (U(x) ^ result)) < 0)); + return T(result); + } + + template + inline T mulOverflowFallback(T x, T y, FastUInt8* of) noexcept { + typedef typename Internal::StdInt()>::Type I; + typedef typename std::make_unsigned::type U; + + U mask = allOnes(); + if (std::is_signed::value) { + U prod = U(I(x)) * U(I(y)); + *of = FastUInt8(*of | FastUInt8(I(prod) < I(std::numeric_limits::lowest()) || I(prod) > I(std::numeric_limits::max()))); + return T(I(prod & mask)); + } + else { + U prod = U(x) * U(y); + *of = FastUInt8(*of | FastUInt8((prod & ~mask) != 0)); + return T(prod & mask); + } + } + + template<> + inline int64_t mulOverflowFallback(int64_t x, int64_t y, FastUInt8* of) noexcept { + int64_t result = int64_t(uint64_t(x) * uint64_t(y)); + *of = FastUInt8(*of | FastUInt8(x && (result / x != y))); + return result; + } + + template<> + inline uint64_t mulOverflowFallback(uint64_t x, uint64_t y, FastUInt8* of) noexcept { + uint64_t result = x * y; + *of = FastUInt8(*of | FastUInt8(y != 0 && allOnes() / y < x)); + return result; + } + + // These can be specialized. + template inline T addOverflowImpl(const T& x, const T& y, FastUInt8* of) noexcept { return addOverflowFallback(x, y, of); } + template inline T subOverflowImpl(const T& x, const T& y, FastUInt8* of) noexcept { return subOverflowFallback(x, y, of); } + template inline T mulOverflowImpl(const T& x, const T& y, FastUInt8* of) noexcept { return mulOverflowFallback(x, y, of); } + +#if defined(__GNUC__) && !defined(ASMJIT_NO_INTRINSICS) +#if defined(__clang__) || __GNUC__ >= 5 +#define ASMJIT_ARITH_OVERFLOW_SPECIALIZE(FUNC, T, RESULT_T, BUILTIN) \ + template<> \ + inline T FUNC(const T& x, const T& y, FastUInt8* of) noexcept { \ + RESULT_T result; \ + *of = FastUInt8(*of | (BUILTIN((RESULT_T)x, (RESULT_T)y, &result))); \ + return T(result); \ + } + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, int32_t , int , __builtin_sadd_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint32_t, unsigned int , __builtin_uadd_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, int64_t , long long , __builtin_saddll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint64_t, unsigned long long, __builtin_uaddll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, int32_t , int , __builtin_ssub_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint32_t, unsigned int , __builtin_usub_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, int64_t , long long , __builtin_ssubll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint64_t, unsigned long long, __builtin_usubll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, int32_t , int , __builtin_smul_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, uint32_t, unsigned int , __builtin_umul_overflow ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, int64_t , long long , __builtin_smulll_overflow) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(mulOverflowImpl, uint64_t, unsigned long long, __builtin_umulll_overflow) +#undef ASMJIT_ARITH_OVERFLOW_SPECIALIZE +#endif +#endif + + // There is a bug in MSVC that makes these specializations unusable, maybe in the future... +#if defined(_MSC_VER) && 0 +#define ASMJIT_ARITH_OVERFLOW_SPECIALIZE(FUNC, T, ALT_T, BUILTIN) \ + template<> \ + inline T FUNC(T x, T y, FastUInt8* of) noexcept { \ + ALT_T result; \ + *of = FastUInt8(*of | BUILTIN(0, (ALT_T)x, (ALT_T)y, &result)); \ + return T(result); \ + } + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint32_t, unsigned int , _addcarry_u32 ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint32_t, unsigned int , _subborrow_u32) +#if ARCH_BITS >= 64 + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(addOverflowImpl, uint64_t, unsigned __int64 , _addcarry_u64 ) + ASMJIT_ARITH_OVERFLOW_SPECIALIZE(subOverflowImpl, uint64_t, unsigned __int64 , _subborrow_u64) +#endif +#undef ASMJIT_ARITH_OVERFLOW_SPECIALIZE +#endif +} // {Internal} +//! \endcond + +template +static inline T addOverflow(const T& x, const T& y, FastUInt8* of) noexcept { return T(Internal::addOverflowImpl(asStdInt(x), asStdInt(y), of)); } + +template +static inline T subOverflow(const T& x, const T& y, FastUInt8* of) noexcept { return T(Internal::subOverflowImpl(asStdInt(x), asStdInt(y), of)); } + +template +static inline T mulOverflow(const T& x, const T& y, FastUInt8* of) noexcept { return T(Internal::mulOverflowImpl(asStdInt(x), asStdInt(y), of)); } + +// Support - Alignment +// =================== + +template +static ASMJIT_INLINE_NODEBUG constexpr bool isAligned(X base, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return ((U)base % (U)alignment) == 0; +} + +//! Tests whether the `x` is a power of two (only one bit is set). +template +static ASMJIT_INLINE_NODEBUG constexpr bool isPowerOf2(T x) noexcept { + typedef typename std::make_unsigned::type U; + return x && !(U(x) & (U(x) - U(1))); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr X alignUp(X x, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return (X)( ((U)x + ((U)(alignment) - 1u)) & ~((U)(alignment) - 1u) ); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr T alignUpPowerOf2(T x) noexcept { + typedef typename Internal::StdInt::Type U; + return (T)(fillTrailingBits(U(x) - 1u) + 1u); +} + +//! Returns either zero or a positive difference between `base` and `base` when +//! aligned to `alignment`. +template +static ASMJIT_INLINE_NODEBUG constexpr typename Internal::StdInt::Type alignUpDiff(X base, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return alignUp(U(base), alignment) - U(base); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr X alignDown(X x, Y alignment) noexcept { + typedef typename Internal::StdInt::Type U; + return (X)( (U)x & ~((U)(alignment) - 1u) ); +} + +// Support - NumGranularized +// ========================= + +//! Calculates the number of elements that would be required if `base` is +//! granularized by `granularity`. This function can be used to calculate +//! the number of BitWords to represent N bits, for example. +template +static ASMJIT_INLINE_NODEBUG constexpr X numGranularized(X base, Y granularity) noexcept { + typedef typename Internal::StdInt::Type U; + return X((U(base) + U(granularity) - 1) / U(granularity)); +} + +// Support - IsBetween +// =================== + +//! Checks whether `x` is greater than or equal to `a` and lesser than or equal to `b`. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isBetween(const T& x, const T& a, const T& b) noexcept { + return x >= a && x <= b; +} + +// Support - IsInt & IsUInt +// ======================== + +//! Checks whether the given integer `x` can be casted to a 4-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt4(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? isBetween(S(x), -8, 7) : U(x) <= U(7u); +} + +//! Checks whether the given integer `x` can be casted to a 7-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt7(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? isBetween(S(x), -64, 63) : U(x) <= U(63u); +} + +//! Checks whether the given integer `x` can be casted to an 8-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt8(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 1 || isBetween(S(x), -128, 127) : U(x) <= U(127u); +} + +//! Checks whether the given integer `x` can be casted to a 9-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt9(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 1 || isBetween(S(x), -256, 255) + : sizeof(T) <= 1 || U(x) <= U(255u); +} + +//! Checks whether the given integer `x` can be casted to a 10-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt10(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 1 || isBetween(S(x), -512, 511) + : sizeof(T) <= 1 || U(x) <= U(511u); +} + +//! Checks whether the given integer `x` can be casted to a 16-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt16(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 2 || isBetween(S(x), -32768, 32767) + : sizeof(T) <= 1 || U(x) <= U(32767u); +} + +//! Checks whether the given integer `x` can be casted to a 32-bit signed integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isInt32(T x) noexcept { + typedef typename std::make_signed::type S; + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? sizeof(T) <= 4 || isBetween(S(x), -2147483647 - 1, 2147483647) + : sizeof(T) <= 2 || U(x) <= U(2147483647u); +} + +//! Checks whether the given integer `x` can be casted to a 4-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt4(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? x >= T(0) && x <= T(15) + : U(x) <= U(15u); +} + +//! Checks whether the given integer `x` can be casted to an 8-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt8(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 1 || T(x) <= T(255)) && x >= T(0) + : (sizeof(T) <= 1 || U(x) <= U(255u)); +} + +//! Checks whether the given integer `x` can be casted to a 12-bit unsigned integer (ARM specific). +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt12(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 1 || T(x) <= T(4095)) && x >= T(0) + : (sizeof(T) <= 1 || U(x) <= U(4095u)); +} + +//! Checks whether the given integer `x` can be casted to a 16-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt16(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 2 || T(x) <= T(65535)) && x >= T(0) + : (sizeof(T) <= 2 || U(x) <= U(65535u)); +} + +//! Checks whether the given integer `x` can be casted to a 32-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt32(T x) noexcept { + typedef typename std::make_unsigned::type U; + + return std::is_signed::value ? (sizeof(T) <= 4 || T(x) <= T(4294967295u)) && x >= T(0) + : (sizeof(T) <= 4 || U(x) <= U(4294967295u)); +} + +//! Checks whether the given integer `x` can be casted to a 32-bit unsigned integer. +template +static ASMJIT_INLINE_NODEBUG constexpr bool isIntOrUInt32(T x) noexcept { + return sizeof(T) <= 4 ? true : (uint32_t(uint64_t(x) >> 32) + 1u) <= 1u; +} + +static bool ASMJIT_INLINE_NODEBUG isEncodableOffset32(int32_t offset, uint32_t nBits) noexcept { + uint32_t nRev = 32 - nBits; + return Support::sar(Support::shl(offset, nRev), nRev) == offset; +} + +static bool ASMJIT_INLINE_NODEBUG isEncodableOffset64(int64_t offset, uint32_t nBits) noexcept { + uint32_t nRev = 64 - nBits; + return Support::sar(Support::shl(offset, nRev), nRev) == offset; +} + +// Support - ByteSwap +// ================== + +static ASMJIT_INLINE_NODEBUG uint16_t byteswap16(uint16_t x) noexcept { + return uint16_t(((x >> 8) & 0xFFu) | ((x & 0xFFu) << 8)); +} + +static ASMJIT_INLINE_NODEBUG uint32_t byteswap32(uint32_t x) noexcept { + return (x << 24) | (x >> 24) | ((x << 8) & 0x00FF0000u) | ((x >> 8) & 0x0000FF00); +} + +static ASMJIT_INLINE_NODEBUG uint64_t byteswap64(uint64_t x) noexcept { +#if (defined(__GNUC__) || defined(__clang__)) && !defined(ASMJIT_NO_INTRINSICS) + return uint64_t(__builtin_bswap64(uint64_t(x))); +#elif defined(_MSC_VER) && !defined(ASMJIT_NO_INTRINSICS) + return uint64_t(_byteswap_uint64(uint64_t(x))); +#else + return (uint64_t(byteswap32(uint32_t(uint64_t(x) >> 32 ))) ) | + (uint64_t(byteswap32(uint32_t(uint64_t(x) & 0xFFFFFFFFu))) << 32) ; +#endif +} + +// Support - BytePack & Unpack +// =========================== + +//! Pack four 8-bit integer into a 32-bit integer as it is an array of `{b0,b1,b2,b3}`. +static ASMJIT_INLINE_NODEBUG constexpr uint32_t bytepack32_4x8(uint32_t a, uint32_t b, uint32_t c, uint32_t d) noexcept { + return ASMJIT_ARCH_LE ? (a | (b << 8) | (c << 16) | (d << 24)) + : (d | (c << 8) | (b << 16) | (a << 24)); +} + +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t unpackU32At0(T x) noexcept { return ASMJIT_ARCH_LE ? uint32_t(uint64_t(x) & 0xFFFFFFFFu) : uint32_t(uint64_t(x) >> 32); } +template +static ASMJIT_INLINE_NODEBUG constexpr uint32_t unpackU32At1(T x) noexcept { return ASMJIT_ARCH_BE ? uint32_t(uint64_t(x) & 0xFFFFFFFFu) : uint32_t(uint64_t(x) >> 32); } + +// Support - Position of byte (in bit-shift) +// ========================================= + +static ASMJIT_INLINE_NODEBUG uint32_t byteShiftOfDWordStruct(uint32_t index) noexcept { + return ASMJIT_ARCH_LE ? index * 8 : (uint32_t(sizeof(uint32_t)) - 1u - index) * 8; +} + +// Support - String Utilities +// ========================== + +template +static ASMJIT_INLINE_NODEBUG constexpr T asciiToLower(T c) noexcept { return T(c ^ T(T(c >= T('A') && c <= T('Z')) << 5)); } + +template +static ASMJIT_INLINE_NODEBUG constexpr T asciiToUpper(T c) noexcept { return T(c ^ T(T(c >= T('a') && c <= T('z')) << 5)); } + +static ASMJIT_INLINE_NODEBUG size_t strLen(const char* s, size_t maxSize) noexcept { + size_t i = 0; + while (i < maxSize && s[i] != '\0') + i++; + return i; +} + +static ASMJIT_INLINE_NODEBUG constexpr uint32_t hashRound(uint32_t hash, uint32_t c) noexcept { return hash * 65599 + c; } + +// Gets a hash of the given string `data` of size `size`. Size must be valid +// as this function doesn't check for a null terminator and allows it in the +// middle of the string. +static ASMJIT_INLINE_NODEBUG uint32_t hashString(const char* data, size_t size) noexcept { + uint32_t hashCode = 0; + for (uint32_t i = 0; i < size; i++) + hashCode = hashRound(hashCode, uint8_t(data[i])); + return hashCode; +} + +static ASMJIT_INLINE_NODEBUG const char* findPackedString(const char* p, uint32_t id) noexcept { + uint32_t i = 0; + while (i < id) { + while (p[0]) + p++; + p++; + i++; + } + return p; +} + +//! Compares two string views. +static ASMJIT_FORCE_INLINE int compareStringViews(const char* aData, size_t aSize, const char* bData, size_t bSize) noexcept { + size_t size = Support::min(aSize, bSize); + + for (size_t i = 0; i < size; i++) { + int c = int(uint8_t(aData[i])) - int(uint8_t(bData[i])); + if (c != 0) + return c; + } + + return int(aSize) - int(bSize); +} + +// Support - Memory Read Access - 8 Bits +// ===================================== + +static ASMJIT_INLINE_NODEBUG uint8_t readU8(const void* p) noexcept { return static_cast(p)[0]; } +static ASMJIT_INLINE_NODEBUG int8_t readI8(const void* p) noexcept { return static_cast(p)[0]; } + +// Support - Memory Read Access - 16 Bits +// ====================================== + +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16x(const void* p) noexcept { + typedef typename Internal::AliasedUInt::T U16AlignedToN; + uint16_t x = static_cast(p)[0]; + return BO == ByteOrder::kNative ? x : byteswap16(x); +} + +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16u(const void* p) noexcept { return readU16x(p); } +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16uLE(const void* p) noexcept { return readU16x(p); } +template +static ASMJIT_INLINE_NODEBUG uint16_t readU16uBE(const void* p) noexcept { return readU16x(p); } + +static ASMJIT_INLINE_NODEBUG uint16_t readU16a(const void* p) noexcept { return readU16x(p); } +static ASMJIT_INLINE_NODEBUG uint16_t readU16aLE(const void* p) noexcept { return readU16x(p); } +static ASMJIT_INLINE_NODEBUG uint16_t readU16aBE(const void* p) noexcept { return readU16x(p); } + +template +static ASMJIT_INLINE_NODEBUG int16_t readI16x(const void* p) noexcept { return int16_t(readU16x(p)); } + +template +static ASMJIT_INLINE_NODEBUG int16_t readI16u(const void* p) noexcept { return int16_t(readU16x(p)); } +template +static ASMJIT_INLINE_NODEBUG int16_t readI16uLE(const void* p) noexcept { return int16_t(readU16x(p)); } +template +static ASMJIT_INLINE_NODEBUG int16_t readI16uBE(const void* p) noexcept { return int16_t(readU16x(p)); } + +static ASMJIT_INLINE_NODEBUG int16_t readI16a(const void* p) noexcept { return int16_t(readU16x(p)); } +static ASMJIT_INLINE_NODEBUG int16_t readI16aLE(const void* p) noexcept { return int16_t(readU16x(p)); } +static ASMJIT_INLINE_NODEBUG int16_t readI16aBE(const void* p) noexcept { return int16_t(readU16x(p)); } + +// Support - Memory Read Access - 24 Bits +// ====================================== + +template +static inline uint32_t readU24u(const void* p) noexcept { + uint32_t b0 = readU8(static_cast(p) + (BO == ByteOrder::kLE ? 2u : 0u)); + uint32_t b1 = readU8(static_cast(p) + 1u); + uint32_t b2 = readU8(static_cast(p) + (BO == ByteOrder::kLE ? 0u : 2u)); + return (b0 << 16) | (b1 << 8) | b2; +} + +static inline uint32_t readU24uLE(const void* p) noexcept { return readU24u(p); } +static inline uint32_t readU24uBE(const void* p) noexcept { return readU24u(p); } + +// Support - Memory Read Access - 32 Bits +// ====================================== + +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32x(const void* p) noexcept { + typedef typename Internal::AliasedUInt::T U32AlignedToN; + uint32_t x = static_cast(p)[0]; + return BO == ByteOrder::kNative ? x : byteswap32(x); +} + +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32u(const void* p) noexcept { return readU32x(p); } +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32uLE(const void* p) noexcept { return readU32x(p); } +template +static ASMJIT_INLINE_NODEBUG uint32_t readU32uBE(const void* p) noexcept { return readU32x(p); } + +static ASMJIT_INLINE_NODEBUG uint32_t readU32a(const void* p) noexcept { return readU32x(p); } +static ASMJIT_INLINE_NODEBUG uint32_t readU32aLE(const void* p) noexcept { return readU32x(p); } +static ASMJIT_INLINE_NODEBUG uint32_t readU32aBE(const void* p) noexcept { return readU32x(p); } + +template +static ASMJIT_INLINE_NODEBUG uint32_t readI32x(const void* p) noexcept { return int32_t(readU32x(p)); } + +template +static ASMJIT_INLINE_NODEBUG int32_t readI32u(const void* p) noexcept { return int32_t(readU32x(p)); } +template +static ASMJIT_INLINE_NODEBUG int32_t readI32uLE(const void* p) noexcept { return int32_t(readU32x(p)); } +template +static ASMJIT_INLINE_NODEBUG int32_t readI32uBE(const void* p) noexcept { return int32_t(readU32x(p)); } + +static ASMJIT_INLINE_NODEBUG int32_t readI32a(const void* p) noexcept { return int32_t(readU32x(p)); } +static ASMJIT_INLINE_NODEBUG int32_t readI32aLE(const void* p) noexcept { return int32_t(readU32x(p)); } +static ASMJIT_INLINE_NODEBUG int32_t readI32aBE(const void* p) noexcept { return int32_t(readU32x(p)); } + +// Support - Memory Read Access - 64 Bits +// ====================================== + +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64x(const void* p) noexcept { + typedef typename Internal::AliasedUInt::T U64AlignedToN; + uint64_t x = static_cast(p)[0]; + return BO == ByteOrder::kNative ? x : byteswap64(x); +} + +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64u(const void* p) noexcept { return readU64x(p); } +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64uLE(const void* p) noexcept { return readU64x(p); } +template +static ASMJIT_INLINE_NODEBUG uint64_t readU64uBE(const void* p) noexcept { return readU64x(p); } + +static ASMJIT_INLINE_NODEBUG uint64_t readU64a(const void* p) noexcept { return readU64x(p); } +static ASMJIT_INLINE_NODEBUG uint64_t readU64aLE(const void* p) noexcept { return readU64x(p); } +static ASMJIT_INLINE_NODEBUG uint64_t readU64aBE(const void* p) noexcept { return readU64x(p); } + +template +static ASMJIT_INLINE_NODEBUG int64_t readI64x(const void* p) noexcept { return int64_t(readU64x(p)); } + +template +static ASMJIT_INLINE_NODEBUG int64_t readI64u(const void* p) noexcept { return int64_t(readU64x(p)); } +template +static ASMJIT_INLINE_NODEBUG int64_t readI64uLE(const void* p) noexcept { return int64_t(readU64x(p)); } +template +static ASMJIT_INLINE_NODEBUG int64_t readI64uBE(const void* p) noexcept { return int64_t(readU64x(p)); } + +static ASMJIT_INLINE_NODEBUG int64_t readI64a(const void* p) noexcept { return int64_t(readU64x(p)); } +static ASMJIT_INLINE_NODEBUG int64_t readI64aLE(const void* p) noexcept { return int64_t(readU64x(p)); } +static ASMJIT_INLINE_NODEBUG int64_t readI64aBE(const void* p) noexcept { return int64_t(readU64x(p)); } + +// Support - Memory Write Access - 8 Bits +// ====================================== + +static ASMJIT_INLINE_NODEBUG void writeU8(void* p, uint8_t x) noexcept { static_cast(p)[0] = x; } +static ASMJIT_INLINE_NODEBUG void writeI8(void* p, int8_t x) noexcept { static_cast(p)[0] = x; } + +// Support - Memory Write Access - 16 Bits +// ======================================= + +template +static ASMJIT_INLINE_NODEBUG void writeU16x(void* p, uint16_t x) noexcept { + typedef typename Internal::AliasedUInt::T U16AlignedToN; + static_cast(p)[0] = BO == ByteOrder::kNative ? x : byteswap16(x); +} + +template +static ASMJIT_INLINE_NODEBUG void writeU16uLE(void* p, uint16_t x) noexcept { writeU16x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU16uBE(void* p, uint16_t x) noexcept { writeU16x(p, x); } + +static ASMJIT_INLINE_NODEBUG void writeU16a(void* p, uint16_t x) noexcept { writeU16x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU16aLE(void* p, uint16_t x) noexcept { writeU16x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU16aBE(void* p, uint16_t x) noexcept { writeU16x(p, x); } + + +template +static ASMJIT_INLINE_NODEBUG void writeI16x(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } + +template +static ASMJIT_INLINE_NODEBUG void writeI16uLE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI16uBE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } + +static ASMJIT_INLINE_NODEBUG void writeI16a(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI16aLE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI16aBE(void* p, int16_t x) noexcept { writeU16x(p, uint16_t(x)); } + +// Support - Memory Write Access - 24 Bits +// ======================================= + +template +static inline void writeU24u(void* p, uint32_t v) noexcept { + static_cast(p)[0] = uint8_t((v >> (BO == ByteOrder::kLE ? 0 : 16)) & 0xFFu); + static_cast(p)[1] = uint8_t((v >> 8) & 0xFFu); + static_cast(p)[2] = uint8_t((v >> (BO == ByteOrder::kLE ? 16 : 0)) & 0xFFu); +} + +static inline void writeU24uLE(void* p, uint32_t v) noexcept { writeU24u(p, v); } +static inline void writeU24uBE(void* p, uint32_t v) noexcept { writeU24u(p, v); } + +// Support - Memory Write Access - 32 Bits +// ======================================= + +template +static ASMJIT_INLINE_NODEBUG void writeU32x(void* p, uint32_t x) noexcept { + typedef typename Internal::AliasedUInt::T U32AlignedToN; + static_cast(p)[0] = (BO == ByteOrder::kNative) ? x : Support::byteswap32(x); +} + +template +static ASMJIT_INLINE_NODEBUG void writeU32u(void* p, uint32_t x) noexcept { writeU32x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU32uLE(void* p, uint32_t x) noexcept { writeU32x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU32uBE(void* p, uint32_t x) noexcept { writeU32x(p, x); } + +static ASMJIT_INLINE_NODEBUG void writeU32a(void* p, uint32_t x) noexcept { writeU32x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU32aLE(void* p, uint32_t x) noexcept { writeU32x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU32aBE(void* p, uint32_t x) noexcept { writeU32x(p, x); } + +template +static ASMJIT_INLINE_NODEBUG void writeI32x(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } + +template +static ASMJIT_INLINE_NODEBUG void writeI32u(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI32uLE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI32uBE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } + +static ASMJIT_INLINE_NODEBUG void writeI32a(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI32aLE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI32aBE(void* p, int32_t x) noexcept { writeU32x(p, uint32_t(x)); } + +// Support - Memory Write Access - 64 Bits +// ======================================= + +template +static ASMJIT_INLINE_NODEBUG void writeU64x(void* p, uint64_t x) noexcept { + typedef typename Internal::AliasedUInt::T U64AlignedToN; + static_cast(p)[0] = BO == ByteOrder::kNative ? x : byteswap64(x); +} + +template +static ASMJIT_INLINE_NODEBUG void writeU64u(void* p, uint64_t x) noexcept { writeU64x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU64uLE(void* p, uint64_t x) noexcept { writeU64x(p, x); } +template +static ASMJIT_INLINE_NODEBUG void writeU64uBE(void* p, uint64_t x) noexcept { writeU64x(p, x); } + +static ASMJIT_INLINE_NODEBUG void writeU64a(void* p, uint64_t x) noexcept { writeU64x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU64aLE(void* p, uint64_t x) noexcept { writeU64x(p, x); } +static ASMJIT_INLINE_NODEBUG void writeU64aBE(void* p, uint64_t x) noexcept { writeU64x(p, x); } + +template +static ASMJIT_INLINE_NODEBUG void writeI64x(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } + +template +static ASMJIT_INLINE_NODEBUG void writeI64u(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI64uLE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +template +static ASMJIT_INLINE_NODEBUG void writeI64uBE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } + +static ASMJIT_INLINE_NODEBUG void writeI64a(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI64aLE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } +static ASMJIT_INLINE_NODEBUG void writeI64aBE(void* p, int64_t x) noexcept { writeU64x(p, uint64_t(x)); } + +// Support - Operators +// =================== + +//! \cond INTERNAL +struct Set { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { DebugUtils::unused(x); return y; } }; +struct SetNot { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { DebugUtils::unused(x); return ~y; } }; +struct And { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x & y; } }; +struct AndNot { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x & ~y; } }; +struct NotAnd { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return ~x & y; } }; +struct Or { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x | y; } }; +struct Xor { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x ^ y; } }; +struct Add { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x + y; } }; +struct Sub { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return x - y; } }; +struct Min { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return min(x, y); } }; +struct Max { template static ASMJIT_INLINE_NODEBUG T op(T x, T y) noexcept { return max(x, y); } }; +//! \endcond + +// Support - BitWordIterator +// ========================= + +//! Iterates over each bit in a number which is set to 1. +//! +//! Example of use: +//! +//! ``` +//! uint32_t bitsToIterate = 0x110F; +//! Support::BitWordIterator it(bitsToIterate); +//! +//! while (it.hasNext()) { +//! uint32_t bitIndex = it.next(); +//! std::printf("Bit at %u is set\n", unsigned(bitIndex)); +//! } +//! ``` +template +class BitWordIterator { +public: + ASMJIT_INLINE_NODEBUG explicit BitWordIterator(T bitWord) noexcept + : _bitWord(bitWord) {} + + ASMJIT_INLINE_NODEBUG void init(T bitWord) noexcept { _bitWord = bitWord; } + ASMJIT_INLINE_NODEBUG bool hasNext() const noexcept { return _bitWord != 0; } + + ASMJIT_FORCE_INLINE uint32_t next() noexcept { + ASMJIT_ASSERT(_bitWord != 0); + uint32_t index = ctz(_bitWord); + _bitWord &= T(_bitWord - 1); + return index; + } + + T _bitWord; +}; + +// Support - BitVectorOps +// ====================== + +//! \cond +namespace Internal { + template + static ASMJIT_FORCE_INLINE void bitVectorOp(T* buf, size_t index, size_t count) noexcept { + if (count == 0) + return; + + const size_t kTSizeInBits = bitSizeOf(); + size_t vecIndex = index / kTSizeInBits; // T[] + size_t bitIndex = index % kTSizeInBits; // T[][] + + buf += vecIndex; + + // The first BitWord requires special handling to preserve bits outside the fill region. + const T kFillMask = allOnes(); + size_t firstNBits = min(kTSizeInBits - bitIndex, count); + + buf[0] = OperatorT::op(buf[0], (kFillMask >> (kTSizeInBits - firstNBits)) << bitIndex); + buf++; + count -= firstNBits; + + // All bits between the first and last affected BitWords can be just filled. + while (count >= kTSizeInBits) { + buf[0] = FullWordOpT::op(buf[0], kFillMask); + buf++; + count -= kTSizeInBits; + } + + // The last BitWord requires special handling as well + if (count) + buf[0] = OperatorT::op(buf[0], kFillMask >> (kTSizeInBits - count)); + } +} +//! \endcond + +//! Sets bit in a bit-vector `buf` at `index`. +template +static ASMJIT_INLINE_NODEBUG bool bitVectorGetBit(T* buf, size_t index) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + + size_t vecIndex = index / kTSizeInBits; + size_t bitIndex = index % kTSizeInBits; + + return bool((buf[vecIndex] >> bitIndex) & 0x1u); +} + +//! Sets bit in a bit-vector `buf` at `index` to `value`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorSetBit(T* buf, size_t index, bool value) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + + size_t vecIndex = index / kTSizeInBits; + size_t bitIndex = index % kTSizeInBits; + + T bitMask = T(1u) << bitIndex; + if (value) + buf[vecIndex] |= bitMask; + else + buf[vecIndex] &= ~bitMask; +} + +//! Sets bit in a bit-vector `buf` at `index` to `value`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorFlipBit(T* buf, size_t index) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + + size_t vecIndex = index / kTSizeInBits; + size_t bitIndex = index % kTSizeInBits; + + T bitMask = T(1u) << bitIndex; + buf[vecIndex] ^= bitMask; +} + +//! Fills `count` bits in bit-vector `buf` starting at bit-index `index`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorFill(T* buf, size_t index, size_t count) noexcept { Internal::bitVectorOp(buf, index, count); } + +//! Clears `count` bits in bit-vector `buf` starting at bit-index `index`. +template +static ASMJIT_INLINE_NODEBUG void bitVectorClear(T* buf, size_t index, size_t count) noexcept { Internal::bitVectorOp(buf, index, count); } + +template +static ASMJIT_FORCE_INLINE size_t bitVectorIndexOf(T* buf, size_t start, bool value) noexcept { + const size_t kTSizeInBits = bitSizeOf(); + size_t vecIndex = start / kTSizeInBits; // T[] + size_t bitIndex = start % kTSizeInBits; // T[][] + + T* p = buf + vecIndex; + + // We always look for zeros, if value is `true` we have to flip all bits before the search. + const T kFillMask = allOnes(); + const T kFlipMask = value ? T(0) : kFillMask; + + // The first BitWord requires special handling as there are some bits we want to ignore. + T bits = (*p ^ kFlipMask) & (kFillMask << bitIndex); + for (;;) { + if (bits) + return (size_t)(p - buf) * kTSizeInBits + ctz(bits); + bits = *++p ^ kFlipMask; + } +} + +// Support - BitVectorIterator +// =========================== + +template +class BitVectorIterator { +public: + const T* _ptr; + size_t _idx; + size_t _end; + T _current; + + ASMJIT_INLINE_NODEBUG BitVectorIterator(const BitVectorIterator& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG BitVectorIterator(const T* data, size_t numBitWords, size_t start = 0) noexcept { + init(data, numBitWords, start); + } + + ASMJIT_FORCE_INLINE void init(const T* data, size_t numBitWords, size_t start = 0) noexcept { + const T* ptr = data + (start / bitSizeOf()); + size_t idx = alignDown(start, bitSizeOf()); + size_t end = numBitWords * bitSizeOf(); + + T bitWord = T(0); + if (idx < end) { + bitWord = *ptr++ & (allOnes() << (start % bitSizeOf())); + while (!bitWord && (idx += bitSizeOf()) < end) + bitWord = *ptr++; + } + + _ptr = ptr; + _idx = idx; + _end = end; + _current = bitWord; + } + + ASMJIT_INLINE_NODEBUG bool hasNext() const noexcept { + return _current != T(0); + } + + ASMJIT_FORCE_INLINE size_t next() noexcept { + T bitWord = _current; + ASMJIT_ASSERT(bitWord != T(0)); + + uint32_t bit = ctz(bitWord); + bitWord &= T(bitWord - 1u); + + size_t n = _idx + bit; + while (!bitWord && (_idx += bitSizeOf()) < _end) + bitWord = *_ptr++; + + _current = bitWord; + return n; + } + + ASMJIT_FORCE_INLINE size_t peekNext() const noexcept { + ASMJIT_ASSERT(_current != T(0)); + return _idx + ctz(_current); + } +}; + +// Support - BitVectorOpIterator +// ============================= + +template +class BitVectorOpIterator { +public: + enum : uint32_t { + kTSizeInBits = bitSizeOf() + }; + + const T* _aPtr; + const T* _bPtr; + size_t _idx; + size_t _end; + T _current; + + ASMJIT_INLINE_NODEBUG BitVectorOpIterator(const T* aData, const T* bData, size_t numBitWords, size_t start = 0) noexcept { + init(aData, bData, numBitWords, start); + } + + ASMJIT_FORCE_INLINE void init(const T* aData, const T* bData, size_t numBitWords, size_t start = 0) noexcept { + const T* aPtr = aData + (start / bitSizeOf()); + const T* bPtr = bData + (start / bitSizeOf()); + size_t idx = alignDown(start, bitSizeOf()); + size_t end = numBitWords * bitSizeOf(); + + T bitWord = T(0); + if (idx < end) { + bitWord = OperatorT::op(*aPtr++, *bPtr++) & (allOnes() << (start % bitSizeOf())); + while (!bitWord && (idx += kTSizeInBits) < end) + bitWord = OperatorT::op(*aPtr++, *bPtr++); + } + + _aPtr = aPtr; + _bPtr = bPtr; + _idx = idx; + _end = end; + _current = bitWord; + } + + ASMJIT_INLINE_NODEBUG bool hasNext() noexcept { + return _current != T(0); + } + + ASMJIT_FORCE_INLINE size_t next() noexcept { + T bitWord = _current; + ASMJIT_ASSERT(bitWord != T(0)); + + uint32_t bit = ctz(bitWord); + bitWord &= T(bitWord - 1u); + + size_t n = _idx + bit; + while (!bitWord && (_idx += kTSizeInBits) < _end) + bitWord = OperatorT::op(*_aPtr++, *_bPtr++); + + _current = bitWord; + return n; + } +}; + +// Support - Sorting +// ================= + +//! Sort order. +enum class SortOrder : uint32_t { + //!< Ascending order. + kAscending = 0, + //!< Descending order. + kDescending = 1 +}; + +//! A helper class that provides comparison of any user-defined type that +//! implements `<` and `>` operators (primitive types are supported as well). +template +struct Compare { + template + ASMJIT_INLINE_NODEBUG int operator()(const A& a, const B& b) const noexcept { + return kOrder == SortOrder::kAscending ? int(a > b) - int(a < b) : int(a < b) - int(a > b); + } +}; + +//! Insertion sort. +template> +static inline void iSort(T* base, size_t size, const CompareT& cmp = CompareT()) noexcept { + for (T* pm = base + 1; pm < base + size; pm++) + for (T* pl = pm; pl > base && cmp(pl[-1], pl[0]) > 0; pl--) + std::swap(pl[-1], pl[0]); +} + +//! \cond +namespace Internal { + //! Quick-sort implementation. + template + struct QSortImpl { + enum : size_t { + kStackSize = 64 * 2, + kISortThreshold = 7 + }; + + // Based on "PDCLib - Public Domain C Library" and rewritten to C++. + static void sort(T* base, size_t size, const CompareT& cmp) noexcept { + T* end = base + size; + T* stack[kStackSize]; + T** stackptr = stack; + + for (;;) { + if ((size_t)(end - base) > kISortThreshold) { + // We work from second to last - first will be pivot element. + T* pi = base + 1; + T* pj = end - 1; + std::swap(base[(size_t)(end - base) / 2], base[0]); + + if (cmp(*pi , *pj ) > 0) std::swap(*pi , *pj ); + if (cmp(*base, *pj ) > 0) std::swap(*base, *pj ); + if (cmp(*pi , *base) > 0) std::swap(*pi , *base); + + // Now we have the median for pivot element, entering main loop. + for (;;) { + while (pi < pj && cmp(*++pi, *base) < 0) continue; // Move `i` right until `*i >= pivot`. + while (pj > base && cmp(*--pj, *base) > 0) continue; // Move `j` left until `*j <= pivot`. + + if (pi > pj) break; + std::swap(*pi, *pj); + } + + // Move pivot into correct place. + std::swap(*base, *pj); + + // Larger subfile base / end to stack, sort smaller. + if (pj - base > end - pi) { + // Left is larger. + *stackptr++ = base; + *stackptr++ = pj; + base = pi; + } + else { + // Right is larger. + *stackptr++ = pi; + *stackptr++ = end; + end = pj; + } + ASMJIT_ASSERT(stackptr <= stack + kStackSize); + } + else { + // UB sanitizer doesn't like applying offset to a nullptr base. + if (base != end) + iSort(base, (size_t)(end - base), cmp); + + if (stackptr == stack) + break; + + end = *--stackptr; + base = *--stackptr; + } + } + } + }; +} +//! \endcond + +//! Quick sort implementation. +//! +//! The main reason to provide a custom qsort implementation is that we needed something that will +//! never throw `bad_alloc` exception. This implementation doesn't use dynamic memory allocation. +template> +static ASMJIT_INLINE_NODEBUG void qSort(T* base, size_t size, const CompareT& cmp = CompareT()) noexcept { + Internal::QSortImpl::sort(base, size, cmp); +} + +// Support - ReverseIterator +// ========================= + +//! Reverse iterator to avoid including `` header for iteration over arrays, specialized for +//! AsmJit use (noexcept by design). +template +class ArrayReverseIterator { +public: + //! \name Members + //! \{ + + T* _ptr {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG constexpr ArrayReverseIterator() noexcept = default; + ASMJIT_INLINE_NODEBUG constexpr ArrayReverseIterator(const ArrayReverseIterator& other) noexcept = default; + ASMJIT_INLINE_NODEBUG constexpr ArrayReverseIterator(T* ptr) noexcept : _ptr(ptr) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator=(const ArrayReverseIterator& other) noexcept = default; + + ASMJIT_INLINE_NODEBUG bool operator==(const T* other) const noexcept { return _ptr == other; } + ASMJIT_INLINE_NODEBUG bool operator==(const ArrayReverseIterator& other) const noexcept { return _ptr == other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator!=(const T* other) const noexcept { return _ptr != other; } + ASMJIT_INLINE_NODEBUG bool operator!=(const ArrayReverseIterator& other) const noexcept { return _ptr != other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator<(const T* other) const noexcept { return _ptr < other; } + ASMJIT_INLINE_NODEBUG bool operator<(const ArrayReverseIterator& other) const noexcept { return _ptr < other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator<=(const T* other) const noexcept { return _ptr <= other; } + ASMJIT_INLINE_NODEBUG bool operator<=(const ArrayReverseIterator& other) const noexcept { return _ptr <= other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator>(const T* other) const noexcept { return _ptr > other; } + ASMJIT_INLINE_NODEBUG bool operator>(const ArrayReverseIterator& other) const noexcept { return _ptr > other._ptr; } + + ASMJIT_INLINE_NODEBUG bool operator>=(const T* other) const noexcept { return _ptr >= other; } + ASMJIT_INLINE_NODEBUG bool operator>=(const ArrayReverseIterator& other) const noexcept { return _ptr >= other._ptr; } + + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator++() noexcept { _ptr--; return *this; } + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator++(int) noexcept { ArrayReverseIterator prev(*this); _ptr--; return prev; } + + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator--() noexcept { _ptr++; return *this; } + ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator--(int) noexcept { ArrayReverseIterator prev(*this); _ptr++; return prev; } + + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator operator+(const Diff& n) noexcept { return ArrayReverseIterator(_ptr -= n); } + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator operator-(const Diff& n) noexcept { return ArrayReverseIterator(_ptr += n); } + + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator+=(const Diff& n) noexcept { _ptr -= n; return *this; } + template ASMJIT_INLINE_NODEBUG ArrayReverseIterator& operator-=(const Diff& n) noexcept { _ptr += n; return *this; } + + ASMJIT_INLINE_NODEBUG constexpr T& operator*() const noexcept { return _ptr[-1]; } + ASMJIT_INLINE_NODEBUG constexpr T* operator->() const noexcept { return &_ptr[-1]; } + + template ASMJIT_INLINE_NODEBUG T& operator[](const Diff& n) noexcept { return *(_ptr - n - 1); } + + ASMJIT_INLINE_NODEBUG operator T*() const noexcept { return _ptr; } + + //! \} +}; + +// Support - Array +// =============== + +//! Array type, similar to std::array, with the possibility to use enums in operator[]. +//! +//! \note The array has C semantics - the elements in the array are not initialized. +template +struct Array { + //! \name Members + //! \{ + + //! The underlying array data, use \ref data() to access it. + T _data[N]; + + //! \} + + //! \cond + // std compatibility. + typedef T value_type; + typedef size_t size_type; + typedef ptrdiff_t difference_type; + + typedef value_type& reference; + typedef const value_type& const_reference; + + typedef value_type* pointer; + typedef const value_type* const_pointer; + + typedef pointer iterator; + typedef const_pointer const_iterator; + //! \endcond + + //! \name Overloaded Operators + //! \{ + + template + inline T& operator[](const Index& index) noexcept { + typedef typename Internal::StdInt::Type U; + ASMJIT_ASSERT(U(index) < N); + return _data[U(index)]; + } + + template + inline const T& operator[](const Index& index) const noexcept { + typedef typename Internal::StdInt::Type U; + ASMJIT_ASSERT(U(index) < N); + return _data[U(index)]; + } + + inline bool operator==(const Array& other) const noexcept { + for (size_t i = 0; i < N; i++) + if (_data[i] != other._data[i]) + return false; + return true; + } + + inline bool operator!=(const Array& other) const noexcept { + return !operator==(other); + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return false; } + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return N; } + + ASMJIT_INLINE_NODEBUG T* data() noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const T* data() const noexcept { return _data; } + + ASMJIT_INLINE_NODEBUG T& front() noexcept { return _data[0]; } + ASMJIT_INLINE_NODEBUG const T& front() const noexcept { return _data[0]; } + + ASMJIT_INLINE_NODEBUG T& back() noexcept { return _data[N - 1]; } + ASMJIT_INLINE_NODEBUG const T& back() const noexcept { return _data[N - 1]; } + + ASMJIT_INLINE_NODEBUG T* begin() noexcept { return _data; } + ASMJIT_INLINE_NODEBUG T* end() noexcept { return _data + N; } + + ASMJIT_INLINE_NODEBUG const T* begin() const noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const T* end() const noexcept { return _data + N; } + + ASMJIT_INLINE_NODEBUG const T* cbegin() const noexcept { return _data; } + ASMJIT_INLINE_NODEBUG const T* cend() const noexcept { return _data + N; } + + //! \} + + //! \name Utilities + //! \{ + + inline void swap(Array& other) noexcept { + for (size_t i = 0; i < N; i++) + std::swap(_data[i], other._data[i]); + } + + inline void fill(const T& value) noexcept { + for (size_t i = 0; i < N; i++) + _data[i] = value; + } + + inline void copyFrom(const Array& other) noexcept { + for (size_t i = 0; i < N; i++) + _data[i] = other._data[i]; + } + + template + inline void combine(const Array& other) noexcept { + for (size_t i = 0; i < N; i++) + _data[i] = Operator::op(_data[i], other._data[i]); + } + + template + inline T aggregate(T initialValue = T()) const noexcept { + T value = initialValue; + for (size_t i = 0; i < N; i++) + value = Operator::op(value, _data[i]); + return value; + } + + template + inline void forEach(Fn&& fn) noexcept { + for (size_t i = 0; i < N; i++) + fn(_data[i]); + } + //! \} +}; + +// Support::Temporary +// ================== + +//! Used to pass a temporary buffer to: +//! +//! - Containers that use user-passed buffer as an initial storage (still can grow). +//! - Zone allocator that would use the temporary buffer as a first block. +struct Temporary { + //! \name Members + //! \{ + + void* _data; + size_t _size; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG constexpr Temporary(const Temporary& other) noexcept = default; + ASMJIT_INLINE_NODEBUG constexpr Temporary(void* data, size_t size) noexcept + : _data(data), + _size(size) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Temporary& operator=(const Temporary& other) noexcept = default; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the data storage. + template + ASMJIT_INLINE_NODEBUG constexpr T* data() const noexcept { return static_cast(_data); } + //! Returns the data storage size in bytes. + ASMJIT_INLINE_NODEBUG constexpr size_t size() const noexcept { return _size; } + + //! \} +}; + +} // {Support} + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_SUPPORT_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/target.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/target.h new file mode 100644 index 0000000000000000000000000000000000000000..ebff5a15e692519bacf1ca224d429e09429c0f63 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/target.h @@ -0,0 +1,59 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_TARGET_H_INCLUDED +#define ASMJIT_CORE_TARGET_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/cpuinfo.h" +#include "../core/func.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Target is an abstract class that describes a machine code target. +class ASMJIT_VIRTAPI Target { +public: + ASMJIT_BASE_CLASS(Target) + ASMJIT_NONCOPYABLE(Target) + + //! Target environment information. + Environment _environment; + //! Target CPU features. + CpuFeatures _cpuFeatures; + + //! \name Construction & Destruction + //! \{ + + //! Creates a `Target` instance. + ASMJIT_API Target() noexcept; + //! Destroys the `Target` instance. + ASMJIT_API virtual ~Target() noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns target's environment. + ASMJIT_INLINE_NODEBUG const Environment& environment() const noexcept { return _environment; } + //! Returns the target architecture. + ASMJIT_INLINE_NODEBUG Arch arch() const noexcept { return _environment.arch(); } + //! Returns the target sub-architecture. + ASMJIT_INLINE_NODEBUG SubArch subArch() const noexcept { return _environment.subArch(); } + + //! Returns target CPU features. + ASMJIT_INLINE_NODEBUG const CpuFeatures& cpuFeatures() const noexcept { return _cpuFeatures; } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_TARGET_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/type.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/type.h new file mode 100644 index 0000000000000000000000000000000000000000..985c223f891eb4d11aea39e69bb5fb75d16fbad5 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/type.h @@ -0,0 +1,443 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_TYPE_H_INCLUDED +#define ASMJIT_CORE_TYPE_H_INCLUDED + +#include "../core/globals.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_core +//! \{ + +//! Type identifier provides a minimalist type system used across AsmJit library. +//! +//! This is an additional information that can be used to describe a value-type of physical or virtual register. It's +//! used mostly by BaseCompiler to describe register representation (the group of data stored in the register and the +//! width used) and it's also used by APIs that allow to describe and work with function signatures. +enum class TypeId : uint8_t { + //! Void type. + kVoid = 0, + + _kBaseStart = 32, + _kBaseEnd = 44, + + _kIntStart = 32, + _kIntEnd = 41, + + //! Abstract signed integer type that has a native size. + kIntPtr = 32, + //! Abstract unsigned integer type that has a native size. + kUIntPtr = 33, + + //! 8-bit signed integer type. + kInt8 = 34, + //! 8-bit unsigned integer type. + kUInt8 = 35, + //! 16-bit signed integer type. + kInt16 = 36, + //! 16-bit unsigned integer type. + kUInt16 = 37, + //! 32-bit signed integer type. + kInt32 = 38, + //! 32-bit unsigned integer type. + kUInt32 = 39, + //! 64-bit signed integer type. + kInt64 = 40, + //! 64-bit unsigned integer type. + kUInt64 = 41, + + _kFloatStart = 42, + _kFloatEnd = 44, + + //! 32-bit floating point type. + kFloat32 = 42, + //! 64-bit floating point type. + kFloat64 = 43, + //! 80-bit floating point type. + kFloat80 = 44, + + _kMaskStart = 45, + _kMaskEnd = 48, + + //! 8-bit opmask register (K). + kMask8 = 45, + //! 16-bit opmask register (K). + kMask16 = 46, + //! 32-bit opmask register (K). + kMask32 = 47, + //! 64-bit opmask register (K). + kMask64 = 48, + + _kMmxStart = 49, + _kMmxEnd = 50, + + //! 64-bit MMX register only used for 32 bits. + kMmx32 = 49, + //! 64-bit MMX register. + kMmx64 = 50, + + _kVec32Start = 51, + _kVec32End = 60, + + kInt8x4 = 51, + kUInt8x4 = 52, + kInt16x2 = 53, + kUInt16x2 = 54, + kInt32x1 = 55, + kUInt32x1 = 56, + kFloat32x1 = 59, + + _kVec64Start = 61, + _kVec64End = 70, + + kInt8x8 = 61, + kUInt8x8 = 62, + kInt16x4 = 63, + kUInt16x4 = 64, + kInt32x2 = 65, + kUInt32x2 = 66, + kInt64x1 = 67, + kUInt64x1 = 68, + kFloat32x2 = 69, + kFloat64x1 = 70, + + _kVec128Start = 71, + _kVec128End = 80, + + kInt8x16 = 71, + kUInt8x16 = 72, + kInt16x8 = 73, + kUInt16x8 = 74, + kInt32x4 = 75, + kUInt32x4 = 76, + kInt64x2 = 77, + kUInt64x2 = 78, + kFloat32x4 = 79, + kFloat64x2 = 80, + + _kVec256Start = 81, + _kVec256End = 90, + + kInt8x32 = 81, + kUInt8x32 = 82, + kInt16x16 = 83, + kUInt16x16 = 84, + kInt32x8 = 85, + kUInt32x8 = 86, + kInt64x4 = 87, + kUInt64x4 = 88, + kFloat32x8 = 89, + kFloat64x4 = 90, + + _kVec512Start = 91, + _kVec512End = 100, + + kInt8x64 = 91, + kUInt8x64 = 92, + kInt16x32 = 93, + kUInt16x32 = 94, + kInt32x16 = 95, + kUInt32x16 = 96, + kInt64x8 = 97, + kUInt64x8 = 98, + kFloat32x16 = 99, + kFloat64x8 = 100, + + kLastAssigned = kFloat64x8, + + kMaxValue = 255 +}; +ASMJIT_DEFINE_ENUM_COMPARE(TypeId) + +//! Type identifier utilities. +namespace TypeUtils { + +struct TypeData { + TypeId scalarOf[uint32_t(TypeId::kMaxValue) + 1]; + uint8_t sizeOf[uint32_t(TypeId::kMaxValue) + 1]; +}; +ASMJIT_VARAPI const TypeData _typeData; + +//! Returns the scalar type of `typeId`. +static ASMJIT_INLINE_NODEBUG TypeId scalarOf(TypeId typeId) noexcept { return _typeData.scalarOf[uint32_t(typeId)]; } + +//! Returns the size [in bytes] of `typeId`. +static ASMJIT_INLINE_NODEBUG uint32_t sizeOf(TypeId typeId) noexcept { return _typeData.sizeOf[uint32_t(typeId)]; } + +//! Tests whether a given type `typeId` is between `a` and `b`. +static ASMJIT_INLINE_NODEBUG constexpr bool isBetween(TypeId typeId, TypeId a, TypeId b) noexcept { + return Support::isBetween(uint32_t(typeId), uint32_t(a), uint32_t(b)); +} + +//! Tests whether a given type `typeId` is \ref TypeId::kVoid. +static ASMJIT_INLINE_NODEBUG constexpr bool isVoid(TypeId typeId) noexcept { return typeId == TypeId::kVoid; } +//! Tests whether a given type `typeId` is a valid non-void type. +static ASMJIT_INLINE_NODEBUG constexpr bool isValid(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kIntStart, TypeId::_kVec512End); } +//! Tests whether a given type `typeId` is scalar (has no vector part). +static ASMJIT_INLINE_NODEBUG constexpr bool isScalar(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kBaseStart, TypeId::_kBaseEnd); } +//! Tests whether a given type `typeId` is abstract, which means that its size depends on register size. +static ASMJIT_INLINE_NODEBUG constexpr bool isAbstract(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kIntPtr, TypeId::kUIntPtr); } + +//! Tests whether a given type is a scalar integer (signed or unsigned) of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isInt(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kIntStart, TypeId::_kIntEnd); } +//! Tests whether a given type is a scalar 8-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt8(TypeId typeId) noexcept { return typeId == TypeId::kInt8; } +//! Tests whether a given type is a scalar 8-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt8(TypeId typeId) noexcept { return typeId == TypeId::kUInt8; } +//! Tests whether a given type is a scalar 16-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt16(TypeId typeId) noexcept { return typeId == TypeId::kInt16; } +//! Tests whether a given type is a scalar 16-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt16(TypeId typeId) noexcept { return typeId == TypeId::kUInt16; } +//! Tests whether a given type is a scalar 32-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt32(TypeId typeId) noexcept { return typeId == TypeId::kInt32; } +//! Tests whether a given type is a scalar 32-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt32(TypeId typeId) noexcept { return typeId == TypeId::kUInt32; } +//! Tests whether a given type is a scalar 64-bit integer (signed). +static ASMJIT_INLINE_NODEBUG constexpr bool isInt64(TypeId typeId) noexcept { return typeId == TypeId::kInt64; } +//! Tests whether a given type is a scalar 64-bit integer (unsigned). +static ASMJIT_INLINE_NODEBUG constexpr bool isUInt64(TypeId typeId) noexcept { return typeId == TypeId::kUInt64; } + +//! Tests whether a given type is an 8-bit general purpose register representing either signed or unsigned 8-bit integer. +static ASMJIT_INLINE_NODEBUG constexpr bool isGp8(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt8, TypeId::kUInt8); } +//! Tests whether a given type is a 16-bit general purpose register representing either signed or unsigned 16-bit integer +static ASMJIT_INLINE_NODEBUG constexpr bool isGp16(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt16, TypeId::kUInt16); } +//! Tests whether a given type is a 32-bit general purpose register representing either signed or unsigned 32-bit integer +static ASMJIT_INLINE_NODEBUG constexpr bool isGp32(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt32, TypeId::kUInt32); } +//! Tests whether a given type is a 64-bit general purpose register representing either signed or unsigned 64-bit integer +static ASMJIT_INLINE_NODEBUG constexpr bool isGp64(TypeId typeId) noexcept { return isBetween(typeId, TypeId::kInt64, TypeId::kUInt64); } + +//! Tests whether a given type is a scalar floating point of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kFloatStart, TypeId::_kFloatEnd); } +//! Tests whether a given type is a scalar 32-bit float. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat32(TypeId typeId) noexcept { return typeId == TypeId::kFloat32; } +//! Tests whether a given type is a scalar 64-bit float. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat64(TypeId typeId) noexcept { return typeId == TypeId::kFloat64; } +//! Tests whether a given type is a scalar 80-bit float. +static ASMJIT_INLINE_NODEBUG constexpr bool isFloat80(TypeId typeId) noexcept { return typeId == TypeId::kFloat80; } + +//! Tests whether a given type is a mask register of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kMaskStart, TypeId::_kMaskEnd); } +//! Tests whether a given type is an 8-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask8(TypeId typeId) noexcept { return typeId == TypeId::kMask8; } +//! Tests whether a given type is an 16-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask16(TypeId typeId) noexcept { return typeId == TypeId::kMask16; } +//! Tests whether a given type is an 32-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask32(TypeId typeId) noexcept { return typeId == TypeId::kMask32; } +//! Tests whether a given type is an 64-bit mask register. +static ASMJIT_INLINE_NODEBUG constexpr bool isMask64(TypeId typeId) noexcept { return typeId == TypeId::kMask64; } + +//! Tests whether a given type is an MMX register. +//! +//! \note MMX functionality is in general deprecated on X86 architecture. AsmJit provides it just for completeness. +static ASMJIT_INLINE_NODEBUG constexpr bool isMmx(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kMmxStart, TypeId::_kMmxEnd); } +//! Tests whether a given type is an MMX register, which only uses the low 32 bits of data (only specific cases). +//! +//! \note MMX functionality is in general deprecated on X86 architecture. AsmJit provides it just for completeness. +static ASMJIT_INLINE_NODEBUG constexpr bool isMmx32(TypeId typeId) noexcept { return typeId == TypeId::kMmx32; } +//! Tests whether a given type is an MMX register, which uses 64 bits of data (default). +//! +//! \note MMX functionality is in general deprecated on X86 architecture. AsmJit provides it just for completeness. +static ASMJIT_INLINE_NODEBUG constexpr bool isMmx64(TypeId typeId) noexcept { return typeId == TypeId::kMmx64; } + +//! Tests whether a given type is a vector register of any size. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec32Start, TypeId::_kVec512End); } +//! Tests whether a given type is a 32-bit or 32-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec32(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec32Start, TypeId::_kVec32End); } +//! Tests whether a given type is a 64-bit or 64-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec64(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec64Start, TypeId::_kVec64End); } +//! Tests whether a given type is a 128-bit or 128-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec128(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec128Start, TypeId::_kVec128End); } +//! Tests whether a given type is a 256-bit or 256-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec256(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec256Start, TypeId::_kVec256End); } +//! Tests whether a given type is a 512-bit or 512-bit view of a vector register. +static ASMJIT_INLINE_NODEBUG constexpr bool isVec512(TypeId typeId) noexcept { return isBetween(typeId, TypeId::_kVec512Start, TypeId::_kVec512End); } + +//! \cond +enum TypeCategory : uint32_t { + kTypeCategoryUnknown = 0, + kTypeCategoryEnum = 1, + kTypeCategoryIntegral = 2, + kTypeCategoryFloatingPoint = 3, + kTypeCategoryFunction = 4 +}; + +template +struct TypeIdOfT_ByCategory {}; // Fails if not specialized. + +template +struct TypeIdOfT_ByCategory { + enum : uint32_t { + kTypeId = uint32_t( + (sizeof(T) == 1 && std::is_signed::value) ? TypeId::kInt8 : + (sizeof(T) == 1 && !std::is_signed::value) ? TypeId::kUInt8 : + (sizeof(T) == 2 && std::is_signed::value) ? TypeId::kInt16 : + (sizeof(T) == 2 && !std::is_signed::value) ? TypeId::kUInt16 : + (sizeof(T) == 4 && std::is_signed::value) ? TypeId::kInt32 : + (sizeof(T) == 4 && !std::is_signed::value) ? TypeId::kUInt32 : + (sizeof(T) == 8 && std::is_signed::value) ? TypeId::kInt64 : + (sizeof(T) == 8 && !std::is_signed::value) ? TypeId::kUInt64 : TypeId::kVoid) + }; +}; + +template +struct TypeIdOfT_ByCategory { + enum : uint32_t { + kTypeId = uint32_t( + (sizeof(T) == 4 ) ? TypeId::kFloat32 : + (sizeof(T) == 8 ) ? TypeId::kFloat64 : + (sizeof(T) >= 10) ? TypeId::kFloat80 : TypeId::kVoid) + }; +}; + +template +struct TypeIdOfT_ByCategory + : public TypeIdOfT_ByCategory::type, kTypeCategoryIntegral> {}; + +template +struct TypeIdOfT_ByCategory { + enum : uint32_t { + kTypeId = uint32_t(TypeId::kUIntPtr) + }; +}; +//! \endcond + +//! TypeIdOfT<> template allows to get a TypeId from a C++ type `T`. +#ifdef _DOXYGEN +template +struct TypeIdOfT { + //! TypeId of C++ type `T`. + static constexpr TypeId kTypeId = _TypeIdDeducedAtCompileTime_; +}; +#else +template +struct TypeIdOfT + : public TypeIdOfT_ByCategory::value ? kTypeCategoryEnum : + std::is_integral::value ? kTypeCategoryIntegral : + std::is_floating_point::value ? kTypeCategoryFloatingPoint : + std::is_function::value ? kTypeCategoryFunction : kTypeCategoryUnknown> {}; +#endif + +//! \cond +template +struct TypeIdOfT { + enum : uint32_t { + kTypeId = uint32_t(TypeId::kUIntPtr) + }; +}; + +template +struct TypeIdOfT { + enum : uint32_t { + kTypeId = uint32_t(TypeId::kUIntPtr) + }; +}; +//! \endcond + +//! Returns a corresponding \ref TypeId of `T` type. +template +static ASMJIT_INLINE_NODEBUG constexpr TypeId typeIdOfT() noexcept { return TypeId(TypeIdOfT::kTypeId); } + +//! Returns offset needed to convert a `kIntPtr` and `kUIntPtr` TypeId into a type that matches `registerSize` +//! (general-purpose register size). If you find such TypeId it's then only about adding the offset to it. +//! +//! For example: +//! +//! ``` +//! uint32_t registerSize = /* 4 or 8 */; +//! uint32_t deabstractDelta = TypeUtils::deabstractDeltaOfSize(registerSize); +//! +//! TypeId typeId = 'some type-id'; +//! +//! // Normalize some typeId into a non-abstract typeId. +//! if (TypeUtils::isAbstract(typeId)) typeId += deabstractDelta; +//! +//! // The same, but by using TypeUtils::deabstract() function. +//! typeId = TypeUtils::deabstract(typeId, deabstractDelta); +//! ``` +static ASMJIT_INLINE_NODEBUG constexpr uint32_t deabstractDeltaOfSize(uint32_t registerSize) noexcept { + return registerSize >= 8 ? uint32_t(TypeId::kInt64) - uint32_t(TypeId::kIntPtr) + : uint32_t(TypeId::kInt32) - uint32_t(TypeId::kIntPtr); +} + +//! Deabstracts a given `typeId` into a native type by using `deabstractDelta`, which was previously +//! calculated by calling \ref deabstractDeltaOfSize() with a target native register size. +static ASMJIT_INLINE_NODEBUG constexpr TypeId deabstract(TypeId typeId, uint32_t deabstractDelta) noexcept { + return isAbstract(typeId) ? TypeId(uint32_t(typeId) + deabstractDelta) : typeId; +} + +static ASMJIT_INLINE_NODEBUG constexpr TypeId scalarToVector(TypeId scalarTypeId, TypeId vecStartId) noexcept { + return TypeId(uint32_t(vecStartId) + uint32_t(scalarTypeId) - uint32_t(TypeId::kInt8)); +} + +} // {TypeUtils} + +//! Provides type identifiers that can be used in templates instead of native types. +namespace Type { + +//! bool as C++ type-name. +struct Bool {}; +//! int8_t as C++ type-name. +struct Int8 {}; +//! uint8_t as C++ type-name. +struct UInt8 {}; +//! int16_t as C++ type-name. +struct Int16 {}; +//! uint16_t as C++ type-name. +struct UInt16 {}; +//! int32_t as C++ type-name. +struct Int32 {}; +//! uint32_t as C++ type-name. +struct UInt32 {}; +//! int64_t as C++ type-name. +struct Int64 {}; +//! uint64_t as C++ type-name. +struct UInt64 {}; +//! intptr_t as C++ type-name. +struct IntPtr {}; +//! uintptr_t as C++ type-name. +struct UIntPtr {}; +//! float as C++ type-name. +struct Float32 {}; +//! double as C++ type-name. +struct Float64 {}; + +} // {Type} + +//! \cond +#define ASMJIT_DEFINE_TYPE_ID(T, TYPE_ID) \ +namespace TypeUtils { \ + template<> \ + struct TypeIdOfT { \ + enum : uint32_t { \ + kTypeId = uint32_t(TYPE_ID) \ + }; \ + }; \ +} + +ASMJIT_DEFINE_TYPE_ID(void , TypeId::kVoid); +ASMJIT_DEFINE_TYPE_ID(Type::Bool , TypeId::kUInt8); +ASMJIT_DEFINE_TYPE_ID(Type::Int8 , TypeId::kInt8); +ASMJIT_DEFINE_TYPE_ID(Type::UInt8 , TypeId::kUInt8); +ASMJIT_DEFINE_TYPE_ID(Type::Int16 , TypeId::kInt16); +ASMJIT_DEFINE_TYPE_ID(Type::UInt16 , TypeId::kUInt16); +ASMJIT_DEFINE_TYPE_ID(Type::Int32 , TypeId::kInt32); +ASMJIT_DEFINE_TYPE_ID(Type::UInt32 , TypeId::kUInt32); +ASMJIT_DEFINE_TYPE_ID(Type::Int64 , TypeId::kInt64); +ASMJIT_DEFINE_TYPE_ID(Type::UInt64 , TypeId::kUInt64); +ASMJIT_DEFINE_TYPE_ID(Type::IntPtr , TypeId::kIntPtr); +ASMJIT_DEFINE_TYPE_ID(Type::UIntPtr, TypeId::kUIntPtr); +ASMJIT_DEFINE_TYPE_ID(Type::Float32, TypeId::kFloat32); +ASMJIT_DEFINE_TYPE_ID(Type::Float64, TypeId::kFloat64); +//! \endcond + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_TYPE_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/virtmem.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/virtmem.h new file mode 100644 index 0000000000000000000000000000000000000000..a4a1359584e4f72f4c10784075d7b4193b5561c3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/virtmem.h @@ -0,0 +1,327 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_VIRTMEM_H_INCLUDED +#define ASMJIT_CORE_VIRTMEM_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_JIT + +#include "../core/globals.h" +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_virtual_memory +//! \{ + +//! Virtual memory management. +namespace VirtMem { + +//! Describes whether instruction cache should be flushed after a write operation. +enum class CachePolicy : uint32_t { + //! Default policy. + //! + //! In some places this would mean `kFlushAfterWrite` and in some places it would mean `kNeverFlush`. + //! For example if it's known that an address has never been used before to execute code. + kDefault = 0, + + //! Flush instruction cache after a write operation. + kFlushAfterWrite = 1, + + //! Avoid flushing instruction cache after a write operation. + kNeverFlush = 2 +}; + +//! Flushes instruction cache in the given region. +//! +//! Only useful on non-x86 architectures, however, it's a good practice to call it on any platform to make your +//! code more portable. +ASMJIT_API void flushInstructionCache(void* p, size_t size) noexcept; + +//! Virtual memory information. +struct Info { + //! Virtual memory page size. + uint32_t pageSize; + //! Virtual memory page granularity. + uint32_t pageGranularity; +}; + +//! Returns virtual memory information, see `VirtMem::Info` for more details. +ASMJIT_API Info info() noexcept; + +//! Returns the size of the smallest large page supported. +//! +//! AsmJit only uses the smallest large page at the moment as these are usually perfectly sized for executable +//! memory allocation (standard size is 2MB, but different sizes are possible). +//! +//! Returns either the detected large page size or 0, if large page support is either not supported by AsmJit +//! or not accessible to the process. +ASMJIT_API size_t largePageSize() noexcept; + +//! Virtual memory access and mmap-specific flags. +enum class MemoryFlags : uint32_t { + //! No flags. + kNone = 0, + + //! Memory is readable. + kAccessRead = 0x00000001u, + + //! Memory is writable. + kAccessWrite = 0x00000002u, + + //! Memory is executable. + kAccessExecute = 0x00000004u, + + //! A combination of \ref kAccessRead and \ref kAccessWrite. + kAccessReadWrite = kAccessRead | kAccessWrite, + + //! A combination of \ref kAccessRead, \ref kAccessWrite. + kAccessRW = kAccessRead | kAccessWrite, + + //! A combination of \ref kAccessRead and \ref kAccessExecute. + kAccessRX = kAccessRead | kAccessExecute, + + //! A combination of \ref kAccessRead, \ref kAccessWrite, and \ref kAccessExecute. + kAccessRWX = kAccessRead | kAccessWrite | kAccessExecute, + + //! Use a `MAP_JIT` flag available on Apple platforms (introduced by Mojave), which allows JIT code to be + //! executed in a MAC bundle. + //! + //! This flag may be turned on by the allocator if there is no other way of allocating executable memory. + //! + //! \note This flag can only be used with \ref VirtMem::alloc(), `MAP_JIT` only works on OSX and not on iOS. + //! When a process uses `fork()` the child process has no access to the pages mapped with `MAP_JIT`. + kMMapEnableMapJit = 0x00000010u, + + //! Pass `PROT_MAX(PROT_READ)` or `PROT_MPROTECT(PROT_READ)` to `mmap()` on platforms that support it. + //! + //! This flag allows to set a "maximum access" that the memory page can get during its lifetime. Use + //! \ref VirtMem::protect() to change the access flags. + //! + //! \note This flag can only be used with \ref VirtMem::alloc() and \ref VirtMem::allocDualMapping(). + //! However \ref VirtMem::allocDualMapping() may automatically use this if \ref kAccessRead is used. + kMMapMaxAccessRead = 0x00000020u, + + //! Pass `PROT_MAX(PROT_WRITE)` or `PROT_MPROTECT(PROT_WRITE)` to `mmap()` on platforms that support it. + //! + //! This flag allows to set a "maximum access" that the memory page can get during its lifetime. Use + //! \ref VirtMem::protect() to change the access flags. + //! + //! \note This flag can only be used with \ref VirtMem::alloc() and \ref VirtMem::allocDualMapping(). + //! However \ref VirtMem::allocDualMapping() may automatically use this if \ref kAccessWrite is used. + kMMapMaxAccessWrite = 0x00000040u, + + //! Pass `PROT_MAX(PROT_EXEC)` or `PROT_MPROTECT(PROT_EXEC)` to `mmap()` on platforms that support it. + //! + //! This flag allows to set a "maximum access" that the memory page can get during its lifetime. Use + //! \ref VirtMem::protect() to change the access flags. + //! + //! \note This flag can only be used with \ref VirtMem::alloc() and \ref VirtMem::allocDualMapping(). + //! However \ref VirtMem::allocDualMapping() may automatically use this if \ref kAccessExecute is used. + kMMapMaxAccessExecute = 0x00000080u, + + //! A combination of \ref kMMapMaxAccessRead and \ref kMMapMaxAccessWrite. + kMMapMaxAccessReadWrite = kMMapMaxAccessRead | kMMapMaxAccessWrite, + + //! A combination of \ref kMMapMaxAccessRead and \ref kMMapMaxAccessWrite. + kMMapMaxAccessRW = kMMapMaxAccessRead | kMMapMaxAccessWrite, + + //! A combination of \ref kMMapMaxAccessRead and \ref kMMapMaxAccessExecute. + kMMapMaxAccessRX = kMMapMaxAccessRead | kMMapMaxAccessExecute, + + //! A combination of \ref kMMapMaxAccessRead, \ref kMMapMaxAccessWrite, \ref kMMapMaxAccessExecute. + kMMapMaxAccessRWX = kMMapMaxAccessRead | kMMapMaxAccessWrite | kMMapMaxAccessExecute, + + //! Use `MAP_SHARED` when calling mmap(). + //! + //! \note In some cases `MAP_SHARED` may be set automatically. For example, some dual mapping implementations must + //! use `MAP_SHARED` instead of `MAP_PRIVATE` to ensure that the OS would not apply copy on write on RW page, which + //! would cause RX page not having the updated content. + kMapShared = 0x00000100u, + + //! Request large memory mapped pages. + //! + //! \remarks If this option is used and large page(s) cannot be mapped, the allocation will fail. Fallback to + //! regular pages must be done by the user in this case. Higher level API such as \ref JitAllocator provides an + //! additional mechanism to allocate regular page(s) when large page(s) allocation fails. + kMMapLargePages = 0x00000200u, + + //! Not an access flag, only used by `allocDualMapping()` to override the default allocation strategy to always use + //! a 'tmp' directory instead of "/dev/shm" (on POSIX platforms). Please note that this flag will be ignored if the + //! operating system allows to allocate an executable memory by a different API than `open()` or `shm_open()`. For + //! example on Linux `memfd_create()` is preferred and on BSDs `shm_open(SHM_ANON, ...)` is used if SHM_ANON is + //! defined. + //! + //! \note This flag can only be used with \ref VirtMem::alloc(). + kMappingPreferTmp = 0x80000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(MemoryFlags) + +//! Allocates virtual memory by either using `mmap()` (POSIX) or `VirtualAlloc()` (Windows). +//! +//! \note `size` should be aligned to page size, use \ref VirtMem::info() to obtain it. Invalid size will not be +//! corrected by the implementation and the allocation would not succeed in such case. +ASMJIT_API Error alloc(void** p, size_t size, MemoryFlags flags) noexcept; + +//! Releases virtual memory previously allocated by \ref VirtMem::alloc(). +//! +//! \note The size must be the same as used by \ref VirtMem::alloc(). If the size is not the same value the call +//! will fail on any POSIX system, but pass on Windows, because it's implemented differently. +ASMJIT_API Error release(void* p, size_t size) noexcept; + +//! A cross-platform wrapper around `mprotect()` (POSIX) and `VirtualProtect()` (Windows). +ASMJIT_API Error protect(void* p, size_t size, MemoryFlags flags) noexcept; + +//! Dual memory mapping used to map an anonymous memory into two memory regions where one region is read-only, but +//! executable, and the second region is read+write, but not executable. See \ref VirtMem::allocDualMapping() for +//! more details. +struct DualMapping { + //! Pointer to data with 'Read+Execute' access (this memory is not writable). + void* rx; + //! Pointer to data with 'Read+Write' access (this memory is not executable). + void* rw; +}; + +//! Allocates virtual memory and creates two views of it where the first view has no write access. This is an addition +//! to the API that should be used in cases in which the operating system either enforces W^X security policy or the +//! application wants to use this policy by default to improve security and prevent an accidental (or purposed) +//! self-modifying code. +//! +//! The memory returned in the `dm` are two independent mappings of the same shared memory region. You must use +//! \ref VirtMem::releaseDualMapping() to release it when it's no longer needed. Never use `VirtMem::release()` to +//! release the memory returned by `allocDualMapping()` as that would fail on Windows. +//! +//! \remarks Both pointers in `dm` would be set to `nullptr` if the function fails. +ASMJIT_API Error allocDualMapping(DualMapping* dm, size_t size, MemoryFlags flags) noexcept; + +//! Releases virtual memory mapping previously allocated by \ref VirtMem::allocDualMapping(). +//! +//! \remarks Both pointers in `dm` would be set to `nullptr` if the function succeeds. +ASMJIT_API Error releaseDualMapping(DualMapping* dm, size_t size) noexcept; + +//! Hardened runtime flags. +enum class HardenedRuntimeFlags : uint32_t { + //! No flags. + kNone = 0, + + //! Hardened runtime is enabled - it's not possible to have "Write & Execute" memory protection. The runtime + //! enforces W^X (either write or execute). + //! + //! \note If the runtime is hardened it means that an operating system specific protection is used. For example + //! on Apple OSX it's possible to allocate memory with MAP_JIT flag and then use `pthread_jit_write_protect_np()` + //! to temporarily swap access permissions for the current thread. Dual mapping is also a possibility on X86/X64 + //! architecture. + kEnabled = 0x00000001u, + + //! Read+Write+Execute can only be allocated with MAP_JIT flag (Apple specific, only available on Apple platforms). + kMapJit = 0x00000002u, + + //! Read+Write+Execute can be allocated with dual mapping approach (one region with RW and the other with RX). + kDualMapping = 0x00000004u +}; +ASMJIT_DEFINE_ENUM_FLAGS(HardenedRuntimeFlags) + +//! Hardened runtime information. +struct HardenedRuntimeInfo { + //! \name Members + //! \{ + + //! Hardened runtime flags. + HardenedRuntimeFlags flags; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the hardened runtime `flag` is set. + ASMJIT_INLINE_NODEBUG bool hasFlag(HardenedRuntimeFlags flag) const noexcept { return Support::test(flags, flag); } + + //! \} +}; + +//! Returns runtime features provided by the OS. +ASMJIT_API HardenedRuntimeInfo hardenedRuntimeInfo() noexcept; + +//! Values that can be used with `protectJitMemory()` function. +enum class ProtectJitAccess : uint32_t { + //! Protect JIT memory with Read+Write permissions. + kReadWrite = 0, + //! Protect JIT memory with Read+Execute permissions. + kReadExecute = 1 +}; + +//! Protects access of memory mapped with MAP_JIT flag for the current thread. +//! +//! \note This feature is only available on Apple hardware (AArch64) at the moment and uses a non-portable +//! `pthread_jit_write_protect_np()` call when available. +//! +//! This function must be called before and after a memory mapped with MAP_JIT flag is modified. Example: +//! +//! ``` +//! void* codePtr = ...; +//! size_t codeSize = ...; +//! +//! VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadWrite); +//! memcpy(codePtr, source, codeSize); +//! VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadExecute); +//! VirtMem::flushInstructionCache(codePtr, codeSize); +//! ``` +//! +//! See \ref ProtectJitReadWriteScope, which makes it simpler than the code above. +ASMJIT_API void protectJitMemory(ProtectJitAccess access) noexcept; + +//! JIT protection scope that prepares the given memory block to be written to in the current thread. +//! +//! It calls `VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadWrite)` at construction time and +//! `VirtMem::protectJitMemory(VirtMem::ProtectJitAccess::kReadExecute)` combined with `flushInstructionCache()` +//! in destructor. The purpose of this class is to make writing to JIT memory easier. +class ProtectJitReadWriteScope { +public: + ASMJIT_NONCOPYABLE(ProtectJitReadWriteScope) + + //! \name Members + //! \{ + + void* _rxPtr; + size_t _size; + CachePolicy _policy; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Makes the given memory block RW protected. + ASMJIT_FORCE_INLINE ProtectJitReadWriteScope( + void* rxPtr, + size_t size, + CachePolicy policy = CachePolicy::kDefault) noexcept + : _rxPtr(rxPtr), + _size(size), + _policy(policy) { + protectJitMemory(ProtectJitAccess::kReadWrite); + } + + //! Makes the memory block RX protected again and flushes instruction cache. + ASMJIT_FORCE_INLINE ~ProtectJitReadWriteScope() noexcept { + protectJitMemory(ProtectJitAccess::kReadExecute); + + if (_policy != CachePolicy::kNeverFlush) + flushInstructionCache(_rxPtr, _size); + } + + //! \} +}; + +} // VirtMem + +//! \} + +ASMJIT_END_NAMESPACE + +#endif +#endif // ASMJIT_CORE_VIRTMEM_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zone.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zone.h new file mode 100644 index 0000000000000000000000000000000000000000..0adabd5d241ad70fe8e02e7f7a6c6844998b65af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zone.h @@ -0,0 +1,611 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONE_H_INCLUDED +#define ASMJIT_CORE_ZONE_H_INCLUDED + +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Zone memory. +//! +//! Zone is an incremental memory allocator that allocates memory by simply incrementing a pointer. It allocates +//! blocks of memory by using C's `malloc()`, but divides these blocks into smaller segments requested by calling +//! `Zone::alloc()` and friends. +//! +//! Zone has no function to release the allocated memory. It has to be released all at once by calling `reset()`. +//! If you need a more friendly allocator that also supports `release()`, consider using `Zone` with `ZoneAllocator`. +class Zone { +public: + ASMJIT_NONCOPYABLE(Zone) + + //! \cond INTERNAL + + //! A single block of memory managed by `Zone`. + struct Block { + inline uint8_t* data() const noexcept { + return const_cast(reinterpret_cast(this) + sizeof(*this)); + } + + //! Link to the previous block. + Block* prev; + //! Link to the next block. + Block* next; + //! Size of the block. + size_t size; + }; + + enum Limits : size_t { + kBlockSize = sizeof(Block), + kBlockOverhead = Globals::kAllocOverhead + kBlockSize, + + kMinBlockSize = 64, // The number is ridiculously small, but still possible. + kMaxBlockSize = size_t(1) << (sizeof(size_t) * 8 - 4 - 1), + kMinAlignment = 1, + kMaxAlignment = 64 + }; + + //! Pointer in the current block. + uint8_t* _ptr; + //! End of the current block. + uint8_t* _end; + //! Current block. + Block* _block; + + union { + struct { + //! Default block size. + size_t _blockSize : Support::bitSizeOf() - 4; + //! First block is temporary (ZoneTmp). + size_t _isTemporary : 1; + //! Block alignment (1 << alignment). + size_t _blockAlignmentShift : 3; + }; + size_t _packedData; + }; + + static ASMJIT_API const Block _zeroBlock; + + //! \endcond + + //! \name Construction & Destruction + //! \{ + + //! Creates a new Zone. + //! + //! The `blockSize` parameter describes the default size of the block. If the `size` parameter passed to `alloc()` + //! is greater than the default size `Zone` will allocate and use a larger block, but it will not change the + //! default `blockSize`. + //! + //! It's not required, but it's good practice to set `blockSize` to a reasonable value that depends on the usage + //! of `Zone`. Greater block sizes are generally safer and perform better than unreasonably low block sizes. + ASMJIT_INLINE_NODEBUG explicit Zone(size_t blockSize, size_t blockAlignment = 1) noexcept { + _init(blockSize, blockAlignment, nullptr); + } + + //! Creates a new Zone with a first block pointing to a `temporary` memory. + ASMJIT_INLINE_NODEBUG Zone(size_t blockSize, size_t blockAlignment, const Support::Temporary& temporary) noexcept { + _init(blockSize, blockAlignment, &temporary); + } + + //! \overload + ASMJIT_INLINE_NODEBUG Zone(size_t blockSize, size_t blockAlignment, const Support::Temporary* temporary) noexcept { + _init(blockSize, blockAlignment, temporary); + } + + //! Moves an existing `Zone`. + //! + //! \note You cannot move an existing `ZoneTmp` as it uses embedded storage. Attempting to move `ZoneTmp` would + //! result in assertion failure in debug mode and undefined behavior in release mode. + inline Zone(Zone&& other) noexcept + : _ptr(other._ptr), + _end(other._end), + _block(other._block), + _packedData(other._packedData) { + ASMJIT_ASSERT(!other.isTemporary()); + other._block = const_cast(&_zeroBlock); + other._ptr = other._block->data(); + other._end = other._block->data(); + } + + //! Destroys the `Zone` instance. + //! + //! This will destroy the `Zone` instance and release all blocks of memory allocated by it. It performs implicit + //! `reset(ResetPolicy::kHard)`. + ASMJIT_INLINE_NODEBUG ~Zone() noexcept { reset(ResetPolicy::kHard); } + + ASMJIT_API void _init(size_t blockSize, size_t blockAlignment, const Support::Temporary* temporary) noexcept; + + //! Resets the `Zone` invalidating all blocks allocated. + //! + //! See `Globals::ResetPolicy` for more details. + ASMJIT_API void reset(ResetPolicy resetPolicy = ResetPolicy::kSoft) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether this `Zone` is actually a `ZoneTmp` that uses temporary memory. + ASMJIT_INLINE_NODEBUG bool isTemporary() const noexcept { return _isTemporary != 0; } + + //! Returns the default block size. + ASMJIT_INLINE_NODEBUG size_t blockSize() const noexcept { return _blockSize; } + //! Returns the default block alignment. + ASMJIT_INLINE_NODEBUG size_t blockAlignment() const noexcept { return size_t(1) << _blockAlignmentShift; } + //! Returns remaining size of the current block. + ASMJIT_INLINE_NODEBUG size_t remainingSize() const noexcept { return (size_t)(_end - _ptr); } + + //! Returns the current zone cursor (dangerous). + //! + //! This is a function that can be used to get exclusive access to the current block's memory buffer. + template + ASMJIT_INLINE_NODEBUG T* ptr() noexcept { return reinterpret_cast(_ptr); } + + //! Returns the end of the current zone block, only useful if you use `ptr()`. + template + ASMJIT_INLINE_NODEBUG T* end() noexcept { return reinterpret_cast(_end); } + + //! Sets the current zone pointer to `ptr` (must be within the current block). + template + inline void setPtr(T* ptr) noexcept { + uint8_t* p = reinterpret_cast(ptr); + ASMJIT_ASSERT(p >= _ptr && p <= _end); + _ptr = p; + } + + //! Sets the end zone pointer to `end` (must be within the current block). + template + inline void setEnd(T* end) noexcept { + uint8_t* p = reinterpret_cast(end); + ASMJIT_ASSERT(p >= _ptr && p <= _end); + _end = p; + } + + //! \} + + //! \name Utilities + //! \{ + + inline void swap(Zone& other) noexcept { + // This could lead to a disaster. + ASMJIT_ASSERT(!this->isTemporary()); + ASMJIT_ASSERT(!other.isTemporary()); + + std::swap(_ptr, other._ptr); + std::swap(_end, other._end); + std::swap(_block, other._block); + std::swap(_packedData, other._packedData); + } + + //! Aligns the current pointer to `alignment`. + ASMJIT_INLINE_NODEBUG void align(size_t alignment) noexcept { + _ptr = Support::min(Support::alignUp(_ptr, alignment), _end); + } + + //! Ensures the remaining size is at least equal or greater than `size`. + //! + //! \note This function doesn't respect any alignment. If you need to ensure there is enough room for an aligned + //! allocation you need to call `align()` before calling `ensure()`. + ASMJIT_INLINE_NODEBUG Error ensure(size_t size) noexcept { + if (size <= remainingSize()) + return kErrorOk; + else + return _alloc(0, 1) ? kErrorOk : DebugUtils::errored(kErrorOutOfMemory); + } + + inline void _assignBlock(Block* block) noexcept { + size_t alignment = blockAlignment(); + _ptr = Support::alignUp(block->data(), alignment); + _end = Support::alignDown(block->data() + block->size, alignment); + _block = block; + } + + inline void _assignZeroBlock() noexcept { + Block* block = const_cast(&_zeroBlock); + _ptr = block->data(); + _end = block->data(); + _block = block; + } + + //! \} + + //! \name Allocation + //! \{ + + //! Allocates the requested memory specified by `size`. + //! + //! Pointer returned is valid until the `Zone` instance is destroyed or reset by calling `reset()`. If you plan to + //! make an instance of C++ from the given pointer use placement `new` and `delete` operators: + //! + //! ``` + //! using namespace asmjit; + //! + //! class Object { ... }; + //! + //! // Create Zone with default block size of approximately 65536 bytes. + //! Zone zone(65536 - Zone::kBlockOverhead); + //! + //! // Create your objects using zone object allocating, for example: + //! Object* obj = static_cast( zone.alloc(sizeof(Object)) ); + //! + //! if (!obj) { + //! // Handle out of memory error. + //! } + //! + //! // Placement `new` and `delete` operators can be used to instantiate it. + //! new(obj) Object(); + //! + //! // ... lifetime of your objects ... + //! + //! // To destroy the instance (if required). + //! obj->~Object(); + //! + //! // Reset or destroy `Zone`. + //! zone.reset(); + //! ``` + inline void* alloc(size_t size) noexcept { + if (ASMJIT_UNLIKELY(size > remainingSize())) + return _alloc(size, 1); + + uint8_t* ptr = _ptr; + _ptr += size; + return static_cast(ptr); + } + + //! Allocates the requested memory specified by `size` and `alignment`. + inline void* alloc(size_t size, size_t alignment) noexcept { + ASMJIT_ASSERT(Support::isPowerOf2(alignment)); + uint8_t* ptr = Support::alignUp(_ptr, alignment); + + if (ptr >= _end || size > (size_t)(_end - ptr)) + return _alloc(size, alignment); + + _ptr = ptr + size; + return static_cast(ptr); + } + + //! Allocates the requested memory specified by `size` without doing any checks. + //! + //! Can only be called if `remainingSize()` returns size at least equal to `size`. + inline void* allocNoCheck(size_t size) noexcept { + ASMJIT_ASSERT(remainingSize() >= size); + + uint8_t* ptr = _ptr; + _ptr += size; + return static_cast(ptr); + } + + //! Allocates the requested memory specified by `size` and `alignment` without doing any checks. + //! + //! Performs the same operation as `Zone::allocNoCheck(size)` with `alignment` applied. + inline void* allocNoCheck(size_t size, size_t alignment) noexcept { + ASMJIT_ASSERT(Support::isPowerOf2(alignment)); + + uint8_t* ptr = Support::alignUp(_ptr, alignment); + ASMJIT_ASSERT(size <= (size_t)(_end - ptr)); + + _ptr = ptr + size; + return static_cast(ptr); + } + + //! Allocates `size` bytes of zeroed memory. See `alloc()` for more details. + ASMJIT_API void* allocZeroed(size_t size, size_t alignment = 1) noexcept; + + //! Like `alloc()`, but the return pointer is casted to `T*`. + template + inline T* allocT(size_t size = sizeof(T), size_t alignment = alignof(T)) noexcept { + return static_cast(alloc(size, alignment)); + } + + //! Like `allocNoCheck()`, but the return pointer is casted to `T*`. + template + inline T* allocNoCheckT(size_t size = sizeof(T), size_t alignment = alignof(T)) noexcept { + return static_cast(allocNoCheck(size, alignment)); + } + + //! Like `allocZeroed()`, but the return pointer is casted to `T*`. + template + inline T* allocZeroedT(size_t size = sizeof(T), size_t alignment = alignof(T)) noexcept { + return static_cast(allocZeroed(size, alignment)); + } + + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT() noexcept { + void* p = alloc(sizeof(T), alignof(T)); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(); + } + + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT(Args&&... args) noexcept { + void* p = alloc(sizeof(T), alignof(T)); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(std::forward(args)...); + } + + //! \cond INTERNAL + //! + //! Internal alloc function used by other inlines. + ASMJIT_API void* _alloc(size_t size, size_t alignment) noexcept; + //! \endcond + + //! Helper to duplicate data. + ASMJIT_API void* dup(const void* data, size_t size, bool nullTerminate = false) noexcept; + + //! Helper to duplicate data. + inline void* dupAligned(const void* data, size_t size, size_t alignment, bool nullTerminate = false) noexcept { + align(alignment); + return dup(data, size, nullTerminate); + } + + //! Helper to duplicate a formatted string, maximum size is 256 bytes. + ASMJIT_API char* sformat(const char* str, ...) noexcept; + + //! \} +}; + +//! \ref Zone with `N` bytes of a static storage, used for the initial block. +//! +//! Temporary zones are used in cases where it's known that some memory will be required, but in many cases it won't +//! exceed N bytes, so the whole operation can be performed without a dynamic memory allocation. +template +class ZoneTmp : public Zone { +public: + ASMJIT_NONCOPYABLE(ZoneTmp) + + //! Temporary storage, embedded after \ref Zone. + struct Storage { + char data[N]; + } _storage; + + //! Creates a temporary zone. Dynamic block size is specified by `blockSize`. + inline explicit ZoneTmp(size_t blockSize, size_t blockAlignment = 1) noexcept + : Zone(blockSize, blockAlignment, Support::Temporary(_storage.data, N)) {} +}; + +//! Zone-based memory allocator that uses an existing `Zone` and provides a `release()` functionality on top of it. +//! It uses `Zone` only for chunks that can be pooled, and uses libc `malloc()` for chunks that are large. +//! +//! The advantage of ZoneAllocator is that it can allocate small chunks of memory really fast, and these chunks, +//! when released, will be reused by consecutive calls to `alloc()`. Also, since ZoneAllocator uses `Zone`, you can +//! turn any `Zone` into a `ZoneAllocator`, and use it in your `Pass` when necessary. +//! +//! ZoneAllocator is used by AsmJit containers to make containers having only few elements fast (and lightweight) +//! and to allow them to grow and use dynamic blocks when require more storage. +class ZoneAllocator { +public: + ASMJIT_NONCOPYABLE(ZoneAllocator) + + //! \cond INTERNAL + + // In short, we pool chunks of these sizes: + // [32, 64, 96, 128, 192, 256, 320, 384, 448, 512] + + enum : uint32_t { + //! How many bytes per a low granularity pool (has to be at least 16). + kLoGranularity = 32, + //! Number of slots of a low granularity pool. + kLoCount = 4, + //! Maximum size of a block that can be allocated in a low granularity pool. + kLoMaxSize = kLoGranularity * kLoCount, + + //! How many bytes per a high granularity pool. + kHiGranularity = 64, + //! Number of slots of a high granularity pool. + kHiCount = 6, + //! Maximum size of a block that can be allocated in a high granularity pool. + kHiMaxSize = kLoMaxSize + kHiGranularity * kHiCount, + + //! Alignment of every pointer returned by `alloc()`. + kBlockAlignment = kLoGranularity + }; + + //! Single-linked list used to store unused chunks. + struct Slot { + //! Link to a next slot in a single-linked list. + Slot* next; + }; + + //! A block of memory that has been allocated dynamically and is not part of block-list used by the allocator. + //! This is used to keep track of all these blocks so they can be freed by `reset()` if not freed explicitly. + struct DynamicBlock { + DynamicBlock* prev; + DynamicBlock* next; + }; + + //! \endcond + + //! \name Members + //! \{ + + //! Zone used to allocate memory that fits into slots. + Zone* _zone {}; + //! Indexed slots containing released memory. + Slot* _slots[kLoCount + kHiCount] {}; + //! Dynamic blocks for larger allocations (no slots). + DynamicBlock* _dynamicBlocks {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a new `ZoneAllocator`. + //! + //! \note To use it, you must first `init()` it. + ASMJIT_INLINE_NODEBUG ZoneAllocator() noexcept {} + + //! Creates a new `ZoneAllocator` initialized to use `zone`. + ASMJIT_INLINE_NODEBUG explicit ZoneAllocator(Zone* zone) noexcept + : _zone(zone) {} + + //! Destroys the `ZoneAllocator`. + ASMJIT_INLINE_NODEBUG ~ZoneAllocator() noexcept { reset(); } + + //! Tests whether the `ZoneAllocator` is initialized (i.e. has `Zone`). + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _zone != nullptr; } + + //! Convenience function to initialize the `ZoneAllocator` with `zone`. + //! + //! It's the same as calling `reset(zone)`. + ASMJIT_INLINE_NODEBUG void init(Zone* zone) noexcept { reset(zone); } + + //! Resets this `ZoneAllocator` and also forget about the current `Zone` which is attached (if any). Reset + //! optionally attaches a new `zone` passed, or keeps the `ZoneAllocator` in an uninitialized state, if + //! `zone` is null. + ASMJIT_API void reset(Zone* zone = nullptr) noexcept; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns the assigned `Zone` of this allocator or null if this `ZoneAllocator` is not initialized. + ASMJIT_INLINE_NODEBUG Zone* zone() const noexcept { return _zone; } + + //! \} + + //! \cond + //! \name Internals + //! \{ + + //! Returns the slot index to be used for `size`. Returns `true` if a valid slot has been written to `slot` and + //! `allocatedSize` has been filled with slot exact size (`allocatedSize` can be equal or slightly greater than + //! `size`). + static inline bool _getSlotIndex(size_t size, uint32_t& slot) noexcept { + ASMJIT_ASSERT(size > 0); + if (size > kHiMaxSize) + return false; + + if (size <= kLoMaxSize) + slot = uint32_t((size - 1) / kLoGranularity); + else + slot = uint32_t((size - kLoMaxSize - 1) / kHiGranularity) + kLoCount; + + return true; + } + + //! \overload + static inline bool _getSlotIndex(size_t size, uint32_t& slot, size_t& allocatedSize) noexcept { + ASMJIT_ASSERT(size > 0); + if (size > kHiMaxSize) + return false; + + if (size <= kLoMaxSize) { + slot = uint32_t((size - 1) / kLoGranularity); + allocatedSize = Support::alignUp(size, kLoGranularity); + } + else { + slot = uint32_t((size - kLoMaxSize - 1) / kHiGranularity) + kLoCount; + allocatedSize = Support::alignUp(size, kHiGranularity); + } + + return true; + } + + //! \} + //! \endcond + + //! \name Allocation + //! \{ + + //! \cond INTERNAL + ASMJIT_API void* _alloc(size_t size, size_t& allocatedSize) noexcept; + ASMJIT_API void* _allocZeroed(size_t size, size_t& allocatedSize) noexcept; + ASMJIT_API void _releaseDynamic(void* p, size_t size) noexcept; + //! \endcond + + //! Allocates `size` bytes of memory, ideally from an available pool. + //! + //! \note `size` can't be zero, it will assert in debug mode in such case. + inline void* alloc(size_t size) noexcept { + ASMJIT_ASSERT(isInitialized()); + size_t allocatedSize; + return _alloc(size, allocatedSize); + } + + //! Like `alloc(size)`, but provides a second argument `allocatedSize` that provides a way to know how big + //! the block returned actually is. This is useful for containers to prevent growing too early. + inline void* alloc(size_t size, size_t& allocatedSize) noexcept { + ASMJIT_ASSERT(isInitialized()); + return _alloc(size, allocatedSize); + } + + //! Like `alloc()`, but the return pointer is casted to `T*`. + template + inline T* allocT(size_t size = sizeof(T)) noexcept { + return static_cast(alloc(size)); + } + + //! Like `alloc(size)`, but returns zeroed memory. + inline void* allocZeroed(size_t size) noexcept { + ASMJIT_ASSERT(isInitialized()); + size_t allocatedSize; + return _allocZeroed(size, allocatedSize); + } + + //! Like `alloc(size, allocatedSize)`, but returns zeroed memory. + inline void* allocZeroed(size_t size, size_t& allocatedSize) noexcept { + ASMJIT_ASSERT(isInitialized()); + return _allocZeroed(size, allocatedSize); + } + + //! Like `allocZeroed()`, but the return pointer is casted to `T*`. + template + inline T* allocZeroedT(size_t size = sizeof(T)) noexcept { + return static_cast(allocZeroed(size)); + } + + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT() noexcept { + void* p = allocT(); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(); + } + //! Like `new(std::nothrow) T(...)`, but allocated by `Zone`. + template + inline T* newT(Args&&... args) noexcept { + void* p = allocT(); + if (ASMJIT_UNLIKELY(!p)) + return nullptr; + return new(Support::PlacementNew{p}) T(std::forward(args)...); + } + + //! Releases the memory previously allocated by `alloc()`. The `size` argument has to be the same as used to call + //! `alloc()` or `allocatedSize` returned by `alloc()`. + inline void release(void* p, size_t size) noexcept { + ASMJIT_ASSERT(isInitialized()); + ASMJIT_ASSERT(p != nullptr); + ASMJIT_ASSERT(size != 0); + + uint32_t slot; + if (_getSlotIndex(size, slot)) { + static_cast(p)->next = static_cast(_slots[slot]); + _slots[slot] = static_cast(p); + } + else { + _releaseDynamic(p, size); + } + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONE_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zonehash.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonehash.h new file mode 100644 index 0000000000000000000000000000000000000000..df36d56dffef82f21f2afdd36b9e32bdacbb0039 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonehash.h @@ -0,0 +1,186 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONEHASH_H_INCLUDED +#define ASMJIT_CORE_ZONEHASH_H_INCLUDED + +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Node used by \ref ZoneHash template. +//! +//! You must provide function `bool eq(const Key& key)` in order to make `ZoneHash::get()` working. +class ZoneHashNode { +public: + ASMJIT_NONCOPYABLE(ZoneHashNode) + + inline ZoneHashNode(uint32_t hashCode = 0) noexcept + : _hashNext(nullptr), + _hashCode(hashCode), + _customData(0) {} + + //! Next node in the chain, null if it terminates the chain. + ZoneHashNode* _hashNext; + //! Precalculated hash-code of key. + uint32_t _hashCode; + //! Padding, can be reused by any Node that inherits `ZoneHashNode`. + uint32_t _customData; +}; + +//! Base class used by \ref ZoneHash template +class ZoneHashBase { +public: + ASMJIT_NONCOPYABLE(ZoneHashBase) + + //! Buckets data. + ZoneHashNode** _data; + //! Count of records inserted into the hash table. + size_t _size; + //! Count of hash buckets. + uint32_t _bucketsCount; + //! When buckets array should grow (only checked after insertion). + uint32_t _bucketsGrow; + //! Reciprocal value of `_bucketsCount`. + uint32_t _rcpValue; + //! How many bits to shift right when hash is multiplied with `_rcpValue`. + uint8_t _rcpShift; + //! Prime value index in internal prime array. + uint8_t _primeIndex; + + //! Embedded data, used by empty hash tables. + ZoneHashNode* _embedded[1]; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneHashBase() noexcept { + reset(); + } + + inline ZoneHashBase(ZoneHashBase&& other) noexcept { + _data = other._data; + _size = other._size; + _bucketsCount = other._bucketsCount; + _bucketsGrow = other._bucketsGrow; + _rcpValue = other._rcpValue; + _rcpShift = other._rcpShift; + _primeIndex = other._primeIndex; + _embedded[0] = other._embedded[0]; + + if (_data == other._embedded) _data = _embedded; + } + + inline void reset() noexcept { + _data = _embedded; + _size = 0; + _bucketsCount = 1; + _bucketsGrow = 1; + _rcpValue = 1; + _rcpShift = 0; + _primeIndex = 0; + _embedded[0] = nullptr; + } + + inline void release(ZoneAllocator* allocator) noexcept { + ZoneHashNode** oldData = _data; + if (oldData != _embedded) + allocator->release(oldData, _bucketsCount * sizeof(ZoneHashNode*)); + reset(); + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + ASMJIT_INLINE_NODEBUG size_t size() const noexcept { return _size; } + + //! \} + + //! \name Utilities + //! \{ + + inline void _swap(ZoneHashBase& other) noexcept { + std::swap(_data, other._data); + std::swap(_size, other._size); + std::swap(_bucketsCount, other._bucketsCount); + std::swap(_bucketsGrow, other._bucketsGrow); + std::swap(_rcpValue, other._rcpValue); + std::swap(_rcpShift, other._rcpShift); + std::swap(_primeIndex, other._primeIndex); + std::swap(_embedded[0], other._embedded[0]); + + if (_data == other._embedded) _data = _embedded; + if (other._data == _embedded) other._data = other._embedded; + } + + //! \cond INTERNAL + inline uint32_t _calcMod(uint32_t hash) const noexcept { + uint32_t x = uint32_t((uint64_t(hash) * _rcpValue) >> _rcpShift); + return hash - x * _bucketsCount; + } + + ASMJIT_API void _rehash(ZoneAllocator* allocator, uint32_t newCount) noexcept; + ASMJIT_API ZoneHashNode* _insert(ZoneAllocator* allocator, ZoneHashNode* node) noexcept; + ASMJIT_API ZoneHashNode* _remove(ZoneAllocator* allocator, ZoneHashNode* node) noexcept; + //! \endcond + + //! \} +}; + +//! Low-level hash table specialized for storing string keys and POD values. +//! +//! This hash table allows duplicates to be inserted (the API is so low level that it's up to you if you allow it or +//! not, as you should first `get()` the node and then modify it or insert a new node by using `insert()`, depending +//! on the intention). +template +class ZoneHash : public ZoneHashBase { +public: + ASMJIT_NONCOPYABLE(ZoneHash) + + typedef NodeT Node; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneHash() noexcept + : ZoneHashBase() {} + + ASMJIT_INLINE_NODEBUG ZoneHash(ZoneHash&& other) noexcept + : ZoneHash(other) {} + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneHash& other) noexcept { ZoneHashBase::_swap(other); } + + template + inline NodeT* get(const KeyT& key) const noexcept { + uint32_t hashMod = _calcMod(key.hashCode()); + NodeT* node = static_cast(_data[hashMod]); + + while (node && !key.matches(node)) + node = static_cast(node->_hashNext); + return node; + } + + ASMJIT_INLINE_NODEBUG NodeT* insert(ZoneAllocator* allocator, NodeT* node) noexcept { return static_cast(_insert(allocator, node)); } + ASMJIT_INLINE_NODEBUG NodeT* remove(ZoneAllocator* allocator, NodeT* node) noexcept { return static_cast(_remove(allocator, node)); } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONEHASH_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zonelist.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonelist.h new file mode 100644 index 0000000000000000000000000000000000000000..80d84b9fe878ebe59cc2a05bcbbea0bdc43447bd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonelist.h @@ -0,0 +1,208 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONELIST_H_INCLUDED +#define ASMJIT_CORE_ZONELIST_H_INCLUDED + +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Node used by \ref ZoneList template. +template +class ZoneListNode { +public: + ASMJIT_NONCOPYABLE(ZoneListNode) + + //! \name Constants + //! \{ + + enum : size_t { + kNodeIndexPrev = 0, + kNodeIndexNext = 1 + }; + + //! \} + + //! \name Members + //! \{ + + NodeT* _listNodes[2]; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneListNode() noexcept + : _listNodes { nullptr, nullptr } {} + + ASMJIT_INLINE_NODEBUG ZoneListNode(ZoneListNode&& other) noexcept + : _listNodes { other._listNodes[0], other._listNodes[1] } {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool hasPrev() const noexcept { return _listNodes[kNodeIndexPrev] != nullptr; } + ASMJIT_INLINE_NODEBUG bool hasNext() const noexcept { return _listNodes[kNodeIndexNext] != nullptr; } + + ASMJIT_INLINE_NODEBUG NodeT* prev() const noexcept { return _listNodes[kNodeIndexPrev]; } + ASMJIT_INLINE_NODEBUG NodeT* next() const noexcept { return _listNodes[kNodeIndexNext]; } + + //! \} +}; + +//! Zone allocated list container that uses nodes of `NodeT` type. +template +class ZoneList { +public: + ASMJIT_NONCOPYABLE(ZoneList) + + //! \name Constants + //! \{ + + enum : size_t { + kNodeIndexFirst = 0, + kNodeIndexLast = 1 + }; + + //! \} + + //! \name Members + //! \{ + + NodeT* _nodes[2] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneList() noexcept {} + + ASMJIT_INLINE_NODEBUG ZoneList(ZoneList&& other) noexcept + : _nodes { other._nodes[0], other._nodes[1] } {} + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _nodes[0] = nullptr; + _nodes[1] = nullptr; + } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _nodes[0] == nullptr; } + ASMJIT_INLINE_NODEBUG NodeT* first() const noexcept { return _nodes[kNodeIndexFirst]; } + ASMJIT_INLINE_NODEBUG NodeT* last() const noexcept { return _nodes[kNodeIndexLast]; } + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneList& other) noexcept { + std::swap(_nodes[0], other._nodes[0]); + std::swap(_nodes[1], other._nodes[1]); + } + + // Can be used to both append and prepend. + inline void _addNode(NodeT* node, size_t dir) noexcept { + NodeT* prev = _nodes[dir]; + + node->_listNodes[!dir] = prev; + _nodes[dir] = node; + if (prev) + prev->_listNodes[dir] = node; + else + _nodes[!dir] = node; + } + + // Can be used to both append and prepend. + inline void _insertNode(NodeT* ref, NodeT* node, size_t dir) noexcept { + ASMJIT_ASSERT(ref != nullptr); + + NodeT* prev = ref; + NodeT* next = ref->_listNodes[dir]; + + prev->_listNodes[dir] = node; + if (next) + next->_listNodes[!dir] = node; + else + _nodes[dir] = node; + + node->_listNodes[!dir] = prev; + node->_listNodes[ dir] = next; + } + + ASMJIT_INLINE_NODEBUG void append(NodeT* node) noexcept { _addNode(node, kNodeIndexLast); } + ASMJIT_INLINE_NODEBUG void prepend(NodeT* node) noexcept { _addNode(node, kNodeIndexFirst); } + + ASMJIT_INLINE_NODEBUG void insertAfter(NodeT* ref, NodeT* node) noexcept { _insertNode(ref, node, NodeT::kNodeIndexNext); } + ASMJIT_INLINE_NODEBUG void insertBefore(NodeT* ref, NodeT* node) noexcept { _insertNode(ref, node, NodeT::kNodeIndexPrev); } + + inline NodeT* unlink(NodeT* node) noexcept { + NodeT* prev = node->prev(); + NodeT* next = node->next(); + + if (prev) { prev->_listNodes[1] = next; node->_listNodes[0] = nullptr; } else { _nodes[0] = next; } + if (next) { next->_listNodes[0] = prev; node->_listNodes[1] = nullptr; } else { _nodes[1] = prev; } + + node->_listNodes[0] = nullptr; + node->_listNodes[1] = nullptr; + + return node; + } + + inline NodeT* popFirst() noexcept { + NodeT* node = _nodes[0]; + ASMJIT_ASSERT(node != nullptr); + + NodeT* next = node->next(); + _nodes[0] = next; + + if (next) { + next->_listNodes[0] = nullptr; + node->_listNodes[1] = nullptr; + } + else { + _nodes[1] = nullptr; + } + + return node; + } + + inline NodeT* pop() noexcept { + NodeT* node = _nodes[1]; + ASMJIT_ASSERT(node != nullptr); + + NodeT* prev = node->prev(); + _nodes[1] = prev; + + if (prev) { + prev->_listNodes[1] = nullptr; + node->_listNodes[0] = nullptr; + } + else { + _nodes[0] = nullptr; + } + + return node; + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONELIST_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zonestack.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonestack.h new file mode 100644 index 0000000000000000000000000000000000000000..b939ef0021b8b2dda64bcee73ed2f849542b0f66 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonestack.h @@ -0,0 +1,235 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONESTACK_H_INCLUDED +#define ASMJIT_CORE_ZONESTACK_H_INCLUDED + +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Base class used by \ref ZoneStack. +class ZoneStackBase { +public: + ASMJIT_NONCOPYABLE(ZoneStackBase) + + //! \name Constants + //! \{ + + enum : size_t { + kBlockIndexPrev = 0, + kBlockIndexNext = 1, + + kBlockIndexFirst = 0, + kBlockIndexLast = 1, + + kBlockSize = ZoneAllocator::kHiMaxSize + }; + + //! \} + + //! \name Types + //! \{ + + struct Block { + //! Next and previous blocks. + Block* _link[2]; + //! Pointer to the start of the array. + void* _start; + //! Pointer to the end of the array. + void* _end; + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _start == _end; } + ASMJIT_INLINE_NODEBUG Block* prev() const noexcept { return _link[kBlockIndexPrev]; } + ASMJIT_INLINE_NODEBUG Block* next() const noexcept { return _link[kBlockIndexNext]; } + + ASMJIT_INLINE_NODEBUG void setPrev(Block* block) noexcept { _link[kBlockIndexPrev] = block; } + ASMJIT_INLINE_NODEBUG void setNext(Block* block) noexcept { _link[kBlockIndexNext] = block; } + + template + ASMJIT_INLINE_NODEBUG T* start() const noexcept { return static_cast(_start); } + template + ASMJIT_INLINE_NODEBUG void setStart(T* start) noexcept { _start = static_cast(start); } + + template + ASMJIT_INLINE_NODEBUG T* end() const noexcept { return (T*)_end; } + template + ASMJIT_INLINE_NODEBUG void setEnd(T* end) noexcept { _end = (void*)end; } + + template + ASMJIT_INLINE_NODEBUG T* data() const noexcept { return (T*)((uint8_t*)(this) + sizeof(Block)); } + + template + ASMJIT_INLINE_NODEBUG bool canPrepend() const noexcept { return _start > data(); } + + template + ASMJIT_INLINE_NODEBUG bool canAppend() const noexcept { + size_t kNumBlockItems = (kBlockSize - sizeof(Block)) / sizeof(T); + size_t kStartBlockIndex = sizeof(Block); + size_t kEndBlockIndex = kStartBlockIndex + kNumBlockItems * sizeof(T); + + return (uintptr_t)_end <= ((uintptr_t)this + kEndBlockIndex - sizeof(T)); + } + }; + + //! \} + + //! \name Members + //! \{ + + //! Allocator used to allocate data. + ZoneAllocator* _allocator {}; + //! First and last blocks. + Block* _block[2] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneStackBase() noexcept {} + ASMJIT_INLINE_NODEBUG ~ZoneStackBase() noexcept { reset(); } + + ASMJIT_INLINE_NODEBUG bool isInitialized() const noexcept { return _allocator != nullptr; } + ASMJIT_API Error _init(ZoneAllocator* allocator, size_t middleIndex) noexcept; + ASMJIT_INLINE_NODEBUG Error reset() noexcept { return _init(nullptr, 0); } + + //! \} + + //! \name Accessors + //! \{ + + //! Returns `ZoneAllocator` attached to this container. + ASMJIT_INLINE_NODEBUG ZoneAllocator* allocator() const noexcept { return _allocator; } + + inline bool empty() const noexcept { + ASMJIT_ASSERT(isInitialized()); + return _block[0]->start() == _block[1]->end(); + } + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + ASMJIT_API Error _prepareBlock(uint32_t side, size_t initialIndex) noexcept; + ASMJIT_API void _cleanupBlock(uint32_t side, size_t middleIndex) noexcept; + + //! \} + //! \endcond +}; + +//! Zone allocated stack container. +template +class ZoneStack : public ZoneStackBase { +public: + ASMJIT_NONCOPYABLE(ZoneStack) + + //! \name Constants + //! \{ + + enum : uint32_t { + kNumBlockItems = uint32_t((kBlockSize - sizeof(Block)) / sizeof(T)), + kStartBlockIndex = uint32_t(sizeof(Block)), + kMidBlockIndex = uint32_t(kStartBlockIndex + (kNumBlockItems / 2) * sizeof(T)), + kEndBlockIndex = uint32_t(kStartBlockIndex + (kNumBlockItems ) * sizeof(T)) + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + inline ZoneStack() noexcept {} + inline ~ZoneStack() noexcept {} + + inline Error init(ZoneAllocator* allocator) noexcept { return _init(allocator, kMidBlockIndex); } + + //! \} + + //! \name Utilities + //! \{ + + inline Error prepend(T item) noexcept { + ASMJIT_ASSERT(isInitialized()); + Block* block = _block[kBlockIndexFirst]; + + if (!block->canPrepend()) { + ASMJIT_PROPAGATE(_prepareBlock(kBlockIndexFirst, kEndBlockIndex)); + block = _block[kBlockIndexFirst]; + } + + T* ptr = block->start() - 1; + ASMJIT_ASSERT(ptr >= block->data() && ptr <= block->data() + (kNumBlockItems - 1)); + *ptr = item; + block->setStart(ptr); + return kErrorOk; + } + + inline Error append(T item) noexcept { + ASMJIT_ASSERT(isInitialized()); + Block* block = _block[kBlockIndexLast]; + + if (!block->canAppend()) { + ASMJIT_PROPAGATE(_prepareBlock(kBlockIndexLast, kStartBlockIndex)); + block = _block[kBlockIndexLast]; + } + + T* ptr = block->end(); + ASMJIT_ASSERT(ptr >= block->data() && ptr <= block->data() + (kNumBlockItems - 1)); + + *ptr++ = item; + block->setEnd(ptr); + return kErrorOk; + } + + inline T popFirst() noexcept { + ASMJIT_ASSERT(isInitialized()); + ASMJIT_ASSERT(!empty()); + + Block* block = _block[kBlockIndexFirst]; + ASMJIT_ASSERT(!block->empty()); + + T* ptr = block->start(); + T item = *ptr++; + + block->setStart(ptr); + if (block->empty()) + _cleanupBlock(kBlockIndexFirst, kMidBlockIndex); + + return item; + } + + inline T pop() noexcept { + ASMJIT_ASSERT(isInitialized()); + ASMJIT_ASSERT(!empty()); + + Block* block = _block[kBlockIndexLast]; + ASMJIT_ASSERT(!block->empty()); + + T* ptr = block->end(); + T item = *--ptr; + ASMJIT_ASSERT(ptr >= block->data()); + ASMJIT_ASSERT(ptr >= block->start()); + + block->setEnd(ptr); + if (block->empty()) + _cleanupBlock(kBlockIndexLast, kMidBlockIndex); + + return item; + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONESTACK_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zonestring.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonestring.h new file mode 100644 index 0000000000000000000000000000000000000000..e72d05fb0cabc8756d814decedb0db1fde4ce733 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonestring.h @@ -0,0 +1,120 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONESTRING_H_INCLUDED +#define ASMJIT_CORE_ZONESTRING_H_INCLUDED + +#include "../core/globals.h" +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! A helper class used by \ref ZoneString implementation. +struct ZoneStringBase { + union { + struct { + uint32_t _size; + char _embedded[sizeof(void*) * 2 - 4]; + }; + struct { + void* _dummy; + char* _external; + }; + }; + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _dummy = nullptr; + _external = nullptr; + } + + Error setData(Zone* zone, uint32_t maxEmbeddedSize, const char* str, size_t size) noexcept { + if (size == SIZE_MAX) + size = strlen(str); + + if (size <= maxEmbeddedSize) { + memcpy(_embedded, str, size); + _embedded[size] = '\0'; + } + else { + char* external = static_cast(zone->dup(str, size, true)); + if (ASMJIT_UNLIKELY(!external)) + return DebugUtils::errored(kErrorOutOfMemory); + _external = external; + } + + _size = uint32_t(size); + return kErrorOk; + } +}; + +//! A string template that can be zone allocated. +//! +//! Helps with creating strings that can be either statically allocated if they are small, or externally allocated +//! in case their size exceeds the limit. The `N` represents the size of the whole `ZoneString` structure, based on +//! that size the maximum size of the internal buffer is determined. +template +class ZoneString { +public: + //! \name Constants + //! \{ + + enum : uint32_t { + kWholeSize = (N > sizeof(ZoneStringBase)) ? uint32_t(N) : uint32_t(sizeof(ZoneStringBase)), + kMaxEmbeddedSize = kWholeSize - 5 + }; + + //! \} + + //! \name Members + //! \{ + + union { + ZoneStringBase _base; + char _wholeData[kWholeSize]; + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneString() noexcept { reset(); } + ASMJIT_INLINE_NODEBUG void reset() noexcept { _base.reset(); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the string is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _base._size == 0; } + + //! Returns the string data. + ASMJIT_INLINE_NODEBUG const char* data() const noexcept { return _base._size <= kMaxEmbeddedSize ? _base._embedded : _base._external; } + //! Returns the string size. + ASMJIT_INLINE_NODEBUG uint32_t size() const noexcept { return _base._size; } + + //! Tests whether the string is embedded (e.g. no dynamically allocated). + ASMJIT_INLINE_NODEBUG bool isEmbedded() const noexcept { return _base._size <= kMaxEmbeddedSize; } + + //! Copies a new `data` of the given `size` to the string. + //! + //! If the `size` exceeds the internal buffer the given `zone` will be used to duplicate the data, otherwise + //! the internal buffer will be used as a storage. + ASMJIT_INLINE_NODEBUG Error setData(Zone* zone, const char* data, size_t size) noexcept { + return _base.setData(zone, kMaxEmbeddedSize, data, size); + } + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONESTRING_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zonetree.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonetree.h new file mode 100644 index 0000000000000000000000000000000000000000..3ef8e6d7580a1c41fb0f10d784ecb5fa24dc6044 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonetree.h @@ -0,0 +1,376 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONETREE_H_INCLUDED +#define ASMJIT_CORE_ZONETREE_H_INCLUDED + +#include "../core/support.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! RB-Tree node. +//! +//! The color is stored in a least significant bit of the `left` node. +//! +//! WARNING: Always use accessors to access left and right children. +class ZoneTreeNode { +public: + ASMJIT_NONCOPYABLE(ZoneTreeNode) + + //! \name Constants + //! \{ + + enum : uintptr_t { + kRedMask = 0x1, + kPtrMask = ~kRedMask + }; + + //! \} + + //! \name Members + //! \{ + + uintptr_t _rbNodeData[2] {}; + + //! \} + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTreeNode() noexcept {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool isRed() const noexcept { return static_cast(_rbNodeData[0] & kRedMask); } + + ASMJIT_INLINE_NODEBUG bool hasChild(size_t i) const noexcept { return _rbNodeData[i] > kRedMask; } + ASMJIT_INLINE_NODEBUG bool hasLeft() const noexcept { return _rbNodeData[0] > kRedMask; } + ASMJIT_INLINE_NODEBUG bool hasRight() const noexcept { return _rbNodeData[1] != 0; } + + template + ASMJIT_INLINE_NODEBUG T* child(size_t i) const noexcept { return static_cast(_getChild(i)); } + template + ASMJIT_INLINE_NODEBUG T* left() const noexcept { return static_cast(_getLeft()); } + template + ASMJIT_INLINE_NODEBUG T* right() const noexcept { return static_cast(_getRight()); } + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTreeNode* _getChild(size_t i) const noexcept { return (ZoneTreeNode*)(_rbNodeData[i] & kPtrMask); } + ASMJIT_INLINE_NODEBUG ZoneTreeNode* _getLeft() const noexcept { return (ZoneTreeNode*)(_rbNodeData[0] & kPtrMask); } + ASMJIT_INLINE_NODEBUG ZoneTreeNode* _getRight() const noexcept { return (ZoneTreeNode*)(_rbNodeData[1]); } + + ASMJIT_INLINE_NODEBUG void _setChild(size_t i, ZoneTreeNode* node) noexcept { _rbNodeData[i] = (_rbNodeData[i] & kRedMask) | (uintptr_t)node; } + ASMJIT_INLINE_NODEBUG void _setLeft(ZoneTreeNode* node) noexcept { _rbNodeData[0] = (_rbNodeData[0] & kRedMask) | (uintptr_t)node; } + ASMJIT_INLINE_NODEBUG void _setRight(ZoneTreeNode* node) noexcept { _rbNodeData[1] = (uintptr_t)node; } + + ASMJIT_INLINE_NODEBUG void _makeRed() noexcept { _rbNodeData[0] |= kRedMask; } + ASMJIT_INLINE_NODEBUG void _makeBlack() noexcept { _rbNodeData[0] &= kPtrMask; } + + //! Tests whether the node is RED (RED node must be non-null and must have RED flag set). + static ASMJIT_INLINE_NODEBUG bool _isValidRed(ZoneTreeNode* node) noexcept { return node && node->isRed(); } + + //! \} + //! \endcond +}; + +//! RB-Tree node casted to `NodeT`. +template +class ZoneTreeNodeT : public ZoneTreeNode { +public: + ASMJIT_NONCOPYABLE(ZoneTreeNodeT) + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTreeNodeT() noexcept + : ZoneTreeNode() {} + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG NodeT* child(size_t i) const noexcept { return static_cast(_getChild(i)); } + ASMJIT_INLINE_NODEBUG NodeT* left() const noexcept { return static_cast(_getLeft()); } + ASMJIT_INLINE_NODEBUG NodeT* right() const noexcept { return static_cast(_getRight()); } + + //! \} +}; + +//! RB-Tree. +template +class ZoneTree { +public: + ASMJIT_NONCOPYABLE(ZoneTree) + + typedef NodeT Node; + NodeT* _root {}; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneTree() noexcept {} + ASMJIT_INLINE_NODEBUG ZoneTree(ZoneTree&& other) noexcept + : _root(other._root) {} + ASMJIT_INLINE_NODEBUG void reset() noexcept { _root = nullptr; } + + //! \} + + //! \name Accessors + //! \{ + + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _root == nullptr; } + ASMJIT_INLINE_NODEBUG NodeT* root() const noexcept { return static_cast(_root); } + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneTree& other) noexcept { + std::swap(_root, other._root); + } + + template> + void insert(NodeT* ASMJIT_NONNULL(node), const CompareT& cmp = CompareT()) noexcept { + // Node to insert must not contain garbage. + ASMJIT_ASSERT(!node->hasLeft()); + ASMJIT_ASSERT(!node->hasRight()); + ASMJIT_ASSERT(!node->isRed()); + + if (!_root) { + _root = node; + return; + } + + ZoneTreeNode head; // False root node, + head._setRight(_root); // having root on the right. + + ZoneTreeNode* g = nullptr; // Grandparent. + ZoneTreeNode* p = nullptr; // Parent. + ZoneTreeNode* t = &head; // Iterator. + ZoneTreeNode* q = _root; // Query. + + size_t dir = 0; // Direction for accessing child nodes. + size_t last = 0; // Not needed to initialize, but makes some tools happy. + + node->_makeRed(); // New nodes are always red and violations fixed appropriately. + + // Search down the tree. + for (;;) { + if (!q) { + // Insert new node at the bottom. + q = node; + p->_setChild(dir, node); + } + else if (_isValidRed(q->_getLeft()) && _isValidRed(q->_getRight())) { + // Color flip. + q->_makeRed(); + q->_getLeft()->_makeBlack(); + q->_getRight()->_makeBlack(); + } + + // Fix red violation. + if (_isValidRed(q) && _isValidRed(p)) { + ASMJIT_ASSUME(g != nullptr); + ASMJIT_ASSUME(p != nullptr); + t->_setChild(t->_getRight() == g, + q == p->_getChild(last) ? _singleRotate(g, !last) : _doubleRotate(g, !last)); + } + + // Stop if found. + if (q == node) + break; + + last = dir; + dir = cmp(*static_cast(q), *static_cast(node)) < 0; + + // Update helpers. + if (g) t = g; + + g = p; + p = q; + q = q->_getChild(dir); + } + + // Update root and make it black. + _root = static_cast(head._getRight()); + _root->_makeBlack(); + } + + //! Remove node from RBTree. + template> + void remove(ZoneTreeNode* ASMJIT_NONNULL(node), const CompareT& cmp = CompareT()) noexcept { + ZoneTreeNode head; // False root node, + head._setRight(_root); // having root on the right. + + ZoneTreeNode* g = nullptr; // Grandparent. + ZoneTreeNode* p = nullptr; // Parent. + ZoneTreeNode* q = &head; // Query. + + ZoneTreeNode* f = nullptr; // Found item. + ZoneTreeNode* gf = nullptr; // Found grandparent. + size_t dir = 1; // Direction (0 or 1). + + // Search and push a red down. + while (q->hasChild(dir)) { + size_t last = dir; + + // Update helpers. + g = p; + p = q; + q = q->_getChild(dir); + dir = cmp(*static_cast(q), *static_cast(node)) < 0; + + // Save found node. + if (q == node) { + f = q; + gf = g; + } + + // Push the red node down. + if (!_isValidRed(q) && !_isValidRed(q->_getChild(dir))) { + if (_isValidRed(q->_getChild(!dir))) { + ZoneTreeNode* child = _singleRotate(q, dir); + p->_setChild(last, child); + p = child; + } + else if (!_isValidRed(q->_getChild(!dir)) && p->_getChild(!last)) { + ZoneTreeNode* s = p->_getChild(!last); + if (!_isValidRed(s->_getChild(!last)) && !_isValidRed(s->_getChild(last))) { + // Color flip. + p->_makeBlack(); + s->_makeRed(); + q->_makeRed(); + } + else { + ASMJIT_ASSUME(g != nullptr); + ASMJIT_ASSUME(s != nullptr); + + size_t dir2 = g->_getRight() == p; + ZoneTreeNode* child = g->_getChild(dir2); + + if (_isValidRed(s->_getChild(last))) { + child = _doubleRotate(p, last); + g->_setChild(dir2, child); + } + else if (_isValidRed(s->_getChild(!last))) { + child = _singleRotate(p, last); + g->_setChild(dir2, child); + } + + // Ensure correct coloring. + q->_makeRed(); + child->_makeRed(); + child->_getLeft()->_makeBlack(); + child->_getRight()->_makeBlack(); + } + } + } + } + + // Replace and remove. + ASMJIT_ASSERT(f != nullptr); + ASMJIT_ASSERT(f != &head); + ASMJIT_ASSERT(q != &head); + + p->_setChild(p->_getRight() == q, + q->_getChild(q->_getLeft() == nullptr)); + + // NOTE: The original algorithm used a trick to just copy 'key/value' to `f` and mark `q` for deletion. But this + // is unacceptable here as we really want to destroy the passed `node`. So, we have to make sure that we have + // really removed `f` and not `q`. + if (f != q) { + ASMJIT_ASSERT(f != &head); + ASMJIT_ASSERT(f != gf); + + ZoneTreeNode* n = gf ? gf : &head; + dir = (n == &head) ? 1 : cmp(*static_cast(n), *static_cast(node)) < 0; + + for (;;) { + if (n->_getChild(dir) == f) { + n->_setChild(dir, q); + // RAW copy, including the color. + q->_rbNodeData[0] = f->_rbNodeData[0]; + q->_rbNodeData[1] = f->_rbNodeData[1]; + break; + } + + n = n->_getChild(dir); + + // Cannot be true as we know that it must reach `f` in few iterations. + ASMJIT_ASSERT(n != nullptr); + dir = cmp(*static_cast(n), *static_cast(node)) < 0; + } + } + + // Update root and make it black. + _root = static_cast(head._getRight()); + if (_root) _root->_makeBlack(); + } + + template> + inline NodeT* get(const KeyT& key, const CompareT& cmp = CompareT()) const noexcept { + ZoneTreeNode* node = _root; + while (node) { + auto result = cmp(*static_cast(node), key); + if (result == 0) break; + + // Go left or right depending on the `result`. + node = node->_getChild(result < 0); + } + return static_cast(node); + } + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + static inline bool _isValidRed(ZoneTreeNode* node) noexcept { return ZoneTreeNode::_isValidRed(node); } + + //! Single rotation. + static inline ZoneTreeNode* _singleRotate(ZoneTreeNode* ASMJIT_NONNULL(root), size_t dir) noexcept { + ZoneTreeNode* save = root->_getChild(!dir); + ASMJIT_ASSUME(save != nullptr); + + ZoneTreeNode* saveChild = save->_getChild(dir); + root->_setChild(!dir, saveChild); + save->_setChild( dir, root); + root->_makeRed(); + save->_makeBlack(); + return save; + } + + //! Double rotation. + static inline ZoneTreeNode* _doubleRotate(ZoneTreeNode* ASMJIT_NONNULL(root), size_t dir) noexcept { + ZoneTreeNode* child = root->_getChild(!dir); + ASMJIT_ASSUME(child != nullptr); + + root->_setChild(!dir, _singleRotate(child, !dir)); + return _singleRotate(root, dir); + } + + //! \} + //! \endcond +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONETREE_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/core/zonevector.h b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonevector.h new file mode 100644 index 0000000000000000000000000000000000000000..68273527a8cb4de095ed9f6f8fbb1b0cf035b85a --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/core/zonevector.h @@ -0,0 +1,729 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_CORE_ZONEVECTOR_H_INCLUDED +#define ASMJIT_CORE_ZONEVECTOR_H_INCLUDED + +#include "../core/support.h" +#include "../core/zone.h" + +ASMJIT_BEGIN_NAMESPACE + +//! \addtogroup asmjit_zone +//! \{ + +//! Base class used by \ref ZoneVector template. +class ZoneVectorBase { +public: + ASMJIT_NONCOPYABLE(ZoneVectorBase) + + // STL compatibility; + typedef uint32_t size_type; + typedef ptrdiff_t difference_type; + + //! Vector data (untyped). + void* _data = nullptr; + //! Size of the vector. + size_type _size = 0; + //! Capacity of the vector. + size_type _capacity = 0; + +protected: + //! \name Construction & Destruction + //! \{ + + //! Creates a new instance of `ZoneVectorBase`. + inline ZoneVectorBase() noexcept {} + + inline ZoneVectorBase(ZoneVectorBase&& other) noexcept + : _data(other._data), + _size(other._size), + _capacity(other._capacity) {} + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + inline void _release(ZoneAllocator* allocator, uint32_t sizeOfT) noexcept { + if (_data != nullptr) { + allocator->release(_data, _capacity * sizeOfT); + reset(); + } + } + + ASMJIT_API Error _grow(ZoneAllocator* allocator, uint32_t sizeOfT, uint32_t n) noexcept; + ASMJIT_API Error _resize(ZoneAllocator* allocator, uint32_t sizeOfT, uint32_t n) noexcept; + ASMJIT_API Error _reserve(ZoneAllocator* allocator, uint32_t sizeOfT, uint32_t n) noexcept; + + inline void _swap(ZoneVectorBase& other) noexcept { + std::swap(_data, other._data); + std::swap(_size, other._size); + std::swap(_capacity, other._capacity); + } + + //! \} + //! \endcond + +public: + //! \name Accessors + //! \{ + + //! Tests whether the vector is empty. + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + //! Returns the vector size. + ASMJIT_INLINE_NODEBUG size_type size() const noexcept { return _size; } + //! Returns the vector capacity. + ASMJIT_INLINE_NODEBUG size_type capacity() const noexcept { return _capacity; } + + //! \} + + //! \name Utilities + //! \{ + + //! Makes the vector empty (won't change the capacity or data pointer). + ASMJIT_INLINE_NODEBUG void clear() noexcept { _size = 0; } + //! Resets the vector data and set its `size` to zero. + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _data = nullptr; + _size = 0; + _capacity = 0; + } + + //! Truncates the vector to at most `n` items. + ASMJIT_INLINE_NODEBUG void truncate(size_type n) noexcept { + _size = Support::min(_size, n); + } + + //! Sets size of the vector to `n`. Used internally by some algorithms. + inline void _setSize(size_type n) noexcept { + ASMJIT_ASSERT(n <= _capacity); + _size = n; + } + + //! \} +}; + +//! Template used to store and manage array of Zone allocated data. +//! +//! This template has these advantages over other std::vector<>: +//! - Always non-copyable (designed to be non-copyable, we want it). +//! - Optimized for working only with POD types. +//! - Uses ZoneAllocator, thus small vectors are almost for free. +//! - Explicit allocation, ZoneAllocator is not part of the data. +template +class ZoneVector : public ZoneVectorBase { +public: + ASMJIT_NONCOPYABLE(ZoneVector) + + // STL compatibility; + typedef T value_type; + typedef T* pointer; + typedef const T* const_pointer; + typedef T& reference; + typedef const T& const_reference; + + typedef T* iterator; + typedef const T* const_iterator; + typedef Support::ArrayReverseIterator reverse_iterator; + typedef Support::ArrayReverseIterator const_reverse_iterator; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneVector() noexcept : ZoneVectorBase() {} + ASMJIT_INLINE_NODEBUG ZoneVector(ZoneVector&& other) noexcept : ZoneVector(other) {} + + //! \} + + //! \name Accessors + //! \{ + + //! Returns vector data. + ASMJIT_INLINE_NODEBUG T* data() noexcept { return static_cast(_data); } + //! Returns vector data (const) + ASMJIT_INLINE_NODEBUG const T* data() const noexcept { return static_cast(_data); } + + //! Returns item at the given index `i` (const). + inline const T& at(size_t i) const noexcept { + ASMJIT_ASSERT(i < _size); + return data()[i]; + } + + inline void _setEndPtr(T* p) noexcept { + ASMJIT_ASSERT(p >= data() && p <= data() + _capacity); + _setSize(uint32_t((uintptr_t)(p - data()))); + } + + //! \} + + //! \name STL Compatibility (Iterators) + //! \{ + + ASMJIT_INLINE_NODEBUG iterator begin() noexcept { return iterator(data()); }; + ASMJIT_INLINE_NODEBUG const_iterator begin() const noexcept { return const_iterator(data()); }; + + ASMJIT_INLINE_NODEBUG iterator end() noexcept { return iterator(data() + _size); }; + ASMJIT_INLINE_NODEBUG const_iterator end() const noexcept { return const_iterator(data() + _size); }; + + ASMJIT_INLINE_NODEBUG reverse_iterator rbegin() noexcept { return reverse_iterator(end()); }; + ASMJIT_INLINE_NODEBUG const_reverse_iterator rbegin() const noexcept { return const_reverse_iterator(end()); }; + + ASMJIT_INLINE_NODEBUG reverse_iterator rend() noexcept { return reverse_iterator(begin()); }; + ASMJIT_INLINE_NODEBUG const_reverse_iterator rend() const noexcept { return const_reverse_iterator(begin()); }; + + ASMJIT_INLINE_NODEBUG const_iterator cbegin() const noexcept { return const_iterator(data()); }; + ASMJIT_INLINE_NODEBUG const_iterator cend() const noexcept { return const_iterator(data() + _size); }; + + ASMJIT_INLINE_NODEBUG const_reverse_iterator crbegin() const noexcept { return const_reverse_iterator(cend()); }; + ASMJIT_INLINE_NODEBUG const_reverse_iterator crend() const noexcept { return const_reverse_iterator(cbegin()); }; + + //! \} + + //! \name Utilities + //! \{ + + //! Swaps this vector with `other`. + ASMJIT_FORCE_INLINE void swap(ZoneVector& other) noexcept { _swap(other); } + + //! Prepends `item` to the vector. + ASMJIT_FORCE_INLINE Error prepend(ZoneAllocator* allocator, const T& item) noexcept { + if (ASMJIT_UNLIKELY(_size == _capacity)) + ASMJIT_PROPAGATE(grow(allocator, 1)); + + memmove(static_cast(static_cast(_data) + 1), + static_cast(_data), + size_t(_size) * sizeof(T)); + + memcpy(static_cast(_data), + static_cast(&item), + sizeof(T)); + + _size++; + return kErrorOk; + } + + //! Inserts an `item` at the specified `index`. + ASMJIT_FORCE_INLINE Error insert(ZoneAllocator* allocator, size_t index, const T& item) noexcept { + ASMJIT_ASSERT(index <= _size); + + if (ASMJIT_UNLIKELY(_size == _capacity)) + ASMJIT_PROPAGATE(grow(allocator, 1)); + + T* dst = static_cast(_data) + index; + memmove(static_cast(dst + 1), + static_cast(dst), + size_t(_size - index) * sizeof(T)); + + memcpy(static_cast(dst), + static_cast(&item), + sizeof(T)); + + _size++; + return kErrorOk; + } + + //! Appends `item` to the vector. + ASMJIT_FORCE_INLINE Error append(ZoneAllocator* allocator, const T& item) noexcept { + if (ASMJIT_UNLIKELY(_size == _capacity)) + ASMJIT_PROPAGATE(grow(allocator, 1)); + + memcpy(static_cast(static_cast(_data) + _size), + static_cast(&item), + sizeof(T)); + + _size++; + return kErrorOk; + } + + //! Appends `other` vector at the end of this vector. + ASMJIT_FORCE_INLINE Error concat(ZoneAllocator* allocator, const ZoneVector& other) noexcept { + uint32_t size = other._size; + if (_capacity - _size < size) + ASMJIT_PROPAGATE(grow(allocator, size)); + + if (size) { + memcpy(static_cast(static_cast(_data) + _size), + static_cast(other._data), + size_t(size) * sizeof(T)); + _size += size; + } + + return kErrorOk; + } + + //! Prepends `item` to the vector (unsafe case). + //! + //! Can only be used together with `willGrow()`. If `willGrow(N)` returns `kErrorOk` then N elements + //! can be added to the vector without checking if there is a place for them. Used mostly internally. + ASMJIT_FORCE_INLINE void prependUnsafe(const T& item) noexcept { + ASMJIT_ASSERT(_size < _capacity); + T* data = static_cast(_data); + + if (_size) { + memmove(static_cast(data + 1), + static_cast(data), + size_t(_size) * sizeof(T)); + } + + memcpy(static_cast(data), + static_cast(&item), + sizeof(T)); + _size++; + } + + //! Append s`item` to the vector (unsafe case). + //! + //! Can only be used together with `willGrow()`. If `willGrow(N)` returns `kErrorOk` then N elements + //! can be added to the vector without checking if there is a place for them. Used mostly internally. + ASMJIT_FORCE_INLINE void appendUnsafe(const T& item) noexcept { + ASMJIT_ASSERT(_size < _capacity); + + memcpy(static_cast(static_cast(_data) + _size), + static_cast(&item), + sizeof(T)); + _size++; + } + + //! Inserts an `item` at the specified `index` (unsafe case). + ASMJIT_FORCE_INLINE void insertUnsafe(size_t index, const T& item) noexcept { + ASMJIT_ASSERT(_size < _capacity); + ASMJIT_ASSERT(index <= _size); + + T* dst = static_cast(_data) + index; + memmove(static_cast(dst + 1), + static_cast(dst), + size_t(_size - index) * sizeof(T)); + + memcpy(static_cast(dst), + static_cast(&item), + sizeof(T)); + + _size++; + } + + //! Concatenates all items of `other` at the end of the vector. + ASMJIT_FORCE_INLINE void concatUnsafe(const ZoneVector& other) noexcept { + uint32_t size = other._size; + ASMJIT_ASSERT(_capacity - _size >= size); + + if (size) { + memcpy(static_cast(static_cast(_data) + _size), + static_cast(other._data), + size_t(size) * sizeof(T)); + _size += size; + } + } + + //! Returns index of the given `val` or `Globals::kNotFound` if it doesn't exist. + ASMJIT_FORCE_INLINE uint32_t indexOf(const T& val) const noexcept { + const T* data = static_cast(_data); + uint32_t size = _size; + + for (uint32_t i = 0; i < size; i++) + if (data[i] == val) + return i; + return Globals::kNotFound; + } + + //! Tests whether the vector contains `val`. + inline bool contains(const T& val) const noexcept { + return indexOf(val) != Globals::kNotFound; + } + + //! Removes item at index `i`. + inline void removeAt(size_t i) noexcept { + ASMJIT_ASSERT(i < _size); + + T* data = static_cast(_data) + i; + size_t size = --_size - i; + + if (size) { + memmove(static_cast(data), + static_cast(data + 1), + size_t(size) * sizeof(T)); + } + } + + //! Pops the last element from the vector and returns it. + inline T pop() noexcept { + ASMJIT_ASSERT(_size > 0); + + uint32_t index = --_size; + return data()[index]; + } + + template> + inline void sort(const CompareT& cmp = CompareT()) noexcept { + Support::qSort(data(), size(), cmp); + } + + //! Returns item at index `i`. + inline T& operator[](size_t i) noexcept { + ASMJIT_ASSERT(i < _size); + return data()[i]; + } + + //! Returns item at index `i`. + inline const T& operator[](size_t i) const noexcept { + ASMJIT_ASSERT(i < _size); + return data()[i]; + } + + //! Returns a reference to the first element of the vector. + //! + //! \note The vector must have at least one element. Attempting to use `first()` on empty vector will trigger + //! an assertion failure in debug builds. + ASMJIT_INLINE_NODEBUG T& first() noexcept { return operator[](0); } + //! \overload + ASMJIT_INLINE_NODEBUG const T& first() const noexcept { return operator[](0); } + + //! Returns a reference to the last element of the vector. + //! + //! \note The vector must have at least one element. Attempting to use `last()` on empty vector will trigger + //! an assertion failure in debug builds. + inline T& last() noexcept { return operator[](_size - 1); } + //! \overload + inline const T& last() const noexcept { return operator[](_size - 1); } + + //! \} + + //! \name Memory Management + //! \{ + + //! Releases the memory held by `ZoneVector` back to the `allocator`. + inline void release(ZoneAllocator* allocator) noexcept { + _release(allocator, sizeof(T)); + } + + //! Called to grow the buffer to fit at least `n` elements more. + inline Error grow(ZoneAllocator* allocator, uint32_t n) noexcept { + return ZoneVectorBase::_grow(allocator, sizeof(T), n); + } + + //! Resizes the vector to hold `n` elements. + //! + //! If `n` is greater than the current size then the additional elements' content will be initialized to zero. + //! If `n` is less than the current size then the vector will be truncated to exactly `n` elements. + inline Error resize(ZoneAllocator* allocator, uint32_t n) noexcept { + return ZoneVectorBase::_resize(allocator, sizeof(T), n); + } + + //! Reallocates the internal array to fit at least `n` items. + inline Error reserve(ZoneAllocator* allocator, uint32_t n) noexcept { + return n > _capacity ? ZoneVectorBase::_reserve(allocator, sizeof(T), n) : Error(kErrorOk); + } + + inline Error willGrow(ZoneAllocator* allocator, uint32_t n = 1) noexcept { + return _capacity - _size < n ? grow(allocator, n) : Error(kErrorOk); + } + + //! \} +}; + +//! Zone-allocated bit vector. +class ZoneBitVector { +public: + typedef Support::BitWord BitWord; + + ASMJIT_NONCOPYABLE(ZoneBitVector) + + //! \name Constants + //! \{ + + enum : uint32_t { + kBitWordSizeInBits = Support::kBitWordSizeInBits + }; + + //! \} + + //! \name Members + //! \{ + + //! Bits. + BitWord* _data = nullptr; + //! Size of the bit-vector (in bits). + uint32_t _size = 0; + //! Capacity of the bit-vector (in bits). + uint32_t _capacity = 0; + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + static ASMJIT_INLINE_NODEBUG uint32_t _wordsPerBits(uint32_t nBits) noexcept { + return ((nBits + kBitWordSizeInBits - 1) / kBitWordSizeInBits); + } + + static ASMJIT_INLINE_NODEBUG void _zeroBits(BitWord* dst, uint32_t nBitWords) noexcept { + for (uint32_t i = 0; i < nBitWords; i++) + dst[i] = 0; + } + + static ASMJIT_INLINE_NODEBUG void _fillBits(BitWord* dst, uint32_t nBitWords) noexcept { + for (uint32_t i = 0; i < nBitWords; i++) + dst[i] = ~BitWord(0); + } + + static ASMJIT_INLINE_NODEBUG void _copyBits(BitWord* dst, const BitWord* src, uint32_t nBitWords) noexcept { + for (uint32_t i = 0; i < nBitWords; i++) + dst[i] = src[i]; + } + + //! \} + //! \endcond + + //! \name Construction & Destruction + //! \{ + + ASMJIT_INLINE_NODEBUG ZoneBitVector() noexcept {} + + ASMJIT_INLINE_NODEBUG ZoneBitVector(ZoneBitVector&& other) noexcept + : _data(other._data), + _size(other._size), + _capacity(other._capacity) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG bool operator==(const ZoneBitVector& other) const noexcept { return equals(other); } + ASMJIT_INLINE_NODEBUG bool operator!=(const ZoneBitVector& other) const noexcept { return !equals(other); } + + //! \} + + //! \name Accessors + //! \{ + + //! Tests whether the bit-vector is empty (has no bits). + ASMJIT_INLINE_NODEBUG bool empty() const noexcept { return _size == 0; } + //! Returns the size of this bit-vector (in bits). + ASMJIT_INLINE_NODEBUG uint32_t size() const noexcept { return _size; } + //! Returns the capacity of this bit-vector (in bits). + ASMJIT_INLINE_NODEBUG uint32_t capacity() const noexcept { return _capacity; } + + //! Returns the size of the `BitWord[]` array in `BitWord` units. + ASMJIT_INLINE_NODEBUG uint32_t sizeInBitWords() const noexcept { return _wordsPerBits(_size); } + //! Returns the capacity of the `BitWord[]` array in `BitWord` units. + ASMJIT_INLINE_NODEBUG uint32_t capacityInBitWords() const noexcept { return _wordsPerBits(_capacity); } + + //! Returns bit-vector data as `BitWord[]`. + ASMJIT_INLINE_NODEBUG BitWord* data() noexcept { return _data; } + //! \overload + ASMJIT_INLINE_NODEBUG const BitWord* data() const noexcept { return _data; } + + //! \} + + //! \name Utilities + //! \{ + + ASMJIT_INLINE_NODEBUG void swap(ZoneBitVector& other) noexcept { + std::swap(_data, other._data); + std::swap(_size, other._size); + std::swap(_capacity, other._capacity); + } + + ASMJIT_INLINE_NODEBUG void clear() noexcept { + _size = 0; + } + + ASMJIT_INLINE_NODEBUG void reset() noexcept { + _data = nullptr; + _size = 0; + _capacity = 0; + } + + ASMJIT_INLINE_NODEBUG void truncate(uint32_t newSize) noexcept { + _size = Support::min(_size, newSize); + _clearUnusedBits(); + } + + inline bool bitAt(uint32_t index) const noexcept { + ASMJIT_ASSERT(index < _size); + return Support::bitVectorGetBit(_data, index); + } + + inline void setBit(uint32_t index, bool value) noexcept { + ASMJIT_ASSERT(index < _size); + Support::bitVectorSetBit(_data, index, value); + } + + inline void flipBit(uint32_t index) noexcept { + ASMJIT_ASSERT(index < _size); + Support::bitVectorFlipBit(_data, index); + } + + ASMJIT_FORCE_INLINE Error append(ZoneAllocator* allocator, bool value) noexcept { + uint32_t index = _size; + if (ASMJIT_UNLIKELY(index >= _capacity)) + return _append(allocator, value); + + uint32_t idx = index / kBitWordSizeInBits; + uint32_t bit = index % kBitWordSizeInBits; + + if (bit == 0) + _data[idx] = BitWord(value) << bit; + else + _data[idx] |= BitWord(value) << bit; + + _size++; + return kErrorOk; + } + + ASMJIT_API Error copyFrom(ZoneAllocator* allocator, const ZoneBitVector& other) noexcept; + + ASMJIT_FORCE_INLINE void clearAll() noexcept { + _zeroBits(_data, _wordsPerBits(_size)); + } + + ASMJIT_FORCE_INLINE void fillAll() noexcept { + _fillBits(_data, _wordsPerBits(_size)); + _clearUnusedBits(); + } + + ASMJIT_FORCE_INLINE void clearBits(uint32_t start, uint32_t count) noexcept { + ASMJIT_ASSERT(start <= _size); + ASMJIT_ASSERT(_size - start >= count); + + Support::bitVectorClear(_data, start, count); + } + + ASMJIT_FORCE_INLINE void fillBits(uint32_t start, uint32_t count) noexcept { + ASMJIT_ASSERT(start <= _size); + ASMJIT_ASSERT(_size - start >= count); + + Support::bitVectorFill(_data, start, count); + } + + //! Performs a logical bitwise AND between bits specified in this array and bits in `other`. If `other` has less + //! bits than `this` then all remaining bits are set to zero. + //! + //! \note The size of the BitVector is unaffected by this operation. + ASMJIT_FORCE_INLINE void and_(const ZoneBitVector& other) noexcept { + BitWord* dst = _data; + const BitWord* src = other._data; + + uint32_t thisBitWordCount = sizeInBitWords(); + uint32_t otherBitWordCount = other.sizeInBitWords(); + uint32_t commonBitWordCount = Support::min(thisBitWordCount, otherBitWordCount); + + uint32_t i = 0; + while (i < commonBitWordCount) { + dst[i] = dst[i] & src[i]; + i++; + } + + while (i < thisBitWordCount) { + dst[i] = 0; + i++; + } + } + + //! Performs a logical bitwise AND between bits specified in this array and negated bits in `other`. If `other` + //! has less bits than `this` then all remaining bits are kept intact. + //! + //! \note The size of the BitVector is unaffected by this operation. + ASMJIT_FORCE_INLINE void andNot(const ZoneBitVector& other) noexcept { + BitWord* dst = _data; + const BitWord* src = other._data; + + uint32_t commonBitWordCount = _wordsPerBits(Support::min(_size, other._size)); + for (uint32_t i = 0; i < commonBitWordCount; i++) + dst[i] = dst[i] & ~src[i]; + } + + //! Performs a logical bitwise OP between bits specified in this array and bits in `other`. If `other` has less + //! bits than `this` then all remaining bits are kept intact. + //! + //! \note The size of the BitVector is unaffected by this operation. + ASMJIT_FORCE_INLINE void or_(const ZoneBitVector& other) noexcept { + BitWord* dst = _data; + const BitWord* src = other._data; + + uint32_t commonBitWordCount = _wordsPerBits(Support::min(_size, other._size)); + for (uint32_t i = 0; i < commonBitWordCount; i++) + dst[i] = dst[i] | src[i]; + _clearUnusedBits(); + } + + ASMJIT_FORCE_INLINE void _clearUnusedBits() noexcept { + uint32_t idx = _size / kBitWordSizeInBits; + uint32_t bit = _size % kBitWordSizeInBits; + + if (!bit) + return; + _data[idx] &= (BitWord(1) << bit) - 1u; + } + + ASMJIT_FORCE_INLINE bool equals(const ZoneBitVector& other) const noexcept { + if (_size != other._size) + return false; + + const BitWord* aData = _data; + const BitWord* bData = other._data; + uint32_t numBitWords = _wordsPerBits(_size); + + for (uint32_t i = 0; i < numBitWords; i++) + if (aData[i] != bData[i]) + return false; + return true; + } + +#if !defined(ASMJIT_NO_DEPRECATED) + ASMJIT_DEPRECATED("Use ZoneVector::equals() instead") + ASMJIT_FORCE_INLINE bool eq(const ZoneBitVector& other) const noexcept { return equals(other); } +#endif // !ASMJIT_NO_DEPRECATED + + //! \} + + //! \name Memory Management + //! \{ + + inline void release(ZoneAllocator* allocator) noexcept { + if (!_data) + return; + allocator->release(_data, _capacity / 8); + reset(); + } + + ASMJIT_INLINE_NODEBUG Error resize(ZoneAllocator* allocator, uint32_t newSize, bool newBitsValue = false) noexcept { + return _resize(allocator, newSize, newSize, newBitsValue); + } + + ASMJIT_API Error _resize(ZoneAllocator* allocator, uint32_t newSize, uint32_t idealCapacity, bool newBitsValue) noexcept; + ASMJIT_API Error _append(ZoneAllocator* allocator, bool value) noexcept; + + //! \} + + //! \name Iterators + //! \{ + + class ForEachBitSet : public Support::BitVectorIterator { + public: + inline explicit ForEachBitSet(const ZoneBitVector& bitVector) noexcept + : Support::BitVectorIterator(bitVector.data(), bitVector.sizeInBitWords()) {} + }; + + template + class ForEachBitOp : public Support::BitVectorOpIterator { + public: + inline ForEachBitOp(const ZoneBitVector& a, const ZoneBitVector& b) noexcept + : Support::BitVectorOpIterator(a.data(), b.data(), a.sizeInBitWords()) { + ASMJIT_ASSERT(a.size() == b.size()); + } + }; + + //! \} +}; + +//! \} + +ASMJIT_END_NAMESPACE + +#endif // ASMJIT_CORE_ZONEVECTOR_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86.h new file mode 100644 index 0000000000000000000000000000000000000000..671c5de4a2912bb16e9edb72406737df7a902a41 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86.h @@ -0,0 +1,93 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_H_INCLUDED +#define ASMJIT_X86_H_INCLUDED + +//! \addtogroup asmjit_x86 +//! +//! ### Namespace +//! +//! - \ref x86 - x86 namespace provides support for X86/X64 code generation. +//! +//! ### Emitters +//! +//! - \ref x86::Assembler - X86/X64 assembler (must read, provides examples). +//! - \ref x86::Builder - X86/X64 builder. +//! - \ref x86::Compiler - X86/X64 compiler. +//! - \ref x86::Emitter - X86/X64 emitter (abstract). +//! +//! ### Supported Instructions +//! +//! - Emitters: +//! - \ref x86::EmitterExplicitT - Provides all instructions that use explicit operands, provides also utility +//! functions. The member functions provided are part of all X86 emitters. +//! - \ref x86::EmitterImplicitT - Provides all instructions that use implicit operands, these cannot be used +//! with \ref x86::Compiler. +//! +//! - Instruction representation: +//! - \ref x86::Inst::Id - Provides instruction identifiers for both X86/X86_64 architectures. +//! - \ref InstOptions - Provides generic and X86/X86_64 specific options. +//! +//! ### Register Operands +//! +//! - \ref x86::Reg - Base class for any X86 register. +//! - \ref x86::Gp - General purpose register: +//! - \ref x86::GpbLo - 8-bit low register. +//! - \ref x86::GpbHi - 8-bit high register. +//! - \ref x86::Gpw - 16-bit register. +//! - \ref x86::Gpd - 32-bit register. +//! - \ref x86::Gpq - 64-bit register (X64 only). +//! - \ref x86::Vec - Vector (SIMD) register: +//! - \ref x86::Xmm - 128-bit SIMD register (SSE+). +//! - \ref x86::Ymm - 256-bit SIMD register (AVX+). +//! - \ref x86::Zmm - 512-bit SIMD register (AVX512+). +//! - \ref x86::Mm - 64-bit MMX register. +//! - \ref x86::St - 80-bit FPU register. +//! - \ref x86::KReg - opmask registers (AVX512+). +//! - \ref x86::SReg - segment register. +//! - \ref x86::CReg - control register. +//! - \ref x86::DReg - debug register. +//! - \ref x86::Bnd - bound register (discontinued). +//! - \ref x86::Rip - relative instruction pointer. +//! +//! ### Memory Operands +//! +//! - \ref x86::Mem - X86/X64 memory operand that provides support for all X86 and X64 addressing features +//! including absolute addresses, index scales, and segment override prefixes. +//! +//! ### Status and Control Words +//! +//! - \ref x86::FpuStatusWord - FPU status word bits / decomposition. +//! - \ref x86::FpuControlWord - FPU control word bits / decomposition. +//! +//! ### Predicates (immediate values) +//! +//! - \ref x86::CmpImm - `CMP[PD|PS|SD|SS]` predicate (SSE+). +//! - \ref x86::PCmpStrImm - `[V]PCMP[I|E]STR[I|M]` predicate (SSE4.1+, AVX+). +//! - \ref x86::RoundImm - `[V]ROUND[PD|PS|SD|SS]` predicate (SSE+, AVX+). +//! - \ref x86::VCmpImm - `VCMP[PD|PS|SD|SS]` predicate (AVX+). +//! - \ref x86::VFixupImm - `VFIXUPIMM[PD|PS|SD|SS]` predicate (AVX512+). +//! - \ref x86::VFPClassImm - `VFPCLASS[PD|PS|SD|SS]` predicate (AVX512+). +//! - \ref x86::VGetMantImm - `VGETMANT[PD|PS|SD|SS]` predicate (AVX512+). +//! - \ref x86::VPCmpImm - `VPCMP[U][B|W|D|Q]` predicate (AVX512+). +//! - \ref x86::VPComImm - `VPCOM[U][B|W|D|Q]` predicate (XOP). +//! - \ref x86::VRangeImm - `VRANGE[PD|PS|SD|SS]` predicate (AVX512+). +//! - \ref x86::VReduceImm - `REDUCE[PD|PS|SD|SS]` predicate (AVX512+). +//! - \ref x86::TLogImm - `VPTERNLOG[D|Q]` predicate and operations (AVX512+). + +#include "core.h" + +#include "asmjit-scope-begin.h" +#include "x86/x86assembler.h" +#include "x86/x86builder.h" +#include "x86/x86compiler.h" +#include "x86/x86emitter.h" +#include "x86/x86globals.h" +#include "x86/x86instdb.h" +#include "x86/x86operand.h" +#include "asmjit-scope-end.h" + +#endif // ASMJIT_X86_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86assembler.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86assembler.h new file mode 100644 index 0000000000000000000000000000000000000000..1b8df4098d66b4dcc24f3e625b7bddbc727e5096 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86assembler.h @@ -0,0 +1,695 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86ASSEMBLER_H_INCLUDED +#define ASMJIT_X86_X86ASSEMBLER_H_INCLUDED + +#include "../core/assembler.h" +#include "../x86/x86emitter.h" +#include "../x86/x86operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! X86/X64 assembler implementation. +//! +//! x86::Assembler is a code emitter that emits machine code directly into the \ref CodeBuffer. The assembler is capable +//! of targeting both 32-bit and 64-bit instruction sets, the instruction set can be configured through \ref CodeHolder. +//! +//! ### Basics +//! +//! The following example shows a basic use of `x86::Assembler`, how to generate a function that works in both 32-bit +//! and 64-bit modes, and how to connect \ref JitRuntime, \ref CodeHolder, and `x86::Assembler`. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*SumFunc)(const int* arr, size_t count); +//! +//! int main() { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Decide between 32-bit CDECL, WIN64, and SysV64 calling conventions: +//! // 32-BIT - passed all arguments by stack. +//! // WIN64 - passes first 4 arguments by RCX, RDX, R8, and R9. +//! // UNIX64 - passes first 6 arguments by RDI, RSI, RDX, RCX, R8, and R9. +//! x86::Gp arr, cnt; +//! x86::Gp sum = x86::eax; // Use EAX as 'sum' as it's a return register. +//! +//! if (ASMJIT_ARCH_BITS == 64) { +//! #if defined(_WIN32) +//! arr = x86::rcx; // First argument (array ptr). +//! cnt = x86::rdx; // Second argument (number of elements) +//! #else +//! arr = x86::rdi; // First argument (array ptr). +//! cnt = x86::rsi; // Second argument (number of elements) +//! #endif +//! } +//! else { +//! arr = x86::edx; // Use EDX to hold the array pointer. +//! cnt = x86::ecx; // Use ECX to hold the counter. +//! // Fetch first and second arguments from [ESP + 4] and [ESP + 8]. +//! a.mov(arr, x86::ptr(x86::esp, 4)); +//! a.mov(cnt, x86::ptr(x86::esp, 8)); +//! } +//! +//! Label Loop = a.newLabel(); // To construct the loop, we need some labels. +//! Label Exit = a.newLabel(); +//! +//! a.xor_(sum, sum); // Clear 'sum' register (shorter than 'mov'). +//! a.test(cnt, cnt); // Border case: +//! a.jz(Exit); // If 'cnt' is zero jump to 'Exit' now. +//! +//! a.bind(Loop); // Start of a loop iteration. +//! a.add(sum, x86::dword_ptr(arr)); // Add int at [arr] to 'sum'. +//! a.add(arr, 4); // Increment 'arr' pointer. +//! a.dec(cnt); // Decrease 'cnt'. +//! a.jnz(Loop); // If not zero jump to 'Loop'. +//! +//! a.bind(Exit); // Exit to handle the border case. +//! a.ret(); // Return from function ('sum' == 'eax'). +//! // ----> x86::Assembler is no longer needed from here and can be destroyed <---- +//! +//! SumFunc fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! static const int array[6] = { 4, 8, 15, 16, 23, 42 }; +//! +//! int result = fn(array, 6); // Execute the generated code. +//! printf("%d\n", result); // Print sum of array (108). +//! +//! rt.release(fn); // Explicitly remove the function from the runtime +//! return 0; // Everything successful... +//! } +//! ``` +//! +//! The example should be self-explanatory. It shows how to work with labels, how to use operands, and how to emit +//! instructions that can use different registers based on runtime selection. It implements 32-bit CDECL, WIN64, +//! and SysV64 calling conventions and will work on most X86/X64 environments. +//! +//! Although functions prologs / epilogs can be implemented manually, AsmJit provides utilities that can be used +//! to create function prologs and epilogs automatically, see \ref asmjit_function for more details. +//! +//! ### Instruction Validation +//! +//! Assembler prefers speed over strictness by default. The implementation checks the type of operands and fails +//! if the signature of types is invalid, however, it does only basic checks regarding registers and their groups +//! used in instructions. It's possible to pass operands that don't form any valid signature to the implementation +//! and succeed. This is usually not a problem as Assembler provides typed API so operand types are normally checked +//! by C++ compiler at compile time, however, Assembler is fully dynamic and its \ref emit() function can be called +//! with any instruction id, options, and operands. Moreover, it's also possible to form instructions that will be +//! accepted by the typed API, for example by calling `mov(x86::eax, x86::al)` - the C++ compiler won't see a problem +//! as both EAX and AL are \ref Gp registers. +//! +//! To help with common mistakes AsmJit allows to activate instruction validation. This feature instruments +//! the Assembler to call \ref InstAPI::validate() before it attempts to encode any instruction. +//! +//! The example below illustrates how validation can be turned on: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Enable strict validation. +//! a.addDiagnosticOptions(DiagnosticOptions::kValidateAssembler); +//! +//! // Try to encode invalid or ill-formed instructions. +//! Error err; +//! +//! // Invalid instruction. +//! err = a.mov(x86::eax, x86::al); +//! printf("Status: %s\n", DebugUtils::errorAsString(err)); +//! +//! // Invalid instruction. +//! err = a.emit(x86::Inst::kIdMovss, x86::eax, x86::xmm0); +//! printf("Status: %s\n", DebugUtils::errorAsString(err)); +//! +//! // Ambiguous operand size - the pointer requires size. +//! err = a.inc(x86::ptr(x86::rax)); +//! printf("Status: %s\n", DebugUtils::errorAsString(err)); +//! +//! return 0; +//! } +//! ``` +//! +//! ### Native Registers +//! +//! All emitters provide functions to construct machine-size registers depending on the target. This feature is +//! for users that want to write code targeting both 32-bit and 64-bit architectures at the same time. In AsmJit +//! terminology such registers have prefix `z`, so for example on X86 architecture the following native registers +//! are provided: +//! +//! - `zax` - mapped to either `eax` or `rax` +//! - `zbx` - mapped to either `ebx` or `rbx` +//! - `zcx` - mapped to either `ecx` or `rcx` +//! - `zdx` - mapped to either `edx` or `rdx` +//! - `zsp` - mapped to either `esp` or `rsp` +//! - `zbp` - mapped to either `ebp` or `rbp` +//! - `zsi` - mapped to either `esi` or `rsi` +//! - `zdi` - mapped to either `edi` or `rdi` +//! +//! They are accessible through \ref x86::Assembler, \ref x86::Builder, and \ref x86::Compiler. The example below +//! illustrates how to use this feature: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef int (*Func)(void); +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Let's get these registers from x86::Assembler. +//! x86::Gp zbp = a.zbp(); +//! x86::Gp zsp = a.zsp(); +//! +//! int stackSize = 32; +//! +//! // Function prolog. +//! a.push(zbp); +//! a.mov(zbp, zsp); +//! a.sub(zsp, stackSize); +//! +//! // ... emit some code (this just sets return value to zero) ... +//! a.xor_(x86::eax, x86::eax); +//! +//! // Function epilog and return. +//! a.mov(zsp, zbp); +//! a.pop(zbp); +//! a.ret(); +//! +//! // To make the example complete let's call it. +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "0". +//! +//! rt.release(fn); // Remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! The example just returns `0`, but the function generated contains a standard prolog and epilog sequence and the +//! function itself reserves 32 bytes of local stack. The advantage is clear - a single code-base can handle multiple +//! targets easily. If you want to create a register of native size dynamically by specifying its id it's also possible: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void example(x86::Assembler& a) { +//! x86::Gp zax = a.gpz(x86::Gp::kIdAx); +//! x86::Gp zbx = a.gpz(x86::Gp::kIdBx); +//! x86::Gp zcx = a.gpz(x86::Gp::kIdCx); +//! x86::Gp zdx = a.gpz(x86::Gp::kIdDx); +//! +//! // You can also change register's id easily. +//! x86::Gp zsp = zax; +//! zsp.setId(4); // or x86::Gp::kIdSp. +//! } +//! ``` +//! +//! ### Data Embedding +//! +//! x86::Assembler extends the standard \ref BaseAssembler with X86/X64 specific conventions that are often used by +//! assemblers to embed data next to the code. The following functions can be used to embed data: +//! +//! - \ref BaseAssembler::embedInt8() - embeds int8_t (portable naming). +//! - \ref BaseAssembler::embedUInt8() - embeds uint8_t (portable naming). +//! - \ref BaseAssembler::embedInt16() - embeds int16_t (portable naming). +//! - \ref BaseAssembler::embedUInt16() - embeds uint16_t (portable naming). +//! - \ref BaseAssembler::embedInt32() - embeds int32_t (portable naming). +//! - \ref BaseAssembler::embedUInt32() - embeds uint32_t (portable naming). +//! - \ref BaseAssembler::embedInt64() - embeds int64_t (portable naming). +//! - \ref BaseAssembler::embedUInt64() - embeds uint64_t (portable naming). +//! - \ref BaseAssembler::embedFloat() - embeds float (portable naming). +//! - \ref BaseAssembler::embedDouble() - embeds double (portable naming). +//! +//! - \ref x86::Assembler::db() - embeds byte (8 bits) (x86 naming). +//! - \ref x86::Assembler::dw() - embeds word (16 bits) (x86 naming). +//! - \ref x86::Assembler::dd() - embeds dword (32 bits) (x86 naming). +//! - \ref x86::Assembler::dq() - embeds qword (64 bits) (x86 naming). +//! +//! The following example illustrates how embed works: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void embedData(x86::Assembler& a) { +//! a.db(0xFF); // Embeds 0xFF byte. +//! a.dw(0xFF00); // Embeds 0xFF00 word (little-endian). +//! a.dd(0xFF000000); // Embeds 0xFF000000 dword (little-endian). +//! a.embedFloat(0.4f); // Embeds 0.4f (32-bit float, little-endian). +//! } +//! ``` +//! +//! Sometimes it's required to read the data that is embedded after code, for example. This can be done through +//! \ref Label as shown below: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void processData(x86::Assembler& a, const Label& L_Data) { +//! x86::Gp addr = a.zax(); // EAX or RAX. +//! x86::Gp val = x86::edi; // Where to store some value... +//! +//! // Approach 1 - Load the address to register through LEA. This approach +//! // is flexible as the address can be then manipulated, for +//! // example if you have a data array, which would need index. +//! a.lea(addr, x86::ptr(L_Data)); +//! a.mov(val, x86::dword_ptr(addr)); +//! +//! // Approach 2 - Load the data directly by using L_Data in address. It's +//! // worth noting that this doesn't work with indexes in X64 +//! // mode. It will use absolute address in 32-bit mode and +//! // relative address (RIP) in 64-bit mode. +//! a.mov(val, x86::dword_ptr(L_Data)); +//! } +//! ``` +//! +//! ### Label Embedding +//! +//! It's also possible to embed labels. In general AsmJit provides the following options: +//! +//! - \ref BaseEmitter::embedLabel() - Embeds absolute address of a label. This is target dependent and would +//! embed either 32-bit or 64-bit data that embeds absolute label address. This kind of embedding cannot be +//! used in a position independent code. +//! +//! - \ref BaseEmitter::embedLabelDelta() - Embeds a difference between two labels. The size of the difference +//! can be specified so it's possible to embed 8-bit, 16-bit, 32-bit, and 64-bit difference, which is sufficient +//! for most purposes. +//! +//! The following example demonstrates how to embed labels and their differences: +//! +//! ``` +//! #include +//! using namespace asmjit; +//! +//! void embedLabel(x86::Assembler& a, const Label& L_Data) { +//! // [1] Embed L_Data - the size of the data will be dependent on the target. +//! a.embedLabel(L_Data); +//! +//! // [2] Embed a 32-bit difference of two labels. +//! Label L_Here = a.newLabel(); +//! a.bind(L_Here); +//! // Embeds int32_t(L_Data - L_Here). +//! a.embedLabelDelta(L_Data, L_Here, 4); +//! } +//! ``` +//! +//! ### Using FuncFrame and FuncDetail with x86::Assembler +//! +//! The example below demonstrates how \ref FuncFrame and \ref FuncDetail can be used together with \ref x86::Assembler +//! to generate a function that will use platform dependent calling conventions automatically depending on the target: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef void (*SumIntsFunc)(int* dst, const int* a, const int* b); +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create JIT Runtime. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Decide which registers will be mapped to function arguments. Try changing +//! // registers of dst, src_a, and src_b and see what happens in function's +//! // prolog and epilog. +//! x86::Gp dst = a.zax(); +//! x86::Gp src_a = a.zcx(); +//! x86::Gp src_b = a.zdx(); +//! +//! x86::Xmm vec0 = x86::xmm0; +//! x86::Xmm vec1 = x86::xmm1; +//! +//! // Create/initialize FuncDetail and FuncFrame. +//! FuncDetail func; +//! func.init(FuncSignature::build(), +//! rt.environment()); +//! +//! FuncFrame frame; +//! frame.init(func); +//! +//! // Make XMM0 and XMM1 dirty - RegGroup::kVec describes XMM|YMM|ZMM registers. +//! frame.setDirtyRegs(RegGroup::kVec, Support::bitMask(0, 1)); +//! +//! // Alternatively, if you don't want to use register masks you can pass BaseReg +//! // to addDirtyRegs(). The following code would add both xmm0 and xmm1. +//! frame.addDirtyRegs(x86::xmm0, x86::xmm1); +//! +//! FuncArgsAssignment args(&func); // Create arguments assignment context. +//! args.assignAll(dst, src_a, src_b);// Assign our registers to arguments. +//! args.updateFuncFrame(frame); // Reflect our args in FuncFrame. +//! frame.finalize(); // Finalize the FuncFrame (updates it). +//! +//! a.emitProlog(frame); // Emit function prolog. +//! a.emitArgsAssignment(frame, args);// Assign arguments to registers. +//! a.movdqu(vec0, x86::ptr(src_a)); // Load 4 ints from [src_a] to XMM0. +//! a.movdqu(vec1, x86::ptr(src_b)); // Load 4 ints from [src_b] to XMM1. +//! a.paddd(vec0, vec1); // Add 4 ints in XMM1 to XMM0. +//! a.movdqu(x86::ptr(dst), vec0); // Store the result to [dst]. +//! a.emitEpilog(frame); // Emit function epilog and return. +//! +//! SumIntsFunc fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error case. +//! +//! // Execute the generated function. +//! int inA[4] = { 4, 3, 2, 1 }; +//! int inB[4] = { 1, 5, 2, 8 }; +//! int out[4]; +//! fn(out, inA, inB); +//! +//! // Prints {5 8 4 9} +//! printf("{%d %d %d %d}\n", out[0], out[1], out[2], out[3]); +//! +//! rt.release(fn); +//! return 0; +//! } +//! ``` +//! +//! ### Using x86::Assembler as Code-Patcher +//! +//! This is an advanced topic that is sometimes unavoidable. AsmJit by default appends machine code it generates +//! into a \ref CodeBuffer, however, it also allows to set the offset in \ref CodeBuffer explicitly and to overwrite +//! its content. This technique is extremely dangerous as X86 instructions have variable length (see below), so you +//! should in general only patch code to change instruction's immediate values or some other details not known the +//! at a time the instruction was emitted. A typical scenario that requires code-patching is when you start emitting +//! function and you don't know how much stack you want to reserve for it. +//! +//! Before we go further it's important to introduce instruction options, because they can help with code-patching +//! (and not only patching, but that will be explained in AVX-512 section): +//! +//! - Many general-purpose instructions (especially arithmetic ones) on X86 have multiple encodings - in AsmJit +//! this is usually called 'short form' and 'long form'. +//! +//! - AsmJit always tries to use 'short form' as it makes the resulting machine-code smaller, which is always +//! good - this decision is used by majority of assemblers out there. +//! +//! - AsmJit allows to override the default decision by using `short_()` and `long_()` instruction options to force +//! short or long form, respectively. The most useful is `long_()` as it basically forces AsmJit to always emit +//! the longest form. The `short_()` is not that useful as it's automatic (except jumps to non-bound labels). Note +//! that the underscore after each function name avoids collision with built-in C++ types. +//! +//! To illustrate what short form and long form means in binary let's assume we want to emit "add esp, 16" instruction, +//! which has two possible binary encodings: +//! +//! - `83C410` - This is a short form aka `short add esp, 16` - You can see opcode byte (0x8C), MOD/RM byte (0xC4) +//! and an 8-bit immediate value representing `16`. +//! +//! - `81C410000000` - This is a long form aka `long add esp, 16` - You can see a different opcode byte (0x81), the +//! same Mod/RM byte (0xC4) and a 32-bit immediate in little-endian representing `16`. +//! +//! It should be obvious that patching an existing instruction into an instruction having a different size may create +//! various problems. So it's recommended to be careful and to only patch instructions into instructions having the +//! same size. The example below demonstrates how instruction options can be used to guarantee the size of an +//! instruction by forcing the assembler to use long-form encoding: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef int (*Func)(void); +//! +//! int main(int argc, char* argv[]) { +//! JitRuntime rt; // Create a runtime specialized for JIT. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Assembler a(&code); // Create and attach x86::Assembler to code. +//! +//! // Let's get these registers from x86::Assembler. +//! x86::Gp zbp = a.zbp(); +//! x86::Gp zsp = a.zsp(); +//! +//! // Function prolog. +//! a.push(zbp); +//! a.mov(zbp, zsp); +//! +//! // This is where we are gonna patch the code later, so let's get the offset +//! // (the current location) from the beginning of the code-buffer. +//! size_t patchOffset = a.offset(); +//! // Let's just emit 'sub zsp, 0' for now, but don't forget to use LONG form. +//! a.long_().sub(zsp, 0); +//! +//! // ... emit some code (this just sets return value to zero) ... +//! a.xor_(x86::eax, x86::eax); +//! +//! // Function epilog and return. +//! a.mov(zsp, zbp); +//! a.pop(zbp); +//! a.ret(); +//! +//! // Now we know how much stack size we want to reserve. I have chosen 128 +//! // bytes on purpose as it's encodable only in long form that we have used. +//! +//! int stackSize = 128; // Number of bytes to reserve on the stack. +//! a.setOffset(patchOffset); // Move the current cursor to `patchOffset`. +//! a.long_().sub(zsp, stackSize); // Patch the code; don't forget to use LONG form. +//! +//! // Now the code is ready to be called +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "0". +//! +//! rt.release(fn); // Remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! If you run the example it will just work, because both instructions have the same size. As an experiment you can +//! try removing `long_()` form to see what happens when wrong code is generated. +//! +//! ### Code Patching and REX Prefix +//! +//! In 64-bit mode there is one more thing to worry about when patching code: REX prefix. It's a single byte prefix +//! designed to address registers with ids from 9 to 15 and to override the default width of operation from 32 to 64 +//! bits. AsmJit, like other assemblers, only emits REX prefix when it's necessary. If the patched code only changes +//! the immediate value as shown in the previous example then there is nothing to worry about as it doesn't change +//! the logic behind emitting REX prefix, however, if the patched code changes register id or overrides the operation +//! width then it's important to take care of REX prefix as well. +//! +//! AsmJit contains another instruction option that controls (forces) REX prefix - `rex()`. If you use it the +//! instruction emitted will always use REX prefix even when it's encodable without it. The following list contains +//! some instructions and their binary representations to illustrate when it's emitted: +//! +//! - `__83C410` - `add esp, 16` - 32-bit operation in 64-bit mode doesn't require REX prefix. +//! - `4083C410` - `rex add esp, 16` - 32-bit operation in 64-bit mode with forced REX prefix (0x40). +//! - `4883C410` - `add rsp, 16` - 64-bit operation in 64-bit mode requires REX prefix (0x48). +//! - `4183C410` - `add r12d, 16` - 32-bit operation in 64-bit mode using R12D requires REX prefix (0x41). +//! - `4983C410` - `add r12, 16` - 64-bit operation in 64-bit mode using R12 requires REX prefix (0x49). +//! +//! ### More Prefixes +//! +//! X86 architecture is known for its prefixes. AsmJit supports all prefixes +//! that can affect how the instruction is encoded: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void prefixesExample(x86::Assembler& a) { +//! // Lock prefix for implementing atomics: +//! // lock add dword ptr [rdi], 1 +//! a.lock().add(x86::dword_ptr(x86::rdi), 1); +//! +//! // Similarly, XAcquire/XRelease prefixes are also available: +//! // xacquire add dword ptr [rdi], 1 +//! a.xacquire().add(x86::dword_ptr(x86::rdi), 1); +//! +//! // Rep prefix (see also repe/repz and repne/repnz): +//! // rep movs byte ptr [rdi], byte ptr [rsi] +//! a.rep().movs(x86::byte_ptr(x86::rdi), x86::byte_ptr(x86::rsi)); +//! +//! // Forcing REX prefix in 64-bit mode. +//! // rex mov eax, 1 +//! a.rex().mov(x86::eax, 1); +//! +//! // AVX instruction without forced prefix uses the shortest encoding: +//! // vaddpd xmm0, xmm1, xmm2 -> [C5|F1|58|C2] +//! a.vaddpd(x86::xmm0, x86::xmm1, x86::xmm2); +//! +//! // Forcing VEX3 prefix (AVX): +//! // vex3 vaddpd xmm0, xmm1, xmm2 -> [C4|E1|71|58|C2] +//! a.vex3().vaddpd(x86::xmm0, x86::xmm1, x86::xmm2); +//! +//! // Forcing EVEX prefix (AVX512): +//! // evex vaddpd xmm0, xmm1, xmm2 -> [62|F1|F5|08|58|C2] +//! a.evex().vaddpd(x86::xmm0, x86::xmm1, x86::xmm2); +//! +//! // Some instructions accept prefixes not originally intended to: +//! // rep ret +//! a.rep().ret(); +//! } +//! ``` +//! +//! It's important to understand that prefixes are part of instruction options. When a member function that involves +//! adding a prefix is called the prefix is combined with existing instruction options, which will affect the next +//! instruction generated. +//! +//! ### Generating AVX512 code. +//! +//! x86::Assembler can generate AVX512+ code including the use of opmask registers. Opmask can be specified through +//! \ref x86::Assembler::k() function, which stores it as an extra register, which will be used by the next +//! instruction. AsmJit uses such concept for manipulating instruction options as well. +//! +//! The following AVX512 features are supported: +//! +//! - Opmask selector {k} and zeroing {z}. +//! - Rounding modes {rn|rd|ru|rz} and suppress-all-exceptions {sae} option. +//! - AVX512 broadcasts {1toN}. +//! +//! The following example demonstrates how AVX512 features can be used: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! void generateAVX512Code(x86::Assembler& a) { +//! using namespace x86; +//! +//! // Opmask Selectors +//! // ---------------- +//! // +//! // - Opmask / zeroing is part of the instruction options / extraReg. +//! // - k(reg) is like {kreg} in Intel syntax. +//! // - z() is like {z} in Intel syntax. +//! +//! // vaddpd zmm {k1} {z}, zmm1, zmm2 +//! a.k(k1).z().vaddpd(zmm0, zmm1, zmm2); +//! +//! // Memory Broadcasts +//! // ----------------- +//! // +//! // - Broadcast data is part of memory operand. +//! // - Use x86::Mem::_1to2(), x86::Mem::_1to4(), etc..., which returns a new x86::Mem operand with broadcast. +//! +//! // vaddpd zmm0 {k1} {z}, zmm1, [rcx] {1to8} +//! a.k(k1).z().vaddpd(zmm0, zmm1, x86::ptr(rcx)._1to8()); +//! +//! // Embedded Rounding & Suppress-All-Exceptions +//! // ------------------------------------------- +//! // +//! // - Rounding mode and {sae} are part of instruction options. +//! // - Use sae() to enable exception suppression. +//! // - Use rn_sae(), rd_sae(), ru_sae(), and rz_sae() - to enable rounding. +//! // - Embedded rounding implicitly sets {sae} as well, that's why the API +//! // also has sae() suffix, to make it clear. +//! +//! // vcmppd k1, zmm1, zmm2, 0x00 {sae} +//! a.sae().vcmppd(k1, zmm1, zmm2, 0); +//! +//! // vaddpd zmm0, zmm1, zmm2 {rz} +//! a.rz_sae().vaddpd(zmm0, zmm1, zmm2); +//! } +//! ``` +class ASMJIT_VIRTAPI Assembler + : public BaseAssembler, + public EmitterImplicitT { +public: + ASMJIT_NONCOPYABLE(Assembler) + typedef BaseAssembler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Assembler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Assembler() noexcept override; + + //! \} + + //! \cond INTERNAL + //! \name Internal + //! \{ + + // NOTE: x86::Assembler uses _privateData to store 'address-override' bit that is used to decide whether to emit + // address-override (67H) prefix based on the memory BASE+INDEX registers. It's either `kX86MemInfo_67H_X86` or + // `kX86MemInfo_67H_X64`. + ASMJIT_INLINE_NODEBUG uint32_t _addressOverrideMask() const noexcept { return _privateData; } + ASMJIT_INLINE_NODEBUG void _setAddressOverrideMask(uint32_t m) noexcept { _privateData = m; } + + //! \} + //! \endcond + + //! \cond INTERNAL + //! \name Emit + //! \{ + + ASMJIT_API Error _emit(InstId instId, const Operand_& o0, const Operand_& o1, const Operand_& o2, const Operand_* opExt) override; + + //! \} + //! \endcond + + //! \name Align + //! \{ + + ASMJIT_API Error align(AlignMode alignMode, uint32_t alignment) override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86ASSEMBLER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86builder.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86builder.h new file mode 100644 index 0000000000000000000000000000000000000000..b79abfc1a481fa43629af14c9482151c65be237f --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86builder.h @@ -0,0 +1,354 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86BUILDER_H_INCLUDED +#define ASMJIT_X86_X86BUILDER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_BUILDER + +#include "../core/builder.h" +#include "../x86/x86emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! X86/X64 builder implementation. +//! +//! The code representation used by \ref BaseBuilder is compatible with everything AsmJit provides. Each instruction +//! is stored as \ref InstNode, which contains instruction id, options, and operands. Each instruction emitted will +//! create a new \ref InstNode instance and add it to the current cursor in the double-linked list of nodes. Since +//! the instruction stream used by \ref BaseBuilder can be manipulated, we can rewrite the SumInts example from +//! \ref asmjit_assembler into the following: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! typedef void (*SumIntsFunc)(int* dst, const int* a, const int* b); +//! +//! // Small helper function to print the current content of `cb`. +//! static void dumpCode(BaseBuilder& builder, const char* phase) { +//! String sb; +//! formatOptions formatOptions {}; +//! +//! Formatter::formatNodeList(sb, formatOptions, &builder); +//! printf("%s:\n%s\n", phase, sb.data()); +//! } +//! +//! int main() { +//! JitRuntime rt; // Create JIT Runtime. +//! CodeHolder code; // Create a CodeHolder. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Builder cb(&code); // Create and attach x86::Builder to `code`. +//! +//! // Decide which registers will be mapped to function arguments. Try changing registers +//! // of `dst`, `srcA`, and `srcB` and see what happens in function's prolog and epilog. +//! x86::Gp dst = cb.zax(); +//! x86::Gp srcA = cb.zcx(); +//! x86::Gp srcB = cb.zdx(); +//! +//! X86::Xmm vec0 = x86::xmm0; +//! X86::Xmm vec1 = x86::xmm1; +//! +//! // Create and initialize `FuncDetail`. +//! FuncDetail func; +//! func.init(FuncSignature::build()); +//! +//! // Remember prolog insertion point. +//! BaseNode* prologInsertionPoint = cb.cursor(); +//! +//! // Emit function body: +//! cb.movdqu(vec0, x86::ptr(srcA)); // Load 4 ints from [srcA] to XMM0. +//! cb.movdqu(vec1, x86::ptr(srcB)); // Load 4 ints from [srcB] to XMM1. +//! cb.paddd(vec0, vec1); // Add 4 ints in XMM1 to XMM0. +//! cb.movdqu(x86::ptr(dst), vec0); // Store the result to [dst]. +//! +//! // Remember epilog insertion point. +//! BaseNode* epilogInsertionPoint = cb.cursor(); +//! +//! // Let's see what we have now. +//! dumpCode(cb, "Raw Function"); +//! +//! // Now, after we emitted the function body, we can insert the prolog, arguments +//! // allocation, and epilog. This is not possible with using pure x86::Assembler. +//! FuncFrame frame; +//! frame.init(func); +//! +//! // Make XMM0 and XMM1 dirty; RegGroup::kVec describes XMM|YMM|ZMM registers. +//! frame.setDirtyRegs(RegGroup::kVec, IntUtils::mask(0, 1)); +//! +//! FuncArgsAssignment args(&func); // Create arguments assignment context. +//! args.assignAll(dst, srcA, srcB); // Assign our registers to arguments. +//! args.updateFrame(frame); // Reflect our args in FuncFrame. +//! frame.finalize(); // Finalize the FuncFrame (updates it). +//! +//! // Insert function prolog and allocate arguments to registers. +//! cb.setCursor(prologInsertionPoint); +//! cb.emitProlog(frame); +//! cb.emitArgsAssignment(frame, args); +//! +//! // Insert function epilog. +//! cb.setCursor(epilogInsertionPoint); +//! cb.emitEpilog(frame); +//! +//! // Let's see how the function's prolog and epilog looks. +//! dumpCode(cb, "Prolog & Epilog"); +//! +//! // IMPORTANT: Builder requires finalize() to be called to serialize its +//! // code to the Assembler (it automatically creates one if not attached). +//! cb.finalize(); +//! +//! SumIntsFunc fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error case. +//! +//! // Execute the generated function. +//! int inA[4] = { 4, 3, 2, 1 }; +//! int inB[4] = { 1, 5, 2, 8 }; +//! int out[4]; +//! fn(out, inA, inB); +//! +//! // Prints {5 8 4 9} +//! printf("{%d %d %d %d}\n", out[0], out[1], out[2], out[3]); +//! +//! rt.release(fn); // Explicitly remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! When the example is executed it should output the following (this one using AMD64-SystemV ABI): +//! +//! ``` +//! Raw Function: +//! movdqu xmm0, [rcx] +//! movdqu xmm1, [rdx] +//! paddd xmm0, xmm1 +//! movdqu [rax], xmm0 +//! +//! Prolog & Epilog: +//! mov rax, rdi +//! mov rcx, rsi +//! movdqu xmm0, [rcx] +//! movdqu xmm1, [rdx] +//! paddd xmm0, xmm1 +//! movdqu [rax], xmm0 +//! ret +//! +//! {5 8 4 9} +//! ``` +//! +//! The number of use-cases of \ref BaseBuilder is not limited and highly depends on your creativity and experience. +//! The previous example can be easily improved to collect all dirty registers inside the function programmatically +//! and to pass them to \ref FuncFrame::setDirtyRegs(). +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! // NOTE: This function doesn't cover all possible constructs. It ignores instructions that write +//! // to implicit registers that are not part of the operand list. It also counts read-only registers. +//! // Real implementation would be a bit more complicated, but still relatively easy to implement. +//! static void collectDirtyRegs(const BaseNode* first, +//! const BaseNode* last, +//! Support::Array& regMask) { +//! const BaseNode* node = first; +//! while (node) { +//! if (node->actsAsInst()) { +//! const InstNode* inst = node->as(); +//! const Operand* opArray = inst->operands(); +//! +//! for (uint32_t i = 0, opCount = inst->opCount(); i < opCount; i++) { +//! const Operand& op = opArray[i]; +//! if (op.isReg()) { +//! const x86::Reg& reg = op.as(); +//! if (reg.group() <= RegGroup::kMaxVirt) { +//! regMask[reg.group()] |= 1u << reg.id(); +//! } +//! } +//! } +//! } +//! +//! if (node == last) +//! break; +//! node = node->next(); +//! } +//! } +//! +//! static void setDirtyRegsOfFuncFrame(const x86::Builder& builder, FuncFrame& frame) { +//! Support::Array regMask {}; +//! collectDirtyRegs(builder.firstNode(), builder.lastNode(), regMask); +//! +//! // X86/X64 ABIs only require to save GP/XMM registers: +//! frame.setDirtyRegs(RegGroup::kGp, regMask[RegGroup::kGp]); +//! frame.setDirtyRegs(RegGroup::kVec, regMask[RegGroup::kVec]); +//! } +//! ``` +//! +//! ### Casting Between Various Emitters +//! +//! Even when \ref BaseAssembler and \ref BaseBuilder provide the same interface as defined by \ref BaseEmitter their +//! platform dependent variants like \ref x86::Assembler and \ref x86::Builder cannot be interchanged or casted to each +//! other by using a C++ `static_cast<>`. The main reason is the inheritance graph of these classes is different and +//! cast-incompatible, as illustrated below: +//! +//! ``` +//! +--------------+ +=========================+ +//! +----------------------->| x86::Emitter |<--+--# x86::EmitterImplicitT<> #<--+ +//! | +--------------+ | +=========================+ | +//! | (abstract) | (mixin) | +//! | +--------------+ +~~~~~~~~~~~~~~+ | | +//! +-->| BaseAssembler|---->|x86::Assembler|<--+ | +//! | +--------------+ +~~~~~~~~~~~~~~+ | | +//! | (abstract) (final) | | +//! +===============+ | +--------------+ +~~~~~~~~~~~~~~+ | | +//! # BaseEmitter #--+-->| BaseBuilder |--+->| x86::Builder |<--+ | +//! +===============+ +--------------+ | +~~~~~~~~~~~~~~+ | +//! (abstract) (abstract) | (final) | +//! +---------------------+ | +//! | | +//! | +--------------+ +~~~~~~~~~~~~~~+ +=========================+ | +//! +-->| BaseCompiler |---->| x86::Compiler|<-----# x86::EmitterExplicitT<> #---+ +//! +--------------+ +~~~~~~~~~~~~~~+ +=========================+ +//! (abstract) (final) (mixin) +//! ``` +//! +//! The graph basically shows that it's not possible to cast between \ref x86::Assembler and \ref x86::Builder. +//! However, since both share the base interface (\ref BaseEmitter) it's possible to cast them to a class that +//! cannot be instantiated, but defines the same interface - the class is called \ref x86::Emitter and was +//! introduced to make it possible to write a function that can emit to both \ref x86::Assembler and \ref +//! x86::Builder. Note that \ref x86::Emitter cannot be created, it's abstract and has private constructors and +//! destructors; it was only designed to be casted to and used as an interface. +//! +//! Each architecture-specific emitter implements a member function called +//! `as()`, which casts the instance to the architecture +//! specific emitter as illustrated below: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! static void emitSomething(x86::Emitter* e) { +//! e->mov(x86::eax, x86::ebx); +//! } +//! +//! static void assemble(CodeHolder& code, bool useAsm) { +//! if (useAsm) { +//! x86::Assembler assembler(&code); +//! emitSomething(assembler.as()); +//! } +//! else { +//! x86::Builder builder(&code); +//! emitSomething(builder.as()); +//! +//! // NOTE: Builder requires `finalize()` to be called to serialize its +//! // content to Assembler (it automatically creates one if not attached). +//! builder.finalize(); +//! } +//! } +//! ``` +//! +//! The example above shows how to create a function that can emit code to either \ref x86::Assembler or \ref +//! x86::Builder through \ref x86::Emitter, which provides emitter-neutral functionality. \ref x86::Emitter, +//! however, doesn't provide any emitter-specific functionality like `setCursor()`. +//! +//! ### Code Injection and Manipulation +//! +//! \ref BaseBuilder emitter stores its nodes in a double-linked list, which makes it easy to manipulate that +//! list during the code generation or afterwards. Each node is always emitted next to the current cursor and +//! the cursor is advanced to that newly emitted node. The cursor can be retrieved and changed by \ref +//! BaseBuilder::cursor() and \ref BaseBuilder::setCursor(), respectively. +//! +//! The example below demonstrates how to remember a node and inject something +//! next to it. +//! +//! ``` +//! static void example(x86::Builder& builder) { +//! // Emit something, after it returns the cursor would point at the last +//! // emitted node. +//! builder.mov(x86::rax, x86::rdx); // [1] +//! +//! // We can retrieve the node. +//! BaseNode* node = builder.cursor(); +//! +//! // Change the instruction we just emitted, just for fun... +//! if (node->isInst()) { +//! InstNode* inst = node->as(); +//! // Changes the operands at index [1] to RCX. +//! inst->setOp(1, x86::rcx); +//! } +//! +//! // ------------------------- Generate Some Code ------------------------- +//! builder.add(x86::rax, x86::rdx); // [2] +//! builder.shr(x86::rax, 3); // [3] +//! // ---------------------------------------------------------------------- +//! +//! // Now, we know where our node is, and we can simply change the cursor +//! // and start emitting something after it. The setCursor() function +//! // returns the previous cursor, and it's always a good practice to remember +//! // it, because you never know if you are not already injecting the code +//! // somewhere else... +//! BaseNode* oldCursor = builder.setCursor(node); +//! +//! builder.mul(x86::rax, 8); // [4] +//! +//! // Restore the cursor +//! builder.setCursor(oldCursor); +//! } +//! ``` +//! +//! The function above would actually emit the following: +//! +//! ``` +//! mov rax, rcx ; [1] Patched at the beginning. +//! mul rax, 8 ; [4] Injected. +//! add rax, rdx ; [2] Followed [1] initially. +//! shr rax, 3 ; [3] Follows [2]. +//! ``` +class ASMJIT_VIRTAPI Builder + : public BaseBuilder, + public EmitterImplicitT { +public: + ASMJIT_NONCOPYABLE(Builder) + typedef BaseBuilder Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Builder(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Builder() noexcept override; + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_BUILDER +#endif // ASMJIT_X86_X86BUILDER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86compiler.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86compiler.h new file mode 100644 index 0000000000000000000000000000000000000000..a80993c5d531131e79672ad67b37d49fe1076484 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86compiler.h @@ -0,0 +1,726 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86COMPILER_H_INCLUDED +#define ASMJIT_X86_X86COMPILER_H_INCLUDED + +#include "../core/api-config.h" +#ifndef ASMJIT_NO_COMPILER + +#include "../core/compiler.h" +#include "../core/type.h" +#include "../x86/x86emitter.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! X86/X64 compiler implementation. +//! +//! ### Compiler Basics +//! +//! The first \ref x86::Compiler example shows how to generate a function that simply returns an integer value. It's +//! an analogy to the first Assembler example: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*Func)(void); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! cc.addFunc(FuncSignature::build()); // Begin a function of `int fn(void)` signature. +//! +//! x86::Gp vReg = cc.newGpd(); // Create a 32-bit general purpose register. +//! cc.mov(vReg, 1); // Move one to our virtual register `vReg`. +//! cc.ret(vReg); // Return `vReg` from the function. +//! +//! cc.endFunc(); // End of the function body. +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! int result = fn(); // Execute the generated code. +//! printf("%d\n", result); // Print the resulting "1". +//! +//! rt.release(fn); // Explicitly remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! The \ref BaseCompiler::addFunc() and \ref BaseCompiler::endFunc() functions are used to define the function and +//! its end. Both must be called per function, but the body doesn't have to be generated in sequence. An example of +//! generating two functions will be shown later. The next example shows more complicated code that contain a loop +//! and generates a simple memory copy function that uses `uint32_t` items: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef void (*MemCpy32)(uint32_t* dst, const uint32_t* src, size_t count); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! FuncNode* funcNode = cc.addFunc ( // Begin the function of the following signature: +//! FuncSignature::build()); // 3rd argument - size_t (machine reg-size). +//! +//! Label L_Loop = cc.newLabel(); // Start of the loop. +//! Label L_Exit = cc.newLabel(); // Used to exit early. +//! +//! x86::Gp dst = cc.newIntPtr("dst"); // Create `dst` register (destination pointer). +//! x86::Gp src = cc.newIntPtr("src"); // Create `src` register (source pointer). +//! x86::Gp i = cc.newUIntPtr("i"); // Create `i` register (loop counter). +//! +//! funcNode->setArg(0, dst); // Assign `dst` argument. +//! funcNode->setArg(1, src); // Assign `src` argument. +//! funcNode->setArg(2, i); // Assign `i` argument. +//! +//! cc.test(i, i); // Early exit if length is zero. +//! cc.jz(L_Exit); +//! +//! cc.bind(L_Loop); // Bind the beginning of the loop here. +//! +//! x86::Gp tmp = cc.newInt32("tmp"); // Copy a single dword (4 bytes). +//! cc.mov(tmp, x86::dword_ptr(src)); // Load DWORD from [src] address. +//! cc.mov(x86::dword_ptr(dst), tmp); // Store DWORD to [dst] address. +//! +//! cc.add(src, 4); // Increment `src`. +//! cc.add(dst, 4); // Increment `dst`. +//! +//! cc.dec(i); // Loop until `i` is non-zero. +//! cc.jnz(L_Loop); +//! +//! cc.bind(L_Exit); // Label used by early exit. +//! cc.endFunc(); // End of the function body. +//! +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! // Add the generated code to the runtime. +//! MemCpy32 memcpy32; +//! Error err = rt.add(&memcpy32, &code); +//! +//! // Handle a possible error returned by AsmJit. +//! if (err) +//! return 1; +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! // Test the generated code. +//! uint32_t input[6] = { 1, 2, 3, 5, 8, 13 }; +//! uint32_t output[6]; +//! memcpy32(output, input, 6); +//! +//! for (uint32_t i = 0; i < 6; i++) +//! printf("%d\n", output[i]); +//! +//! rt.release(memcpy32); +//! return 0; +//! } +//! ``` +//! +//! ### AVX and AVX-512 +//! +//! AVX and AVX-512 code generation must be explicitly enabled via \ref FuncFrame to work properly. If it's not setup +//! correctly then Prolog & Epilog would use SSE instead of AVX instructions to work with SIMD registers. In addition, +//! Compiler requires explicitly enable AVX-512 via \ref FuncFrame in order to use all 32 SIMD registers. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef void (*Func)(void*); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! FuncNode* funcNode = cc.addFunc(FuncSignature::build()); +//! +//! // Use the following to enable AVX and/or AVX-512. +//! funcNode->frame().setAvxEnabled(); +//! funcNode->frame().setAvx512Enabled(); +//! +//! // Do something with the input pointer. +//! x86::Gp addr = cc.newIntPtr("addr"); +//! x86::Zmm vreg = cc.newZmm("vreg"); +//! +//! funcNode->setArg(0, addr); +//! +//! cc.vmovdqu32(vreg, x86::ptr(addr)); +//! cc.vpaddq(vreg, vreg, vreg); +//! cc.vmovdqu32(x86::ptr(addr), vreg); +//! +//! cc.endFunc(); // End of the function body. +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Func fn; +//! Error err = rt.add(&fn, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! // Execute the generated code and print some output. +//! uint64_t data[] = { 1, 2, 3, 4, 5, 6, 7, 8 }; +//! fn(data); +//! printf("%llu\n", (unsigned long long)data[0]); +//! +//! rt.release(fn); // Explicitly remove the function from the runtime. +//! return 0; +//! } +//! ``` +//! +//! ### Recursive Functions +//! +//! It's possible to create more functions by using the same \ref x86::Compiler instance and make links between them. +//! In such case it's important to keep the pointer to \ref FuncNode. +//! +//! The example below creates a simple Fibonacci function that calls itself recursively: +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef uint32_t (*Fibonacci)(uint32_t x); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! FuncNode* funcNode = cc.addFunc( // Begin of the Fibonacci function, addFunc() +//! FuncSignature::build()); // Returns a pointer to the FuncNode node. +//! +//! Label L_Exit = cc.newLabel(); // Exit label. +//! x86::Gp x = cc.newUInt32(); // Function x argument. +//! x86::Gp y = cc.newUInt32(); // Temporary. +//! +//! funcNode->setArg(0, x); +//! +//! cc.cmp(x, 3); // Return x if less than 3. +//! cc.jb(L_Exit); +//! +//! cc.mov(y, x); // Make copy of the original x. +//! cc.dec(x); // Decrease x. +//! +//! InvokeNode* invokeNode; // Function invocation: +//! cc.invoke(&invokeNode, // - InvokeNode (output). +//! funcNode->label(), // - Function address or Label. +//! FuncSignature::build()); // - Function signature. +//! +//! invokeNode->setArg(0, x); // Assign x as the first argument. +//! invokeNode->setRet(0, x); // Assign x as a return value as well. +//! +//! cc.add(x, y); // Combine the return value with y. +//! +//! cc.bind(L_Exit); +//! cc.ret(x); // Return x. +//! cc.endFunc(); // End of the function body. +//! +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Fibonacci fib; +//! Error err = rt.add(&fib, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! // Test the generated code. +//! printf("Fib(%u) -> %u\n", 8, fib(8)); +//! +//! rt.release(fib); +//! return 0; +//! } +//! ``` +//! +//! ### Stack Management +//! +//! Function's stack-frame is managed automatically, which is used by the register allocator to spill virtual +//! registers. It also provides an interface to allocate user-defined block of the stack, which can be used as +//! a temporary storage by the generated function. In the following example a stack of 256 bytes size is allocated, +//! filled by bytes starting from 0 to 255 and then iterated again to sum all the values. +//! +//! ``` +//! #include +//! #include +//! +//! using namespace asmjit; +//! +//! // Signature of the generated function. +//! typedef int (*Func)(void); +//! +//! int main() { +//! JitRuntime rt; // Runtime specialized for JIT code execution. +//! CodeHolder code; // Holds code and relocation information. +//! +//! code.init(rt.environment(), // Initialize code to match the JIT environment. +//! rt.cpuFeatures()); +//! x86::Compiler cc(&code); // Create and attach x86::Compiler to code. +//! +//! cc.addFunc(FuncSignature::build()); // Create a function that returns int. +//! +//! x86::Gp p = cc.newIntPtr("p"); +//! x86::Gp i = cc.newIntPtr("i"); +//! +//! // Allocate 256 bytes on the stack aligned to 4 bytes. +//! x86::Mem stack = cc.newStack(256, 4); +//! +//! x86::Mem stackIdx(stack); // Copy of stack with i added. +//! stackIdx.setIndex(i); // stackIdx <- stack[i]. +//! stackIdx.setSize(1); // stackIdx <- byte ptr stack[i]. +//! +//! // Load a stack address to `p`. This step is purely optional and shows +//! // that `lea` is useful to load a memory operands address (even absolute) +//! // to a general purpose register. +//! cc.lea(p, stack); +//! +//! // Clear i (xor is a C++ keyword, hence 'xor_' is used instead). +//! cc.xor_(i, i); +//! +//! Label L1 = cc.newLabel(); +//! Label L2 = cc.newLabel(); +//! +//! cc.bind(L1); // First loop, fill the stack. +//! cc.mov(stackIdx, i.r8()); // stack[i] = uint8_t(i). +//! +//! cc.inc(i); // i++; +//! cc.cmp(i, 256); // if (i < 256) +//! cc.jb(L1); // goto L1; +//! +//! // Second loop, sum all bytes stored in `stack`. +//! x86::Gp sum = cc.newInt32("sum"); +//! x86::Gp val = cc.newInt32("val"); +//! +//! cc.xor_(i, i); +//! cc.xor_(sum, sum); +//! +//! cc.bind(L2); +//! +//! cc.movzx(val, stackIdx); // val = uint32_t(stack[i]); +//! cc.add(sum, val); // sum += val; +//! +//! cc.inc(i); // i++; +//! cc.cmp(i, 256); // if (i < 256) +//! cc.jb(L2); // goto L2; +//! +//! cc.ret(sum); // Return the `sum` of all values. +//! cc.endFunc(); // End of the function body. +//! +//! cc.finalize(); // Translate and assemble the whole 'cc' content. +//! // ----> x86::Compiler is no longer needed from here and can be destroyed <---- +//! +//! Func func; +//! Error err = rt.add(&func, &code); // Add the generated code to the runtime. +//! if (err) return 1; // Handle a possible error returned by AsmJit. +//! // ----> CodeHolder is no longer needed from here and can be destroyed <---- +//! +//! printf("Func() -> %d\n", func()); // Test the generated code. +//! +//! rt.release(func); +//! return 0; +//! } +//! ``` +//! +//! ### Constant Pool +//! +//! Compiler provides two constant pools for a general purpose code generation: +//! +//! - Local constant pool - Part of \ref FuncNode, can be only used by a single function and added after the +//! function epilog sequence (after `ret` instruction). +//! +//! - Global constant pool - Part of \ref BaseCompiler, flushed at the end of the generated code by \ref +//! BaseEmitter::finalize(). +//! +//! The example below illustrates how a built-in constant pool can be used: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! static void exampleUseOfConstPool(x86::Compiler& cc) { +//! cc.addFunc(FuncSignature::build()); +//! +//! x86::Gp v0 = cc.newGpd("v0"); +//! x86::Gp v1 = cc.newGpd("v1"); +//! +//! x86::Mem c0 = cc.newInt32Const(ConstPoolScope::kLocal, 200); +//! x86::Mem c1 = cc.newInt32Const(ConstPoolScope::kLocal, 33); +//! +//! cc.mov(v0, c0); +//! cc.mov(v1, c1); +//! cc.add(v0, v1); +//! +//! cc.ret(v0); +//! cc.endFunc(); +//! } +//! ``` +//! +//! ### Jump Tables +//! +//! x86::Compiler supports `jmp` instruction with reg/mem operand, which is a commonly used pattern to implement +//! indirect jumps within a function, for example to implement `switch()` statement in a programming languages. +//! By default AsmJit assumes that every basic block can be a possible jump target as it's unable to deduce targets +//! from instruction's operands. This is a very pessimistic default that should be avoided if possible as it's costly +//! and very unfriendly to liveness analysis and register allocation. +//! +//! Instead of relying on such pessimistic default behavior, let's use \ref JumpAnnotation to annotate a jump where +//! all targets are known: +//! +//! ``` +//! #include +//! +//! using namespace asmjit; +//! +//! static void exampleUseOfIndirectJump(x86::Compiler& cc) { +//! FuncNode* funcNode = cc.addFunc(FuncSignature::build()); +//! +//! // Function arguments +//! x86::Xmm a = cc.newXmmSs("a"); +//! x86::Xmm b = cc.newXmmSs("b"); +//! x86::Gp op = cc.newUInt32("op"); +//! +//! x86::Gp target = cc.newIntPtr("target"); +//! x86::Gp offset = cc.newIntPtr("offset"); +//! +//! Label L_Table = cc.newLabel(); +//! Label L_Add = cc.newLabel(); +//! Label L_Sub = cc.newLabel(); +//! Label L_Mul = cc.newLabel(); +//! Label L_Div = cc.newLabel(); +//! Label L_End = cc.newLabel(); +//! +//! funcNode->setArg(0, a); +//! funcNode->setArg(1, b); +//! funcNode->setArg(2, op); +//! +//! // Jump annotation is a building block that allows to annotate all possible targets where `jmp()` can +//! // jump. It then drives the CFG construction and liveness analysis, which impacts register allocation. +//! JumpAnnotation* annotation = cc.newJumpAnnotation(); +//! annotation->addLabel(L_Add); +//! annotation->addLabel(L_Sub); +//! annotation->addLabel(L_Mul); +//! annotation->addLabel(L_Div); +//! +//! // Most likely not the common indirect jump approach, but it +//! // doesn't really matter how final address is calculated. The +//! // most important path using JumpAnnotation with `jmp()`. +//! cc.lea(offset, x86::ptr(L_Table)); +//! if (cc.is64Bit()) +//! cc.movsxd(target, x86::dword_ptr(offset, op.cloneAs(offset), 2)); +//! else +//! cc.mov(target, x86::dword_ptr(offset, op.cloneAs(offset), 2)); +//! cc.add(target, offset); +//! cc.jmp(target, annotation); +//! +//! // Acts like a switch() statement in C. +//! cc.bind(L_Add); +//! cc.addss(a, b); +//! cc.jmp(L_End); +//! +//! cc.bind(L_Sub); +//! cc.subss(a, b); +//! cc.jmp(L_End); +//! +//! cc.bind(L_Mul); +//! cc.mulss(a, b); +//! cc.jmp(L_End); +//! +//! cc.bind(L_Div); +//! cc.divss(a, b); +//! +//! cc.bind(L_End); +//! cc.ret(a); +//! +//! cc.endFunc(); +//! +//! // Relative int32_t offsets of `L_XXX - L_Table`. +//! cc.bind(L_Table); +//! cc.embedLabelDelta(L_Add, L_Table, 4); +//! cc.embedLabelDelta(L_Sub, L_Table, 4); +//! cc.embedLabelDelta(L_Mul, L_Table, 4); +//! cc.embedLabelDelta(L_Div, L_Table, 4); +//! } +//! ``` +class ASMJIT_VIRTAPI Compiler + : public BaseCompiler, + public EmitterExplicitT { +public: + ASMJIT_NONCOPYABLE(Compiler) + typedef BaseCompiler Base; + + //! \name Construction & Destruction + //! \{ + + ASMJIT_API explicit Compiler(CodeHolder* code = nullptr) noexcept; + ASMJIT_API ~Compiler() noexcept override; + + //! \} + + //! \name Virtual Registers + //! \{ + +#ifndef ASMJIT_NO_LOGGING +# define ASMJIT_NEW_REG_FMT(OUT, PARAM, FORMAT, ARGS) \ + _newRegFmt(&OUT, PARAM, FORMAT, ARGS) +#else +# define ASMJIT_NEW_REG_FMT(OUT, PARAM, FORMAT, ARGS) \ + DebugUtils::unused(FORMAT); \ + DebugUtils::unused(std::forward(args)...); \ + _newReg(&OUT, PARAM) +#endif + +#define ASMJIT_NEW_REG_CUSTOM(FUNC, REG) \ + ASMJIT_INLINE_NODEBUG REG FUNC(TypeId typeId) { \ + REG reg(Globals::NoInit); \ + _newReg(®, typeId); \ + return reg; \ + } \ + \ + template \ + ASMJIT_INLINE_NODEBUG REG FUNC(TypeId typeId, const char* fmt, Args&&... args) { \ + REG reg(Globals::NoInit); \ + ASMJIT_NEW_REG_FMT(reg, typeId, fmt, std::forward(args)...); \ + return reg; \ + } + +#define ASMJIT_NEW_REG_TYPED(FUNC, REG, TYPE_ID) \ + ASMJIT_INLINE_NODEBUG REG FUNC() { \ + REG reg(Globals::NoInit); \ + _newReg(®, TYPE_ID); \ + return reg; \ + } \ + \ + template \ + ASMJIT_INLINE_NODEBUG REG FUNC(const char* fmt, Args&&... args) { \ + REG reg(Globals::NoInit); \ + ASMJIT_NEW_REG_FMT(reg, TYPE_ID, fmt, std::forward(args)...); \ + return reg; \ + } + + template + ASMJIT_INLINE_NODEBUG RegT newSimilarReg(const RegT& ref) { + RegT reg(Globals::NoInit); + _newReg(®, ref); + return reg; + } + + template + ASMJIT_INLINE_NODEBUG RegT newSimilarReg(const RegT& ref, const char* fmt, Args&&... args) { + RegT reg(Globals::NoInit); + ASMJIT_NEW_REG_FMT(reg, ref, fmt, std::forward(args)...); + return reg; + } + + ASMJIT_NEW_REG_CUSTOM(newReg , Reg ) + ASMJIT_NEW_REG_CUSTOM(newGp , Gp ) + ASMJIT_NEW_REG_CUSTOM(newVec , Vec ) + ASMJIT_NEW_REG_CUSTOM(newK , KReg) + + ASMJIT_NEW_REG_TYPED(newInt8 , Gp , TypeId::kInt8) + ASMJIT_NEW_REG_TYPED(newUInt8 , Gp , TypeId::kUInt8) + ASMJIT_NEW_REG_TYPED(newInt16 , Gp , TypeId::kInt16) + ASMJIT_NEW_REG_TYPED(newUInt16 , Gp , TypeId::kUInt16) + ASMJIT_NEW_REG_TYPED(newInt32 , Gp , TypeId::kInt32) + ASMJIT_NEW_REG_TYPED(newUInt32 , Gp , TypeId::kUInt32) + ASMJIT_NEW_REG_TYPED(newInt64 , Gp , TypeId::kInt64) + ASMJIT_NEW_REG_TYPED(newUInt64 , Gp , TypeId::kUInt64) + ASMJIT_NEW_REG_TYPED(newIntPtr , Gp , TypeId::kIntPtr) + ASMJIT_NEW_REG_TYPED(newUIntPtr, Gp , TypeId::kUIntPtr) + + ASMJIT_NEW_REG_TYPED(newGpb , Gp , TypeId::kUInt8) + ASMJIT_NEW_REG_TYPED(newGpw , Gp , TypeId::kUInt16) + ASMJIT_NEW_REG_TYPED(newGpd , Gp , TypeId::kUInt32) + ASMJIT_NEW_REG_TYPED(newGpq , Gp , TypeId::kUInt64) + ASMJIT_NEW_REG_TYPED(newGpz , Gp , TypeId::kUIntPtr) + ASMJIT_NEW_REG_TYPED(newXmm , Xmm , TypeId::kInt32x4) + ASMJIT_NEW_REG_TYPED(newXmmSs , Xmm , TypeId::kFloat32x1) + ASMJIT_NEW_REG_TYPED(newXmmSd , Xmm , TypeId::kFloat64x1) + ASMJIT_NEW_REG_TYPED(newXmmPs , Xmm , TypeId::kFloat32x4) + ASMJIT_NEW_REG_TYPED(newXmmPd , Xmm , TypeId::kFloat64x2) + ASMJIT_NEW_REG_TYPED(newYmm , Ymm , TypeId::kInt32x8) + ASMJIT_NEW_REG_TYPED(newYmmPs , Ymm , TypeId::kFloat32x8) + ASMJIT_NEW_REG_TYPED(newYmmPd , Ymm , TypeId::kFloat64x4) + ASMJIT_NEW_REG_TYPED(newZmm , Zmm , TypeId::kInt32x16) + ASMJIT_NEW_REG_TYPED(newZmmPs , Zmm , TypeId::kFloat32x16) + ASMJIT_NEW_REG_TYPED(newZmmPd , Zmm , TypeId::kFloat64x8) + ASMJIT_NEW_REG_TYPED(newMm , Mm , TypeId::kMmx64) + ASMJIT_NEW_REG_TYPED(newKb , KReg, TypeId::kMask8) + ASMJIT_NEW_REG_TYPED(newKw , KReg, TypeId::kMask16) + ASMJIT_NEW_REG_TYPED(newKd , KReg, TypeId::kMask32) + ASMJIT_NEW_REG_TYPED(newKq , KReg, TypeId::kMask64) + +#undef ASMJIT_NEW_REG_TYPED +#undef ASMJIT_NEW_REG_CUSTOM +#undef ASMJIT_NEW_REG_FMT + + //! \} + + //! \name Stack + //! \{ + + //! Creates a new memory chunk allocated on the current function's stack. + ASMJIT_INLINE_NODEBUG Mem newStack(uint32_t size, uint32_t alignment, const char* name = nullptr) { + Mem m(Globals::NoInit); + _newStack(&m, size, alignment, name); + return m; + } + + //! \} + + //! \name Constants + //! \{ + + //! Put data to a constant-pool and get a memory reference to it. + ASMJIT_INLINE_NODEBUG Mem newConst(ConstPoolScope scope, const void* data, size_t size) { + Mem m(Globals::NoInit); + _newConst(&m, scope, data, size); + return m; + } + + //! Put a BYTE `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newByteConst(ConstPoolScope scope, uint8_t val) noexcept { return newConst(scope, &val, 1); } + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newWordConst(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newDWordConst(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newQWordConst(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt16Const(ConstPoolScope scope, int16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a WORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt16Const(ConstPoolScope scope, uint16_t val) noexcept { return newConst(scope, &val, 2); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt32Const(ConstPoolScope scope, int32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a DWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt32Const(ConstPoolScope scope, uint32_t val) noexcept { return newConst(scope, &val, 4); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newInt64Const(ConstPoolScope scope, int64_t val) noexcept { return newConst(scope, &val, 8); } + //! Put a QWORD `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newUInt64Const(ConstPoolScope scope, uint64_t val) noexcept { return newConst(scope, &val, 8); } + + //! Put a SP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newFloatConst(ConstPoolScope scope, float val) noexcept { return newConst(scope, &val, 4); } + //! Put a DP-FP `val` to a constant-pool. + ASMJIT_INLINE_NODEBUG Mem newDoubleConst(ConstPoolScope scope, double val) noexcept { return newConst(scope, &val, 8); } + + //! \} + + //! \name Instruction Options + //! \{ + + //! Force the compiler to not follow the conditional or unconditional jump. + ASMJIT_INLINE_NODEBUG Compiler& unfollow() noexcept { addInstOptions(InstOptions::kUnfollow); return *this; } + //! Tell the compiler that the destination variable will be overwritten. + ASMJIT_INLINE_NODEBUG Compiler& overwrite() noexcept { addInstOptions(InstOptions::kOverwrite); return *this; } + + //! \} + + //! \name Function Call & Ret Intrinsics + //! \{ + + //! Invoke a function call without `target` type enforcement. + ASMJIT_INLINE_NODEBUG Error invoke_(InvokeNode** out, const Operand_& target, const FuncSignature& signature) { + return addInvokeNode(out, Inst::kIdCall, target, signature); + } + + //! Invoke a function call of the given `target` and `signature` and store the added node to `out`. + //! + //! Creates a new \ref InvokeNode, initializes all the necessary members to match the given function `signature`, + //! adds the node to the compiler, and stores its pointer to `out`. The operation is atomic, if anything fails + //! nullptr is stored in `out` and error code is returned. + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Gp& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Mem& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Label& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, const Imm& target, const FuncSignature& signature) { return invoke_(out, target, signature); } + //! \overload + ASMJIT_INLINE_NODEBUG Error invoke(InvokeNode** out, uint64_t target, const FuncSignature& signature) { return invoke_(out, Imm(int64_t(target)), signature); } + + //! Return from function. + ASMJIT_INLINE_NODEBUG Error ret() { return addRet(Operand(), Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0) { return addRet(o0, Operand()); } + //! \overload + ASMJIT_INLINE_NODEBUG Error ret(const BaseReg& o0, const BaseReg& o1) { return addRet(o0, o1); } + + //! \} + + //! \name Jump Tables Support + //! \{ + + using EmitterExplicitT::jmp; + + //! Adds a jump to the given `target` with the provided jump `annotation`. + ASMJIT_INLINE_NODEBUG Error jmp(const BaseReg& target, JumpAnnotation* annotation) { return emitAnnotatedJump(Inst::kIdJmp, target, annotation); } + //! \overload + ASMJIT_INLINE_NODEBUG Error jmp(const BaseMem& target, JumpAnnotation* annotation) { return emitAnnotatedJump(Inst::kIdJmp, target, annotation); } + + //! \} + + //! \name Events + //! \{ + + ASMJIT_API Error onAttach(CodeHolder* code) noexcept override; + ASMJIT_API Error onDetach(CodeHolder* code) noexcept override; + + //! \} + + //! \name Finalize + //! \{ + + ASMJIT_API Error finalize() override; + + //! \} +}; + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // !ASMJIT_NO_COMPILER +#endif // ASMJIT_X86_X86COMPILER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86emitter.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86emitter.h new file mode 100644 index 0000000000000000000000000000000000000000..8722406170cb9a3dd1a8b7e62603ca407d67c272 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86emitter.h @@ -0,0 +1,4493 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86EMITTER_H_INCLUDED +#define ASMJIT_X86_X86EMITTER_H_INCLUDED + +#include "../core/emitter.h" +#include "../core/support.h" +#include "../x86/x86globals.h" +#include "../x86/x86operand.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +#define ASMJIT_INST_0x(NAME, ID) \ + inline Error NAME() { return _emitter()->_emitI(Inst::kId##ID); } + +#define ASMJIT_INST_1x(NAME, ID, T0) \ + inline Error NAME(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID, o0); } + +#define ASMJIT_INST_1c(NAME, ID, CONV, T0) \ + inline Error NAME(CondCode cc, const T0& o0) { return _emitter()->_emitI(CONV(cc), o0); } \ + inline Error NAME##a(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##a, o0); } \ + inline Error NAME##ae(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ae, o0); } \ + inline Error NAME##b(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##b, o0); } \ + inline Error NAME##be(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##be, o0); } \ + inline Error NAME##c(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##c, o0); } \ + inline Error NAME##e(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##e, o0); } \ + inline Error NAME##g(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##g, o0); } \ + inline Error NAME##ge(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ge, o0); } \ + inline Error NAME##l(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##l, o0); } \ + inline Error NAME##le(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##le, o0); } \ + inline Error NAME##na(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##na, o0); } \ + inline Error NAME##nae(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nae, o0); } \ + inline Error NAME##nb(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nb, o0); } \ + inline Error NAME##nbe(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nbe, o0); } \ + inline Error NAME##nc(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nc, o0); } \ + inline Error NAME##ne(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ne, o0); } \ + inline Error NAME##ng(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ng, o0); } \ + inline Error NAME##nge(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nge, o0); } \ + inline Error NAME##nl(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nl, o0); } \ + inline Error NAME##nle(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nle, o0); } \ + inline Error NAME##no(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##no, o0); } \ + inline Error NAME##np(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##np, o0); } \ + inline Error NAME##ns(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##ns, o0); } \ + inline Error NAME##nz(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##nz, o0); } \ + inline Error NAME##o(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##o, o0); } \ + inline Error NAME##p(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##p, o0); } \ + inline Error NAME##pe(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##pe, o0); } \ + inline Error NAME##po(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##po, o0); } \ + inline Error NAME##s(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##s, o0); } \ + inline Error NAME##z(const T0& o0) { return _emitter()->_emitI(Inst::kId##ID##z, o0); } + +#define ASMJIT_INST_2x(NAME, ID, T0, T1) \ + inline Error NAME(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID, o0, o1); } + +#define ASMJIT_INST_2c(NAME, ID, CONV, T0, T1) \ + inline Error NAME(CondCode cc, const T0& o0, const T1& o1) { return _emitter()->_emitI(CONV(cc), o0, o1); } \ + inline Error NAME##a(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##a, o0, o1); } \ + inline Error NAME##ae(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ae, o0, o1); } \ + inline Error NAME##b(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##b, o0, o1); } \ + inline Error NAME##be(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##be, o0, o1); } \ + inline Error NAME##c(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##c, o0, o1); } \ + inline Error NAME##e(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##e, o0, o1); } \ + inline Error NAME##g(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##g, o0, o1); } \ + inline Error NAME##ge(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ge, o0, o1); } \ + inline Error NAME##l(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##l, o0, o1); } \ + inline Error NAME##le(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##le, o0, o1); } \ + inline Error NAME##na(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##na, o0, o1); } \ + inline Error NAME##nae(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nae, o0, o1); } \ + inline Error NAME##nb(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nb, o0, o1); } \ + inline Error NAME##nbe(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nbe, o0, o1); } \ + inline Error NAME##nc(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nc, o0, o1); } \ + inline Error NAME##ne(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ne, o0, o1); } \ + inline Error NAME##ng(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ng, o0, o1); } \ + inline Error NAME##nge(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nge, o0, o1); } \ + inline Error NAME##nl(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nl, o0, o1); } \ + inline Error NAME##nle(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nle, o0, o1); } \ + inline Error NAME##no(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##no, o0, o1); } \ + inline Error NAME##np(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##np, o0, o1); } \ + inline Error NAME##ns(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##ns, o0, o1); } \ + inline Error NAME##nz(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##nz, o0, o1); } \ + inline Error NAME##o(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##o, o0, o1); } \ + inline Error NAME##p(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##p, o0, o1); } \ + inline Error NAME##pe(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##pe, o0, o1); } \ + inline Error NAME##po(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##po, o0, o1); } \ + inline Error NAME##s(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##s, o0, o1); } \ + inline Error NAME##z(const T0& o0, const T1& o1) { return _emitter()->_emitI(Inst::kId##ID##z, o0, o1); } + +#define ASMJIT_INST_3x(NAME, ID, T0, T1, T2) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2); } + +#define ASMJIT_INST_4x(NAME, ID, T0, T1, T2, T3) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3); } + +#define ASMJIT_INST_5x(NAME, ID, T0, T1, T2, T3, T4) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4); } + +#define ASMJIT_INST_6x(NAME, ID, T0, T1, T2, T3, T4, T5) \ + inline Error NAME(const T0& o0, const T1& o1, const T2& o2, const T3& o3, const T4& o4, const T5& o5) { return _emitter()->_emitI(Inst::kId##ID, o0, o1, o2, o3, o4, o5); } + +//! \addtogroup asmjit_x86 +//! \{ + +//! Emitter (X86 - explicit). +template +struct EmitterExplicitT { + //! \cond + + // These typedefs are used to describe implicit operands passed explicitly. + typedef Gp Gp_AL; + typedef Gp Gp_AH; + typedef Gp Gp_CL; + typedef Gp Gp_AX; + typedef Gp Gp_DX; + + typedef Gp Gp_EAX; + typedef Gp Gp_EBX; + typedef Gp Gp_ECX; + typedef Gp Gp_EDX; + + typedef Gp Gp_RAX; + typedef Gp Gp_RBX; + typedef Gp Gp_RCX; + typedef Gp Gp_RDX; + + typedef Gp Gp_ZAX; + typedef Gp Gp_ZBX; + typedef Gp Gp_ZCX; + typedef Gp Gp_ZDX; + + typedef Mem DS_ZAX; // ds:[zax] + typedef Mem DS_ZDI; // ds:[zdi] + typedef Mem ES_ZDI; // es:[zdi] + typedef Mem DS_ZSI; // ds:[zsi] + + typedef Xmm XMM0; + + // These two are unfortunately reported by the sanitizer. We know what we do, however, the sanitizer doesn't. + // I have tried to use reinterpret_cast instead, but that would generate bad code when compiled by MSC. + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG This* _emitter() noexcept { return static_cast(this); } + ASMJIT_ATTRIBUTE_NO_SANITIZE_UNDEF ASMJIT_INLINE_NODEBUG const This* _emitter() const noexcept { return static_cast(this); } + + //! \endcond + + //! \name Native Registers + //! \{ + + //! Returns either GPD or GPQ register of the given `id` depending on the emitter's architecture. + inline Gp gpz(uint32_t id) const noexcept { return Gp(_emitter()->_gpSignature, id); } + + inline Gp zax() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdAx); } + inline Gp zcx() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdCx); } + inline Gp zdx() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdDx); } + inline Gp zbx() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdBx); } + inline Gp zsp() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdSp); } + inline Gp zbp() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdBp); } + inline Gp zsi() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdSi); } + inline Gp zdi() const noexcept { return Gp(_emitter()->_gpSignature, Gp::kIdDi); } + + //! \} + + //! \name Native Pointers + //! \{ + + //! Creates a target dependent pointer of which base register's id is `baseId`. + inline Mem ptr_base(uint32_t baseId, int32_t off = 0, uint32_t size = 0) const noexcept { + return Mem(OperandSignature::fromOpType(OperandType::kMem) | + OperandSignature::fromMemBaseType(_emitter()->_gpSignature.regType()) | + OperandSignature::fromSize(size), + baseId, 0, off); + } + + inline Mem ptr_zax(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdAx, off, size); } + inline Mem ptr_zcx(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdCx, off, size); } + inline Mem ptr_zdx(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdDx, off, size); } + inline Mem ptr_zbx(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdBx, off, size); } + inline Mem ptr_zsp(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdSp, off, size); } + inline Mem ptr_zbp(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdBp, off, size); } + inline Mem ptr_zsi(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdSi, off, size); } + inline Mem ptr_zdi(int32_t off = 0, uint32_t size = 0) const noexcept { return ptr_base(Gp::kIdDi, off, size); } + + //! Creates an `intptr_t` memory operand depending on the current architecture. + inline Mem intptr_ptr(const Gp& base, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Gp& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Gp& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Label& base, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Label& base, const Gp& index, uint32_t shift, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Label& base, const Vec& index, uint32_t shift, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(const Rip& rip, int32_t offset = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(rip, offset, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(uint64_t base) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr(uint64_t base, const Gp& index, uint32_t shift = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, nativeGpSize); + } + //! \overload + inline Mem intptr_ptr_abs(uint64_t base) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, nativeGpSize, OperandSignature::fromValue(Mem::AddrType::kAbs)); + } + //! \overload + inline Mem intptr_ptr_abs(uint64_t base, const Gp& index, uint32_t shift = 0) const noexcept { + uint32_t nativeGpSize = _emitter()->registerSize(); + return Mem(base, index, shift, nativeGpSize, OperandSignature::fromValue(Mem::AddrType::kRel)); + } + + //! \} + + //! \name Embed + //! \{ + + //! Embeds 8-bit integer data. + inline Error db(uint8_t x, size_t repeatCount = 1) { return _emitter()->embedUInt8(x, repeatCount); } + //! Embeds 16-bit integer data. + inline Error dw(uint16_t x, size_t repeatCount = 1) { return _emitter()->embedUInt16(x, repeatCount); } + //! Embeds 32-bit integer data. + inline Error dd(uint32_t x, size_t repeatCount = 1) { return _emitter()->embedUInt32(x, repeatCount); } + //! Embeds 64-bit integer data. + inline Error dq(uint64_t x, size_t repeatCount = 1) { return _emitter()->embedUInt64(x, repeatCount); } + + //! Adds data in a given structure instance to the CodeBuffer. + template + inline Error dstruct(const T& x) { return _emitter()->embed(&x, uint32_t(sizeof(T))); } + + //! \} + +protected: + //! \cond + inline This& _addInstOptions(InstOptions options) noexcept { + _emitter()->addInstOptions(options); + return *_emitter(); + } + //! \endcond + +public: + //! \name Short/Long Form Options + //! \{ + + //! Force short form of jmp/jcc instruction. + inline This& short_() noexcept { return _addInstOptions(InstOptions::kShortForm); } + //! Force long form of jmp/jcc instruction. + inline This& long_() noexcept { return _addInstOptions(InstOptions::kLongForm); } + + //! \} + + //! \name Encoding Options + //! \{ + + //! Prefer MOD/RM encoding when both MOD/RM and MOD/MR forms are applicable. + inline This& mod_rm() noexcept { return _addInstOptions(InstOptions::kX86_ModRM); } + + //! Prefer MOD/MR encoding when both MOD/RM and MOD/MR forms are applicable. + inline This& mod_mr() noexcept { return _addInstOptions(InstOptions::kX86_ModMR); } + + //! \} + + //! \name Prefix Options + //! \{ + + //! Condition is likely to be taken (has only benefit on P4). + inline This& taken() noexcept { return _addInstOptions(InstOptions::kTaken); } + //! Condition is unlikely to be taken (has only benefit on P4). + inline This& notTaken() noexcept { return _addInstOptions(InstOptions::kNotTaken); } + + //! Use LOCK prefix. + inline This& lock() noexcept { return _addInstOptions(InstOptions::kX86_Lock); } + //! Use XACQUIRE prefix. + inline This& xacquire() noexcept { return _addInstOptions(InstOptions::kX86_XAcquire); } + //! Use XRELEASE prefix. + inline This& xrelease() noexcept { return _addInstOptions(InstOptions::kX86_XRelease); } + + //! Use BND/REPNE prefix. + //! + //! \note This is the same as using `repne()` or `repnz()` prefix. + inline This& bnd() noexcept { return _addInstOptions(InstOptions::kX86_Repne); } + + //! Use REP/REPZ prefix. + //! + //! \note This is the same as using `repe()` or `repz()` prefix. + inline This& rep(const Gp& zcx) noexcept { + _emitter()->_extraReg.init(zcx); + return _addInstOptions(InstOptions::kX86_Rep); + } + + //! Use REP/REPE prefix. + //! + //! \note This is the same as using `rep()` or `repz()` prefix. + inline This& repe(const Gp& zcx) noexcept { return rep(zcx); } + + //! Use REP/REPE prefix. + //! + //! \note This is the same as using `rep()` or `repe()` prefix. + inline This& repz(const Gp& zcx) noexcept { return rep(zcx); } + + //! Use REPNE prefix. + //! + //! \note This is the same as using `bnd()` or `repnz()` prefix. + inline This& repne(const Gp& zcx) noexcept { + _emitter()->_extraReg.init(zcx); + return _addInstOptions(InstOptions::kX86_Repne); + } + + //! Use REPNE prefix. + //! + //! \note This is the same as using `bnd()` or `repne()` prefix. + inline This& repnz(const Gp& zcx) noexcept { return repne(zcx); } + + //! \} + + //! \name REX Options + //! \{ + + //! Force REX prefix to be emitted even when it's not needed (X86_64). + //! + //! \note Don't use when using high 8-bit registers as REX prefix makes them inaccessible and `x86::Assembler` + //! would fail to encode such instruction. + inline This& rex() noexcept { return _addInstOptions(InstOptions::kX86_Rex); } + + //! Force REX.B prefix (X64) [It exists for special purposes only]. + inline This& rex_b() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeB); } + //! Force REX.X prefix (X64) [It exists for special purposes only]. + inline This& rex_x() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeX); } + //! Force REX.R prefix (X64) [It exists for special purposes only]. + inline This& rex_r() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeR); } + //! Force REX.W prefix (X64) [It exists for special purposes only]. + inline This& rex_w() noexcept { return _addInstOptions(InstOptions::kX86_OpCodeW); } + + //! \} + + //! \name VEX and EVEX Options + //! \{ + + //! Use VEX prefix instead of EVEX prefix (useful to select AVX_VNNI instruction instead of AVX512_VNNI). + inline This& vex() noexcept { return _addInstOptions(InstOptions::kX86_Vex); } + //! Force 3-byte VEX prefix (AVX+). + inline This& vex3() noexcept { return _addInstOptions(InstOptions::kX86_Vex3); } + //! Force 4-byte EVEX prefix (AVX512+). + inline This& evex() noexcept { return _addInstOptions(InstOptions::kX86_Evex); } + + //! \} + + //! \name AVX-512 Options & Masking + //! \{ + + //! Use masking {k} (AVX512+). + inline This& k(const KReg& kreg) noexcept { + _emitter()->_extraReg.init(kreg); + return *_emitter(); + } + + //! Use zeroing instead of merging (AVX512+). + inline This& z() noexcept { return _addInstOptions(InstOptions::kX86_ZMask); } + + //! Suppress all exceptions (AVX512+). + inline This& sae() noexcept { return _addInstOptions(InstOptions::kX86_SAE); } + //! Static rounding mode {rn} (round-to-nearest even) and {sae} (AVX512+). + inline This& rn_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RN_SAE); } + //! Static rounding mode {rd} (round-down, toward -inf) and {sae} (AVX512+). + inline This& rd_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RD_SAE); } + //! Static rounding mode {ru} (round-up, toward +inf) and {sae} (AVX512+). + inline This& ru_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RU_SAE); } + //! Static rounding mode {rz} (round-toward-zero, truncate) and {sae} (AVX512+). + inline This& rz_sae() noexcept { return _addInstOptions(InstOptions::kX86_ER | InstOptions::kX86_RZ_SAE); } + + //! \} + + //! \name Core Instructions + //! \{ + + ASMJIT_INST_2x(adc, Adc, Gp, Gp) // ANY + ASMJIT_INST_2x(adc, Adc, Gp, Mem) // ANY + ASMJIT_INST_2x(adc, Adc, Gp, Imm) // ANY + ASMJIT_INST_2x(adc, Adc, Mem, Gp) // ANY + ASMJIT_INST_2x(adc, Adc, Mem, Imm) // ANY + ASMJIT_INST_2x(add, Add, Gp, Gp) // ANY + ASMJIT_INST_2x(add, Add, Gp, Mem) // ANY + ASMJIT_INST_2x(add, Add, Gp, Imm) // ANY + ASMJIT_INST_2x(add, Add, Mem, Gp) // ANY + ASMJIT_INST_2x(add, Add, Mem, Imm) // ANY + ASMJIT_INST_2x(and_, And, Gp, Gp) // ANY + ASMJIT_INST_2x(and_, And, Gp, Mem) // ANY + ASMJIT_INST_2x(and_, And, Gp, Imm) // ANY + ASMJIT_INST_2x(and_, And, Mem, Gp) // ANY + ASMJIT_INST_2x(and_, And, Mem, Imm) // ANY + ASMJIT_INST_2x(bound, Bound, Gp, Mem) // X86 + ASMJIT_INST_2x(bsf, Bsf, Gp, Gp) // ANY + ASMJIT_INST_2x(bsf, Bsf, Gp, Mem) // ANY + ASMJIT_INST_2x(bsr, Bsr, Gp, Gp) // ANY + ASMJIT_INST_2x(bsr, Bsr, Gp, Mem) // ANY + ASMJIT_INST_1x(bswap, Bswap, Gp) // ANY + ASMJIT_INST_2x(bt, Bt, Gp, Gp) // ANY + ASMJIT_INST_2x(bt, Bt, Gp, Imm) // ANY + ASMJIT_INST_2x(bt, Bt, Mem, Gp) // ANY + ASMJIT_INST_2x(bt, Bt, Mem, Imm) // ANY + ASMJIT_INST_2x(btc, Btc, Gp, Gp) // ANY + ASMJIT_INST_2x(btc, Btc, Gp, Imm) // ANY + ASMJIT_INST_2x(btc, Btc, Mem, Gp) // ANY + ASMJIT_INST_2x(btc, Btc, Mem, Imm) // ANY + ASMJIT_INST_2x(btr, Btr, Gp, Gp) // ANY + ASMJIT_INST_2x(btr, Btr, Gp, Imm) // ANY + ASMJIT_INST_2x(btr, Btr, Mem, Gp) // ANY + ASMJIT_INST_2x(btr, Btr, Mem, Imm) // ANY + ASMJIT_INST_2x(bts, Bts, Gp, Gp) // ANY + ASMJIT_INST_2x(bts, Bts, Gp, Imm) // ANY + ASMJIT_INST_2x(bts, Bts, Mem, Gp) // ANY + ASMJIT_INST_2x(bts, Bts, Mem, Imm) // ANY + ASMJIT_INST_1x(cbw, Cbw, Gp_AX) // ANY [EXPLICIT] AX <- Sign Extend AL + ASMJIT_INST_2x(cdq, Cdq, Gp_EDX, Gp_EAX) // ANY [EXPLICIT] EDX:EAX <- Sign Extend EAX + ASMJIT_INST_1x(cdqe, Cdqe, Gp_EAX) // X64 [EXPLICIT] RAX <- Sign Extend EAX + ASMJIT_INST_2x(cqo, Cqo, Gp_RDX, Gp_RAX) // X64 [EXPLICIT] RDX:RAX <- Sign Extend RAX + ASMJIT_INST_2x(cwd, Cwd, Gp_DX, Gp_AX) // ANY [EXPLICIT] DX:AX <- Sign Extend AX + ASMJIT_INST_1x(cwde, Cwde, Gp_EAX) // ANY [EXPLICIT] EAX <- Sign Extend AX + ASMJIT_INST_1x(call, Call, Gp) // ANY + ASMJIT_INST_1x(call, Call, Mem) // ANY + ASMJIT_INST_1x(call, Call, Label) // ANY + ASMJIT_INST_1x(call, Call, Imm) // ANY + ASMJIT_INST_2c(cmov, Cmov, Inst::cmovccFromCond, Gp, Gp) // CMOV + ASMJIT_INST_2c(cmov, Cmov, Inst::cmovccFromCond, Gp, Mem) // CMOV + ASMJIT_INST_2x(cmp, Cmp, Gp, Gp) // ANY + ASMJIT_INST_2x(cmp, Cmp, Gp, Mem) // ANY + ASMJIT_INST_2x(cmp, Cmp, Gp, Imm) // ANY + ASMJIT_INST_2x(cmp, Cmp, Mem, Gp) // ANY + ASMJIT_INST_2x(cmp, Cmp, Mem, Imm) // ANY + ASMJIT_INST_2x(cmps, Cmps, DS_ZSI, ES_ZDI) // ANY [EXPLICIT] + ASMJIT_INST_3x(cmpxchg, Cmpxchg, Gp, Gp, Gp_ZAX) // I486 [EXPLICIT] + ASMJIT_INST_3x(cmpxchg, Cmpxchg, Mem, Gp, Gp_ZAX) // I486 [EXPLICIT] + ASMJIT_INST_5x(cmpxchg16b, Cmpxchg16b, Mem, Gp_RDX, Gp_RAX, Gp_RCX, Gp_RBX); // CMPXCHG16B [EXPLICIT] m == EDX:EAX ? m <- ECX:EBX + ASMJIT_INST_5x(cmpxchg8b, Cmpxchg8b, Mem, Gp_EDX, Gp_EAX, Gp_ECX, Gp_EBX); // CMPXCHG8B [EXPLICIT] m == RDX:RAX ? m <- RCX:RBX + ASMJIT_INST_1x(dec, Dec, Gp) // ANY + ASMJIT_INST_1x(dec, Dec, Mem) // ANY + ASMJIT_INST_2x(div, Div, Gp, Gp) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / r8 + ASMJIT_INST_2x(div, Div, Gp, Mem) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / m8 + ASMJIT_INST_3x(div, Div, Gp, Gp, Gp) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / r16|r32|r64 + ASMJIT_INST_3x(div, Div, Gp, Gp, Mem) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / m16|m32|m64 + ASMJIT_INST_2x(idiv, Idiv, Gp, Gp) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / r8 + ASMJIT_INST_2x(idiv, Idiv, Gp, Mem) // ANY [EXPLICIT] AH[Rem]: AL[Quot] <- AX / m8 + ASMJIT_INST_3x(idiv, Idiv, Gp, Gp, Gp) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / r16|r32|r64 + ASMJIT_INST_3x(idiv, Idiv, Gp, Gp, Mem) // ANY [EXPLICIT] xDX[Rem]:xAX[Quot] <- xDX:xAX / m16|m32|m64 + ASMJIT_INST_2x(imul, Imul, Gp, Gp) // ANY [EXPLICIT] AX <- AL * r8 | ra <- ra * rb + ASMJIT_INST_2x(imul, Imul, Gp, Mem) // ANY [EXPLICIT] AX <- AL * m8 | ra <- ra * m16|m32|m64 + ASMJIT_INST_3x(imul, Imul, Gp, Gp, Imm) // ANY + ASMJIT_INST_3x(imul, Imul, Gp, Mem, Imm) // ANY + ASMJIT_INST_3x(imul, Imul, Gp, Gp, Gp) // ANY [EXPLICIT] xDX:xAX <- xAX * r16|r32|r64 + ASMJIT_INST_3x(imul, Imul, Gp, Gp, Mem) // ANY [EXPLICIT] xDX:xAX <- xAX * m16|m32|m64 + ASMJIT_INST_1x(inc, Inc, Gp) // ANY + ASMJIT_INST_1x(inc, Inc, Mem) // ANY + ASMJIT_INST_1c(j, J, Inst::jccFromCond, Label) // ANY + ASMJIT_INST_1c(j, J, Inst::jccFromCond, Imm) // ANY + ASMJIT_INST_2x(jecxz, Jecxz, Gp, Label) // ANY [EXPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_2x(jecxz, Jecxz, Gp, Imm) // ANY [EXPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_1x(jmp, Jmp, Gp) // ANY + ASMJIT_INST_1x(jmp, Jmp, Mem) // ANY + ASMJIT_INST_1x(jmp, Jmp, Label) // ANY + ASMJIT_INST_1x(jmp, Jmp, Imm) // ANY + ASMJIT_INST_2x(lcall, Lcall, Imm, Imm) // ANY + ASMJIT_INST_1x(lcall, Lcall, Mem) // ANY + ASMJIT_INST_2x(lea, Lea, Gp, Mem) // ANY + ASMJIT_INST_2x(ljmp, Ljmp, Imm, Imm) // ANY + ASMJIT_INST_1x(ljmp, Ljmp, Mem) // ANY + ASMJIT_INST_2x(lods, Lods, Gp_ZAX, DS_ZSI) // ANY [EXPLICIT] + ASMJIT_INST_2x(loop, Loop, Gp_ZCX, Label) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_2x(loop, Loop, Gp_ZCX, Imm) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_2x(loope, Loope, Gp_ZCX, Label) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_2x(loope, Loope, Gp_ZCX, Imm) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_2x(loopne, Loopne, Gp_ZCX, Label) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_2x(loopne, Loopne, Gp_ZCX, Imm) // ANY [EXPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_2x(mov, Mov, Gp, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, Mem) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, Imm) // ANY + ASMJIT_INST_2x(mov, Mov, Mem, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Mem, Imm) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, CReg) // ANY + ASMJIT_INST_2x(mov, Mov, CReg, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, DReg) // ANY + ASMJIT_INST_2x(mov, Mov, DReg, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, Gp, SReg) // ANY + ASMJIT_INST_2x(mov, Mov, Mem, SReg) // ANY + ASMJIT_INST_2x(mov, Mov, SReg, Gp) // ANY + ASMJIT_INST_2x(mov, Mov, SReg, Mem) // ANY + ASMJIT_INST_2x(movabs, Movabs, Gp, Mem) // X64 + ASMJIT_INST_2x(movabs, Movabs, Gp, Imm) // X64 + ASMJIT_INST_2x(movabs, Movabs, Mem, Gp) // X64 + ASMJIT_INST_2x(movnti, Movnti, Mem, Gp) // SSE2 + ASMJIT_INST_2x(movs, Movs, ES_ZDI, DS_ZSI) // ANY [EXPLICIT] + ASMJIT_INST_2x(movsx, Movsx, Gp, Gp) // ANY + ASMJIT_INST_2x(movsx, Movsx, Gp, Mem) // ANY + ASMJIT_INST_2x(movsxd, Movsxd, Gp, Gp) // X64 + ASMJIT_INST_2x(movsxd, Movsxd, Gp, Mem) // X64 + ASMJIT_INST_2x(movzx, Movzx, Gp, Gp) // ANY + ASMJIT_INST_2x(movzx, Movzx, Gp, Mem) // ANY + ASMJIT_INST_2x(mul, Mul, Gp_AX, Gp) // ANY [EXPLICIT] AX <- AL * r8 + ASMJIT_INST_2x(mul, Mul, Gp_AX, Mem) // ANY [EXPLICIT] AX <- AL * m8 + ASMJIT_INST_3x(mul, Mul, Gp_ZDX, Gp_ZAX, Gp) // ANY [EXPLICIT] xDX:xAX <- xAX * r16|r32|r64 + ASMJIT_INST_3x(mul, Mul, Gp_ZDX, Gp_ZAX, Mem) // ANY [EXPLICIT] xDX:xAX <- xAX * m16|m32|m64 + ASMJIT_INST_1x(neg, Neg, Gp) // ANY + ASMJIT_INST_1x(neg, Neg, Mem) // ANY + ASMJIT_INST_0x(nop, Nop) // ANY + ASMJIT_INST_1x(nop, Nop, Gp) // ANY + ASMJIT_INST_1x(nop, Nop, Mem) // ANY + ASMJIT_INST_2x(nop, Nop, Gp, Gp) // ANY + ASMJIT_INST_2x(nop, Nop, Mem, Gp) // ANY + ASMJIT_INST_1x(not_, Not, Gp) // ANY + ASMJIT_INST_1x(not_, Not, Mem) // ANY + ASMJIT_INST_2x(or_, Or, Gp, Gp) // ANY + ASMJIT_INST_2x(or_, Or, Gp, Mem) // ANY + ASMJIT_INST_2x(or_, Or, Gp, Imm) // ANY + ASMJIT_INST_2x(or_, Or, Mem, Gp) // ANY + ASMJIT_INST_2x(or_, Or, Mem, Imm) // ANY + ASMJIT_INST_1x(pop, Pop, Gp) // ANY + ASMJIT_INST_1x(pop, Pop, Mem) // ANY + ASMJIT_INST_1x(pop, Pop, SReg); // ANY + ASMJIT_INST_0x(popa, Popa) // X86 + ASMJIT_INST_0x(popad, Popad) // X86 + ASMJIT_INST_0x(popf, Popf) // ANY + ASMJIT_INST_0x(popfd, Popfd) // X86 + ASMJIT_INST_0x(popfq, Popfq) // X64 + ASMJIT_INST_1x(push, Push, Gp) // ANY + ASMJIT_INST_1x(push, Push, Mem) // ANY + ASMJIT_INST_1x(push, Push, SReg) // ANY + ASMJIT_INST_1x(push, Push, Imm) // ANY + ASMJIT_INST_0x(pusha, Pusha) // X86 + ASMJIT_INST_0x(pushad, Pushad) // X86 + ASMJIT_INST_0x(pushf, Pushf) // ANY + ASMJIT_INST_0x(pushfd, Pushfd) // X86 + ASMJIT_INST_0x(pushfq, Pushfq) // X64 + ASMJIT_INST_2x(rcl, Rcl, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(rcl, Rcl, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(rcl, Rcl, Gp, Imm) // ANY + ASMJIT_INST_2x(rcl, Rcl, Mem, Imm) // ANY + ASMJIT_INST_2x(rcr, Rcr, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(rcr, Rcr, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(rcr, Rcr, Gp, Imm) // ANY + ASMJIT_INST_2x(rcr, Rcr, Mem, Imm) // ANY + ASMJIT_INST_2x(rol, Rol, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(rol, Rol, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(rol, Rol, Gp, Imm) // ANY + ASMJIT_INST_2x(rol, Rol, Mem, Imm) // ANY + ASMJIT_INST_2x(ror, Ror, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(ror, Ror, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(ror, Ror, Gp, Imm) // ANY + ASMJIT_INST_2x(ror, Ror, Mem, Imm) // ANY + ASMJIT_INST_2x(sbb, Sbb, Gp, Gp) // ANY + ASMJIT_INST_2x(sbb, Sbb, Gp, Mem) // ANY + ASMJIT_INST_2x(sbb, Sbb, Gp, Imm) // ANY + ASMJIT_INST_2x(sbb, Sbb, Mem, Gp) // ANY + ASMJIT_INST_2x(sbb, Sbb, Mem, Imm) // ANY + ASMJIT_INST_2x(sal, Sal, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(sal, Sal, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(sal, Sal, Gp, Imm) // ANY + ASMJIT_INST_2x(sal, Sal, Mem, Imm) // ANY + ASMJIT_INST_2x(sar, Sar, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(sar, Sar, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(sar, Sar, Gp, Imm) // ANY + ASMJIT_INST_2x(sar, Sar, Mem, Imm) // ANY + ASMJIT_INST_2x(scas, Scas, Gp_ZAX, ES_ZDI) // ANY [EXPLICIT] + ASMJIT_INST_1c(set, Set, Inst::setccFromCond, Gp) // ANY + ASMJIT_INST_1c(set, Set, Inst::setccFromCond, Mem) // ANY + ASMJIT_INST_2x(shl, Shl, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(shl, Shl, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(shl, Shl, Gp, Imm) // ANY + ASMJIT_INST_2x(shl, Shl, Mem, Imm) // ANY + ASMJIT_INST_2x(shr, Shr, Gp, Gp_CL) // ANY + ASMJIT_INST_2x(shr, Shr, Mem, Gp_CL) // ANY + ASMJIT_INST_2x(shr, Shr, Gp, Imm) // ANY + ASMJIT_INST_2x(shr, Shr, Mem, Imm) // ANY + ASMJIT_INST_3x(shld, Shld, Gp, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shld, Shld, Mem, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shld, Shld, Gp, Gp, Imm) // ANY + ASMJIT_INST_3x(shld, Shld, Mem, Gp, Imm) // ANY + ASMJIT_INST_3x(shrd, Shrd, Gp, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shrd, Shrd, Mem, Gp, Gp_CL) // ANY + ASMJIT_INST_3x(shrd, Shrd, Gp, Gp, Imm) // ANY + ASMJIT_INST_3x(shrd, Shrd, Mem, Gp, Imm) // ANY + ASMJIT_INST_2x(stos, Stos, ES_ZDI, Gp_ZAX) // ANY [EXPLICIT] + ASMJIT_INST_2x(sub, Sub, Gp, Gp) // ANY + ASMJIT_INST_2x(sub, Sub, Gp, Mem) // ANY + ASMJIT_INST_2x(sub, Sub, Gp, Imm) // ANY + ASMJIT_INST_2x(sub, Sub, Mem, Gp) // ANY + ASMJIT_INST_2x(sub, Sub, Mem, Imm) // ANY + ASMJIT_INST_2x(test, Test, Gp, Gp) // ANY + ASMJIT_INST_2x(test, Test, Gp, Imm) // ANY + ASMJIT_INST_2x(test, Test, Mem, Gp) // ANY + ASMJIT_INST_2x(test, Test, Mem, Imm) // ANY + ASMJIT_INST_2x(ud0, Ud0, Gp, Gp) // ANY + ASMJIT_INST_2x(ud0, Ud0, Gp, Mem) // ANY + ASMJIT_INST_2x(ud1, Ud1, Gp, Gp) // ANY + ASMJIT_INST_2x(ud1, Ud1, Gp, Mem) // ANY + ASMJIT_INST_0x(ud2, Ud2) // ANY + ASMJIT_INST_2x(xadd, Xadd, Gp, Gp) // ANY + ASMJIT_INST_2x(xadd, Xadd, Mem, Gp) // ANY + ASMJIT_INST_2x(xchg, Xchg, Gp, Gp) // ANY + ASMJIT_INST_2x(xchg, Xchg, Mem, Gp) // ANY + ASMJIT_INST_2x(xchg, Xchg, Gp, Mem) // ANY + ASMJIT_INST_2x(xor_, Xor, Gp, Gp) // ANY + ASMJIT_INST_2x(xor_, Xor, Gp, Mem) // ANY + ASMJIT_INST_2x(xor_, Xor, Gp, Imm) // ANY + ASMJIT_INST_2x(xor_, Xor, Mem, Gp) // ANY + ASMJIT_INST_2x(xor_, Xor, Mem, Imm) // ANY + + //! \} + + //! \name Core Instructions (Aliases) + //! \{ + + //! The `imul(Gp, Imm)` instruction is an alias of `imul(Gp, Gp, Imm)` instruction. + inline Error imul(const Gp& o0, const Imm& o1) { return _emitter()->_emitI(Inst::kIdImul, o0, o0, o1); } + + //! \} + + //! \name Deprecated 32-bit Instructions + //! \{ + + ASMJIT_INST_1x(aaa, Aaa, Gp) // X86 [EXPLICIT] + ASMJIT_INST_2x(aad, Aad, Gp, Imm) // X86 [EXPLICIT] + ASMJIT_INST_2x(aam, Aam, Gp, Imm) // X86 [EXPLICIT] + ASMJIT_INST_1x(aas, Aas, Gp) // X86 [EXPLICIT] + ASMJIT_INST_1x(daa, Daa, Gp) // X86 [EXPLICIT] + ASMJIT_INST_1x(das, Das, Gp) // X86 [EXPLICIT] + + //! \} + + //! \name ENTER/LEAVE Instructions + //! \{ + + ASMJIT_INST_2x(enter, Enter, Imm, Imm) // ANY + ASMJIT_INST_0x(leave, Leave) // ANY + + //! \} + + //! \name IN/OUT Instructions + //! \{ + + // NOTE: For some reason Doxygen is messed up here and thinks we are in cond. + + ASMJIT_INST_2x(in, In, Gp_ZAX, Imm) // ANY + ASMJIT_INST_2x(in, In, Gp_ZAX, Gp_DX) // ANY + ASMJIT_INST_2x(ins, Ins, ES_ZDI, Gp_DX) // ANY + ASMJIT_INST_2x(out, Out, Imm, Gp_ZAX) // ANY + ASMJIT_INST_2x(out, Out, Gp_DX, Gp_ZAX) // ANY + ASMJIT_INST_2x(outs, Outs, Gp_DX, DS_ZSI) // ANY + + //! \} + + //! \name Clear/Set CF/DF Instructions + //! \{ + + ASMJIT_INST_0x(clc, Clc) // ANY + ASMJIT_INST_0x(cld, Cld) // ANY + ASMJIT_INST_0x(cmc, Cmc) // ANY + ASMJIT_INST_0x(stc, Stc) // ANY + ASMJIT_INST_0x(std, Std) // ANY + + //! \} + + //! \name ADX Instructions + //! \{ + + ASMJIT_INST_2x(adcx, Adcx, Gp, Gp) // ADX + ASMJIT_INST_2x(adcx, Adcx, Gp, Mem) // ADX + ASMJIT_INST_2x(adox, Adox, Gp, Gp) // ADX + ASMJIT_INST_2x(adox, Adox, Gp, Mem) // ADX + + //! \} + + //! \name CPUID Instruction + //! \{ + + ASMJIT_INST_4x(cpuid, Cpuid, Gp_EAX, Gp_EBX, Gp_ECX, Gp_EDX) // I486 [EXPLICIT] EAX:EBX:ECX:EDX <- CPUID[EAX:ECX] + + //! \} + + //! \name LAHF/SAHF Instructions + //! \{ + + ASMJIT_INST_1x(lahf, Lahf, Gp_AH) // LAHFSAHF [EXPLICIT] AH <- EFL + ASMJIT_INST_1x(sahf, Sahf, Gp_AH) // LAHFSAHF [EXPLICIT] EFL <- AH + + //! \} + + //! \name BMI Instructions + //! \{ + + ASMJIT_INST_3x(andn, Andn, Gp, Gp, Gp) // BMI + ASMJIT_INST_3x(andn, Andn, Gp, Gp, Mem) // BMI + ASMJIT_INST_3x(bextr, Bextr, Gp, Gp, Gp) // BMI + ASMJIT_INST_3x(bextr, Bextr, Gp, Mem, Gp) // BMI + ASMJIT_INST_2x(blsi, Blsi, Gp, Gp) // BMI + ASMJIT_INST_2x(blsi, Blsi, Gp, Mem) // BMI + ASMJIT_INST_2x(blsmsk, Blsmsk, Gp, Gp) // BMI + ASMJIT_INST_2x(blsmsk, Blsmsk, Gp, Mem) // BMI + ASMJIT_INST_2x(blsr, Blsr, Gp, Gp) // BMI + ASMJIT_INST_2x(blsr, Blsr, Gp, Mem) // BMI + ASMJIT_INST_2x(tzcnt, Tzcnt, Gp, Gp) // BMI + ASMJIT_INST_2x(tzcnt, Tzcnt, Gp, Mem) // BMI + + //! \} + + //! \name BMI2 Instructions + //! \{ + + ASMJIT_INST_3x(bzhi, Bzhi, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(bzhi, Bzhi, Gp, Mem, Gp) // BMI2 + ASMJIT_INST_4x(mulx, Mulx, Gp, Gp, Gp, Gp_ZDX) // BMI2 [EXPLICIT] + ASMJIT_INST_4x(mulx, Mulx, Gp, Gp, Mem, Gp_ZDX) // BMI2 [EXPLICIT] + ASMJIT_INST_3x(pdep, Pdep, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(pdep, Pdep, Gp, Gp, Mem) // BMI2 + ASMJIT_INST_3x(pext, Pext, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(pext, Pext, Gp, Gp, Mem) // BMI2 + ASMJIT_INST_3x(rorx, Rorx, Gp, Gp, Imm) // BMI2 + ASMJIT_INST_3x(rorx, Rorx, Gp, Mem, Imm) // BMI2 + ASMJIT_INST_3x(sarx, Sarx, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(sarx, Sarx, Gp, Mem, Gp) // BMI2 + ASMJIT_INST_3x(shlx, Shlx, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(shlx, Shlx, Gp, Mem, Gp) // BMI2 + ASMJIT_INST_3x(shrx, Shrx, Gp, Gp, Gp) // BMI2 + ASMJIT_INST_3x(shrx, Shrx, Gp, Mem, Gp) // BMI2 + + //! \} + + //! \name CMPCCXADD Instructions + //! \{ + + ASMJIT_INST_3x(cmpbexadd, Cmpbexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpbxadd, Cmpbxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmplexadd, Cmplexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmplxadd, Cmplxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnbexadd, Cmpnbexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnbxadd, Cmpnbxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnlexadd, Cmpnlexadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnlxadd, Cmpnlxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnoxadd, Cmpnoxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnpxadd, Cmpnpxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnsxadd, Cmpnsxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpnzxadd, Cmpnzxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpoxadd, Cmpoxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmppxadd, Cmppxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpsxadd, Cmpsxadd, Mem, Gp, Gp) + ASMJIT_INST_3x(cmpzxadd, Cmpzxadd, Mem, Gp, Gp) + + //! \} + + //! \name CacheLine Instructions + //! \{ + + ASMJIT_INST_1x(cldemote, Cldemote, Mem) // CLDEMOTE + ASMJIT_INST_1x(clflush, Clflush, Mem) // CLFLUSH + ASMJIT_INST_1x(clflushopt, Clflushopt, Mem) // CLFLUSH_OPT + ASMJIT_INST_1x(clwb, Clwb, Mem) // CLWB + ASMJIT_INST_1x(clzero, Clzero, DS_ZAX) // CLZERO [EXPLICIT] + + //! \} + + //! \name CRC32 Instructions (SSE4.2) + //! \{ + + ASMJIT_INST_2x(crc32, Crc32, Gp, Gp) // SSE4_2 + ASMJIT_INST_2x(crc32, Crc32, Gp, Mem) // SSE4_2 + + //! \} + + //! \name FENCE Instructions (SSE and SSE2) + //! \{ + + ASMJIT_INST_0x(lfence, Lfence) // SSE2 + ASMJIT_INST_0x(mfence, Mfence) // SSE2 + ASMJIT_INST_0x(sfence, Sfence) // SSE + + //! \} + + //! \name LZCNT Instructions + //! \{ + + ASMJIT_INST_2x(lzcnt, Lzcnt, Gp, Gp) // LZCNT + ASMJIT_INST_2x(lzcnt, Lzcnt, Gp, Mem) // LZCNT + + //! \} + + //! \name MOVBE Instructions + //! \{ + + ASMJIT_INST_2x(movbe, Movbe, Gp, Mem) // MOVBE + ASMJIT_INST_2x(movbe, Movbe, Mem, Gp) // MOVBE + + //! \} + + //! \name MOVDIRI & MOVDIR64B Instructions + //! \{ + + ASMJIT_INST_2x(movdiri, Movdiri, Mem, Gp) // MOVDIRI + ASMJIT_INST_2x(movdir64b, Movdir64b, Mem, Mem) // MOVDIR64B + + //! \} + + //! \name MXCSR Instructions (SSE) + //! \{ + + ASMJIT_INST_1x(ldmxcsr, Ldmxcsr, Mem) // SSE + ASMJIT_INST_1x(stmxcsr, Stmxcsr, Mem) // SSE + + //! \} + + //! \name POPCNT Instructions + //! \{ + + ASMJIT_INST_2x(popcnt, Popcnt, Gp, Gp) // POPCNT + ASMJIT_INST_2x(popcnt, Popcnt, Gp, Mem) // POPCNT + + //! \} + + //! \name PREFETCH Instructions + //! \{ + + ASMJIT_INST_1x(prefetch, Prefetch, Mem) // 3DNOW + ASMJIT_INST_1x(prefetchnta, Prefetchnta, Mem) // SSE + ASMJIT_INST_1x(prefetcht0, Prefetcht0, Mem) // SSE + ASMJIT_INST_1x(prefetcht1, Prefetcht1, Mem) // SSE + ASMJIT_INST_1x(prefetcht2, Prefetcht2, Mem) // SSE + ASMJIT_INST_1x(prefetchw, Prefetchw, Mem) // PREFETCHW + ASMJIT_INST_1x(prefetchwt1, Prefetchwt1, Mem) // PREFETCHW1 + + //! \} + + //! \name PREFETCHI Instructions + //! \{ + + ASMJIT_INST_1x(prefetchit0, Prefetchit0, Mem) + ASMJIT_INST_1x(prefetchit1, Prefetchit1, Mem) + + //! \} + + //! \name RAO_INT Instructions + //! \{ + + ASMJIT_INST_2x(aadd, Aadd, Mem, Gp) + ASMJIT_INST_2x(aand, Aand, Mem, Gp) + ASMJIT_INST_2x(aor, Aor, Mem, Gp) + ASMJIT_INST_2x(axor, Axor, Mem, Gp) + + //! \} + + //! \name RDPID Instruction + //! \{ + + ASMJIT_INST_1x(rdpid, Rdpid, Gp) // RDPID + + //! \} + + //! \name RDPRU/RDPKRU Instructions + //! \{ + + ASMJIT_INST_3x(rdpru, Rdpru, Gp_EDX, Gp_EAX, Gp_ECX) // RDPRU [EXPLICIT] EDX:EAX <- PRU[ECX] + ASMJIT_INST_3x(rdpkru, Rdpkru, Gp_EDX, Gp_EAX, Gp_ECX) // RDPKRU [EXPLICIT] EDX:EAX <- PKRU[ECX] + + //! \} + + //! \name RDTSC/RDTSCP Instructions + //! \{ + + ASMJIT_INST_2x(rdtsc, Rdtsc, Gp_EDX, Gp_EAX) // RDTSC [EXPLICIT] EDX:EAX <- Counter + ASMJIT_INST_3x(rdtscp, Rdtscp, Gp_EDX, Gp_EAX, Gp_ECX) // RDTSCP [EXPLICIT] EDX:EAX:EXC <- Counter + + //! \} + + //! \name SERIALIZE Instruction + //! \{ + + ASMJIT_INST_0x(serialize, Serialize) // SERIALIZE + + //! \} + + //! \name TBM Instructions + //! \{ + + ASMJIT_INST_2x(blcfill, Blcfill, Gp, Gp) // TBM + ASMJIT_INST_2x(blcfill, Blcfill, Gp, Mem) // TBM + ASMJIT_INST_2x(blci, Blci, Gp, Gp) // TBM + ASMJIT_INST_2x(blci, Blci, Gp, Mem) // TBM + ASMJIT_INST_2x(blcic, Blcic, Gp, Gp) // TBM + ASMJIT_INST_2x(blcic, Blcic, Gp, Mem) // TBM + ASMJIT_INST_2x(blcmsk, Blcmsk, Gp, Gp) // TBM + ASMJIT_INST_2x(blcmsk, Blcmsk, Gp, Mem) // TBM + ASMJIT_INST_2x(blcs, Blcs, Gp, Gp) // TBM + ASMJIT_INST_2x(blcs, Blcs, Gp, Mem) // TBM + ASMJIT_INST_2x(blsfill, Blsfill, Gp, Gp) // TBM + ASMJIT_INST_2x(blsfill, Blsfill, Gp, Mem) // TBM + ASMJIT_INST_2x(blsic, Blsic, Gp, Gp) // TBM + ASMJIT_INST_2x(blsic, Blsic, Gp, Mem) // TBM + ASMJIT_INST_2x(t1mskc, T1mskc, Gp, Gp) // TBM + ASMJIT_INST_2x(t1mskc, T1mskc, Gp, Mem) // TBM + ASMJIT_INST_2x(tzmsk, Tzmsk, Gp, Gp) // TBM + ASMJIT_INST_2x(tzmsk, Tzmsk, Gp, Mem) // TBM + + //! \} + + //! \name Other User-Mode Instructions + //! \{ + + ASMJIT_INST_2x(arpl, Arpl, Gp, Gp) // X86 + ASMJIT_INST_2x(arpl, Arpl, Mem, Gp) // X86 + ASMJIT_INST_0x(cli, Cli) // ANY + ASMJIT_INST_0x(getsec, Getsec) // SMX + ASMJIT_INST_1x(int_, Int, Imm) // ANY + ASMJIT_INST_0x(int3, Int3) // ANY + ASMJIT_INST_0x(into, Into) // ANY + ASMJIT_INST_2x(lar, Lar, Gp, Gp) // ANY + ASMJIT_INST_2x(lar, Lar, Gp, Mem) // ANY + ASMJIT_INST_2x(lds, Lds, Gp, Mem) // X86 + ASMJIT_INST_2x(les, Les, Gp, Mem) // X86 + ASMJIT_INST_2x(lfs, Lfs, Gp, Mem) // ANY + ASMJIT_INST_2x(lgs, Lgs, Gp, Mem) // ANY + ASMJIT_INST_2x(lsl, Lsl, Gp, Gp) // ANY + ASMJIT_INST_2x(lsl, Lsl, Gp, Mem) // ANY + ASMJIT_INST_2x(lss, Lss, Gp, Mem) // ANY + ASMJIT_INST_0x(pause, Pause) // SSE2 + ASMJIT_INST_0x(rsm, Rsm) // X86 + ASMJIT_INST_1x(sgdt, Sgdt, Mem) // ANY + ASMJIT_INST_1x(sidt, Sidt, Mem) // ANY + ASMJIT_INST_1x(sldt, Sldt, Gp) // ANY + ASMJIT_INST_1x(sldt, Sldt, Mem) // ANY + ASMJIT_INST_1x(smsw, Smsw, Gp) // ANY + ASMJIT_INST_1x(smsw, Smsw, Mem) // ANY + ASMJIT_INST_0x(sti, Sti) // ANY + ASMJIT_INST_1x(str, Str, Gp) // ANY + ASMJIT_INST_1x(str, Str, Mem) // ANY + ASMJIT_INST_1x(verr, Verr, Gp) // ANY + ASMJIT_INST_1x(verr, Verr, Mem) // ANY + ASMJIT_INST_1x(verw, Verw, Gp) // ANY + ASMJIT_INST_1x(verw, Verw, Mem) // ANY + + //! \} + + //! \name FSGSBASE Instructions + //! \{ + + ASMJIT_INST_1x(rdfsbase, Rdfsbase, Gp) // FSGSBASE + ASMJIT_INST_1x(rdgsbase, Rdgsbase, Gp) // FSGSBASE + ASMJIT_INST_1x(wrfsbase, Wrfsbase, Gp) // FSGSBASE + ASMJIT_INST_1x(wrgsbase, Wrgsbase, Gp) // FSGSBASE + + //! \} + + //! \name FXSR Instructions + //! \{ + + ASMJIT_INST_1x(fxrstor, Fxrstor, Mem) // FXSR + ASMJIT_INST_1x(fxrstor64, Fxrstor64, Mem) // FXSR + ASMJIT_INST_1x(fxsave, Fxsave, Mem) // FXSR + ASMJIT_INST_1x(fxsave64, Fxsave64, Mem) // FXSR + + //! \} + + //! \name XSAVE Instructions + //! \{ + + ASMJIT_INST_3x(xgetbv, Xgetbv, Gp_EDX, Gp_EAX, Gp_ECX) // XSAVE [EXPLICIT] EDX:EAX <- XCR[ECX] + ASMJIT_INST_3x(xrstor, Xrstor, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xrstor64, Xrstor64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xrstors, Xrstors, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xrstors64, Xrstors64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsave, Xsave, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsave64, Xsave64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsavec, Xsavec, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsavec64, Xsavec64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsaveopt, Xsaveopt, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsaveopt64, Xsaveopt64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + ASMJIT_INST_3x(xsaves, Xsaves, Mem, Gp_EDX, Gp_EAX) // XSAVE [EXPLICIT] + ASMJIT_INST_3x(xsaves64, Xsaves64, Mem, Gp_EDX, Gp_EAX) // XSAVE+X64 [EXPLICIT] + + //! \} + + //! \name MPX Extensions + //! \{ + + ASMJIT_INST_2x(bndcl, Bndcl, Bnd, Gp) // MPX + ASMJIT_INST_2x(bndcl, Bndcl, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndcn, Bndcn, Bnd, Gp) // MPX + ASMJIT_INST_2x(bndcn, Bndcn, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndcu, Bndcu, Bnd, Gp) // MPX + ASMJIT_INST_2x(bndcu, Bndcu, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndldx, Bndldx, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndmk, Bndmk, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndmov, Bndmov, Bnd, Bnd) // MPX + ASMJIT_INST_2x(bndmov, Bndmov, Bnd, Mem) // MPX + ASMJIT_INST_2x(bndmov, Bndmov, Mem, Bnd) // MPX + ASMJIT_INST_2x(bndstx, Bndstx, Mem, Bnd) // MPX + + //! \} + + //! \name MONITORX Instructions + //! \{ + + ASMJIT_INST_3x(monitorx, Monitorx, Mem, Gp, Gp) // MONITORX + ASMJIT_INST_3x(mwaitx, Mwaitx, Gp, Gp, Gp) // MONITORX + + //! \} + + //! \name MCOMMIT Instruction + //! \{ + + ASMJIT_INST_0x(mcommit, Mcommit) // MCOMMIT + + //! \} + + //! \name PTWRITE Instruction + //! \{ + + ASMJIT_INST_1x(ptwrite, Ptwrite, Gp) // PTWRITE + ASMJIT_INST_1x(ptwrite, Ptwrite, Mem) // PTWRITE + + //! \} + + //! \name ENQCMD Instructions + //! \{ + + ASMJIT_INST_2x(enqcmd, Enqcmd, Mem, Mem) // ENQCMD + ASMJIT_INST_2x(enqcmds, Enqcmds, Mem, Mem) // ENQCMD + + //! \} + + //! \name WAITPKG Instructions + //! \{ + + ASMJIT_INST_3x(tpause, Tpause, Gp, Gp, Gp) // WAITPKG + ASMJIT_INST_1x(umonitor, Umonitor, Mem) // WAITPKG + ASMJIT_INST_3x(umwait, Umwait, Gp, Gp, Gp) // WAITPKG + + //! \} + + //! \name RDRAND & RDSEED Instructions + //! \{ + + ASMJIT_INST_1x(rdrand, Rdrand, Gp) // RDRAND + ASMJIT_INST_1x(rdseed, Rdseed, Gp) // RDSEED + + //! \} + + //! \name LWP Instructions + //! \{ + + ASMJIT_INST_1x(llwpcb, Llwpcb, Gp) // LWP + ASMJIT_INST_3x(lwpins, Lwpins, Gp, Gp, Imm) // LWP + ASMJIT_INST_3x(lwpins, Lwpins, Gp, Mem, Imm) // LWP + ASMJIT_INST_3x(lwpval, Lwpval, Gp, Gp, Imm) // LWP + ASMJIT_INST_3x(lwpval, Lwpval, Gp, Mem, Imm) // LWP + ASMJIT_INST_1x(slwpcb, Slwpcb, Gp) // LWP + + //! \} + + //! \name RTM & TSX Instructions + //! \{ + + ASMJIT_INST_1x(xabort, Xabort, Imm) // RTM + ASMJIT_INST_1x(xbegin, Xbegin, Label) // RTM + ASMJIT_INST_1x(xbegin, Xbegin, Imm) // RTM + ASMJIT_INST_0x(xend, Xend) // RTM + ASMJIT_INST_0x(xtest, Xtest) // TSX + + //! \} + + //! \name TSXLDTRK Instructions + //! \{ + + ASMJIT_INST_0x(xresldtrk, Xresldtrk) // TSXLDTRK + ASMJIT_INST_0x(xsusldtrk, Xsusldtrk) // TSXLDTRK + + //! \} + + //! \name CET-IBT Instructions + //! \{ + + ASMJIT_INST_0x(endbr32, Endbr32) // CET_IBT + ASMJIT_INST_0x(endbr64, Endbr64) // CET_IBT + + //! \} + + //! \name CET-SS Instructions + //! \{ + + ASMJIT_INST_1x(clrssbsy, Clrssbsy, Mem) // CET_SS + ASMJIT_INST_0x(setssbsy, Setssbsy) // CET_SS + + ASMJIT_INST_1x(rstorssp, Rstorssp, Mem) // CET_SS + ASMJIT_INST_0x(saveprevssp, Saveprevssp) // CET_SS + + ASMJIT_INST_1x(incsspd, Incsspd, Gp) // CET_SS + ASMJIT_INST_1x(incsspq, Incsspq, Gp) // CET_SS + ASMJIT_INST_1x(rdsspd, Rdsspd, Gp) // CET_SS + ASMJIT_INST_1x(rdsspq, Rdsspq, Gp) // CET_SS + ASMJIT_INST_2x(wrssd, Wrssd, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrssd, Wrssd, Mem, Gp) // CET_SS + ASMJIT_INST_2x(wrssq, Wrssq, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrssq, Wrssq, Mem, Gp) // CET_SS + ASMJIT_INST_2x(wrussd, Wrussd, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrussd, Wrussd, Mem, Gp) // CET_SS + ASMJIT_INST_2x(wrussq, Wrussq, Gp, Gp) // CET_SS + ASMJIT_INST_2x(wrussq, Wrussq, Mem, Gp) // CET_SS + + //! \} + + //! \name HRESET Instructions + //! \{ + + ASMJIT_INST_2x(hreset, Hreset, Imm, Gp) // HRESET + + //! \} + + //! \name UINTR Instructions + //! \{ + + ASMJIT_INST_0x(clui, Clui) // UINTR + ASMJIT_INST_1x(senduipi, Senduipi, Gp) // UINTR + ASMJIT_INST_0x(testui, Testui) // UINTR + ASMJIT_INST_0x(stui, Stui) // UINTR + ASMJIT_INST_0x(uiret, Uiret) // UINTR + + //! \} + + //! \name Core Privileged Instructions + //! \{ + + ASMJIT_INST_0x(clts, Clts) // ANY + ASMJIT_INST_0x(hlt, Hlt) // ANY + ASMJIT_INST_0x(invd, Invd) // ANY + ASMJIT_INST_1x(invlpg, Invlpg, Mem) // ANY + ASMJIT_INST_2x(invpcid, Invpcid, Gp, Mem) // ANY + ASMJIT_INST_1x(lgdt, Lgdt, Mem) // ANY + ASMJIT_INST_1x(lidt, Lidt, Mem) // ANY + ASMJIT_INST_1x(lldt, Lldt, Gp) // ANY + ASMJIT_INST_1x(lldt, Lldt, Mem) // ANY + ASMJIT_INST_1x(lmsw, Lmsw, Gp) // ANY + ASMJIT_INST_1x(lmsw, Lmsw, Mem) // ANY + ASMJIT_INST_1x(ltr, Ltr, Gp) // ANY + ASMJIT_INST_1x(ltr, Ltr, Mem) // ANY + ASMJIT_INST_3x(rdmsr, Rdmsr, Gp_EDX, Gp_EAX, Gp_ECX) // MSR [EXPLICIT] RDX:EAX <- MSR[ECX] + ASMJIT_INST_3x(rdpmc, Rdpmc, Gp_EDX, Gp_EAX, Gp_ECX) // ANY [EXPLICIT] RDX:EAX <- PMC[ECX] + ASMJIT_INST_0x(swapgs, Swapgs) // X64 + ASMJIT_INST_0x(wbinvd, Wbinvd) // ANY + ASMJIT_INST_0x(wbnoinvd, Wbnoinvd) // WBNOINVD + ASMJIT_INST_3x(wrmsr, Wrmsr, Gp_EDX, Gp_EAX, Gp_ECX) // MSR [EXPLICIT] RDX:EAX -> MSR[ECX] + ASMJIT_INST_3x(xsetbv, Xsetbv, Gp_EDX, Gp_EAX, Gp_ECX) // XSAVE [EXPLICIT] XCR[ECX] <- EDX:EAX + + //! \} + + //! \name INVLPGB Instructions + //! \{ + + ASMJIT_INST_3x(invlpgb, Invlpgb, Gp_EAX, Gp_EDX, Gp_ECX) + ASMJIT_INST_0x(tlbsync, Tlbsync) + + //! \} + + //! \name MONITOR Instructions (Privileged) + //! \{ + + ASMJIT_INST_3x(monitor, Monitor, Mem, Gp, Gp) // MONITOR + ASMJIT_INST_2x(mwait, Mwait, Gp, Gp) // MONITOR + + //! \} + + //! \name SMAP Instructions (Privileged) + //! \{ + + ASMJIT_INST_0x(clac, Clac) // SMAP + ASMJIT_INST_0x(stac, Stac) // SMAP + + //! \} + + //! \name SKINIT Instructions (Privileged) + //! \{ + + ASMJIT_INST_1x(skinit, Skinit, Gp) // SKINIT [EXPLICIT] + ASMJIT_INST_0x(stgi, Stgi) // SKINIT + + //! \} + + //! \name SNP Instructions (Privileged) + //! \{ + + ASMJIT_INST_0x(psmash, Psmash) // SNP + ASMJIT_INST_0x(pvalidate, Pvalidate) // SNP + ASMJIT_INST_0x(rmpadjust, Rmpadjust) // SNP + ASMJIT_INST_0x(rmpupdate, Rmpupdate) // SNP + + //! \} + + //! \name VMX Instructions (All privileged except vmfunc) + //! \{ + + ASMJIT_INST_2x(invept, Invept, Gp, Mem) // VMX + ASMJIT_INST_2x(invvpid, Invvpid, Gp, Mem) // VMX + ASMJIT_INST_0x(vmcall, Vmcall) // VMX + ASMJIT_INST_1x(vmclear, Vmclear, Mem) // VMX + ASMJIT_INST_0x(vmfunc, Vmfunc) // VMX + ASMJIT_INST_0x(vmlaunch, Vmlaunch) // VMX + ASMJIT_INST_1x(vmptrld, Vmptrld, Mem) // VMX + ASMJIT_INST_1x(vmptrst, Vmptrst, Mem) // VMX + ASMJIT_INST_2x(vmread, Vmread, Gp, Gp) // VMX + ASMJIT_INST_2x(vmread, Vmread, Mem, Gp) // VMX + ASMJIT_INST_0x(vmresume, Vmresume) // VMX + ASMJIT_INST_2x(vmwrite, Vmwrite, Gp, Mem) // VMX + ASMJIT_INST_2x(vmwrite, Vmwrite, Gp, Gp) // VMX + ASMJIT_INST_0x(vmxoff, Vmxoff) // VMX + ASMJIT_INST_1x(vmxon, Vmxon, Mem) // VMX + + //! \} + + //! \name SVM Instructions (All privileged except vmmcall) + //! \{ + + ASMJIT_INST_0x(clgi, Clgi) // SVM + ASMJIT_INST_2x(invlpga, Invlpga, Gp, Gp) // SVM [EXPLICIT] + ASMJIT_INST_1x(vmload, Vmload, Gp) // SVM [EXPLICIT] + ASMJIT_INST_0x(vmmcall, Vmmcall) // SVM + ASMJIT_INST_1x(vmrun, Vmrun, Gp) // SVM [EXPLICIT] + ASMJIT_INST_1x(vmsave, Vmsave, Gp) // SVM [EXPLICIT] + + //! \} + + //! \name SEV_ES Instructions + //! \{ + + ASMJIT_INST_0x(vmgexit, Vmgexit) + + //! \} + + //! \name FPU Instructions + //! \{ + + ASMJIT_INST_0x(f2xm1, F2xm1) // FPU + ASMJIT_INST_0x(fabs, Fabs) // FPU + ASMJIT_INST_2x(fadd, Fadd, St, St) // FPU + ASMJIT_INST_1x(fadd, Fadd, Mem) // FPU + ASMJIT_INST_1x(faddp, Faddp, St) // FPU + ASMJIT_INST_0x(faddp, Faddp) // FPU + ASMJIT_INST_1x(fbld, Fbld, Mem) // FPU + ASMJIT_INST_1x(fbstp, Fbstp, Mem) // FPU + ASMJIT_INST_0x(fchs, Fchs) // FPU + ASMJIT_INST_0x(fclex, Fclex) // FPU + ASMJIT_INST_1x(fcmovb, Fcmovb, St) // FPU + ASMJIT_INST_1x(fcmovbe, Fcmovbe, St) // FPU + ASMJIT_INST_1x(fcmove, Fcmove, St) // FPU + ASMJIT_INST_1x(fcmovnb, Fcmovnb, St) // FPU + ASMJIT_INST_1x(fcmovnbe, Fcmovnbe, St) // FPU + ASMJIT_INST_1x(fcmovne, Fcmovne, St) // FPU + ASMJIT_INST_1x(fcmovnu, Fcmovnu, St) // FPU + ASMJIT_INST_1x(fcmovu, Fcmovu, St) // FPU + ASMJIT_INST_1x(fcom, Fcom, St) // FPU + ASMJIT_INST_0x(fcom, Fcom) // FPU + ASMJIT_INST_1x(fcom, Fcom, Mem) // FPU + ASMJIT_INST_1x(fcomp, Fcomp, St) // FPU + ASMJIT_INST_0x(fcomp, Fcomp) // FPU + ASMJIT_INST_1x(fcomp, Fcomp, Mem) // FPU + ASMJIT_INST_0x(fcompp, Fcompp) // FPU + ASMJIT_INST_1x(fcomi, Fcomi, St) // FPU + ASMJIT_INST_1x(fcomip, Fcomip, St) // FPU + ASMJIT_INST_0x(fcos, Fcos) // FPU + ASMJIT_INST_0x(fdecstp, Fdecstp) // FPU + ASMJIT_INST_2x(fdiv, Fdiv, St, St) // FPU + ASMJIT_INST_1x(fdiv, Fdiv, Mem) // FPU + ASMJIT_INST_1x(fdivp, Fdivp, St) // FPU + ASMJIT_INST_0x(fdivp, Fdivp) // FPU + ASMJIT_INST_2x(fdivr, Fdivr, St, St) // FPU + ASMJIT_INST_1x(fdivr, Fdivr, Mem) // FPU + ASMJIT_INST_1x(fdivrp, Fdivrp, St) // FPU + ASMJIT_INST_0x(fdivrp, Fdivrp) // FPU + ASMJIT_INST_1x(ffree, Ffree, St) // FPU + ASMJIT_INST_1x(fiadd, Fiadd, Mem) // FPU + ASMJIT_INST_1x(ficom, Ficom, Mem) // FPU + ASMJIT_INST_1x(ficomp, Ficomp, Mem) // FPU + ASMJIT_INST_1x(fidiv, Fidiv, Mem) // FPU + ASMJIT_INST_1x(fidivr, Fidivr, Mem) // FPU + ASMJIT_INST_1x(fild, Fild, Mem) // FPU + ASMJIT_INST_1x(fimul, Fimul, Mem) // FPU + ASMJIT_INST_0x(fincstp, Fincstp) // FPU + ASMJIT_INST_0x(finit, Finit) // FPU + ASMJIT_INST_1x(fisub, Fisub, Mem) // FPU + ASMJIT_INST_1x(fisubr, Fisubr, Mem) // FPU + ASMJIT_INST_0x(fninit, Fninit) // FPU + ASMJIT_INST_1x(fist, Fist, Mem) // FPU + ASMJIT_INST_1x(fistp, Fistp, Mem) // FPU + ASMJIT_INST_1x(fisttp, Fisttp, Mem) // FPU+SSE3 + ASMJIT_INST_1x(fld, Fld, Mem) // FPU + ASMJIT_INST_1x(fld, Fld, St) // FPU + ASMJIT_INST_0x(fld1, Fld1) // FPU + ASMJIT_INST_0x(fldl2t, Fldl2t) // FPU + ASMJIT_INST_0x(fldl2e, Fldl2e) // FPU + ASMJIT_INST_0x(fldpi, Fldpi) // FPU + ASMJIT_INST_0x(fldlg2, Fldlg2) // FPU + ASMJIT_INST_0x(fldln2, Fldln2) // FPU + ASMJIT_INST_0x(fldz, Fldz) // FPU + ASMJIT_INST_1x(fldcw, Fldcw, Mem) // FPU + ASMJIT_INST_1x(fldenv, Fldenv, Mem) // FPU + ASMJIT_INST_2x(fmul, Fmul, St, St) // FPU + ASMJIT_INST_1x(fmul, Fmul, Mem) // FPU + ASMJIT_INST_1x(fmulp, Fmulp, St) // FPU + ASMJIT_INST_0x(fmulp, Fmulp) // FPU + ASMJIT_INST_0x(fnclex, Fnclex) // FPU + ASMJIT_INST_0x(fnop, Fnop) // FPU + ASMJIT_INST_1x(fnsave, Fnsave, Mem) // FPU + ASMJIT_INST_1x(fnstenv, Fnstenv, Mem) // FPU + ASMJIT_INST_1x(fnstcw, Fnstcw, Mem) // FPU + ASMJIT_INST_0x(fpatan, Fpatan) // FPU + ASMJIT_INST_0x(fprem, Fprem) // FPU + ASMJIT_INST_0x(fprem1, Fprem1) // FPU + ASMJIT_INST_0x(fptan, Fptan) // FPU + ASMJIT_INST_0x(frndint, Frndint) // FPU + ASMJIT_INST_1x(frstor, Frstor, Mem) // FPU + ASMJIT_INST_1x(fsave, Fsave, Mem) // FPU + ASMJIT_INST_0x(fscale, Fscale) // FPU + ASMJIT_INST_0x(fsin, Fsin) // FPU + ASMJIT_INST_0x(fsincos, Fsincos) // FPU + ASMJIT_INST_0x(fsqrt, Fsqrt) // FPU + ASMJIT_INST_1x(fst, Fst, Mem) // FPU + ASMJIT_INST_1x(fst, Fst, St) // FPU + ASMJIT_INST_1x(fstp, Fstp, Mem) // FPU + ASMJIT_INST_1x(fstp, Fstp, St) // FPU + ASMJIT_INST_1x(fstcw, Fstcw, Mem) // FPU + ASMJIT_INST_1x(fstenv, Fstenv, Mem) // FPU + ASMJIT_INST_2x(fsub, Fsub, St, St) // FPU + ASMJIT_INST_1x(fsub, Fsub, Mem) // FPU + ASMJIT_INST_1x(fsubp, Fsubp, St) // FPU + ASMJIT_INST_0x(fsubp, Fsubp) // FPU + ASMJIT_INST_2x(fsubr, Fsubr, St, St) // FPU + ASMJIT_INST_1x(fsubr, Fsubr, Mem) // FPU + ASMJIT_INST_1x(fsubrp, Fsubrp, St) // FPU + ASMJIT_INST_0x(fsubrp, Fsubrp) // FPU + ASMJIT_INST_0x(ftst, Ftst) // FPU + ASMJIT_INST_1x(fucom, Fucom, St) // FPU + ASMJIT_INST_0x(fucom, Fucom) // FPU + ASMJIT_INST_1x(fucomi, Fucomi, St) // FPU + ASMJIT_INST_1x(fucomip, Fucomip, St) // FPU + ASMJIT_INST_1x(fucomp, Fucomp, St) // FPU + ASMJIT_INST_0x(fucomp, Fucomp) // FPU + ASMJIT_INST_0x(fucompp, Fucompp) // FPU + ASMJIT_INST_0x(fwait, Fwait) // FPU + ASMJIT_INST_0x(fxam, Fxam) // FPU + ASMJIT_INST_1x(fxch, Fxch, St) // FPU + ASMJIT_INST_0x(fxtract, Fxtract) // FPU + ASMJIT_INST_0x(fyl2x, Fyl2x) // FPU + ASMJIT_INST_0x(fyl2xp1, Fyl2xp1) // FPU + ASMJIT_INST_1x(fstsw, Fstsw, Gp) // FPU + ASMJIT_INST_1x(fstsw, Fstsw, Mem) // FPU + ASMJIT_INST_1x(fnstsw, Fnstsw, Gp) // FPU + ASMJIT_INST_1x(fnstsw, Fnstsw, Mem) // FPU + + //! \} + + //! \name MMX & SSE+ Instructions + //! \{ + + ASMJIT_INST_2x(addpd, Addpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(addpd, Addpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(addps, Addps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(addps, Addps, Xmm, Mem) // SSE + ASMJIT_INST_2x(addsd, Addsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(addsd, Addsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(addss, Addss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(addss, Addss, Xmm, Mem) // SSE + ASMJIT_INST_2x(addsubpd, Addsubpd, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(addsubpd, Addsubpd, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(addsubps, Addsubps, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(addsubps, Addsubps, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(andnpd, Andnpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(andnpd, Andnpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(andnps, Andnps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(andnps, Andnps, Xmm, Mem) // SSE + ASMJIT_INST_2x(andpd, Andpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(andpd, Andpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(andps, Andps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(andps, Andps, Xmm, Mem) // SSE + ASMJIT_INST_3x(blendpd, Blendpd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(blendpd, Blendpd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(blendps, Blendps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(blendps, Blendps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(blendvpd, Blendvpd, Xmm, Xmm, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(blendvpd, Blendvpd, Xmm, Mem, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(blendvps, Blendvps, Xmm, Xmm, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(blendvps, Blendvps, Xmm, Mem, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(cmppd, Cmppd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(cmppd, Cmppd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(cmpps, Cmpps, Xmm, Xmm, Imm) // SSE + ASMJIT_INST_3x(cmpps, Cmpps, Xmm, Mem, Imm) // SSE + ASMJIT_INST_3x(cmpsd, Cmpsd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(cmpsd, Cmpsd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(cmpss, Cmpss, Xmm, Xmm, Imm) // SSE + ASMJIT_INST_3x(cmpss, Cmpss, Xmm, Mem, Imm) // SSE + ASMJIT_INST_2x(comisd, Comisd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(comisd, Comisd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(comiss, Comiss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(comiss, Comiss, Xmm, Mem) // SSE + ASMJIT_INST_2x(cvtdq2pd, Cvtdq2pd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtdq2pd, Cvtdq2pd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtdq2ps, Cvtdq2ps, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtdq2ps, Cvtdq2ps, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpd2dq, Cvtpd2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtpd2dq, Cvtpd2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpd2pi, Cvtpd2pi, Mm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtpd2pi, Cvtpd2pi, Mm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpd2ps, Cvtpd2ps, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtpd2ps, Cvtpd2ps, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpi2pd, Cvtpi2pd, Xmm, Mm) // SSE2 + ASMJIT_INST_2x(cvtpi2pd, Cvtpi2pd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtpi2ps, Cvtpi2ps, Xmm, Mm) // SSE + ASMJIT_INST_2x(cvtpi2ps, Cvtpi2ps, Xmm, Mem) // SSE + ASMJIT_INST_2x(cvtps2dq, Cvtps2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtps2dq, Cvtps2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtps2pd, Cvtps2pd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtps2pd, Cvtps2pd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtps2pi, Cvtps2pi, Mm, Xmm) // SSE + ASMJIT_INST_2x(cvtps2pi, Cvtps2pi, Mm, Mem) // SSE + ASMJIT_INST_2x(cvtsd2si, Cvtsd2si, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(cvtsd2si, Cvtsd2si, Gp, Mem) // SSE2 + ASMJIT_INST_2x(cvtsd2ss, Cvtsd2ss, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtsd2ss, Cvtsd2ss, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtsi2sd, Cvtsi2sd, Xmm, Gp) // SSE2 + ASMJIT_INST_2x(cvtsi2sd, Cvtsi2sd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtsi2ss, Cvtsi2ss, Xmm, Gp) // SSE + ASMJIT_INST_2x(cvtsi2ss, Cvtsi2ss, Xmm, Mem) // SSE + ASMJIT_INST_2x(cvtss2sd, Cvtss2sd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvtss2sd, Cvtss2sd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvtss2si, Cvtss2si, Gp, Xmm) // SSE + ASMJIT_INST_2x(cvtss2si, Cvtss2si, Gp, Mem) // SSE + ASMJIT_INST_2x(cvttpd2pi, Cvttpd2pi, Mm, Xmm) // SSE2 + ASMJIT_INST_2x(cvttpd2pi, Cvttpd2pi, Mm, Mem) // SSE2 + ASMJIT_INST_2x(cvttpd2dq, Cvttpd2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvttpd2dq, Cvttpd2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvttps2dq, Cvttps2dq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(cvttps2dq, Cvttps2dq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(cvttps2pi, Cvttps2pi, Mm, Xmm) // SSE + ASMJIT_INST_2x(cvttps2pi, Cvttps2pi, Mm, Mem) // SSE + ASMJIT_INST_2x(cvttsd2si, Cvttsd2si, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(cvttsd2si, Cvttsd2si, Gp, Mem) // SSE2 + ASMJIT_INST_2x(cvttss2si, Cvttss2si, Gp, Xmm) // SSE + ASMJIT_INST_2x(cvttss2si, Cvttss2si, Gp, Mem) // SSE + ASMJIT_INST_2x(divpd, Divpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(divpd, Divpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(divps, Divps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(divps, Divps, Xmm, Mem) // SSE + ASMJIT_INST_2x(divsd, Divsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(divsd, Divsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(divss, Divss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(divss, Divss, Xmm, Mem) // SSE + ASMJIT_INST_3x(dppd, Dppd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(dppd, Dppd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(dpps, Dpps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(dpps, Dpps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(extractps, Extractps, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(extractps, Extractps, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_2x(extrq, Extrq, Xmm, Xmm) // SSE4A + ASMJIT_INST_3x(extrq, Extrq, Xmm, Imm, Imm) // SSE4A + ASMJIT_INST_2x(haddpd, Haddpd, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(haddpd, Haddpd, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(haddps, Haddps, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(haddps, Haddps, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(hsubpd, Hsubpd, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(hsubpd, Hsubpd, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(hsubps, Hsubps, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(hsubps, Hsubps, Xmm, Mem) // SSE3 + ASMJIT_INST_3x(insertps, Insertps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(insertps, Insertps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_2x(insertq, Insertq, Xmm, Xmm) // SSE4A + ASMJIT_INST_4x(insertq, Insertq, Xmm, Xmm, Imm, Imm) // SSE4A + ASMJIT_INST_2x(lddqu, Lddqu, Xmm, Mem) // SSE3 + ASMJIT_INST_3x(maskmovq, Maskmovq, Mm, Mm, DS_ZDI) // SSE [EXPLICIT] + ASMJIT_INST_3x(maskmovdqu, Maskmovdqu, Xmm, Xmm, DS_ZDI) // SSE2 [EXPLICIT] + ASMJIT_INST_2x(maxpd, Maxpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(maxpd, Maxpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(maxps, Maxps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(maxps, Maxps, Xmm, Mem) // SSE + ASMJIT_INST_2x(maxsd, Maxsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(maxsd, Maxsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(maxss, Maxss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(maxss, Maxss, Xmm, Mem) // SSE + ASMJIT_INST_2x(minpd, Minpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(minpd, Minpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(minps, Minps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(minps, Minps, Xmm, Mem) // SSE + ASMJIT_INST_2x(minsd, Minsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(minsd, Minsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(minss, Minss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(minss, Minss, Xmm, Mem) // SSE + ASMJIT_INST_2x(movapd, Movapd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movapd, Movapd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movapd, Movapd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movaps, Movaps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movaps, Movaps, Xmm, Mem) // SSE + ASMJIT_INST_2x(movaps, Movaps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movd, Movd, Mem, Mm) // MMX + ASMJIT_INST_2x(movd, Movd, Mem, Xmm) // SSE + ASMJIT_INST_2x(movd, Movd, Gp, Mm) // MMX + ASMJIT_INST_2x(movd, Movd, Gp, Xmm) // SSE + ASMJIT_INST_2x(movd, Movd, Mm, Mem) // MMX + ASMJIT_INST_2x(movd, Movd, Xmm, Mem) // SSE + ASMJIT_INST_2x(movd, Movd, Mm, Gp) // MMX + ASMJIT_INST_2x(movd, Movd, Xmm, Gp) // SSE + ASMJIT_INST_2x(movddup, Movddup, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(movddup, Movddup, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(movdq2q, Movdq2q, Mm, Xmm) // SSE2 + ASMJIT_INST_2x(movdqa, Movdqa, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movdqa, Movdqa, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movdqa, Movdqa, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movdqu, Movdqu, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movdqu, Movdqu, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movdqu, Movdqu, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movhlps, Movhlps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movhpd, Movhpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movhpd, Movhpd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movhps, Movhps, Xmm, Mem) // SSE + ASMJIT_INST_2x(movhps, Movhps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movlhps, Movlhps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movlpd, Movlpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movlpd, Movlpd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movlps, Movlps, Xmm, Mem) // SSE + ASMJIT_INST_2x(movlps, Movlps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movmskps, Movmskps, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(movmskpd, Movmskpd, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(movntdq, Movntdq, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movntdqa, Movntdqa, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(movntpd, Movntpd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movntps, Movntps, Mem, Xmm) // SSE + ASMJIT_INST_2x(movntsd, Movntsd, Mem, Xmm) // SSE4A + ASMJIT_INST_2x(movntss, Movntss, Mem, Xmm) // SSE4A + ASMJIT_INST_2x(movntq, Movntq, Mem, Mm) // SSE + ASMJIT_INST_2x(movq, Movq, Mm, Mm) // MMX + ASMJIT_INST_2x(movq, Movq, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movq, Movq, Mem, Mm) // MMX + ASMJIT_INST_2x(movq, Movq, Mem, Xmm) // SSE + ASMJIT_INST_2x(movq, Movq, Mm, Mem) // MMX + ASMJIT_INST_2x(movq, Movq, Xmm, Mem) // SSE + ASMJIT_INST_2x(movq, Movq, Gp, Mm) // MMX + ASMJIT_INST_2x(movq, Movq, Gp, Xmm) // SSE+X64. + ASMJIT_INST_2x(movq, Movq, Mm, Gp) // MMX + ASMJIT_INST_2x(movq, Movq, Xmm, Gp) // SSE+X64. + ASMJIT_INST_2x(movq2dq, Movq2dq, Xmm, Mm) // SSE2 + ASMJIT_INST_2x(movsd, Movsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movsd, Movsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movsd, Movsd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movshdup, Movshdup, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(movshdup, Movshdup, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(movsldup, Movsldup, Xmm, Xmm) // SSE3 + ASMJIT_INST_2x(movsldup, Movsldup, Xmm, Mem) // SSE3 + ASMJIT_INST_2x(movss, Movss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movss, Movss, Xmm, Mem) // SSE + ASMJIT_INST_2x(movss, Movss, Mem, Xmm) // SSE + ASMJIT_INST_2x(movupd, Movupd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(movupd, Movupd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(movupd, Movupd, Mem, Xmm) // SSE2 + ASMJIT_INST_2x(movups, Movups, Xmm, Xmm) // SSE + ASMJIT_INST_2x(movups, Movups, Xmm, Mem) // SSE + ASMJIT_INST_2x(movups, Movups, Mem, Xmm) // SSE + ASMJIT_INST_3x(mpsadbw, Mpsadbw, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(mpsadbw, Mpsadbw, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_2x(mulpd, Mulpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(mulpd, Mulpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(mulps, Mulps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(mulps, Mulps, Xmm, Mem) // SSE + ASMJIT_INST_2x(mulsd, Mulsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(mulsd, Mulsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(mulss, Mulss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(mulss, Mulss, Xmm, Mem) // SSE + ASMJIT_INST_2x(orpd, Orpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(orpd, Orpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(orps, Orps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(orps, Orps, Xmm, Mem) // SSE + ASMJIT_INST_2x(packssdw, Packssdw, Mm, Mm) // MMX + ASMJIT_INST_2x(packssdw, Packssdw, Mm, Mem) // MMX + ASMJIT_INST_2x(packssdw, Packssdw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(packssdw, Packssdw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(packsswb, Packsswb, Mm, Mm) // MMX + ASMJIT_INST_2x(packsswb, Packsswb, Mm, Mem) // MMX + ASMJIT_INST_2x(packsswb, Packsswb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(packsswb, Packsswb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(packusdw, Packusdw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(packusdw, Packusdw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(packuswb, Packuswb, Mm, Mm) // MMX + ASMJIT_INST_2x(packuswb, Packuswb, Mm, Mem) // MMX + ASMJIT_INST_2x(packuswb, Packuswb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(packuswb, Packuswb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pabsb, Pabsb, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pabsb, Pabsb, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsb, Pabsb, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pabsb, Pabsb, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pabsd, Pabsd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pabsw, Pabsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(paddb, Paddb, Mm, Mm) // MMX + ASMJIT_INST_2x(paddb, Paddb, Mm, Mem) // MMX + ASMJIT_INST_2x(paddb, Paddb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddb, Paddb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddd, Paddd, Mm, Mm) // MMX + ASMJIT_INST_2x(paddd, Paddd, Mm, Mem) // MMX + ASMJIT_INST_2x(paddd, Paddd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddd, Paddd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Mm, Mm) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Mm, Mem) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddq, Paddq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddsb, Paddsb, Mm, Mm) // MMX + ASMJIT_INST_2x(paddsb, Paddsb, Mm, Mem) // MMX + ASMJIT_INST_2x(paddsb, Paddsb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddsb, Paddsb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddsw, Paddsw, Mm, Mm) // MMX + ASMJIT_INST_2x(paddsw, Paddsw, Mm, Mem) // MMX + ASMJIT_INST_2x(paddsw, Paddsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddsw, Paddsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddusb, Paddusb, Mm, Mm) // MMX + ASMJIT_INST_2x(paddusb, Paddusb, Mm, Mem) // MMX + ASMJIT_INST_2x(paddusb, Paddusb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddusb, Paddusb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddusw, Paddusw, Mm, Mm) // MMX + ASMJIT_INST_2x(paddusw, Paddusw, Mm, Mem) // MMX + ASMJIT_INST_2x(paddusw, Paddusw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddusw, Paddusw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(paddw, Paddw, Mm, Mm) // MMX + ASMJIT_INST_2x(paddw, Paddw, Mm, Mem) // MMX + ASMJIT_INST_2x(paddw, Paddw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(paddw, Paddw, Xmm, Mem) // SSE2 + ASMJIT_INST_3x(palignr, Palignr, Mm, Mm, Imm) // SSSE3 + ASMJIT_INST_3x(palignr, Palignr, Mm, Mem, Imm) // SSSE3 + ASMJIT_INST_3x(palignr, Palignr, Xmm, Xmm, Imm) // SSSE3 + ASMJIT_INST_3x(palignr, Palignr, Xmm, Mem, Imm) // SSSE3 + ASMJIT_INST_2x(pand, Pand, Mm, Mm) // MMX + ASMJIT_INST_2x(pand, Pand, Mm, Mem) // MMX + ASMJIT_INST_2x(pand, Pand, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pand, Pand, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pandn, Pandn, Mm, Mm) // MMX + ASMJIT_INST_2x(pandn, Pandn, Mm, Mem) // MMX + ASMJIT_INST_2x(pandn, Pandn, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pandn, Pandn, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pavgb, Pavgb, Mm, Mm) // SSE + ASMJIT_INST_2x(pavgb, Pavgb, Mm, Mem) // SSE + ASMJIT_INST_2x(pavgb, Pavgb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pavgb, Pavgb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pavgw, Pavgw, Mm, Mm) // SSE + ASMJIT_INST_2x(pavgw, Pavgw, Mm, Mem) // SSE + ASMJIT_INST_2x(pavgw, Pavgw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pavgw, Pavgw, Xmm, Mem) // SSE2 + ASMJIT_INST_3x(pblendvb, Pblendvb, Xmm, Xmm, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(pblendvb, Pblendvb, Xmm, Mem, XMM0) // SSE4_1 [EXPLICIT] + ASMJIT_INST_3x(pblendw, Pblendw, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pblendw, Pblendw, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pclmulqdq, Pclmulqdq, Xmm, Xmm, Imm) // PCLMULQDQ. + ASMJIT_INST_3x(pclmulqdq, Pclmulqdq, Xmm, Mem, Imm) // PCLMULQDQ. + ASMJIT_INST_6x(pcmpestri, Pcmpestri, Xmm, Xmm, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_6x(pcmpestri, Pcmpestri, Xmm, Mem, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_6x(pcmpestrm, Pcmpestrm, Xmm, Xmm, Imm, XMM0, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_6x(pcmpestrm, Pcmpestrm, Xmm, Mem, Imm, XMM0, Gp_EAX, Gp_EDX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpeqb, Pcmpeqb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpeqd, Pcmpeqd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpeqq, Pcmpeqq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pcmpeqq, Pcmpeqq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpeqw, Pcmpeqw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpgtb, Pcmpgtb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpgtd, Pcmpgtd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pcmpgtq, Pcmpgtq, Xmm, Xmm) // SSE4_2. + ASMJIT_INST_2x(pcmpgtq, Pcmpgtq, Xmm, Mem) // SSE4_2. + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Mm, Mm) // MMX + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Mm, Mem) // MMX + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pcmpgtw, Pcmpgtw, Xmm, Mem) // SSE2 + ASMJIT_INST_4x(pcmpistri, Pcmpistri, Xmm, Xmm, Imm, Gp_ECX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_4x(pcmpistri, Pcmpistri, Xmm, Mem, Imm, Gp_ECX) // SSE4_2 [EXPLICIT] + ASMJIT_INST_4x(pcmpistrm, Pcmpistrm, Xmm, Xmm, Imm, XMM0) // SSE4_2 [EXPLICIT] + ASMJIT_INST_4x(pcmpistrm, Pcmpistrm, Xmm, Mem, Imm, XMM0) // SSE4_2 [EXPLICIT] + ASMJIT_INST_3x(pextrb, Pextrb, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrb, Pextrb, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrd, Pextrd, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrd, Pextrd, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrq, Pextrq, Gp, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrq, Pextrq, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(pextrw, Pextrw, Gp, Mm, Imm) // SSE + ASMJIT_INST_3x(pextrw, Pextrw, Gp, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pextrw, Pextrw, Mem, Xmm, Imm) // SSE4_1 + ASMJIT_INST_2x(phaddd, Phaddd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phaddd, Phaddd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddd, Phaddd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phaddd, Phaddd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phaddsw, Phaddsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phaddw, Phaddw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phminposuw, Phminposuw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(phminposuw, Phminposuw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(phsubd, Phsubd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phsubd, Phsubd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubd, Phsubd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phsubd, Phsubd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phsubsw, Phsubsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(phsubw, Phsubw, Xmm, Mem) // SSSE3 + ASMJIT_INST_3x(pinsrb, Pinsrb, Xmm, Gp, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrb, Pinsrb, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrd, Pinsrd, Xmm, Gp, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrd, Pinsrd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrq, Pinsrq, Xmm, Gp, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrq, Pinsrq, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(pinsrw, Pinsrw, Mm, Gp, Imm) // SSE + ASMJIT_INST_3x(pinsrw, Pinsrw, Mm, Mem, Imm) // SSE + ASMJIT_INST_3x(pinsrw, Pinsrw, Xmm, Gp, Imm) // SSE2 + ASMJIT_INST_3x(pinsrw, Pinsrw, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pmaddubsw, Pmaddubsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Mm, Mm) // MMX + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Mm, Mem) // MMX + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmaddwd, Pmaddwd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmaxsb, Pmaxsb, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxsb, Pmaxsb, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmaxsd, Pmaxsd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxsd, Pmaxsd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Mm, Mm) // SSE + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Mm, Mem) // SSE + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmaxsw, Pmaxsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmaxub, Pmaxub, Mm, Mm) // SSE + ASMJIT_INST_2x(pmaxub, Pmaxub, Mm, Mem) // SSE + ASMJIT_INST_2x(pmaxub, Pmaxub, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmaxub, Pmaxub, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmaxud, Pmaxud, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxud, Pmaxud, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmaxuw, Pmaxuw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmaxuw, Pmaxuw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminsb, Pminsb, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminsb, Pminsb, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminsd, Pminsd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminsd, Pminsd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminsw, Pminsw, Mm, Mm) // SSE + ASMJIT_INST_2x(pminsw, Pminsw, Mm, Mem) // SSE + ASMJIT_INST_2x(pminsw, Pminsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pminsw, Pminsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pminub, Pminub, Mm, Mm) // SSE + ASMJIT_INST_2x(pminub, Pminub, Mm, Mem) // SSE + ASMJIT_INST_2x(pminub, Pminub, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pminub, Pminub, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pminud, Pminud, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminud, Pminud, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pminuw, Pminuw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pminuw, Pminuw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovmskb, Pmovmskb, Gp, Mm) // SSE + ASMJIT_INST_2x(pmovmskb, Pmovmskb, Gp, Xmm) // SSE2 + ASMJIT_INST_2x(pmovsxbd, Pmovsxbd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxbd, Pmovsxbd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxbq, Pmovsxbq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxbq, Pmovsxbq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxbw, Pmovsxbw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxbw, Pmovsxbw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxdq, Pmovsxdq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxdq, Pmovsxdq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxwd, Pmovsxwd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxwd, Pmovsxwd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovsxwq, Pmovsxwq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovsxwq, Pmovsxwq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxbd, Pmovzxbd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxbd, Pmovzxbd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxbq, Pmovzxbq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxbq, Pmovzxbq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxbw, Pmovzxbw, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxbw, Pmovzxbw, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxdq, Pmovzxdq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxdq, Pmovzxdq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxwd, Pmovzxwd, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxwd, Pmovzxwd, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmovzxwq, Pmovzxwq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmovzxwq, Pmovzxwq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmuldq, Pmuldq, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmuldq, Pmuldq, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pmulhrsw, Pmulhrsw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(pmulhw, Pmulhw, Mm, Mm) // MMX + ASMJIT_INST_2x(pmulhw, Pmulhw, Mm, Mem) // MMX + ASMJIT_INST_2x(pmulhw, Pmulhw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmulhw, Pmulhw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Mm, Mm) // SSE + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Mm, Mem) // SSE + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmulhuw, Pmulhuw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmulld, Pmulld, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(pmulld, Pmulld, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(pmullw, Pmullw, Mm, Mm) // MMX + ASMJIT_INST_2x(pmullw, Pmullw, Mm, Mem) // MMX + ASMJIT_INST_2x(pmullw, Pmullw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmullw, Pmullw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Mm, Mm) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Mm, Mem) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pmuludq, Pmuludq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(por, Por, Mm, Mm) // MMX + ASMJIT_INST_2x(por, Por, Mm, Mem) // MMX + ASMJIT_INST_2x(por, Por, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(por, Por, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psadbw, Psadbw, Mm, Mm) // SSE + ASMJIT_INST_2x(psadbw, Psadbw, Mm, Mem) // SSE + ASMJIT_INST_2x(psadbw, Psadbw, Xmm, Xmm) // SSE + ASMJIT_INST_2x(psadbw, Psadbw, Xmm, Mem) // SSE + ASMJIT_INST_2x(pslld, Pslld, Mm, Mm) // MMX + ASMJIT_INST_2x(pslld, Pslld, Mm, Mem) // MMX + ASMJIT_INST_2x(pslld, Pslld, Mm, Imm) // MMX + ASMJIT_INST_2x(pslld, Pslld, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pslld, Pslld, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pslld, Pslld, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(pslldq, Pslldq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psllq, Psllq, Mm, Mm) // MMX + ASMJIT_INST_2x(psllq, Psllq, Mm, Mem) // MMX + ASMJIT_INST_2x(psllq, Psllq, Mm, Imm) // MMX + ASMJIT_INST_2x(psllq, Psllq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psllq, Psllq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psllq, Psllq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psllw, Psllw, Mm, Mm) // MMX + ASMJIT_INST_2x(psllw, Psllw, Mm, Mem) // MMX + ASMJIT_INST_2x(psllw, Psllw, Mm, Imm) // MMX + ASMJIT_INST_2x(psllw, Psllw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psllw, Psllw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psllw, Psllw, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrad, Psrad, Mm, Mm) // MMX + ASMJIT_INST_2x(psrad, Psrad, Mm, Mem) // MMX + ASMJIT_INST_2x(psrad, Psrad, Mm, Imm) // MMX + ASMJIT_INST_2x(psrad, Psrad, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrad, Psrad, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrad, Psrad, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psraw, Psraw, Mm, Mm) // MMX + ASMJIT_INST_2x(psraw, Psraw, Mm, Mem) // MMX + ASMJIT_INST_2x(psraw, Psraw, Mm, Imm) // MMX + ASMJIT_INST_2x(psraw, Psraw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psraw, Psraw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psraw, Psraw, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(pshufb, Pshufb, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(pshufb, Pshufb, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(pshufb, Pshufb, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(pshufb, Pshufb, Xmm, Mem) // SSSE3 + ASMJIT_INST_3x(pshufd, Pshufd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pshufd, Pshufd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(pshufhw, Pshufhw, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pshufhw, Pshufhw, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(pshuflw, Pshuflw, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(pshuflw, Pshuflw, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(pshufw, Pshufw, Mm, Mm, Imm) // SSE + ASMJIT_INST_3x(pshufw, Pshufw, Mm, Mem, Imm) // SSE + ASMJIT_INST_2x(psignb, Psignb, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(psignb, Psignb, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(psignb, Psignb, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(psignb, Psignb, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(psignd, Psignd, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Mm, Mm) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Mm, Mem) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Xmm, Xmm) // SSSE3 + ASMJIT_INST_2x(psignw, Psignw, Xmm, Mem) // SSSE3 + ASMJIT_INST_2x(psrld, Psrld, Mm, Mm) // MMX + ASMJIT_INST_2x(psrld, Psrld, Mm, Mem) // MMX + ASMJIT_INST_2x(psrld, Psrld, Mm, Imm) // MMX + ASMJIT_INST_2x(psrld, Psrld, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrld, Psrld, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrld, Psrld, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrldq, Psrldq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrlq, Psrlq, Mm, Mm) // MMX + ASMJIT_INST_2x(psrlq, Psrlq, Mm, Mem) // MMX + ASMJIT_INST_2x(psrlq, Psrlq, Mm, Imm) // MMX + ASMJIT_INST_2x(psrlq, Psrlq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrlq, Psrlq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrlq, Psrlq, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psrlw, Psrlw, Mm, Mm) // MMX + ASMJIT_INST_2x(psrlw, Psrlw, Mm, Mem) // MMX + ASMJIT_INST_2x(psrlw, Psrlw, Mm, Imm) // MMX + ASMJIT_INST_2x(psrlw, Psrlw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psrlw, Psrlw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psrlw, Psrlw, Xmm, Imm) // SSE2 + ASMJIT_INST_2x(psubb, Psubb, Mm, Mm) // MMX + ASMJIT_INST_2x(psubb, Psubb, Mm, Mem) // MMX + ASMJIT_INST_2x(psubb, Psubb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubb, Psubb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubd, Psubd, Mm, Mm) // MMX + ASMJIT_INST_2x(psubd, Psubd, Mm, Mem) // MMX + ASMJIT_INST_2x(psubd, Psubd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubd, Psubd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Mm, Mm) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Mm, Mem) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubq, Psubq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubsb, Psubsb, Mm, Mm) // MMX + ASMJIT_INST_2x(psubsb, Psubsb, Mm, Mem) // MMX + ASMJIT_INST_2x(psubsb, Psubsb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubsb, Psubsb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubsw, Psubsw, Mm, Mm) // MMX + ASMJIT_INST_2x(psubsw, Psubsw, Mm, Mem) // MMX + ASMJIT_INST_2x(psubsw, Psubsw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubsw, Psubsw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubusb, Psubusb, Mm, Mm) // MMX + ASMJIT_INST_2x(psubusb, Psubusb, Mm, Mem) // MMX + ASMJIT_INST_2x(psubusb, Psubusb, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubusb, Psubusb, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubusw, Psubusw, Mm, Mm) // MMX + ASMJIT_INST_2x(psubusw, Psubusw, Mm, Mem) // MMX + ASMJIT_INST_2x(psubusw, Psubusw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubusw, Psubusw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(psubw, Psubw, Mm, Mm) // MMX + ASMJIT_INST_2x(psubw, Psubw, Mm, Mem) // MMX + ASMJIT_INST_2x(psubw, Psubw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(psubw, Psubw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(ptest, Ptest, Xmm, Xmm) // SSE4_1 + ASMJIT_INST_2x(ptest, Ptest, Xmm, Mem) // SSE4_1 + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhbw, Punpckhbw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhdq, Punpckhdq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckhqdq, Punpckhqdq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhqdq, Punpckhqdq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckhwd, Punpckhwd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Mm, Mm) // MMX + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Mm, Mem) // MMX + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpcklbw, Punpcklbw, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpckldq, Punpckldq, Mm, Mm) // MMX + ASMJIT_INST_2x(punpckldq, Punpckldq, Mm, Mem) // MMX + ASMJIT_INST_2x(punpckldq, Punpckldq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpckldq, Punpckldq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpcklqdq, Punpcklqdq, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpcklqdq, Punpcklqdq, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Mm, Mm) // MMX + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Mm, Mem) // MMX + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(punpcklwd, Punpcklwd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(pxor, Pxor, Mm, Mm) // MMX + ASMJIT_INST_2x(pxor, Pxor, Mm, Mem) // MMX + ASMJIT_INST_2x(pxor, Pxor, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(pxor, Pxor, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(rcpps, Rcpps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rcpps, Rcpps, Xmm, Mem) // SSE + ASMJIT_INST_2x(rcpss, Rcpss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rcpss, Rcpss, Xmm, Mem) // SSE + ASMJIT_INST_3x(roundpd, Roundpd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundpd, Roundpd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(roundps, Roundps, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundps, Roundps, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(roundsd, Roundsd, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundsd, Roundsd, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_3x(roundss, Roundss, Xmm, Xmm, Imm) // SSE4_1 + ASMJIT_INST_3x(roundss, Roundss, Xmm, Mem, Imm) // SSE4_1 + ASMJIT_INST_2x(rsqrtps, Rsqrtps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rsqrtps, Rsqrtps, Xmm, Mem) // SSE + ASMJIT_INST_2x(rsqrtss, Rsqrtss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(rsqrtss, Rsqrtss, Xmm, Mem) // SSE + ASMJIT_INST_3x(shufpd, Shufpd, Xmm, Xmm, Imm) // SSE2 + ASMJIT_INST_3x(shufpd, Shufpd, Xmm, Mem, Imm) // SSE2 + ASMJIT_INST_3x(shufps, Shufps, Xmm, Xmm, Imm) // SSE + ASMJIT_INST_3x(shufps, Shufps, Xmm, Mem, Imm) // SSE + ASMJIT_INST_2x(sqrtpd, Sqrtpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(sqrtpd, Sqrtpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(sqrtps, Sqrtps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(sqrtps, Sqrtps, Xmm, Mem) // SSE + ASMJIT_INST_2x(sqrtsd, Sqrtsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(sqrtsd, Sqrtsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(sqrtss, Sqrtss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(sqrtss, Sqrtss, Xmm, Mem) // SSE + ASMJIT_INST_2x(subpd, Subpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(subpd, Subpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(subps, Subps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(subps, Subps, Xmm, Mem) // SSE + ASMJIT_INST_2x(subsd, Subsd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(subsd, Subsd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(subss, Subss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(subss, Subss, Xmm, Mem) // SSE + ASMJIT_INST_2x(ucomisd, Ucomisd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(ucomisd, Ucomisd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(ucomiss, Ucomiss, Xmm, Xmm) // SSE + ASMJIT_INST_2x(ucomiss, Ucomiss, Xmm, Mem) // SSE + ASMJIT_INST_2x(unpckhpd, Unpckhpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(unpckhpd, Unpckhpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(unpckhps, Unpckhps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(unpckhps, Unpckhps, Xmm, Mem) // SSE + ASMJIT_INST_2x(unpcklpd, Unpcklpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(unpcklpd, Unpcklpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(unpcklps, Unpcklps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(unpcklps, Unpcklps, Xmm, Mem) // SSE + ASMJIT_INST_2x(xorpd, Xorpd, Xmm, Xmm) // SSE2 + ASMJIT_INST_2x(xorpd, Xorpd, Xmm, Mem) // SSE2 + ASMJIT_INST_2x(xorps, Xorps, Xmm, Xmm) // SSE + ASMJIT_INST_2x(xorps, Xorps, Xmm, Mem) // SSE + + //! \} + + //! \name 3DNOW and GEODE Instructions (Deprecated) + //! \{ + + ASMJIT_INST_2x(pavgusb, Pavgusb, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pavgusb, Pavgusb, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pf2id, Pf2id, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pf2id, Pf2id, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pf2iw, Pf2iw, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pf2iw, Pf2iw, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfacc, Pfacc, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfacc, Pfacc, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfadd, Pfadd, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfadd, Pfadd, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfcmpeq, Pfcmpeq, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfcmpeq, Pfcmpeq, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfcmpge, Pfcmpge, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfcmpge, Pfcmpge, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfcmpgt, Pfcmpgt, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfcmpgt, Pfcmpgt, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfmax, Pfmax, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfmax, Pfmax, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfmin, Pfmin, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfmin, Pfmin, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfmul, Pfmul, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfmul, Pfmul, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfnacc, Pfnacc, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfnacc, Pfnacc, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfpnacc, Pfpnacc, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfpnacc, Pfpnacc, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcp, Pfrcp, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrcp, Pfrcp, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcpit1, Pfrcpit1, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrcpit1, Pfrcpit1, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcpit2, Pfrcpit2, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrcpit2, Pfrcpit2, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrcpv, Pfrcpv, Mm, Mm) // GEODE + ASMJIT_INST_2x(pfrcpv, Pfrcpv, Mm, Mem) // GEODE + ASMJIT_INST_2x(pfrsqit1, Pfrsqit1, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrsqit1, Pfrsqit1, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrsqrt, Pfrsqrt, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfrsqrt, Pfrsqrt, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfrsqrtv, Pfrsqrtv, Mm, Mm) // GEODE + ASMJIT_INST_2x(pfrsqrtv, Pfrsqrtv, Mm, Mem) // GEODE + ASMJIT_INST_2x(pfsub, Pfsub, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfsub, Pfsub, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pfsubr, Pfsubr, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pfsubr, Pfsubr, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pi2fd, Pi2fd, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pi2fd, Pi2fd, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pi2fw, Pi2fw, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pi2fw, Pi2fw, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pmulhrw, Pmulhrw, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pmulhrw, Pmulhrw, Mm, Mem) // 3DNOW + ASMJIT_INST_2x(pswapd, Pswapd, Mm, Mm) // 3DNOW + ASMJIT_INST_2x(pswapd, Pswapd, Mm, Mem) // 3DNOW + + //! \} + + //! \name EMMS/FEMMS Instructions + //! \{ + + ASMJIT_INST_0x(emms, Emms) // MMX + ASMJIT_INST_0x(femms, Femms) // 3DNOW + + //! \} + + //! \name AESNI Instructions + //! \{ + + ASMJIT_INST_2x(aesdec, Aesdec, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesdec, Aesdec, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesdeclast, Aesdeclast, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesdeclast, Aesdeclast, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesenc, Aesenc, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesenc, Aesenc, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesenclast, Aesenclast, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesenclast, Aesenclast, Xmm, Mem) // AESNI + ASMJIT_INST_2x(aesimc, Aesimc, Xmm, Xmm) // AESNI + ASMJIT_INST_2x(aesimc, Aesimc, Xmm, Mem) // AESNI + ASMJIT_INST_3x(aeskeygenassist, Aeskeygenassist, Xmm, Xmm, Imm) // AESNI + ASMJIT_INST_3x(aeskeygenassist, Aeskeygenassist, Xmm, Mem, Imm) // AESNI + + //! \} + + //! \name SHA Instructions + //! \{ + + ASMJIT_INST_2x(sha1msg1, Sha1msg1, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha1msg1, Sha1msg1, Xmm, Mem) // SHA + ASMJIT_INST_2x(sha1msg2, Sha1msg2, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha1msg2, Sha1msg2, Xmm, Mem) // SHA + ASMJIT_INST_2x(sha1nexte, Sha1nexte, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha1nexte, Sha1nexte, Xmm, Mem) // SHA + ASMJIT_INST_3x(sha1rnds4, Sha1rnds4, Xmm, Xmm, Imm) // SHA + ASMJIT_INST_3x(sha1rnds4, Sha1rnds4, Xmm, Mem, Imm) // SHA + ASMJIT_INST_2x(sha256msg1, Sha256msg1, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha256msg1, Sha256msg1, Xmm, Mem) // SHA + ASMJIT_INST_2x(sha256msg2, Sha256msg2, Xmm, Xmm) // SHA + ASMJIT_INST_2x(sha256msg2, Sha256msg2, Xmm, Mem) // SHA + ASMJIT_INST_3x(sha256rnds2, Sha256rnds2, Xmm, Xmm, XMM0) // SHA [EXPLICIT] + ASMJIT_INST_3x(sha256rnds2, Sha256rnds2, Xmm, Mem, XMM0) // SHA [EXPLICIT] + + //! \} + + //! \name GFNI Instructions + //! \{ + + ASMJIT_INST_3x(gf2p8affineinvqb, Gf2p8affineinvqb, Xmm, Xmm, Imm) // GFNI + ASMJIT_INST_3x(gf2p8affineinvqb, Gf2p8affineinvqb, Xmm, Mem, Imm) // GFNI + ASMJIT_INST_3x(gf2p8affineqb, Gf2p8affineqb, Xmm, Xmm, Imm) // GFNI + ASMJIT_INST_3x(gf2p8affineqb, Gf2p8affineqb, Xmm, Mem, Imm) // GFNI + ASMJIT_INST_2x(gf2p8mulb, Gf2p8mulb, Xmm, Xmm) // GFNI + ASMJIT_INST_2x(gf2p8mulb, Gf2p8mulb, Xmm, Mem) // GFNI + + //! \} + + //! \name AVX, FMA, and AVX512 Instructions + //! \{ + + ASMJIT_INST_3x(kaddb, Kaddb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kaddd, Kaddd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kaddq, Kaddq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kaddw, Kaddw, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kandb, Kandb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kandd, Kandd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandnb, Kandnb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kandnd, Kandnd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandnq, Kandnq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandnw, Kandnw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kandq, Kandq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kandw, Kandw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_2x(kmovb, Kmovb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, KReg, Mem) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, KReg, Gp) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, Mem, KReg) // AVX512_DQ + ASMJIT_INST_2x(kmovb, Kmovb, Gp, KReg) // AVX512_DQ + ASMJIT_INST_2x(kmovd, Kmovd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, KReg, Mem) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, KReg, Gp) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, Mem, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovd, Kmovd, Gp, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, KReg, Mem) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, KReg, Gp) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, Mem, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovq, Kmovq, Gp, KReg) // AVX512_BW + ASMJIT_INST_2x(kmovw, Kmovw, KReg, KReg) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, KReg, Mem) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, KReg, Gp) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, Mem, KReg) // AVX512_F + ASMJIT_INST_2x(kmovw, Kmovw, Gp, KReg) // AVX512_F + ASMJIT_INST_2x(knotb, Knotb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(knotd, Knotd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(knotq, Knotq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(knotw, Knotw, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(korb, Korb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kord, Kord, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(korq, Korq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kortestb, Kortestb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(kortestd, Kortestd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kortestq, Kortestq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(kortestw, Kortestw, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(korw, Korw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kshiftlb, Kshiftlb, KReg, KReg, Imm) // AVX512_DQ + ASMJIT_INST_3x(kshiftld, Kshiftld, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftlq, Kshiftlq, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftlw, Kshiftlw, KReg, KReg, Imm) // AVX512_F + ASMJIT_INST_3x(kshiftrb, Kshiftrb, KReg, KReg, Imm) // AVX512_DQ + ASMJIT_INST_3x(kshiftrd, Kshiftrd, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftrq, Kshiftrq, KReg, KReg, Imm) // AVX512_BW + ASMJIT_INST_3x(kshiftrw, Kshiftrw, KReg, KReg, Imm) // AVX512_F + ASMJIT_INST_2x(ktestb, Ktestb, KReg, KReg) // AVX512_DQ + ASMJIT_INST_2x(ktestd, Ktestd, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(ktestq, Ktestq, KReg, KReg) // AVX512_BW + ASMJIT_INST_2x(ktestw, Ktestw, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kunpckbw, Kunpckbw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kunpckdq, Kunpckdq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kunpckwd, Kunpckwd, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxnorb, Kxnorb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kxnord, Kxnord, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxnorq, Kxnorq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxnorw, Kxnorw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_3x(kxorb, Kxorb, KReg, KReg, KReg) // AVX512_DQ + ASMJIT_INST_3x(kxord, Kxord, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxorq, Kxorq, KReg, KReg, KReg) // AVX512_BW + ASMJIT_INST_3x(kxorw, Kxorw, KReg, KReg, KReg) // AVX512_F + ASMJIT_INST_6x(v4fmaddps, V4fmaddps, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(v4fmaddss, V4fmaddss, Xmm, Xmm, Xmm, Xmm, Xmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(v4fnmaddps, V4fnmaddps, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(v4fnmaddss, V4fnmaddss, Xmm, Xmm, Xmm, Xmm, Xmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_3x(vaddpd, Vaddpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vaddpd, Vaddpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vaddps, Vaddps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vaddps, Vaddps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vaddsd, Vaddsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddsd, Vaddsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddss, Vaddss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddss, Vaddss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vaddsubpd, Vaddsubpd, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vaddsubpd, Vaddsubpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vaddsubps, Vaddsubps, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vaddsubps, Vaddsubps, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vaesdec, Vaesdec, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesdec, Vaesdec, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesdeclast, Vaesdeclast, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesdeclast, Vaesdeclast, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenc, Vaesenc, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenc, Vaesenc, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenclast, Vaesenclast, Vec, Vec, Vec) // AVX+AESNI VAES + ASMJIT_INST_3x(vaesenclast, Vaesenclast, Vec, Vec, Mem) // AVX+AESNI VAES + ASMJIT_INST_2x(vaesimc, Vaesimc, Xmm, Xmm) // AVX+AESNI + ASMJIT_INST_2x(vaesimc, Vaesimc, Xmm, Mem) // AVX+AESNI + ASMJIT_INST_3x(vaeskeygenassist, Vaeskeygenassist, Xmm, Xmm, Imm) // AVX+AESNI + ASMJIT_INST_3x(vaeskeygenassist, Vaeskeygenassist, Xmm, Mem, Imm) // AVX+AESNI + ASMJIT_INST_4x(valignd, Valignd, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(valignd, Valignd, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(valignq, Valignq, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(valignq, Valignq, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vandnpd, Vandnpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandnpd, Vandnpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandnps, Vandnps, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vandnps, Vandnps, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vandpd, Vandpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandpd, Vandpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vandps, Vandps, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vandps, Vandps, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vblendmpd, Vblendmpd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vblendmpd, Vblendmpd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vblendmps, Vblendmps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vblendmps, Vblendmps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vblendpd, Vblendpd, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vblendpd, Vblendpd, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vblendps, Vblendps, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vblendps, Vblendps, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vblendvpd, Vblendvpd, Vec, Vec, Vec, Vec) // AVX + ASMJIT_INST_4x(vblendvpd, Vblendvpd, Vec, Vec, Mem, Vec) // AVX + ASMJIT_INST_4x(vblendvps, Vblendvps, Vec, Vec, Vec, Vec) // AVX + ASMJIT_INST_4x(vblendvps, Vblendvps, Vec, Vec, Mem, Vec) // AVX + ASMJIT_INST_2x(vbroadcastf128, Vbroadcastf128, Vec, Mem) // AVX + ASMJIT_INST_2x(vbroadcastf32x2, Vbroadcastf32x2, Vec, Vec) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf32x2, Vbroadcastf32x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf32x4, Vbroadcastf32x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastf32x8, Vbroadcastf32x8, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf64x2, Vbroadcastf64x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcastf64x4, Vbroadcastf64x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcasti128, Vbroadcasti128, Vec, Mem) // AVX2 + ASMJIT_INST_2x(vbroadcasti32x2, Vbroadcasti32x2, Vec, Vec) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti32x2, Vbroadcasti32x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti32x4, Vbroadcasti32x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcasti32x8, Vbroadcasti32x8, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti64x2, Vbroadcasti64x2, Vec, Vec) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti64x2, Vbroadcasti64x2, Vec, Mem) // AVX512_DQ{kz} + ASMJIT_INST_2x(vbroadcasti64x4, Vbroadcasti64x4, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcasti64x4, Vbroadcasti64x4, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastsd, Vbroadcastsd, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastsd, Vbroadcastsd, Vec, Xmm) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastss, Vbroadcastss, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vbroadcastss, Vbroadcastss, Vec, Xmm) // AVX2 AVX512_F{kz} + ASMJIT_INST_4x(vcmppd, Vcmppd, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vcmppd, Vcmppd, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmppd, Vcmppd, KReg, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vcmppd, Vcmppd, KReg, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vcmpps, Vcmpps, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vcmpps, Vcmpps, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmpps, Vcmpps, KReg, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vcmpps, Vcmpps, KReg, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vcmpsd, Vcmpsd, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vcmpsd, Vcmpsd, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmpsd, Vcmpsd, KReg, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vcmpsd, Vcmpsd, KReg, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vcmpss, Vcmpss, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vcmpss, Vcmpss, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_4x(vcmpss, Vcmpss, KReg, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vcmpss, Vcmpss, KReg, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_2x(vcomisd, Vcomisd, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcomisd, Vcomisd, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcomiss, Vcomiss, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcomiss, Vcomiss, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcompresspd, Vcompresspd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcompresspd, Vcompresspd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcompressps, Vcompressps, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcompressps, Vcompressps, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vcvtdq2pd, Vcvtdq2pd, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtdq2pd, Vcvtdq2pd, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtdq2ps, Vcvtdq2ps, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtdq2ps, Vcvtdq2ps, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vcvtne2ps2bf16, Vcvtne2ps2bf16, Vec, Vec, Vec) // AVX512_BF16{kz|b32} + ASMJIT_INST_3x(vcvtne2ps2bf16, Vcvtne2ps2bf16, Vec, Vec, Mem) // AVX512_BF16{kz|b32} + ASMJIT_INST_2x(vcvtneps2bf16, Vcvtneps2bf16, Vec, Vec) // AVX512_BF16{kz|b32} + ASMJIT_INST_2x(vcvtneps2bf16, Vcvtneps2bf16, Vec, Mem) // AVX512_BF16{kz|b32} + ASMJIT_INST_2x(vcvtpd2dq, Vcvtpd2dq, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2dq, Vcvtpd2dq, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2ps, Vcvtpd2ps, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2ps, Vcvtpd2ps, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2qq, Vcvtpd2qq, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtpd2qq, Vcvtpd2qq, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtpd2udq, Vcvtpd2udq, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2udq, Vcvtpd2udq, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvtpd2uqq, Vcvtpd2uqq, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtpd2uqq, Vcvtpd2uqq, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtph2ps, Vcvtph2ps, Vec, Vec) // F16C AVX512_F{kz} + ASMJIT_INST_2x(vcvtph2ps, Vcvtph2ps, Vec, Mem) // F16C AVX512_F{kz} + ASMJIT_INST_2x(vcvtps2dq, Vcvtps2dq, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2dq, Vcvtps2dq, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2pd, Vcvtps2pd, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2pd, Vcvtps2pd, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vcvtps2ph, Vcvtps2ph, Vec, Vec, Imm) // F16C AVX512_F{kz} + ASMJIT_INST_3x(vcvtps2ph, Vcvtps2ph, Mem, Vec, Imm) // F16C AVX512_F{kz} + ASMJIT_INST_2x(vcvtps2qq, Vcvtps2qq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtps2qq, Vcvtps2qq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtps2udq, Vcvtps2udq, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2udq, Vcvtps2udq, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtps2uqq, Vcvtps2uqq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtps2uqq, Vcvtps2uqq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvtqq2pd, Vcvtqq2pd, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtqq2pd, Vcvtqq2pd, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtqq2ps, Vcvtqq2ps, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtqq2ps, Vcvtqq2ps, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtsd2si, Vcvtsd2si, Gp, Xmm) // AVX AVX512_F{er} + ASMJIT_INST_2x(vcvtsd2si, Vcvtsd2si, Gp, Mem) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsd2ss, Vcvtsd2ss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vcvtsd2ss, Vcvtsd2ss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_2x(vcvtsd2usi, Vcvtsd2usi, Gp, Xmm) // AVX512_F{er} + ASMJIT_INST_2x(vcvtsd2usi, Vcvtsd2usi, Gp, Mem) // AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2sd, Vcvtsi2sd, Xmm, Xmm, Gp) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2sd, Vcvtsi2sd, Xmm, Xmm, Mem) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2ss, Vcvtsi2ss, Xmm, Xmm, Gp) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtsi2ss, Vcvtsi2ss, Xmm, Xmm, Mem) // AVX AVX512_F{er} + ASMJIT_INST_3x(vcvtss2sd, Vcvtss2sd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vcvtss2sd, Vcvtss2sd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_2x(vcvtss2si, Vcvtss2si, Gp, Xmm) // AVX AVX512_F{er} + ASMJIT_INST_2x(vcvtss2si, Vcvtss2si, Gp, Mem) // AVX AVX512_F{er} + ASMJIT_INST_2x(vcvtss2usi, Vcvtss2usi, Gp, Xmm) // AVX512_F{er} + ASMJIT_INST_2x(vcvtss2usi, Vcvtss2usi, Gp, Mem) // AVX512_F{er} + ASMJIT_INST_2x(vcvttpd2dq, Vcvttpd2dq, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2dq, Vcvttpd2dq, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2qq, Vcvttpd2qq, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2qq, Vcvttpd2qq, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2udq, Vcvttpd2udq, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2udq, Vcvttpd2udq, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vcvttpd2uqq, Vcvttpd2uqq, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvttpd2uqq, Vcvttpd2uqq, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvttps2dq, Vcvttps2dq, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2dq, Vcvttps2dq, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2qq, Vcvttps2qq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttps2qq, Vcvttps2qq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttps2udq, Vcvttps2udq, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2udq, Vcvttps2udq, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvttps2uqq, Vcvttps2uqq, Vec, Vec) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttps2uqq, Vcvttps2uqq, Vec, Mem) // AVX512_DQ{kz|b32} + ASMJIT_INST_2x(vcvttsd2si, Vcvttsd2si, Gp, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttsd2si, Vcvttsd2si, Gp, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttsd2usi, Vcvttsd2usi, Gp, Xmm) // AVX512_F{sae} + ASMJIT_INST_2x(vcvttsd2usi, Vcvttsd2usi, Gp, Mem) // AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2si, Vcvttss2si, Gp, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2si, Vcvttss2si, Gp, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2usi, Vcvttss2usi, Gp, Xmm) // AVX512_F{sae} + ASMJIT_INST_2x(vcvttss2usi, Vcvttss2usi, Gp, Mem) // AVX512_F{sae} + ASMJIT_INST_2x(vcvtudq2pd, Vcvtudq2pd, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtudq2pd, Vcvtudq2pd, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtudq2ps, Vcvtudq2ps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtudq2ps, Vcvtudq2ps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vcvtuqq2pd, Vcvtuqq2pd, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtuqq2pd, Vcvtuqq2pd, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtuqq2ps, Vcvtuqq2ps, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_2x(vcvtuqq2ps, Vcvtuqq2ps, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vcvtusi2sd, Vcvtusi2sd, Xmm, Xmm, Gp) // AVX512_F{er} + ASMJIT_INST_3x(vcvtusi2sd, Vcvtusi2sd, Xmm, Xmm, Mem) // AVX512_F{er} + ASMJIT_INST_3x(vcvtusi2ss, Vcvtusi2ss, Xmm, Xmm, Gp) // AVX512_F{er} + ASMJIT_INST_3x(vcvtusi2ss, Vcvtusi2ss, Xmm, Xmm, Mem) // AVX512_F{er} + ASMJIT_INST_4x(vdbpsadbw, Vdbpsadbw, Vec, Vec, Vec, Imm) // AVX512_BW{kz} + ASMJIT_INST_4x(vdbpsadbw, Vdbpsadbw, Vec, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vdivpd, Vdivpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vdivpd, Vdivpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vdivps, Vdivps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vdivps, Vdivps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vdivsd, Vdivsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdivsd, Vdivsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdivss, Vdivss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdivss, Vdivss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vdpbf16ps, Vdpbf16ps, Vec, Vec, Vec) // AVX512_BF16{kz|b32} + ASMJIT_INST_3x(vdpbf16ps, Vdpbf16ps, Vec, Vec, Mem) // AVX512_BF16{kz|b32} + ASMJIT_INST_4x(vdppd, Vdppd, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vdppd, Vdppd, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vdpps, Vdpps, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vdpps, Vdpps, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_2x(vexp2pd, Vexp2pd, Vec, Vec) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vexp2pd, Vexp2pd, Vec, Mem) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vexp2ps, Vexp2ps, Vec, Vec) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vexp2ps, Vexp2ps, Vec, Mem) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vexpandpd, Vexpandpd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vexpandpd, Vexpandpd, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vexpandps, Vexpandps, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vexpandps, Vexpandps, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf128, Vextractf128, Vec, Vec, Imm) // AVX + ASMJIT_INST_3x(vextractf128, Vextractf128, Mem, Vec, Imm) // AVX + ASMJIT_INST_3x(vextractf32x4, Vextractf32x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf32x4, Vextractf32x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf32x8, Vextractf32x8, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf32x8, Vextractf32x8, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf64x2, Vextractf64x2, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf64x2, Vextractf64x2, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextractf64x4, Vextractf64x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractf64x4, Vextractf64x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti128, Vextracti128, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_3x(vextracti128, Vextracti128, Mem, Vec, Imm) // AVX2 + ASMJIT_INST_3x(vextracti32x4, Vextracti32x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti32x4, Vextracti32x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti32x8, Vextracti32x8, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti32x8, Vextracti32x8, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti64x2, Vextracti64x2, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti64x2, Vextracti64x2, Mem, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vextracti64x4, Vextracti64x4, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextracti64x4, Vextracti64x4, Mem, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_3x(vextractps, Vextractps, Gp, Xmm, Imm) // AVX AVX512_F + ASMJIT_INST_3x(vextractps, Vextractps, Mem, Xmm, Imm) // AVX AVX512_F + ASMJIT_INST_4x(vfixupimmpd, Vfixupimmpd, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vfixupimmpd, Vfixupimmpd, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vfixupimmps, Vfixupimmps, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vfixupimmps, Vfixupimmps, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vfixupimmsd, Vfixupimmsd, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vfixupimmsd, Vfixupimmsd, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vfixupimmss, Vfixupimmss, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vfixupimmss, Vfixupimmss, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vfmadd132pd, Vfmadd132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd132pd, Vfmadd132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd132ps, Vfmadd132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd132ps, Vfmadd132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd132sd, Vfmadd132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd132sd, Vfmadd132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd132ss, Vfmadd132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd132ss, Vfmadd132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213pd, Vfmadd213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd213pd, Vfmadd213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd213ps, Vfmadd213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd213ps, Vfmadd213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd213sd, Vfmadd213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213sd, Vfmadd213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213ss, Vfmadd213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd213ss, Vfmadd213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231pd, Vfmadd231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd231pd, Vfmadd231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmadd231ps, Vfmadd231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd231ps, Vfmadd231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmadd231sd, Vfmadd231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231sd, Vfmadd231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231ss, Vfmadd231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmadd231ss, Vfmadd231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmaddsub132pd, Vfmaddsub132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub132pd, Vfmaddsub132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub132ps, Vfmaddsub132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub132ps, Vfmaddsub132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub213pd, Vfmaddsub213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub213pd, Vfmaddsub213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub213ps, Vfmaddsub213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub213ps, Vfmaddsub213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub231pd, Vfmaddsub231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub231pd, Vfmaddsub231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmaddsub231ps, Vfmaddsub231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmaddsub231ps, Vfmaddsub231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub132pd, Vfmsub132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub132pd, Vfmsub132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub132ps, Vfmsub132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub132ps, Vfmsub132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub132sd, Vfmsub132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub132sd, Vfmsub132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub132ss, Vfmsub132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub132ss, Vfmsub132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213pd, Vfmsub213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub213pd, Vfmsub213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub213ps, Vfmsub213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub213ps, Vfmsub213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub213sd, Vfmsub213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213sd, Vfmsub213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213ss, Vfmsub213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub213ss, Vfmsub213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231pd, Vfmsub231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub231pd, Vfmsub231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsub231ps, Vfmsub231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub231ps, Vfmsub231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsub231sd, Vfmsub231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231sd, Vfmsub231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231ss, Vfmsub231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsub231ss, Vfmsub231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfmsubadd132pd, Vfmsubadd132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd132pd, Vfmsubadd132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd132ps, Vfmsubadd132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd132ps, Vfmsubadd132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd213pd, Vfmsubadd213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd213pd, Vfmsubadd213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd213ps, Vfmsubadd213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd213ps, Vfmsubadd213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd231pd, Vfmsubadd231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd231pd, Vfmsubadd231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfmsubadd231ps, Vfmsubadd231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfmsubadd231ps, Vfmsubadd231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd132pd, Vfnmadd132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd132pd, Vfnmadd132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd132ps, Vfnmadd132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd132ps, Vfnmadd132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd132sd, Vfnmadd132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd132sd, Vfnmadd132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd132ss, Vfnmadd132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd132ss, Vfnmadd132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213pd, Vfnmadd213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd213pd, Vfnmadd213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd213ps, Vfnmadd213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd213ps, Vfnmadd213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd213sd, Vfnmadd213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213sd, Vfnmadd213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213ss, Vfnmadd213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd213ss, Vfnmadd213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231pd, Vfnmadd231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd231pd, Vfnmadd231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmadd231ps, Vfnmadd231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd231ps, Vfnmadd231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmadd231sd, Vfnmadd231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231sd, Vfnmadd231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231ss, Vfnmadd231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmadd231ss, Vfnmadd231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132pd, Vfnmsub132pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub132pd, Vfnmsub132pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub132ps, Vfnmsub132ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub132ps, Vfnmsub132ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub132sd, Vfnmsub132sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132sd, Vfnmsub132sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132ss, Vfnmsub132ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub132ss, Vfnmsub132ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213pd, Vfnmsub213pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub213pd, Vfnmsub213pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub213ps, Vfnmsub213ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub213ps, Vfnmsub213ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub213sd, Vfnmsub213sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213sd, Vfnmsub213sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213ss, Vfnmsub213ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub213ss, Vfnmsub213ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231pd, Vfnmsub231pd, Vec, Vec, Vec) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub231pd, Vfnmsub231pd, Vec, Vec, Mem) // FMA AVX512_F{kz|b64} + ASMJIT_INST_3x(vfnmsub231ps, Vfnmsub231ps, Vec, Vec, Vec) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub231ps, Vfnmsub231ps, Vec, Vec, Mem) // FMA AVX512_F{kz|b32} + ASMJIT_INST_3x(vfnmsub231sd, Vfnmsub231sd, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231sd, Vfnmsub231sd, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231ss, Vfnmsub231ss, Xmm, Xmm, Xmm) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfnmsub231ss, Vfnmsub231ss, Xmm, Xmm, Mem) // FMA AVX512_F{kz|er} + ASMJIT_INST_3x(vfpclasspd, Vfpclasspd, KReg, Vec, Imm) // AVX512_DQ{k|b64} + ASMJIT_INST_3x(vfpclasspd, Vfpclasspd, KReg, Mem, Imm) // AVX512_DQ{k|b64} + ASMJIT_INST_3x(vfpclassps, Vfpclassps, KReg, Vec, Imm) // AVX512_DQ{k|b32} + ASMJIT_INST_3x(vfpclassps, Vfpclassps, KReg, Mem, Imm) // AVX512_DQ{k|b32} + ASMJIT_INST_3x(vfpclasssd, Vfpclasssd, KReg, Xmm, Imm) // AVX512_DQ{k} + ASMJIT_INST_3x(vfpclasssd, Vfpclasssd, KReg, Mem, Imm) // AVX512_DQ{k} + ASMJIT_INST_3x(vfpclassss, Vfpclassss, KReg, Xmm, Imm) // AVX512_DQ{k} + ASMJIT_INST_3x(vfpclassss, Vfpclassss, KReg, Mem, Imm) // AVX512_DQ{k} + ASMJIT_INST_2x(vgatherdpd, Vgatherdpd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherdpd, Vgatherdpd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vgatherdps, Vgatherdps, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherdps, Vgatherdps, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_1x(vgatherpf0dpd, Vgatherpf0dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf0dps, Vgatherpf0dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf0qpd, Vgatherpf0qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf0qps, Vgatherpf0qps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1dpd, Vgatherpf1dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1dps, Vgatherpf1dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1qpd, Vgatherpf1qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vgatherpf1qps, Vgatherpf1qps, Mem) // AVX512_PF{k} + ASMJIT_INST_2x(vgatherqpd, Vgatherqpd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherqpd, Vgatherqpd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vgatherqps, Vgatherqps, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vgatherqps, Vgatherqps, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vgetexppd, Vgetexppd, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vgetexppd, Vgetexppd, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vgetexpps, Vgetexpps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vgetexpps, Vgetexpps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vgetexpsd, Vgetexpsd, Xmm, Xmm, Xmm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetexpsd, Vgetexpsd, Xmm, Xmm, Mem) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetexpss, Vgetexpss, Xmm, Xmm, Xmm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetexpss, Vgetexpss, Xmm, Xmm, Mem) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vgetmantpd, Vgetmantpd, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vgetmantpd, Vgetmantpd, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vgetmantps, Vgetmantps, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vgetmantps, Vgetmantps, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vgetmantsd, Vgetmantsd, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgetmantsd, Vgetmantsd, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgetmantss, Vgetmantss, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgetmantss, Vgetmantss, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vgf2p8affineinvqb, Vgf2p8affineinvqb,Vec,Vec,Vec,Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_4x(vgf2p8affineinvqb, Vgf2p8affineinvqb,Vec,Vec,Mem,Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_4x(vgf2p8affineqb, Vgf2p8affineqb, Vec, Vec, Vec, Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_4x(vgf2p8affineqb, Vgf2p8affineqb, Vec, Vec, Mem, Imm) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_3x(vgf2p8mulb, Vgf2p8mulb, Vec, Vec, Vec) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_3x(vgf2p8mulb, Vgf2p8mulb, Vec, Vec, Mem) // AVX AVX512_VL{kz} GFNI + ASMJIT_INST_3x(vhaddpd, Vhaddpd, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhaddpd, Vhaddpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vhaddps, Vhaddps, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhaddps, Vhaddps, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vhsubpd, Vhsubpd, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhsubpd, Vhsubpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vhsubps, Vhsubps, Vec, Vec, Vec) // AVX + ASMJIT_INST_3x(vhsubps, Vhsubps, Vec, Vec, Mem) // AVX + ASMJIT_INST_4x(vinsertf128, Vinsertf128, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vinsertf128, Vinsertf128, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vinsertf32x4, Vinsertf32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertf32x4, Vinsertf32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertf32x8, Vinsertf32x8, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf32x8, Vinsertf32x8, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf64x2, Vinsertf64x2, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf64x2, Vinsertf64x2, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinsertf64x4, Vinsertf64x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertf64x4, Vinsertf64x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti128, Vinserti128, Vec, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_4x(vinserti128, Vinserti128, Vec, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_4x(vinserti32x4, Vinserti32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti32x4, Vinserti32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti32x8, Vinserti32x8, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti32x8, Vinserti32x8, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti64x2, Vinserti64x2, Vec, Vec, Vec, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti64x2, Vinserti64x2, Vec, Vec, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vinserti64x4, Vinserti64x4, Vec, Vec, Vec, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinserti64x4, Vinserti64x4, Vec, Vec, Mem, Imm) // AVX512_F{kz} + ASMJIT_INST_4x(vinsertps, Vinsertps, Xmm, Xmm, Xmm, Imm) // AVX AVX512_F + ASMJIT_INST_4x(vinsertps, Vinsertps, Xmm, Xmm, Mem, Imm) // AVX AVX512_F + ASMJIT_INST_2x(vlddqu, Vlddqu, Vec, Mem) // AVX + ASMJIT_INST_1x(vldmxcsr, Vldmxcsr, Mem) // AVX + ASMJIT_INST_3x(vmaskmovdqu, Vmaskmovdqu, Vec, Vec, DS_ZDI) // AVX [EXPLICIT] + ASMJIT_INST_3x(vmaskmovpd, Vmaskmovpd, Mem, Vec, Vec) // AVX + ASMJIT_INST_3x(vmaskmovpd, Vmaskmovpd, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vmaskmovps, Vmaskmovps, Mem, Vec, Vec) // AVX + ASMJIT_INST_3x(vmaskmovps, Vmaskmovps, Vec, Vec, Mem) // AVX + ASMJIT_INST_3x(vmaxpd, Vmaxpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmaxpd, Vmaxpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmaxps, Vmaxps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmaxps, Vmaxps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmaxsd, Vmaxsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vmaxsd, Vmaxsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vmaxss, Vmaxss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vmaxss, Vmaxss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminpd, Vminpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vminpd, Vminpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vminps, Vminps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vminps, Vminps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vminsd, Vminsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminsd, Vminsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminss, Vminss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|sae} + ASMJIT_INST_3x(vminss, Vminss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|sae} + ASMJIT_INST_2x(vmovapd, Vmovapd, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovapd, Vmovapd, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovapd, Vmovapd, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovaps, Vmovaps, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovaps, Vmovaps, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovaps, Vmovaps, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovd, Vmovd, Gp, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovd, Vmovd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovd, Vmovd, Xmm, Gp) // AVX AVX512_F + ASMJIT_INST_2x(vmovd, Vmovd, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovddup, Vmovddup, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovddup, Vmovddup, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa, Vmovdqa, Vec, Vec) // AVX + ASMJIT_INST_2x(vmovdqa, Vmovdqa, Vec, Mem) // AVX + ASMJIT_INST_2x(vmovdqa, Vmovdqa, Mem, Vec) // AVX + ASMJIT_INST_2x(vmovdqa32, Vmovdqa32, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa32, Vmovdqa32, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa32, Vmovdqa32, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa64, Vmovdqa64, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa64, Vmovdqa64, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqa64, Vmovdqa64, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu, Vmovdqu, Vec, Vec) // AVX + ASMJIT_INST_2x(vmovdqu, Vmovdqu, Vec, Mem) // AVX + ASMJIT_INST_2x(vmovdqu, Vmovdqu, Mem, Vec) // AVX + ASMJIT_INST_2x(vmovdqu16, Vmovdqu16, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu16, Vmovdqu16, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu16, Vmovdqu16, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu32, Vmovdqu32, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu32, Vmovdqu32, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu32, Vmovdqu32, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu64, Vmovdqu64, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu64, Vmovdqu64, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu64, Vmovdqu64, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vmovdqu8, Vmovdqu8, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu8, Vmovdqu8, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_2x(vmovdqu8, Vmovdqu8, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vmovhlps, Vmovhlps, Xmm, Xmm, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovhpd, Vmovhpd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovhpd, Vmovhpd, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovhps, Vmovhps, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovhps, Vmovhps, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_3x(vmovlhps, Vmovlhps, Xmm, Xmm, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovlpd, Vmovlpd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovlpd, Vmovlpd, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovlps, Vmovlps, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_3x(vmovlps, Vmovlps, Xmm, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovmskpd, Vmovmskpd, Gp, Vec) // AVX + ASMJIT_INST_2x(vmovmskps, Vmovmskps, Gp, Vec) // AVX + ASMJIT_INST_2x(vmovntdq, Vmovntdq, Mem, Vec) // AVX+ AVX512_F + ASMJIT_INST_2x(vmovntdqa, Vmovntdqa, Vec, Mem) // AVX+ AVX512_F + ASMJIT_INST_2x(vmovntpd, Vmovntpd, Mem, Vec) // AVX AVX512_F + ASMJIT_INST_2x(vmovntps, Vmovntps, Mem, Vec) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Gp, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Xmm, Mem) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Xmm, Gp) // AVX AVX512_F + ASMJIT_INST_2x(vmovq, Vmovq, Xmm, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovsd, Vmovsd, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovsd, Vmovsd, Xmm, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vmovsd, Vmovsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovshdup, Vmovshdup, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovshdup, Vmovshdup, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovsldup, Vmovsldup, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovsldup, Vmovsldup, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovss, Vmovss, Mem, Xmm) // AVX AVX512_F + ASMJIT_INST_2x(vmovss, Vmovss, Xmm, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vmovss, Vmovss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovupd, Vmovupd, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovupd, Vmovupd, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovupd, Vmovupd, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovups, Vmovups, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovups, Vmovups, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_2x(vmovups, Vmovups, Mem, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_4x(vmpsadbw, Vmpsadbw, Vec, Vec, Vec, Imm) // AVX+ + ASMJIT_INST_4x(vmpsadbw, Vmpsadbw, Vec, Vec, Mem, Imm) // AVX+ + ASMJIT_INST_3x(vmulpd, Vmulpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmulpd, Vmulpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vmulps, Vmulps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmulps, Vmulps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vmulsd, Vmulsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vmulsd, Vmulsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vmulss, Vmulss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vmulss, Vmulss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vorpd, Vorpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vorpd, Vorpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vorps, Vorps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vorps, Vorps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_4x(vp2intersectd, Vp2intersectd, KReg, KReg, Vec, Vec) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_4x(vp2intersectd, Vp2intersectd, KReg, KReg, Vec, Mem) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_4x(vp2intersectq, Vp2intersectq, KReg, KReg, Vec, Vec) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_4x(vp2intersectq, Vp2intersectq, KReg, KReg, Vec, Mem) // AVX512_VP2INTERSECT{kz} + ASMJIT_INST_6x(vp4dpwssd, Vp4dpwssd, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_6x(vp4dpwssds, Vp4dpwssds, Zmm, Zmm, Zmm, Zmm, Zmm, Mem) // AVX512_4FMAPS{kz} + ASMJIT_INST_2x(vpabsb, Vpabsb, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpabsb, Vpabsb, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpabsd, Vpabsd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpabsd, Vpabsd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpabsq, Vpabsq, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpabsq, Vpabsq, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vpabsw, Vpabsw, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpabsw, Vpabsw, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpackssdw, Vpackssdw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpackssdw, Vpackssdw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpacksswb, Vpacksswb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpacksswb, Vpacksswb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpackusdw, Vpackusdw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpackusdw, Vpackusdw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz|b32} + ASMJIT_INST_3x(vpackuswb, Vpackuswb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpackuswb, Vpackuswb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddb, Vpaddb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddb, Vpaddb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddd, Vpaddd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpaddd, Vpaddd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpaddq, Vpaddq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpaddq, Vpaddq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpaddsb, Vpaddsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddsb, Vpaddsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddsw, Vpaddsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddsw, Vpaddsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusb, Vpaddusb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusb, Vpaddusb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusw, Vpaddusw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddusw, Vpaddusw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddw, Vpaddw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpaddw, Vpaddw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_4x(vpalignr, Vpalignr, Vec, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_4x(vpalignr, Vpalignr, Vec, Vec, Mem, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpand, Vpand, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpand, Vpand, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpandd, Vpandd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandd, Vpandd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandn, Vpandn, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpandn, Vpandn, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpandnd, Vpandnd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandnd, Vpandnd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpandnq, Vpandnq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpandnq, Vpandnq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpandq, Vpandq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpandq, Vpandq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpavgb, Vpavgb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpavgb, Vpavgb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpavgw, Vpavgw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpavgw, Vpavgw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_4x(vpblendd, Vpblendd, Vec, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_4x(vpblendd, Vpblendd, Vec, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_3x(vpblendmb, Vpblendmb, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpblendmb, Vpblendmb, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpblendmd, Vpblendmd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpblendmd, Vpblendmd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpblendmq, Vpblendmq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpblendmq, Vpblendmq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpblendmw, Vpblendmw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpblendmw, Vpblendmw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_4x(vpblendvb, Vpblendvb, Vec, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_4x(vpblendvb, Vpblendvb, Vec, Vec, Mem, Vec) // AVX+ + ASMJIT_INST_4x(vpblendw, Vpblendw, Vec, Vec, Vec, Imm) // AVX+ + ASMJIT_INST_4x(vpblendw, Vpblendw, Vec, Vec, Mem, Imm) // AVX+ + ASMJIT_INST_2x(vpbroadcastb, Vpbroadcastb, Vec, Vec) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastb, Vpbroadcastb, Vec, Mem) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastb, Vpbroadcastb, Vec, Gp) // AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastd, Vpbroadcastd, Vec, Vec) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastd, Vpbroadcastd, Vec, Mem) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastd, Vpbroadcastd, Vec, Gp) // AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastmb2q, Vpbroadcastmb2q, Vec, KReg) // AVX512_CD + ASMJIT_INST_2x(vpbroadcastmw2d, Vpbroadcastmw2d, Vec, KReg) // AVX512_CD + ASMJIT_INST_2x(vpbroadcastq, Vpbroadcastq, Vec, Vec) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastq, Vpbroadcastq, Vec, Mem) // AVX2 AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastq, Vpbroadcastq, Vec, Gp) // AVX512_F{kz} + ASMJIT_INST_2x(vpbroadcastw, Vpbroadcastw, Vec, Vec) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastw, Vpbroadcastw, Vec, Mem) // AVX2 AVX512_BW{kz} + ASMJIT_INST_2x(vpbroadcastw, Vpbroadcastw, Vec, Gp) // AVX512_BW{kz} + ASMJIT_INST_4x(vpclmulqdq, Vpclmulqdq, Vec, Vec, Vec, Imm) // AVX VPCLMULQDQ AVX512_F + ASMJIT_INST_4x(vpclmulqdq, Vpclmulqdq, Vec, Vec, Mem, Imm) // AVX VPCLMULQDQ AVX512_F + ASMJIT_INST_4x(vpcmpb, Vpcmpb, KReg, Vec, Vec, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpb, Vpcmpb, KReg, Vec, Mem, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpd, Vpcmpd, KReg, Vec, Vec, Imm) // AVX512_F{k|b32} + ASMJIT_INST_4x(vpcmpd, Vpcmpd, KReg, Vec, Mem, Imm) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpeqb, Vpcmpeqb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpeqd, Vpcmpeqd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpeqq, Vpcmpeqq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpeqw, Vpcmpeqw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_6x(vpcmpestri, Vpcmpestri, Vec, Vec, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_6x(vpcmpestri, Vpcmpestri, Vec, Mem, Imm, Gp_ECX, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_6x(vpcmpestrm, Vpcmpestrm, Vec, Vec, Imm, XMM0, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_6x(vpcmpestrm, Vpcmpestrm, Vec, Mem, Imm, XMM0, Gp_EAX, Gp_EDX) // AVX [EXPLICIT] + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpgtb, Vpcmpgtb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpgtd, Vpcmpgtd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpgtq, Vpcmpgtq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vpcmpgtw, Vpcmpgtw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpistri, Vpcmpistri, Vec, Vec, Imm, Gp_ECX) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpistri, Vpcmpistri, Vec, Mem, Imm, Gp_ECX) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpistrm, Vpcmpistrm, Vec, Vec, Imm, XMM0) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpistrm, Vpcmpistrm, Vec, Mem, Imm, XMM0) // AVX [EXPLICIT] + ASMJIT_INST_4x(vpcmpq, Vpcmpq, KReg, Vec, Vec, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpq, Vpcmpq, KReg, Vec, Mem, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpub, Vpcmpub, KReg, Vec, Vec, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpub, Vpcmpub, KReg, Vec, Mem, Imm) // AVX512_BW{k} + ASMJIT_INST_4x(vpcmpud, Vpcmpud, KReg, Vec, Vec, Imm) // AVX512_F{k|b32} + ASMJIT_INST_4x(vpcmpud, Vpcmpud, KReg, Vec, Mem, Imm) // AVX512_F{k|b32} + ASMJIT_INST_4x(vpcmpuq, Vpcmpuq, KReg, Vec, Vec, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpuq, Vpcmpuq, KReg, Vec, Mem, Imm) // AVX512_F{k|b64} + ASMJIT_INST_4x(vpcmpuw, Vpcmpuw, KReg, Vec, Vec, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_4x(vpcmpuw, Vpcmpuw, KReg, Vec, Mem, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_4x(vpcmpw, Vpcmpw, KReg, Vec, Vec, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_4x(vpcmpw, Vpcmpw, KReg, Vec, Mem, Imm) // AVX512_BW{k|b64} + ASMJIT_INST_2x(vpcompressb, Vpcompressb, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpcompressb, Vpcompressb, Mem, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpcompressd, Vpcompressd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressd, Vpcompressd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressq, Vpcompressq, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressq, Vpcompressq, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpcompressw, Vpcompressw, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpcompressw, Vpcompressw, Mem, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpconflictd, Vpconflictd, Vec, Vec) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vpconflictd, Vpconflictd, Vec, Mem) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vpconflictq, Vpconflictq, Vec, Vec) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vpconflictq, Vpconflictq, Vec, Mem) // AVX512_CD{kz|b32} + ASMJIT_INST_3x(vpdpbusd, Vpdpbusd, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpbusd, Vpdpbusd, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpbusds, Vpdpbusds, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpbusds, Vpdpbusds, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssd, Vpdpwssd, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssd, Vpdpwssd, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssds, Vpdpwssds, Vec, Vec, Vec) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_3x(vpdpwssds, Vpdpwssds, Vec, Vec, Mem) // AVX_VNNI AVX512_VNNI{kz|b32} + ASMJIT_INST_4x(vperm2f128, Vperm2f128, Vec, Vec, Vec, Imm) // AVX + ASMJIT_INST_4x(vperm2f128, Vperm2f128, Vec, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vperm2i128, Vperm2i128, Vec, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_4x(vperm2i128, Vperm2i128, Vec, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_3x(vpermb, Vpermb, Vec, Vec, Vec) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermb, Vpermb, Vec, Vec, Mem) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermd, Vpermd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermd, Vpermd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2b, Vpermi2b, Vec, Vec, Vec) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermi2b, Vpermi2b, Vec, Vec, Mem) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermi2d, Vpermi2d, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2d, Vpermi2d, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2pd, Vpermi2pd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2pd, Vpermi2pd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2ps, Vpermi2ps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2ps, Vpermi2ps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermi2q, Vpermi2q, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2q, Vpermi2q, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermi2w, Vpermi2w, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermi2w, Vpermi2w, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilpd, Vpermilpd, Vec, Mem, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermilps, Vpermilps, Vec, Mem, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Vec, Imm) // AVX2 + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Mem, Imm) // AVX2 + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermpd, Vpermpd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermps, Vpermps, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermps, Vpermps, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Vec, Imm) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Mem, Imm) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermq, Vpermq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2b, Vpermt2b, Vec, Vec, Vec) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermt2b, Vpermt2b, Vec, Vec, Mem) // AVX512_VBMI{kz} + ASMJIT_INST_3x(vpermt2d, Vpermt2d, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2d, Vpermt2d, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2pd, Vpermt2pd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2pd, Vpermt2pd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2ps, Vpermt2ps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2ps, Vpermt2ps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpermt2q, Vpermt2q, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2q, Vpermt2q, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpermt2w, Vpermt2w, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermt2w, Vpermt2w, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermw, Vpermw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpermw, Vpermw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_2x(vpexpandb, Vpexpandb, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpexpandb, Vpexpandb, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpexpandd, Vpexpandd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandd, Vpexpandd, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandq, Vpexpandq, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandq, Vpexpandq, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vpexpandw, Vpexpandw, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_2x(vpexpandw, Vpexpandw, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpextrb, Vpextrb, Gp, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_3x(vpextrb, Vpextrb, Mem, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_3x(vpextrd, Vpextrd, Gp, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrd, Vpextrd, Mem, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrq, Vpextrq, Gp, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrq, Vpextrq, Mem, Xmm, Imm) // AVX AVX512_DQ + ASMJIT_INST_3x(vpextrw, Vpextrw, Gp, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_3x(vpextrw, Vpextrw, Mem, Xmm, Imm) // AVX AVX512_BW + ASMJIT_INST_2x(vpgatherdd, Vpgatherdd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherdd, Vpgatherdd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vpgatherdq, Vpgatherdq, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherdq, Vpgatherdq, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vpgatherqd, Vpgatherqd, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherqd, Vpgatherqd, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_2x(vpgatherqq, Vpgatherqq, Vec, Mem) // AVX512_F{k} + ASMJIT_INST_3x(vpgatherqq, Vpgatherqq, Vec, Mem, Vec) // AVX2 + ASMJIT_INST_3x(vphaddd, Vphaddd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphaddd, Vphaddd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphaddsw, Vphaddsw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphaddsw, Vphaddsw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphaddw, Vphaddw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphaddw, Vphaddw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_2x(vphminposuw, Vphminposuw, Vec, Vec) // AVX + ASMJIT_INST_2x(vphminposuw, Vphminposuw, Vec, Mem) // AVX + ASMJIT_INST_3x(vphsubd, Vphsubd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphsubd, Vphsubd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphsubsw, Vphsubsw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphsubsw, Vphsubsw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vphsubw, Vphsubw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vphsubw, Vphsubw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_4x(vpinsrb, Vpinsrb, Xmm, Xmm, Gp, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpinsrb, Vpinsrb, Xmm, Xmm, Mem, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpinsrd, Vpinsrd, Xmm, Xmm, Gp, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrd, Vpinsrd, Xmm, Xmm, Mem, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrq, Vpinsrq, Xmm, Xmm, Gp, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrq, Vpinsrq, Xmm, Xmm, Mem, Imm) // AVX AVX512_DQ{kz} + ASMJIT_INST_4x(vpinsrw, Vpinsrw, Xmm, Xmm, Gp, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpinsrw, Vpinsrw, Xmm, Xmm, Mem, Imm) // AVX AVX512_BW{kz} + ASMJIT_INST_2x(vplzcntd, Vplzcntd, Vec, Vec) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vplzcntd, Vplzcntd, Vec, Mem) // AVX512_CD{kz|b32} + ASMJIT_INST_2x(vplzcntq, Vplzcntq, Vec, Vec) // AVX512_CD{kz|b64} + ASMJIT_INST_2x(vplzcntq, Vplzcntq, Vec, Mem) // AVX512_CD{kz|b64} + ASMJIT_INST_3x(vpmadd52huq, Vpmadd52huq, Vec, Vec, Vec) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmadd52huq, Vpmadd52huq, Vec, Vec, Mem) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmadd52luq, Vpmadd52luq, Vec, Vec, Vec) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmadd52luq, Vpmadd52luq, Vec, Vec, Mem) // AVX512_IFMA{kz|b64} + ASMJIT_INST_3x(vpmaddubsw, Vpmaddubsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaddubsw, Vpmaddubsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaddwd, Vpmaddwd, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaddwd, Vpmaddwd, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaskmovd, Vpmaskmovd, Mem, Vec, Vec) // AVX2 + ASMJIT_INST_3x(vpmaskmovd, Vpmaskmovd, Vec, Vec, Mem) // AVX2 + ASMJIT_INST_3x(vpmaskmovq, Vpmaskmovq, Mem, Vec, Vec) // AVX2 + ASMJIT_INST_3x(vpmaskmovq, Vpmaskmovq, Vec, Vec, Mem) // AVX2 + ASMJIT_INST_3x(vpmaxsb, Vpmaxsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxsb, Vpmaxsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxsd, Vpmaxsd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxsd, Vpmaxsd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxsq, Vpmaxsq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxsq, Vpmaxsq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxsw, Vpmaxsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxsw, Vpmaxsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxub, Vpmaxub, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxub, Vpmaxub, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxud, Vpmaxud, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxud, Vpmaxud, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmaxuq, Vpmaxuq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxuq, Vpmaxuq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmaxuw, Vpmaxuw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmaxuw, Vpmaxuw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsb, Vpminsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsb, Vpminsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsd, Vpminsd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminsd, Vpminsd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminsq, Vpminsq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminsq, Vpminsq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminsw, Vpminsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminsw, Vpminsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminub, Vpminub, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminub, Vpminub, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminud, Vpminud, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminud, Vpminud, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpminuq, Vpminuq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminuq, Vpminuq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpminuw, Vpminuw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpminuw, Vpminuw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovb2m, Vpmovb2m, KReg, Vec) // AVX512_BW + ASMJIT_INST_2x(vpmovd2m, Vpmovd2m, KReg, Vec) // AVX512_DQ + ASMJIT_INST_2x(vpmovdb, Vpmovdb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovdb, Vpmovdb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovdw, Vpmovdw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovdw, Vpmovdw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovm2b, Vpmovm2b, Vec, KReg) // AVX512_BW + ASMJIT_INST_2x(vpmovm2d, Vpmovm2d, Vec, KReg) // AVX512_DQ + ASMJIT_INST_2x(vpmovm2q, Vpmovm2q, Vec, KReg) // AVX512_DQ + ASMJIT_INST_2x(vpmovm2w, Vpmovm2w, Vec, KReg) // AVX512_BW + ASMJIT_INST_2x(vpmovmskb, Vpmovmskb, Gp, Vec) // AVX+ + ASMJIT_INST_2x(vpmovq2m, Vpmovq2m, KReg, Vec) // AVX512_DQ + ASMJIT_INST_2x(vpmovqb, Vpmovqb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqb, Vpmovqb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqd, Vpmovqd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqd, Vpmovqd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqw, Vpmovqw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovqw, Vpmovqw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdb, Vpmovsdb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdb, Vpmovsdb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdw, Vpmovsdw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsdw, Vpmovsdw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqb, Vpmovsqb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqb, Vpmovsqb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqd, Vpmovsqd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqd, Vpmovsqd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqw, Vpmovsqw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovsqw, Vpmovsqw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovswb, Vpmovswb, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovswb, Vpmovswb, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovsxbd, Vpmovsxbd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbd, Vpmovsxbd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbq, Vpmovsxbq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbq, Vpmovsxbq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxbw, Vpmovsxbw, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovsxbw, Vpmovsxbw, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovsxdq, Vpmovsxdq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxdq, Vpmovsxdq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwd, Vpmovsxwd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwd, Vpmovsxwd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwq, Vpmovsxwq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovsxwq, Vpmovsxwq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdb, Vpmovusdb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdb, Vpmovusdb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdw, Vpmovusdw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusdw, Vpmovusdw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqb, Vpmovusqb, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqb, Vpmovusqb, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqd, Vpmovusqd, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqd, Vpmovusqd, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqw, Vpmovusqw, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovusqw, Vpmovusqw, Mem, Vec) // AVX512_F{kz} + ASMJIT_INST_2x(vpmovuswb, Vpmovuswb, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovuswb, Vpmovuswb, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovw2m, Vpmovw2m, KReg, Vec) // AVX512_BW + ASMJIT_INST_2x(vpmovwb, Vpmovwb, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovwb, Vpmovwb, Mem, Vec) // AVX512_BW{kz} + ASMJIT_INST_2x(vpmovzxbd, Vpmovzxbd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbd, Vpmovzxbd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbq, Vpmovzxbq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbq, Vpmovzxbq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxbw, Vpmovzxbw, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovzxbw, Vpmovzxbw, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_2x(vpmovzxdq, Vpmovzxdq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxdq, Vpmovzxdq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwd, Vpmovzxwd, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwd, Vpmovzxwd, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwq, Vpmovzxwq, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_2x(vpmovzxwq, Vpmovzxwq, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpmuldq, Vpmuldq, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmuldq, Vpmuldq, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmulhrsw, Vpmulhrsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhrsw, Vpmulhrsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhuw, Vpmulhuw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhuw, Vpmulhuw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhw, Vpmulhw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulhw, Vpmulhw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmulld, Vpmulld, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmulld, Vpmulld, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpmullq, Vpmullq, Vec, Vec, Vec) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vpmullq, Vpmullq, Vec, Vec, Mem) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vpmullw, Vpmullw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmullw, Vpmullw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpmultishiftqb, Vpmultishiftqb, Vec, Vec, Vec) // AVX512_VBMI{kz|b64} + ASMJIT_INST_3x(vpmultishiftqb, Vpmultishiftqb, Vec, Vec, Mem) // AVX512_VBMI{kz|b64} + ASMJIT_INST_3x(vpmuludq, Vpmuludq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpmuludq, Vpmuludq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_2x(vpopcntb, Vpopcntb, Vec, Vec) // AVX512_BITALG{kz|b32} + ASMJIT_INST_2x(vpopcntb, Vpopcntb, Vec, Mem) // AVX512_BITALG{kz|b32} + ASMJIT_INST_2x(vpopcntd, Vpopcntd, Vec, Vec) // AVX512_VPOPCNTDQ{kz|b32} + ASMJIT_INST_2x(vpopcntd, Vpopcntd, Vec, Mem) // AVX512_VPOPCNTDQ{kz|b32} + ASMJIT_INST_2x(vpopcntq, Vpopcntq, Vec, Vec) // AVX512_VPOPCNTDQ{kz|b64} + ASMJIT_INST_2x(vpopcntq, Vpopcntq, Vec, Mem) // AVX512_VPOPCNTDQ{kz|b64} + ASMJIT_INST_2x(vpopcntw, Vpopcntw, Vec, Vec) // AVX512_BITALG{kz|b32} + ASMJIT_INST_2x(vpopcntw, Vpopcntw, Vec, Mem) // AVX512_BITALG{kz|b32} + ASMJIT_INST_3x(vpor, Vpor, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpor, Vpor, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpord, Vpord, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpord, Vpord, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vporq, Vporq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vporq, Vporq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprold, Vprold, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprold, Vprold, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprolq, Vprolq, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprolq, Vprolq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprolvd, Vprolvd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprolvd, Vprolvd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprolvq, Vprolvq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprolvq, Vprolvq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprord, Vprord, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprord, Vprord, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprorq, Vprorq, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprorq, Vprorq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprorvd, Vprorvd, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprorvd, Vprorvd, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vprorvq, Vprorvq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vprorvq, Vprorvq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsadbw, Vpsadbw, Vec, Vec, Vec) // AVX+ AVX512_BW + ASMJIT_INST_3x(vpsadbw, Vpsadbw, Vec, Vec, Mem) // AVX+ AVX512_BW + ASMJIT_INST_2x(vpscatterdd, Vpscatterdd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vpscatterdq, Vpscatterdq, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vpscatterqd, Vpscatterqd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vpscatterqq, Vpscatterqq, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_4x(vpshldd, Vpshldd, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldd, Vpshldd, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldq, Vpshldq, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldq, Vpshldq, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvd, Vpshldvd, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvd, Vpshldvd, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvq, Vpshldvq, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvq, Vpshldvq, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvw, Vpshldvw, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshldvw, Vpshldvw, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldw, Vpshldw, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshldw, Vpshldw, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdd, Vpshrdd, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdd, Vpshrdd, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdq, Vpshrdq, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdq, Vpshrdq, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvd, Vpshrdvd, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvd, Vpshrdvd, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvq, Vpshrdvq, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvq, Vpshrdvq, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvw, Vpshrdvw, Vec, Vec, Vec) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshrdvw, Vpshrdvw, Vec, Vec, Mem) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdw, Vpshrdw, Vec, Vec, Vec, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_4x(vpshrdw, Vpshrdw, Vec, Vec, Mem, Imm) // AVX512_VBMI2{kz} + ASMJIT_INST_3x(vpshufb, Vpshufb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshufb, Vpshufb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshufbitqmb, Vpshufbitqmb, KReg, Vec, Vec) // AVX512_BITALG{k} + ASMJIT_INST_3x(vpshufbitqmb, Vpshufbitqmb, KReg, Vec, Mem) // AVX512_BITALG{k} + ASMJIT_INST_3x(vpshufd, Vpshufd, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpshufd, Vpshufd, Vec, Mem, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpshufhw, Vpshufhw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshufhw, Vpshufhw, Vec, Mem, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshuflw, Vpshuflw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpshuflw, Vpshuflw, Vec, Mem, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsignb, Vpsignb, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpsignb, Vpsignb, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpsignd, Vpsignd, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpsignd, Vpsignd, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpsignw, Vpsignw, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpsignw, Vpsignw, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpslld, Vpslld, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpslldq, Vpslldq, Vec, Vec, Imm) // AVX+ AVX512_BW + ASMJIT_INST_3x(vpslldq, Vpslldq, Vec, Mem, Imm) // AVX512_BW + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsllq, Vpsllq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllvd, Vpsllvd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsllvd, Vpsllvd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsllvq, Vpsllvq, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllvq, Vpsllvq, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsllvw, Vpsllvw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsllvw, Vpsllvw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsllw, Vpsllw, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrad, Vpsrad, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Vec, Vec) // AVX512_F{kz} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Vec, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsraq, Vpsraq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsravd, Vpsravd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsravd, Vpsravd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsravq, Vpsravq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsravq, Vpsravq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsravw, Vpsravw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsravw, Vpsravw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsraw, Vpsraw, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Vec, Imm) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Vec, Vec) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Vec, Mem) // AVX+ AVX512_F{kz} + ASMJIT_INST_3x(vpsrld, Vpsrld, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrldq, Vpsrldq, Vec, Vec, Imm) // AVX+ AVX512_BW + ASMJIT_INST_3x(vpsrldq, Vpsrldq, Vec, Mem, Imm) // AVX512_BW + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Vec, Vec) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Vec, Mem) // AVX AVX512_F{kz} + ASMJIT_INST_3x(vpsrlq, Vpsrlq, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlvd, Vpsrlvd, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrlvd, Vpsrlvd, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsrlvq, Vpsrlvq, Vec, Vec, Vec) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlvq, Vpsrlvq, Vec, Vec, Mem) // AVX2 AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsrlvw, Vpsrlvw, Vec, Vec, Vec) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlvw, Vpsrlvw, Vec, Vec, Mem) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Vec, Imm) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsrlw, Vpsrlw, Vec, Mem, Imm) // AVX512_BW{kz} + ASMJIT_INST_3x(vpsubb, Vpsubb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubb, Vpsubb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubd, Vpsubd, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsubd, Vpsubd, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpsubq, Vpsubq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsubq, Vpsubq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpsubsb, Vpsubsb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubsb, Vpsubsb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubsw, Vpsubsw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubsw, Vpsubsw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusb, Vpsubusb, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusb, Vpsubusb, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusw, Vpsubusw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubusw, Vpsubusw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpsubw, Vpsubw, Vec, Vec, Vec) // AVX AVX512_BW{kz} + ASMJIT_INST_3x(vpsubw, Vpsubw, Vec, Vec, Mem) // AVX AVX512_BW{kz} + ASMJIT_INST_4x(vpternlogd, Vpternlogd, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vpternlogd, Vpternlogd, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vpternlogq, Vpternlogq, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vpternlogq, Vpternlogq, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vptest, Vptest, Vec, Vec) // AVX + ASMJIT_INST_2x(vptest, Vptest, Vec, Mem) // AVX + ASMJIT_INST_3x(vptestmb, Vptestmb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestmb, Vptestmb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vptestmd, Vptestmd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestmd, Vptestmd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestmq, Vptestmq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestmq, Vptestmq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestmw, Vptestmw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestmw, Vptestmw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmb, Vptestnmb, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmb, Vptestnmb, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmd, Vptestnmd, KReg, Vec, Vec) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestnmd, Vptestnmd, KReg, Vec, Mem) // AVX512_F{k|b32} + ASMJIT_INST_3x(vptestnmq, Vptestnmq, KReg, Vec, Vec) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestnmq, Vptestnmq, KReg, Vec, Mem) // AVX512_F{k|b64} + ASMJIT_INST_3x(vptestnmw, Vptestnmw, KReg, Vec, Vec) // AVX512_BW{k} + ASMJIT_INST_3x(vptestnmw, Vptestnmw, KReg, Vec, Mem) // AVX512_BW{k} + ASMJIT_INST_3x(vpunpckhbw, Vpunpckhbw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckhbw, Vpunpckhbw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckhdq, Vpunpckhdq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpckhdq, Vpunpckhdq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpckhqdq, Vpunpckhqdq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpckhqdq, Vpunpckhqdq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpckhwd, Vpunpckhwd, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckhwd, Vpunpckhwd, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpcklbw, Vpunpcklbw, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpcklbw, Vpunpcklbw, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpckldq, Vpunpckldq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpckldq, Vpunpckldq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b32} + ASMJIT_INST_3x(vpunpcklqdq, Vpunpcklqdq, Vec, Vec, Vec) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpcklqdq, Vpunpcklqdq, Vec, Vec, Mem) // AVX+ AVX512_F{kz|b64} + ASMJIT_INST_3x(vpunpcklwd, Vpunpcklwd, Vec, Vec, Vec) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpunpcklwd, Vpunpcklwd, Vec, Vec, Mem) // AVX+ AVX512_BW{kz} + ASMJIT_INST_3x(vpxor, Vpxor, Vec, Vec, Vec) // AVX+ + ASMJIT_INST_3x(vpxor, Vpxor, Vec, Vec, Mem) // AVX+ + ASMJIT_INST_3x(vpxord, Vpxord, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpxord, Vpxord, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vpxorq, Vpxorq, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vpxorq, Vpxorq, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vrangepd, Vrangepd, Vec, Vec, Vec, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_4x(vrangepd, Vrangepd, Vec, Vec, Mem, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_4x(vrangeps, Vrangeps, Vec, Vec, Vec, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_4x(vrangeps, Vrangeps, Vec, Vec, Mem, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_4x(vrangesd, Vrangesd, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_4x(vrangesd, Vrangesd, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_4x(vrangess, Vrangess, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_4x(vrangess, Vrangess, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz|sae} + ASMJIT_INST_2x(vrcp14pd, Vrcp14pd, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrcp14pd, Vrcp14pd, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrcp14ps, Vrcp14ps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vrcp14ps, Vrcp14ps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vrcp14sd, Vrcp14sd, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrcp14sd, Vrcp14sd, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vrcp14ss, Vrcp14ss, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrcp14ss, Vrcp14ss, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vrcp28pd, Vrcp28pd, Vec, Vec) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrcp28pd, Vrcp28pd, Vec, Mem) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrcp28ps, Vrcp28ps, Vec, Vec) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vrcp28ps, Vrcp28ps, Vec, Mem) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_3x(vrcp28sd, Vrcp28sd, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrcp28sd, Vrcp28sd, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrcp28ss, Vrcp28ss, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrcp28ss, Vrcp28ss, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_2x(vrcpps, Vrcpps, Vec, Vec) // AVX + ASMJIT_INST_2x(vrcpps, Vrcpps, Vec, Mem) // AVX + ASMJIT_INST_3x(vrcpss, Vrcpss, Xmm, Xmm, Xmm) // AVX + ASMJIT_INST_3x(vrcpss, Vrcpss, Xmm, Xmm, Mem) // AVX + ASMJIT_INST_3x(vreducepd, Vreducepd, Vec, Vec, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vreducepd, Vreducepd, Vec, Mem, Imm) // AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vreduceps, Vreduceps, Vec, Vec, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vreduceps, Vreduceps, Vec, Mem, Imm) // AVX512_DQ{kz|b32} + ASMJIT_INST_4x(vreducesd, Vreducesd, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vreducesd, Vreducesd, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vreducess, Vreducess, Xmm, Xmm, Xmm, Imm) // AVX512_DQ{kz} + ASMJIT_INST_4x(vreducess, Vreducess, Xmm, Xmm, Mem, Imm) // AVX512_DQ{kz} + ASMJIT_INST_3x(vrndscalepd, Vrndscalepd, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vrndscalepd, Vrndscalepd, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vrndscaleps, Vrndscaleps, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vrndscaleps, Vrndscaleps, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vrndscalesd, Vrndscalesd, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vrndscalesd, Vrndscalesd, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vrndscaless, Vrndscaless, Xmm, Xmm, Xmm, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_4x(vrndscaless, Vrndscaless, Xmm, Xmm, Mem, Imm) // AVX512_F{kz|sae} + ASMJIT_INST_3x(vroundpd, Vroundpd, Vec, Vec, Imm) // AVX + ASMJIT_INST_3x(vroundpd, Vroundpd, Vec, Mem, Imm) // AVX + ASMJIT_INST_3x(vroundps, Vroundps, Vec, Vec, Imm) // AVX + ASMJIT_INST_3x(vroundps, Vroundps, Vec, Mem, Imm) // AVX + ASMJIT_INST_4x(vroundsd, Vroundsd, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vroundsd, Vroundsd, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_4x(vroundss, Vroundss, Xmm, Xmm, Xmm, Imm) // AVX + ASMJIT_INST_4x(vroundss, Vroundss, Xmm, Xmm, Mem, Imm) // AVX + ASMJIT_INST_2x(vrsqrt14pd, Vrsqrt14pd, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrsqrt14pd, Vrsqrt14pd, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_2x(vrsqrt14ps, Vrsqrt14ps, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_2x(vrsqrt14ps, Vrsqrt14ps, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vrsqrt14sd, Vrsqrt14sd, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrsqrt14sd, Vrsqrt14sd, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_3x(vrsqrt14ss, Vrsqrt14ss, Xmm, Xmm, Xmm) // AVX512_F{kz} + ASMJIT_INST_3x(vrsqrt14ss, Vrsqrt14ss, Xmm, Xmm, Mem) // AVX512_F{kz} + ASMJIT_INST_2x(vrsqrt28pd, Vrsqrt28pd, Vec, Vec) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrsqrt28pd, Vrsqrt28pd, Vec, Mem) // AVX512_ER{kz|sae|b64} + ASMJIT_INST_2x(vrsqrt28ps, Vrsqrt28ps, Vec, Vec) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_2x(vrsqrt28ps, Vrsqrt28ps, Vec, Mem) // AVX512_ER{kz|sae|b32} + ASMJIT_INST_3x(vrsqrt28sd, Vrsqrt28sd, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrsqrt28sd, Vrsqrt28sd, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrsqrt28ss, Vrsqrt28ss, Xmm, Xmm, Xmm) // AVX512_ER{kz|sae} + ASMJIT_INST_3x(vrsqrt28ss, Vrsqrt28ss, Xmm, Xmm, Mem) // AVX512_ER{kz|sae} + ASMJIT_INST_2x(vrsqrtps, Vrsqrtps, Vec, Vec) // AVX + ASMJIT_INST_2x(vrsqrtps, Vrsqrtps, Vec, Mem) // AVX + ASMJIT_INST_3x(vrsqrtss, Vrsqrtss, Xmm, Xmm, Xmm) // AVX + ASMJIT_INST_3x(vrsqrtss, Vrsqrtss, Xmm, Xmm, Mem) // AVX + ASMJIT_INST_3x(vscalefpd, Vscalefpd, Vec, Vec, Vec) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vscalefpd, Vscalefpd, Vec, Vec, Mem) // AVX512_F{kz|b64} + ASMJIT_INST_3x(vscalefps, Vscalefps, Vec, Vec, Vec) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vscalefps, Vscalefps, Vec, Vec, Mem) // AVX512_F{kz|b32} + ASMJIT_INST_3x(vscalefsd, Vscalefsd, Xmm, Xmm, Xmm) // AVX512_F{kz|er} + ASMJIT_INST_3x(vscalefsd, Vscalefsd, Xmm, Xmm, Mem) // AVX512_F{kz|er} + ASMJIT_INST_3x(vscalefss, Vscalefss, Xmm, Xmm, Xmm) // AVX512_F{kz|er} + ASMJIT_INST_3x(vscalefss, Vscalefss, Xmm, Xmm, Mem) // AVX512_F{kz|er} + ASMJIT_INST_2x(vscatterdpd, Vscatterdpd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vscatterdps, Vscatterdps, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_1x(vscatterpf0dpd, Vscatterpf0dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf0dps, Vscatterpf0dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf0qpd, Vscatterpf0qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf0qps, Vscatterpf0qps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1dpd, Vscatterpf1dpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1dps, Vscatterpf1dps, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1qpd, Vscatterpf1qpd, Mem) // AVX512_PF{k} + ASMJIT_INST_1x(vscatterpf1qps, Vscatterpf1qps, Mem) // AVX512_PF{k} + ASMJIT_INST_2x(vscatterqpd, Vscatterqpd, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_2x(vscatterqps, Vscatterqps, Mem, Vec) // AVX512_F{k} + ASMJIT_INST_4x(vshuff32x4, Vshuff32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshuff32x4, Vshuff32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshuff64x2, Vshuff64x2, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshuff64x2, Vshuff64x2, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufi32x4, Vshufi32x4, Vec, Vec, Vec, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufi32x4, Vshufi32x4, Vec, Vec, Mem, Imm) // AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufi64x2, Vshufi64x2, Vec, Vec, Vec, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufi64x2, Vshufi64x2, Vec, Vec, Mem, Imm) // AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufpd, Vshufpd, Vec, Vec, Vec, Imm) // AVX AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufpd, Vshufpd, Vec, Vec, Mem, Imm) // AVX AVX512_F{kz|b32} + ASMJIT_INST_4x(vshufps, Vshufps, Vec, Vec, Vec, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_4x(vshufps, Vshufps, Vec, Vec, Mem, Imm) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vsqrtpd, Vsqrtpd, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vsqrtpd, Vsqrtpd, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_2x(vsqrtps, Vsqrtps, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_2x(vsqrtps, Vsqrtps, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vsqrtsd, Vsqrtsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsqrtsd, Vsqrtsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsqrtss, Vsqrtss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsqrtss, Vsqrtss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_1x(vstmxcsr, Vstmxcsr, Mem) // AVX + ASMJIT_INST_3x(vsubpd, Vsubpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vsubpd, Vsubpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vsubps, Vsubps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vsubps, Vsubps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vsubsd, Vsubsd, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsubsd, Vsubsd, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsubss, Vsubss, Xmm, Xmm, Xmm) // AVX AVX512_F{kz|er} + ASMJIT_INST_3x(vsubss, Vsubss, Xmm, Xmm, Mem) // AVX AVX512_F{kz|er} + ASMJIT_INST_2x(vtestpd, Vtestpd, Vec, Vec) // AVX + ASMJIT_INST_2x(vtestpd, Vtestpd, Vec, Mem) // AVX + ASMJIT_INST_2x(vtestps, Vtestps, Vec, Vec) // AVX + ASMJIT_INST_2x(vtestps, Vtestps, Vec, Mem) // AVX + ASMJIT_INST_2x(vucomisd, Vucomisd, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vucomisd, Vucomisd, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vucomiss, Vucomiss, Xmm, Xmm) // AVX AVX512_F{sae} + ASMJIT_INST_2x(vucomiss, Vucomiss, Xmm, Mem) // AVX AVX512_F{sae} + ASMJIT_INST_3x(vunpckhpd, Vunpckhpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpckhpd, Vunpckhpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpckhps, Vunpckhps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vunpckhps, Vunpckhps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vunpcklpd, Vunpcklpd, Vec, Vec, Vec) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpcklpd, Vunpcklpd, Vec, Vec, Mem) // AVX AVX512_F{kz|b64} + ASMJIT_INST_3x(vunpcklps, Vunpcklps, Vec, Vec, Vec) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vunpcklps, Vunpcklps, Vec, Vec, Mem) // AVX AVX512_F{kz|b32} + ASMJIT_INST_3x(vxorpd, Vxorpd, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vxorpd, Vxorpd, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b64} + ASMJIT_INST_3x(vxorps, Vxorps, Vec, Vec, Vec) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_3x(vxorps, Vxorps, Vec, Vec, Mem) // AVX AVX512_DQ{kz|b32} + ASMJIT_INST_0x(vzeroall, Vzeroall) // AVX + ASMJIT_INST_0x(vzeroupper, Vzeroupper) // AVX + + //! \} + + //! \name FMA4 Instructions + //! \{ + + ASMJIT_INST_4x(vfmaddpd, Vfmaddpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddpd, Vfmaddpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddpd, Vfmaddpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddps, Vfmaddps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddps, Vfmaddps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddps, Vfmaddps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddsd, Vfmaddsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddsd, Vfmaddsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddsd, Vfmaddsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddss, Vfmaddss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddss, Vfmaddss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmaddss, Vfmaddss, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddsubpd, Vfmaddsubpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubpd, Vfmaddsubpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubpd, Vfmaddsubpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmaddsubps, Vfmaddsubps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubps, Vfmaddsubps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmaddsubps, Vfmaddsubps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubaddpd, Vfmsubaddpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddpd, Vfmsubaddpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddpd, Vfmsubaddpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubaddps, Vfmsubaddps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddps, Vfmsubaddps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubaddps, Vfmsubaddps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubpd, Vfmsubpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubpd, Vfmsubpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubpd, Vfmsubpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubps, Vfmsubps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubps, Vfmsubps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfmsubps, Vfmsubps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubsd, Vfmsubsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubsd, Vfmsubsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubsd, Vfmsubsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfmsubss, Vfmsubss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubss, Vfmsubss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfmsubss, Vfmsubss, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddpd, Vfnmaddpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddpd, Vfnmaddpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddpd, Vfnmaddpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddps, Vfnmaddps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddps, Vfnmaddps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmaddps, Vfnmaddps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddsd, Vfnmaddsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddsd, Vfnmaddsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddsd, Vfnmaddsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmaddss, Vfnmaddss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddss, Vfnmaddss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmaddss, Vfnmaddss, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubpd, Vfnmsubpd, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubpd, Vfnmsubpd, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubpd, Vfnmsubpd, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubps, Vfnmsubps, Vec, Vec, Vec, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubps, Vfnmsubps, Vec, Vec, Mem, Vec) // FMA4 + ASMJIT_INST_4x(vfnmsubps, Vfnmsubps, Vec, Vec, Vec, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubsd, Vfnmsubsd, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubsd, Vfnmsubsd, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubsd, Vfnmsubsd, Xmm, Xmm, Xmm, Mem) // FMA4 + ASMJIT_INST_4x(vfnmsubss, Vfnmsubss, Xmm, Xmm, Xmm, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubss, Vfnmsubss, Xmm, Xmm, Mem, Xmm) // FMA4 + ASMJIT_INST_4x(vfnmsubss, Vfnmsubss, Xmm, Xmm, Xmm, Mem) // FMA4 + + //! \} + + //! \name XOP Instructions (Deprecated) + //! \{ + + ASMJIT_INST_2x(vfrczpd, Vfrczpd, Vec, Vec) // XOP + ASMJIT_INST_2x(vfrczpd, Vfrczpd, Vec, Mem) // XOP + ASMJIT_INST_2x(vfrczps, Vfrczps, Vec, Vec) // XOP + ASMJIT_INST_2x(vfrczps, Vfrczps, Vec, Mem) // XOP + ASMJIT_INST_2x(vfrczsd, Vfrczsd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vfrczsd, Vfrczsd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vfrczss, Vfrczss, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vfrczss, Vfrczss, Xmm, Mem) // XOP + ASMJIT_INST_4x(vpcmov, Vpcmov, Vec, Vec, Vec, Vec) // XOP + ASMJIT_INST_4x(vpcmov, Vpcmov, Vec, Vec, Mem, Vec) // XOP + ASMJIT_INST_4x(vpcmov, Vpcmov, Vec, Vec, Vec, Mem) // XOP + ASMJIT_INST_4x(vpcomb, Vpcomb, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomb, Vpcomb, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomd, Vpcomd, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomd, Vpcomd, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomq, Vpcomq, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomq, Vpcomq, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomw, Vpcomw, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomw, Vpcomw, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomub, Vpcomub, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomub, Vpcomub, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomud, Vpcomud, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomud, Vpcomud, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomuq, Vpcomuq, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomuq, Vpcomuq, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_4x(vpcomuw, Vpcomuw, Xmm, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_4x(vpcomuw, Vpcomuw, Xmm, Xmm, Mem, Imm) // XOP + ASMJIT_INST_5x(vpermil2pd, Vpermil2pd, Vec, Vec, Vec, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2pd, Vpermil2pd, Vec, Vec, Mem, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2pd, Vpermil2pd, Vec, Vec, Vec, Mem, Imm) // XOP + ASMJIT_INST_5x(vpermil2ps, Vpermil2ps, Vec, Vec, Vec, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2ps, Vpermil2ps, Vec, Vec, Mem, Vec, Imm) // XOP + ASMJIT_INST_5x(vpermil2ps, Vpermil2ps, Vec, Vec, Vec, Mem, Imm) // XOP + ASMJIT_INST_2x(vphaddbd, Vphaddbd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddbd, Vphaddbd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddbq, Vphaddbq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddbq, Vphaddbq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddbw, Vphaddbw, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddbw, Vphaddbw, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphadddq, Vphadddq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphadddq, Vphadddq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddwd, Vphaddwd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddwd, Vphaddwd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddwq, Vphaddwq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddwq, Vphaddwq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddubd, Vphaddubd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddubd, Vphaddubd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddubq, Vphaddubq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddubq, Vphaddubq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddubw, Vphaddubw, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddubw, Vphaddubw, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphaddudq, Vphaddudq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphaddudq, Vphaddudq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphadduwd, Vphadduwd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphadduwd, Vphadduwd, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphadduwq, Vphadduwq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphadduwq, Vphadduwq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphsubbw, Vphsubbw, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphsubbw, Vphsubbw, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphsubdq, Vphsubdq, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphsubdq, Vphsubdq, Xmm, Mem) // XOP + ASMJIT_INST_2x(vphsubwd, Vphsubwd, Xmm, Xmm) // XOP + ASMJIT_INST_2x(vphsubwd, Vphsubwd, Xmm, Mem) // XOP + ASMJIT_INST_4x(vpmacsdd, Vpmacsdd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdd, Vpmacsdd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdqh, Vpmacsdqh, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdqh, Vpmacsdqh, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdql, Vpmacsdql, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsdql, Vpmacsdql, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacswd, Vpmacswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacswd, Vpmacswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsww, Vpmacsww, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsww, Vpmacsww, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdd, Vpmacssdd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdd, Vpmacssdd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdqh, Vpmacssdqh, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdqh, Vpmacssdqh, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdql, Vpmacssdql, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssdql, Vpmacssdql, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacsswd, Vpmacsswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacsswd, Vpmacsswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmacssww, Vpmacssww, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmacssww, Vpmacssww, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmadcsswd, Vpmadcsswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmadcsswd, Vpmadcsswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpmadcswd, Vpmadcswd, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpmadcswd, Vpmadcswd, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpperm, Vpperm, Xmm, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_4x(vpperm, Vpperm, Xmm, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_4x(vpperm, Vpperm, Xmm, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotb, Vprotb, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotd, Vprotd, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotq, Vprotq, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Xmm, Imm) // XOP + ASMJIT_INST_3x(vprotw, Vprotw, Xmm, Mem, Imm) // XOP + ASMJIT_INST_3x(vpshab, Vpshab, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshab, Vpshab, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshab, Vpshab, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshad, Vpshad, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshad, Vpshad, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshad, Vpshad, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshaq, Vpshaq, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshaq, Vpshaq, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshaq, Vpshaq, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshaw, Vpshaw, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshaw, Vpshaw, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshaw, Vpshaw, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshlb, Vpshlb, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshlb, Vpshlb, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshlb, Vpshlb, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshld, Vpshld, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshld, Vpshld, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshld, Vpshld, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshlq, Vpshlq, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshlq, Vpshlq, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshlq, Vpshlq, Xmm, Xmm, Mem) // XOP + ASMJIT_INST_3x(vpshlw, Vpshlw, Xmm, Xmm, Xmm) // XOP + ASMJIT_INST_3x(vpshlw, Vpshlw, Xmm, Mem, Xmm) // XOP + ASMJIT_INST_3x(vpshlw, Vpshlw, Xmm, Xmm, Mem) // XOP + + //! \} + + //! \name AVX_NE_CONVERT Instructions + //! \{ + + ASMJIT_INST_2x(vbcstnebf162ps, Vbcstnebf162ps, Vec, Mem) + ASMJIT_INST_2x(vbcstnesh2ps, Vbcstnesh2ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneebf162ps, Vcvtneebf162ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneeph2ps, Vcvtneeph2ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneobf162ps, Vcvtneobf162ps, Vec, Mem) + ASMJIT_INST_2x(vcvtneoph2ps, Vcvtneoph2ps, Vec, Mem) + + //! \} + + //! \name AVX_VNNI_INT8 Instructions + //! \{ + + ASMJIT_INST_3x(vpdpbssd, Vpdpbssd, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbssd, Vpdpbssd, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbssds, Vpdpbssds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbssds, Vpdpbssds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbsud, Vpdpbsud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbsud, Vpdpbsud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbsuds, Vpdpbsuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbsuds, Vpdpbsuds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbuud, Vpdpbuud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbuud, Vpdpbuud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpbuuds, Vpdpbuuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpbuuds, Vpdpbuuds, Vec, Vec, Mem) + + //! \} + + //! \name AVX_VNNI_INT16 Instructions + //! \{ + + ASMJIT_INST_3x(vpdpwsud, Vpdpwsud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwsud, Vpdpwsud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwsuds, Vpdpwsuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwsuds, Vpdpwsuds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwusd, Vpdpwusd, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwusd, Vpdpwusd, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwusds, Vpdpwusds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwusds, Vpdpwusds, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwuud, Vpdpwuud, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwuud, Vpdpwuud, Vec, Vec, Mem) + ASMJIT_INST_3x(vpdpwuuds, Vpdpwuuds, Vec, Vec, Vec) + ASMJIT_INST_3x(vpdpwuuds, Vpdpwuuds, Vec, Vec, Mem) + + //! \} + + //! \name AVX+SHA512 Instructions + //! \{ + ASMJIT_INST_2x(vsha512msg1, Vsha512msg1, Vec, Vec) + ASMJIT_INST_2x(vsha512msg2, Vsha512msg2, Vec, Vec) + ASMJIT_INST_3x(vsha512rnds2, Vsha512rnds2, Vec, Vec, Vec) + //! \} + + //! \name AVX+SM3 Instructions + //! \{ + + ASMJIT_INST_3x(vsm3msg1, Vsm3msg1, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm3msg1, Vsm3msg1, Vec, Vec, Mem) + ASMJIT_INST_3x(vsm3msg2, Vsm3msg2, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm3msg2, Vsm3msg2, Vec, Vec, Mem) + ASMJIT_INST_4x(vsm3rnds2, Vsm3rnds2, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vsm3rnds2, Vsm3rnds2, Vec, Vec, Mem, Imm) + + //! \} + + //! \name AVX+SM4 Instructions + //! \{ + + ASMJIT_INST_3x(vsm4key4, Vsm4key4, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm4key4, Vsm4key4, Vec, Vec, Mem) + ASMJIT_INST_3x(vsm4rnds4, Vsm4rnds4, Vec, Vec, Vec) + ASMJIT_INST_3x(vsm4rnds4, Vsm4rnds4, Vec, Vec, Mem) + + //! \} + + //! \name AVX512_FP16 Instructions + //! \{ + + ASMJIT_INST_3x(vaddph, Vaddph, Vec, Vec, Vec) + ASMJIT_INST_3x(vaddph, Vaddph, Vec, Vec, Mem) + ASMJIT_INST_3x(vaddsh, Vaddsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vaddsh, Vaddsh, Vec, Vec, Mem) + ASMJIT_INST_4x(vcmpph, Vcmpph, KReg, Vec, Vec, Imm) + ASMJIT_INST_4x(vcmpph, Vcmpph, KReg, Vec, Mem, Imm) + ASMJIT_INST_4x(vcmpsh, Vcmpsh, KReg, Vec, Vec, Imm) + ASMJIT_INST_4x(vcmpsh, Vcmpsh, KReg, Vec, Mem, Imm) + ASMJIT_INST_2x(vcomish, Vcomish, Vec, Vec) + ASMJIT_INST_2x(vcomish, Vcomish, Vec, Mem) + ASMJIT_INST_2x(vcvtdq2ph, Vcvtdq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtdq2ph, Vcvtdq2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtpd2ph, Vcvtpd2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtpd2ph, Vcvtpd2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtph2dq, Vcvtph2dq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2dq, Vcvtph2dq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2pd, Vcvtph2pd, Vec, Vec) + ASMJIT_INST_2x(vcvtph2pd, Vcvtph2pd, Vec, Mem) + ASMJIT_INST_2x(vcvtph2psx, Vcvtph2psx, Vec, Vec) + ASMJIT_INST_2x(vcvtph2psx, Vcvtph2psx, Vec, Mem) + ASMJIT_INST_2x(vcvtph2qq, Vcvtph2qq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2qq, Vcvtph2qq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2udq, Vcvtph2udq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2udq, Vcvtph2udq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2uqq, Vcvtph2uqq, Vec, Vec) + ASMJIT_INST_2x(vcvtph2uqq, Vcvtph2uqq, Vec, Mem) + ASMJIT_INST_2x(vcvtph2uw, Vcvtph2uw, Vec, Vec) + ASMJIT_INST_2x(vcvtph2uw, Vcvtph2uw, Vec, Mem) + ASMJIT_INST_2x(vcvtph2w, Vcvtph2w, Vec, Vec) + ASMJIT_INST_2x(vcvtph2w, Vcvtph2w, Vec, Mem) + ASMJIT_INST_2x(vcvtps2phx, Vcvtps2phx, Vec, Vec) + ASMJIT_INST_2x(vcvtps2phx, Vcvtps2phx, Vec, Mem) + ASMJIT_INST_2x(vcvtqq2ph, Vcvtqq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtqq2ph, Vcvtqq2ph, Vec, Mem) + ASMJIT_INST_3x(vcvtsd2sh, Vcvtsd2sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtsd2sh, Vcvtsd2sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vcvtsh2sd, Vcvtsh2sd, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtsh2sd, Vcvtsh2sd, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvtsh2si, Vcvtsh2si, Gp, Vec) + ASMJIT_INST_2x(vcvtsh2si, Vcvtsh2si, Gp, Mem) + ASMJIT_INST_3x(vcvtsh2ss, Vcvtsh2ss, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtsh2ss, Vcvtsh2ss, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvtsh2usi, Vcvtsh2usi, Gp, Vec) + ASMJIT_INST_2x(vcvtsh2usi, Vcvtsh2usi, Gp, Mem) + ASMJIT_INST_3x(vcvtsi2sh, Vcvtsi2sh, Vec, Vec, Gp) + ASMJIT_INST_3x(vcvtsi2sh, Vcvtsi2sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vcvtss2sh, Vcvtss2sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vcvtss2sh, Vcvtss2sh, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvttph2dq, Vcvttph2dq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2dq, Vcvttph2dq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2qq, Vcvttph2qq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2qq, Vcvttph2qq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2udq, Vcvttph2udq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2udq, Vcvttph2udq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2uqq, Vcvttph2uqq, Vec, Vec) + ASMJIT_INST_2x(vcvttph2uqq, Vcvttph2uqq, Vec, Mem) + ASMJIT_INST_2x(vcvttph2uw, Vcvttph2uw, Vec, Vec) + ASMJIT_INST_2x(vcvttph2uw, Vcvttph2uw, Vec, Mem) + ASMJIT_INST_2x(vcvttph2w, Vcvttph2w, Vec, Vec) + ASMJIT_INST_2x(vcvttph2w, Vcvttph2w, Vec, Mem) + ASMJIT_INST_2x(vcvttsh2si, Vcvttsh2si, Gp, Vec) + ASMJIT_INST_2x(vcvttsh2si, Vcvttsh2si, Gp, Mem) + ASMJIT_INST_2x(vcvttsh2usi, Vcvttsh2usi, Gp, Vec) + ASMJIT_INST_2x(vcvttsh2usi, Vcvttsh2usi, Gp, Mem) + ASMJIT_INST_2x(vcvtudq2ph, Vcvtudq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtudq2ph, Vcvtudq2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtuqq2ph, Vcvtuqq2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtuqq2ph, Vcvtuqq2ph, Vec, Mem) + ASMJIT_INST_3x(vcvtusi2sh, Vcvtusi2sh, Vec, Vec, Gp) + ASMJIT_INST_3x(vcvtusi2sh, Vcvtusi2sh, Vec, Vec, Mem) + ASMJIT_INST_2x(vcvtuw2ph, Vcvtuw2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtuw2ph, Vcvtuw2ph, Vec, Mem) + ASMJIT_INST_2x(vcvtw2ph, Vcvtw2ph, Vec, Vec) + ASMJIT_INST_2x(vcvtw2ph, Vcvtw2ph, Vec, Mem) + ASMJIT_INST_3x(vdivph, Vdivph, Vec, Vec, Vec) + ASMJIT_INST_3x(vdivph, Vdivph, Vec, Vec, Mem) + ASMJIT_INST_3x(vdivsh, Vdivsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vdivsh, Vdivsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmaddcph, Vfcmaddcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmaddcph, Vfcmaddcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmaddcsh, Vfcmaddcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmaddcsh, Vfcmaddcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmulcph, Vfcmulcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmulcph, Vfcmulcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfcmulcsh, Vfcmulcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfcmulcsh, Vfcmulcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd132ph, Vfmadd132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd132ph, Vfmadd132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd132sh, Vfmadd132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd132sh, Vfmadd132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd213ph, Vfmadd213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd213ph, Vfmadd213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd213sh, Vfmadd213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd213sh, Vfmadd213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd231ph, Vfmadd231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd231ph, Vfmadd231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmadd231sh, Vfmadd231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmadd231sh, Vfmadd231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddcph, Vfmaddcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddcph, Vfmaddcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddcsh, Vfmaddcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddcsh, Vfmaddcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddsub132ph, Vfmaddsub132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddsub132ph, Vfmaddsub132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddsub213ph, Vfmaddsub213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddsub213ph, Vfmaddsub213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmaddsub231ph, Vfmaddsub231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmaddsub231ph, Vfmaddsub231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub132ph, Vfmsub132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub132ph, Vfmsub132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub132sh, Vfmsub132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub132sh, Vfmsub132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub213ph, Vfmsub213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub213ph, Vfmsub213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub213sh, Vfmsub213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub213sh, Vfmsub213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub231ph, Vfmsub231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub231ph, Vfmsub231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsub231sh, Vfmsub231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsub231sh, Vfmsub231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsubadd132ph, Vfmsubadd132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsubadd132ph, Vfmsubadd132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsubadd213ph, Vfmsubadd213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsubadd213ph, Vfmsubadd213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmsubadd231ph, Vfmsubadd231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmsubadd231ph, Vfmsubadd231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmulcph, Vfmulcph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmulcph, Vfmulcph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfmulcsh, Vfmulcsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfmulcsh, Vfmulcsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd132ph, Vfnmadd132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd132ph, Vfnmadd132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd132sh, Vfnmadd132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd132sh, Vfnmadd132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd213ph, Vfnmadd213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd213ph, Vfnmadd213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd213sh, Vfnmadd213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd213sh, Vfnmadd213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd231ph, Vfnmadd231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd231ph, Vfnmadd231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmadd231sh, Vfnmadd231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmadd231sh, Vfnmadd231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub132ph, Vfnmsub132ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub132ph, Vfnmsub132ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub132sh, Vfnmsub132sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub132sh, Vfnmsub132sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub213ph, Vfnmsub213ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub213ph, Vfnmsub213ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub213sh, Vfnmsub213sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub213sh, Vfnmsub213sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub231ph, Vfnmsub231ph, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub231ph, Vfnmsub231ph, Vec, Vec, Mem) + ASMJIT_INST_3x(vfnmsub231sh, Vfnmsub231sh, Vec, Vec, Vec) + ASMJIT_INST_3x(vfnmsub231sh, Vfnmsub231sh, Vec, Vec, Mem) + ASMJIT_INST_3x(vfpclassph, Vfpclassph, KReg, Vec, Imm) + ASMJIT_INST_3x(vfpclassph, Vfpclassph, KReg, Mem, Imm) + ASMJIT_INST_3x(vfpclasssh, Vfpclasssh, KReg, Vec, Imm) + ASMJIT_INST_3x(vfpclasssh, Vfpclasssh, KReg, Mem, Imm) + ASMJIT_INST_2x(vgetexpph, Vgetexpph, Vec, Vec) + ASMJIT_INST_2x(vgetexpph, Vgetexpph, Vec, Mem) + ASMJIT_INST_3x(vgetexpsh, Vgetexpsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vgetexpsh, Vgetexpsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vgetmantph, Vgetmantph, Vec, Vec, Imm) + ASMJIT_INST_3x(vgetmantph, Vgetmantph, Vec, Mem, Imm) + ASMJIT_INST_4x(vgetmantsh, Vgetmantsh, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vgetmantsh, Vgetmantsh, Vec, Vec, Mem, Imm) + ASMJIT_INST_3x(vmaxph, Vmaxph, Vec, Vec, Vec) + ASMJIT_INST_3x(vmaxph, Vmaxph, Vec, Vec, Mem) + ASMJIT_INST_3x(vmaxsh, Vmaxsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vmaxsh, Vmaxsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vminph, Vminph, Vec, Vec, Vec) + ASMJIT_INST_3x(vminph, Vminph, Vec, Vec, Mem) + ASMJIT_INST_3x(vminsh, Vminsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vminsh, Vminsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vmovsh, Vmovsh, Mem, Xmm) + ASMJIT_INST_2x(vmovsh, Vmovsh, Xmm, Mem) + ASMJIT_INST_3x(vmovsh, Vmovsh, Xmm, Xmm, Xmm) + ASMJIT_INST_2x(vmovw, Vmovw, Gp, Xmm) + ASMJIT_INST_2x(vmovw, Vmovw, Mem, Xmm) + ASMJIT_INST_2x(vmovw, Vmovw, Xmm, Gp) + ASMJIT_INST_2x(vmovw, Vmovw, Xmm, Mem) + ASMJIT_INST_3x(vmulph, Vmulph, Vec, Vec, Vec) + ASMJIT_INST_3x(vmulph, Vmulph, Vec, Vec, Mem) + ASMJIT_INST_3x(vmulsh, Vmulsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vmulsh, Vmulsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vrcpph, Vrcpph, Vec, Vec) + ASMJIT_INST_2x(vrcpph, Vrcpph, Vec, Mem) + ASMJIT_INST_3x(vrcpsh, Vrcpsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vrcpsh, Vrcpsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vreduceph, Vreduceph, Vec, Vec, Imm) + ASMJIT_INST_3x(vreduceph, Vreduceph, Vec, Mem, Imm) + ASMJIT_INST_4x(vreducesh, Vreducesh, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vreducesh, Vreducesh, Vec, Vec, Mem, Imm) + ASMJIT_INST_3x(vrndscaleph, Vrndscaleph, Vec, Vec, Imm) + ASMJIT_INST_3x(vrndscaleph, Vrndscaleph, Vec, Mem, Imm) + ASMJIT_INST_4x(vrndscalesh, Vrndscalesh, Vec, Vec, Vec, Imm) + ASMJIT_INST_4x(vrndscalesh, Vrndscalesh, Vec, Vec, Mem, Imm) + ASMJIT_INST_2x(vrsqrtph, Vrsqrtph, Vec, Vec) + ASMJIT_INST_2x(vrsqrtph, Vrsqrtph, Vec, Mem) + ASMJIT_INST_3x(vrsqrtsh, Vrsqrtsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vrsqrtsh, Vrsqrtsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vscalefph, Vscalefph, Vec, Vec, Vec) + ASMJIT_INST_3x(vscalefph, Vscalefph, Vec, Vec, Mem) + ASMJIT_INST_3x(vscalefsh, Vscalefsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vscalefsh, Vscalefsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vsqrtph, Vsqrtph, Vec, Vec) + ASMJIT_INST_2x(vsqrtph, Vsqrtph, Vec, Mem) + ASMJIT_INST_3x(vsqrtsh, Vsqrtsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vsqrtsh, Vsqrtsh, Vec, Vec, Mem) + ASMJIT_INST_3x(vsubph, Vsubph, Vec, Vec, Vec) + ASMJIT_INST_3x(vsubph, Vsubph, Vec, Vec, Mem) + ASMJIT_INST_3x(vsubsh, Vsubsh, Vec, Vec, Vec) + ASMJIT_INST_3x(vsubsh, Vsubsh, Vec, Vec, Mem) + ASMJIT_INST_2x(vucomish, Vucomish, Vec, Vec) + ASMJIT_INST_2x(vucomish, Vucomish, Vec, Mem) + + //! \} + + //! \name AMX_TILE Instructions + //! \{ + + ASMJIT_INST_1x(ldtilecfg, Ldtilecfg, Mem) + ASMJIT_INST_1x(sttilecfg, Sttilecfg, Mem) + ASMJIT_INST_2x(tileloadd, Tileloadd, Tmm, Mem) + ASMJIT_INST_2x(tileloaddt1, Tileloaddt1, Tmm, Mem) + ASMJIT_INST_0x(tilerelease, Tilerelease) + ASMJIT_INST_2x(tilestored, Tilestored, Mem, Tmm) + ASMJIT_INST_1x(tilezero, Tilezero, Tmm) + + //! \} + + //! \name AMX_BF16 Instructions + //! \{ + + ASMJIT_INST_3x(tdpbf16ps, Tdpbf16ps, Tmm, Tmm, Tmm) + + //! \} + + //! \name AMX_COMPLEX Instructions + //! \{ + + ASMJIT_INST_3x(tcmmimfp16ps, Tcmmimfp16ps, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tcmmrlfp16ps, Tcmmrlfp16ps, Tmm, Tmm, Tmm) + + //! \} + + //! \name AMX_FP16 Instructions + //! \{ + + ASMJIT_INST_3x(tdpfp16ps, Tdpfp16ps, Tmm, Tmm, Tmm) + + //! \} + + //! \name AMX_INT8 Instructions + //! \{ + + ASMJIT_INST_3x(tdpbssd, Tdpbssd, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tdpbsud, Tdpbsud, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tdpbusd, Tdpbusd, Tmm, Tmm, Tmm) + ASMJIT_INST_3x(tdpbuud, Tdpbuud, Tmm, Tmm, Tmm) + + //! \} +}; + +//! Emitter (X86 - implicit). +template +struct EmitterImplicitT : public EmitterExplicitT { + //! \cond + using EmitterExplicitT::_emitter; + //! \endcond + + //! \name Prefix Options + //! \{ + + //! Use REP/REPE prefix. + inline This& rep() noexcept { return EmitterExplicitT::_addInstOptions(InstOptions::kX86_Rep); } + //! Use REP/REPE prefix. + inline This& repe() noexcept { return rep(); } + //! Use REP/REPE prefix. + inline This& repz() noexcept { return rep(); } + + //! Use REPNE prefix. + inline This& repne() noexcept { return EmitterExplicitT::_addInstOptions(InstOptions::kX86_Repne); } + //! Use REPNE prefix. + inline This& repnz() noexcept { return repne(); } + + //! \} + + //! \name Core Instructions + //! \{ + + //! \cond + using EmitterExplicitT::cbw; + using EmitterExplicitT::cdq; + using EmitterExplicitT::cdqe; + using EmitterExplicitT::cqo; + using EmitterExplicitT::cwd; + using EmitterExplicitT::cwde; + using EmitterExplicitT::cmpsd; + using EmitterExplicitT::cmpxchg; + using EmitterExplicitT::cmpxchg8b; + using EmitterExplicitT::cmpxchg16b; + using EmitterExplicitT::div; + using EmitterExplicitT::idiv; + using EmitterExplicitT::imul; + using EmitterExplicitT::jecxz; + using EmitterExplicitT::loop; + using EmitterExplicitT::loope; + using EmitterExplicitT::loopne; + using EmitterExplicitT::mul; + //! \endcond + + ASMJIT_INST_0x(cbw, Cbw) // ANY [IMPLICIT] AX <- Sign Extend AL + ASMJIT_INST_0x(cdq, Cdq) // ANY [IMPLICIT] EDX:EAX <- Sign Extend EAX + ASMJIT_INST_0x(cdqe, Cdqe) // X64 [IMPLICIT] RAX <- Sign Extend EAX + ASMJIT_INST_2x(cmpxchg, Cmpxchg, Gp, Gp) // I486 [IMPLICIT] + ASMJIT_INST_2x(cmpxchg, Cmpxchg, Mem, Gp) // I486 [IMPLICIT] + ASMJIT_INST_1x(cmpxchg16b, Cmpxchg16b, Mem) // CMPXCHG8B [IMPLICIT] m == RDX:RAX ? m <- RCX:RBX + ASMJIT_INST_1x(cmpxchg8b, Cmpxchg8b, Mem) // CMPXCHG16B[IMPLICIT] m == EDX:EAX ? m <- ECX:EBX + ASMJIT_INST_0x(cqo, Cqo) // X64 [IMPLICIT] RDX:RAX <- Sign Extend RAX + ASMJIT_INST_0x(cwd, Cwd) // ANY [IMPLICIT] DX:AX <- Sign Extend AX + ASMJIT_INST_0x(cwde, Cwde) // ANY [IMPLICIT] EAX <- Sign Extend AX + ASMJIT_INST_1x(div, Div, Gp) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / r8} {xDX[Rem]:xAX[Quot] <- DX:AX / r16|r32|r64} + ASMJIT_INST_1x(div, Div, Mem) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / m8} {xDX[Rem]:xAX[Quot] <- DX:AX / m16|m32|m64} + ASMJIT_INST_1x(idiv, Idiv, Gp) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / r8} {xDX[Rem]:xAX[Quot] <- DX:AX / r16|r32|r64} + ASMJIT_INST_1x(idiv, Idiv, Mem) // ANY [IMPLICIT] {AH[Rem]: AL[Quot] <- AX / m8} {xDX[Rem]:xAX[Quot] <- DX:AX / m16|m32|m64} + ASMJIT_INST_1x(imul, Imul, Gp) // ANY [IMPLICIT] {AX <- AL * r8} {xAX:xDX <- xAX * r16|r32|r64} + ASMJIT_INST_1x(imul, Imul, Mem) // ANY [IMPLICIT] {AX <- AL * m8} {xAX:xDX <- xAX * m16|m32|m64} + ASMJIT_INST_0x(iret, Iret) // ANY [IMPLICIT] + ASMJIT_INST_0x(iretd, Iretd) // ANY [IMPLICIT] + ASMJIT_INST_0x(iretq, Iretq) // X64 [IMPLICIT] + ASMJIT_INST_1x(jecxz, Jecxz, Label) // ANY [IMPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_1x(jecxz, Jecxz, Imm) // ANY [IMPLICIT] Short jump if CX/ECX/RCX is zero. + ASMJIT_INST_1x(loop, Loop, Label) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_1x(loop, Loop, Imm) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0. + ASMJIT_INST_1x(loope, Loope, Label) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_1x(loope, Loope, Imm) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 1. + ASMJIT_INST_1x(loopne, Loopne, Label) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_1x(loopne, Loopne, Imm) // ANY [IMPLICIT] Decrement xCX; short jump if xCX != 0 && ZF == 0. + ASMJIT_INST_1x(mul, Mul, Gp) // ANY [IMPLICIT] {AX <- AL * r8} {xDX:xAX <- xAX * r16|r32|r64} + ASMJIT_INST_1x(mul, Mul, Mem) // ANY [IMPLICIT] {AX <- AL * m8} {xDX:xAX <- xAX * m16|m32|m64} + ASMJIT_INST_0x(ret, Ret) + ASMJIT_INST_1x(ret, Ret, Imm) + ASMJIT_INST_0x(retf, Retf) + ASMJIT_INST_1x(retf, Retf, Imm) + ASMJIT_INST_0x(xlatb, Xlatb) // ANY [IMPLICIT] + + //! \} + + //! \name String Instruction Aliases + //! \{ + + //! \cond + using EmitterExplicitT::movsd; + //! \endcond + + inline Error cmpsb() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 1), EmitterExplicitT::ptr_zdi(0, 1)); } + inline Error cmpsd() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 4), EmitterExplicitT::ptr_zdi(0, 4)); } + inline Error cmpsq() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 8), EmitterExplicitT::ptr_zdi(0, 8)); } + inline Error cmpsw() { return _emitter()->emit(Inst::kIdCmps, EmitterExplicitT::ptr_zsi(0, 2), EmitterExplicitT::ptr_zdi(0, 2)); } + + inline Error lodsb() { return _emitter()->emit(Inst::kIdLods, al , EmitterExplicitT::ptr_zsi(0, 1)); } + inline Error lodsd() { return _emitter()->emit(Inst::kIdLods, eax, EmitterExplicitT::ptr_zsi(0, 4)); } + inline Error lodsq() { return _emitter()->emit(Inst::kIdLods, rax, EmitterExplicitT::ptr_zsi(0, 8)); } + inline Error lodsw() { return _emitter()->emit(Inst::kIdLods, ax , EmitterExplicitT::ptr_zsi(0, 2)); } + + inline Error movsb() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 1), EmitterExplicitT::ptr_zsi(0, 1)); } + inline Error movsd() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 4), EmitterExplicitT::ptr_zsi(0, 4)); } + inline Error movsq() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 8), EmitterExplicitT::ptr_zsi(0, 8)); } + inline Error movsw() { return _emitter()->emit(Inst::kIdMovs, EmitterExplicitT::ptr_zdi(0, 2), EmitterExplicitT::ptr_zsi(0, 2)); } + + inline Error scasb() { return _emitter()->emit(Inst::kIdScas, al , EmitterExplicitT::ptr_zdi(0, 1)); } + inline Error scasd() { return _emitter()->emit(Inst::kIdScas, eax, EmitterExplicitT::ptr_zdi(0, 4)); } + inline Error scasq() { return _emitter()->emit(Inst::kIdScas, rax, EmitterExplicitT::ptr_zdi(0, 8)); } + inline Error scasw() { return _emitter()->emit(Inst::kIdScas, ax , EmitterExplicitT::ptr_zdi(0, 2)); } + + inline Error stosb() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 1), al ); } + inline Error stosd() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 4), eax); } + inline Error stosq() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 8), rax); } + inline Error stosw() { return _emitter()->emit(Inst::kIdStos, EmitterExplicitT::ptr_zdi(0, 2), ax ); } + + //! \} + + //! \name Deprecated 32-bit Instructions + //! \{ + + //! \cond + using EmitterExplicitT::aaa; + using EmitterExplicitT::aad; + using EmitterExplicitT::aam; + using EmitterExplicitT::aas; + using EmitterExplicitT::daa; + using EmitterExplicitT::das; + //! \endcond + + ASMJIT_INST_0x(aaa, Aaa) // X86 [IMPLICIT] + ASMJIT_INST_1x(aad, Aad, Imm) // X86 [IMPLICIT] + ASMJIT_INST_1x(aam, Aam, Imm) // X86 [IMPLICIT] + ASMJIT_INST_0x(aas, Aas) // X86 [IMPLICIT] + ASMJIT_INST_0x(daa, Daa) // X86 [IMPLICIT] + ASMJIT_INST_0x(das, Das) // X86 [IMPLICIT] + + //! \} + + //! \name LAHF/SAHF Instructions + //! \{ + + //! \cond + using EmitterExplicitT::lahf; + using EmitterExplicitT::sahf; + //! \endcond + + ASMJIT_INST_0x(lahf, Lahf) // LAHFSAHF [IMPLICIT] AH <- EFL + ASMJIT_INST_0x(sahf, Sahf) // LAHFSAHF [IMPLICIT] EFL <- AH + + //! \} + + //! \name CPUID Instruction + //! \{ + + //! \cond + using EmitterExplicitT::cpuid; + //! \endcond + + ASMJIT_INST_0x(cpuid, Cpuid) // I486 [IMPLICIT] EAX:EBX:ECX:EDX <- CPUID[EAX:ECX] + + //! \} + + //! \name CacheLine Instructions + //! \{ + + //! \cond + using EmitterExplicitT::clzero; + //! \endcond + + ASMJIT_INST_0x(clzero, Clzero) // CLZERO [IMPLICIT] + + //! \} + + //! \name RDPRU/RDPKRU Instructions + //! \{ + + //! \cond + using EmitterExplicitT::rdpru; + using EmitterExplicitT::rdpkru; + //! \endcond + + ASMJIT_INST_0x(rdpru, Rdpru) // RDPRU [IMPLICIT] EDX:EAX <- PRU[ECX] + ASMJIT_INST_0x(rdpkru, Rdpkru) // RDPKRU [IMPLICIT] EDX:EAX <- PKRU[ECX] + + //! \} + + //! \name RDTSC/RDTSCP Instructions + //! \{ + + //! \cond + using EmitterExplicitT::rdtsc; + using EmitterExplicitT::rdtscp; + //! \endcond + + ASMJIT_INST_0x(rdtsc, Rdtsc) // RDTSC [IMPLICIT] EDX:EAX <- CNT + ASMJIT_INST_0x(rdtscp, Rdtscp) // RDTSCP [IMPLICIT] EDX:EAX:EXC <- CNT + + //! \} + + //! \name BMI2 Instructions + //! \{ + + //! \cond + using EmitterExplicitT::mulx; + //! \endcond + + ASMJIT_INST_3x(mulx, Mulx, Gp, Gp, Gp) // BMI2 [IMPLICIT] + ASMJIT_INST_3x(mulx, Mulx, Gp, Gp, Mem) // BMI2 [IMPLICIT] + + //! \} + + //! \name XSAVE Instructions + //! \{ + + //! \cond + using EmitterExplicitT::xgetbv; + using EmitterExplicitT::xrstor; + using EmitterExplicitT::xrstor64; + using EmitterExplicitT::xrstors; + using EmitterExplicitT::xrstors64; + using EmitterExplicitT::xsave; + using EmitterExplicitT::xsave64; + using EmitterExplicitT::xsavec; + using EmitterExplicitT::xsavec64; + using EmitterExplicitT::xsaveopt; + using EmitterExplicitT::xsaveopt64; + using EmitterExplicitT::xsaves; + using EmitterExplicitT::xsaves64; + //! \endcond + + ASMJIT_INST_0x(xgetbv, Xgetbv) // XSAVE [IMPLICIT] EDX:EAX <- XCR[ECX] + ASMJIT_INST_1x(xrstor, Xrstor, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xrstor64, Xrstor64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xrstors, Xrstors, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xrstors64, Xrstors64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsave, Xsave, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsave64, Xsave64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsavec, Xsavec, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsavec64, Xsavec64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsaveopt, Xsaveopt, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsaveopt64, Xsaveopt64, Mem) // XSAVE+X64 [IMPLICIT] + ASMJIT_INST_1x(xsaves, Xsaves, Mem) // XSAVE [IMPLICIT] + ASMJIT_INST_1x(xsaves64, Xsaves64, Mem) // XSAVE+X64 [IMPLICIT] + + //! \} + + //! \name SYSCALL/SYSENTER Instructions + //! \{ + + ASMJIT_INST_0x(syscall, Syscall) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysenter, Sysenter) // X64 [IMPLICIT] + + //! \} + + //! \name HRESET Instructions + //! \{ + + //! \cond + using EmitterExplicitT::hreset; + //! \endcond + + ASMJIT_INST_1x(hreset, Hreset, Imm) // HRESET [IMPLICIT] + + //! \} + + //! \name SEAM Instructions + //! \{ + + ASMJIT_INST_0x(seamcall, Seamcall) + ASMJIT_INST_0x(seamops, Seamops) + ASMJIT_INST_0x(seamret, Seamret) + ASMJIT_INST_0x(tdcall, Tdcall) + + //! \} + + //! \name Privileged Instructions + //! \{ + + //! \cond + using EmitterExplicitT::rdmsr; + using EmitterExplicitT::rdpmc; + using EmitterExplicitT::wrmsr; + using EmitterExplicitT::xsetbv; + //! \endcond + + ASMJIT_INST_0x(pconfig, Pconfig) // PCONFIG [IMPLICIT] + ASMJIT_INST_0x(rdmsr, Rdmsr) // ANY [IMPLICIT] + ASMJIT_INST_0x(rdpmc, Rdpmc) // ANY [IMPLICIT] + ASMJIT_INST_0x(sysexit, Sysexit) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysexitq, Sysexitq) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysret, Sysret) // X64 [IMPLICIT] + ASMJIT_INST_0x(sysretq, Sysretq) // X64 [IMPLICIT] + ASMJIT_INST_0x(wrmsr, Wrmsr) // ANY [IMPLICIT] + ASMJIT_INST_0x(xsetbv, Xsetbv) // XSAVE [IMPLICIT] XCR[ECX] <- EDX:EAX + + //! \} + + //! \name Monitor & MWait Instructions + //! \{ + + //! \cond + using EmitterExplicitT::monitor; + using EmitterExplicitT::monitorx; + using EmitterExplicitT::mwait; + using EmitterExplicitT::mwaitx; + //! \endcond + + ASMJIT_INST_0x(monitor, Monitor) + ASMJIT_INST_0x(monitorx, Monitorx) + ASMJIT_INST_0x(mwait, Mwait) + ASMJIT_INST_0x(mwaitx, Mwaitx) + + //! \} + + //! \name WAITPKG Instructions + //! \{ + + //! \cond + using EmitterExplicitT::tpause; + using EmitterExplicitT::umwait; + //! \endcond + + ASMJIT_INST_1x(tpause, Tpause, Gp) + ASMJIT_INST_1x(umwait, Umwait, Gp) + + //! \} + + //! \name MMX & SSE Instructions + //! \{ + + //! \cond + using EmitterExplicitT::blendvpd; + using EmitterExplicitT::blendvps; + using EmitterExplicitT::maskmovq; + using EmitterExplicitT::maskmovdqu; + using EmitterExplicitT::pblendvb; + using EmitterExplicitT::pcmpestri; + using EmitterExplicitT::pcmpestrm; + using EmitterExplicitT::pcmpistri; + using EmitterExplicitT::pcmpistrm; + //! \endcond + + ASMJIT_INST_2x(blendvpd, Blendvpd, Xmm, Xmm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(blendvpd, Blendvpd, Xmm, Mem) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(blendvps, Blendvps, Xmm, Xmm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(blendvps, Blendvps, Xmm, Mem) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(pblendvb, Pblendvb, Xmm, Xmm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(pblendvb, Pblendvb, Xmm, Mem) // SSE4_1 [IMPLICIT] + ASMJIT_INST_2x(maskmovq, Maskmovq, Mm, Mm) // SSE [IMPLICIT] + ASMJIT_INST_2x(maskmovdqu, Maskmovdqu, Xmm, Xmm) // SSE2 [IMPLICIT] + ASMJIT_INST_3x(pcmpestri, Pcmpestri, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpestri, Pcmpestri, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpestrm, Pcmpestrm, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpestrm, Pcmpestrm, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistri, Pcmpistri, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistri, Pcmpistri, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistrm, Pcmpistrm, Xmm, Xmm, Imm) // SSE4_1 [IMPLICIT] + ASMJIT_INST_3x(pcmpistrm, Pcmpistrm, Xmm, Mem, Imm) // SSE4_1 [IMPLICIT] + + //! \} + + //! \name SHA Instructions + //! \{ + + //! \cond + using EmitterExplicitT::sha256rnds2; + //! \endcond + + ASMJIT_INST_2x(sha256rnds2, Sha256rnds2, Xmm, Xmm) // SHA [IMPLICIT] + ASMJIT_INST_2x(sha256rnds2, Sha256rnds2, Xmm, Mem) // SHA [IMPLICIT] + + //! \} + + //! \name AVX, FMA, and AVX512 Instructions + //! \{ + + //! \cond + using EmitterExplicitT::vmaskmovdqu; + using EmitterExplicitT::vpcmpestri; + using EmitterExplicitT::vpcmpestrm; + using EmitterExplicitT::vpcmpistri; + using EmitterExplicitT::vpcmpistrm; + //! \endcond + + ASMJIT_INST_2x(vmaskmovdqu, Vmaskmovdqu, Xmm, Xmm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestri, Vpcmpestri, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestri, Vpcmpestri, Xmm, Mem, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestrm, Vpcmpestrm, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpestrm, Vpcmpestrm, Xmm, Mem, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistri, Vpcmpistri, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistri, Vpcmpistri, Xmm, Mem, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistrm, Vpcmpistrm, Xmm, Xmm, Imm) // AVX [IMPLICIT] + ASMJIT_INST_3x(vpcmpistrm, Vpcmpistrm, Xmm, Mem, Imm) // AVX [IMPLICIT] + + //! \} +}; + +//! Emitter (X86). +//! +//! \note This class cannot be instantiated, you can only cast to it and use it as emitter that emits to either +//! `x86::Assembler`, `x86::Builder`, or `x86::Compiler` (use with caution with `x86::Compiler` as it requires +//! virtual registers). +class Emitter : public BaseEmitter, public EmitterImplicitT { + ASMJIT_NONCONSTRUCTIBLE(Emitter) +}; + +//! \} + +#undef ASMJIT_INST_0x +#undef ASMJIT_INST_1x +#undef ASMJIT_INST_1c +#undef ASMJIT_INST_2x +#undef ASMJIT_INST_2c +#undef ASMJIT_INST_3x +#undef ASMJIT_INST_4x +#undef ASMJIT_INST_5x +#undef ASMJIT_INST_6x + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86EMITTER_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86globals.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86globals.h new file mode 100644 index 0000000000000000000000000000000000000000..06b20c0c9655fda1c0313d4e25afa874afa8f5d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86globals.h @@ -0,0 +1,2234 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86GLOBALS_H_INCLUDED +#define ASMJIT_X86_X86GLOBALS_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/inst.h" + +//! \namespace asmjit::x86 +//! \ingroup asmjit_x86 +//! +//! X86/X64 API. + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! Condition code. +enum class CondCode : uint8_t { + kO = 0x00u, //!< OF==1 + kNO = 0x01u, //!< OF==0 + kC = 0x02u, //!< CF==1 + kB = 0x02u, //!< CF==1 (unsigned < ) + kNAE = 0x02u, //!< CF==1 (unsigned < ) + kNC = 0x03u, //!< CF==0 + kAE = 0x03u, //!< CF==0 (unsigned >=) + kNB = 0x03u, //!< CF==0 (unsigned >=) + kE = 0x04u, //!< ZF==1 (any_sign ==) + kZ = 0x04u, //!< ZF==1 (any_sign ==) + kNE = 0x05u, //!< ZF==0 (any_sign !=) + kNZ = 0x05u, //!< ZF==0 (any_sign !=) + kBE = 0x06u, //!< CF==1 | ZF==1 (unsigned <=) + kNA = 0x06u, //!< CF==1 | ZF==1 (unsigned <=) + kA = 0x07u, //!< CF==0 & ZF==0 (unsigned > ) + kNBE = 0x07u, //!< CF==0 & ZF==0 (unsigned > ) + kS = 0x08u, //!< SF==1 (is negative) + kNS = 0x09u, //!< SF==0 (is positive or zero) + kP = 0x0Au, //!< PF==1 + kPE = 0x0Au, //!< PF==1 + kPO = 0x0Bu, //!< PF==0 + kNP = 0x0Bu, //!< PF==0 + kL = 0x0Cu, //!< SF!=OF (signed < ) + kNGE = 0x0Cu, //!< SF!=OF (signed < ) + kGE = 0x0Du, //!< SF==OF (signed >=) + kNL = 0x0Du, //!< SF==OF (signed >=) + kLE = 0x0Eu, //!< ZF==1 | SF!=OF (signed <=) + kNG = 0x0Eu, //!< ZF==1 | SF!=OF (signed <=) + kG = 0x0Fu, //!< ZF==0 & SF==OF (signed > ) + kNLE = 0x0Fu, //!< ZF==0 & SF==OF (signed > ) + + kZero = kZ, //!< Zero flag. + kNotZero = kNZ, //!< Not zero. + + kEqual = kE, //!< `a == b` (equal). + kNotEqual = kNE, //!< `a != b` (not equal). + + kCarry = kC, //!< Carry flag. + kNotCarry = kNC, //!< Not carry. + + kSign = kS, //!< Sign flag. + kNotSign = kNS, //!< Not sign. + + kNegative = kS, //!< Sign flag. + kPositive = kNS, //!< Not sign. + + kOverflow = kO, //!< Overflow (signed). + kNotOverflow = kNO, //!< Not overflow (signed). + + kSignedLT = kL, //!< `a < b` (signed). + kSignedLE = kLE, //!< `a <= b` (signed). + kSignedGT = kG, //!< `a > b` (signed). + kSignedGE = kGE, //!< `a >= b` (signed). + + kUnsignedLT = kB, //!< `a < b` (unsigned). + kUnsignedLE = kBE, //!< `a <= b` (unsigned). + kUnsignedGT = kA, //!< `a > b` (unsigned). + kUnsignedGE = kAE, //!< `a >= b` (unsigned). + + kBTZero = kNC, //!< Tested bit is zero. + kBTNotZero = kC, //!< Tested bit is non-zero. + + kParityEven = kP, //!< Even parity flag. + kParityOdd = kPO, //!< Odd parity flag. + + kMaxValue = 0x0Fu +}; + +//! \cond +static constexpr CondCode _reverseCondTable[] = { + CondCode::kO, // O <- O + CondCode::kNO, // NO <- NO + CondCode::kA , // A <- B + CondCode::kBE, // BE <- AE + CondCode::kE, // E <- E + CondCode::kNE, // NE <- NE + CondCode::kAE, // AE <- BE + CondCode::kB , // B <- A + CondCode::kS, // S <- S + CondCode::kNS, // NS <- NS + CondCode::kPE, // PE <- PE + CondCode::kPO, // PO <- PO + CondCode::kG, // G <- L + CondCode::kLE, // LE <- GE + CondCode::kGE, // GE <- LE + CondCode::kL // L <- G +}; +//! \endcond + +//! Reverses a condition code (reverses the corresponding operands of a comparison). +static ASMJIT_INLINE_NODEBUG constexpr CondCode reverseCond(CondCode cond) noexcept { return _reverseCondTable[uint8_t(cond)]; } +//! Negates a condition code. +static ASMJIT_INLINE_NODEBUG constexpr CondCode negateCond(CondCode cond) noexcept { return CondCode(uint8_t(cond) ^ 1u); } + +//! Instruction. +//! +//! \note Only used to hold x86-specific instruction identifiers and some additional helper functions. +namespace Inst { + //! Instruction id. + enum Id : uint32_t { + // ${InstId:Begin} + kIdNone = 0, //!< Invalid instruction id. + kIdAaa, //!< Instruction 'aaa' (X86). + kIdAad, //!< Instruction 'aad' (X86). + kIdAadd, //!< Instruction 'aadd' {RAO_INT}. + kIdAam, //!< Instruction 'aam' (X86). + kIdAand, //!< Instruction 'aand' {RAO_INT}. + kIdAas, //!< Instruction 'aas' (X86). + kIdAdc, //!< Instruction 'adc'. + kIdAdcx, //!< Instruction 'adcx' {ADX}. + kIdAdd, //!< Instruction 'add'. + kIdAddpd, //!< Instruction 'addpd' {SSE2}. + kIdAddps, //!< Instruction 'addps' {SSE}. + kIdAddsd, //!< Instruction 'addsd' {SSE2}. + kIdAddss, //!< Instruction 'addss' {SSE}. + kIdAddsubpd, //!< Instruction 'addsubpd' {SSE3}. + kIdAddsubps, //!< Instruction 'addsubps' {SSE3}. + kIdAdox, //!< Instruction 'adox' {ADX}. + kIdAesdec, //!< Instruction 'aesdec' {AESNI}. + kIdAesdeclast, //!< Instruction 'aesdeclast' {AESNI}. + kIdAesenc, //!< Instruction 'aesenc' {AESNI}. + kIdAesenclast, //!< Instruction 'aesenclast' {AESNI}. + kIdAesimc, //!< Instruction 'aesimc' {AESNI}. + kIdAeskeygenassist, //!< Instruction 'aeskeygenassist' {AESNI}. + kIdAnd, //!< Instruction 'and'. + kIdAndn, //!< Instruction 'andn' {BMI}. + kIdAndnpd, //!< Instruction 'andnpd' {SSE2}. + kIdAndnps, //!< Instruction 'andnps' {SSE}. + kIdAndpd, //!< Instruction 'andpd' {SSE2}. + kIdAndps, //!< Instruction 'andps' {SSE}. + kIdAor, //!< Instruction 'aor' {RAO_INT}. + kIdArpl, //!< Instruction 'arpl' (X86). + kIdAxor, //!< Instruction 'axor' {RAO_INT}. + kIdBextr, //!< Instruction 'bextr' {BMI}. + kIdBlcfill, //!< Instruction 'blcfill' {TBM}. + kIdBlci, //!< Instruction 'blci' {TBM}. + kIdBlcic, //!< Instruction 'blcic' {TBM}. + kIdBlcmsk, //!< Instruction 'blcmsk' {TBM}. + kIdBlcs, //!< Instruction 'blcs' {TBM}. + kIdBlendpd, //!< Instruction 'blendpd' {SSE4_1}. + kIdBlendps, //!< Instruction 'blendps' {SSE4_1}. + kIdBlendvpd, //!< Instruction 'blendvpd' {SSE4_1}. + kIdBlendvps, //!< Instruction 'blendvps' {SSE4_1}. + kIdBlsfill, //!< Instruction 'blsfill' {TBM}. + kIdBlsi, //!< Instruction 'blsi' {BMI}. + kIdBlsic, //!< Instruction 'blsic' {TBM}. + kIdBlsmsk, //!< Instruction 'blsmsk' {BMI}. + kIdBlsr, //!< Instruction 'blsr' {BMI}. + kIdBndcl, //!< Instruction 'bndcl' {MPX}. + kIdBndcn, //!< Instruction 'bndcn' {MPX}. + kIdBndcu, //!< Instruction 'bndcu' {MPX}. + kIdBndldx, //!< Instruction 'bndldx' {MPX}. + kIdBndmk, //!< Instruction 'bndmk' {MPX}. + kIdBndmov, //!< Instruction 'bndmov' {MPX}. + kIdBndstx, //!< Instruction 'bndstx' {MPX}. + kIdBound, //!< Instruction 'bound' (X86). + kIdBsf, //!< Instruction 'bsf'. + kIdBsr, //!< Instruction 'bsr'. + kIdBswap, //!< Instruction 'bswap'. + kIdBt, //!< Instruction 'bt'. + kIdBtc, //!< Instruction 'btc'. + kIdBtr, //!< Instruction 'btr'. + kIdBts, //!< Instruction 'bts'. + kIdBzhi, //!< Instruction 'bzhi' {BMI2}. + kIdCall, //!< Instruction 'call'. + kIdCbw, //!< Instruction 'cbw'. + kIdCdq, //!< Instruction 'cdq'. + kIdCdqe, //!< Instruction 'cdqe' (X64). + kIdClac, //!< Instruction 'clac' {SMAP}. + kIdClc, //!< Instruction 'clc'. + kIdCld, //!< Instruction 'cld'. + kIdCldemote, //!< Instruction 'cldemote' {CLDEMOTE}. + kIdClflush, //!< Instruction 'clflush' {CLFLUSH}. + kIdClflushopt, //!< Instruction 'clflushopt' {CLFLUSHOPT}. + kIdClgi, //!< Instruction 'clgi' {SVM}. + kIdCli, //!< Instruction 'cli'. + kIdClrssbsy, //!< Instruction 'clrssbsy' {CET_SS}. + kIdClts, //!< Instruction 'clts'. + kIdClui, //!< Instruction 'clui' {UINTR} (X64). + kIdClwb, //!< Instruction 'clwb' {CLWB}. + kIdClzero, //!< Instruction 'clzero' {CLZERO}. + kIdCmc, //!< Instruction 'cmc'. + kIdCmova, //!< Instruction 'cmova' {CMOV}. + kIdCmovae, //!< Instruction 'cmovae' {CMOV}. + kIdCmovb, //!< Instruction 'cmovb' {CMOV}. + kIdCmovbe, //!< Instruction 'cmovbe' {CMOV}. + kIdCmovc, //!< Instruction 'cmovc' {CMOV}. + kIdCmove, //!< Instruction 'cmove' {CMOV}. + kIdCmovg, //!< Instruction 'cmovg' {CMOV}. + kIdCmovge, //!< Instruction 'cmovge' {CMOV}. + kIdCmovl, //!< Instruction 'cmovl' {CMOV}. + kIdCmovle, //!< Instruction 'cmovle' {CMOV}. + kIdCmovna, //!< Instruction 'cmovna' {CMOV}. + kIdCmovnae, //!< Instruction 'cmovnae' {CMOV}. + kIdCmovnb, //!< Instruction 'cmovnb' {CMOV}. + kIdCmovnbe, //!< Instruction 'cmovnbe' {CMOV}. + kIdCmovnc, //!< Instruction 'cmovnc' {CMOV}. + kIdCmovne, //!< Instruction 'cmovne' {CMOV}. + kIdCmovng, //!< Instruction 'cmovng' {CMOV}. + kIdCmovnge, //!< Instruction 'cmovnge' {CMOV}. + kIdCmovnl, //!< Instruction 'cmovnl' {CMOV}. + kIdCmovnle, //!< Instruction 'cmovnle' {CMOV}. + kIdCmovno, //!< Instruction 'cmovno' {CMOV}. + kIdCmovnp, //!< Instruction 'cmovnp' {CMOV}. + kIdCmovns, //!< Instruction 'cmovns' {CMOV}. + kIdCmovnz, //!< Instruction 'cmovnz' {CMOV}. + kIdCmovo, //!< Instruction 'cmovo' {CMOV}. + kIdCmovp, //!< Instruction 'cmovp' {CMOV}. + kIdCmovpe, //!< Instruction 'cmovpe' {CMOV}. + kIdCmovpo, //!< Instruction 'cmovpo' {CMOV}. + kIdCmovs, //!< Instruction 'cmovs' {CMOV}. + kIdCmovz, //!< Instruction 'cmovz' {CMOV}. + kIdCmp, //!< Instruction 'cmp'. + kIdCmpbexadd, //!< Instruction 'cmpbexadd' {CMPCCXADD}. + kIdCmpbxadd, //!< Instruction 'cmpbxadd' {CMPCCXADD}. + kIdCmplexadd, //!< Instruction 'cmplexadd' {CMPCCXADD}. + kIdCmplxadd, //!< Instruction 'cmplxadd' {CMPCCXADD}. + kIdCmpnbexadd, //!< Instruction 'cmpnbexadd' {CMPCCXADD}. + kIdCmpnbxadd, //!< Instruction 'cmpnbxadd' {CMPCCXADD}. + kIdCmpnlexadd, //!< Instruction 'cmpnlexadd' {CMPCCXADD}. + kIdCmpnlxadd, //!< Instruction 'cmpnlxadd' {CMPCCXADD}. + kIdCmpnoxadd, //!< Instruction 'cmpnoxadd' {CMPCCXADD}. + kIdCmpnpxadd, //!< Instruction 'cmpnpxadd' {CMPCCXADD}. + kIdCmpnsxadd, //!< Instruction 'cmpnsxadd' {CMPCCXADD}. + kIdCmpnzxadd, //!< Instruction 'cmpnzxadd' {CMPCCXADD}. + kIdCmpoxadd, //!< Instruction 'cmpoxadd' {CMPCCXADD}. + kIdCmppd, //!< Instruction 'cmppd' {SSE2}. + kIdCmpps, //!< Instruction 'cmpps' {SSE}. + kIdCmppxadd, //!< Instruction 'cmppxadd' {CMPCCXADD}. + kIdCmps, //!< Instruction 'cmps'. + kIdCmpsd, //!< Instruction 'cmpsd' {SSE2}. + kIdCmpss, //!< Instruction 'cmpss' {SSE}. + kIdCmpsxadd, //!< Instruction 'cmpsxadd' {CMPCCXADD}. + kIdCmpxchg, //!< Instruction 'cmpxchg' {I486}. + kIdCmpxchg16b, //!< Instruction 'cmpxchg16b' {CMPXCHG16B} (X64). + kIdCmpxchg8b, //!< Instruction 'cmpxchg8b' {CMPXCHG8B}. + kIdCmpzxadd, //!< Instruction 'cmpzxadd' {CMPCCXADD}. + kIdComisd, //!< Instruction 'comisd' {SSE2}. + kIdComiss, //!< Instruction 'comiss' {SSE}. + kIdCpuid, //!< Instruction 'cpuid' {I486}. + kIdCqo, //!< Instruction 'cqo' (X64). + kIdCrc32, //!< Instruction 'crc32' {SSE4_2}. + kIdCvtdq2pd, //!< Instruction 'cvtdq2pd' {SSE2}. + kIdCvtdq2ps, //!< Instruction 'cvtdq2ps' {SSE2}. + kIdCvtpd2dq, //!< Instruction 'cvtpd2dq' {SSE2}. + kIdCvtpd2pi, //!< Instruction 'cvtpd2pi' {SSE2}. + kIdCvtpd2ps, //!< Instruction 'cvtpd2ps' {SSE2}. + kIdCvtpi2pd, //!< Instruction 'cvtpi2pd' {SSE2}. + kIdCvtpi2ps, //!< Instruction 'cvtpi2ps' {SSE}. + kIdCvtps2dq, //!< Instruction 'cvtps2dq' {SSE2}. + kIdCvtps2pd, //!< Instruction 'cvtps2pd' {SSE2}. + kIdCvtps2pi, //!< Instruction 'cvtps2pi' {SSE}. + kIdCvtsd2si, //!< Instruction 'cvtsd2si' {SSE2}. + kIdCvtsd2ss, //!< Instruction 'cvtsd2ss' {SSE2}. + kIdCvtsi2sd, //!< Instruction 'cvtsi2sd' {SSE2}. + kIdCvtsi2ss, //!< Instruction 'cvtsi2ss' {SSE}. + kIdCvtss2sd, //!< Instruction 'cvtss2sd' {SSE2}. + kIdCvtss2si, //!< Instruction 'cvtss2si' {SSE}. + kIdCvttpd2dq, //!< Instruction 'cvttpd2dq' {SSE2}. + kIdCvttpd2pi, //!< Instruction 'cvttpd2pi' {SSE2}. + kIdCvttps2dq, //!< Instruction 'cvttps2dq' {SSE2}. + kIdCvttps2pi, //!< Instruction 'cvttps2pi' {SSE}. + kIdCvttsd2si, //!< Instruction 'cvttsd2si' {SSE2}. + kIdCvttss2si, //!< Instruction 'cvttss2si' {SSE}. + kIdCwd, //!< Instruction 'cwd'. + kIdCwde, //!< Instruction 'cwde'. + kIdDaa, //!< Instruction 'daa' (X86). + kIdDas, //!< Instruction 'das' (X86). + kIdDec, //!< Instruction 'dec'. + kIdDiv, //!< Instruction 'div'. + kIdDivpd, //!< Instruction 'divpd' {SSE2}. + kIdDivps, //!< Instruction 'divps' {SSE}. + kIdDivsd, //!< Instruction 'divsd' {SSE2}. + kIdDivss, //!< Instruction 'divss' {SSE}. + kIdDppd, //!< Instruction 'dppd' {SSE4_1}. + kIdDpps, //!< Instruction 'dpps' {SSE4_1}. + kIdEmms, //!< Instruction 'emms' {MMX}. + kIdEndbr32, //!< Instruction 'endbr32' {CET_IBT}. + kIdEndbr64, //!< Instruction 'endbr64' {CET_IBT}. + kIdEnqcmd, //!< Instruction 'enqcmd' {ENQCMD}. + kIdEnqcmds, //!< Instruction 'enqcmds' {ENQCMD}. + kIdEnter, //!< Instruction 'enter'. + kIdExtractps, //!< Instruction 'extractps' {SSE4_1}. + kIdExtrq, //!< Instruction 'extrq' {SSE4A}. + kIdF2xm1, //!< Instruction 'f2xm1' {FPU}. + kIdFabs, //!< Instruction 'fabs' {FPU}. + kIdFadd, //!< Instruction 'fadd' {FPU}. + kIdFaddp, //!< Instruction 'faddp' {FPU}. + kIdFbld, //!< Instruction 'fbld' {FPU}. + kIdFbstp, //!< Instruction 'fbstp' {FPU}. + kIdFchs, //!< Instruction 'fchs' {FPU}. + kIdFclex, //!< Instruction 'fclex' {FPU}. + kIdFcmovb, //!< Instruction 'fcmovb' {CMOV|FPU}. + kIdFcmovbe, //!< Instruction 'fcmovbe' {CMOV|FPU}. + kIdFcmove, //!< Instruction 'fcmove' {CMOV|FPU}. + kIdFcmovnb, //!< Instruction 'fcmovnb' {CMOV|FPU}. + kIdFcmovnbe, //!< Instruction 'fcmovnbe' {CMOV|FPU}. + kIdFcmovne, //!< Instruction 'fcmovne' {CMOV|FPU}. + kIdFcmovnu, //!< Instruction 'fcmovnu' {CMOV|FPU}. + kIdFcmovu, //!< Instruction 'fcmovu' {CMOV|FPU}. + kIdFcom, //!< Instruction 'fcom' {FPU}. + kIdFcomi, //!< Instruction 'fcomi' {FPU}. + kIdFcomip, //!< Instruction 'fcomip' {FPU}. + kIdFcomp, //!< Instruction 'fcomp' {FPU}. + kIdFcompp, //!< Instruction 'fcompp' {FPU}. + kIdFcos, //!< Instruction 'fcos' {FPU}. + kIdFdecstp, //!< Instruction 'fdecstp' {FPU}. + kIdFdiv, //!< Instruction 'fdiv' {FPU}. + kIdFdivp, //!< Instruction 'fdivp' {FPU}. + kIdFdivr, //!< Instruction 'fdivr' {FPU}. + kIdFdivrp, //!< Instruction 'fdivrp' {FPU}. + kIdFemms, //!< Instruction 'femms' {3DNOW}. + kIdFfree, //!< Instruction 'ffree' {FPU}. + kIdFiadd, //!< Instruction 'fiadd' {FPU}. + kIdFicom, //!< Instruction 'ficom' {FPU}. + kIdFicomp, //!< Instruction 'ficomp' {FPU}. + kIdFidiv, //!< Instruction 'fidiv' {FPU}. + kIdFidivr, //!< Instruction 'fidivr' {FPU}. + kIdFild, //!< Instruction 'fild' {FPU}. + kIdFimul, //!< Instruction 'fimul' {FPU}. + kIdFincstp, //!< Instruction 'fincstp' {FPU}. + kIdFinit, //!< Instruction 'finit' {FPU}. + kIdFist, //!< Instruction 'fist' {FPU}. + kIdFistp, //!< Instruction 'fistp' {FPU}. + kIdFisttp, //!< Instruction 'fisttp' {SSE3|FPU}. + kIdFisub, //!< Instruction 'fisub' {FPU}. + kIdFisubr, //!< Instruction 'fisubr' {FPU}. + kIdFld, //!< Instruction 'fld' {FPU}. + kIdFld1, //!< Instruction 'fld1' {FPU}. + kIdFldcw, //!< Instruction 'fldcw' {FPU}. + kIdFldenv, //!< Instruction 'fldenv' {FPU}. + kIdFldl2e, //!< Instruction 'fldl2e' {FPU}. + kIdFldl2t, //!< Instruction 'fldl2t' {FPU}. + kIdFldlg2, //!< Instruction 'fldlg2' {FPU}. + kIdFldln2, //!< Instruction 'fldln2' {FPU}. + kIdFldpi, //!< Instruction 'fldpi' {FPU}. + kIdFldz, //!< Instruction 'fldz' {FPU}. + kIdFmul, //!< Instruction 'fmul' {FPU}. + kIdFmulp, //!< Instruction 'fmulp' {FPU}. + kIdFnclex, //!< Instruction 'fnclex' {FPU}. + kIdFninit, //!< Instruction 'fninit' {FPU}. + kIdFnop, //!< Instruction 'fnop' {FPU}. + kIdFnsave, //!< Instruction 'fnsave' {FPU}. + kIdFnstcw, //!< Instruction 'fnstcw' {FPU}. + kIdFnstenv, //!< Instruction 'fnstenv' {FPU}. + kIdFnstsw, //!< Instruction 'fnstsw' {FPU}. + kIdFpatan, //!< Instruction 'fpatan' {FPU}. + kIdFprem, //!< Instruction 'fprem' {FPU}. + kIdFprem1, //!< Instruction 'fprem1' {FPU}. + kIdFptan, //!< Instruction 'fptan' {FPU}. + kIdFrndint, //!< Instruction 'frndint' {FPU}. + kIdFrstor, //!< Instruction 'frstor' {FPU}. + kIdFsave, //!< Instruction 'fsave' {FPU}. + kIdFscale, //!< Instruction 'fscale' {FPU}. + kIdFsin, //!< Instruction 'fsin' {FPU}. + kIdFsincos, //!< Instruction 'fsincos' {FPU}. + kIdFsqrt, //!< Instruction 'fsqrt' {FPU}. + kIdFst, //!< Instruction 'fst' {FPU}. + kIdFstcw, //!< Instruction 'fstcw' {FPU}. + kIdFstenv, //!< Instruction 'fstenv' {FPU}. + kIdFstp, //!< Instruction 'fstp' {FPU}. + kIdFstsw, //!< Instruction 'fstsw' {FPU}. + kIdFsub, //!< Instruction 'fsub' {FPU}. + kIdFsubp, //!< Instruction 'fsubp' {FPU}. + kIdFsubr, //!< Instruction 'fsubr' {FPU}. + kIdFsubrp, //!< Instruction 'fsubrp' {FPU}. + kIdFtst, //!< Instruction 'ftst' {FPU}. + kIdFucom, //!< Instruction 'fucom' {FPU}. + kIdFucomi, //!< Instruction 'fucomi' {FPU}. + kIdFucomip, //!< Instruction 'fucomip' {FPU}. + kIdFucomp, //!< Instruction 'fucomp' {FPU}. + kIdFucompp, //!< Instruction 'fucompp' {FPU}. + kIdFwait, //!< Instruction 'fwait' {FPU}. + kIdFxam, //!< Instruction 'fxam' {FPU}. + kIdFxch, //!< Instruction 'fxch' {FPU}. + kIdFxrstor, //!< Instruction 'fxrstor' {FXSR}. + kIdFxrstor64, //!< Instruction 'fxrstor64' {FXSR} (X64). + kIdFxsave, //!< Instruction 'fxsave' {FXSR}. + kIdFxsave64, //!< Instruction 'fxsave64' {FXSR} (X64). + kIdFxtract, //!< Instruction 'fxtract' {FPU}. + kIdFyl2x, //!< Instruction 'fyl2x' {FPU}. + kIdFyl2xp1, //!< Instruction 'fyl2xp1' {FPU}. + kIdGetsec, //!< Instruction 'getsec' {SMX}. + kIdGf2p8affineinvqb, //!< Instruction 'gf2p8affineinvqb' {GFNI}. + kIdGf2p8affineqb, //!< Instruction 'gf2p8affineqb' {GFNI}. + kIdGf2p8mulb, //!< Instruction 'gf2p8mulb' {GFNI}. + kIdHaddpd, //!< Instruction 'haddpd' {SSE3}. + kIdHaddps, //!< Instruction 'haddps' {SSE3}. + kIdHlt, //!< Instruction 'hlt'. + kIdHreset, //!< Instruction 'hreset' {HRESET}. + kIdHsubpd, //!< Instruction 'hsubpd' {SSE3}. + kIdHsubps, //!< Instruction 'hsubps' {SSE3}. + kIdIdiv, //!< Instruction 'idiv'. + kIdImul, //!< Instruction 'imul'. + kIdIn, //!< Instruction 'in'. + kIdInc, //!< Instruction 'inc'. + kIdIncsspd, //!< Instruction 'incsspd' {CET_SS}. + kIdIncsspq, //!< Instruction 'incsspq' {CET_SS} (X64). + kIdIns, //!< Instruction 'ins'. + kIdInsertps, //!< Instruction 'insertps' {SSE4_1}. + kIdInsertq, //!< Instruction 'insertq' {SSE4A}. + kIdInt, //!< Instruction 'int'. + kIdInt3, //!< Instruction 'int3'. + kIdInto, //!< Instruction 'into' (X86). + kIdInvd, //!< Instruction 'invd' {I486}. + kIdInvept, //!< Instruction 'invept' {VMX}. + kIdInvlpg, //!< Instruction 'invlpg' {I486}. + kIdInvlpga, //!< Instruction 'invlpga' {SVM}. + kIdInvlpgb, //!< Instruction 'invlpgb' {INVLPGB}. + kIdInvpcid, //!< Instruction 'invpcid' {I486}. + kIdInvvpid, //!< Instruction 'invvpid' {VMX}. + kIdIret, //!< Instruction 'iret'. + kIdIretd, //!< Instruction 'iretd'. + kIdIretq, //!< Instruction 'iretq' (X64). + kIdJa, //!< Instruction 'ja'. + kIdJae, //!< Instruction 'jae'. + kIdJb, //!< Instruction 'jb'. + kIdJbe, //!< Instruction 'jbe'. + kIdJc, //!< Instruction 'jc'. + kIdJe, //!< Instruction 'je'. + kIdJecxz, //!< Instruction 'jecxz'. + kIdJg, //!< Instruction 'jg'. + kIdJge, //!< Instruction 'jge'. + kIdJl, //!< Instruction 'jl'. + kIdJle, //!< Instruction 'jle'. + kIdJmp, //!< Instruction 'jmp'. + kIdJna, //!< Instruction 'jna'. + kIdJnae, //!< Instruction 'jnae'. + kIdJnb, //!< Instruction 'jnb'. + kIdJnbe, //!< Instruction 'jnbe'. + kIdJnc, //!< Instruction 'jnc'. + kIdJne, //!< Instruction 'jne'. + kIdJng, //!< Instruction 'jng'. + kIdJnge, //!< Instruction 'jnge'. + kIdJnl, //!< Instruction 'jnl'. + kIdJnle, //!< Instruction 'jnle'. + kIdJno, //!< Instruction 'jno'. + kIdJnp, //!< Instruction 'jnp'. + kIdJns, //!< Instruction 'jns'. + kIdJnz, //!< Instruction 'jnz'. + kIdJo, //!< Instruction 'jo'. + kIdJp, //!< Instruction 'jp'. + kIdJpe, //!< Instruction 'jpe'. + kIdJpo, //!< Instruction 'jpo'. + kIdJs, //!< Instruction 'js'. + kIdJz, //!< Instruction 'jz'. + kIdKaddb, //!< Instruction 'kaddb' {AVX512_DQ}. + kIdKaddd, //!< Instruction 'kaddd' {AVX512_BW}. + kIdKaddq, //!< Instruction 'kaddq' {AVX512_BW}. + kIdKaddw, //!< Instruction 'kaddw' {AVX512_DQ}. + kIdKandb, //!< Instruction 'kandb' {AVX512_DQ}. + kIdKandd, //!< Instruction 'kandd' {AVX512_BW}. + kIdKandnb, //!< Instruction 'kandnb' {AVX512_DQ}. + kIdKandnd, //!< Instruction 'kandnd' {AVX512_BW}. + kIdKandnq, //!< Instruction 'kandnq' {AVX512_BW}. + kIdKandnw, //!< Instruction 'kandnw' {AVX512_F}. + kIdKandq, //!< Instruction 'kandq' {AVX512_BW}. + kIdKandw, //!< Instruction 'kandw' {AVX512_F}. + kIdKmovb, //!< Instruction 'kmovb' {AVX512_DQ}. + kIdKmovd, //!< Instruction 'kmovd' {AVX512_BW}. + kIdKmovq, //!< Instruction 'kmovq' {AVX512_BW}. + kIdKmovw, //!< Instruction 'kmovw' {AVX512_F}. + kIdKnotb, //!< Instruction 'knotb' {AVX512_DQ}. + kIdKnotd, //!< Instruction 'knotd' {AVX512_BW}. + kIdKnotq, //!< Instruction 'knotq' {AVX512_BW}. + kIdKnotw, //!< Instruction 'knotw' {AVX512_F}. + kIdKorb, //!< Instruction 'korb' {AVX512_DQ}. + kIdKord, //!< Instruction 'kord' {AVX512_BW}. + kIdKorq, //!< Instruction 'korq' {AVX512_BW}. + kIdKortestb, //!< Instruction 'kortestb' {AVX512_DQ}. + kIdKortestd, //!< Instruction 'kortestd' {AVX512_BW}. + kIdKortestq, //!< Instruction 'kortestq' {AVX512_BW}. + kIdKortestw, //!< Instruction 'kortestw' {AVX512_F}. + kIdKorw, //!< Instruction 'korw' {AVX512_F}. + kIdKshiftlb, //!< Instruction 'kshiftlb' {AVX512_DQ}. + kIdKshiftld, //!< Instruction 'kshiftld' {AVX512_BW}. + kIdKshiftlq, //!< Instruction 'kshiftlq' {AVX512_BW}. + kIdKshiftlw, //!< Instruction 'kshiftlw' {AVX512_F}. + kIdKshiftrb, //!< Instruction 'kshiftrb' {AVX512_DQ}. + kIdKshiftrd, //!< Instruction 'kshiftrd' {AVX512_BW}. + kIdKshiftrq, //!< Instruction 'kshiftrq' {AVX512_BW}. + kIdKshiftrw, //!< Instruction 'kshiftrw' {AVX512_F}. + kIdKtestb, //!< Instruction 'ktestb' {AVX512_DQ}. + kIdKtestd, //!< Instruction 'ktestd' {AVX512_BW}. + kIdKtestq, //!< Instruction 'ktestq' {AVX512_BW}. + kIdKtestw, //!< Instruction 'ktestw' {AVX512_DQ}. + kIdKunpckbw, //!< Instruction 'kunpckbw' {AVX512_F}. + kIdKunpckdq, //!< Instruction 'kunpckdq' {AVX512_BW}. + kIdKunpckwd, //!< Instruction 'kunpckwd' {AVX512_BW}. + kIdKxnorb, //!< Instruction 'kxnorb' {AVX512_DQ}. + kIdKxnord, //!< Instruction 'kxnord' {AVX512_BW}. + kIdKxnorq, //!< Instruction 'kxnorq' {AVX512_BW}. + kIdKxnorw, //!< Instruction 'kxnorw' {AVX512_F}. + kIdKxorb, //!< Instruction 'kxorb' {AVX512_DQ}. + kIdKxord, //!< Instruction 'kxord' {AVX512_BW}. + kIdKxorq, //!< Instruction 'kxorq' {AVX512_BW}. + kIdKxorw, //!< Instruction 'kxorw' {AVX512_F}. + kIdLahf, //!< Instruction 'lahf' {LAHFSAHF}. + kIdLar, //!< Instruction 'lar'. + kIdLcall, //!< Instruction 'lcall'. + kIdLddqu, //!< Instruction 'lddqu' {SSE3}. + kIdLdmxcsr, //!< Instruction 'ldmxcsr' {SSE}. + kIdLds, //!< Instruction 'lds' (X86). + kIdLdtilecfg, //!< Instruction 'ldtilecfg' {AMX_TILE} (X64). + kIdLea, //!< Instruction 'lea'. + kIdLeave, //!< Instruction 'leave'. + kIdLes, //!< Instruction 'les' (X86). + kIdLfence, //!< Instruction 'lfence' {SSE2}. + kIdLfs, //!< Instruction 'lfs'. + kIdLgdt, //!< Instruction 'lgdt'. + kIdLgs, //!< Instruction 'lgs'. + kIdLidt, //!< Instruction 'lidt'. + kIdLjmp, //!< Instruction 'ljmp'. + kIdLldt, //!< Instruction 'lldt'. + kIdLlwpcb, //!< Instruction 'llwpcb' {LWP}. + kIdLmsw, //!< Instruction 'lmsw'. + kIdLods, //!< Instruction 'lods'. + kIdLoop, //!< Instruction 'loop'. + kIdLoope, //!< Instruction 'loope'. + kIdLoopne, //!< Instruction 'loopne'. + kIdLsl, //!< Instruction 'lsl'. + kIdLss, //!< Instruction 'lss'. + kIdLtr, //!< Instruction 'ltr'. + kIdLwpins, //!< Instruction 'lwpins' {LWP}. + kIdLwpval, //!< Instruction 'lwpval' {LWP}. + kIdLzcnt, //!< Instruction 'lzcnt' {LZCNT}. + kIdMaskmovdqu, //!< Instruction 'maskmovdqu' {SSE2}. + kIdMaskmovq, //!< Instruction 'maskmovq' {MMX2}. + kIdMaxpd, //!< Instruction 'maxpd' {SSE2}. + kIdMaxps, //!< Instruction 'maxps' {SSE}. + kIdMaxsd, //!< Instruction 'maxsd' {SSE2}. + kIdMaxss, //!< Instruction 'maxss' {SSE}. + kIdMcommit, //!< Instruction 'mcommit' {MCOMMIT}. + kIdMfence, //!< Instruction 'mfence' {SSE2}. + kIdMinpd, //!< Instruction 'minpd' {SSE2}. + kIdMinps, //!< Instruction 'minps' {SSE}. + kIdMinsd, //!< Instruction 'minsd' {SSE2}. + kIdMinss, //!< Instruction 'minss' {SSE}. + kIdMonitor, //!< Instruction 'monitor' {MONITOR}. + kIdMonitorx, //!< Instruction 'monitorx' {MONITORX}. + kIdMov, //!< Instruction 'mov'. + kIdMovabs, //!< Instruction 'movabs'. + kIdMovapd, //!< Instruction 'movapd' {SSE2}. + kIdMovaps, //!< Instruction 'movaps' {SSE}. + kIdMovbe, //!< Instruction 'movbe' {MOVBE}. + kIdMovd, //!< Instruction 'movd' {MMX|SSE2}. + kIdMovddup, //!< Instruction 'movddup' {SSE3}. + kIdMovdir64b, //!< Instruction 'movdir64b' {MOVDIR64B}. + kIdMovdiri, //!< Instruction 'movdiri' {MOVDIRI}. + kIdMovdq2q, //!< Instruction 'movdq2q' {SSE2}. + kIdMovdqa, //!< Instruction 'movdqa' {SSE2}. + kIdMovdqu, //!< Instruction 'movdqu' {SSE2}. + kIdMovhlps, //!< Instruction 'movhlps' {SSE}. + kIdMovhpd, //!< Instruction 'movhpd' {SSE2}. + kIdMovhps, //!< Instruction 'movhps' {SSE}. + kIdMovlhps, //!< Instruction 'movlhps' {SSE}. + kIdMovlpd, //!< Instruction 'movlpd' {SSE2}. + kIdMovlps, //!< Instruction 'movlps' {SSE}. + kIdMovmskpd, //!< Instruction 'movmskpd' {SSE2}. + kIdMovmskps, //!< Instruction 'movmskps' {SSE}. + kIdMovntdq, //!< Instruction 'movntdq' {SSE2}. + kIdMovntdqa, //!< Instruction 'movntdqa' {SSE4_1}. + kIdMovnti, //!< Instruction 'movnti' {SSE2}. + kIdMovntpd, //!< Instruction 'movntpd' {SSE2}. + kIdMovntps, //!< Instruction 'movntps' {SSE}. + kIdMovntq, //!< Instruction 'movntq' {MMX2}. + kIdMovntsd, //!< Instruction 'movntsd' {SSE4A}. + kIdMovntss, //!< Instruction 'movntss' {SSE4A}. + kIdMovq, //!< Instruction 'movq' {MMX|SSE2}. + kIdMovq2dq, //!< Instruction 'movq2dq' {SSE2}. + kIdMovs, //!< Instruction 'movs'. + kIdMovsd, //!< Instruction 'movsd' {SSE2}. + kIdMovshdup, //!< Instruction 'movshdup' {SSE3}. + kIdMovsldup, //!< Instruction 'movsldup' {SSE3}. + kIdMovss, //!< Instruction 'movss' {SSE}. + kIdMovsx, //!< Instruction 'movsx'. + kIdMovsxd, //!< Instruction 'movsxd' (X64). + kIdMovupd, //!< Instruction 'movupd' {SSE2}. + kIdMovups, //!< Instruction 'movups' {SSE}. + kIdMovzx, //!< Instruction 'movzx'. + kIdMpsadbw, //!< Instruction 'mpsadbw' {SSE4_1}. + kIdMul, //!< Instruction 'mul'. + kIdMulpd, //!< Instruction 'mulpd' {SSE2}. + kIdMulps, //!< Instruction 'mulps' {SSE}. + kIdMulsd, //!< Instruction 'mulsd' {SSE2}. + kIdMulss, //!< Instruction 'mulss' {SSE}. + kIdMulx, //!< Instruction 'mulx' {BMI2}. + kIdMwait, //!< Instruction 'mwait' {MONITOR}. + kIdMwaitx, //!< Instruction 'mwaitx' {MONITORX}. + kIdNeg, //!< Instruction 'neg'. + kIdNop, //!< Instruction 'nop'. + kIdNot, //!< Instruction 'not'. + kIdOr, //!< Instruction 'or'. + kIdOrpd, //!< Instruction 'orpd' {SSE2}. + kIdOrps, //!< Instruction 'orps' {SSE}. + kIdOut, //!< Instruction 'out'. + kIdOuts, //!< Instruction 'outs'. + kIdPabsb, //!< Instruction 'pabsb' {SSSE3}. + kIdPabsd, //!< Instruction 'pabsd' {SSSE3}. + kIdPabsw, //!< Instruction 'pabsw' {SSSE3}. + kIdPackssdw, //!< Instruction 'packssdw' {MMX|SSE2}. + kIdPacksswb, //!< Instruction 'packsswb' {MMX|SSE2}. + kIdPackusdw, //!< Instruction 'packusdw' {SSE4_1}. + kIdPackuswb, //!< Instruction 'packuswb' {MMX|SSE2}. + kIdPaddb, //!< Instruction 'paddb' {MMX|SSE2}. + kIdPaddd, //!< Instruction 'paddd' {MMX|SSE2}. + kIdPaddq, //!< Instruction 'paddq' {SSE2}. + kIdPaddsb, //!< Instruction 'paddsb' {MMX|SSE2}. + kIdPaddsw, //!< Instruction 'paddsw' {MMX|SSE2}. + kIdPaddusb, //!< Instruction 'paddusb' {MMX|SSE2}. + kIdPaddusw, //!< Instruction 'paddusw' {MMX|SSE2}. + kIdPaddw, //!< Instruction 'paddw' {MMX|SSE2}. + kIdPalignr, //!< Instruction 'palignr' {SSSE3}. + kIdPand, //!< Instruction 'pand' {MMX|SSE2}. + kIdPandn, //!< Instruction 'pandn' {MMX|SSE2}. + kIdPause, //!< Instruction 'pause'. + kIdPavgb, //!< Instruction 'pavgb' {MMX2|SSE2}. + kIdPavgusb, //!< Instruction 'pavgusb' {3DNOW}. + kIdPavgw, //!< Instruction 'pavgw' {MMX2|SSE2}. + kIdPblendvb, //!< Instruction 'pblendvb' {SSE4_1}. + kIdPblendw, //!< Instruction 'pblendw' {SSE4_1}. + kIdPclmulqdq, //!< Instruction 'pclmulqdq' {PCLMULQDQ}. + kIdPcmpeqb, //!< Instruction 'pcmpeqb' {MMX|SSE2}. + kIdPcmpeqd, //!< Instruction 'pcmpeqd' {MMX|SSE2}. + kIdPcmpeqq, //!< Instruction 'pcmpeqq' {SSE4_1}. + kIdPcmpeqw, //!< Instruction 'pcmpeqw' {MMX|SSE2}. + kIdPcmpestri, //!< Instruction 'pcmpestri' {SSE4_2}. + kIdPcmpestrm, //!< Instruction 'pcmpestrm' {SSE4_2}. + kIdPcmpgtb, //!< Instruction 'pcmpgtb' {MMX|SSE2}. + kIdPcmpgtd, //!< Instruction 'pcmpgtd' {MMX|SSE2}. + kIdPcmpgtq, //!< Instruction 'pcmpgtq' {SSE4_2}. + kIdPcmpgtw, //!< Instruction 'pcmpgtw' {MMX|SSE2}. + kIdPcmpistri, //!< Instruction 'pcmpistri' {SSE4_2}. + kIdPcmpistrm, //!< Instruction 'pcmpistrm' {SSE4_2}. + kIdPconfig, //!< Instruction 'pconfig' {PCONFIG}. + kIdPdep, //!< Instruction 'pdep' {BMI2}. + kIdPext, //!< Instruction 'pext' {BMI2}. + kIdPextrb, //!< Instruction 'pextrb' {SSE4_1}. + kIdPextrd, //!< Instruction 'pextrd' {SSE4_1}. + kIdPextrq, //!< Instruction 'pextrq' {SSE4_1} (X64). + kIdPextrw, //!< Instruction 'pextrw' {MMX2|SSE2|SSE4_1}. + kIdPf2id, //!< Instruction 'pf2id' {3DNOW}. + kIdPf2iw, //!< Instruction 'pf2iw' {3DNOW2}. + kIdPfacc, //!< Instruction 'pfacc' {3DNOW}. + kIdPfadd, //!< Instruction 'pfadd' {3DNOW}. + kIdPfcmpeq, //!< Instruction 'pfcmpeq' {3DNOW}. + kIdPfcmpge, //!< Instruction 'pfcmpge' {3DNOW}. + kIdPfcmpgt, //!< Instruction 'pfcmpgt' {3DNOW}. + kIdPfmax, //!< Instruction 'pfmax' {3DNOW}. + kIdPfmin, //!< Instruction 'pfmin' {3DNOW}. + kIdPfmul, //!< Instruction 'pfmul' {3DNOW}. + kIdPfnacc, //!< Instruction 'pfnacc' {3DNOW2}. + kIdPfpnacc, //!< Instruction 'pfpnacc' {3DNOW2}. + kIdPfrcp, //!< Instruction 'pfrcp' {3DNOW}. + kIdPfrcpit1, //!< Instruction 'pfrcpit1' {3DNOW}. + kIdPfrcpit2, //!< Instruction 'pfrcpit2' {3DNOW}. + kIdPfrcpv, //!< Instruction 'pfrcpv' {GEODE}. + kIdPfrsqit1, //!< Instruction 'pfrsqit1' {3DNOW}. + kIdPfrsqrt, //!< Instruction 'pfrsqrt' {3DNOW}. + kIdPfrsqrtv, //!< Instruction 'pfrsqrtv' {GEODE}. + kIdPfsub, //!< Instruction 'pfsub' {3DNOW}. + kIdPfsubr, //!< Instruction 'pfsubr' {3DNOW}. + kIdPhaddd, //!< Instruction 'phaddd' {SSSE3}. + kIdPhaddsw, //!< Instruction 'phaddsw' {SSSE3}. + kIdPhaddw, //!< Instruction 'phaddw' {SSSE3}. + kIdPhminposuw, //!< Instruction 'phminposuw' {SSE4_1}. + kIdPhsubd, //!< Instruction 'phsubd' {SSSE3}. + kIdPhsubsw, //!< Instruction 'phsubsw' {SSSE3}. + kIdPhsubw, //!< Instruction 'phsubw' {SSSE3}. + kIdPi2fd, //!< Instruction 'pi2fd' {3DNOW}. + kIdPi2fw, //!< Instruction 'pi2fw' {3DNOW2}. + kIdPinsrb, //!< Instruction 'pinsrb' {SSE4_1}. + kIdPinsrd, //!< Instruction 'pinsrd' {SSE4_1}. + kIdPinsrq, //!< Instruction 'pinsrq' {SSE4_1} (X64). + kIdPinsrw, //!< Instruction 'pinsrw' {MMX2|SSE2}. + kIdPmaddubsw, //!< Instruction 'pmaddubsw' {SSSE3}. + kIdPmaddwd, //!< Instruction 'pmaddwd' {MMX|SSE2}. + kIdPmaxsb, //!< Instruction 'pmaxsb' {SSE4_1}. + kIdPmaxsd, //!< Instruction 'pmaxsd' {SSE4_1}. + kIdPmaxsw, //!< Instruction 'pmaxsw' {MMX2|SSE2}. + kIdPmaxub, //!< Instruction 'pmaxub' {MMX2|SSE2}. + kIdPmaxud, //!< Instruction 'pmaxud' {SSE4_1}. + kIdPmaxuw, //!< Instruction 'pmaxuw' {SSE4_1}. + kIdPminsb, //!< Instruction 'pminsb' {SSE4_1}. + kIdPminsd, //!< Instruction 'pminsd' {SSE4_1}. + kIdPminsw, //!< Instruction 'pminsw' {MMX2|SSE2}. + kIdPminub, //!< Instruction 'pminub' {MMX2|SSE2}. + kIdPminud, //!< Instruction 'pminud' {SSE4_1}. + kIdPminuw, //!< Instruction 'pminuw' {SSE4_1}. + kIdPmovmskb, //!< Instruction 'pmovmskb' {MMX2|SSE2}. + kIdPmovsxbd, //!< Instruction 'pmovsxbd' {SSE4_1}. + kIdPmovsxbq, //!< Instruction 'pmovsxbq' {SSE4_1}. + kIdPmovsxbw, //!< Instruction 'pmovsxbw' {SSE4_1}. + kIdPmovsxdq, //!< Instruction 'pmovsxdq' {SSE4_1}. + kIdPmovsxwd, //!< Instruction 'pmovsxwd' {SSE4_1}. + kIdPmovsxwq, //!< Instruction 'pmovsxwq' {SSE4_1}. + kIdPmovzxbd, //!< Instruction 'pmovzxbd' {SSE4_1}. + kIdPmovzxbq, //!< Instruction 'pmovzxbq' {SSE4_1}. + kIdPmovzxbw, //!< Instruction 'pmovzxbw' {SSE4_1}. + kIdPmovzxdq, //!< Instruction 'pmovzxdq' {SSE4_1}. + kIdPmovzxwd, //!< Instruction 'pmovzxwd' {SSE4_1}. + kIdPmovzxwq, //!< Instruction 'pmovzxwq' {SSE4_1}. + kIdPmuldq, //!< Instruction 'pmuldq' {SSE4_1}. + kIdPmulhrsw, //!< Instruction 'pmulhrsw' {SSSE3}. + kIdPmulhrw, //!< Instruction 'pmulhrw' {3DNOW}. + kIdPmulhuw, //!< Instruction 'pmulhuw' {MMX2|SSE2}. + kIdPmulhw, //!< Instruction 'pmulhw' {MMX|SSE2}. + kIdPmulld, //!< Instruction 'pmulld' {SSE4_1}. + kIdPmullw, //!< Instruction 'pmullw' {MMX|SSE2}. + kIdPmuludq, //!< Instruction 'pmuludq' {SSE2}. + kIdPop, //!< Instruction 'pop'. + kIdPopa, //!< Instruction 'popa' (X86). + kIdPopad, //!< Instruction 'popad' (X86). + kIdPopcnt, //!< Instruction 'popcnt' {POPCNT}. + kIdPopf, //!< Instruction 'popf'. + kIdPopfd, //!< Instruction 'popfd' (X86). + kIdPopfq, //!< Instruction 'popfq' (X64). + kIdPor, //!< Instruction 'por' {MMX|SSE2}. + kIdPrefetch, //!< Instruction 'prefetch' {3DNOW}. + kIdPrefetchit0, //!< Instruction 'prefetchit0' {PREFETCHI} (X64). + kIdPrefetchit1, //!< Instruction 'prefetchit1' {PREFETCHI} (X64). + kIdPrefetchnta, //!< Instruction 'prefetchnta' {SSE}. + kIdPrefetcht0, //!< Instruction 'prefetcht0' {SSE}. + kIdPrefetcht1, //!< Instruction 'prefetcht1' {SSE}. + kIdPrefetcht2, //!< Instruction 'prefetcht2' {SSE}. + kIdPrefetchw, //!< Instruction 'prefetchw' {PREFETCHW}. + kIdPrefetchwt1, //!< Instruction 'prefetchwt1' {PREFETCHWT1}. + kIdPsadbw, //!< Instruction 'psadbw' {MMX2|SSE2}. + kIdPshufb, //!< Instruction 'pshufb' {SSSE3}. + kIdPshufd, //!< Instruction 'pshufd' {SSE2}. + kIdPshufhw, //!< Instruction 'pshufhw' {SSE2}. + kIdPshuflw, //!< Instruction 'pshuflw' {SSE2}. + kIdPshufw, //!< Instruction 'pshufw' {MMX2}. + kIdPsignb, //!< Instruction 'psignb' {SSSE3}. + kIdPsignd, //!< Instruction 'psignd' {SSSE3}. + kIdPsignw, //!< Instruction 'psignw' {SSSE3}. + kIdPslld, //!< Instruction 'pslld' {MMX|SSE2}. + kIdPslldq, //!< Instruction 'pslldq' {SSE2}. + kIdPsllq, //!< Instruction 'psllq' {MMX|SSE2}. + kIdPsllw, //!< Instruction 'psllw' {MMX|SSE2}. + kIdPsmash, //!< Instruction 'psmash' {SEV_SNP} (X64). + kIdPsrad, //!< Instruction 'psrad' {MMX|SSE2}. + kIdPsraw, //!< Instruction 'psraw' {MMX|SSE2}. + kIdPsrld, //!< Instruction 'psrld' {MMX|SSE2}. + kIdPsrldq, //!< Instruction 'psrldq' {SSE2}. + kIdPsrlq, //!< Instruction 'psrlq' {MMX|SSE2}. + kIdPsrlw, //!< Instruction 'psrlw' {MMX|SSE2}. + kIdPsubb, //!< Instruction 'psubb' {MMX|SSE2}. + kIdPsubd, //!< Instruction 'psubd' {MMX|SSE2}. + kIdPsubq, //!< Instruction 'psubq' {SSE2}. + kIdPsubsb, //!< Instruction 'psubsb' {MMX|SSE2}. + kIdPsubsw, //!< Instruction 'psubsw' {MMX|SSE2}. + kIdPsubusb, //!< Instruction 'psubusb' {MMX|SSE2}. + kIdPsubusw, //!< Instruction 'psubusw' {MMX|SSE2}. + kIdPsubw, //!< Instruction 'psubw' {MMX|SSE2}. + kIdPswapd, //!< Instruction 'pswapd' {3DNOW2}. + kIdPtest, //!< Instruction 'ptest' {SSE4_1}. + kIdPtwrite, //!< Instruction 'ptwrite' {PTWRITE}. + kIdPunpckhbw, //!< Instruction 'punpckhbw' {MMX|SSE2}. + kIdPunpckhdq, //!< Instruction 'punpckhdq' {MMX|SSE2}. + kIdPunpckhqdq, //!< Instruction 'punpckhqdq' {SSE2}. + kIdPunpckhwd, //!< Instruction 'punpckhwd' {MMX|SSE2}. + kIdPunpcklbw, //!< Instruction 'punpcklbw' {MMX|SSE2}. + kIdPunpckldq, //!< Instruction 'punpckldq' {MMX|SSE2}. + kIdPunpcklqdq, //!< Instruction 'punpcklqdq' {SSE2}. + kIdPunpcklwd, //!< Instruction 'punpcklwd' {MMX|SSE2}. + kIdPush, //!< Instruction 'push'. + kIdPusha, //!< Instruction 'pusha' (X86). + kIdPushad, //!< Instruction 'pushad' (X86). + kIdPushf, //!< Instruction 'pushf'. + kIdPushfd, //!< Instruction 'pushfd' (X86). + kIdPushfq, //!< Instruction 'pushfq' (X64). + kIdPvalidate, //!< Instruction 'pvalidate' {SEV_SNP}. + kIdPxor, //!< Instruction 'pxor' {MMX|SSE2}. + kIdRcl, //!< Instruction 'rcl'. + kIdRcpps, //!< Instruction 'rcpps' {SSE}. + kIdRcpss, //!< Instruction 'rcpss' {SSE}. + kIdRcr, //!< Instruction 'rcr'. + kIdRdfsbase, //!< Instruction 'rdfsbase' {FSGSBASE} (X64). + kIdRdgsbase, //!< Instruction 'rdgsbase' {FSGSBASE} (X64). + kIdRdmsr, //!< Instruction 'rdmsr' {MSR}. + kIdRdpid, //!< Instruction 'rdpid' {RDPID}. + kIdRdpkru, //!< Instruction 'rdpkru' {OSPKE}. + kIdRdpmc, //!< Instruction 'rdpmc'. + kIdRdpru, //!< Instruction 'rdpru' {RDPRU}. + kIdRdrand, //!< Instruction 'rdrand' {RDRAND}. + kIdRdseed, //!< Instruction 'rdseed' {RDSEED}. + kIdRdsspd, //!< Instruction 'rdsspd' {CET_SS}. + kIdRdsspq, //!< Instruction 'rdsspq' {CET_SS} (X64). + kIdRdtsc, //!< Instruction 'rdtsc' {RDTSC}. + kIdRdtscp, //!< Instruction 'rdtscp' {RDTSCP}. + kIdRet, //!< Instruction 'ret'. + kIdRetf, //!< Instruction 'retf'. + kIdRmpadjust, //!< Instruction 'rmpadjust' {SEV_SNP} (X64). + kIdRmpupdate, //!< Instruction 'rmpupdate' {SEV_SNP} (X64). + kIdRol, //!< Instruction 'rol'. + kIdRor, //!< Instruction 'ror'. + kIdRorx, //!< Instruction 'rorx' {BMI2}. + kIdRoundpd, //!< Instruction 'roundpd' {SSE4_1}. + kIdRoundps, //!< Instruction 'roundps' {SSE4_1}. + kIdRoundsd, //!< Instruction 'roundsd' {SSE4_1}. + kIdRoundss, //!< Instruction 'roundss' {SSE4_1}. + kIdRsm, //!< Instruction 'rsm' (X86). + kIdRsqrtps, //!< Instruction 'rsqrtps' {SSE}. + kIdRsqrtss, //!< Instruction 'rsqrtss' {SSE}. + kIdRstorssp, //!< Instruction 'rstorssp' {CET_SS}. + kIdSahf, //!< Instruction 'sahf' {LAHFSAHF}. + kIdSal, //!< Instruction 'sal'. + kIdSar, //!< Instruction 'sar'. + kIdSarx, //!< Instruction 'sarx' {BMI2}. + kIdSaveprevssp, //!< Instruction 'saveprevssp' {CET_SS}. + kIdSbb, //!< Instruction 'sbb'. + kIdScas, //!< Instruction 'scas'. + kIdSeamcall, //!< Instruction 'seamcall' {SEAM}. + kIdSeamops, //!< Instruction 'seamops' {SEAM}. + kIdSeamret, //!< Instruction 'seamret' {SEAM}. + kIdSenduipi, //!< Instruction 'senduipi' {UINTR} (X64). + kIdSerialize, //!< Instruction 'serialize' {SERIALIZE}. + kIdSeta, //!< Instruction 'seta'. + kIdSetae, //!< Instruction 'setae'. + kIdSetb, //!< Instruction 'setb'. + kIdSetbe, //!< Instruction 'setbe'. + kIdSetc, //!< Instruction 'setc'. + kIdSete, //!< Instruction 'sete'. + kIdSetg, //!< Instruction 'setg'. + kIdSetge, //!< Instruction 'setge'. + kIdSetl, //!< Instruction 'setl'. + kIdSetle, //!< Instruction 'setle'. + kIdSetna, //!< Instruction 'setna'. + kIdSetnae, //!< Instruction 'setnae'. + kIdSetnb, //!< Instruction 'setnb'. + kIdSetnbe, //!< Instruction 'setnbe'. + kIdSetnc, //!< Instruction 'setnc'. + kIdSetne, //!< Instruction 'setne'. + kIdSetng, //!< Instruction 'setng'. + kIdSetnge, //!< Instruction 'setnge'. + kIdSetnl, //!< Instruction 'setnl'. + kIdSetnle, //!< Instruction 'setnle'. + kIdSetno, //!< Instruction 'setno'. + kIdSetnp, //!< Instruction 'setnp'. + kIdSetns, //!< Instruction 'setns'. + kIdSetnz, //!< Instruction 'setnz'. + kIdSeto, //!< Instruction 'seto'. + kIdSetp, //!< Instruction 'setp'. + kIdSetpe, //!< Instruction 'setpe'. + kIdSetpo, //!< Instruction 'setpo'. + kIdSets, //!< Instruction 'sets'. + kIdSetssbsy, //!< Instruction 'setssbsy' {CET_SS}. + kIdSetz, //!< Instruction 'setz'. + kIdSfence, //!< Instruction 'sfence' {SSE}. + kIdSgdt, //!< Instruction 'sgdt'. + kIdSha1msg1, //!< Instruction 'sha1msg1' {SHA}. + kIdSha1msg2, //!< Instruction 'sha1msg2' {SHA}. + kIdSha1nexte, //!< Instruction 'sha1nexte' {SHA}. + kIdSha1rnds4, //!< Instruction 'sha1rnds4' {SHA}. + kIdSha256msg1, //!< Instruction 'sha256msg1' {SHA}. + kIdSha256msg2, //!< Instruction 'sha256msg2' {SHA}. + kIdSha256rnds2, //!< Instruction 'sha256rnds2' {SHA}. + kIdShl, //!< Instruction 'shl'. + kIdShld, //!< Instruction 'shld'. + kIdShlx, //!< Instruction 'shlx' {BMI2}. + kIdShr, //!< Instruction 'shr'. + kIdShrd, //!< Instruction 'shrd'. + kIdShrx, //!< Instruction 'shrx' {BMI2}. + kIdShufpd, //!< Instruction 'shufpd' {SSE2}. + kIdShufps, //!< Instruction 'shufps' {SSE}. + kIdSidt, //!< Instruction 'sidt'. + kIdSkinit, //!< Instruction 'skinit' {SKINIT}. + kIdSldt, //!< Instruction 'sldt'. + kIdSlwpcb, //!< Instruction 'slwpcb' {LWP}. + kIdSmsw, //!< Instruction 'smsw'. + kIdSqrtpd, //!< Instruction 'sqrtpd' {SSE2}. + kIdSqrtps, //!< Instruction 'sqrtps' {SSE}. + kIdSqrtsd, //!< Instruction 'sqrtsd' {SSE2}. + kIdSqrtss, //!< Instruction 'sqrtss' {SSE}. + kIdStac, //!< Instruction 'stac' {SMAP}. + kIdStc, //!< Instruction 'stc'. + kIdStd, //!< Instruction 'std'. + kIdStgi, //!< Instruction 'stgi' {SKINIT}. + kIdSti, //!< Instruction 'sti'. + kIdStmxcsr, //!< Instruction 'stmxcsr' {SSE}. + kIdStos, //!< Instruction 'stos'. + kIdStr, //!< Instruction 'str'. + kIdSttilecfg, //!< Instruction 'sttilecfg' {AMX_TILE} (X64). + kIdStui, //!< Instruction 'stui' {UINTR} (X64). + kIdSub, //!< Instruction 'sub'. + kIdSubpd, //!< Instruction 'subpd' {SSE2}. + kIdSubps, //!< Instruction 'subps' {SSE}. + kIdSubsd, //!< Instruction 'subsd' {SSE2}. + kIdSubss, //!< Instruction 'subss' {SSE}. + kIdSwapgs, //!< Instruction 'swapgs' (X64). + kIdSyscall, //!< Instruction 'syscall' (X64). + kIdSysenter, //!< Instruction 'sysenter'. + kIdSysexit, //!< Instruction 'sysexit'. + kIdSysexitq, //!< Instruction 'sysexitq' (X64). + kIdSysret, //!< Instruction 'sysret' (X64). + kIdSysretq, //!< Instruction 'sysretq' (X64). + kIdT1mskc, //!< Instruction 't1mskc' {TBM}. + kIdTcmmimfp16ps, //!< Instruction 'tcmmimfp16ps' {AMX_COMPLEX} (X64). + kIdTcmmrlfp16ps, //!< Instruction 'tcmmrlfp16ps' {AMX_COMPLEX} (X64). + kIdTdcall, //!< Instruction 'tdcall' {SEAM}. + kIdTdpbf16ps, //!< Instruction 'tdpbf16ps' {AMX_BF16} (X64). + kIdTdpbssd, //!< Instruction 'tdpbssd' {AMX_INT8} (X64). + kIdTdpbsud, //!< Instruction 'tdpbsud' {AMX_INT8} (X64). + kIdTdpbusd, //!< Instruction 'tdpbusd' {AMX_INT8} (X64). + kIdTdpbuud, //!< Instruction 'tdpbuud' {AMX_INT8} (X64). + kIdTdpfp16ps, //!< Instruction 'tdpfp16ps' {AMX_FP16} (X64). + kIdTest, //!< Instruction 'test'. + kIdTestui, //!< Instruction 'testui' {UINTR} (X64). + kIdTileloadd, //!< Instruction 'tileloadd' {AMX_TILE} (X64). + kIdTileloaddt1, //!< Instruction 'tileloaddt1' {AMX_TILE} (X64). + kIdTilerelease, //!< Instruction 'tilerelease' {AMX_TILE} (X64). + kIdTilestored, //!< Instruction 'tilestored' {AMX_TILE} (X64). + kIdTilezero, //!< Instruction 'tilezero' {AMX_TILE} (X64). + kIdTlbsync, //!< Instruction 'tlbsync' {INVLPGB}. + kIdTpause, //!< Instruction 'tpause' {WAITPKG}. + kIdTzcnt, //!< Instruction 'tzcnt' {BMI}. + kIdTzmsk, //!< Instruction 'tzmsk' {TBM}. + kIdUcomisd, //!< Instruction 'ucomisd' {SSE2}. + kIdUcomiss, //!< Instruction 'ucomiss' {SSE}. + kIdUd0, //!< Instruction 'ud0'. + kIdUd1, //!< Instruction 'ud1'. + kIdUd2, //!< Instruction 'ud2'. + kIdUiret, //!< Instruction 'uiret' {UINTR} (X64). + kIdUmonitor, //!< Instruction 'umonitor' {WAITPKG}. + kIdUmwait, //!< Instruction 'umwait' {WAITPKG}. + kIdUnpckhpd, //!< Instruction 'unpckhpd' {SSE2}. + kIdUnpckhps, //!< Instruction 'unpckhps' {SSE}. + kIdUnpcklpd, //!< Instruction 'unpcklpd' {SSE2}. + kIdUnpcklps, //!< Instruction 'unpcklps' {SSE}. + kIdV4fmaddps, //!< Instruction 'v4fmaddps' {AVX512_4FMAPS}. + kIdV4fmaddss, //!< Instruction 'v4fmaddss' {AVX512_4FMAPS}. + kIdV4fnmaddps, //!< Instruction 'v4fnmaddps' {AVX512_4FMAPS}. + kIdV4fnmaddss, //!< Instruction 'v4fnmaddss' {AVX512_4FMAPS}. + kIdVaddpd, //!< Instruction 'vaddpd' {AVX|AVX512_F+VL}. + kIdVaddph, //!< Instruction 'vaddph' {AVX512_FP16+VL}. + kIdVaddps, //!< Instruction 'vaddps' {AVX|AVX512_F+VL}. + kIdVaddsd, //!< Instruction 'vaddsd' {AVX|AVX512_F}. + kIdVaddsh, //!< Instruction 'vaddsh' {AVX512_FP16}. + kIdVaddss, //!< Instruction 'vaddss' {AVX|AVX512_F}. + kIdVaddsubpd, //!< Instruction 'vaddsubpd' {AVX}. + kIdVaddsubps, //!< Instruction 'vaddsubps' {AVX}. + kIdVaesdec, //!< Instruction 'vaesdec' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesdeclast, //!< Instruction 'vaesdeclast' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesenc, //!< Instruction 'vaesenc' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesenclast, //!< Instruction 'vaesenclast' {AVX|AVX512_F+VL & AESNI|VAES}. + kIdVaesimc, //!< Instruction 'vaesimc' {AVX & AESNI}. + kIdVaeskeygenassist, //!< Instruction 'vaeskeygenassist' {AVX & AESNI}. + kIdValignd, //!< Instruction 'valignd' {AVX512_F+VL}. + kIdValignq, //!< Instruction 'valignq' {AVX512_F+VL}. + kIdVandnpd, //!< Instruction 'vandnpd' {AVX|AVX512_DQ+VL}. + kIdVandnps, //!< Instruction 'vandnps' {AVX|AVX512_DQ+VL}. + kIdVandpd, //!< Instruction 'vandpd' {AVX|AVX512_DQ+VL}. + kIdVandps, //!< Instruction 'vandps' {AVX|AVX512_DQ+VL}. + kIdVbcstnebf162ps, //!< Instruction 'vbcstnebf162ps' {AVX_NE_CONVERT}. + kIdVbcstnesh2ps, //!< Instruction 'vbcstnesh2ps' {AVX_NE_CONVERT}. + kIdVblendmpd, //!< Instruction 'vblendmpd' {AVX512_F+VL}. + kIdVblendmps, //!< Instruction 'vblendmps' {AVX512_F+VL}. + kIdVblendpd, //!< Instruction 'vblendpd' {AVX}. + kIdVblendps, //!< Instruction 'vblendps' {AVX}. + kIdVblendvpd, //!< Instruction 'vblendvpd' {AVX}. + kIdVblendvps, //!< Instruction 'vblendvps' {AVX}. + kIdVbroadcastf128, //!< Instruction 'vbroadcastf128' {AVX}. + kIdVbroadcastf32x2, //!< Instruction 'vbroadcastf32x2' {AVX512_DQ+VL}. + kIdVbroadcastf32x4, //!< Instruction 'vbroadcastf32x4' {AVX512_F}. + kIdVbroadcastf32x8, //!< Instruction 'vbroadcastf32x8' {AVX512_DQ}. + kIdVbroadcastf64x2, //!< Instruction 'vbroadcastf64x2' {AVX512_DQ+VL}. + kIdVbroadcastf64x4, //!< Instruction 'vbroadcastf64x4' {AVX512_F}. + kIdVbroadcasti128, //!< Instruction 'vbroadcasti128' {AVX2}. + kIdVbroadcasti32x2, //!< Instruction 'vbroadcasti32x2' {AVX512_DQ+VL}. + kIdVbroadcasti32x4, //!< Instruction 'vbroadcasti32x4' {AVX512_F+VL}. + kIdVbroadcasti32x8, //!< Instruction 'vbroadcasti32x8' {AVX512_DQ}. + kIdVbroadcasti64x2, //!< Instruction 'vbroadcasti64x2' {AVX512_DQ+VL}. + kIdVbroadcasti64x4, //!< Instruction 'vbroadcasti64x4' {AVX512_F}. + kIdVbroadcastsd, //!< Instruction 'vbroadcastsd' {AVX|AVX2|AVX512_F+VL}. + kIdVbroadcastss, //!< Instruction 'vbroadcastss' {AVX|AVX2|AVX512_F+VL}. + kIdVcmppd, //!< Instruction 'vcmppd' {AVX|AVX512_F+VL}. + kIdVcmpph, //!< Instruction 'vcmpph' {AVX512_FP16+VL}. + kIdVcmpps, //!< Instruction 'vcmpps' {AVX|AVX512_F+VL}. + kIdVcmpsd, //!< Instruction 'vcmpsd' {AVX|AVX512_F}. + kIdVcmpsh, //!< Instruction 'vcmpsh' {AVX512_FP16}. + kIdVcmpss, //!< Instruction 'vcmpss' {AVX|AVX512_F}. + kIdVcomisd, //!< Instruction 'vcomisd' {AVX|AVX512_F}. + kIdVcomish, //!< Instruction 'vcomish' {AVX512_FP16}. + kIdVcomiss, //!< Instruction 'vcomiss' {AVX|AVX512_F}. + kIdVcompresspd, //!< Instruction 'vcompresspd' {AVX512_F+VL}. + kIdVcompressps, //!< Instruction 'vcompressps' {AVX512_F+VL}. + kIdVcvtdq2pd, //!< Instruction 'vcvtdq2pd' {AVX|AVX512_F+VL}. + kIdVcvtdq2ph, //!< Instruction 'vcvtdq2ph' {AVX512_FP16+VL}. + kIdVcvtdq2ps, //!< Instruction 'vcvtdq2ps' {AVX|AVX512_F+VL}. + kIdVcvtne2ps2bf16, //!< Instruction 'vcvtne2ps2bf16' {AVX512_BF16+VL}. + kIdVcvtneebf162ps, //!< Instruction 'vcvtneebf162ps' {AVX_NE_CONVERT}. + kIdVcvtneeph2ps, //!< Instruction 'vcvtneeph2ps' {AVX_NE_CONVERT}. + kIdVcvtneobf162ps, //!< Instruction 'vcvtneobf162ps' {AVX_NE_CONVERT}. + kIdVcvtneoph2ps, //!< Instruction 'vcvtneoph2ps' {AVX_NE_CONVERT}. + kIdVcvtneps2bf16, //!< Instruction 'vcvtneps2bf16' {AVX_NE_CONVERT|AVX512_BF16+VL}. + kIdVcvtpd2dq, //!< Instruction 'vcvtpd2dq' {AVX|AVX512_F+VL}. + kIdVcvtpd2ph, //!< Instruction 'vcvtpd2ph' {AVX512_FP16+VL}. + kIdVcvtpd2ps, //!< Instruction 'vcvtpd2ps' {AVX|AVX512_F+VL}. + kIdVcvtpd2qq, //!< Instruction 'vcvtpd2qq' {AVX512_DQ+VL}. + kIdVcvtpd2udq, //!< Instruction 'vcvtpd2udq' {AVX512_F+VL}. + kIdVcvtpd2uqq, //!< Instruction 'vcvtpd2uqq' {AVX512_DQ+VL}. + kIdVcvtph2dq, //!< Instruction 'vcvtph2dq' {AVX512_FP16+VL}. + kIdVcvtph2pd, //!< Instruction 'vcvtph2pd' {AVX512_FP16+VL}. + kIdVcvtph2ps, //!< Instruction 'vcvtph2ps' {AVX512_F+VL & F16C}. + kIdVcvtph2psx, //!< Instruction 'vcvtph2psx' {AVX512_FP16+VL}. + kIdVcvtph2qq, //!< Instruction 'vcvtph2qq' {AVX512_FP16+VL}. + kIdVcvtph2udq, //!< Instruction 'vcvtph2udq' {AVX512_FP16+VL}. + kIdVcvtph2uqq, //!< Instruction 'vcvtph2uqq' {AVX512_FP16+VL}. + kIdVcvtph2uw, //!< Instruction 'vcvtph2uw' {AVX512_FP16+VL}. + kIdVcvtph2w, //!< Instruction 'vcvtph2w' {AVX512_FP16+VL}. + kIdVcvtps2dq, //!< Instruction 'vcvtps2dq' {AVX|AVX512_F+VL}. + kIdVcvtps2pd, //!< Instruction 'vcvtps2pd' {AVX|AVX512_F+VL}. + kIdVcvtps2ph, //!< Instruction 'vcvtps2ph' {AVX512_F+VL & F16C}. + kIdVcvtps2phx, //!< Instruction 'vcvtps2phx' {AVX512_FP16+VL}. + kIdVcvtps2qq, //!< Instruction 'vcvtps2qq' {AVX512_DQ+VL}. + kIdVcvtps2udq, //!< Instruction 'vcvtps2udq' {AVX512_F+VL}. + kIdVcvtps2uqq, //!< Instruction 'vcvtps2uqq' {AVX512_DQ+VL}. + kIdVcvtqq2pd, //!< Instruction 'vcvtqq2pd' {AVX512_DQ+VL}. + kIdVcvtqq2ph, //!< Instruction 'vcvtqq2ph' {AVX512_FP16+VL}. + kIdVcvtqq2ps, //!< Instruction 'vcvtqq2ps' {AVX512_DQ+VL}. + kIdVcvtsd2sh, //!< Instruction 'vcvtsd2sh' {AVX512_FP16}. + kIdVcvtsd2si, //!< Instruction 'vcvtsd2si' {AVX|AVX512_F}. + kIdVcvtsd2ss, //!< Instruction 'vcvtsd2ss' {AVX|AVX512_F}. + kIdVcvtsd2usi, //!< Instruction 'vcvtsd2usi' {AVX512_F}. + kIdVcvtsh2sd, //!< Instruction 'vcvtsh2sd' {AVX512_FP16}. + kIdVcvtsh2si, //!< Instruction 'vcvtsh2si' {AVX512_FP16}. + kIdVcvtsh2ss, //!< Instruction 'vcvtsh2ss' {AVX512_FP16}. + kIdVcvtsh2usi, //!< Instruction 'vcvtsh2usi' {AVX512_FP16}. + kIdVcvtsi2sd, //!< Instruction 'vcvtsi2sd' {AVX|AVX512_F}. + kIdVcvtsi2sh, //!< Instruction 'vcvtsi2sh' {AVX512_FP16}. + kIdVcvtsi2ss, //!< Instruction 'vcvtsi2ss' {AVX|AVX512_F}. + kIdVcvtss2sd, //!< Instruction 'vcvtss2sd' {AVX|AVX512_F}. + kIdVcvtss2sh, //!< Instruction 'vcvtss2sh' {AVX512_FP16}. + kIdVcvtss2si, //!< Instruction 'vcvtss2si' {AVX|AVX512_F}. + kIdVcvtss2usi, //!< Instruction 'vcvtss2usi' {AVX512_F}. + kIdVcvttpd2dq, //!< Instruction 'vcvttpd2dq' {AVX|AVX512_F+VL}. + kIdVcvttpd2qq, //!< Instruction 'vcvttpd2qq' {AVX512_F+VL}. + kIdVcvttpd2udq, //!< Instruction 'vcvttpd2udq' {AVX512_F+VL}. + kIdVcvttpd2uqq, //!< Instruction 'vcvttpd2uqq' {AVX512_DQ+VL}. + kIdVcvttph2dq, //!< Instruction 'vcvttph2dq' {AVX512_FP16+VL}. + kIdVcvttph2qq, //!< Instruction 'vcvttph2qq' {AVX512_FP16+VL}. + kIdVcvttph2udq, //!< Instruction 'vcvttph2udq' {AVX512_FP16+VL}. + kIdVcvttph2uqq, //!< Instruction 'vcvttph2uqq' {AVX512_FP16+VL}. + kIdVcvttph2uw, //!< Instruction 'vcvttph2uw' {AVX512_FP16+VL}. + kIdVcvttph2w, //!< Instruction 'vcvttph2w' {AVX512_FP16+VL}. + kIdVcvttps2dq, //!< Instruction 'vcvttps2dq' {AVX|AVX512_F+VL}. + kIdVcvttps2qq, //!< Instruction 'vcvttps2qq' {AVX512_DQ+VL}. + kIdVcvttps2udq, //!< Instruction 'vcvttps2udq' {AVX512_F+VL}. + kIdVcvttps2uqq, //!< Instruction 'vcvttps2uqq' {AVX512_DQ+VL}. + kIdVcvttsd2si, //!< Instruction 'vcvttsd2si' {AVX|AVX512_F}. + kIdVcvttsd2usi, //!< Instruction 'vcvttsd2usi' {AVX512_F}. + kIdVcvttsh2si, //!< Instruction 'vcvttsh2si' {AVX512_FP16}. + kIdVcvttsh2usi, //!< Instruction 'vcvttsh2usi' {AVX512_FP16}. + kIdVcvttss2si, //!< Instruction 'vcvttss2si' {AVX|AVX512_F}. + kIdVcvttss2usi, //!< Instruction 'vcvttss2usi' {AVX512_F}. + kIdVcvtudq2pd, //!< Instruction 'vcvtudq2pd' {AVX512_F+VL}. + kIdVcvtudq2ph, //!< Instruction 'vcvtudq2ph' {AVX512_FP16+VL}. + kIdVcvtudq2ps, //!< Instruction 'vcvtudq2ps' {AVX512_F+VL}. + kIdVcvtuqq2pd, //!< Instruction 'vcvtuqq2pd' {AVX512_DQ+VL}. + kIdVcvtuqq2ph, //!< Instruction 'vcvtuqq2ph' {AVX512_FP16+VL}. + kIdVcvtuqq2ps, //!< Instruction 'vcvtuqq2ps' {AVX512_DQ+VL}. + kIdVcvtusi2sd, //!< Instruction 'vcvtusi2sd' {AVX512_F}. + kIdVcvtusi2sh, //!< Instruction 'vcvtusi2sh' {AVX512_FP16}. + kIdVcvtusi2ss, //!< Instruction 'vcvtusi2ss' {AVX512_F}. + kIdVcvtuw2ph, //!< Instruction 'vcvtuw2ph' {AVX512_FP16+VL}. + kIdVcvtw2ph, //!< Instruction 'vcvtw2ph' {AVX512_FP16+VL}. + kIdVdbpsadbw, //!< Instruction 'vdbpsadbw' {AVX512_BW+VL}. + kIdVdivpd, //!< Instruction 'vdivpd' {AVX|AVX512_F+VL}. + kIdVdivph, //!< Instruction 'vdivph' {AVX512_FP16+VL}. + kIdVdivps, //!< Instruction 'vdivps' {AVX|AVX512_F+VL}. + kIdVdivsd, //!< Instruction 'vdivsd' {AVX|AVX512_F}. + kIdVdivsh, //!< Instruction 'vdivsh' {AVX512_FP16}. + kIdVdivss, //!< Instruction 'vdivss' {AVX|AVX512_F}. + kIdVdpbf16ps, //!< Instruction 'vdpbf16ps' {AVX512_BF16+VL}. + kIdVdppd, //!< Instruction 'vdppd' {AVX}. + kIdVdpps, //!< Instruction 'vdpps' {AVX}. + kIdVerr, //!< Instruction 'verr'. + kIdVerw, //!< Instruction 'verw'. + kIdVexp2pd, //!< Instruction 'vexp2pd' {AVX512_ER}. + kIdVexp2ps, //!< Instruction 'vexp2ps' {AVX512_ER}. + kIdVexpandpd, //!< Instruction 'vexpandpd' {AVX512_F+VL}. + kIdVexpandps, //!< Instruction 'vexpandps' {AVX512_F+VL}. + kIdVextractf128, //!< Instruction 'vextractf128' {AVX}. + kIdVextractf32x4, //!< Instruction 'vextractf32x4' {AVX512_F+VL}. + kIdVextractf32x8, //!< Instruction 'vextractf32x8' {AVX512_DQ}. + kIdVextractf64x2, //!< Instruction 'vextractf64x2' {AVX512_DQ+VL}. + kIdVextractf64x4, //!< Instruction 'vextractf64x4' {AVX512_F}. + kIdVextracti128, //!< Instruction 'vextracti128' {AVX2}. + kIdVextracti32x4, //!< Instruction 'vextracti32x4' {AVX512_F+VL}. + kIdVextracti32x8, //!< Instruction 'vextracti32x8' {AVX512_DQ}. + kIdVextracti64x2, //!< Instruction 'vextracti64x2' {AVX512_DQ+VL}. + kIdVextracti64x4, //!< Instruction 'vextracti64x4' {AVX512_F}. + kIdVextractps, //!< Instruction 'vextractps' {AVX|AVX512_F}. + kIdVfcmaddcph, //!< Instruction 'vfcmaddcph' {AVX512_FP16+VL}. + kIdVfcmaddcsh, //!< Instruction 'vfcmaddcsh' {AVX512_FP16}. + kIdVfcmulcph, //!< Instruction 'vfcmulcph' {AVX512_FP16+VL}. + kIdVfcmulcsh, //!< Instruction 'vfcmulcsh' {AVX512_FP16}. + kIdVfixupimmpd, //!< Instruction 'vfixupimmpd' {AVX512_F+VL}. + kIdVfixupimmps, //!< Instruction 'vfixupimmps' {AVX512_F+VL}. + kIdVfixupimmsd, //!< Instruction 'vfixupimmsd' {AVX512_F}. + kIdVfixupimmss, //!< Instruction 'vfixupimmss' {AVX512_F}. + kIdVfmadd132pd, //!< Instruction 'vfmadd132pd' {FMA|AVX512_F+VL}. + kIdVfmadd132ph, //!< Instruction 'vfmadd132ph' {AVX512_FP16+VL}. + kIdVfmadd132ps, //!< Instruction 'vfmadd132ps' {FMA|AVX512_F+VL}. + kIdVfmadd132sd, //!< Instruction 'vfmadd132sd' {FMA|AVX512_F}. + kIdVfmadd132sh, //!< Instruction 'vfmadd132sh' {AVX512_FP16}. + kIdVfmadd132ss, //!< Instruction 'vfmadd132ss' {FMA|AVX512_F}. + kIdVfmadd213pd, //!< Instruction 'vfmadd213pd' {FMA|AVX512_F+VL}. + kIdVfmadd213ph, //!< Instruction 'vfmadd213ph' {AVX512_FP16+VL}. + kIdVfmadd213ps, //!< Instruction 'vfmadd213ps' {FMA|AVX512_F+VL}. + kIdVfmadd213sd, //!< Instruction 'vfmadd213sd' {FMA|AVX512_F}. + kIdVfmadd213sh, //!< Instruction 'vfmadd213sh' {AVX512_FP16}. + kIdVfmadd213ss, //!< Instruction 'vfmadd213ss' {FMA|AVX512_F}. + kIdVfmadd231pd, //!< Instruction 'vfmadd231pd' {FMA|AVX512_F+VL}. + kIdVfmadd231ph, //!< Instruction 'vfmadd231ph' {AVX512_FP16+VL}. + kIdVfmadd231ps, //!< Instruction 'vfmadd231ps' {FMA|AVX512_F+VL}. + kIdVfmadd231sd, //!< Instruction 'vfmadd231sd' {FMA|AVX512_F}. + kIdVfmadd231sh, //!< Instruction 'vfmadd231sh' {AVX512_FP16}. + kIdVfmadd231ss, //!< Instruction 'vfmadd231ss' {FMA|AVX512_F}. + kIdVfmaddcph, //!< Instruction 'vfmaddcph' {AVX512_FP16+VL}. + kIdVfmaddcsh, //!< Instruction 'vfmaddcsh' {AVX512_FP16}. + kIdVfmaddpd, //!< Instruction 'vfmaddpd' {FMA4}. + kIdVfmaddps, //!< Instruction 'vfmaddps' {FMA4}. + kIdVfmaddsd, //!< Instruction 'vfmaddsd' {FMA4}. + kIdVfmaddss, //!< Instruction 'vfmaddss' {FMA4}. + kIdVfmaddsub132pd, //!< Instruction 'vfmaddsub132pd' {FMA|AVX512_F+VL}. + kIdVfmaddsub132ph, //!< Instruction 'vfmaddsub132ph' {AVX512_FP16+VL}. + kIdVfmaddsub132ps, //!< Instruction 'vfmaddsub132ps' {FMA|AVX512_F+VL}. + kIdVfmaddsub213pd, //!< Instruction 'vfmaddsub213pd' {FMA|AVX512_F+VL}. + kIdVfmaddsub213ph, //!< Instruction 'vfmaddsub213ph' {AVX512_FP16+VL}. + kIdVfmaddsub213ps, //!< Instruction 'vfmaddsub213ps' {FMA|AVX512_F+VL}. + kIdVfmaddsub231pd, //!< Instruction 'vfmaddsub231pd' {FMA|AVX512_F+VL}. + kIdVfmaddsub231ph, //!< Instruction 'vfmaddsub231ph' {AVX512_FP16+VL}. + kIdVfmaddsub231ps, //!< Instruction 'vfmaddsub231ps' {FMA|AVX512_F+VL}. + kIdVfmaddsubpd, //!< Instruction 'vfmaddsubpd' {FMA4}. + kIdVfmaddsubps, //!< Instruction 'vfmaddsubps' {FMA4}. + kIdVfmsub132pd, //!< Instruction 'vfmsub132pd' {FMA|AVX512_F+VL}. + kIdVfmsub132ph, //!< Instruction 'vfmsub132ph' {AVX512_FP16+VL}. + kIdVfmsub132ps, //!< Instruction 'vfmsub132ps' {FMA|AVX512_F+VL}. + kIdVfmsub132sd, //!< Instruction 'vfmsub132sd' {FMA|AVX512_F}. + kIdVfmsub132sh, //!< Instruction 'vfmsub132sh' {AVX512_FP16}. + kIdVfmsub132ss, //!< Instruction 'vfmsub132ss' {FMA|AVX512_F}. + kIdVfmsub213pd, //!< Instruction 'vfmsub213pd' {FMA|AVX512_F+VL}. + kIdVfmsub213ph, //!< Instruction 'vfmsub213ph' {AVX512_FP16+VL}. + kIdVfmsub213ps, //!< Instruction 'vfmsub213ps' {FMA|AVX512_F+VL}. + kIdVfmsub213sd, //!< Instruction 'vfmsub213sd' {FMA|AVX512_F}. + kIdVfmsub213sh, //!< Instruction 'vfmsub213sh' {AVX512_FP16}. + kIdVfmsub213ss, //!< Instruction 'vfmsub213ss' {FMA|AVX512_F}. + kIdVfmsub231pd, //!< Instruction 'vfmsub231pd' {FMA|AVX512_F+VL}. + kIdVfmsub231ph, //!< Instruction 'vfmsub231ph' {AVX512_FP16+VL}. + kIdVfmsub231ps, //!< Instruction 'vfmsub231ps' {FMA|AVX512_F+VL}. + kIdVfmsub231sd, //!< Instruction 'vfmsub231sd' {FMA|AVX512_F}. + kIdVfmsub231sh, //!< Instruction 'vfmsub231sh' {AVX512_FP16}. + kIdVfmsub231ss, //!< Instruction 'vfmsub231ss' {FMA|AVX512_F}. + kIdVfmsubadd132pd, //!< Instruction 'vfmsubadd132pd' {FMA|AVX512_F+VL}. + kIdVfmsubadd132ph, //!< Instruction 'vfmsubadd132ph' {AVX512_FP16+VL}. + kIdVfmsubadd132ps, //!< Instruction 'vfmsubadd132ps' {FMA|AVX512_F+VL}. + kIdVfmsubadd213pd, //!< Instruction 'vfmsubadd213pd' {FMA|AVX512_F+VL}. + kIdVfmsubadd213ph, //!< Instruction 'vfmsubadd213ph' {AVX512_FP16+VL}. + kIdVfmsubadd213ps, //!< Instruction 'vfmsubadd213ps' {FMA|AVX512_F+VL}. + kIdVfmsubadd231pd, //!< Instruction 'vfmsubadd231pd' {FMA|AVX512_F+VL}. + kIdVfmsubadd231ph, //!< Instruction 'vfmsubadd231ph' {AVX512_FP16+VL}. + kIdVfmsubadd231ps, //!< Instruction 'vfmsubadd231ps' {FMA|AVX512_F+VL}. + kIdVfmsubaddpd, //!< Instruction 'vfmsubaddpd' {FMA4}. + kIdVfmsubaddps, //!< Instruction 'vfmsubaddps' {FMA4}. + kIdVfmsubpd, //!< Instruction 'vfmsubpd' {FMA4}. + kIdVfmsubps, //!< Instruction 'vfmsubps' {FMA4}. + kIdVfmsubsd, //!< Instruction 'vfmsubsd' {FMA4}. + kIdVfmsubss, //!< Instruction 'vfmsubss' {FMA4}. + kIdVfmulcph, //!< Instruction 'vfmulcph' {AVX512_FP16+VL}. + kIdVfmulcsh, //!< Instruction 'vfmulcsh' {AVX512_FP16+VL}. + kIdVfnmadd132pd, //!< Instruction 'vfnmadd132pd' {FMA|AVX512_F+VL}. + kIdVfnmadd132ph, //!< Instruction 'vfnmadd132ph' {AVX512_FP16+VL}. + kIdVfnmadd132ps, //!< Instruction 'vfnmadd132ps' {FMA|AVX512_F+VL}. + kIdVfnmadd132sd, //!< Instruction 'vfnmadd132sd' {FMA|AVX512_F}. + kIdVfnmadd132sh, //!< Instruction 'vfnmadd132sh' {AVX512_FP16}. + kIdVfnmadd132ss, //!< Instruction 'vfnmadd132ss' {FMA|AVX512_F}. + kIdVfnmadd213pd, //!< Instruction 'vfnmadd213pd' {FMA|AVX512_F+VL}. + kIdVfnmadd213ph, //!< Instruction 'vfnmadd213ph' {AVX512_FP16+VL}. + kIdVfnmadd213ps, //!< Instruction 'vfnmadd213ps' {FMA|AVX512_F+VL}. + kIdVfnmadd213sd, //!< Instruction 'vfnmadd213sd' {FMA|AVX512_F}. + kIdVfnmadd213sh, //!< Instruction 'vfnmadd213sh' {AVX512_FP16}. + kIdVfnmadd213ss, //!< Instruction 'vfnmadd213ss' {FMA|AVX512_F}. + kIdVfnmadd231pd, //!< Instruction 'vfnmadd231pd' {FMA|AVX512_F+VL}. + kIdVfnmadd231ph, //!< Instruction 'vfnmadd231ph' {AVX512_FP16+VL}. + kIdVfnmadd231ps, //!< Instruction 'vfnmadd231ps' {FMA|AVX512_F+VL}. + kIdVfnmadd231sd, //!< Instruction 'vfnmadd231sd' {FMA|AVX512_F}. + kIdVfnmadd231sh, //!< Instruction 'vfnmadd231sh' {AVX512_FP16}. + kIdVfnmadd231ss, //!< Instruction 'vfnmadd231ss' {FMA|AVX512_F}. + kIdVfnmaddpd, //!< Instruction 'vfnmaddpd' {FMA4}. + kIdVfnmaddps, //!< Instruction 'vfnmaddps' {FMA4}. + kIdVfnmaddsd, //!< Instruction 'vfnmaddsd' {FMA4}. + kIdVfnmaddss, //!< Instruction 'vfnmaddss' {FMA4}. + kIdVfnmsub132pd, //!< Instruction 'vfnmsub132pd' {FMA|AVX512_F+VL}. + kIdVfnmsub132ph, //!< Instruction 'vfnmsub132ph' {AVX512_FP16+VL}. + kIdVfnmsub132ps, //!< Instruction 'vfnmsub132ps' {FMA|AVX512_F+VL}. + kIdVfnmsub132sd, //!< Instruction 'vfnmsub132sd' {FMA|AVX512_F}. + kIdVfnmsub132sh, //!< Instruction 'vfnmsub132sh' {AVX512_FP16}. + kIdVfnmsub132ss, //!< Instruction 'vfnmsub132ss' {FMA|AVX512_F}. + kIdVfnmsub213pd, //!< Instruction 'vfnmsub213pd' {FMA|AVX512_F+VL}. + kIdVfnmsub213ph, //!< Instruction 'vfnmsub213ph' {AVX512_FP16+VL}. + kIdVfnmsub213ps, //!< Instruction 'vfnmsub213ps' {FMA|AVX512_F+VL}. + kIdVfnmsub213sd, //!< Instruction 'vfnmsub213sd' {FMA|AVX512_F}. + kIdVfnmsub213sh, //!< Instruction 'vfnmsub213sh' {AVX512_FP16}. + kIdVfnmsub213ss, //!< Instruction 'vfnmsub213ss' {FMA|AVX512_F}. + kIdVfnmsub231pd, //!< Instruction 'vfnmsub231pd' {FMA|AVX512_F+VL}. + kIdVfnmsub231ph, //!< Instruction 'vfnmsub231ph' {AVX512_FP16+VL}. + kIdVfnmsub231ps, //!< Instruction 'vfnmsub231ps' {FMA|AVX512_F+VL}. + kIdVfnmsub231sd, //!< Instruction 'vfnmsub231sd' {FMA|AVX512_F}. + kIdVfnmsub231sh, //!< Instruction 'vfnmsub231sh' {AVX512_FP16}. + kIdVfnmsub231ss, //!< Instruction 'vfnmsub231ss' {FMA|AVX512_F}. + kIdVfnmsubpd, //!< Instruction 'vfnmsubpd' {FMA4}. + kIdVfnmsubps, //!< Instruction 'vfnmsubps' {FMA4}. + kIdVfnmsubsd, //!< Instruction 'vfnmsubsd' {FMA4}. + kIdVfnmsubss, //!< Instruction 'vfnmsubss' {FMA4}. + kIdVfpclasspd, //!< Instruction 'vfpclasspd' {AVX512_DQ+VL}. + kIdVfpclassph, //!< Instruction 'vfpclassph' {AVX512_FP16+VL}. + kIdVfpclassps, //!< Instruction 'vfpclassps' {AVX512_DQ+VL}. + kIdVfpclasssd, //!< Instruction 'vfpclasssd' {AVX512_DQ}. + kIdVfpclasssh, //!< Instruction 'vfpclasssh' {AVX512_FP16}. + kIdVfpclassss, //!< Instruction 'vfpclassss' {AVX512_DQ}. + kIdVfrczpd, //!< Instruction 'vfrczpd' {XOP}. + kIdVfrczps, //!< Instruction 'vfrczps' {XOP}. + kIdVfrczsd, //!< Instruction 'vfrczsd' {XOP}. + kIdVfrczss, //!< Instruction 'vfrczss' {XOP}. + kIdVgatherdpd, //!< Instruction 'vgatherdpd' {AVX2|AVX512_F+VL}. + kIdVgatherdps, //!< Instruction 'vgatherdps' {AVX2|AVX512_F+VL}. + kIdVgatherpf0dpd, //!< Instruction 'vgatherpf0dpd' {AVX512_PF}. + kIdVgatherpf0dps, //!< Instruction 'vgatherpf0dps' {AVX512_PF}. + kIdVgatherpf0qpd, //!< Instruction 'vgatherpf0qpd' {AVX512_PF}. + kIdVgatherpf0qps, //!< Instruction 'vgatherpf0qps' {AVX512_PF}. + kIdVgatherpf1dpd, //!< Instruction 'vgatherpf1dpd' {AVX512_PF}. + kIdVgatherpf1dps, //!< Instruction 'vgatherpf1dps' {AVX512_PF}. + kIdVgatherpf1qpd, //!< Instruction 'vgatherpf1qpd' {AVX512_PF}. + kIdVgatherpf1qps, //!< Instruction 'vgatherpf1qps' {AVX512_PF}. + kIdVgatherqpd, //!< Instruction 'vgatherqpd' {AVX2|AVX512_F+VL}. + kIdVgatherqps, //!< Instruction 'vgatherqps' {AVX2|AVX512_F+VL}. + kIdVgetexppd, //!< Instruction 'vgetexppd' {AVX512_F+VL}. + kIdVgetexpph, //!< Instruction 'vgetexpph' {AVX512_FP16+VL}. + kIdVgetexpps, //!< Instruction 'vgetexpps' {AVX512_F+VL}. + kIdVgetexpsd, //!< Instruction 'vgetexpsd' {AVX512_F}. + kIdVgetexpsh, //!< Instruction 'vgetexpsh' {AVX512_FP16}. + kIdVgetexpss, //!< Instruction 'vgetexpss' {AVX512_F}. + kIdVgetmantpd, //!< Instruction 'vgetmantpd' {AVX512_F+VL}. + kIdVgetmantph, //!< Instruction 'vgetmantph' {AVX512_FP16+VL}. + kIdVgetmantps, //!< Instruction 'vgetmantps' {AVX512_F+VL}. + kIdVgetmantsd, //!< Instruction 'vgetmantsd' {AVX512_F}. + kIdVgetmantsh, //!< Instruction 'vgetmantsh' {AVX512_FP16}. + kIdVgetmantss, //!< Instruction 'vgetmantss' {AVX512_F}. + kIdVgf2p8affineinvqb, //!< Instruction 'vgf2p8affineinvqb' {AVX|AVX512_F+VL & GFNI}. + kIdVgf2p8affineqb, //!< Instruction 'vgf2p8affineqb' {AVX|AVX512_F+VL & GFNI}. + kIdVgf2p8mulb, //!< Instruction 'vgf2p8mulb' {AVX|AVX512_F+VL & GFNI}. + kIdVhaddpd, //!< Instruction 'vhaddpd' {AVX}. + kIdVhaddps, //!< Instruction 'vhaddps' {AVX}. + kIdVhsubpd, //!< Instruction 'vhsubpd' {AVX}. + kIdVhsubps, //!< Instruction 'vhsubps' {AVX}. + kIdVinsertf128, //!< Instruction 'vinsertf128' {AVX}. + kIdVinsertf32x4, //!< Instruction 'vinsertf32x4' {AVX512_F+VL}. + kIdVinsertf32x8, //!< Instruction 'vinsertf32x8' {AVX512_DQ}. + kIdVinsertf64x2, //!< Instruction 'vinsertf64x2' {AVX512_DQ+VL}. + kIdVinsertf64x4, //!< Instruction 'vinsertf64x4' {AVX512_F}. + kIdVinserti128, //!< Instruction 'vinserti128' {AVX2}. + kIdVinserti32x4, //!< Instruction 'vinserti32x4' {AVX512_F+VL}. + kIdVinserti32x8, //!< Instruction 'vinserti32x8' {AVX512_DQ}. + kIdVinserti64x2, //!< Instruction 'vinserti64x2' {AVX512_DQ+VL}. + kIdVinserti64x4, //!< Instruction 'vinserti64x4' {AVX512_F}. + kIdVinsertps, //!< Instruction 'vinsertps' {AVX|AVX512_F}. + kIdVlddqu, //!< Instruction 'vlddqu' {AVX}. + kIdVldmxcsr, //!< Instruction 'vldmxcsr' {AVX}. + kIdVmaskmovdqu, //!< Instruction 'vmaskmovdqu' {AVX}. + kIdVmaskmovpd, //!< Instruction 'vmaskmovpd' {AVX}. + kIdVmaskmovps, //!< Instruction 'vmaskmovps' {AVX}. + kIdVmaxpd, //!< Instruction 'vmaxpd' {AVX|AVX512_F+VL}. + kIdVmaxph, //!< Instruction 'vmaxph' {AVX512_FP16+VL}. + kIdVmaxps, //!< Instruction 'vmaxps' {AVX|AVX512_F+VL}. + kIdVmaxsd, //!< Instruction 'vmaxsd' {AVX|AVX512_F}. + kIdVmaxsh, //!< Instruction 'vmaxsh' {AVX512_FP16}. + kIdVmaxss, //!< Instruction 'vmaxss' {AVX|AVX512_F}. + kIdVmcall, //!< Instruction 'vmcall' {VMX}. + kIdVmclear, //!< Instruction 'vmclear' {VMX}. + kIdVmfunc, //!< Instruction 'vmfunc' {VMX}. + kIdVmgexit, //!< Instruction 'vmgexit' {SEV_ES}. + kIdVminpd, //!< Instruction 'vminpd' {AVX|AVX512_F+VL}. + kIdVminph, //!< Instruction 'vminph' {AVX512_FP16+VL}. + kIdVminps, //!< Instruction 'vminps' {AVX|AVX512_F+VL}. + kIdVminsd, //!< Instruction 'vminsd' {AVX|AVX512_F}. + kIdVminsh, //!< Instruction 'vminsh' {AVX512_FP16}. + kIdVminss, //!< Instruction 'vminss' {AVX|AVX512_F}. + kIdVmlaunch, //!< Instruction 'vmlaunch' {VMX}. + kIdVmload, //!< Instruction 'vmload' {SVM}. + kIdVmmcall, //!< Instruction 'vmmcall' {SVM}. + kIdVmovapd, //!< Instruction 'vmovapd' {AVX|AVX512_F+VL}. + kIdVmovaps, //!< Instruction 'vmovaps' {AVX|AVX512_F+VL}. + kIdVmovd, //!< Instruction 'vmovd' {AVX|AVX512_F}. + kIdVmovddup, //!< Instruction 'vmovddup' {AVX|AVX512_F+VL}. + kIdVmovdqa, //!< Instruction 'vmovdqa' {AVX}. + kIdVmovdqa32, //!< Instruction 'vmovdqa32' {AVX512_F+VL}. + kIdVmovdqa64, //!< Instruction 'vmovdqa64' {AVX512_F+VL}. + kIdVmovdqu, //!< Instruction 'vmovdqu' {AVX}. + kIdVmovdqu16, //!< Instruction 'vmovdqu16' {AVX512_BW+VL}. + kIdVmovdqu32, //!< Instruction 'vmovdqu32' {AVX512_F+VL}. + kIdVmovdqu64, //!< Instruction 'vmovdqu64' {AVX512_F+VL}. + kIdVmovdqu8, //!< Instruction 'vmovdqu8' {AVX512_BW+VL}. + kIdVmovhlps, //!< Instruction 'vmovhlps' {AVX|AVX512_F}. + kIdVmovhpd, //!< Instruction 'vmovhpd' {AVX|AVX512_F}. + kIdVmovhps, //!< Instruction 'vmovhps' {AVX|AVX512_F}. + kIdVmovlhps, //!< Instruction 'vmovlhps' {AVX|AVX512_F}. + kIdVmovlpd, //!< Instruction 'vmovlpd' {AVX|AVX512_F}. + kIdVmovlps, //!< Instruction 'vmovlps' {AVX|AVX512_F}. + kIdVmovmskpd, //!< Instruction 'vmovmskpd' {AVX}. + kIdVmovmskps, //!< Instruction 'vmovmskps' {AVX}. + kIdVmovntdq, //!< Instruction 'vmovntdq' {AVX|AVX512_F+VL}. + kIdVmovntdqa, //!< Instruction 'vmovntdqa' {AVX|AVX2|AVX512_F+VL}. + kIdVmovntpd, //!< Instruction 'vmovntpd' {AVX|AVX512_F+VL}. + kIdVmovntps, //!< Instruction 'vmovntps' {AVX|AVX512_F+VL}. + kIdVmovq, //!< Instruction 'vmovq' {AVX|AVX512_F}. + kIdVmovsd, //!< Instruction 'vmovsd' {AVX|AVX512_F}. + kIdVmovsh, //!< Instruction 'vmovsh' {AVX512_FP16}. + kIdVmovshdup, //!< Instruction 'vmovshdup' {AVX|AVX512_F+VL}. + kIdVmovsldup, //!< Instruction 'vmovsldup' {AVX|AVX512_F+VL}. + kIdVmovss, //!< Instruction 'vmovss' {AVX|AVX512_F}. + kIdVmovupd, //!< Instruction 'vmovupd' {AVX|AVX512_F+VL}. + kIdVmovups, //!< Instruction 'vmovups' {AVX|AVX512_F+VL}. + kIdVmovw, //!< Instruction 'vmovw' {AVX512_FP16}. + kIdVmpsadbw, //!< Instruction 'vmpsadbw' {AVX|AVX2}. + kIdVmptrld, //!< Instruction 'vmptrld' {VMX}. + kIdVmptrst, //!< Instruction 'vmptrst' {VMX}. + kIdVmread, //!< Instruction 'vmread' {VMX}. + kIdVmresume, //!< Instruction 'vmresume' {VMX}. + kIdVmrun, //!< Instruction 'vmrun' {SVM}. + kIdVmsave, //!< Instruction 'vmsave' {SVM}. + kIdVmulpd, //!< Instruction 'vmulpd' {AVX|AVX512_F+VL}. + kIdVmulph, //!< Instruction 'vmulph' {AVX512_FP16+VL}. + kIdVmulps, //!< Instruction 'vmulps' {AVX|AVX512_F+VL}. + kIdVmulsd, //!< Instruction 'vmulsd' {AVX|AVX512_F}. + kIdVmulsh, //!< Instruction 'vmulsh' {AVX512_FP16}. + kIdVmulss, //!< Instruction 'vmulss' {AVX|AVX512_F}. + kIdVmwrite, //!< Instruction 'vmwrite' {VMX}. + kIdVmxoff, //!< Instruction 'vmxoff' {VMX}. + kIdVmxon, //!< Instruction 'vmxon' {VMX}. + kIdVorpd, //!< Instruction 'vorpd' {AVX|AVX512_DQ+VL}. + kIdVorps, //!< Instruction 'vorps' {AVX|AVX512_DQ+VL}. + kIdVp2intersectd, //!< Instruction 'vp2intersectd' {AVX512_VP2INTERSECT+VL}. + kIdVp2intersectq, //!< Instruction 'vp2intersectq' {AVX512_VP2INTERSECT+VL}. + kIdVp4dpwssd, //!< Instruction 'vp4dpwssd' {AVX512_4VNNIW}. + kIdVp4dpwssds, //!< Instruction 'vp4dpwssds' {AVX512_4VNNIW}. + kIdVpabsb, //!< Instruction 'vpabsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpabsd, //!< Instruction 'vpabsd' {AVX|AVX2|AVX512_F+VL}. + kIdVpabsq, //!< Instruction 'vpabsq' {AVX512_F+VL}. + kIdVpabsw, //!< Instruction 'vpabsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpackssdw, //!< Instruction 'vpackssdw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpacksswb, //!< Instruction 'vpacksswb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpackusdw, //!< Instruction 'vpackusdw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpackuswb, //!< Instruction 'vpackuswb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddb, //!< Instruction 'vpaddb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddd, //!< Instruction 'vpaddd' {AVX|AVX2|AVX512_F+VL}. + kIdVpaddq, //!< Instruction 'vpaddq' {AVX|AVX2|AVX512_F+VL}. + kIdVpaddsb, //!< Instruction 'vpaddsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddsw, //!< Instruction 'vpaddsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddusb, //!< Instruction 'vpaddusb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddusw, //!< Instruction 'vpaddusw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpaddw, //!< Instruction 'vpaddw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpalignr, //!< Instruction 'vpalignr' {AVX|AVX2|AVX512_BW+VL}. + kIdVpand, //!< Instruction 'vpand' {AVX|AVX2}. + kIdVpandd, //!< Instruction 'vpandd' {AVX512_F+VL}. + kIdVpandn, //!< Instruction 'vpandn' {AVX|AVX2}. + kIdVpandnd, //!< Instruction 'vpandnd' {AVX512_F+VL}. + kIdVpandnq, //!< Instruction 'vpandnq' {AVX512_F+VL}. + kIdVpandq, //!< Instruction 'vpandq' {AVX512_F+VL}. + kIdVpavgb, //!< Instruction 'vpavgb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpavgw, //!< Instruction 'vpavgw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpblendd, //!< Instruction 'vpblendd' {AVX2}. + kIdVpblendmb, //!< Instruction 'vpblendmb' {AVX512_BW+VL}. + kIdVpblendmd, //!< Instruction 'vpblendmd' {AVX512_F+VL}. + kIdVpblendmq, //!< Instruction 'vpblendmq' {AVX512_F+VL}. + kIdVpblendmw, //!< Instruction 'vpblendmw' {AVX512_BW+VL}. + kIdVpblendvb, //!< Instruction 'vpblendvb' {AVX|AVX2}. + kIdVpblendw, //!< Instruction 'vpblendw' {AVX|AVX2}. + kIdVpbroadcastb, //!< Instruction 'vpbroadcastb' {AVX2|AVX512_BW+VL}. + kIdVpbroadcastd, //!< Instruction 'vpbroadcastd' {AVX2|AVX512_F+VL}. + kIdVpbroadcastmb2q, //!< Instruction 'vpbroadcastmb2q' {AVX512_CD+VL}. + kIdVpbroadcastmw2d, //!< Instruction 'vpbroadcastmw2d' {AVX512_CD+VL}. + kIdVpbroadcastq, //!< Instruction 'vpbroadcastq' {AVX2|AVX512_F+VL}. + kIdVpbroadcastw, //!< Instruction 'vpbroadcastw' {AVX2|AVX512_BW+VL}. + kIdVpclmulqdq, //!< Instruction 'vpclmulqdq' {AVX|AVX512_F+VL & PCLMULQDQ|VPCLMULQDQ}. + kIdVpcmov, //!< Instruction 'vpcmov' {XOP}. + kIdVpcmpb, //!< Instruction 'vpcmpb' {AVX512_BW+VL}. + kIdVpcmpd, //!< Instruction 'vpcmpd' {AVX512_F+VL}. + kIdVpcmpeqb, //!< Instruction 'vpcmpeqb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpeqd, //!< Instruction 'vpcmpeqd' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpeqq, //!< Instruction 'vpcmpeqq' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpeqw, //!< Instruction 'vpcmpeqw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpestri, //!< Instruction 'vpcmpestri' {AVX}. + kIdVpcmpestrm, //!< Instruction 'vpcmpestrm' {AVX}. + kIdVpcmpgtb, //!< Instruction 'vpcmpgtb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpgtd, //!< Instruction 'vpcmpgtd' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpgtq, //!< Instruction 'vpcmpgtq' {AVX|AVX2|AVX512_F+VL}. + kIdVpcmpgtw, //!< Instruction 'vpcmpgtw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpcmpistri, //!< Instruction 'vpcmpistri' {AVX}. + kIdVpcmpistrm, //!< Instruction 'vpcmpistrm' {AVX}. + kIdVpcmpq, //!< Instruction 'vpcmpq' {AVX512_F+VL}. + kIdVpcmpub, //!< Instruction 'vpcmpub' {AVX512_BW+VL}. + kIdVpcmpud, //!< Instruction 'vpcmpud' {AVX512_F+VL}. + kIdVpcmpuq, //!< Instruction 'vpcmpuq' {AVX512_F+VL}. + kIdVpcmpuw, //!< Instruction 'vpcmpuw' {AVX512_BW+VL}. + kIdVpcmpw, //!< Instruction 'vpcmpw' {AVX512_BW+VL}. + kIdVpcomb, //!< Instruction 'vpcomb' {XOP}. + kIdVpcomd, //!< Instruction 'vpcomd' {XOP}. + kIdVpcompressb, //!< Instruction 'vpcompressb' {AVX512_VBMI2+VL}. + kIdVpcompressd, //!< Instruction 'vpcompressd' {AVX512_F+VL}. + kIdVpcompressq, //!< Instruction 'vpcompressq' {AVX512_F+VL}. + kIdVpcompressw, //!< Instruction 'vpcompressw' {AVX512_VBMI2+VL}. + kIdVpcomq, //!< Instruction 'vpcomq' {XOP}. + kIdVpcomub, //!< Instruction 'vpcomub' {XOP}. + kIdVpcomud, //!< Instruction 'vpcomud' {XOP}. + kIdVpcomuq, //!< Instruction 'vpcomuq' {XOP}. + kIdVpcomuw, //!< Instruction 'vpcomuw' {XOP}. + kIdVpcomw, //!< Instruction 'vpcomw' {XOP}. + kIdVpconflictd, //!< Instruction 'vpconflictd' {AVX512_CD+VL}. + kIdVpconflictq, //!< Instruction 'vpconflictq' {AVX512_CD+VL}. + kIdVpdpbssd, //!< Instruction 'vpdpbssd' {AVX_VNNI_INT8}. + kIdVpdpbssds, //!< Instruction 'vpdpbssds' {AVX_VNNI_INT8}. + kIdVpdpbsud, //!< Instruction 'vpdpbsud' {AVX_VNNI_INT8}. + kIdVpdpbsuds, //!< Instruction 'vpdpbsuds' {AVX_VNNI_INT8}. + kIdVpdpbusd, //!< Instruction 'vpdpbusd' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpbusds, //!< Instruction 'vpdpbusds' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpbuud, //!< Instruction 'vpdpbuud' {AVX_VNNI_INT8}. + kIdVpdpbuuds, //!< Instruction 'vpdpbuuds' {AVX_VNNI_INT8}. + kIdVpdpwssd, //!< Instruction 'vpdpwssd' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpwssds, //!< Instruction 'vpdpwssds' {AVX_VNNI|AVX512_VNNI+VL}. + kIdVpdpwsud, //!< Instruction 'vpdpwsud' {AVX_VNNI_INT16}. + kIdVpdpwsuds, //!< Instruction 'vpdpwsuds' {AVX_VNNI_INT16}. + kIdVpdpwusd, //!< Instruction 'vpdpwusd' {AVX_VNNI_INT16}. + kIdVpdpwusds, //!< Instruction 'vpdpwusds' {AVX_VNNI_INT16}. + kIdVpdpwuud, //!< Instruction 'vpdpwuud' {AVX_VNNI_INT16}. + kIdVpdpwuuds, //!< Instruction 'vpdpwuuds' {AVX_VNNI_INT16}. + kIdVperm2f128, //!< Instruction 'vperm2f128' {AVX}. + kIdVperm2i128, //!< Instruction 'vperm2i128' {AVX2}. + kIdVpermb, //!< Instruction 'vpermb' {AVX512_VBMI+VL}. + kIdVpermd, //!< Instruction 'vpermd' {AVX2|AVX512_F+VL}. + kIdVpermi2b, //!< Instruction 'vpermi2b' {AVX512_VBMI+VL}. + kIdVpermi2d, //!< Instruction 'vpermi2d' {AVX512_F+VL}. + kIdVpermi2pd, //!< Instruction 'vpermi2pd' {AVX512_F+VL}. + kIdVpermi2ps, //!< Instruction 'vpermi2ps' {AVX512_F+VL}. + kIdVpermi2q, //!< Instruction 'vpermi2q' {AVX512_F+VL}. + kIdVpermi2w, //!< Instruction 'vpermi2w' {AVX512_BW+VL}. + kIdVpermil2pd, //!< Instruction 'vpermil2pd' {XOP}. + kIdVpermil2ps, //!< Instruction 'vpermil2ps' {XOP}. + kIdVpermilpd, //!< Instruction 'vpermilpd' {AVX|AVX512_F+VL}. + kIdVpermilps, //!< Instruction 'vpermilps' {AVX|AVX512_F+VL}. + kIdVpermpd, //!< Instruction 'vpermpd' {AVX2|AVX512_F+VL}. + kIdVpermps, //!< Instruction 'vpermps' {AVX2|AVX512_F+VL}. + kIdVpermq, //!< Instruction 'vpermq' {AVX2|AVX512_F+VL}. + kIdVpermt2b, //!< Instruction 'vpermt2b' {AVX512_VBMI+VL}. + kIdVpermt2d, //!< Instruction 'vpermt2d' {AVX512_F+VL}. + kIdVpermt2pd, //!< Instruction 'vpermt2pd' {AVX512_F+VL}. + kIdVpermt2ps, //!< Instruction 'vpermt2ps' {AVX512_F+VL}. + kIdVpermt2q, //!< Instruction 'vpermt2q' {AVX512_F+VL}. + kIdVpermt2w, //!< Instruction 'vpermt2w' {AVX512_BW+VL}. + kIdVpermw, //!< Instruction 'vpermw' {AVX512_BW+VL}. + kIdVpexpandb, //!< Instruction 'vpexpandb' {AVX512_VBMI2+VL}. + kIdVpexpandd, //!< Instruction 'vpexpandd' {AVX512_F+VL}. + kIdVpexpandq, //!< Instruction 'vpexpandq' {AVX512_F+VL}. + kIdVpexpandw, //!< Instruction 'vpexpandw' {AVX512_VBMI2+VL}. + kIdVpextrb, //!< Instruction 'vpextrb' {AVX|AVX512_BW}. + kIdVpextrd, //!< Instruction 'vpextrd' {AVX|AVX512_DQ}. + kIdVpextrq, //!< Instruction 'vpextrq' {AVX|AVX512_DQ} (X64). + kIdVpextrw, //!< Instruction 'vpextrw' {AVX|AVX512_BW}. + kIdVpgatherdd, //!< Instruction 'vpgatherdd' {AVX2|AVX512_F+VL}. + kIdVpgatherdq, //!< Instruction 'vpgatherdq' {AVX2|AVX512_F+VL}. + kIdVpgatherqd, //!< Instruction 'vpgatherqd' {AVX2|AVX512_F+VL}. + kIdVpgatherqq, //!< Instruction 'vpgatherqq' {AVX2|AVX512_F+VL}. + kIdVphaddbd, //!< Instruction 'vphaddbd' {XOP}. + kIdVphaddbq, //!< Instruction 'vphaddbq' {XOP}. + kIdVphaddbw, //!< Instruction 'vphaddbw' {XOP}. + kIdVphaddd, //!< Instruction 'vphaddd' {AVX|AVX2}. + kIdVphadddq, //!< Instruction 'vphadddq' {XOP}. + kIdVphaddsw, //!< Instruction 'vphaddsw' {AVX|AVX2}. + kIdVphaddubd, //!< Instruction 'vphaddubd' {XOP}. + kIdVphaddubq, //!< Instruction 'vphaddubq' {XOP}. + kIdVphaddubw, //!< Instruction 'vphaddubw' {XOP}. + kIdVphaddudq, //!< Instruction 'vphaddudq' {XOP}. + kIdVphadduwd, //!< Instruction 'vphadduwd' {XOP}. + kIdVphadduwq, //!< Instruction 'vphadduwq' {XOP}. + kIdVphaddw, //!< Instruction 'vphaddw' {AVX|AVX2}. + kIdVphaddwd, //!< Instruction 'vphaddwd' {XOP}. + kIdVphaddwq, //!< Instruction 'vphaddwq' {XOP}. + kIdVphminposuw, //!< Instruction 'vphminposuw' {AVX}. + kIdVphsubbw, //!< Instruction 'vphsubbw' {XOP}. + kIdVphsubd, //!< Instruction 'vphsubd' {AVX|AVX2}. + kIdVphsubdq, //!< Instruction 'vphsubdq' {XOP}. + kIdVphsubsw, //!< Instruction 'vphsubsw' {AVX|AVX2}. + kIdVphsubw, //!< Instruction 'vphsubw' {AVX|AVX2}. + kIdVphsubwd, //!< Instruction 'vphsubwd' {XOP}. + kIdVpinsrb, //!< Instruction 'vpinsrb' {AVX|AVX512_BW}. + kIdVpinsrd, //!< Instruction 'vpinsrd' {AVX|AVX512_DQ}. + kIdVpinsrq, //!< Instruction 'vpinsrq' {AVX|AVX512_DQ} (X64). + kIdVpinsrw, //!< Instruction 'vpinsrw' {AVX|AVX512_BW}. + kIdVplzcntd, //!< Instruction 'vplzcntd' {AVX512_CD+VL}. + kIdVplzcntq, //!< Instruction 'vplzcntq' {AVX512_CD+VL}. + kIdVpmacsdd, //!< Instruction 'vpmacsdd' {XOP}. + kIdVpmacsdqh, //!< Instruction 'vpmacsdqh' {XOP}. + kIdVpmacsdql, //!< Instruction 'vpmacsdql' {XOP}. + kIdVpmacssdd, //!< Instruction 'vpmacssdd' {XOP}. + kIdVpmacssdqh, //!< Instruction 'vpmacssdqh' {XOP}. + kIdVpmacssdql, //!< Instruction 'vpmacssdql' {XOP}. + kIdVpmacsswd, //!< Instruction 'vpmacsswd' {XOP}. + kIdVpmacssww, //!< Instruction 'vpmacssww' {XOP}. + kIdVpmacswd, //!< Instruction 'vpmacswd' {XOP}. + kIdVpmacsww, //!< Instruction 'vpmacsww' {XOP}. + kIdVpmadcsswd, //!< Instruction 'vpmadcsswd' {XOP}. + kIdVpmadcswd, //!< Instruction 'vpmadcswd' {XOP}. + kIdVpmadd52huq, //!< Instruction 'vpmadd52huq' {AVX_IFMA|AVX512_IFMA+VL}. + kIdVpmadd52luq, //!< Instruction 'vpmadd52luq' {AVX_IFMA|AVX512_IFMA+VL}. + kIdVpmaddubsw, //!< Instruction 'vpmaddubsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaddwd, //!< Instruction 'vpmaddwd' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaskmovd, //!< Instruction 'vpmaskmovd' {AVX2}. + kIdVpmaskmovq, //!< Instruction 'vpmaskmovq' {AVX2}. + kIdVpmaxsb, //!< Instruction 'vpmaxsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaxsd, //!< Instruction 'vpmaxsd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmaxsq, //!< Instruction 'vpmaxsq' {AVX512_F+VL}. + kIdVpmaxsw, //!< Instruction 'vpmaxsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaxub, //!< Instruction 'vpmaxub' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmaxud, //!< Instruction 'vpmaxud' {AVX|AVX2|AVX512_F+VL}. + kIdVpmaxuq, //!< Instruction 'vpmaxuq' {AVX512_F+VL}. + kIdVpmaxuw, //!< Instruction 'vpmaxuw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminsb, //!< Instruction 'vpminsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminsd, //!< Instruction 'vpminsd' {AVX|AVX2|AVX512_F+VL}. + kIdVpminsq, //!< Instruction 'vpminsq' {AVX512_F+VL}. + kIdVpminsw, //!< Instruction 'vpminsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminub, //!< Instruction 'vpminub' {AVX|AVX2|AVX512_BW+VL}. + kIdVpminud, //!< Instruction 'vpminud' {AVX|AVX2|AVX512_F+VL}. + kIdVpminuq, //!< Instruction 'vpminuq' {AVX512_F+VL}. + kIdVpminuw, //!< Instruction 'vpminuw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmovb2m, //!< Instruction 'vpmovb2m' {AVX512_BW+VL}. + kIdVpmovd2m, //!< Instruction 'vpmovd2m' {AVX512_DQ+VL}. + kIdVpmovdb, //!< Instruction 'vpmovdb' {AVX512_F+VL}. + kIdVpmovdw, //!< Instruction 'vpmovdw' {AVX512_F+VL}. + kIdVpmovm2b, //!< Instruction 'vpmovm2b' {AVX512_BW+VL}. + kIdVpmovm2d, //!< Instruction 'vpmovm2d' {AVX512_DQ+VL}. + kIdVpmovm2q, //!< Instruction 'vpmovm2q' {AVX512_DQ+VL}. + kIdVpmovm2w, //!< Instruction 'vpmovm2w' {AVX512_BW+VL}. + kIdVpmovmskb, //!< Instruction 'vpmovmskb' {AVX|AVX2}. + kIdVpmovq2m, //!< Instruction 'vpmovq2m' {AVX512_DQ+VL}. + kIdVpmovqb, //!< Instruction 'vpmovqb' {AVX512_F+VL}. + kIdVpmovqd, //!< Instruction 'vpmovqd' {AVX512_F+VL}. + kIdVpmovqw, //!< Instruction 'vpmovqw' {AVX512_F+VL}. + kIdVpmovsdb, //!< Instruction 'vpmovsdb' {AVX512_F+VL}. + kIdVpmovsdw, //!< Instruction 'vpmovsdw' {AVX512_F+VL}. + kIdVpmovsqb, //!< Instruction 'vpmovsqb' {AVX512_F+VL}. + kIdVpmovsqd, //!< Instruction 'vpmovsqd' {AVX512_F+VL}. + kIdVpmovsqw, //!< Instruction 'vpmovsqw' {AVX512_F+VL}. + kIdVpmovswb, //!< Instruction 'vpmovswb' {AVX512_BW+VL}. + kIdVpmovsxbd, //!< Instruction 'vpmovsxbd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxbq, //!< Instruction 'vpmovsxbq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxbw, //!< Instruction 'vpmovsxbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmovsxdq, //!< Instruction 'vpmovsxdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxwd, //!< Instruction 'vpmovsxwd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovsxwq, //!< Instruction 'vpmovsxwq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovusdb, //!< Instruction 'vpmovusdb' {AVX512_F+VL}. + kIdVpmovusdw, //!< Instruction 'vpmovusdw' {AVX512_F+VL}. + kIdVpmovusqb, //!< Instruction 'vpmovusqb' {AVX512_F+VL}. + kIdVpmovusqd, //!< Instruction 'vpmovusqd' {AVX512_F+VL}. + kIdVpmovusqw, //!< Instruction 'vpmovusqw' {AVX512_F+VL}. + kIdVpmovuswb, //!< Instruction 'vpmovuswb' {AVX512_BW+VL}. + kIdVpmovw2m, //!< Instruction 'vpmovw2m' {AVX512_BW+VL}. + kIdVpmovwb, //!< Instruction 'vpmovwb' {AVX512_BW+VL}. + kIdVpmovzxbd, //!< Instruction 'vpmovzxbd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxbq, //!< Instruction 'vpmovzxbq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxbw, //!< Instruction 'vpmovzxbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmovzxdq, //!< Instruction 'vpmovzxdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxwd, //!< Instruction 'vpmovzxwd' {AVX|AVX2|AVX512_F+VL}. + kIdVpmovzxwq, //!< Instruction 'vpmovzxwq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmuldq, //!< Instruction 'vpmuldq' {AVX|AVX2|AVX512_F+VL}. + kIdVpmulhrsw, //!< Instruction 'vpmulhrsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmulhuw, //!< Instruction 'vpmulhuw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmulhw, //!< Instruction 'vpmulhw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmulld, //!< Instruction 'vpmulld' {AVX|AVX2|AVX512_F+VL}. + kIdVpmullq, //!< Instruction 'vpmullq' {AVX512_DQ+VL}. + kIdVpmullw, //!< Instruction 'vpmullw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpmultishiftqb, //!< Instruction 'vpmultishiftqb' {AVX512_VBMI+VL}. + kIdVpmuludq, //!< Instruction 'vpmuludq' {AVX|AVX2|AVX512_F+VL}. + kIdVpopcntb, //!< Instruction 'vpopcntb' {AVX512_BITALG+VL}. + kIdVpopcntd, //!< Instruction 'vpopcntd' {AVX512_VPOPCNTDQ+VL}. + kIdVpopcntq, //!< Instruction 'vpopcntq' {AVX512_VPOPCNTDQ+VL}. + kIdVpopcntw, //!< Instruction 'vpopcntw' {AVX512_BITALG+VL}. + kIdVpor, //!< Instruction 'vpor' {AVX|AVX2}. + kIdVpord, //!< Instruction 'vpord' {AVX512_F+VL}. + kIdVporq, //!< Instruction 'vporq' {AVX512_F+VL}. + kIdVpperm, //!< Instruction 'vpperm' {XOP}. + kIdVprold, //!< Instruction 'vprold' {AVX512_F+VL}. + kIdVprolq, //!< Instruction 'vprolq' {AVX512_F+VL}. + kIdVprolvd, //!< Instruction 'vprolvd' {AVX512_F+VL}. + kIdVprolvq, //!< Instruction 'vprolvq' {AVX512_F+VL}. + kIdVprord, //!< Instruction 'vprord' {AVX512_F+VL}. + kIdVprorq, //!< Instruction 'vprorq' {AVX512_F+VL}. + kIdVprorvd, //!< Instruction 'vprorvd' {AVX512_F+VL}. + kIdVprorvq, //!< Instruction 'vprorvq' {AVX512_F+VL}. + kIdVprotb, //!< Instruction 'vprotb' {XOP}. + kIdVprotd, //!< Instruction 'vprotd' {XOP}. + kIdVprotq, //!< Instruction 'vprotq' {XOP}. + kIdVprotw, //!< Instruction 'vprotw' {XOP}. + kIdVpsadbw, //!< Instruction 'vpsadbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpscatterdd, //!< Instruction 'vpscatterdd' {AVX512_F+VL}. + kIdVpscatterdq, //!< Instruction 'vpscatterdq' {AVX512_F+VL}. + kIdVpscatterqd, //!< Instruction 'vpscatterqd' {AVX512_F+VL}. + kIdVpscatterqq, //!< Instruction 'vpscatterqq' {AVX512_F+VL}. + kIdVpshab, //!< Instruction 'vpshab' {XOP}. + kIdVpshad, //!< Instruction 'vpshad' {XOP}. + kIdVpshaq, //!< Instruction 'vpshaq' {XOP}. + kIdVpshaw, //!< Instruction 'vpshaw' {XOP}. + kIdVpshlb, //!< Instruction 'vpshlb' {XOP}. + kIdVpshld, //!< Instruction 'vpshld' {XOP}. + kIdVpshldd, //!< Instruction 'vpshldd' {AVX512_VBMI2+VL}. + kIdVpshldq, //!< Instruction 'vpshldq' {AVX512_VBMI2+VL}. + kIdVpshldvd, //!< Instruction 'vpshldvd' {AVX512_VBMI2+VL}. + kIdVpshldvq, //!< Instruction 'vpshldvq' {AVX512_VBMI2+VL}. + kIdVpshldvw, //!< Instruction 'vpshldvw' {AVX512_VBMI2+VL}. + kIdVpshldw, //!< Instruction 'vpshldw' {AVX512_VBMI2+VL}. + kIdVpshlq, //!< Instruction 'vpshlq' {XOP}. + kIdVpshlw, //!< Instruction 'vpshlw' {XOP}. + kIdVpshrdd, //!< Instruction 'vpshrdd' {AVX512_VBMI2+VL}. + kIdVpshrdq, //!< Instruction 'vpshrdq' {AVX512_VBMI2+VL}. + kIdVpshrdvd, //!< Instruction 'vpshrdvd' {AVX512_VBMI2+VL}. + kIdVpshrdvq, //!< Instruction 'vpshrdvq' {AVX512_VBMI2+VL}. + kIdVpshrdvw, //!< Instruction 'vpshrdvw' {AVX512_VBMI2+VL}. + kIdVpshrdw, //!< Instruction 'vpshrdw' {AVX512_VBMI2+VL}. + kIdVpshufb, //!< Instruction 'vpshufb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpshufbitqmb, //!< Instruction 'vpshufbitqmb' {AVX512_BITALG+VL}. + kIdVpshufd, //!< Instruction 'vpshufd' {AVX|AVX2|AVX512_F+VL}. + kIdVpshufhw, //!< Instruction 'vpshufhw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpshuflw, //!< Instruction 'vpshuflw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsignb, //!< Instruction 'vpsignb' {AVX|AVX2}. + kIdVpsignd, //!< Instruction 'vpsignd' {AVX|AVX2}. + kIdVpsignw, //!< Instruction 'vpsignw' {AVX|AVX2}. + kIdVpslld, //!< Instruction 'vpslld' {AVX|AVX2|AVX512_F+VL}. + kIdVpslldq, //!< Instruction 'vpslldq' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsllq, //!< Instruction 'vpsllq' {AVX|AVX2|AVX512_F+VL}. + kIdVpsllvd, //!< Instruction 'vpsllvd' {AVX2|AVX512_F+VL}. + kIdVpsllvq, //!< Instruction 'vpsllvq' {AVX2|AVX512_F+VL}. + kIdVpsllvw, //!< Instruction 'vpsllvw' {AVX512_BW+VL}. + kIdVpsllw, //!< Instruction 'vpsllw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsrad, //!< Instruction 'vpsrad' {AVX|AVX2|AVX512_F+VL}. + kIdVpsraq, //!< Instruction 'vpsraq' {AVX512_F+VL}. + kIdVpsravd, //!< Instruction 'vpsravd' {AVX2|AVX512_F+VL}. + kIdVpsravq, //!< Instruction 'vpsravq' {AVX512_F+VL}. + kIdVpsravw, //!< Instruction 'vpsravw' {AVX512_BW+VL}. + kIdVpsraw, //!< Instruction 'vpsraw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsrld, //!< Instruction 'vpsrld' {AVX|AVX2|AVX512_F+VL}. + kIdVpsrldq, //!< Instruction 'vpsrldq' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsrlq, //!< Instruction 'vpsrlq' {AVX|AVX2|AVX512_F+VL}. + kIdVpsrlvd, //!< Instruction 'vpsrlvd' {AVX2|AVX512_F+VL}. + kIdVpsrlvq, //!< Instruction 'vpsrlvq' {AVX2|AVX512_F+VL}. + kIdVpsrlvw, //!< Instruction 'vpsrlvw' {AVX512_BW+VL}. + kIdVpsrlw, //!< Instruction 'vpsrlw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubb, //!< Instruction 'vpsubb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubd, //!< Instruction 'vpsubd' {AVX|AVX2|AVX512_F+VL}. + kIdVpsubq, //!< Instruction 'vpsubq' {AVX|AVX2|AVX512_F+VL}. + kIdVpsubsb, //!< Instruction 'vpsubsb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubsw, //!< Instruction 'vpsubsw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubusb, //!< Instruction 'vpsubusb' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubusw, //!< Instruction 'vpsubusw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpsubw, //!< Instruction 'vpsubw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpternlogd, //!< Instruction 'vpternlogd' {AVX512_F+VL}. + kIdVpternlogq, //!< Instruction 'vpternlogq' {AVX512_F+VL}. + kIdVptest, //!< Instruction 'vptest' {AVX}. + kIdVptestmb, //!< Instruction 'vptestmb' {AVX512_BW+VL}. + kIdVptestmd, //!< Instruction 'vptestmd' {AVX512_F+VL}. + kIdVptestmq, //!< Instruction 'vptestmq' {AVX512_F+VL}. + kIdVptestmw, //!< Instruction 'vptestmw' {AVX512_BW+VL}. + kIdVptestnmb, //!< Instruction 'vptestnmb' {AVX512_BW+VL}. + kIdVptestnmd, //!< Instruction 'vptestnmd' {AVX512_F+VL}. + kIdVptestnmq, //!< Instruction 'vptestnmq' {AVX512_F+VL}. + kIdVptestnmw, //!< Instruction 'vptestnmw' {AVX512_BW+VL}. + kIdVpunpckhbw, //!< Instruction 'vpunpckhbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpunpckhdq, //!< Instruction 'vpunpckhdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpckhqdq, //!< Instruction 'vpunpckhqdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpckhwd, //!< Instruction 'vpunpckhwd' {AVX|AVX2|AVX512_BW+VL}. + kIdVpunpcklbw, //!< Instruction 'vpunpcklbw' {AVX|AVX2|AVX512_BW+VL}. + kIdVpunpckldq, //!< Instruction 'vpunpckldq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpcklqdq, //!< Instruction 'vpunpcklqdq' {AVX|AVX2|AVX512_F+VL}. + kIdVpunpcklwd, //!< Instruction 'vpunpcklwd' {AVX|AVX2|AVX512_BW+VL}. + kIdVpxor, //!< Instruction 'vpxor' {AVX|AVX2}. + kIdVpxord, //!< Instruction 'vpxord' {AVX512_F+VL}. + kIdVpxorq, //!< Instruction 'vpxorq' {AVX512_F+VL}. + kIdVrangepd, //!< Instruction 'vrangepd' {AVX512_DQ+VL}. + kIdVrangeps, //!< Instruction 'vrangeps' {AVX512_DQ+VL}. + kIdVrangesd, //!< Instruction 'vrangesd' {AVX512_DQ}. + kIdVrangess, //!< Instruction 'vrangess' {AVX512_DQ}. + kIdVrcp14pd, //!< Instruction 'vrcp14pd' {AVX512_F+VL}. + kIdVrcp14ps, //!< Instruction 'vrcp14ps' {AVX512_F+VL}. + kIdVrcp14sd, //!< Instruction 'vrcp14sd' {AVX512_F}. + kIdVrcp14ss, //!< Instruction 'vrcp14ss' {AVX512_F}. + kIdVrcp28pd, //!< Instruction 'vrcp28pd' {AVX512_ER}. + kIdVrcp28ps, //!< Instruction 'vrcp28ps' {AVX512_ER}. + kIdVrcp28sd, //!< Instruction 'vrcp28sd' {AVX512_ER}. + kIdVrcp28ss, //!< Instruction 'vrcp28ss' {AVX512_ER}. + kIdVrcpph, //!< Instruction 'vrcpph' {AVX512_FP16}. + kIdVrcpps, //!< Instruction 'vrcpps' {AVX}. + kIdVrcpsh, //!< Instruction 'vrcpsh' {AVX512_FP16}. + kIdVrcpss, //!< Instruction 'vrcpss' {AVX}. + kIdVreducepd, //!< Instruction 'vreducepd' {AVX512_DQ+VL}. + kIdVreduceph, //!< Instruction 'vreduceph' {AVX512_FP16+VL}. + kIdVreduceps, //!< Instruction 'vreduceps' {AVX512_DQ+VL}. + kIdVreducesd, //!< Instruction 'vreducesd' {AVX512_DQ}. + kIdVreducesh, //!< Instruction 'vreducesh' {AVX512_FP16}. + kIdVreducess, //!< Instruction 'vreducess' {AVX512_DQ}. + kIdVrndscalepd, //!< Instruction 'vrndscalepd' {AVX512_F+VL}. + kIdVrndscaleph, //!< Instruction 'vrndscaleph' {AVX512_FP16+VL}. + kIdVrndscaleps, //!< Instruction 'vrndscaleps' {AVX512_F+VL}. + kIdVrndscalesd, //!< Instruction 'vrndscalesd' {AVX512_F}. + kIdVrndscalesh, //!< Instruction 'vrndscalesh' {AVX512_FP16}. + kIdVrndscaless, //!< Instruction 'vrndscaless' {AVX512_F}. + kIdVroundpd, //!< Instruction 'vroundpd' {AVX}. + kIdVroundps, //!< Instruction 'vroundps' {AVX}. + kIdVroundsd, //!< Instruction 'vroundsd' {AVX}. + kIdVroundss, //!< Instruction 'vroundss' {AVX}. + kIdVrsqrt14pd, //!< Instruction 'vrsqrt14pd' {AVX512_F+VL}. + kIdVrsqrt14ps, //!< Instruction 'vrsqrt14ps' {AVX512_F+VL}. + kIdVrsqrt14sd, //!< Instruction 'vrsqrt14sd' {AVX512_F}. + kIdVrsqrt14ss, //!< Instruction 'vrsqrt14ss' {AVX512_F}. + kIdVrsqrt28pd, //!< Instruction 'vrsqrt28pd' {AVX512_ER}. + kIdVrsqrt28ps, //!< Instruction 'vrsqrt28ps' {AVX512_ER}. + kIdVrsqrt28sd, //!< Instruction 'vrsqrt28sd' {AVX512_ER}. + kIdVrsqrt28ss, //!< Instruction 'vrsqrt28ss' {AVX512_ER}. + kIdVrsqrtph, //!< Instruction 'vrsqrtph' {AVX512_FP16+VL}. + kIdVrsqrtps, //!< Instruction 'vrsqrtps' {AVX}. + kIdVrsqrtsh, //!< Instruction 'vrsqrtsh' {AVX512_FP16}. + kIdVrsqrtss, //!< Instruction 'vrsqrtss' {AVX}. + kIdVscalefpd, //!< Instruction 'vscalefpd' {AVX512_F+VL}. + kIdVscalefph, //!< Instruction 'vscalefph' {AVX512_FP16+VL}. + kIdVscalefps, //!< Instruction 'vscalefps' {AVX512_F+VL}. + kIdVscalefsd, //!< Instruction 'vscalefsd' {AVX512_F}. + kIdVscalefsh, //!< Instruction 'vscalefsh' {AVX512_FP16}. + kIdVscalefss, //!< Instruction 'vscalefss' {AVX512_F}. + kIdVscatterdpd, //!< Instruction 'vscatterdpd' {AVX512_F+VL}. + kIdVscatterdps, //!< Instruction 'vscatterdps' {AVX512_F+VL}. + kIdVscatterpf0dpd, //!< Instruction 'vscatterpf0dpd' {AVX512_PF}. + kIdVscatterpf0dps, //!< Instruction 'vscatterpf0dps' {AVX512_PF}. + kIdVscatterpf0qpd, //!< Instruction 'vscatterpf0qpd' {AVX512_PF}. + kIdVscatterpf0qps, //!< Instruction 'vscatterpf0qps' {AVX512_PF}. + kIdVscatterpf1dpd, //!< Instruction 'vscatterpf1dpd' {AVX512_PF}. + kIdVscatterpf1dps, //!< Instruction 'vscatterpf1dps' {AVX512_PF}. + kIdVscatterpf1qpd, //!< Instruction 'vscatterpf1qpd' {AVX512_PF}. + kIdVscatterpf1qps, //!< Instruction 'vscatterpf1qps' {AVX512_PF}. + kIdVscatterqpd, //!< Instruction 'vscatterqpd' {AVX512_F+VL}. + kIdVscatterqps, //!< Instruction 'vscatterqps' {AVX512_F+VL}. + kIdVsha512msg1, //!< Instruction 'vsha512msg1' {AVX & SHA512}. + kIdVsha512msg2, //!< Instruction 'vsha512msg2' {AVX & SHA512}. + kIdVsha512rnds2, //!< Instruction 'vsha512rnds2' {AVX & SHA512}. + kIdVshuff32x4, //!< Instruction 'vshuff32x4' {AVX512_F+VL}. + kIdVshuff64x2, //!< Instruction 'vshuff64x2' {AVX512_F+VL}. + kIdVshufi32x4, //!< Instruction 'vshufi32x4' {AVX512_F+VL}. + kIdVshufi64x2, //!< Instruction 'vshufi64x2' {AVX512_F+VL}. + kIdVshufpd, //!< Instruction 'vshufpd' {AVX|AVX512_F+VL}. + kIdVshufps, //!< Instruction 'vshufps' {AVX|AVX512_F+VL}. + kIdVsm3msg1, //!< Instruction 'vsm3msg1' {AVX & SM3}. + kIdVsm3msg2, //!< Instruction 'vsm3msg2' {AVX & SM3}. + kIdVsm3rnds2, //!< Instruction 'vsm3rnds2' {AVX & SM3}. + kIdVsm4key4, //!< Instruction 'vsm4key4' {AVX & SM4}. + kIdVsm4rnds4, //!< Instruction 'vsm4rnds4' {AVX & SM4}. + kIdVsqrtpd, //!< Instruction 'vsqrtpd' {AVX|AVX512_F+VL}. + kIdVsqrtph, //!< Instruction 'vsqrtph' {AVX512_FP16+VL}. + kIdVsqrtps, //!< Instruction 'vsqrtps' {AVX|AVX512_F+VL}. + kIdVsqrtsd, //!< Instruction 'vsqrtsd' {AVX|AVX512_F}. + kIdVsqrtsh, //!< Instruction 'vsqrtsh' {AVX512_FP16}. + kIdVsqrtss, //!< Instruction 'vsqrtss' {AVX|AVX512_F}. + kIdVstmxcsr, //!< Instruction 'vstmxcsr' {AVX}. + kIdVsubpd, //!< Instruction 'vsubpd' {AVX|AVX512_F+VL}. + kIdVsubph, //!< Instruction 'vsubph' {AVX512_FP16+VL}. + kIdVsubps, //!< Instruction 'vsubps' {AVX|AVX512_F+VL}. + kIdVsubsd, //!< Instruction 'vsubsd' {AVX|AVX512_F}. + kIdVsubsh, //!< Instruction 'vsubsh' {AVX512_FP16}. + kIdVsubss, //!< Instruction 'vsubss' {AVX|AVX512_F}. + kIdVtestpd, //!< Instruction 'vtestpd' {AVX}. + kIdVtestps, //!< Instruction 'vtestps' {AVX}. + kIdVucomisd, //!< Instruction 'vucomisd' {AVX|AVX512_F}. + kIdVucomish, //!< Instruction 'vucomish' {AVX512_FP16}. + kIdVucomiss, //!< Instruction 'vucomiss' {AVX|AVX512_F}. + kIdVunpckhpd, //!< Instruction 'vunpckhpd' {AVX|AVX512_F+VL}. + kIdVunpckhps, //!< Instruction 'vunpckhps' {AVX|AVX512_F+VL}. + kIdVunpcklpd, //!< Instruction 'vunpcklpd' {AVX|AVX512_F+VL}. + kIdVunpcklps, //!< Instruction 'vunpcklps' {AVX|AVX512_F+VL}. + kIdVxorpd, //!< Instruction 'vxorpd' {AVX|AVX512_DQ+VL}. + kIdVxorps, //!< Instruction 'vxorps' {AVX|AVX512_DQ+VL}. + kIdVzeroall, //!< Instruction 'vzeroall' {AVX}. + kIdVzeroupper, //!< Instruction 'vzeroupper' {AVX}. + kIdWbinvd, //!< Instruction 'wbinvd' {I486}. + kIdWbnoinvd, //!< Instruction 'wbnoinvd' {WBNOINVD}. + kIdWrfsbase, //!< Instruction 'wrfsbase' {FSGSBASE} (X64). + kIdWrgsbase, //!< Instruction 'wrgsbase' {FSGSBASE} (X64). + kIdWrmsr, //!< Instruction 'wrmsr' {MSR}. + kIdWrssd, //!< Instruction 'wrssd' {CET_SS}. + kIdWrssq, //!< Instruction 'wrssq' {CET_SS} (X64). + kIdWrussd, //!< Instruction 'wrussd' {CET_SS}. + kIdWrussq, //!< Instruction 'wrussq' {CET_SS} (X64). + kIdXabort, //!< Instruction 'xabort' {RTM}. + kIdXadd, //!< Instruction 'xadd' {I486}. + kIdXbegin, //!< Instruction 'xbegin' {RTM}. + kIdXchg, //!< Instruction 'xchg'. + kIdXend, //!< Instruction 'xend' {RTM}. + kIdXgetbv, //!< Instruction 'xgetbv' {XSAVE}. + kIdXlatb, //!< Instruction 'xlatb'. + kIdXor, //!< Instruction 'xor'. + kIdXorpd, //!< Instruction 'xorpd' {SSE2}. + kIdXorps, //!< Instruction 'xorps' {SSE}. + kIdXresldtrk, //!< Instruction 'xresldtrk' {TSXLDTRK}. + kIdXrstor, //!< Instruction 'xrstor' {XSAVE}. + kIdXrstor64, //!< Instruction 'xrstor64' {XSAVE} (X64). + kIdXrstors, //!< Instruction 'xrstors' {XSAVES}. + kIdXrstors64, //!< Instruction 'xrstors64' {XSAVES} (X64). + kIdXsave, //!< Instruction 'xsave' {XSAVE}. + kIdXsave64, //!< Instruction 'xsave64' {XSAVE} (X64). + kIdXsavec, //!< Instruction 'xsavec' {XSAVEC}. + kIdXsavec64, //!< Instruction 'xsavec64' {XSAVEC} (X64). + kIdXsaveopt, //!< Instruction 'xsaveopt' {XSAVEOPT}. + kIdXsaveopt64, //!< Instruction 'xsaveopt64' {XSAVEOPT} (X64). + kIdXsaves, //!< Instruction 'xsaves' {XSAVES}. + kIdXsaves64, //!< Instruction 'xsaves64' {XSAVES} (X64). + kIdXsetbv, //!< Instruction 'xsetbv' {XSAVE}. + kIdXsusldtrk, //!< Instruction 'xsusldtrk' {TSXLDTRK}. + kIdXtest, //!< Instruction 'xtest' {TSX}. + _kIdCount + // ${InstId:End} + }; + + //! Tests whether the `instId` is defined. + static ASMJIT_INLINE_NODEBUG constexpr bool isDefinedId(InstId instId) noexcept { return instId < _kIdCount; } + + //! \cond + #define ASMJIT_INST_FROM_COND(ID) \ + ID##o, ID##no, ID##b , ID##ae, \ + ID##e, ID##ne, ID##be, ID##a , \ + ID##s, ID##ns, ID##pe, ID##po, \ + ID##l, ID##ge, ID##le, ID##g + + static constexpr uint16_t _jccTable[] = { ASMJIT_INST_FROM_COND(Inst::kIdJ) }; + static constexpr uint16_t _setccTable[] = { ASMJIT_INST_FROM_COND(Inst::kIdSet) }; + static constexpr uint16_t _cmovccTable[] = { ASMJIT_INST_FROM_COND(Inst::kIdCmov) }; + + #undef ASMJIT_INST_FROM_COND + //! \endcond + + //! Translates a condition code `cond` to a `jcc` instruction id. + static ASMJIT_INLINE_NODEBUG constexpr InstId jccFromCond(CondCode cond) noexcept { return _jccTable[uint8_t(cond)]; } + //! Translates a condition code `cond` to a `setcc` instruction id. + static ASMJIT_INLINE_NODEBUG constexpr InstId setccFromCond(CondCode cond) noexcept { return _setccTable[uint8_t(cond)]; } + //! Translates a condition code `cond` to a `cmovcc` instruction id. + static ASMJIT_INLINE_NODEBUG constexpr InstId cmovccFromCond(CondCode cond) noexcept { return _cmovccTable[uint8_t(cond)]; } +} // {Inst} + +//! FPU status word bits. +enum class FpuStatusWord : uint16_t { + kNone = 0x0000u, //!< No bits set. + + kInvalid = 0x0001u, //!< Invalid operation. + kDenormalized = 0x0002u, //!< Denormalized operand. + kDivByZero = 0x0004u, //!< Division by zero. + kOverflow = 0x0008u, //!< Overflown. + kUnderflow = 0x0010u, //!< Underflown. + kPrecision = 0x0020u, //!< Precision lost. + kStackFault = 0x0040u, //!< Stack fault. + kInterrupt = 0x0080u, //!< Interrupt. + kC0 = 0x0100u, //!< C0 flag. + kC1 = 0x0200u, //!< C1 flag. + kC2 = 0x0400u, //!< C2 flag. + kTopMask = 0x3800u, //!< Top of the stack (mask). + kC3 = 0x4000u, //!< C3 flag. + kBusy = 0x8000u //!< FPU is busy. +}; +ASMJIT_DEFINE_ENUM_FLAGS(FpuStatusWord) + +//! FPU control word bits. +enum class FpuControlWord : uint16_t { + kNone = 0x0000u, //!< No bits set. + + // Bits 0-5 + // -------- + + kEM_Mask = 0x003Fu, //!< Exception mask (0x3F). + kEM_Invalid = 0x0001u, //!< Invalid operation exception. + kEM_Denormal = 0x0002u, //!< Denormalized operand exception. + kEM_DivByZero = 0x0004u, //!< Division by zero exception. + kEM_Overflow = 0x0008u, //!< Overflow exception. + kEM_Underflow = 0x0010u, //!< Underflow exception. + kEM_Inexact = 0x0020u, //!< Inexact operation exception. + + // Bits 8-9 + // -------- + + kPC_Mask = 0x0300u, //!< Precision control mask. + kPC_Float = 0x0000u, //!< Single precision (24 bits). + kPC_Reserved = 0x0100u, //!< Reserved. + kPC_Double = 0x0200u, //!< Double precision (53 bits). + kPC_Extended = 0x0300u, //!< Extended precision (64 bits). + + // Bits 10-11 + // ---------- + + kRC_Mask = 0x0C00u, //!< Rounding control mask. + kRC_Nearest = 0x0000u, //!< Round to nearest even. + kRC_Down = 0x0400u, //!< Round down (floor). + kRC_Up = 0x0800u, //!< Round up (ceil). + kRC_Truncate = 0x0C00u, //!< Round towards zero (truncate). + + // Bit 12 + // ------ + + kIC_Mask = 0x1000u, //!< Infinity control. + kIC_Projective = 0x0000u, //!< Projective (not supported on X64). + kIC_Affine = 0x1000u //!< Affine (default). +}; +ASMJIT_DEFINE_ENUM_FLAGS(FpuControlWord) + +//! An immediate value that can be used with CMP[PD|PS|SD|SS] instructions. +enum class CmpImm : uint8_t { + kEQ = 0x00u, //!< Equal (Quiet), same as \ref VCmpImm::kEQ_OQ. + kLT = 0x01u, //!< Less (Signaling), same as \ref VCmpImm::kLT_OS. + kLE = 0x02u, //!< Less/Equal (Signaling), same as \ref VCmpImm::kLE_OS. + kUNORD = 0x03u, //!< Unordered (Quiet), same as \ref VCmpImm::kUNORD_Q. + kNEQ = 0x04u, //!< Not Equal (Quiet), same as \ref VCmpImm::kNEQ_UQ. + kNLT = 0x05u, //!< Not Less (Signaling), same as \ref VCmpImm::kNLT_US. + kNLE = 0x06u, //!< Not Less/Equal (Signaling), same as \ref VCmpImm::kNLE_US. + kORD = 0x07u //!< Ordered (Quiet), same as \ref VCmpImm::kORD_Q. +}; + +//! An immediate value that can be used with [V]PCMP[I|E]STR[I|M] instructions. +enum class PCmpStrImm : uint8_t { + // Source Data Format + // ------------------ + + kUB = 0x00u << 0, //!< The source data format is unsigned bytes. + kUW = 0x01u << 0, //!< The source data format is unsigned words. + kSB = 0x02u << 0, //!< The source data format is signed bytes. + kSW = 0x03u << 0, //!< The source data format is signed words. + + // Aggregation Operation + // --------------------- + + kEqualAny = 0x00u << 2, //!< The arithmetic comparison is "equal". + kRanges = 0x01u << 2, //!< The arithmetic comparison is "greater than or equal" between even indexed + //!< elements and "less than or equal" between odd indexed elements. + kEqualEach = 0x02u << 2, //!< The arithmetic comparison is "equal". + kEqualOrdered = 0x03u << 2, //!< The arithmetic comparison is "equal". + + // Polarity + // -------- + + kPosPolarity = 0x00u << 4, //!< IntRes2 = IntRes1. + kNegPolarity = 0x01u << 4, //!< IntRes2 = -1 XOR IntRes1. + kPosMasked = 0x02u << 4, //!< IntRes2 = IntRes1. + kNegMasked = 0x03u << 4, //!< IntRes2[i] = second[i] == invalid ? IntRes1[i] : ~IntRes1[i]. + + // Output Selection (pcmpstri) + // --------------------------- + + kOutputLSI = 0x00u << 6, //!< The index returned to ECX is of the least significant set bit in IntRes2. + kOutputMSI = 0x01u << 6, //!< The index returned to ECX is of the most significant set bit in IntRes2. + + // Output Selection (pcmpstrm) + // --------------------------- + + kBitMask = 0x00u << 6, //!< IntRes2 is returned as the mask to the least significant bits of XMM0. + kIndexMask = 0x01u << 6 //!< IntRes2 is expanded into a byte/word mask and placed in XMM0. +}; +ASMJIT_DEFINE_ENUM_FLAGS(PCmpStrImm) + +//! An immediate value that can be used with ROUND[PD|PS|SD|SS] instructions. +//! +//! \note `kSuppress` is a mask that can be used with any other value. +enum class RoundImm : uint8_t { + kNearest = 0x00u, //!< Round to nearest (even). + kDown = 0x01u, //!< Round to down toward -INF (floor), + kUp = 0x02u, //!< Round to up toward +INF (ceil). + kTrunc = 0x03u, //!< Round toward zero (truncate). + kCurrent = 0x04u, //!< Round to the current rounding mode set (ignores other RC bits). + kSuppress = 0x08u //!< Suppress exceptions (avoids inexact exception, if set). +}; +ASMJIT_DEFINE_ENUM_FLAGS(RoundImm) + +//! An immediate value that can be used with VCMP[PD|PS|SD|SS] instructions (AVX). +//! +//! The first 8 values are compatible with \ref CmpImm. +enum class VCmpImm : uint8_t { + kEQ_OQ = 0x00u, //!< Equal (Quiet , Ordered) , same as \ref CmpImm::kEQ. + kLT_OS = 0x01u, //!< Less (Signaling, Ordered) , same as \ref CmpImm::kLT. + kLE_OS = 0x02u, //!< Less/Equal (Signaling, Ordered) , same as \ref CmpImm::kLE. + kUNORD_Q = 0x03u, //!< Unordered (Quiet) , same as \ref CmpImm::kUNORD. + kNEQ_UQ = 0x04u, //!< Not Equal (Quiet , Unordered), same as \ref CmpImm::kNEQ. + kNLT_US = 0x05u, //!< Not Less (Signaling, Unordered), same as \ref CmpImm::kNLT. + kNLE_US = 0x06u, //!< Not Less/Equal (Signaling, Unordered), same as \ref CmpImm::kNLE. + kORD_Q = 0x07u, //!< Ordered (Quiet) , same as \ref CmpImm::kORD. + kEQ_UQ = 0x08u, //!< Equal (Quiet , Unordered). + kNGE_US = 0x09u, //!< Not Greater/Equal (Signaling, Unordered). + kNGT_US = 0x0Au, //!< Not Greater (Signaling, Unordered). + kFALSE_OQ = 0x0Bu, //!< False (Quiet , Ordered). + kNEQ_OQ = 0x0Cu, //!< Not Equal (Quiet , Ordered). + kGE_OS = 0x0Du, //!< Greater/Equal (Signaling, Ordered). + kGT_OS = 0x0Eu, //!< Greater (Signaling, Ordered). + kTRUE_UQ = 0x0Fu, //!< True (Quiet , Unordered). + kEQ_OS = 0x10u, //!< Equal (Signaling, Ordered). + kLT_OQ = 0x11u, //!< Less (Quiet , Ordered). + kLE_OQ = 0x12u, //!< Less/Equal (Quiet , Ordered). + kUNORD_S = 0x13u, //!< Unordered (Signaling). + kNEQ_US = 0x14u, //!< Not Equal (Signaling, Unordered). + kNLT_UQ = 0x15u, //!< Not Less (Quiet , Unordered). + kNLE_UQ = 0x16u, //!< Not Less/Equal (Quiet , Unordered). + kORD_S = 0x17u, //!< Ordered (Signaling). + kEQ_US = 0x18u, //!< Equal (Signaling, Unordered). + kNGE_UQ = 0x19u, //!< Not Greater/Equal (Quiet , Unordered). + kNGT_UQ = 0x1Au, //!< Not Greater (Quiet , Unordered). + kFALSE_OS = 0x1Bu, //!< False (Signaling, Ordered). + kNEQ_OS = 0x1Cu, //!< Not Equal (Signaling, Ordered). + kGE_OQ = 0x1Du, //!< Greater/Equal (Quiet , Ordered). + kGT_OQ = 0x1Eu, //!< Greater (Quiet , Ordered). + kTRUE_US = 0x1Fu //!< True (Signaling, Unordered). +}; + +//! An immediate value that can be used with VFIXUPIMM[PD|PS|SD|SS] instructions (AVX-512). +//! +//! The final immediate is a combination of all possible control bits. +enum class VFixupImm : uint8_t { + kNone = 0x00u, + kZEOnZero = 0x01u, + kIEOnZero = 0x02u, + kZEOnOne = 0x04u, + kIEOnOne = 0x08u, + kIEOnSNaN = 0x10u, + kIEOnNInf = 0x20u, + kIEOnNegative = 0x40u, + kIEOnPInf = 0x80u +}; +ASMJIT_DEFINE_ENUM_FLAGS(VFixupImm) + +//! An immediate value that can be used with VFPCLASS[PD|PS|SD|SS] instructions (AVX-512). +//! +//! The values can be combined together to form the final 8-bit mask. +enum class VFPClassImm : uint8_t { + kNone = 0x00u, + kQNaN = 0x01u, //!< Checks for QNaN. + kPZero = 0x02u, //!< Checks for +0. + kNZero = 0x04u, //!< Checks for -0. + kPInf = 0x08u, //!< Checks for +Inf. + kNInf = 0x10u, //!< Checks for -Inf. + kDenormal = 0x20u, //!< Checks for denormal. + kNegative = 0x40u, //!< Checks for negative finite value. + kSNaN = 0x80u //!< Checks for SNaN. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VFPClassImm) + +//! An immediate value that can be used with VGETMANT[PD|PS|SD|SS] instructions (AVX-512). +//! +//! The value is a combination of a normalization interval and a sign control. +enum class VGetMantImm : uint8_t { + // Normalization Interval + // ---------------------- + + k1To2 = 0x00u, //!< Normalization interval is [1, 2) + k1Div2To2 = 0x01u, //!< Normalization interval is [0.5, 2) + k1Div2To1 = 0x02u, //!< Normalization interval is [0.5, 1) + k3Div4To3Div2 = 0x03u, //!< Normalization interval is [3/4, 3/2) + + // Sign Control + // ------------ + + kSrcSign = 0x00u, //!< Source sign. + kNoSign = 0x04u, //!< Zero sign + kQNaNIfSign = 0x08u //!< QNAN_Indefinite if sign(src) != 0, regardless of `kSignSrc` or `kNoSign`. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VGetMantImm) + +//! A predicate used by VPCMP[U][B|W|D|Q] instructions (AVX-512). +enum class VPCmpImm : uint8_t { + kEQ = 0x00u, //!< Equal. + kLT = 0x01u, //!< Less. + kLE = 0x02u, //!< Less/Equal. + kFALSE = 0x03u, //!< False. + kNE = 0x04u, //!< Not Equal. + kGE = 0x05u, //!< Greater/Equal. + kGT = 0x06u, //!< Greater. + kTRUE = 0x07u //!< True. +}; + +//! A predicate used by VPCOM[U][B|W|D|Q] instructions (XOP). +enum class VPComImm : uint8_t { + kLT = 0x00u, //!< Less. + kLE = 0x01u, //!< Less/Equal + kGT = 0x02u, //!< Greater. + kGE = 0x03u, //!< Greater/Equal. + kEQ = 0x04u, //!< Equal. + kNE = 0x05u, //!< Not Equal. + kFALSE = 0x06u, //!< False. + kTRUE = 0x07u //!< True. +}; + +//! A predicate used by VRANGE[PD|PS|SD|SS] instructions (AVX-512). +enum class VRangeImm : uint8_t { + // Selector + // -------- + + kSelectMin = 0x00u, //!< Select minimum value. + kSelectMax = 0x01u, //!< Select maximum value. + kSelectAbsMin = 0x02u, //!< Select minimum absolute value. + kSelectAbsMax = 0x03u, //!< Select maximum absolute value. + + // Sign + // ---- + + kSignSrc1 = 0x00u, //!< Select sign of SRC1. + kSignSrc2 = 0x04u, //!< Select sign of SRC2. + kSign0 = 0x08u, //!< Set sign to 0. + kSign1 = 0x0Cu //!< Set sign to 1. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VRangeImm) + +//! A predicate used by VREDUCE[PD|PS|SD|SS] instructions (AVX-512). +enum class VReduceImm : uint8_t { + kRoundEven = 0x00u, //!< Round to nearest even. + kRoundDown = 0x01u, //!< Round down. + kRoundUp = 0x02u, //!< Round up. + kRoundTrunc = 0x03u, //!< Truncate. + kRoundCurrent = 0x04u, //!< Round to the current mode set. + kSuppress = 0x08u, //!< Suppress exceptions. + kFixedImmMask = 0xF0u //!< Fixed length value mask. +}; +ASMJIT_DEFINE_ENUM_FLAGS(VReduceImm) + +//! Creates a \ref VReduceImm from a combination of `flags` and `fixedPointLength`. +static ASMJIT_INLINE_NODEBUG constexpr VReduceImm vReduceImm(VReduceImm flags, uint32_t fixedPointLength) noexcept { + return flags | VReduceImm(fixedPointLength << 4); +} + +//! A predicate that can be used as an immediate value with VPTERNLOG[D|Q] instruction. +//! +//! There are 3 inputs to the instruction (\ref kA, \ref kB, \ref kC). Ternary logic can define any combination +//! that would be performed on these 3 inputs to get the desired output - any combination of AND, OR, XOR, NOT +//! is possible. +//! +//! \sa \ref tLogFromBits and \ref fLogIfElse +enum class TLogImm : uint8_t { + k0 = 0x00u, //!< 0 value. + k1 = 0xFFu, //!< 1 value. + kA = 0xF0u, //!< A value. + kB = 0xCCu, //!< B value. + kC = 0xAAu, //!< C value. + + kNotA = kA ^ k1, //!< `!A` expression. + kNotB = kB ^ k1, //!< `!B` expression. + kNotC = kC ^ k1, //!< `!C` expression. + + kAB = kA & kB, //!< `A & B` expression. + kAC = kA & kC, //!< `A & C` expression. + kBC = kB & kC, //!< `B & C` expression. + kNotAB = kAB ^ k1, //!< `!(A & B)` expression. + kNotAC = kAC ^ k1, //!< `!(A & C)` expression. + kNotBC = kBC ^ k1, //!< `!(B & C)` expression. + + kABC = kAB & kC, //!< `A & B & C` expression. + kNotABC = kABC ^ k1 //!< `!(A & B & C)` expression. +}; +ASMJIT_DEFINE_ENUM_FLAGS(TLogImm) + +//! Creates an immediate that can be used by VPTERNLOG[D|Q] instructions. +static ASMJIT_INLINE_NODEBUG constexpr TLogImm tLogFromBits(uint8_t b000, uint8_t b001, uint8_t b010, uint8_t b011, uint8_t b100, uint8_t b101, uint8_t b110, uint8_t b111) noexcept { + return TLogImm(uint8_t(b000 << 0) | + uint8_t(b001 << 1) | + uint8_t(b010 << 2) | + uint8_t(b011 << 3) | + uint8_t(b100 << 4) | + uint8_t(b101 << 5) | + uint8_t(b110 << 6) | + uint8_t(b111 << 7)); +} + +//! Creates an if/else logic that can be used by VPTERNLOG[D|Q] instructions. +static ASMJIT_INLINE_NODEBUG constexpr TLogImm fLogIfElse(TLogImm condition, TLogImm a, TLogImm b) noexcept { return (condition & a) | (~condition & b); } + +//! Creates a shuffle immediate value that be used with SSE/AVX/AVX-512 instructions to shuffle 2 elements in a vector. +//! +//! \param a Position of the first component [0, 1]. +//! \param b Position of the second component [0, 1]. +//! +//! Shuffle constants can be used to encode an immediate for these instructions: +//! - `shufpd|vshufpd` +static ASMJIT_INLINE_NODEBUG constexpr uint32_t shuffleImm(uint32_t a, uint32_t b) noexcept { + return (a << 1) | b; +} + +//! Creates a shuffle immediate value that be used with SSE/AVX/AVX-512 instructions to shuffle 4 elements in a vector. +//! +//! \param a Position of the first component [0, 3]. +//! \param b Position of the second component [0, 3]. +//! \param c Position of the third component [0, 3]. +//! \param d Position of the fourth component [0, 3]. +//! +//! Shuffle constants can be used to encode an immediate for these instructions: +//! - `pshufw` +//! - `pshuflw|vpshuflw` +//! - `pshufhw|vpshufhw` +//! - `pshufd|vpshufd` +//! - `shufps|vshufps` +static ASMJIT_INLINE_NODEBUG constexpr uint32_t shuffleImm(uint32_t a, uint32_t b, uint32_t c, uint32_t d) noexcept { + return (a << 6) | (b << 4) | (c << 2) | d; +} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86GLOBALS_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86instdb.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86instdb.h new file mode 100644 index 0000000000000000000000000000000000000000..b0695d05db5307f6d8b182162ea74b46f6102b0b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86instdb.h @@ -0,0 +1,563 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86INSTDB_H_INCLUDED +#define ASMJIT_X86_X86INSTDB_H_INCLUDED + +#include "../x86/x86globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +//! Instruction database (X86). +namespace InstDB { + +//! Describes which operation mode is supported by an instruction. +enum class Mode : uint8_t { + //! Invalid mode. + kNone = 0x00u, + //! X86 mode supported. + kX86 = 0x01u, + //! X64 mode supported. + kX64 = 0x02u, + //! Both X86 and X64 modes supported. + kAny = 0x03u +}; +ASMJIT_DEFINE_ENUM_FLAGS(Mode) + +//! Converts architecture to operation mode, see \ref Mode. +static ASMJIT_INLINE_NODEBUG constexpr Mode modeFromArch(Arch arch) noexcept { + return arch == Arch::kX86 ? Mode::kX86 : + arch == Arch::kX64 ? Mode::kX64 : Mode::kNone; +} + +//! Operand signature flags used by \ref OpSignature. +enum class OpFlags : uint64_t { + //! No operand flags. + kNone = 0u, + + kRegGpbLo = 0x0000000000000001u, //!< Operand can be low 8-bit GPB register. + kRegGpbHi = 0x0000000000000002u, //!< Operand can be high 8-bit GPB register. + kRegGpw = 0x0000000000000004u, //!< Operand can be 16-bit GPW register. + kRegGpd = 0x0000000000000008u, //!< Operand can be 32-bit GPD register. + kRegGpq = 0x0000000000000010u, //!< Operand can be 64-bit GPQ register. + kRegXmm = 0x0000000000000020u, //!< Operand can be 128-bit XMM register. + kRegYmm = 0x0000000000000040u, //!< Operand can be 256-bit YMM register. + kRegZmm = 0x0000000000000080u, //!< Operand can be 512-bit ZMM register. + kRegMm = 0x0000000000000100u, //!< Operand can be 64-bit MM register. + kRegKReg = 0x0000000000000200u, //!< Operand can be 64-bit K register. + kRegSReg = 0x0000000000000400u, //!< Operand can be SReg (segment register). + kRegCReg = 0x0000000000000800u, //!< Operand can be CReg (control register). + kRegDReg = 0x0000000000001000u, //!< Operand can be DReg (debug register). + kRegSt = 0x0000000000002000u, //!< Operand can be 80-bit ST register (X87). + kRegBnd = 0x0000000000004000u, //!< Operand can be 128-bit BND register. + kRegTmm = 0x0000000000008000u, //!< Operand can be 0..8192-bit TMM register. + kRegMask = 0x000000000000FFFFu, //!< Mask of all possible register types. + + kMemUnspecified = 0x0000000000040000u, //!< Operand can be a scalar memory pointer without size. + kMem8 = 0x0000000000080000u, //!< Operand can be an 8-bit memory pointer. + kMem16 = 0x0000000000100000u, //!< Operand can be a 16-bit memory pointer. + kMem32 = 0x0000000000200000u, //!< Operand can be a 32-bit memory pointer. + kMem48 = 0x0000000000400000u, //!< Operand can be a 48-bit memory pointer (FAR pointers only). + kMem64 = 0x0000000000800000u, //!< Operand can be a 64-bit memory pointer. + kMem80 = 0x0000000001000000u, //!< Operand can be an 80-bit memory pointer. + kMem128 = 0x0000000002000000u, //!< Operand can be a 128-bit memory pointer. + kMem256 = 0x0000000004000000u, //!< Operand can be a 256-bit memory pointer. + kMem512 = 0x0000000008000000u, //!< Operand can be a 512-bit memory pointer. + kMem1024 = 0x0000000010000000u, //!< Operand can be a 1024-bit memory pointer. + kMemMask = 0x000000001FFC0000u, //!< Mask of all possible scalar memory types. + + kVm32x = 0x0000000040000000u, //!< Operand can be a vm32x (vector) pointer. + kVm32y = 0x0000000080000000u, //!< Operand can be a vm32y (vector) pointer. + kVm32z = 0x0000000100000000u, //!< Operand can be a vm32z (vector) pointer. + kVm64x = 0x0000000200000000u, //!< Operand can be a vm64x (vector) pointer. + kVm64y = 0x0000000400000000u, //!< Operand can be a vm64y (vector) pointer. + kVm64z = 0x0000000800000000u, //!< Operand can be a vm64z (vector) pointer. + kVmMask = 0x0000000FC0000000u, //!< Mask of all possible vector memory types. + + kImmI4 = 0x0000001000000000u, //!< Operand can be signed 4-bit immediate. + kImmU4 = 0x0000002000000000u, //!< Operand can be unsigned 4-bit immediate. + kImmI8 = 0x0000004000000000u, //!< Operand can be signed 8-bit immediate. + kImmU8 = 0x0000008000000000u, //!< Operand can be unsigned 8-bit immediate. + kImmI16 = 0x0000010000000000u, //!< Operand can be signed 16-bit immediate. + kImmU16 = 0x0000020000000000u, //!< Operand can be unsigned 16-bit immediate. + kImmI32 = 0x0000040000000000u, //!< Operand can be signed 32-bit immediate. + kImmU32 = 0x0000080000000000u, //!< Operand can be unsigned 32-bit immediate. + kImmI64 = 0x0000100000000000u, //!< Operand can be signed 64-bit immediate. + kImmU64 = 0x0000200000000000u, //!< Operand can be unsigned 64-bit immediate. + kImmMask = 0x00003FF000000000u, //!< Mask of all immediate types. + + kRel8 = 0x0000400000000000u, //!< Operand can be relative 8-bit displacement. + kRel32 = 0x0000800000000000u, //!< Operand can be relative 32-bit displacement. + kRelMask = 0x0000C00000000000u, //!< Mask of all relative displacement types. + + kFlagMemBase = 0x0001000000000000u, //!< Flag: Only memory base is allowed (no index, no offset). + kFlagMemDs = 0x0002000000000000u, //!< Flag: Implicit memory operand's DS segment. + kFlagMemEs = 0x0004000000000000u, //!< Flag: Implicit memory operand's ES segment. + + kFlagMib = 0x0008000000000000u, //!< Flag: Operand is MIB (base+index) pointer. + kFlagTMem = 0x0010000000000000u, //!< Flag: Operand is TMEM (sib_mem), AMX memory pointer. + + kFlagImplicit = 0x0080000000000000u, //!< Flag: Operand is implicit. + kFlagMask = 0x009F000000000000u, //!< Mask of all flags. + + //! Contains mask of all registers, memory operands, immediate operands, and displacement operands. + kOpMask = kRegMask | kMemMask | kVmMask | kImmMask | kRelMask +}; +ASMJIT_DEFINE_ENUM_FLAGS(OpFlags) + +//! Operand signature. +//! +//! Contains all possible operand combinations, memory size information, and a fixed register id (or `BaseReg::kIdBad` +//! if fixed id isn't required). +struct OpSignature { + //! \name Members + //! \{ + + uint64_t _flags : 56; + uint64_t _regMask : 8; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns operand signature flags. + inline OpFlags flags() const noexcept { return (OpFlags)_flags; } + + //! Tests whether the given `flag` is set. + inline bool hasFlag(OpFlags flag) const noexcept { return (_flags & uint64_t(flag)) != 0; } + + //! Tests whether this signature contains at least one register operand of any type. + inline bool hasReg() const noexcept { return hasFlag(OpFlags::kRegMask); } + //! Tests whether this signature contains at least one scalar memory operand of any type. + inline bool hasMem() const noexcept { return hasFlag(OpFlags::kMemMask); } + //! Tests whether this signature contains at least one vector memory operand of any type. + inline bool hasVm() const noexcept { return hasFlag(OpFlags::kVmMask); } + //! Tests whether this signature contains at least one immediate operand of any type. + inline bool hasImm() const noexcept { return hasFlag(OpFlags::kImmMask); } + //! Tests whether this signature contains at least one relative displacement operand of any type. + inline bool hasRel() const noexcept { return hasFlag(OpFlags::kRelMask); } + + //! Tests whether the operand is implicit. + inline bool isImplicit() const noexcept { return hasFlag(OpFlags::kFlagImplicit); } + + //! Returns a physical register mask. + inline RegMask regMask() const noexcept { return _regMask; } + + //! \} +}; + +ASMJIT_VARAPI const OpSignature _opSignatureTable[]; + +//! Instruction signature. +//! +//! Contains a sequence of operands' combinations and other metadata that defines a single instruction. This data is +//! used by instruction validator. +struct InstSignature { + //! \name Members + //! \{ + + //! Count of operands in `opIndex` (0..6). + uint8_t _opCount : 3; + //! Architecture modes supported (X86 / X64). + uint8_t _mode : 2; + //! Number of implicit operands. + uint8_t _implicitOpCount : 3; + //! Reserved for future use. + uint8_t _reserved; + //! Indexes to `OpSignature` table. + uint8_t _opSignatureIndexes[Globals::kMaxOpCount]; + + //! \} + + //! \name Accessors + //! \{ + + //! Returns instruction operation mode. + inline Mode mode() const noexcept { return (Mode)_mode; } + //! Tests whether the instruction supports the given operating mode. + inline bool supportsMode(Mode mode) const noexcept { return (uint8_t(_mode) & uint8_t(mode)) != 0; } + + //! Returns the number of operands of this signature. + inline uint32_t opCount() const noexcept { return _opCount; } + //! Returns the number of implicit operands this signature has. + inline uint32_t implicitOpCount() const noexcept { return _implicitOpCount; } + //! Tests whether this instruction signature has at least one implicit operand. + inline bool hasImplicitOperands() const noexcept { return _implicitOpCount != 0; } + + //! Returns indexes to \ref _opSignatureTable for each operand of the instruction. + //! + //! \note The returned array always provides indexes for all operands (see \ref Globals::kMaxOpCount) even if the + //! instruction provides less operands. Undefined operands have always index of zero. + inline const uint8_t* opSignatureIndexes() const noexcept { return _opSignatureIndexes; } + + //! Returns index to \ref _opSignatureTable, corresponding to the requested operand `index` of the instruction. + inline uint8_t opSignatureIndex(size_t index) const noexcept { + ASMJIT_ASSERT(index < Globals::kMaxOpCount); + return _opSignatureIndexes[index]; + } + + //! Returns \ref OpSignature corresponding to the requested operand `index` of the instruction. + inline const OpSignature& opSignature(size_t index) const noexcept { + ASMJIT_ASSERT(index < Globals::kMaxOpCount); + return _opSignatureTable[_opSignatureIndexes[index]]; + } + + //! \} +}; + +ASMJIT_VARAPI const InstSignature _instSignatureTable[]; + +//! Instruction flags. +//! +//! Details about instruction encoding, operation, features, and some limitations. +enum class InstFlags : uint32_t { + //! No flags. + kNone = 0x00000000u, + + // Instruction Family + // ------------------ + // + // Instruction family information. + + //! Instruction that accesses FPU registers. + kFpu = 0x00000100u, + //! Instruction that accesses MMX registers (including 3DNOW and GEODE) and EMMS. + kMmx = 0x00000200u, + //! Instruction that accesses XMM registers (SSE, AVX, AVX512). + kVec = 0x00000400u, + + // FPU Flags + // --------- + // + // Used to tell the encoder which memory operand sizes are encodable. + + //! FPU instruction can address `word_ptr` (shared with M80). + kFpuM16 = 0x00000800u, + //! FPU instruction can address `dword_ptr`. + kFpuM32 = 0x00001000u, + //! FPU instruction can address `qword_ptr`. + kFpuM64 = 0x00002000u, + //! FPU instruction can address `tword_ptr` (shared with M16). + kFpuM80 = 0x00000800u, + + // Prefixes and Encoding Flags + // --------------------------- + // + // These describe optional X86 prefixes that can be used to change the instruction's operation. + + //! Instruction can be prefixed with using the REP(REPE) or REPNE prefix. + kRep = 0x00004000u, + //! Rep prefix is accepted, but it has no effect other than being emitted with the instruction (as an extra byte). + kRepIgnored = 0x00008000u, + //! Instruction can be prefixed with using the LOCK prefix. + kLock = 0x00010000u, + //! Instruction can be prefixed with using the XACQUIRE prefix. + kXAcquire = 0x00020000u, + //! Instruction can be prefixed with using the XRELEASE prefix. + kXRelease = 0x00040000u, + //! Instruction uses MIB (BNDLDX|BNDSTX) to encode two registers. + kMib = 0x00080000u, + //! Instruction uses VSIB instead of legacy SIB. + kVsib = 0x00100000u, + //! Instruction uses TSIB (or SIB_MEM) encoding (MODRM followed by SIB). + kTsib = 0x00200000u, + + // If both `kPrefixVex` and `kPrefixEvex` flags are specified it means that the instructions can be encoded + // by either VEX or EVEX prefix. In that case AsmJit checks global options and also instruction options to decide + // whether to emit VEX or EVEX prefix. + + //! Instruction can be encoded by VEX|XOP (AVX|AVX2|BMI|XOP|...). + kVex = 0x00400000u, + //! Instruction can be encoded by EVEX (AVX512). + kEvex = 0x00800000u, + //! EVEX encoding is preferred over VEX encoding (AVX515_VNNI vs AVX_VNNI). + kPreferEvex = 0x01000000u, + //! EVEX and VEX signatures are compatible. + kEvexCompat = 0x02000000u, + //! EVEX instruction requires K register in the first operand (compare instructions). + kEvexKReg = 0x04000000u, + //! EVEX instruction requires two operands and K register as a selector (gather instructions). + kEvexTwoOp = 0x08000000u, + //! VEX instruction that can be transformed to a compatible EVEX instruction. + kEvexTransformable = 0x10000000u, + + // Other Flags + // ----------- + + //! Instruction uses consecutive registers. + //! + //! Used by V4FMADDPS, V4FMADDSS, V4FNMADDPS, V4FNMADDSS, VP4DPWSSD, VP4DPWSSDS, VP2INTERSECTD, and VP2INTERSECTQ + //! instructions + kConsecutiveRegs = 0x20000000u +}; +ASMJIT_DEFINE_ENUM_FLAGS(InstFlags) + +//! AVX-512 flags. +enum class Avx512Flags : uint32_t { + //! No AVX-512 flags. + kNone = 0, + + //! Internally used in tables, has no meaning. + k_ = 0x00000000u, + //! Supports masking {k1..k7}. + kK = 0x00000001u, + //! Supports zeroing {z}, must be used together with `kAvx512k`. + kZ = 0x00000002u, + //! Supports 'embedded-rounding' {er} with implicit {sae}, + kER = 0x00000004u, + //! Supports 'suppress-all-exceptions' {sae}. + kSAE = 0x00000008u, + //! Supports 16-bit broadcast 'b16'. + kB16 = 0x00000010u, + //! Supports 32-bit broadcast 'b32'. + kB32 = 0x00000020u, + //! Supports 64-bit broadcast 'b64'. + kB64 = 0x00000040u, + //! Operates on a vector of consecutive registers (AVX512_4FMAPS and AVX512_4VNNIW). + kT4X = 0x00000080u, + + //! Implicit zeroing if {k} masking is used. Using {z} is not valid in this case as it's implicit. + kImplicitZ = 0x00000100, +}; +ASMJIT_DEFINE_ENUM_FLAGS(Avx512Flags) + +//! Instruction common information. +//! +//! Aggregated information shared across one or more instruction. +struct CommonInfo { + //! Instruction flags. + uint32_t _flags; + //! Reserved for future use. + uint32_t _avx512Flags : 11; + //! First `InstSignature` entry in the database. + uint32_t _iSignatureIndex : 11; + //! Number of relevant `ISignature` entries. + uint32_t _iSignatureCount : 5; + //! Instruction control flow category, see \ref InstControlFlow. + uint32_t _controlFlow : 3; + //! Specifies what happens if all source operands share the same register. + uint32_t _sameRegHint : 2; + + //! \name Accessors + //! \{ + + //! Returns instruction flags. + ASMJIT_INLINE_NODEBUG InstFlags flags() const noexcept { return (InstFlags)_flags; } + //! Tests whether the instruction has a `flag`. + ASMJIT_INLINE_NODEBUG bool hasFlag(InstFlags flag) const noexcept { return Support::test(_flags, flag); } + + //! Returns instruction AVX-512 flags. + ASMJIT_INLINE_NODEBUG Avx512Flags avx512Flags() const noexcept { return (Avx512Flags)_avx512Flags; } + //! Tests whether the instruction has an AVX-512 `flag`. + ASMJIT_INLINE_NODEBUG bool hasAvx512Flag(Avx512Flags flag) const noexcept { return Support::test(_avx512Flags, flag); } + + //! Tests whether the instruction is FPU instruction. + ASMJIT_INLINE_NODEBUG bool isFpu() const noexcept { return hasFlag(InstFlags::kFpu); } + //! Tests whether the instruction is MMX/3DNOW instruction that accesses MMX registers (includes EMMS and FEMMS). + ASMJIT_INLINE_NODEBUG bool isMmx() const noexcept { return hasFlag(InstFlags::kMmx); } + //! Tests whether the instruction is SSE|AVX|AVX512 instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isVec() const noexcept { return hasFlag(InstFlags::kVec); } + //! Tests whether the instruction is SSE+ (SSE4.2, AES, SHA included) instruction that accesses XMM registers. + ASMJIT_INLINE_NODEBUG bool isSse() const noexcept { return (flags() & (InstFlags::kVec | InstFlags::kVex | InstFlags::kEvex)) == InstFlags::kVec; } + //! Tests whether the instruction is AVX+ (FMA included) instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isAvx() const noexcept { return isVec() && isVexOrEvex(); } + + //! Tests whether the instruction can be prefixed with LOCK prefix. + ASMJIT_INLINE_NODEBUG bool hasLockPrefix() const noexcept { return hasFlag(InstFlags::kLock); } + //! Tests whether the instruction can be prefixed with REP (REPE|REPZ) prefix. + ASMJIT_INLINE_NODEBUG bool hasRepPrefix() const noexcept { return hasFlag(InstFlags::kRep); } + //! Tests whether the instruction can be prefixed with XACQUIRE prefix. + ASMJIT_INLINE_NODEBUG bool hasXAcquirePrefix() const noexcept { return hasFlag(InstFlags::kXAcquire); } + //! Tests whether the instruction can be prefixed with XRELEASE prefix. + ASMJIT_INLINE_NODEBUG bool hasXReleasePrefix() const noexcept { return hasFlag(InstFlags::kXRelease); } + + //! Tests whether the rep prefix is supported by the instruction, but ignored (has no effect). + ASMJIT_INLINE_NODEBUG bool isRepIgnored() const noexcept { return hasFlag(InstFlags::kRepIgnored); } + //! Tests whether the instruction uses MIB. + ASMJIT_INLINE_NODEBUG bool isMibOp() const noexcept { return hasFlag(InstFlags::kMib); } + //! Tests whether the instruction uses VSIB. + ASMJIT_INLINE_NODEBUG bool isVsibOp() const noexcept { return hasFlag(InstFlags::kVsib); } + //! Tests whether the instruction uses TSIB (AMX, instruction requires MOD+SIB). + ASMJIT_INLINE_NODEBUG bool isTsibOp() const noexcept { return hasFlag(InstFlags::kTsib); } + //! Tests whether the instruction uses VEX (can be set together with EVEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVex() const noexcept { return hasFlag(InstFlags::kVex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isEvex() const noexcept { return hasFlag(InstFlags::kEvex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVexOrEvex() const noexcept { return hasFlag(InstFlags::kVex | InstFlags::kEvex); } + + //! Tests whether the instruction should prefer EVEX prefix instead of VEX prefix. + ASMJIT_INLINE_NODEBUG bool preferEvex() const noexcept { return hasFlag(InstFlags::kPreferEvex); } + + ASMJIT_INLINE_NODEBUG bool isEvexCompatible() const noexcept { return hasFlag(InstFlags::kEvexCompat); } + ASMJIT_INLINE_NODEBUG bool isEvexKRegOnly() const noexcept { return hasFlag(InstFlags::kEvexKReg); } + ASMJIT_INLINE_NODEBUG bool isEvexTwoOpOnly() const noexcept { return hasFlag(InstFlags::kEvexTwoOp); } + ASMJIT_INLINE_NODEBUG bool isEvexTransformable() const noexcept { return hasFlag(InstFlags::kEvexTransformable); } + + //! Tests whether the instruction supports AVX512 masking {k}. + ASMJIT_INLINE_NODEBUG bool hasAvx512K() const noexcept { return hasAvx512Flag(Avx512Flags::kK); } + //! Tests whether the instruction supports AVX512 zeroing {k}{z}. + ASMJIT_INLINE_NODEBUG bool hasAvx512Z() const noexcept { return hasAvx512Flag(Avx512Flags::kZ); } + //! Tests whether the instruction supports AVX512 embedded-rounding {er}. + ASMJIT_INLINE_NODEBUG bool hasAvx512ER() const noexcept { return hasAvx512Flag(Avx512Flags::kER); } + //! Tests whether the instruction supports AVX512 suppress-all-exceptions {sae}. + ASMJIT_INLINE_NODEBUG bool hasAvx512SAE() const noexcept { return hasAvx512Flag(Avx512Flags::kSAE); } + //! Tests whether the instruction supports AVX512 broadcast (either 32-bit or 64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B() const noexcept { return hasAvx512Flag(Avx512Flags::kB16 | Avx512Flags::kB32 | Avx512Flags::kB64); } + //! Tests whether the instruction supports AVX512 broadcast (16-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B16() const noexcept { return hasAvx512Flag(Avx512Flags::kB16); } + //! Tests whether the instruction supports AVX512 broadcast (32-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B32() const noexcept { return hasAvx512Flag(Avx512Flags::kB32); } + //! Tests whether the instruction supports AVX512 broadcast (64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B64() const noexcept { return hasAvx512Flag(Avx512Flags::kB64); } + + // Returns the size of the broadcast - either 2, 4, or 8, or 0 if broadcast is not supported. + ASMJIT_INLINE_NODEBUG uint32_t broadcastSize() const noexcept { + constexpr uint32_t kShift = Support::ConstCTZ::value; + return (uint32_t(_avx512Flags) & uint32_t(Avx512Flags::kB16 | Avx512Flags::kB32 | Avx512Flags::kB64)) >> (kShift - 1); + } + + ASMJIT_INLINE_NODEBUG uint32_t signatureIndex() const noexcept { return _iSignatureIndex; } + ASMJIT_INLINE_NODEBUG uint32_t signatureCount() const noexcept { return _iSignatureCount; } + + ASMJIT_INLINE_NODEBUG const InstSignature* signatureData() const noexcept { return _instSignatureTable + _iSignatureIndex; } + ASMJIT_INLINE_NODEBUG const InstSignature* signatureEnd() const noexcept { return _instSignatureTable + _iSignatureIndex + _iSignatureCount; } + + //! Returns a control flow category of the instruction. + ASMJIT_INLINE_NODEBUG InstControlFlow controlFlow() const noexcept { return (InstControlFlow)_controlFlow; } + + //! Returns a hint that can be used when both inputs are the same register. + ASMJIT_INLINE_NODEBUG InstSameRegHint sameRegHint() const noexcept { return (InstSameRegHint)_sameRegHint; } + + //! \} +}; + +ASMJIT_VARAPI const CommonInfo _commonInfoTable[]; + +//! Instruction information. +struct InstInfo { + //! Reserved for future use. + uint32_t _reserved : 14; + //! Index to \ref _commonInfoTable. + uint32_t _commonInfoIndex : 10; + //! Index to \ref _additionalInfoTable. + uint32_t _additionalInfoIndex : 8; + + //! Instruction encoding (internal encoding identifier used by \ref Assembler). + uint8_t _encoding; + //! Main opcode value (0..255). + uint8_t _mainOpcodeValue; + //! Index to \ref _mainOpcodeTable` that is combined with \ref _mainOpcodeValue to form the final opcode. + uint8_t _mainOpcodeIndex; + //! Index to \ref _altOpcodeTable that contains a full alternative opcode. + uint8_t _altOpcodeIndex; + + //! \name Accessors + //! \{ + + //! Returns common information, see \ref CommonInfo. + ASMJIT_INLINE_NODEBUG const CommonInfo& commonInfo() const noexcept { return _commonInfoTable[_commonInfoIndex]; } + + //! Returns instruction flags, see \ref InstFlags. + ASMJIT_INLINE_NODEBUG InstFlags flags() const noexcept { return commonInfo().flags(); } + //! Tests whether the instruction has flag `flag`, see \ref InstFlags. + ASMJIT_INLINE_NODEBUG bool hasFlag(InstFlags flag) const noexcept { return commonInfo().hasFlag(flag); } + + //! Returns instruction AVX-512 flags, see \ref Avx512Flags. + ASMJIT_INLINE_NODEBUG Avx512Flags avx512Flags() const noexcept { return commonInfo().avx512Flags(); } + //! Tests whether the instruction has an AVX-512 `flag`, see \ref Avx512Flags. + ASMJIT_INLINE_NODEBUG bool hasAvx512Flag(Avx512Flags flag) const noexcept { return commonInfo().hasAvx512Flag(flag); } + + //! Tests whether the instruction is FPU instruction. + ASMJIT_INLINE_NODEBUG bool isFpu() const noexcept { return commonInfo().isFpu(); } + //! Tests whether the instruction is MMX/3DNOW instruction that accesses MMX registers (includes EMMS and FEMMS). + ASMJIT_INLINE_NODEBUG bool isMmx() const noexcept { return commonInfo().isMmx(); } + //! Tests whether the instruction is SSE|AVX|AVX512 instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isVec() const noexcept { return commonInfo().isVec(); } + //! Tests whether the instruction is SSE+ (SSE4.2, AES, SHA included) instruction that accesses XMM registers. + ASMJIT_INLINE_NODEBUG bool isSse() const noexcept { return commonInfo().isSse(); } + //! Tests whether the instruction is AVX+ (FMA included) instruction that accesses XMM|YMM|ZMM registers. + ASMJIT_INLINE_NODEBUG bool isAvx() const noexcept { return commonInfo().isAvx(); } + + //! Tests whether the instruction can be prefixed with LOCK prefix. + ASMJIT_INLINE_NODEBUG bool hasLockPrefix() const noexcept { return commonInfo().hasLockPrefix(); } + //! Tests whether the instruction can be prefixed with REP (REPE|REPZ) prefix. + ASMJIT_INLINE_NODEBUG bool hasRepPrefix() const noexcept { return commonInfo().hasRepPrefix(); } + //! Tests whether the instruction can be prefixed with XACQUIRE prefix. + ASMJIT_INLINE_NODEBUG bool hasXAcquirePrefix() const noexcept { return commonInfo().hasXAcquirePrefix(); } + //! Tests whether the instruction can be prefixed with XRELEASE prefix. + ASMJIT_INLINE_NODEBUG bool hasXReleasePrefix() const noexcept { return commonInfo().hasXReleasePrefix(); } + + //! Tests whether the rep prefix is supported by the instruction, but ignored (has no effect). + ASMJIT_INLINE_NODEBUG bool isRepIgnored() const noexcept { return commonInfo().isRepIgnored(); } + //! Tests whether the instruction uses MIB. + ASMJIT_INLINE_NODEBUG bool isMibOp() const noexcept { return hasFlag(InstFlags::kMib); } + //! Tests whether the instruction uses VSIB. + ASMJIT_INLINE_NODEBUG bool isVsibOp() const noexcept { return hasFlag(InstFlags::kVsib); } + //! Tests whether the instruction uses VEX (can be set together with EVEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVex() const noexcept { return hasFlag(InstFlags::kVex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isEvex() const noexcept { return hasFlag(InstFlags::kEvex); } + //! Tests whether the instruction uses EVEX (can be set together with VEX if both are encodable). + ASMJIT_INLINE_NODEBUG bool isVexOrEvex() const noexcept { return hasFlag(InstFlags::kVex | InstFlags::kEvex); } + + ASMJIT_INLINE_NODEBUG bool isEvexCompatible() const noexcept { return hasFlag(InstFlags::kEvexCompat); } + ASMJIT_INLINE_NODEBUG bool isEvexKRegOnly() const noexcept { return hasFlag(InstFlags::kEvexKReg); } + ASMJIT_INLINE_NODEBUG bool isEvexTwoOpOnly() const noexcept { return hasFlag(InstFlags::kEvexTwoOp); } + ASMJIT_INLINE_NODEBUG bool isEvexTransformable() const noexcept { return hasFlag(InstFlags::kEvexTransformable); } + + //! Tests whether the instruction supports AVX512 masking {k}. + ASMJIT_INLINE_NODEBUG bool hasAvx512K() const noexcept { return hasAvx512Flag(Avx512Flags::kK); } + //! Tests whether the instruction supports AVX512 zeroing {k}{z}. + ASMJIT_INLINE_NODEBUG bool hasAvx512Z() const noexcept { return hasAvx512Flag(Avx512Flags::kZ); } + //! Tests whether the instruction supports AVX512 embedded-rounding {er}. + ASMJIT_INLINE_NODEBUG bool hasAvx512ER() const noexcept { return hasAvx512Flag(Avx512Flags::kER); } + //! Tests whether the instruction supports AVX512 suppress-all-exceptions {sae}. + ASMJIT_INLINE_NODEBUG bool hasAvx512SAE() const noexcept { return hasAvx512Flag(Avx512Flags::kSAE); } + //! Tests whether the instruction supports AVX512 broadcast (either 32-bit or 64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B() const noexcept { return hasAvx512Flag(Avx512Flags::kB16 | Avx512Flags::kB32 | Avx512Flags::kB64); } + //! Tests whether the instruction supports AVX512 broadcast (16-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B16() const noexcept { return hasAvx512Flag(Avx512Flags::kB16); } + //! Tests whether the instruction supports AVX512 broadcast (32-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B32() const noexcept { return hasAvx512Flag(Avx512Flags::kB32); } + //! Tests whether the instruction supports AVX512 broadcast (64-bit). + ASMJIT_INLINE_NODEBUG bool hasAvx512B64() const noexcept { return hasAvx512Flag(Avx512Flags::kB64); } + + //! Returns a control flow category of the instruction. + ASMJIT_INLINE_NODEBUG InstControlFlow controlFlow() const noexcept { return commonInfo().controlFlow(); } + //! Returns a hint that can be used when both inputs are the same register. + ASMJIT_INLINE_NODEBUG InstSameRegHint sameRegHint() const noexcept { return commonInfo().sameRegHint(); } + + ASMJIT_INLINE_NODEBUG uint32_t signatureIndex() const noexcept { return commonInfo().signatureIndex(); } + ASMJIT_INLINE_NODEBUG uint32_t signatureCount() const noexcept { return commonInfo().signatureCount(); } + + ASMJIT_INLINE_NODEBUG const InstSignature* signatureData() const noexcept { return commonInfo().signatureData(); } + ASMJIT_INLINE_NODEBUG const InstSignature* signatureEnd() const noexcept { return commonInfo().signatureEnd(); } + + //! \} +}; + +ASMJIT_VARAPI const InstInfo _instInfoTable[]; + +static inline const InstInfo& infoById(InstId instId) noexcept { + ASMJIT_ASSERT(Inst::isDefinedId(instId)); + return _instInfoTable[instId]; +} + +//! \cond INTERNAL +static_assert(sizeof(OpSignature) == 8, "InstDB::OpSignature must be 8 bytes long"); +//! \endcond + +} // {InstDB} + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +#endif // ASMJIT_X86_X86INSTDB_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86operand.h b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86operand.h new file mode 100644 index 0000000000000000000000000000000000000000..94ba4115846001fdc2804db51385dbac29ed550c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/asmjit/x86/x86operand.h @@ -0,0 +1,1145 @@ +// This file is part of AsmJit project +// +// See asmjit.h or LICENSE.md for license and copyright information +// SPDX-License-Identifier: Zlib + +#ifndef ASMJIT_X86_X86OPERAND_H_INCLUDED +#define ASMJIT_X86_X86OPERAND_H_INCLUDED + +#include "../core/archtraits.h" +#include "../core/operand.h" +#include "../core/type.h" +#include "../x86/x86globals.h" + +ASMJIT_BEGIN_SUB_NAMESPACE(x86) + +//! \addtogroup asmjit_x86 +//! \{ + +class Reg; +class Mem; + +class Gp; +class Gpb; +class GpbLo; +class GpbHi; +class Gpw; +class Gpd; +class Gpq; +class Vec; +class Xmm; +class Ymm; +class Zmm; +class Mm; +class KReg; +class SReg; +class CReg; +class DReg; +class St; +class Bnd; +class Tmm; +class Rip; + +//! Register traits (X86). +//! +//! Register traits contains information about a particular register type. It's used by asmjit to setup register +//! information on-the-fly and to populate tables that contain register information (this way it's possible to change +//! register types and groups without having to reorder these tables). +template +struct RegTraits : public BaseRegTraits {}; + +//! \cond +// <--------------------+------------------------+------------------------+---+------------------+ +// | Reg-Type | Reg-Group |Sz | TypeId | +// <--------------------+------------------------+------------------------+---+------------------+ +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Rip , RegGroup::kX86_Rip , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_GpbLo , RegGroup::kGp , 1 , TypeId::kInt8 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_GpbHi , RegGroup::kGp , 1 , TypeId::kInt8 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Gpw , RegGroup::kGp , 2 , TypeId::kInt16 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Gpd , RegGroup::kGp , 4 , TypeId::kInt32 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Gpq , RegGroup::kGp , 8 , TypeId::kInt64 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Xmm , RegGroup::kVec , 16, TypeId::kInt32x4 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Ymm , RegGroup::kVec , 32, TypeId::kInt32x8 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Zmm , RegGroup::kVec , 64, TypeId::kInt32x16); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_KReg , RegGroup::kX86_K , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Mm , RegGroup::kX86_MM , 8 , TypeId::kMmx64 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_SReg , RegGroup::kX86_SReg , 2 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_CReg , RegGroup::kX86_CReg , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_DReg , RegGroup::kX86_DReg , 0 , TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_St , RegGroup::kX86_St , 10, TypeId::kFloat80 ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Bnd , RegGroup::kX86_Bnd , 16, TypeId::kVoid ); +ASMJIT_DEFINE_REG_TRAITS(RegType::kX86_Tmm , RegGroup::kX86_Tmm , 0 , TypeId::kVoid ); +//! \endcond + +//! Register (X86). +class Reg : public BaseReg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Reg, BaseReg) + + //! Tests whether the register is a GPB register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpb() const noexcept { return size() == 1; } + //! Tests whether the register is a low GPB register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpbLo() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a high GPB register (8-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpbHi() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a GPW register (16-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpw() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a GPD register (32-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpd() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a GPQ register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isGpq() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is a 32-bit general purpose register, alias of \ref isGpd(). + ASMJIT_INLINE_NODEBUG constexpr bool isGp32() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a 64-bit general purpose register, alias of \ref isGpq() + ASMJIT_INLINE_NODEBUG constexpr bool isGp64() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is an XMM register (128-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isXmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a YMM register (256-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isYmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a ZMM register (512-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isZmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is a 128-bit vector register, alias of \ref isXmm(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec128() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a 256-bit vector register, alias of \ref isYmm(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec256() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a 512-bit vector register, alias of \ref isZmm(). + ASMJIT_INLINE_NODEBUG constexpr bool isVec512() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + //! Tests whether the register is an MMX register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isMm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a K register (64-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isKReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a segment register. + ASMJIT_INLINE_NODEBUG constexpr bool isSReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a control register. + ASMJIT_INLINE_NODEBUG constexpr bool isCReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a debug register. + ASMJIT_INLINE_NODEBUG constexpr bool isDReg() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is an FPU register (80-bit). + ASMJIT_INLINE_NODEBUG constexpr bool isSt() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a bound register. + ASMJIT_INLINE_NODEBUG constexpr bool isBnd() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is a TMM register. + ASMJIT_INLINE_NODEBUG constexpr bool isTmm() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + //! Tests whether the register is RIP. + ASMJIT_INLINE_NODEBUG constexpr bool isRip() const noexcept { return hasBaseSignature(RegTraits::kSignature); } + + template + ASMJIT_INLINE_NODEBUG void setRegT(uint32_t rId) noexcept { + setSignature(OperandSignature{RegTraits::kSignature}); + setId(rId); + } + + ASMJIT_INLINE_NODEBUG void setTypeAndId(RegType type, uint32_t id) noexcept { + setSignature(signatureOf(type)); + setId(id); + } + + static ASMJIT_INLINE_NODEBUG RegGroup groupOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kX86).regTypeToGroup(type); } + static ASMJIT_INLINE_NODEBUG TypeId typeIdOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kX86).regTypeToTypeId(type); } + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOf(RegType type) noexcept { return ArchTraits::byArch(Arch::kX86).regTypeToSignature(type); } + + template + static ASMJIT_INLINE_NODEBUG RegGroup groupOfT() noexcept { return RegGroup(RegTraits::kGroup); } + + template + static ASMJIT_INLINE_NODEBUG TypeId typeIdOfT() noexcept { return TypeId(RegTraits::kTypeId); } + + template + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfT() noexcept { return OperandSignature{RegTraits::kSignature}; } + + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfVecByType(TypeId typeId) noexcept { + return OperandSignature{typeId <= TypeId::_kVec128End ? uint32_t(RegTraits::kSignature) : + typeId <= TypeId::_kVec256End ? uint32_t(RegTraits::kSignature) : + uint32_t(RegTraits::kSignature)}; + } + + static ASMJIT_INLINE_NODEBUG OperandSignature signatureOfVecBySize(uint32_t size) noexcept { + return OperandSignature{size <= 16 ? uint32_t(RegTraits::kSignature) : + size <= 32 ? uint32_t(RegTraits::kSignature) : + uint32_t(RegTraits::kSignature)}; + } + + //! Tests whether the `op` operand is either a low or high 8-bit GPB register. + static ASMJIT_INLINE_NODEBUG bool isGpb(const Operand_& op) noexcept { + // Check operand type, register group, and size. Not interested in register type. + return op.signature().subset(Signature::kOpTypeMask | Signature::kRegGroupMask | Signature::kSizeMask) == + (Signature::fromOpType(OperandType::kReg) | Signature::fromRegGroup(RegGroup::kGp) | Signature::fromSize(1)); + } + + static ASMJIT_INLINE_NODEBUG bool isGpbLo(const Operand_& op) noexcept { return op.as().isGpbLo(); } + static ASMJIT_INLINE_NODEBUG bool isGpbHi(const Operand_& op) noexcept { return op.as().isGpbHi(); } + static ASMJIT_INLINE_NODEBUG bool isGpw(const Operand_& op) noexcept { return op.as().isGpw(); } + static ASMJIT_INLINE_NODEBUG bool isGpd(const Operand_& op) noexcept { return op.as().isGpd(); } + static ASMJIT_INLINE_NODEBUG bool isGpq(const Operand_& op) noexcept { return op.as().isGpq(); } + static ASMJIT_INLINE_NODEBUG bool isXmm(const Operand_& op) noexcept { return op.as().isXmm(); } + static ASMJIT_INLINE_NODEBUG bool isYmm(const Operand_& op) noexcept { return op.as().isYmm(); } + static ASMJIT_INLINE_NODEBUG bool isZmm(const Operand_& op) noexcept { return op.as().isZmm(); } + static ASMJIT_INLINE_NODEBUG bool isMm(const Operand_& op) noexcept { return op.as().isMm(); } + static ASMJIT_INLINE_NODEBUG bool isKReg(const Operand_& op) noexcept { return op.as().isKReg(); } + static ASMJIT_INLINE_NODEBUG bool isSReg(const Operand_& op) noexcept { return op.as().isSReg(); } + static ASMJIT_INLINE_NODEBUG bool isCReg(const Operand_& op) noexcept { return op.as().isCReg(); } + static ASMJIT_INLINE_NODEBUG bool isDReg(const Operand_& op) noexcept { return op.as().isDReg(); } + static ASMJIT_INLINE_NODEBUG bool isSt(const Operand_& op) noexcept { return op.as().isSt(); } + static ASMJIT_INLINE_NODEBUG bool isBnd(const Operand_& op) noexcept { return op.as().isBnd(); } + static ASMJIT_INLINE_NODEBUG bool isTmm(const Operand_& op) noexcept { return op.as().isTmm(); } + static ASMJIT_INLINE_NODEBUG bool isRip(const Operand_& op) noexcept { return op.as().isRip(); } + + static ASMJIT_INLINE_NODEBUG bool isGpb(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpb(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpbLo(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpbLo(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpbHi(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpbHi(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpw(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpw(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpd(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpd(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isGpq(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isGpq(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isXmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isXmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isYmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isYmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isZmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isZmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isMm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isMm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isKReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isKReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isSReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isSReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isCReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isCReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isDReg(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isDReg(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isSt(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isSt(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isBnd(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isBnd(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isTmm(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isTmm(op)) & unsigned(op.id() == rId)); } + static ASMJIT_INLINE_NODEBUG bool isRip(const Operand_& op, uint32_t rId) noexcept { return bool(unsigned(isRip(op)) & unsigned(op.id() == rId)); } +}; + +//! General purpose register (X86). +class Gp : public Reg { +public: + ASMJIT_DEFINE_ABSTRACT_REG(Gp, Reg) + + //! Physical id (X86). + //! + //! \note Register indexes have been reduced to only support general purpose registers. There is no need to + //! have enumerations with number suffix that expands to the exactly same value as the suffix value itself. + enum Id : uint32_t { + kIdAx = 0, //!< Physical id of AL|AH|AX|EAX|RAX registers. + kIdCx = 1, //!< Physical id of CL|CH|CX|ECX|RCX registers. + kIdDx = 2, //!< Physical id of DL|DH|DX|EDX|RDX registers. + kIdBx = 3, //!< Physical id of BL|BH|BX|EBX|RBX registers. + kIdSp = 4, //!< Physical id of SPL|SP|ESP|RSP registers. + kIdBp = 5, //!< Physical id of BPL|BP|EBP|RBP registers. + kIdSi = 6, //!< Physical id of SIL|SI|ESI|RSI registers. + kIdDi = 7, //!< Physical id of DIL|DI|EDI|RDI registers. + kIdR8 = 8, //!< Physical id of R8B|R8W|R8D|R8 registers (64-bit only). + kIdR9 = 9, //!< Physical id of R9B|R9W|R9D|R9 registers (64-bit only). + kIdR10 = 10, //!< Physical id of R10B|R10W|R10D|R10 registers (64-bit only). + kIdR11 = 11, //!< Physical id of R11B|R11W|R11D|R11 registers (64-bit only). + kIdR12 = 12, //!< Physical id of R12B|R12W|R12D|R12 registers (64-bit only). + kIdR13 = 13, //!< Physical id of R13B|R13W|R13D|R13 registers (64-bit only). + kIdR14 = 14, //!< Physical id of R14B|R14W|R14D|R14 registers (64-bit only). + kIdR15 = 15 //!< Physical id of R15B|R15W|R15D|R15 registers (64-bit only). + }; + + //! Casts this register to 8-bit (LO) part. + ASMJIT_INLINE_NODEBUG GpbLo r8() const noexcept; + //! Casts this register to 8-bit (LO) part. + ASMJIT_INLINE_NODEBUG GpbLo r8Lo() const noexcept; + //! Casts this register to 8-bit (HI) part. + ASMJIT_INLINE_NODEBUG GpbHi r8Hi() const noexcept; + //! Casts this register to 16-bit. + ASMJIT_INLINE_NODEBUG Gpw r16() const noexcept; + //! Casts this register to 32-bit. + ASMJIT_INLINE_NODEBUG Gpd r32() const noexcept; + //! Casts this register to 64-bit. + ASMJIT_INLINE_NODEBUG Gpq r64() const noexcept; +}; + +//! Vector register (XMM|YMM|ZMM) (X86). +class Vec : public Reg { + ASMJIT_DEFINE_ABSTRACT_REG(Vec, Reg) + + //! Casts this register to XMM (clone). + ASMJIT_INLINE_NODEBUG Xmm xmm() const noexcept; + //! Casts this register to YMM (clone). + ASMJIT_INLINE_NODEBUG Ymm ymm() const noexcept; + //! Casts this register to ZMM (clone). + ASMJIT_INLINE_NODEBUG Zmm zmm() const noexcept; + + //! Casts this register to XMM (clone). + ASMJIT_INLINE_NODEBUG Vec v128() const noexcept; + //! Casts this register to YMM (clone). + ASMJIT_INLINE_NODEBUG Vec v256() const noexcept; + //! Casts this register to ZMM (clone). + ASMJIT_INLINE_NODEBUG Vec v512() const noexcept; + + //! Casts this register to a register that has half the size (or XMM if it's already XMM). + ASMJIT_INLINE_NODEBUG Vec half() const noexcept { + return Vec(type() == RegType::kX86_Zmm ? signatureOfT() : signatureOfT(), id()); + } +}; + +//! Segment register (X86). +class SReg : public Reg { + ASMJIT_DEFINE_FINAL_REG(SReg, Reg, RegTraits) + + //! X86 segment id. + enum Id : uint32_t { + //! No segment (default). + kIdNone = 0, + //! ES segment. + kIdEs = 1, + //! CS segment. + kIdCs = 2, + //! SS segment. + kIdSs = 3, + //! DS segment. + kIdDs = 4, + //! FS segment. + kIdFs = 5, + //! GS segment. + kIdGs = 6, + + //! Count of X86 segment registers supported by AsmJit. + //! + //! \note X86 architecture has 6 segment registers - ES, CS, SS, DS, FS, GS. X64 architecture lowers them down to + //! just FS and GS. AsmJit supports 7 segment registers - all addressable in both X86 and X64 modes and one extra + //! called `SReg::kIdNone`, which is AsmJit specific and means that there is no segment register specified. + kIdCount = 7 + }; +}; + +//! GPB low or high register (X86). +class Gpb : public Gp { ASMJIT_DEFINE_ABSTRACT_REG(Gpb, Gp) }; +//! GPB low register (X86). +class GpbLo : public Gpb { ASMJIT_DEFINE_FINAL_REG(GpbLo, Gpb, RegTraits) }; +//! GPB high register (X86). +class GpbHi : public Gpb { ASMJIT_DEFINE_FINAL_REG(GpbHi, Gpb, RegTraits) }; +//! GPW register (X86). +class Gpw : public Gp { ASMJIT_DEFINE_FINAL_REG(Gpw, Gp, RegTraits) }; +//! GPD register (X86). +class Gpd : public Gp { ASMJIT_DEFINE_FINAL_REG(Gpd, Gp, RegTraits) }; +//! GPQ register (X86_64). +class Gpq : public Gp { ASMJIT_DEFINE_FINAL_REG(Gpq, Gp, RegTraits) }; + +//! 128-bit XMM register (SSE+). +class Xmm : public Vec { + ASMJIT_DEFINE_FINAL_REG(Xmm, Vec, RegTraits) + //! Casts this register to a register that has half the size (XMM). + ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { return Xmm(id()); } +}; + +//! 256-bit YMM register (AVX+). +class Ymm : public Vec { + ASMJIT_DEFINE_FINAL_REG(Ymm, Vec, RegTraits) + //! Casts this register to a register that has half the size (XMM). + ASMJIT_INLINE_NODEBUG Xmm half() const noexcept { return Xmm(id()); } +}; + +//! 512-bit ZMM register (AVX512+). +class Zmm : public Vec { + ASMJIT_DEFINE_FINAL_REG(Zmm, Vec, RegTraits) + //! Casts this register to a register that has half the size (YMM). + ASMJIT_INLINE_NODEBUG Ymm half() const noexcept { return Ymm(id()); } +}; + +//! 64-bit MMX register (MMX+). +class Mm : public Reg { ASMJIT_DEFINE_FINAL_REG(Mm, Reg, RegTraits) }; +//! 64-bit K register (AVX512+). +class KReg : public Reg { ASMJIT_DEFINE_FINAL_REG(KReg, Reg, RegTraits) }; +//! 32-bit or 64-bit control register (X86). +class CReg : public Reg { ASMJIT_DEFINE_FINAL_REG(CReg, Reg, RegTraits) }; +//! 32-bit or 64-bit debug register (X86). +class DReg : public Reg { ASMJIT_DEFINE_FINAL_REG(DReg, Reg, RegTraits) }; +//! 80-bit FPU register (X86). +class St : public Reg { ASMJIT_DEFINE_FINAL_REG(St, Reg, RegTraits) }; +//! 128-bit BND register (BND+). +class Bnd : public Reg { ASMJIT_DEFINE_FINAL_REG(Bnd, Reg, RegTraits) }; +//! 8192-bit TMM register (AMX). +class Tmm : public Reg { ASMJIT_DEFINE_FINAL_REG(Tmm, Reg, RegTraits) }; +//! RIP register (X86). +class Rip : public Reg { ASMJIT_DEFINE_FINAL_REG(Rip, Reg, RegTraits) }; + +//! \cond +ASMJIT_INLINE_NODEBUG GpbLo Gp::r8() const noexcept { return GpbLo(id()); } +ASMJIT_INLINE_NODEBUG GpbLo Gp::r8Lo() const noexcept { return GpbLo(id()); } +ASMJIT_INLINE_NODEBUG GpbHi Gp::r8Hi() const noexcept { return GpbHi(id()); } +ASMJIT_INLINE_NODEBUG Gpw Gp::r16() const noexcept { return Gpw(id()); } +ASMJIT_INLINE_NODEBUG Gpd Gp::r32() const noexcept { return Gpd(id()); } +ASMJIT_INLINE_NODEBUG Gpq Gp::r64() const noexcept { return Gpq(id()); } +ASMJIT_INLINE_NODEBUG Xmm Vec::xmm() const noexcept { return Xmm(id()); } +ASMJIT_INLINE_NODEBUG Ymm Vec::ymm() const noexcept { return Ymm(id()); } +ASMJIT_INLINE_NODEBUG Zmm Vec::zmm() const noexcept { return Zmm(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v128() const noexcept { return Xmm(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v256() const noexcept { return Ymm(id()); } +ASMJIT_INLINE_NODEBUG Vec Vec::v512() const noexcept { return Zmm(id()); } +//! \endcond + +//! \namespace asmjit::x86::regs +//! +//! Registers provided by X86 and X64 ISAs are in both `asmjit::x86` and `asmjit::x86::regs` namespaces so they can +//! be included with using directive. For example `using namespace asmjit::x86::regs` would include all registers, +//! but not other X86-specific API, whereas `using namespace asmjit::x86` would include everything X86-specific. +#ifndef _DOXYGEN +namespace regs { +#endif + +//! Creates an 8-bit low GPB register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpbLo gpb(uint32_t rId) noexcept { return GpbLo(rId); } +//! Creates an 8-bit low GPB register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpbLo gpb_lo(uint32_t rId) noexcept { return GpbLo(rId); } +//! Creates an 8-bit high GPB register operand. +static ASMJIT_INLINE_NODEBUG constexpr GpbHi gpb_hi(uint32_t rId) noexcept { return GpbHi(rId); } +//! Creates a 16-bit GPW register operand. +static ASMJIT_INLINE_NODEBUG constexpr Gpw gpw(uint32_t rId) noexcept { return Gpw(rId); } +//! Creates a 32-bit GPD register operand. +static ASMJIT_INLINE_NODEBUG constexpr Gpd gpd(uint32_t rId) noexcept { return Gpd(rId); } +//! Creates a 64-bit GPQ register operand (64-bit). +static ASMJIT_INLINE_NODEBUG constexpr Gpq gpq(uint32_t rId) noexcept { return Gpq(rId); } +//! Creates a 128-bit XMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Xmm xmm(uint32_t rId) noexcept { return Xmm(rId); } +//! Creates a 256-bit YMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Ymm ymm(uint32_t rId) noexcept { return Ymm(rId); } +//! Creates a 512-bit ZMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Zmm zmm(uint32_t rId) noexcept { return Zmm(rId); } +//! Creates a 64-bit Mm register operand. +static ASMJIT_INLINE_NODEBUG constexpr Mm mm(uint32_t rId) noexcept { return Mm(rId); } +//! Creates a 64-bit K register operand. +static ASMJIT_INLINE_NODEBUG constexpr KReg k(uint32_t rId) noexcept { return KReg(rId); } +//! Creates a 32-bit or 64-bit control register operand. +static ASMJIT_INLINE_NODEBUG constexpr CReg cr(uint32_t rId) noexcept { return CReg(rId); } +//! Creates a 32-bit or 64-bit debug register operand. +static ASMJIT_INLINE_NODEBUG constexpr DReg dr(uint32_t rId) noexcept { return DReg(rId); } +//! Creates an 80-bit st register operand. +static ASMJIT_INLINE_NODEBUG constexpr St st(uint32_t rId) noexcept { return St(rId); } +//! Creates a 128-bit bound register operand. +static ASMJIT_INLINE_NODEBUG constexpr Bnd bnd(uint32_t rId) noexcept { return Bnd(rId); } +//! Creates a TMM register operand. +static ASMJIT_INLINE_NODEBUG constexpr Tmm tmm(uint32_t rId) noexcept { return Tmm(rId); } + +static constexpr GpbLo al = GpbLo(Gp::kIdAx); +static constexpr GpbLo bl = GpbLo(Gp::kIdBx); +static constexpr GpbLo cl = GpbLo(Gp::kIdCx); +static constexpr GpbLo dl = GpbLo(Gp::kIdDx); +static constexpr GpbLo spl = GpbLo(Gp::kIdSp); +static constexpr GpbLo bpl = GpbLo(Gp::kIdBp); +static constexpr GpbLo sil = GpbLo(Gp::kIdSi); +static constexpr GpbLo dil = GpbLo(Gp::kIdDi); +static constexpr GpbLo r8b = GpbLo(Gp::kIdR8); +static constexpr GpbLo r9b = GpbLo(Gp::kIdR9); +static constexpr GpbLo r10b = GpbLo(Gp::kIdR10); +static constexpr GpbLo r11b = GpbLo(Gp::kIdR11); +static constexpr GpbLo r12b = GpbLo(Gp::kIdR12); +static constexpr GpbLo r13b = GpbLo(Gp::kIdR13); +static constexpr GpbLo r14b = GpbLo(Gp::kIdR14); +static constexpr GpbLo r15b = GpbLo(Gp::kIdR15); + +static constexpr GpbHi ah = GpbHi(Gp::kIdAx); +static constexpr GpbHi bh = GpbHi(Gp::kIdBx); +static constexpr GpbHi ch = GpbHi(Gp::kIdCx); +static constexpr GpbHi dh = GpbHi(Gp::kIdDx); + +static constexpr Gpw ax = Gpw(Gp::kIdAx); +static constexpr Gpw bx = Gpw(Gp::kIdBx); +static constexpr Gpw cx = Gpw(Gp::kIdCx); +static constexpr Gpw dx = Gpw(Gp::kIdDx); +static constexpr Gpw sp = Gpw(Gp::kIdSp); +static constexpr Gpw bp = Gpw(Gp::kIdBp); +static constexpr Gpw si = Gpw(Gp::kIdSi); +static constexpr Gpw di = Gpw(Gp::kIdDi); +static constexpr Gpw r8w = Gpw(Gp::kIdR8); +static constexpr Gpw r9w = Gpw(Gp::kIdR9); +static constexpr Gpw r10w = Gpw(Gp::kIdR10); +static constexpr Gpw r11w = Gpw(Gp::kIdR11); +static constexpr Gpw r12w = Gpw(Gp::kIdR12); +static constexpr Gpw r13w = Gpw(Gp::kIdR13); +static constexpr Gpw r14w = Gpw(Gp::kIdR14); +static constexpr Gpw r15w = Gpw(Gp::kIdR15); + +static constexpr Gpd eax = Gpd(Gp::kIdAx); +static constexpr Gpd ebx = Gpd(Gp::kIdBx); +static constexpr Gpd ecx = Gpd(Gp::kIdCx); +static constexpr Gpd edx = Gpd(Gp::kIdDx); +static constexpr Gpd esp = Gpd(Gp::kIdSp); +static constexpr Gpd ebp = Gpd(Gp::kIdBp); +static constexpr Gpd esi = Gpd(Gp::kIdSi); +static constexpr Gpd edi = Gpd(Gp::kIdDi); +static constexpr Gpd r8d = Gpd(Gp::kIdR8); +static constexpr Gpd r9d = Gpd(Gp::kIdR9); +static constexpr Gpd r10d = Gpd(Gp::kIdR10); +static constexpr Gpd r11d = Gpd(Gp::kIdR11); +static constexpr Gpd r12d = Gpd(Gp::kIdR12); +static constexpr Gpd r13d = Gpd(Gp::kIdR13); +static constexpr Gpd r14d = Gpd(Gp::kIdR14); +static constexpr Gpd r15d = Gpd(Gp::kIdR15); + +static constexpr Gpq rax = Gpq(Gp::kIdAx); +static constexpr Gpq rbx = Gpq(Gp::kIdBx); +static constexpr Gpq rcx = Gpq(Gp::kIdCx); +static constexpr Gpq rdx = Gpq(Gp::kIdDx); +static constexpr Gpq rsp = Gpq(Gp::kIdSp); +static constexpr Gpq rbp = Gpq(Gp::kIdBp); +static constexpr Gpq rsi = Gpq(Gp::kIdSi); +static constexpr Gpq rdi = Gpq(Gp::kIdDi); +static constexpr Gpq r8 = Gpq(Gp::kIdR8); +static constexpr Gpq r9 = Gpq(Gp::kIdR9); +static constexpr Gpq r10 = Gpq(Gp::kIdR10); +static constexpr Gpq r11 = Gpq(Gp::kIdR11); +static constexpr Gpq r12 = Gpq(Gp::kIdR12); +static constexpr Gpq r13 = Gpq(Gp::kIdR13); +static constexpr Gpq r14 = Gpq(Gp::kIdR14); +static constexpr Gpq r15 = Gpq(Gp::kIdR15); + +static constexpr Xmm xmm0 = Xmm(0); +static constexpr Xmm xmm1 = Xmm(1); +static constexpr Xmm xmm2 = Xmm(2); +static constexpr Xmm xmm3 = Xmm(3); +static constexpr Xmm xmm4 = Xmm(4); +static constexpr Xmm xmm5 = Xmm(5); +static constexpr Xmm xmm6 = Xmm(6); +static constexpr Xmm xmm7 = Xmm(7); +static constexpr Xmm xmm8 = Xmm(8); +static constexpr Xmm xmm9 = Xmm(9); +static constexpr Xmm xmm10 = Xmm(10); +static constexpr Xmm xmm11 = Xmm(11); +static constexpr Xmm xmm12 = Xmm(12); +static constexpr Xmm xmm13 = Xmm(13); +static constexpr Xmm xmm14 = Xmm(14); +static constexpr Xmm xmm15 = Xmm(15); +static constexpr Xmm xmm16 = Xmm(16); +static constexpr Xmm xmm17 = Xmm(17); +static constexpr Xmm xmm18 = Xmm(18); +static constexpr Xmm xmm19 = Xmm(19); +static constexpr Xmm xmm20 = Xmm(20); +static constexpr Xmm xmm21 = Xmm(21); +static constexpr Xmm xmm22 = Xmm(22); +static constexpr Xmm xmm23 = Xmm(23); +static constexpr Xmm xmm24 = Xmm(24); +static constexpr Xmm xmm25 = Xmm(25); +static constexpr Xmm xmm26 = Xmm(26); +static constexpr Xmm xmm27 = Xmm(27); +static constexpr Xmm xmm28 = Xmm(28); +static constexpr Xmm xmm29 = Xmm(29); +static constexpr Xmm xmm30 = Xmm(30); +static constexpr Xmm xmm31 = Xmm(31); + +static constexpr Ymm ymm0 = Ymm(0); +static constexpr Ymm ymm1 = Ymm(1); +static constexpr Ymm ymm2 = Ymm(2); +static constexpr Ymm ymm3 = Ymm(3); +static constexpr Ymm ymm4 = Ymm(4); +static constexpr Ymm ymm5 = Ymm(5); +static constexpr Ymm ymm6 = Ymm(6); +static constexpr Ymm ymm7 = Ymm(7); +static constexpr Ymm ymm8 = Ymm(8); +static constexpr Ymm ymm9 = Ymm(9); +static constexpr Ymm ymm10 = Ymm(10); +static constexpr Ymm ymm11 = Ymm(11); +static constexpr Ymm ymm12 = Ymm(12); +static constexpr Ymm ymm13 = Ymm(13); +static constexpr Ymm ymm14 = Ymm(14); +static constexpr Ymm ymm15 = Ymm(15); +static constexpr Ymm ymm16 = Ymm(16); +static constexpr Ymm ymm17 = Ymm(17); +static constexpr Ymm ymm18 = Ymm(18); +static constexpr Ymm ymm19 = Ymm(19); +static constexpr Ymm ymm20 = Ymm(20); +static constexpr Ymm ymm21 = Ymm(21); +static constexpr Ymm ymm22 = Ymm(22); +static constexpr Ymm ymm23 = Ymm(23); +static constexpr Ymm ymm24 = Ymm(24); +static constexpr Ymm ymm25 = Ymm(25); +static constexpr Ymm ymm26 = Ymm(26); +static constexpr Ymm ymm27 = Ymm(27); +static constexpr Ymm ymm28 = Ymm(28); +static constexpr Ymm ymm29 = Ymm(29); +static constexpr Ymm ymm30 = Ymm(30); +static constexpr Ymm ymm31 = Ymm(31); + +static constexpr Zmm zmm0 = Zmm(0); +static constexpr Zmm zmm1 = Zmm(1); +static constexpr Zmm zmm2 = Zmm(2); +static constexpr Zmm zmm3 = Zmm(3); +static constexpr Zmm zmm4 = Zmm(4); +static constexpr Zmm zmm5 = Zmm(5); +static constexpr Zmm zmm6 = Zmm(6); +static constexpr Zmm zmm7 = Zmm(7); +static constexpr Zmm zmm8 = Zmm(8); +static constexpr Zmm zmm9 = Zmm(9); +static constexpr Zmm zmm10 = Zmm(10); +static constexpr Zmm zmm11 = Zmm(11); +static constexpr Zmm zmm12 = Zmm(12); +static constexpr Zmm zmm13 = Zmm(13); +static constexpr Zmm zmm14 = Zmm(14); +static constexpr Zmm zmm15 = Zmm(15); +static constexpr Zmm zmm16 = Zmm(16); +static constexpr Zmm zmm17 = Zmm(17); +static constexpr Zmm zmm18 = Zmm(18); +static constexpr Zmm zmm19 = Zmm(19); +static constexpr Zmm zmm20 = Zmm(20); +static constexpr Zmm zmm21 = Zmm(21); +static constexpr Zmm zmm22 = Zmm(22); +static constexpr Zmm zmm23 = Zmm(23); +static constexpr Zmm zmm24 = Zmm(24); +static constexpr Zmm zmm25 = Zmm(25); +static constexpr Zmm zmm26 = Zmm(26); +static constexpr Zmm zmm27 = Zmm(27); +static constexpr Zmm zmm28 = Zmm(28); +static constexpr Zmm zmm29 = Zmm(29); +static constexpr Zmm zmm30 = Zmm(30); +static constexpr Zmm zmm31 = Zmm(31); + +static constexpr Mm mm0 = Mm(0); +static constexpr Mm mm1 = Mm(1); +static constexpr Mm mm2 = Mm(2); +static constexpr Mm mm3 = Mm(3); +static constexpr Mm mm4 = Mm(4); +static constexpr Mm mm5 = Mm(5); +static constexpr Mm mm6 = Mm(6); +static constexpr Mm mm7 = Mm(7); + +static constexpr KReg k0 = KReg(0); +static constexpr KReg k1 = KReg(1); +static constexpr KReg k2 = KReg(2); +static constexpr KReg k3 = KReg(3); +static constexpr KReg k4 = KReg(4); +static constexpr KReg k5 = KReg(5); +static constexpr KReg k6 = KReg(6); +static constexpr KReg k7 = KReg(7); + +static constexpr SReg no_seg = SReg(SReg::kIdNone); +static constexpr SReg es = SReg(SReg::kIdEs); +static constexpr SReg cs = SReg(SReg::kIdCs); +static constexpr SReg ss = SReg(SReg::kIdSs); +static constexpr SReg ds = SReg(SReg::kIdDs); +static constexpr SReg fs = SReg(SReg::kIdFs); +static constexpr SReg gs = SReg(SReg::kIdGs); + +static constexpr CReg cr0 = CReg(0); +static constexpr CReg cr1 = CReg(1); +static constexpr CReg cr2 = CReg(2); +static constexpr CReg cr3 = CReg(3); +static constexpr CReg cr4 = CReg(4); +static constexpr CReg cr5 = CReg(5); +static constexpr CReg cr6 = CReg(6); +static constexpr CReg cr7 = CReg(7); +static constexpr CReg cr8 = CReg(8); +static constexpr CReg cr9 = CReg(9); +static constexpr CReg cr10 = CReg(10); +static constexpr CReg cr11 = CReg(11); +static constexpr CReg cr12 = CReg(12); +static constexpr CReg cr13 = CReg(13); +static constexpr CReg cr14 = CReg(14); +static constexpr CReg cr15 = CReg(15); + +static constexpr DReg dr0 = DReg(0); +static constexpr DReg dr1 = DReg(1); +static constexpr DReg dr2 = DReg(2); +static constexpr DReg dr3 = DReg(3); +static constexpr DReg dr4 = DReg(4); +static constexpr DReg dr5 = DReg(5); +static constexpr DReg dr6 = DReg(6); +static constexpr DReg dr7 = DReg(7); +static constexpr DReg dr8 = DReg(8); +static constexpr DReg dr9 = DReg(9); +static constexpr DReg dr10 = DReg(10); +static constexpr DReg dr11 = DReg(11); +static constexpr DReg dr12 = DReg(12); +static constexpr DReg dr13 = DReg(13); +static constexpr DReg dr14 = DReg(14); +static constexpr DReg dr15 = DReg(15); + +static constexpr St st0 = St(0); +static constexpr St st1 = St(1); +static constexpr St st2 = St(2); +static constexpr St st3 = St(3); +static constexpr St st4 = St(4); +static constexpr St st5 = St(5); +static constexpr St st6 = St(6); +static constexpr St st7 = St(7); + +static constexpr Bnd bnd0 = Bnd(0); +static constexpr Bnd bnd1 = Bnd(1); +static constexpr Bnd bnd2 = Bnd(2); +static constexpr Bnd bnd3 = Bnd(3); + +static constexpr Tmm tmm0 = Tmm(0); +static constexpr Tmm tmm1 = Tmm(1); +static constexpr Tmm tmm2 = Tmm(2); +static constexpr Tmm tmm3 = Tmm(3); +static constexpr Tmm tmm4 = Tmm(4); +static constexpr Tmm tmm5 = Tmm(5); +static constexpr Tmm tmm6 = Tmm(6); +static constexpr Tmm tmm7 = Tmm(7); + +static constexpr Rip rip = Rip(0); + +#ifndef _DOXYGEN +} // {regs} + +// Make `x86::regs` accessible through `x86` namespace as well. +using namespace regs; +#endif + +//! Memory operand specific to X86 and X86_64 architecture. +class Mem : public BaseMem { +public: + //! \name Constants + //! \{ + + //! Additional bits of operand's signature used by `x86::Mem`. + enum AdditionalBits : uint32_t { + // Memory address type (2 bits). + // |........|........|XX......|........| + kSignatureMemAddrTypeShift = 14, + kSignatureMemAddrTypeMask = 0x03u << kSignatureMemAddrTypeShift, + + // Memory shift amount (2 bits). + // |........|......XX|........|........| + kSignatureMemShiftValueShift = 16, + kSignatureMemShiftValueMask = 0x03u << kSignatureMemShiftValueShift, + + // Memory segment reg (3 bits). + // |........|...XXX..|........|........| + kSignatureMemSegmentShift = 18, + kSignatureMemSegmentMask = 0x07u << kSignatureMemSegmentShift, + + // Memory broadcast type (3 bits). + // |........|XXX.....|........|........| + kSignatureMemBroadcastShift = 21, + kSignatureMemBroadcastMask = 0x7u << kSignatureMemBroadcastShift + }; + + //! Address type. + enum class AddrType : uint32_t { + //! Default address type, Assembler will select the best type when necessary. + kDefault = 0, + //! Absolute address type. + kAbs = 1, + //! Relative address type. + kRel = 2, + + //! Maximum value of `AddrType`. + kMaxValue = kRel + }; + + //! Memory broadcast type. + enum class Broadcast : uint32_t { + //! No broadcast (regular memory operand). + kNone = 0, + //! Broadcast {1to2}. + k1To2 = 1, + //! Broadcast {1to4}. + k1To4 = 2, + //! Broadcast {1to8}. + k1To8 = 3, + //! Broadcast {1to16}. + k1To16 = 4, + //! Broadcast {1to32}. + k1To32 = 5, + //! Broadcast {1to64}. + k1To64 = 6, + + //! Maximum value of `Broadcast`. + kMaxValue = k1To64 + }; + + //! \} + + //! \name Construction & Destruction + //! \{ + + //! Creates a default `Mem` operand that points to [0]. + ASMJIT_INLINE_NODEBUG constexpr Mem() noexcept + : BaseMem() {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Mem& other) noexcept + : BaseMem(other) {} + + ASMJIT_INLINE_NODEBUG explicit Mem(Globals::NoInit_) noexcept + : BaseMem(Globals::NoInit) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Signature& signature, uint32_t baseId, uint32_t indexId, int32_t offset) noexcept + : BaseMem(signature, baseId, indexId, offset) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Label& base, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(RegType::kLabelTag) | + Signature::fromSize(size) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const Label& base, const BaseReg& index, uint32_t shift, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(RegType::kLabelTag) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(shift) | + Signature::fromSize(size) | + signature, base.id(), index.id(), off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromSize(size) | + signature, base.id(), 0, off) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(const BaseReg& base, const BaseReg& index, uint32_t shift, int32_t off, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemBaseType(base.type()) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(shift) | + Signature::fromSize(size) | + signature, base.id(), index.id(), off) {} + + ASMJIT_INLINE_NODEBUG constexpr explicit Mem(uint64_t base, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromSize(size) | + signature, uint32_t(base >> 32), 0, int32_t(uint32_t(base & 0xFFFFFFFFu))) {} + + ASMJIT_INLINE_NODEBUG constexpr Mem(uint64_t base, const BaseReg& index, uint32_t shift = 0, uint32_t size = 0, Signature signature = OperandSignature{0}) noexcept + : BaseMem(Signature::fromOpType(OperandType::kMem) | + Signature::fromMemIndexType(index.type()) | + Signature::fromValue(shift) | + Signature::fromSize(size) | + signature, uint32_t(base >> 32), index.id(), int32_t(uint32_t(base & 0xFFFFFFFFu))) {} + + //! \} + + //! \name Overloaded Operators + //! \{ + + ASMJIT_INLINE_NODEBUG Mem& operator=(const Mem& other) noexcept = default; + + //! \} + + //! \name Clone + //! \{ + + //! Clones the memory operand. + ASMJIT_INLINE_NODEBUG constexpr Mem clone() const noexcept { return Mem(*this); } + + //! Creates a copy of this memory operand adjusted by `off`. + inline Mem cloneAdjusted(int64_t off) const noexcept { + Mem result(*this); + result.addOffset(off); + return result; + } + + //! Creates a copy of this memory operand resized to `size`. + inline Mem cloneResized(uint32_t size) const noexcept { + Mem result(*this); + result.setSize(size); + return result; + } + + //! Creates a copy of this memory operand with a broadcast `bcst`. + ASMJIT_INLINE_NODEBUG constexpr Mem cloneBroadcasted(Broadcast bcst) const noexcept { + return Mem((_signature & ~Signature{kSignatureMemBroadcastMask}) | Signature::fromValue(bcst), _baseId, _data[0], int32_t(_data[1])); + } + + //! \} + + //! \name Base & Index + //! \{ + + //! Converts memory `baseType` and `baseId` to `x86::Reg` instance. + //! + //! The memory must have a valid base register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg baseReg() const noexcept { return Reg::fromTypeAndId(baseType(), baseId()); } + + //! Converts memory `indexType` and `indexId` to `x86::Reg` instance. + //! + //! The memory must have a valid index register otherwise the result will be wrong. + ASMJIT_INLINE_NODEBUG Reg indexReg() const noexcept { return Reg::fromTypeAndId(indexType(), indexId()); } + + using BaseMem::setIndex; + + ASMJIT_INLINE_NODEBUG void setIndex(const BaseReg& index, uint32_t shift) noexcept { + setIndex(index); + setShift(shift); + } + + //! \} + + //! \name Memory Size + //! \{ + + //! Tests whether the memory operand specifies a size (i.e. the size is not zero). + ASMJIT_INLINE_NODEBUG constexpr bool hasSize() const noexcept { return _signature.hasField(); } + //! Tests whether the memory operand size matches size `s`. + ASMJIT_INLINE_NODEBUG constexpr bool hasSize(uint32_t s) const noexcept { return size() == s; } + + //! Returns the size of the memory operand in bytes. + //! + //! \note Most instructions would deduce the size of the memory operand, so in most cases it's expected that the + //! returned value would be zero. However, some instruction require the size to select between multiple variations, + //! so in some cases size is required would be non-zero (for example `inc [mem], immediate` requires size to + //! distinguish between 8-bit, 16-bit, 32-bit, and 64-bit increments. + ASMJIT_INLINE_NODEBUG constexpr uint32_t size() const noexcept { return _signature.getField(); } + + //! \} + + //! \name Address Type + //! \{ + + //! Returns the address type of the memory operand. + //! + //! By default, address type of newly created memory operands is always \ref AddrType::kDefault. + ASMJIT_INLINE_NODEBUG constexpr AddrType addrType() const noexcept { return (AddrType)_signature.getField(); } + //! Sets the address type to `addrType`. + ASMJIT_INLINE_NODEBUG void setAddrType(AddrType addrType) noexcept { _signature.setField(uint32_t(addrType)); } + //! Resets the address type to \ref AddrType::kDefault. + ASMJIT_INLINE_NODEBUG void resetAddrType() noexcept { _signature.setField(uint32_t(AddrType::kDefault)); } + + //! Tests whether the address type is \ref AddrType::kAbs. + ASMJIT_INLINE_NODEBUG constexpr bool isAbs() const noexcept { return addrType() == AddrType::kAbs; } + //! Sets the address type to \ref AddrType::kAbs. + ASMJIT_INLINE_NODEBUG void setAbs() noexcept { setAddrType(AddrType::kAbs); } + + //! Tests whether the address type is \ref AddrType::kRel. + ASMJIT_INLINE_NODEBUG constexpr bool isRel() const noexcept { return addrType() == AddrType::kRel; } + //! Sets the address type to \ref AddrType::kRel. + ASMJIT_INLINE_NODEBUG void setRel() noexcept { setAddrType(AddrType::kRel); } + + //! \} + + //! \name Segment + //! \{ + + //! Tests whether the memory operand has a segment override. + ASMJIT_INLINE_NODEBUG constexpr bool hasSegment() const noexcept { return _signature.hasField(); } + //! Returns the associated segment override as `SReg` operand. + ASMJIT_INLINE_NODEBUG constexpr SReg segment() const noexcept { return SReg(segmentId()); } + //! Returns segment override register id, see `SReg::Id`. + ASMJIT_INLINE_NODEBUG constexpr uint32_t segmentId() const noexcept { return _signature.getField(); } + + //! Sets the segment override to `seg`. + ASMJIT_INLINE_NODEBUG void setSegment(const SReg& seg) noexcept { setSegment(seg.id()); } + //! Sets the segment override to `id`. + ASMJIT_INLINE_NODEBUG void setSegment(uint32_t rId) noexcept { _signature.setField(rId); } + //! Resets the segment override. + ASMJIT_INLINE_NODEBUG void resetSegment() noexcept { _signature.setField(0); } + + //! \} + + //! \name Shift + //! \{ + + //! Tests whether the memory operand has shift (aka scale) value. + ASMJIT_INLINE_NODEBUG constexpr bool hasShift() const noexcept { return _signature.hasField(); } + //! Returns the memory operand's shift (aka scale) value. + ASMJIT_INLINE_NODEBUG constexpr uint32_t shift() const noexcept { return _signature.getField(); } + //! Sets the memory operand's shift (aka scale) value. + ASMJIT_INLINE_NODEBUG void setShift(uint32_t shift) noexcept { _signature.setField(shift); } + //! Resets the memory operand's shift (aka scale) value to zero. + ASMJIT_INLINE_NODEBUG void resetShift() noexcept { _signature.setField(0); } + + //! \} + + //! \name Broadcast + //! \{ + + //! Tests whether the memory operand has broadcast {1tox}. + ASMJIT_INLINE_NODEBUG constexpr bool hasBroadcast() const noexcept { return _signature.hasField(); } + //! Returns the memory operand's broadcast. + ASMJIT_INLINE_NODEBUG constexpr Broadcast getBroadcast() const noexcept { return (Broadcast)_signature.getField(); } + //! Sets the memory operand's broadcast. + ASMJIT_INLINE_NODEBUG void setBroadcast(Broadcast b) noexcept { _signature.setField(uint32_t(b)); } + //! Resets the memory operand's broadcast to none. + ASMJIT_INLINE_NODEBUG void resetBroadcast() noexcept { _signature.setField(0); } + + //! Returns a new `Mem` without a broadcast (the possible broadcast is cleared). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to1() const noexcept { return cloneBroadcasted(Broadcast::kNone); } + //! Returns a new `Mem` with {1to2} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to2() const noexcept { return cloneBroadcasted(Broadcast::k1To2); } + //! Returns a new `Mem` with {1to4} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to4() const noexcept { return cloneBroadcasted(Broadcast::k1To4); } + //! Returns a new `Mem` with {1to8} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to8() const noexcept { return cloneBroadcasted(Broadcast::k1To8); } + //! Returns a new `Mem` with {1to16} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to16() const noexcept { return cloneBroadcasted(Broadcast::k1To16); } + //! Returns a new `Mem` with {1to32} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to32() const noexcept { return cloneBroadcasted(Broadcast::k1To32); } + //! Returns a new `Mem` with {1to64} broadcast (AVX-512). + ASMJIT_INLINE_NODEBUG constexpr Mem _1to64() const noexcept { return cloneBroadcasted(Broadcast::k1To64); } + + //! \} +}; + +//! Creates `[base.reg + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, offset, size); +} +//! Creates `[base.reg + (index << shift) + offset]` memory operand (scalar index). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} +//! Creates `[base.reg + (index << shift) + offset]` memory operand (vector index). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Gp& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} + +//! Creates `[base + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, offset, size); +} +//! Creates `[base + (index << shift) + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} +//! Creates `[base + (index << shift) + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Label& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, offset, size); +} + +//! Creates `[rip + offset]` memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(const Rip& rip_, int32_t offset = 0, uint32_t size = 0) noexcept { + return Mem(rip_, offset, size); +} + +//! Creates `[base]` absolute memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base, uint32_t size = 0) noexcept { + return Mem(base, size); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base, const Reg& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand. +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr(uint64_t base, const Vec& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size); +} + +//! Creates `[base]` absolute memory operand (absolute). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_abs(uint64_t base, uint32_t size = 0) noexcept { + return Mem(base, size, OperandSignature::fromValue(Mem::AddrType::kAbs)); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand (absolute). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_abs(uint64_t base, const Reg& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kAbs)); +} +//! Creates `[base + (index.reg << shift)]` absolute memory operand (absolute). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_abs(uint64_t base, const Vec& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kAbs)); +} + +//! Creates `[base]` relative memory operand (relative). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_rel(uint64_t base, uint32_t size = 0) noexcept { + return Mem(base, size, OperandSignature::fromValue(Mem::AddrType::kRel)); +} +//! Creates `[base + (index.reg << shift)]` relative memory operand (relative). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_rel(uint64_t base, const Reg& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kRel)); +} +//! Creates `[base + (index.reg << shift)]` relative memory operand (relative). +static ASMJIT_INLINE_NODEBUG constexpr Mem ptr_rel(uint64_t base, const Vec& index, uint32_t shift = 0, uint32_t size = 0) noexcept { + return Mem(base, index, shift, size, OperandSignature::fromValue(Mem::AddrType::kRel)); +} + +#define ASMJIT_MEM_PTR(FUNC, SIZE) \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Gp& base, int32_t offset = 0) noexcept \ + { return Mem(base, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Gp& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0) noexcept \ + { return Mem(base, index, shift, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Gp& base, const Vec& index, uint32_t shift = 0, int32_t offset = 0) noexcept \ + { return Mem(base, index, shift, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Label& base, int32_t offset = 0) noexcept \ + { return Mem(base, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Label& base, const Gp& index, uint32_t shift = 0, int32_t offset = 0) noexcept \ + { return Mem(base, index, shift, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + const Rip& rip_, int32_t offset = 0) noexcept \ + { return Mem(rip_, offset, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + uint64_t base) noexcept \ + { return Mem(base, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + uint64_t base, const Gp& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC( \ + uint64_t base, const Vec& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_abs( \ + uint64_t base) noexcept \ + { return Mem(base, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kAbs)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_abs( \ + uint64_t base, const Gp& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kAbs)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_abs( \ + uint64_t base, const Vec& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kAbs)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_rel( \ + uint64_t base) noexcept \ + { return Mem(base, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kRel)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_rel( \ + uint64_t base, const Gp& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kRel)); } \ + \ + static ASMJIT_INLINE_NODEBUG constexpr Mem FUNC##_rel( \ + uint64_t base, const Vec& index, uint32_t shift = 0) noexcept \ + { return Mem(base, index, shift, SIZE, \ + OperandSignature::fromValue(Mem::AddrType::kRel)); } + +// Definition of memory operand constructors that use platform independent naming. +ASMJIT_MEM_PTR(ptr_8, 1) +ASMJIT_MEM_PTR(ptr_16, 2) +ASMJIT_MEM_PTR(ptr_32, 4) +ASMJIT_MEM_PTR(ptr_48, 6) +ASMJIT_MEM_PTR(ptr_64, 8) +ASMJIT_MEM_PTR(ptr_80, 10) +ASMJIT_MEM_PTR(ptr_128, 16) +ASMJIT_MEM_PTR(ptr_256, 32) +ASMJIT_MEM_PTR(ptr_512, 64) + +// Definition of memory operand constructors that use X86-specific convention. +ASMJIT_MEM_PTR(byte_ptr, 1) +ASMJIT_MEM_PTR(word_ptr, 2) +ASMJIT_MEM_PTR(dword_ptr, 4) +ASMJIT_MEM_PTR(fword_ptr, 6) +ASMJIT_MEM_PTR(qword_ptr, 8) +ASMJIT_MEM_PTR(tbyte_ptr, 10) +ASMJIT_MEM_PTR(tword_ptr, 10) +ASMJIT_MEM_PTR(oword_ptr, 16) +ASMJIT_MEM_PTR(dqword_ptr, 16) +ASMJIT_MEM_PTR(qqword_ptr, 32) +ASMJIT_MEM_PTR(xmmword_ptr, 16) +ASMJIT_MEM_PTR(ymmword_ptr, 32) +ASMJIT_MEM_PTR(zmmword_ptr, 64) + +#undef ASMJIT_MEM_PTR + +//! \} + +ASMJIT_END_SUB_NAMESPACE + +//! \cond INTERNAL +ASMJIT_BEGIN_NAMESPACE +ASMJIT_DEFINE_TYPE_ID(x86::Gpb, TypeId::kInt8); +ASMJIT_DEFINE_TYPE_ID(x86::Gpw, TypeId::kInt16); +ASMJIT_DEFINE_TYPE_ID(x86::Gpd, TypeId::kInt32); +ASMJIT_DEFINE_TYPE_ID(x86::Gpq, TypeId::kInt64); +ASMJIT_DEFINE_TYPE_ID(x86::Mm , TypeId::kMmx64); +ASMJIT_DEFINE_TYPE_ID(x86::Xmm, TypeId::kInt32x4); +ASMJIT_DEFINE_TYPE_ID(x86::Ymm, TypeId::kInt32x8); +ASMJIT_DEFINE_TYPE_ID(x86::Zmm, TypeId::kInt32x16); +ASMJIT_END_NAMESPACE +//! \endcond + +#endif // ASMJIT_X86_X86OPERAND_H_INCLUDED diff --git a/phivenv/Lib/site-packages/torch/include/cpuinfo.h b/phivenv/Lib/site-packages/torch/include/cpuinfo.h new file mode 100644 index 0000000000000000000000000000000000000000..ce794de8651d364d48f8f4a96e44aa8f000cb8dd --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/cpuinfo.h @@ -0,0 +1,2321 @@ +#pragma once +#ifndef CPUINFO_H +#define CPUINFO_H + +#ifndef __cplusplus +#include +#endif + +#ifdef __APPLE__ +#include +#endif + +#include + +/* Identify architecture and define corresponding macro */ + +#if defined(__i386__) || defined(__i486__) || defined(__i586__) || defined(__i686__) || defined(_M_IX86) +#define CPUINFO_ARCH_X86 1 +#endif + +#if defined(__x86_64__) || defined(__x86_64) || defined(_M_X64) || defined(_M_AMD64) +#define CPUINFO_ARCH_X86_64 1 +#endif + +#if defined(__arm__) || defined(_M_ARM) +#define CPUINFO_ARCH_ARM 1 +#endif + +#if defined(__aarch64__) || defined(_M_ARM64) +#define CPUINFO_ARCH_ARM64 1 +#endif + +#if defined(__PPC64__) || defined(__powerpc64__) || defined(_ARCH_PPC64) +#define CPUINFO_ARCH_PPC64 1 +#endif + +#if defined(__asmjs__) +#define CPUINFO_ARCH_ASMJS 1 +#endif + +#if defined(__wasm__) +#if defined(__wasm_simd128__) +#define CPUINFO_ARCH_WASMSIMD 1 +#else +#define CPUINFO_ARCH_WASM 1 +#endif +#endif + +#if defined(__riscv) +#if (__riscv_xlen == 32) +#define CPUINFO_ARCH_RISCV32 1 +#elif (__riscv_xlen == 64) +#define CPUINFO_ARCH_RISCV64 1 +#endif +#endif + +/* Define other architecture-specific macros as 0 */ + +#ifndef CPUINFO_ARCH_X86 +#define CPUINFO_ARCH_X86 0 +#endif + +#ifndef CPUINFO_ARCH_X86_64 +#define CPUINFO_ARCH_X86_64 0 +#endif + +#ifndef CPUINFO_ARCH_ARM +#define CPUINFO_ARCH_ARM 0 +#endif + +#ifndef CPUINFO_ARCH_ARM64 +#define CPUINFO_ARCH_ARM64 0 +#endif + +#ifndef CPUINFO_ARCH_PPC64 +#define CPUINFO_ARCH_PPC64 0 +#endif + +#ifndef CPUINFO_ARCH_ASMJS +#define CPUINFO_ARCH_ASMJS 0 +#endif + +#ifndef CPUINFO_ARCH_WASM +#define CPUINFO_ARCH_WASM 0 +#endif + +#ifndef CPUINFO_ARCH_WASMSIMD +#define CPUINFO_ARCH_WASMSIMD 0 +#endif + +#ifndef CPUINFO_ARCH_RISCV32 +#define CPUINFO_ARCH_RISCV32 0 +#endif + +#ifndef CPUINFO_ARCH_RISCV64 +#define CPUINFO_ARCH_RISCV64 0 +#endif + +#if CPUINFO_ARCH_X86 && defined(_MSC_VER) +#define CPUINFO_ABI __cdecl +#elif CPUINFO_ARCH_X86 && defined(__GNUC__) +#define CPUINFO_ABI __attribute__((__cdecl__)) +#else +#define CPUINFO_ABI +#endif + +#define CPUINFO_CACHE_UNIFIED 0x00000001 +#define CPUINFO_CACHE_INCLUSIVE 0x00000002 +#define CPUINFO_CACHE_COMPLEX_INDEXING 0x00000004 + +struct cpuinfo_cache { + /** Cache size in bytes */ + uint32_t size; + /** Number of ways of associativity */ + uint32_t associativity; + /** Number of sets */ + uint32_t sets; + /** Number of partitions */ + uint32_t partitions; + /** Line size in bytes */ + uint32_t line_size; + /** + * Binary characteristics of the cache (unified cache, inclusive cache, + * cache with complex indexing). + * + * @see CPUINFO_CACHE_UNIFIED, CPUINFO_CACHE_INCLUSIVE, + * CPUINFO_CACHE_COMPLEX_INDEXING + */ + uint32_t flags; + /** Index of the first logical processor that shares this cache */ + uint32_t processor_start; + /** Number of logical processors that share this cache */ + uint32_t processor_count; +}; + +struct cpuinfo_trace_cache { + uint32_t uops; + uint32_t associativity; +}; + +#define CPUINFO_PAGE_SIZE_4KB 0x1000 +#define CPUINFO_PAGE_SIZE_1MB 0x100000 +#define CPUINFO_PAGE_SIZE_2MB 0x200000 +#define CPUINFO_PAGE_SIZE_4MB 0x400000 +#define CPUINFO_PAGE_SIZE_16MB 0x1000000 +#define CPUINFO_PAGE_SIZE_1GB 0x40000000 + +struct cpuinfo_tlb { + uint32_t entries; + uint32_t associativity; + uint64_t pages; +}; + +/** Vendor of processor core design */ +enum cpuinfo_vendor { + /** Processor vendor is not known to the library, or the library failed + to get vendor information from the OS. */ + cpuinfo_vendor_unknown = 0, + + /* Active vendors of modern CPUs */ + + /** + * Intel Corporation. Vendor of x86, x86-64, IA64, and ARM processor + * microarchitectures. + * + * Sold its ARM design subsidiary in 2006. The last ARM processor design + * was released in 2004. + */ + cpuinfo_vendor_intel = 1, + /** Advanced Micro Devices, Inc. Vendor of x86 and x86-64 processor + microarchitectures. */ + cpuinfo_vendor_amd = 2, + /** ARM Holdings plc. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_arm = 3, + /** Qualcomm Incorporated. Vendor of ARM and ARM64 processor + microarchitectures. */ + cpuinfo_vendor_qualcomm = 4, + /** Apple Inc. Vendor of ARM and ARM64 processor microarchitectures. */ + cpuinfo_vendor_apple = 5, + /** Samsung Electronics Co., Ltd. Vendir if ARM64 processor + microarchitectures. */ + cpuinfo_vendor_samsung = 6, + /** Nvidia Corporation. Vendor of ARM64-compatible processor + microarchitectures. */ + cpuinfo_vendor_nvidia = 7, + /** MIPS Technologies, Inc. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_mips = 8, + /** International Business Machines Corporation. Vendor of PowerPC + processor microarchitectures. */ + cpuinfo_vendor_ibm = 9, + /** Ingenic Semiconductor. Vendor of MIPS processor microarchitectures. + */ + cpuinfo_vendor_ingenic = 10, + /** + * VIA Technologies, Inc. Vendor of x86 and x86-64 processor + * microarchitectures. + * + * Processors are designed by Centaur Technology, a subsidiary of VIA + * Technologies. + */ + cpuinfo_vendor_via = 11, + /** Cavium, Inc. Vendor of ARM64 processor microarchitectures. */ + cpuinfo_vendor_cavium = 12, + /** Broadcom, Inc. Vendor of ARM processor microarchitectures. */ + cpuinfo_vendor_broadcom = 13, + /** Applied Micro Circuits Corporation (APM). Vendor of ARM64 processor + microarchitectures. */ + cpuinfo_vendor_apm = 14, + /** + * Huawei Technologies Co., Ltd. Vendor of ARM64 processor + * microarchitectures. + * + * Processors are designed by HiSilicon, a subsidiary of Huawei. + */ + cpuinfo_vendor_huawei = 15, + /** + * Hygon (Chengdu Haiguang Integrated Circuit Design Co., Ltd), Vendor + * of x86-64 processor microarchitectures. + * + * Processors are variants of AMD cores. + */ + cpuinfo_vendor_hygon = 16, + /** SiFive, Inc. Vendor of RISC-V processor microarchitectures. */ + cpuinfo_vendor_sifive = 17, + + /* Active vendors of embedded CPUs */ + + /** Texas Instruments Inc. Vendor of ARM processor microarchitectures. + */ + cpuinfo_vendor_texas_instruments = 30, + /** Marvell Technology Group Ltd. Vendor of ARM processor + * microarchitectures. + */ + cpuinfo_vendor_marvell = 31, + /** RDC Semiconductor Co., Ltd. Vendor of x86 processor + microarchitectures. */ + cpuinfo_vendor_rdc = 32, + /** DM&P Electronics Inc. Vendor of x86 processor microarchitectures. */ + cpuinfo_vendor_dmp = 33, + /** Motorola, Inc. Vendor of PowerPC and ARM processor + microarchitectures. */ + cpuinfo_vendor_motorola = 34, + + /* Defunct CPU vendors */ + + /** + * Transmeta Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 2004. + * Transmeta processors implemented VLIW ISA and used binary translation + * to execute x86 code. + */ + cpuinfo_vendor_transmeta = 50, + /** + * Cyrix Corporation. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1996. + */ + cpuinfo_vendor_cyrix = 51, + /** + * Rise Technology. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1999. + */ + cpuinfo_vendor_rise = 52, + /** + * National Semiconductor. Vendor of x86 processor microarchitectures. + * + * Sold its x86 design subsidiary in 1999. The last processor design was + * released in 1998. + */ + cpuinfo_vendor_nsc = 53, + /** + * Silicon Integrated Systems. Vendor of x86 processor + * microarchitectures. + * + * Sold its x86 design subsidiary in 2001. The last processor design was + * released in 2001. + */ + cpuinfo_vendor_sis = 54, + /** + * NexGen. Vendor of x86 processor microarchitectures. + * + * Now defunct. The last processor design was released in 1994. + * NexGen designed the first x86 microarchitecture which decomposed x86 + * instructions into simple microoperations. + */ + cpuinfo_vendor_nexgen = 55, + /** + * United Microelectronics Corporation. Vendor of x86 processor + * microarchitectures. + * + * Ceased x86 in the early 1990s. The last processor design was released + * in 1991. Designed U5C and U5D processors. Both are 486 level. + */ + cpuinfo_vendor_umc = 56, + /** + * Digital Equipment Corporation. Vendor of ARM processor + * microarchitecture. + * + * Sold its ARM designs in 1997. The last processor design was released + * in 1997. + */ + cpuinfo_vendor_dec = 57, +}; + +/** + * Processor microarchitecture + * + * Processors with different microarchitectures often have different instruction + * performance characteristics, and may have dramatically different pipeline + * organization. + */ +enum cpuinfo_uarch { + /** Microarchitecture is unknown, or the library failed to get + information about the microarchitecture from OS */ + cpuinfo_uarch_unknown = 0, + + /** Pentium and Pentium MMX microarchitecture. */ + cpuinfo_uarch_p5 = 0x00100100, + /** Intel Quark microarchitecture. */ + cpuinfo_uarch_quark = 0x00100101, + + /** Pentium Pro, Pentium II, and Pentium III. */ + cpuinfo_uarch_p6 = 0x00100200, + /** Pentium M. */ + cpuinfo_uarch_dothan = 0x00100201, + /** Intel Core microarchitecture. */ + cpuinfo_uarch_yonah = 0x00100202, + /** Intel Core 2 microarchitecture on 65 nm process. */ + cpuinfo_uarch_conroe = 0x00100203, + /** Intel Core 2 microarchitecture on 45 nm process. */ + cpuinfo_uarch_penryn = 0x00100204, + /** Intel Nehalem and Westmere microarchitectures (Core i3/i5/i7 1st + gen). */ + cpuinfo_uarch_nehalem = 0x00100205, + /** Intel Sandy Bridge microarchitecture (Core i3/i5/i7 2nd gen). */ + cpuinfo_uarch_sandy_bridge = 0x00100206, + /** Intel Ivy Bridge microarchitecture (Core i3/i5/i7 3rd gen). */ + cpuinfo_uarch_ivy_bridge = 0x00100207, + /** Intel Haswell microarchitecture (Core i3/i5/i7 4th gen). */ + cpuinfo_uarch_haswell = 0x00100208, + /** Intel Broadwell microarchitecture. */ + cpuinfo_uarch_broadwell = 0x00100209, + /** Intel Sky Lake microarchitecture (14 nm, including + Kaby/Coffee/Whiskey/Amber/Comet/Cascade/Cooper Lake). */ + cpuinfo_uarch_sky_lake = 0x0010020A, + /** DEPRECATED (Intel Kaby Lake microarchitecture). */ + cpuinfo_uarch_kaby_lake = 0x0010020A, + /** Intel Palm Cove microarchitecture (10 nm, Cannon Lake). */ + cpuinfo_uarch_palm_cove = 0x0010020B, + /** Intel Sunny Cove microarchitecture (10 nm, Ice Lake). */ + cpuinfo_uarch_sunny_cove = 0x0010020C, + + /** Pentium 4 with Willamette, Northwood, or Foster cores. */ + cpuinfo_uarch_willamette = 0x00100300, + /** Pentium 4 with Prescott and later cores. */ + cpuinfo_uarch_prescott = 0x00100301, + + /** Intel Atom on 45 nm process. */ + cpuinfo_uarch_bonnell = 0x00100400, + /** Intel Atom on 32 nm process. */ + cpuinfo_uarch_saltwell = 0x00100401, + /** Intel Silvermont microarchitecture (22 nm out-of-order Atom). */ + cpuinfo_uarch_silvermont = 0x00100402, + /** Intel Airmont microarchitecture (14 nm out-of-order Atom). */ + cpuinfo_uarch_airmont = 0x00100403, + /** Intel Goldmont microarchitecture (Denverton, Apollo Lake). */ + cpuinfo_uarch_goldmont = 0x00100404, + /** Intel Goldmont Plus microarchitecture (Gemini Lake). */ + cpuinfo_uarch_goldmont_plus = 0x00100405, + + /** Intel Knights Ferry HPC boards. */ + cpuinfo_uarch_knights_ferry = 0x00100500, + /** Intel Knights Corner HPC boards (aka Xeon Phi). */ + cpuinfo_uarch_knights_corner = 0x00100501, + /** Intel Knights Landing microarchitecture (second-gen MIC). */ + cpuinfo_uarch_knights_landing = 0x00100502, + /** Intel Knights Hill microarchitecture (third-gen MIC). */ + cpuinfo_uarch_knights_hill = 0x00100503, + /** Intel Knights Mill Xeon Phi. */ + cpuinfo_uarch_knights_mill = 0x00100504, + + /** Intel/Marvell XScale series. */ + cpuinfo_uarch_xscale = 0x00100600, + + /** AMD K5. */ + cpuinfo_uarch_k5 = 0x00200100, + /** AMD K6 and alike. */ + cpuinfo_uarch_k6 = 0x00200101, + /** AMD Athlon and Duron. */ + cpuinfo_uarch_k7 = 0x00200102, + /** AMD Athlon 64, Opteron 64. */ + cpuinfo_uarch_k8 = 0x00200103, + /** AMD Family 10h (Barcelona, Istambul, Magny-Cours). */ + cpuinfo_uarch_k10 = 0x00200104, + /** + * AMD Bulldozer microarchitecture + * Zambezi FX-series CPUs, Zurich, Valencia and Interlagos Opteron CPUs. + */ + cpuinfo_uarch_bulldozer = 0x00200105, + /** + * AMD Piledriver microarchitecture + * Vishera FX-series CPUs, Trinity and Richland APUs, Delhi, Seoul, Abu + * Dhabi Opteron CPUs. + */ + cpuinfo_uarch_piledriver = 0x00200106, + /** AMD Steamroller microarchitecture (Kaveri APUs). */ + cpuinfo_uarch_steamroller = 0x00200107, + /** AMD Excavator microarchitecture (Carizzo APUs). */ + cpuinfo_uarch_excavator = 0x00200108, + /** AMD Zen microarchitecture (12/14 nm Ryzen and EPYC CPUs). */ + cpuinfo_uarch_zen = 0x00200109, + /** AMD Zen 2 microarchitecture (7 nm Ryzen and EPYC CPUs). */ + cpuinfo_uarch_zen2 = 0x0020010A, + /** AMD Zen 3 microarchitecture. */ + cpuinfo_uarch_zen3 = 0x0020010B, + /** AMD Zen 4 microarchitecture. */ + cpuinfo_uarch_zen4 = 0x0020010C, + /** AMD Zen 5 microarchitecture. */ + cpuinfo_uarch_zen5 = 0x0020010D, + + /** NSC Geode and AMD Geode GX and LX. */ + cpuinfo_uarch_geode = 0x00200200, + /** AMD Bobcat mobile microarchitecture. */ + cpuinfo_uarch_bobcat = 0x00200201, + /** AMD Jaguar mobile microarchitecture. */ + cpuinfo_uarch_jaguar = 0x00200202, + /** AMD Puma mobile microarchitecture. */ + cpuinfo_uarch_puma = 0x00200203, + + /** ARM7 series. */ + cpuinfo_uarch_arm7 = 0x00300100, + /** ARM9 series. */ + cpuinfo_uarch_arm9 = 0x00300101, + /** ARM 1136, ARM 1156, ARM 1176, or ARM 11MPCore. */ + cpuinfo_uarch_arm11 = 0x00300102, + + /** ARM Cortex-A5. */ + cpuinfo_uarch_cortex_a5 = 0x00300205, + /** ARM Cortex-A7. */ + cpuinfo_uarch_cortex_a7 = 0x00300207, + /** ARM Cortex-A8. */ + cpuinfo_uarch_cortex_a8 = 0x00300208, + /** ARM Cortex-A9. */ + cpuinfo_uarch_cortex_a9 = 0x00300209, + /** ARM Cortex-A12. */ + cpuinfo_uarch_cortex_a12 = 0x00300212, + /** ARM Cortex-A15. */ + cpuinfo_uarch_cortex_a15 = 0x00300215, + /** ARM Cortex-A17. */ + cpuinfo_uarch_cortex_a17 = 0x00300217, + + /** ARM Cortex-A32. */ + cpuinfo_uarch_cortex_a32 = 0x00300332, + /** ARM Cortex-A35. */ + cpuinfo_uarch_cortex_a35 = 0x00300335, + /** ARM Cortex-A53. */ + cpuinfo_uarch_cortex_a53 = 0x00300353, + /** ARM Cortex-A55 revision 0 (restricted dual-issue capabilities + compared to revision 1+). */ + cpuinfo_uarch_cortex_a55r0 = 0x00300354, + /** ARM Cortex-A55. */ + cpuinfo_uarch_cortex_a55 = 0x00300355, + /** ARM Cortex-A57. */ + cpuinfo_uarch_cortex_a57 = 0x00300357, + /** ARM Cortex-A65. */ + cpuinfo_uarch_cortex_a65 = 0x00300365, + /** ARM Cortex-A72. */ + cpuinfo_uarch_cortex_a72 = 0x00300372, + /** ARM Cortex-A73. */ + cpuinfo_uarch_cortex_a73 = 0x00300373, + /** ARM Cortex-A75. */ + cpuinfo_uarch_cortex_a75 = 0x00300375, + /** ARM Cortex-A76. */ + cpuinfo_uarch_cortex_a76 = 0x00300376, + /** ARM Cortex-A77. */ + cpuinfo_uarch_cortex_a77 = 0x00300377, + /** ARM Cortex-A78. */ + cpuinfo_uarch_cortex_a78 = 0x00300378, + + /** ARM Neoverse N1. */ + cpuinfo_uarch_neoverse_n1 = 0x00300400, + /** ARM Neoverse E1. */ + cpuinfo_uarch_neoverse_e1 = 0x00300401, + /** ARM Neoverse V1. */ + cpuinfo_uarch_neoverse_v1 = 0x00300402, + /** ARM Neoverse N2. */ + cpuinfo_uarch_neoverse_n2 = 0x00300403, + /** ARM Neoverse V2. */ + cpuinfo_uarch_neoverse_v2 = 0x00300404, + + /** ARM Cortex-X1. */ + cpuinfo_uarch_cortex_x1 = 0x00300501, + /** ARM Cortex-X2. */ + cpuinfo_uarch_cortex_x2 = 0x00300502, + /** ARM Cortex-X3. */ + cpuinfo_uarch_cortex_x3 = 0x00300503, + /** ARM Cortex-X4. */ + cpuinfo_uarch_cortex_x4 = 0x00300504, + + /** ARM Cortex-A510. */ + cpuinfo_uarch_cortex_a510 = 0x00300551, + /** ARM Cortex-A520. */ + cpuinfo_uarch_cortex_a520 = 0x00300552, + /** ARM Cortex-A710. */ + cpuinfo_uarch_cortex_a710 = 0x00300571, + /** ARM Cortex-A715. */ + cpuinfo_uarch_cortex_a715 = 0x00300572, + /** ARM Cortex-A720. */ + cpuinfo_uarch_cortex_a720 = 0x00300573, + + /** Qualcomm Scorpion. */ + cpuinfo_uarch_scorpion = 0x00400100, + /** Qualcomm Krait. */ + cpuinfo_uarch_krait = 0x00400101, + /** Qualcomm Kryo. */ + cpuinfo_uarch_kryo = 0x00400102, + /** Qualcomm Falkor. */ + cpuinfo_uarch_falkor = 0x00400103, + /** Qualcomm Saphira. */ + cpuinfo_uarch_saphira = 0x00400104, + /** Qualcomm Oryon. */ + cpuinfo_uarch_oryon = 0x00400105, + + /** Nvidia Denver. */ + cpuinfo_uarch_denver = 0x00500100, + /** Nvidia Denver 2. */ + cpuinfo_uarch_denver2 = 0x00500101, + /** Nvidia Carmel. */ + cpuinfo_uarch_carmel = 0x00500102, + + /** Samsung Exynos M1 (Exynos 8890 big cores). */ + cpuinfo_uarch_exynos_m1 = 0x00600100, + /** Samsung Exynos M2 (Exynos 8895 big cores). */ + cpuinfo_uarch_exynos_m2 = 0x00600101, + /** Samsung Exynos M3 (Exynos 9810 big cores). */ + cpuinfo_uarch_exynos_m3 = 0x00600102, + /** Samsung Exynos M4 (Exynos 9820 big cores). */ + cpuinfo_uarch_exynos_m4 = 0x00600103, + /** Samsung Exynos M5 (Exynos 9830 big cores). */ + cpuinfo_uarch_exynos_m5 = 0x00600104, + + /* Deprecated synonym for Cortex-A76 */ + cpuinfo_uarch_cortex_a76ae = 0x00300376, + /* Deprecated names for Exynos. */ + cpuinfo_uarch_mongoose_m1 = 0x00600100, + cpuinfo_uarch_mongoose_m2 = 0x00600101, + cpuinfo_uarch_meerkat_m3 = 0x00600102, + cpuinfo_uarch_meerkat_m4 = 0x00600103, + + /** Apple A6 and A6X processors. */ + cpuinfo_uarch_swift = 0x00700100, + /** Apple A7 processor. */ + cpuinfo_uarch_cyclone = 0x00700101, + /** Apple A8 and A8X processor. */ + cpuinfo_uarch_typhoon = 0x00700102, + /** Apple A9 and A9X processor. */ + cpuinfo_uarch_twister = 0x00700103, + /** Apple A10 and A10X processor. */ + cpuinfo_uarch_hurricane = 0x00700104, + /** Apple A11 processor (big cores). */ + cpuinfo_uarch_monsoon = 0x00700105, + /** Apple A11 processor (little cores). */ + cpuinfo_uarch_mistral = 0x00700106, + /** Apple A12 processor (big cores). */ + cpuinfo_uarch_vortex = 0x00700107, + /** Apple A12 processor (little cores). */ + cpuinfo_uarch_tempest = 0x00700108, + /** Apple A13 processor (big cores). */ + cpuinfo_uarch_lightning = 0x00700109, + /** Apple A13 processor (little cores). */ + cpuinfo_uarch_thunder = 0x0070010A, + /** Apple A14 / M1 processor (big cores). */ + cpuinfo_uarch_firestorm = 0x0070010B, + /** Apple A14 / M1 processor (little cores). */ + cpuinfo_uarch_icestorm = 0x0070010C, + /** Apple A15 / M2 processor (big cores). */ + cpuinfo_uarch_avalanche = 0x0070010D, + /** Apple A15 / M2 processor (little cores). */ + cpuinfo_uarch_blizzard = 0x0070010E, + + /** Cavium ThunderX. */ + cpuinfo_uarch_thunderx = 0x00800100, + /** Cavium ThunderX2 (originally Broadcom Vulkan). */ + cpuinfo_uarch_thunderx2 = 0x00800200, + + /** Marvell PJ4. */ + cpuinfo_uarch_pj4 = 0x00900100, + + /** Broadcom Brahma B15. */ + cpuinfo_uarch_brahma_b15 = 0x00A00100, + /** Broadcom Brahma B53. */ + cpuinfo_uarch_brahma_b53 = 0x00A00101, + + /** Applied Micro X-Gene. */ + cpuinfo_uarch_xgene = 0x00B00100, + + /* Hygon Dhyana (a modification of AMD Zen for Chinese market). */ + cpuinfo_uarch_dhyana = 0x01000100, + + /** HiSilicon TaiShan v110 (Huawei Kunpeng 920 series processors). */ + cpuinfo_uarch_taishan_v110 = 0x00C00100, +}; + +struct cpuinfo_processor { + /** SMT (hyperthread) ID within a core */ + uint32_t smt_id; + /** Core containing this logical processor */ + const struct cpuinfo_core* core; + /** Cluster of cores containing this logical processor */ + const struct cpuinfo_cluster* cluster; + /** Physical package containing this logical processor */ + const struct cpuinfo_package* package; +#if defined(__linux__) + /** + * Linux-specific ID for the logical processor: + * - Linux kernel exposes information about this logical processor in + * /sys/devices/system/cpu/cpu/ + * - Bit in the cpu_set_t identifies this logical processor + */ + int linux_id; +#endif +#if defined(_WIN32) || defined(__CYGWIN__) + /** Windows-specific ID for the group containing the logical processor. + */ + uint16_t windows_group_id; + /** + * Windows-specific ID of the logical processor within its group: + * - Bit in the KAFFINITY mask identifies this + * logical processor within its group. + */ + uint16_t windows_processor_id; +#endif +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** APIC ID (unique x86-specific ID of the logical processor) */ + uint32_t apic_id; +#endif + struct { + /** Level 1 instruction cache */ + const struct cpuinfo_cache* l1i; + /** Level 1 data cache */ + const struct cpuinfo_cache* l1d; + /** Level 2 unified or data cache */ + const struct cpuinfo_cache* l2; + /** Level 3 unified or data cache */ + const struct cpuinfo_cache* l3; + /** Level 4 unified or data cache */ + const struct cpuinfo_cache* l4; + } cache; +}; + +struct cpuinfo_core { + /** Index of the first logical processor on this core. */ + uint32_t processor_start; + /** Number of logical processors on this core */ + uint32_t processor_count; + /** Core ID within a package */ + uint32_t core_id; + /** Cluster containing this core */ + const struct cpuinfo_cluster* cluster; + /** Physical package containing this core. */ + const struct cpuinfo_package* package; + /** Vendor of the CPU microarchitecture for this core */ + enum cpuinfo_vendor vendor; + /** CPU microarchitecture for this core */ + enum cpuinfo_uarch uarch; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** Value of CPUID leaf 1 EAX register for this core */ + uint32_t cpuid; +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + /** Value of Main ID Register (MIDR) for this core */ + uint32_t midr; +#endif + /** Clock rate (non-Turbo) of the core, in Hz */ + uint64_t frequency; +}; + +struct cpuinfo_cluster { + /** Index of the first logical processor in the cluster */ + uint32_t processor_start; + /** Number of logical processors in the cluster */ + uint32_t processor_count; + /** Index of the first core in the cluster */ + uint32_t core_start; + /** Number of cores on the cluster */ + uint32_t core_count; + /** Cluster ID within a package */ + uint32_t cluster_id; + /** Physical package containing the cluster */ + const struct cpuinfo_package* package; + /** CPU microarchitecture vendor of the cores in the cluster */ + enum cpuinfo_vendor vendor; + /** CPU microarchitecture of the cores in the cluster */ + enum cpuinfo_uarch uarch; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** Value of CPUID leaf 1 EAX register of the cores in the cluster */ + uint32_t cpuid; +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + /** Value of Main ID Register (MIDR) of the cores in the cluster */ + uint32_t midr; +#endif + /** Clock rate (non-Turbo) of the cores in the cluster, in Hz */ + uint64_t frequency; +}; + +#define CPUINFO_PACKAGE_NAME_MAX 48 + +struct cpuinfo_package { + /** SoC or processor chip model name */ + char name[CPUINFO_PACKAGE_NAME_MAX]; + /** Index of the first logical processor on this physical package */ + uint32_t processor_start; + /** Number of logical processors on this physical package */ + uint32_t processor_count; + /** Index of the first core on this physical package */ + uint32_t core_start; + /** Number of cores on this physical package */ + uint32_t core_count; + /** Index of the first cluster of cores on this physical package */ + uint32_t cluster_start; + /** Number of clusters of cores on this physical package */ + uint32_t cluster_count; +}; + +struct cpuinfo_uarch_info { + /** Type of CPU microarchitecture */ + enum cpuinfo_uarch uarch; +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + /** Value of CPUID leaf 1 EAX register for the microarchitecture */ + uint32_t cpuid; +#elif CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + /** Value of Main ID Register (MIDR) for the microarchitecture */ + uint32_t midr; +#endif + /** Number of logical processors with the microarchitecture */ + uint32_t processor_count; + /** Number of cores with the microarchitecture */ + uint32_t core_count; +}; + +#ifdef __cplusplus +extern "C" { +#endif + +bool CPUINFO_ABI cpuinfo_initialize(void); + +void CPUINFO_ABI cpuinfo_deinitialize(void); + +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +/* This structure is not a part of stable API. Use cpuinfo_has_x86_* functions + * instead. */ +struct cpuinfo_x86_isa { +#if CPUINFO_ARCH_X86 + bool rdtsc; +#endif + bool rdtscp; + bool rdpid; + bool sysenter; +#if CPUINFO_ARCH_X86 + bool syscall; +#endif + bool msr; + bool clzero; + bool clflush; + bool clflushopt; + bool mwait; + bool mwaitx; +#if CPUINFO_ARCH_X86 + bool emmx; +#endif + bool fxsave; + bool xsave; +#if CPUINFO_ARCH_X86 + bool fpu; + bool mmx; + bool mmx_plus; +#endif + bool three_d_now; + bool three_d_now_plus; +#if CPUINFO_ARCH_X86 + bool three_d_now_geode; +#endif + bool prefetch; + bool prefetchw; + bool prefetchwt1; +#if CPUINFO_ARCH_X86 + bool daz; + bool sse; + bool sse2; +#endif + bool sse3; + bool ssse3; + bool sse4_1; + bool sse4_2; + bool sse4a; + bool misaligned_sse; + bool avx; + bool avxvnni; + bool fma3; + bool fma4; + bool xop; + bool f16c; + bool avx2; + bool avx512f; + bool avx512pf; + bool avx512er; + bool avx512cd; + bool avx512dq; + bool avx512bw; + bool avx512vl; + bool avx512ifma; + bool avx512vbmi; + bool avx512vbmi2; + bool avx512bitalg; + bool avx512vpopcntdq; + bool avx512vnni; + bool avx512bf16; + bool avx512fp16; + bool avx512vp2intersect; + bool avx512_4vnniw; + bool avx512_4fmaps; + bool avx10_1; + bool avx10_2; + bool amx_bf16; + bool amx_tile; + bool amx_int8; + bool amx_fp16; + bool avx_vnni_int8; + bool avx_vnni_int16; + bool avx_ne_convert; + bool hle; + bool rtm; + bool xtest; + bool mpx; +#if CPUINFO_ARCH_X86 + bool cmov; + bool cmpxchg8b; +#endif + bool cmpxchg16b; + bool clwb; + bool movbe; +#if CPUINFO_ARCH_X86_64 + bool lahf_sahf; +#endif + bool fs_gs_base; + bool lzcnt; + bool popcnt; + bool tbm; + bool bmi; + bool bmi2; + bool adx; + bool aes; + bool vaes; + bool pclmulqdq; + bool vpclmulqdq; + bool gfni; + bool rdrand; + bool rdseed; + bool sha; + bool rng; + bool ace; + bool ace2; + bool phe; + bool pmm; + bool lwp; +}; + +extern struct cpuinfo_x86_isa cpuinfo_isa; +#endif + +static inline bool cpuinfo_has_x86_rdtsc(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.rdtsc; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdtscp(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdtscp; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdpid(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdpid; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_clzero(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.clzero; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mwait(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.mwait; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mwaitx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.mwaitx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fxsave(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.fxsave; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_xsave(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.xsave; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fpu(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.fpu; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mmx(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.mmx; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mmx_plus(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.mmx_plus; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_3dnow(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.three_d_now; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_3dnow_plus(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.three_d_now_plus; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_3dnow_geode(void) { +#if CPUINFO_ARCH_X86_64 + return false; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return false; +#else + return cpuinfo_isa.three_d_now_geode; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_prefetch(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.prefetch; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_prefetchw(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.prefetchw; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_prefetchwt1(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.prefetchwt1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_daz(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.daz; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse2(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse2; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse3(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse3; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_ssse3(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.ssse3; +#endif +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse4_1(void) { +#if CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse4_1; +#endif +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.sse4_1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse4_2(void) { +#if CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.sse4_2; +#endif +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.sse4_2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sse4a(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.sse4a; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_misaligned_sse(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.misaligned_sse; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avxvnni(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avxvnni; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fma3(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.fma3; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_fma4(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.fma4; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_xop(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.xop; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_f16c(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.f16c; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512f(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512f; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512pf(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512pf; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512er(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512er; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512cd(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512cd; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512dq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512dq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512bw(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512bw; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vl(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vl; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512ifma(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512ifma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vbmi(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vbmi; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vbmi2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vbmi2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512bitalg(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512bitalg; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vpopcntdq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vpopcntdq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vnni(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vnni; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512bf16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512fp16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512fp16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512vp2intersect(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512vp2intersect; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512_4vnniw(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512_4vnniw; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx512_4fmaps(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx512_4fmaps; +#else + return false; +#endif +} + +/* [NOTE] Intel Advanced Matrix Extensions (AMX) detection + * + * I. AMX is a new extensions to the x86 ISA to work on matrices, consists of + * 1) 2-dimentional registers (tiles), hold sub-matrices from larger matrices in memory + * 2) Accelerator called Tile Matrix Multiply (TMUL), contains instructions operating on tiles + * + * II. Platforms that supports AMX: + * +-----------------+-----+----------+----------+----------+----------+ + * | Platforms | Gen | amx-bf16 | amx-tile | amx-int8 | amx-fp16 | + * +-----------------+-----+----------+----------+----------+----------+ + * | Sapphire Rapids | 4th | YES | YES | YES | NO | + * +-----------------+-----+----------+----------+----------+----------+ + * | Emerald Rapids | 5th | YES | YES | YES | NO | + * +-----------------+-----+----------+----------+----------+----------+ + * | Granite Rapids | 6th | YES | YES | YES | YES | + * +-----------------+-----+----------+----------+----------+----------+ + * + * Reference: https://www.intel.com/content/www/us/en/products/docs + * /accelerator-engines/advanced-matrix-extensions/overview.html + */ +static inline bool cpuinfo_has_x86_amx_bf16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_amx_tile(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_tile; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_amx_int8(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_int8; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_amx_fp16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.amx_fp16; +#else + return false; +#endif +} + +/* + * Intel AVX Vector Neural Network Instructions (VNNI) INT8 + * Supported Platfroms: Sierra Forest, Arrow Lake, Lunar Lake + */ +static inline bool cpuinfo_has_x86_avx_vnni_int8(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx_vnni_int8; +#else + return false; +#endif +} + +/* + * Intel AVX Vector Neural Network Instructions (VNNI) INT16 + * Supported Platfroms: Arrow Lake, Lunar Lake + */ +static inline bool cpuinfo_has_x86_avx_vnni_int16(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx_vnni_int16; +#else + return false; +#endif +} + +/* + * A new set of instructions, which can convert low precision floating point + * like BF16/FP16 to high precision floating point FP32, as well as convert FP32 + * elements to BF16. This instruction allows the platform to have improved AI + * capabilities and better compatibility. + * + * Supported Platforms: Sierra Forest, Arrow Lake, Lunar Lake + */ +static inline bool cpuinfo_has_x86_avx_ne_convert(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx_ne_convert; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx10_1(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx10_1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_avx10_2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.avx10_2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_hle(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.hle; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rtm(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rtm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_xtest(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.xtest; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_mpx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.mpx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_cmov(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.cmov; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_cmpxchg8b(void) { +#if CPUINFO_ARCH_X86_64 + return true; +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.cmpxchg8b; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_cmpxchg16b(void) { +#if CPUINFO_ARCH_X86_64 + return cpuinfo_isa.cmpxchg16b; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_clwb(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.clwb; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_movbe(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.movbe; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_lahf_sahf(void) { +#if CPUINFO_ARCH_X86 + return true; +#elif CPUINFO_ARCH_X86_64 + return cpuinfo_isa.lahf_sahf; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_lzcnt(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.lzcnt; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_popcnt(void) { +#if CPUINFO_ARCH_X86_64 +#if defined(__ANDROID__) + return true; +#else + return cpuinfo_isa.popcnt; +#endif +#elif CPUINFO_ARCH_X86 + return cpuinfo_isa.popcnt; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_tbm(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.tbm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_bmi(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.bmi; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_bmi2(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.bmi2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_adx(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.adx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_aes(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.aes; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_vaes(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.vaes; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_pclmulqdq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.pclmulqdq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_vpclmulqdq(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.vpclmulqdq; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_gfni(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.gfni; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdrand(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdrand; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_rdseed(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.rdseed; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_x86_sha(void) { +#if CPUINFO_ARCH_X86 || CPUINFO_ARCH_X86_64 + return cpuinfo_isa.sha; +#else + return false; +#endif +} + +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 +/* This structure is not a part of stable API. Use cpuinfo_has_arm_* functions + * instead. */ +struct cpuinfo_arm_isa { +#if CPUINFO_ARCH_ARM + bool thumb; + bool thumb2; + bool thumbee; + bool jazelle; + bool armv5e; + bool armv6; + bool armv6k; + bool armv7; + bool armv7mp; + bool armv8; + bool idiv; + + bool vfpv2; + bool vfpv3; + bool d32; + bool fp16; + bool fma; + + bool wmmx; + bool wmmx2; + bool neon; +#endif +#if CPUINFO_ARCH_ARM64 + bool atomics; + bool bf16; + bool sve; + bool sve2; + bool i8mm; + bool sme; + bool sme2; + bool sme2p1; + bool sme_i16i32; + bool sme_bi32i32; + bool sme_b16b16; + bool sme_f16f16; + uint32_t svelen; +#endif + bool rdm; + bool fp16arith; + bool dot; + bool jscvt; + bool fcma; + bool fhm; + + bool aes; + bool sha1; + bool sha2; + bool pmull; + bool crc32; +}; + +extern struct cpuinfo_arm_isa cpuinfo_isa; +#endif + +static inline bool cpuinfo_has_arm_thumb(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.thumb; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_thumb2(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.thumb2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v5e(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv5e; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v6(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv6; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v6k(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv6k; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v7(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv7; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v7mp(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.armv7mp; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_v8(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.armv8; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_idiv(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.idiv; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv2(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3_d32(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.d32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3_fp16(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fp16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv3_fp16_d32(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fp16 && cpuinfo_isa.d32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv4(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_vfpv4_d32(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.vfpv3 && cpuinfo_isa.fma && cpuinfo_isa.d32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_fp16_arith(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fp16arith; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_bf16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_wmmx(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.wmmx; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_wmmx2(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.wmmx2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_fp16(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.fp16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_fma(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.fma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_v8(void) { +#if CPUINFO_ARCH_ARM64 + return true; +#elif CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.armv8; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_atomics(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.atomics; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_rdm(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.rdm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_fp16_arith(void) { +#if CPUINFO_ARCH_ARM + return cpuinfo_isa.neon && cpuinfo_isa.fp16arith; +#elif CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fp16arith; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_fhm(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fhm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_dot(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.dot; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_neon_bf16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_jscvt(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.jscvt; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_fcma(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.fcma; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_i8mm(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.i8mm; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_aes(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.aes; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sha1(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sha1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sha2(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sha2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_pmull(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.pmull; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_crc32(void) { +#if CPUINFO_ARCH_ARM || CPUINFO_ARCH_ARM64 + return cpuinfo_isa.crc32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sve(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sve; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sve_bf16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sve && cpuinfo_isa.bf16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sve2(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sve2; +#else + return false; +#endif +} + +// Function to get the max SVE vector length on ARM CPU's which support SVE. +static inline uint32_t cpuinfo_get_max_arm_sve_length(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.svelen * 8; // bytes * 8 = bit length(vector length) +#else + return 0; +#endif +} + +static inline bool cpuinfo_has_arm_sme(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sme2(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme2; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sme2p1(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme2p1; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sme_i16i32(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme_i16i32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sme_bi32i32(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme_bi32i32; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sme_b16b16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme_b16b16; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_arm_sme_f16f16(void) { +#if CPUINFO_ARCH_ARM64 + return cpuinfo_isa.sme_f16f16; +#else + return false; +#endif +} + +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 +/* This structure is not a part of stable API. Use cpuinfo_has_riscv_* functions + * instead. */ +struct cpuinfo_riscv_isa { + /** + * Keep fields in line with the canonical order as defined by + * Section 27.11 Subset Naming Convention. + */ + /* RV32I/64I/128I Base ISA. */ + bool i; +#if CPUINFO_ARCH_RISCV32 + /* RV32E Base ISA. */ + bool e; +#endif + /* Integer Multiply/Divide Extension. */ + bool m; + /* Atomic Extension. */ + bool a; + /* Single-Precision Floating-Point Extension. */ + bool f; + /* Double-Precision Floating-Point Extension. */ + bool d; + /* Compressed Extension. */ + bool c; + /* Vector Extension. */ + bool v; +}; + +extern struct cpuinfo_riscv_isa cpuinfo_isa; +#endif + +static inline bool cpuinfo_has_riscv_i(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.i; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_e(void) { +#if CPUINFO_ARCH_RISCV32 + return cpuinfo_isa.e; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_m(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.m; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_a(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.a; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_f(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.f; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_d(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.d; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_g(void) { + // The 'G' extension is simply shorthand for 'IMAFD'. + return cpuinfo_has_riscv_i() && cpuinfo_has_riscv_m() && cpuinfo_has_riscv_a() && cpuinfo_has_riscv_f() && + cpuinfo_has_riscv_d(); +} + +static inline bool cpuinfo_has_riscv_c(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.c; +#else + return false; +#endif +} + +static inline bool cpuinfo_has_riscv_v(void) { +#if CPUINFO_ARCH_RISCV32 || CPUINFO_ARCH_RISCV64 + return cpuinfo_isa.v; +#else + return false; +#endif +} + +const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processors(void); +const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_cores(void); +const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_clusters(void); +const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_packages(void); +const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarchs(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_caches(void); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_caches(void); + +const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_processor(uint32_t index); +const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_core(uint32_t index); +const struct cpuinfo_cluster* CPUINFO_ABI cpuinfo_get_cluster(uint32_t index); +const struct cpuinfo_package* CPUINFO_ABI cpuinfo_get_package(uint32_t index); +const struct cpuinfo_uarch_info* CPUINFO_ABI cpuinfo_get_uarch(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1i_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l1d_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l2_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l3_cache(uint32_t index); +const struct cpuinfo_cache* CPUINFO_ABI cpuinfo_get_l4_cache(uint32_t index); + +uint32_t CPUINFO_ABI cpuinfo_get_processors_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_cores_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_clusters_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_packages_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_uarchs_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l1i_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l1d_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l2_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l3_caches_count(void); +uint32_t CPUINFO_ABI cpuinfo_get_l4_caches_count(void); + +/** + * Returns upper bound on cache size. + */ +uint32_t CPUINFO_ABI cpuinfo_get_max_cache_size(void); + +/** + * Identify the logical processor that executes the current thread. + * + * There is no guarantee that the thread will stay on the same logical processor + * for any time. Callers should treat the result as only a hint, and be prepared + * to handle NULL return value. + */ +const struct cpuinfo_processor* CPUINFO_ABI cpuinfo_get_current_processor(void); + +/** + * Identify the core that executes the current thread. + * + * There is no guarantee that the thread will stay on the same core for any + * time. Callers should treat the result as only a hint, and be prepared to + * handle NULL return value. + */ +const struct cpuinfo_core* CPUINFO_ABI cpuinfo_get_current_core(void); + +/** + * Identify the microarchitecture index of the core that executes the current + * thread. If the system does not support such identification, the function + * returns 0. + * + * There is no guarantee that the thread will stay on the same type of core for + * any time. Callers should treat the result as only a hint. + */ +uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index(void); + +/** + * Identify the microarchitecture index of the core that executes the current + * thread. If the system does not support such identification, the function + * returns the user-specified default value. + * + * There is no guarantee that the thread will stay on the same type of core for + * any time. Callers should treat the result as only a hint. + */ +uint32_t CPUINFO_ABI cpuinfo_get_current_uarch_index_with_default(uint32_t default_uarch_index); + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#endif /* CPUINFO_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl.h b/phivenv/Lib/site-packages/torch/include/dnnl.h new file mode 100644 index 0000000000000000000000000000000000000000..e8ebd3b838642891a300fc77d88bea921a4cad73 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_H +#define DNNL_H + +#include "oneapi/dnnl/dnnl.h" + +#endif /* DNNL_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl.hpp b/phivenv/Lib/site-packages/torch/include/dnnl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..f564d0309711aee596bc5fca6c4746a7992e6a20 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl.hpp @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_HPP +#define DNNL_HPP + +#include "oneapi/dnnl/dnnl.hpp" + +#endif /* DNNL_HPP */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_config.h b/phivenv/Lib/site-packages/torch/include/dnnl_config.h new file mode 100644 index 0000000000000000000000000000000000000000..ebc23d438ff76b329065e710bee5dfbde5505c2c --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_config.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_CONFIG_H +#define DNNL_CONFIG_H + +#include "oneapi/dnnl/dnnl_config.h" + +#endif /* DNNL_CONFIG_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_debug.h b/phivenv/Lib/site-packages/torch/include/dnnl_debug.h new file mode 100644 index 0000000000000000000000000000000000000000..d650115f2ebde41d7e1a726e9164d31ad5bcdaf3 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_debug.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_DEBUG_H +#define DNNL_DEBUG_H + +#include "oneapi/dnnl/dnnl_debug.h" + +#endif /* DNNL_DEBUG_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_ocl.h b/phivenv/Lib/site-packages/torch/include/dnnl_ocl.h new file mode 100644 index 0000000000000000000000000000000000000000..b4554534de61413e6aaa3e4e38eb3721b7dcf910 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_ocl.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_OCL_H +#define DNNL_OCL_H + +#include "oneapi/dnnl/dnnl_ocl.h" + +#endif /* DNNL_OCL_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_ocl.hpp b/phivenv/Lib/site-packages/torch/include/dnnl_ocl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ed29f35ab88be897a76c6a33d19ac83d8661cf38 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_ocl.hpp @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_OCL_HPP +#define DNNL_OCL_HPP + +#include "oneapi/dnnl/dnnl_ocl.hpp" + +#endif /* DNNL_OCL_HPP */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_sycl.h b/phivenv/Lib/site-packages/torch/include/dnnl_sycl.h new file mode 100644 index 0000000000000000000000000000000000000000..3d029e0e17bf548fdeff9c1866eb0115d31674d7 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_sycl.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_SYCL_H +#define DNNL_SYCL_H + +#include "oneapi/dnnl/dnnl_sycl.h" + +#endif /* DNNL_SYCL_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_sycl.hpp b/phivenv/Lib/site-packages/torch/include/dnnl_sycl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..4606ae3149a8485fada6efdf698f029f6d9a6926 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_sycl.hpp @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_SYCL_HPP +#define DNNL_SYCL_HPP + +#include "oneapi/dnnl/dnnl_sycl.hpp" + +#endif /* DNNL_SYCL_HPP */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_sycl_types.h b/phivenv/Lib/site-packages/torch/include/dnnl_sycl_types.h new file mode 100644 index 0000000000000000000000000000000000000000..2f071ca6cfbd5c76d5ce1fbfd7ca6ee27dd354b0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_sycl_types.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_SYCL_TYPES_H +#define DNNL_SYCL_TYPES_H + +#include "oneapi/dnnl/dnnl_sycl_types.h" + +#endif /* DNNL_SYCL_TYPES_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_threadpool.h b/phivenv/Lib/site-packages/torch/include/dnnl_threadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..5a189ac44d923aeae9bc4e448450bd5ae583356b --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_threadpool.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_THREADPOOL_H +#define DNNL_THREADPOOL_H + +#include "oneapi/dnnl/dnnl_threadpool.h" + +#endif /* DNNL_THREADPOOL_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_threadpool.hpp b/phivenv/Lib/site-packages/torch/include/dnnl_threadpool.hpp new file mode 100644 index 0000000000000000000000000000000000000000..554d8bb3f8cbd7ad18b5f7d992795c6d626838aa --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_threadpool.hpp @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_THREADPOOL_HPP +#define DNNL_THREADPOOL_HPP + +#include "oneapi/dnnl/dnnl_threadpool.hpp" + +#endif /* DNNL_THREADPOOL_HPP */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_threadpool_iface.hpp b/phivenv/Lib/site-packages/torch/include/dnnl_threadpool_iface.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c61a10eef52078f48372953b980776500e12a770 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_threadpool_iface.hpp @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_THREADPOOL_IFACE_HPP +#define DNNL_THREADPOOL_IFACE_HPP + +#include "oneapi/dnnl/dnnl_threadpool_iface.hpp" + +#endif /* DNNL_THREADPOOL_IFACE_HPP */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_types.h b/phivenv/Lib/site-packages/torch/include/dnnl_types.h new file mode 100644 index 0000000000000000000000000000000000000000..130d67e0b3c2175b11863cdc9e62d31b55f57d45 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_types.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_TYPES_H +#define DNNL_TYPES_H + +#include "oneapi/dnnl/dnnl_types.h" + +#endif /* DNNL_TYPES_H */ diff --git a/phivenv/Lib/site-packages/torch/include/dnnl_version.h b/phivenv/Lib/site-packages/torch/include/dnnl_version.h new file mode 100644 index 0000000000000000000000000000000000000000..956a253adbba299be9747a5fc32fe647b115ee89 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/dnnl_version.h @@ -0,0 +1,22 @@ +/******************************************************************************* +* Copyright 2020 Intel Corporation +* +* Licensed under the Apache License, Version 2.0 (the "License"); +* you may not use this file except in compliance with the License. +* You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*******************************************************************************/ + +#ifndef DNNL_VERSION_H +#define DNNL_VERSION_H + +#include "oneapi/dnnl/dnnl_version.h" + +#endif /* DNNL_VERSION_H */ diff --git a/phivenv/Lib/site-packages/torch/include/experiments-config.h b/phivenv/Lib/site-packages/torch/include/experiments-config.h new file mode 100644 index 0000000000000000000000000000000000000000..62281b179244e313b5106e098431f1a4f9774c02 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/experiments-config.h @@ -0,0 +1,25 @@ +// Copyright 2023 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct xnn_experiment_config { + int dummy; // C requires that a struct or union has at least one member +}; + +struct xnn_experiment_config* xnn_get_experiment_config(); + +void xnn_experiment_enable_adaptive_avx_optimization(); + + +#ifdef __cplusplus +} // extern "C" +#endif diff --git a/phivenv/Lib/site-packages/torch/include/fp16.h b/phivenv/Lib/site-packages/torch/include/fp16.h new file mode 100644 index 0000000000000000000000000000000000000000..84a01d40d9225b9fb1d5bcb25d845ee3471502ac --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/fp16.h @@ -0,0 +1,11 @@ +#pragma once +#ifndef FP16_H +#define FP16_H + +#include + +#if defined(PSIMD_H) +#include +#endif + +#endif /* FP16_H */ diff --git a/phivenv/Lib/site-packages/torch/include/fxdiv.h b/phivenv/Lib/site-packages/torch/include/fxdiv.h new file mode 100644 index 0000000000000000000000000000000000000000..f0bb47c45fcfa0ca4f41cb8fe08d6f4ce55a94af --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/fxdiv.h @@ -0,0 +1,425 @@ +#pragma once +#ifndef FXDIV_H +#define FXDIV_H + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include + #include + #include +#elif !defined(__OPENCL_VERSION__) + #include + #include + #include +#endif + +#if defined(_MSC_VER) + #include + #if defined(_M_IX86) || defined(_M_X64) + #include + #endif +#endif + +#ifndef FXDIV_USE_INLINE_ASSEMBLY + #define FXDIV_USE_INLINE_ASSEMBLY 0 +#endif + +static inline uint64_t fxdiv_mulext_uint32_t(uint32_t a, uint32_t b) { +#if defined(_MSC_VER) && defined(_M_IX86) + return (uint64_t) __emulu((unsigned int) a, (unsigned int) b); +#else + return (uint64_t) a * (uint64_t) b; +#endif +} + +static inline uint32_t fxdiv_mulhi_uint32_t(uint32_t a, uint32_t b) { +#if defined(__OPENCL_VERSION__) + return mul_hi(a, b); +#elif defined(__CUDA_ARCH__) + return (uint32_t) __umulhi((unsigned int) a, (unsigned int) b); +#elif defined(_MSC_VER) && defined(_M_IX86) + return (uint32_t) (__emulu((unsigned int) a, (unsigned int) b) >> 32); +#elif defined(_MSC_VER) && defined(_M_ARM) + return (uint32_t) _MulUnsignedHigh((unsigned long) a, (unsigned long) b); +#else + return (uint32_t) (((uint64_t) a * (uint64_t) b) >> 32); +#endif +} + +static inline uint64_t fxdiv_mulhi_uint64_t(uint64_t a, uint64_t b) { +#if defined(__OPENCL_VERSION__) + return mul_hi(a, b); +#elif defined(__CUDA_ARCH__) + return (uint64_t) __umul64hi((unsigned long long) a, (unsigned long long) b); +#elif defined(_MSC_VER) && defined(_M_X64) + return (uint64_t) __umulh((unsigned __int64) a, (unsigned __int64) b); +#elif defined(__GNUC__) && defined(__SIZEOF_INT128__) + return (uint64_t) (((((unsigned __int128) a) * ((unsigned __int128) b))) >> 64); +#else + const uint32_t a_lo = (uint32_t) a; + const uint32_t a_hi = (uint32_t) (a >> 32); + const uint32_t b_lo = (uint32_t) b; + const uint32_t b_hi = (uint32_t) (b >> 32); + + const uint64_t t = fxdiv_mulext_uint32_t(a_hi, b_lo) + + (uint64_t) fxdiv_mulhi_uint32_t(a_lo, b_lo); + return fxdiv_mulext_uint32_t(a_hi, b_hi) + (t >> 32) + + ((fxdiv_mulext_uint32_t(a_lo, b_hi) + (uint64_t) (uint32_t) t) >> 32); +#endif +} + +static inline size_t fxdiv_mulhi_size_t(size_t a, size_t b) { +#if SIZE_MAX == UINT32_MAX + return (size_t) fxdiv_mulhi_uint32_t((uint32_t) a, (uint32_t) b); +#elif SIZE_MAX == UINT64_MAX + return (size_t) fxdiv_mulhi_uint64_t((uint64_t) a, (uint64_t) b); +#else + #error Unsupported platform +#endif +} + +struct fxdiv_divisor_uint32_t { + uint32_t value; + uint32_t m; + uint8_t s1; + uint8_t s2; +}; + +struct fxdiv_result_uint32_t { + uint32_t quotient; + uint32_t remainder; +}; + +struct fxdiv_divisor_uint64_t { + uint64_t value; + uint64_t m; + uint8_t s1; + uint8_t s2; +}; + +struct fxdiv_result_uint64_t { + uint64_t quotient; + uint64_t remainder; +}; + +struct fxdiv_divisor_size_t { + size_t value; + size_t m; + uint8_t s1; + uint8_t s2; +}; + +struct fxdiv_result_size_t { + size_t quotient; + size_t remainder; +}; + +static inline struct fxdiv_divisor_uint32_t fxdiv_init_uint32_t(uint32_t d) { + struct fxdiv_divisor_uint32_t result = { d }; + if (d == 1) { + result.m = UINT32_C(1); + result.s1 = 0; + result.s2 = 0; + } else { + #if defined(__OPENCL_VERSION__) + const uint32_t l_minus_1 = 31 - clz(d - 1); + #elif defined(__CUDA_ARCH__) + const uint32_t l_minus_1 = 31 - __clz((int) (d - 1)); + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_X64) || defined(_M_ARM) || defined(_M_ARM64)) + unsigned long l_minus_1; + _BitScanReverse(&l_minus_1, (unsigned long) (d - 1)); + #elif defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__)) && FXDIV_USE_INLINE_ASSEMBLY + uint32_t l_minus_1; + __asm__("BSRL %[d_minus_1], %[l_minus_1]" + : [l_minus_1] "=r" (l_minus_1) + : [d_minus_1] "r" (d - 1) + : "cc"); + #elif defined(__GNUC__) + const uint32_t l_minus_1 = 31 - __builtin_clz(d - 1); + #else + /* Based on Algorithm 2 from Hacker's delight */ + + uint32_t l_minus_1 = 0; + uint32_t x = d - 1; + uint32_t y = x >> 16; + if (y != 0) { + l_minus_1 += 16; + x = y; + } + y = x >> 8; + if (y != 0) { + l_minus_1 += 8; + x = y; + } + y = x >> 4; + if (y != 0) { + l_minus_1 += 4; + x = y; + } + y = x >> 2; + if (y != 0) { + l_minus_1 += 2; + x = y; + } + if ((x & 2) != 0) { + l_minus_1 += 1; + } + #endif + uint32_t u_hi = (UINT32_C(2) << (uint32_t) l_minus_1) - d; + + /* Division of 64-bit number u_hi:UINT32_C(0) by 32-bit number d, 32-bit quotient output q */ + #if defined(__GNUC__) && defined(__i386__) && FXDIV_USE_INLINE_ASSEMBLY + uint32_t q; + __asm__("DIVL %[d]" + : "=a" (q), "+d" (u_hi) + : [d] "r" (d), "a" (0) + : "cc"); + #elif (defined(_MSC_VER) && _MSC_VER >= 1920) && !defined(__clang__) && !defined(__INTEL_COMPILER) && (defined(_M_IX86) || defined(_M_X64)) + unsigned int remainder; + const uint32_t q = (uint32_t) _udiv64((unsigned __int64) ((uint64_t) u_hi << 32), (unsigned int) d, &remainder); + #else + const uint32_t q = ((uint64_t) u_hi << 32) / d; + #endif + + result.m = q + UINT32_C(1); + result.s1 = 1; + result.s2 = (uint8_t) l_minus_1; + } + return result; +} + +static inline struct fxdiv_divisor_uint64_t fxdiv_init_uint64_t(uint64_t d) { + struct fxdiv_divisor_uint64_t result = { d }; + if (d == 1) { + result.m = UINT64_C(1); + result.s1 = 0; + result.s2 = 0; + } else { + #if defined(__OPENCL_VERSION__) + const uint32_t nlz_d = clz(d); + const uint32_t l_minus_1 = 63 - clz(d - 1); + #elif defined(__CUDA_ARCH__) + const uint32_t nlz_d = __clzll((long long) d); + const uint32_t l_minus_1 = 63 - __clzll((long long) (d - 1)); + #elif defined(_MSC_VER) && (defined(_M_X64) || defined(_M_ARM64)) + unsigned long l_minus_1; + _BitScanReverse64(&l_minus_1, (unsigned __int64) (d - 1)); + unsigned long bsr_d; + _BitScanReverse64(&bsr_d, (unsigned __int64) d); + const uint32_t nlz_d = bsr_d ^ 0x3F; + #elif defined(_MSC_VER) && (defined(_M_IX86) || defined(_M_ARM)) + const uint64_t d_minus_1 = d - 1; + const uint8_t d_is_power_of_2 = (d & d_minus_1) == 0; + unsigned long l_minus_1; + if ((uint32_t) (d_minus_1 >> 32) == 0) { + _BitScanReverse(&l_minus_1, (unsigned long) d_minus_1); + } else { + _BitScanReverse(&l_minus_1, (unsigned long) (uint32_t) (d_minus_1 >> 32)); + l_minus_1 += 32; + } + const uint32_t nlz_d = ((uint8_t) l_minus_1 ^ UINT8_C(0x3F)) - d_is_power_of_2; + #elif defined(__GNUC__) && defined(__x86_64__) && FXDIV_USE_INLINE_ASSEMBLY + uint64_t l_minus_1; + __asm__("BSRQ %[d_minus_1], %[l_minus_1]" + : [l_minus_1] "=r" (l_minus_1) + : [d_minus_1] "r" (d - 1) + : "cc"); + #elif defined(__GNUC__) + const uint32_t l_minus_1 = 63 - __builtin_clzll(d - 1); + const uint32_t nlz_d = __builtin_clzll(d); + #else + /* Based on Algorithm 2 from Hacker's delight */ + const uint64_t d_minus_1 = d - 1; + const uint32_t d_is_power_of_2 = (d & d_minus_1) == 0; + uint32_t l_minus_1 = 0; + uint32_t x = (uint32_t) d_minus_1; + uint32_t y = d_minus_1 >> 32; + if (y != 0) { + l_minus_1 += 32; + x = y; + } + y = x >> 16; + if (y != 0) { + l_minus_1 += 16; + x = y; + } + y = x >> 8; + if (y != 0) { + l_minus_1 += 8; + x = y; + } + y = x >> 4; + if (y != 0) { + l_minus_1 += 4; + x = y; + } + y = x >> 2; + if (y != 0) { + l_minus_1 += 2; + x = y; + } + if ((x & 2) != 0) { + l_minus_1 += 1; + } + const uint32_t nlz_d = (l_minus_1 ^ UINT32_C(0x3F)) - d_is_power_of_2; + #endif + uint64_t u_hi = (UINT64_C(2) << (uint32_t) l_minus_1) - d; + + /* Division of 128-bit number u_hi:UINT64_C(0) by 64-bit number d, 64-bit quotient output q */ + #if defined(__GNUC__) && defined(__x86_64__) && FXDIV_USE_INLINE_ASSEMBLY + uint64_t q; + __asm__("DIVQ %[d]" + : "=a" (q), "+d" (u_hi) + : [d] "r" (d), "a" (UINT64_C(0)) + : "cc"); + #elif 0 && defined(__GNUC__) && defined(__SIZEOF_INT128__) + /* GCC, Clang, and Intel Compiler fail to inline optimized implementation and call into support library for 128-bit division */ + const uint64_t q = (uint64_t) (((unsigned __int128) u_hi << 64) / ((unsigned __int128) d)); + #elif (defined(_MSC_VER) && _MSC_VER >= 1920) && !defined(__clang__) && !defined(__INTEL_COMPILER) && defined(_M_X64) + unsigned __int64 remainder; + const uint64_t q = (uint64_t) _udiv128((unsigned __int64) u_hi, 0, (unsigned __int64) d, &remainder); + #else + /* Implementation based on code from Hacker's delight */ + + /* Normalize divisor and shift divident left */ + d <<= nlz_d; + u_hi <<= nlz_d; + /* Break divisor up into two 32-bit digits */ + const uint64_t d_hi = (uint32_t) (d >> 32); + const uint32_t d_lo = (uint32_t) d; + + /* Compute the first quotient digit, q1 */ + uint64_t q1 = u_hi / d_hi; + uint64_t r1 = u_hi - q1 * d_hi; + + while ((q1 >> 32) != 0 || fxdiv_mulext_uint32_t((uint32_t) q1, d_lo) > (r1 << 32)) { + q1 -= 1; + r1 += d_hi; + if ((r1 >> 32) != 0) { + break; + } + } + + /* Multiply and subtract. */ + u_hi = (u_hi << 32) - q1 * d; + + /* Compute the second quotient digit, q0 */ + uint64_t q0 = u_hi / d_hi; + uint64_t r0 = u_hi - q0 * d_hi; + + while ((q0 >> 32) != 0 || fxdiv_mulext_uint32_t((uint32_t) q0, d_lo) > (r0 << 32)) { + q0 -= 1; + r0 += d_hi; + if ((r0 >> 32) != 0) { + break; + } + } + const uint64_t q = (q1 << 32) | (uint32_t) q0; + #endif + result.m = q + UINT64_C(1); + result.s1 = 1; + result.s2 = (uint8_t) l_minus_1; + } + return result; +} + +static inline struct fxdiv_divisor_size_t fxdiv_init_size_t(size_t d) { +#if SIZE_MAX == UINT32_MAX + const struct fxdiv_divisor_uint32_t uint_result = fxdiv_init_uint32_t((uint32_t) d); +#elif SIZE_MAX == UINT64_MAX + const struct fxdiv_divisor_uint64_t uint_result = fxdiv_init_uint64_t((uint64_t) d); +#else + #error Unsupported platform +#endif + struct fxdiv_divisor_size_t size_result = { + (size_t) uint_result.value, + (size_t) uint_result.m, + uint_result.s1, + uint_result.s2 + }; + return size_result; +} + +static inline uint32_t fxdiv_quotient_uint32_t(uint32_t n, const struct fxdiv_divisor_uint32_t divisor) { + const uint32_t t = fxdiv_mulhi_uint32_t(n, divisor.m); + return (t + ((n - t) >> divisor.s1)) >> divisor.s2; +} + +static inline uint64_t fxdiv_quotient_uint64_t(uint64_t n, const struct fxdiv_divisor_uint64_t divisor) { + const uint64_t t = fxdiv_mulhi_uint64_t(n, divisor.m); + return (t + ((n - t) >> divisor.s1)) >> divisor.s2; +} + +static inline size_t fxdiv_quotient_size_t(size_t n, const struct fxdiv_divisor_size_t divisor) { +#if SIZE_MAX == UINT32_MAX + const struct fxdiv_divisor_uint32_t uint32_divisor = { + (uint32_t) divisor.value, + (uint32_t) divisor.m, + divisor.s1, + divisor.s2 + }; + return fxdiv_quotient_uint32_t((uint32_t) n, uint32_divisor); +#elif SIZE_MAX == UINT64_MAX + const struct fxdiv_divisor_uint64_t uint64_divisor = { + (uint64_t) divisor.value, + (uint64_t) divisor.m, + divisor.s1, + divisor.s2 + }; + return fxdiv_quotient_uint64_t((uint64_t) n, uint64_divisor); +#else + #error Unsupported platform +#endif +} + +static inline uint32_t fxdiv_remainder_uint32_t(uint32_t n, const struct fxdiv_divisor_uint32_t divisor) { + const uint32_t quotient = fxdiv_quotient_uint32_t(n, divisor); + return n - quotient * divisor.value; +} + +static inline uint64_t fxdiv_remainder_uint64_t(uint64_t n, const struct fxdiv_divisor_uint64_t divisor) { + const uint64_t quotient = fxdiv_quotient_uint64_t(n, divisor); + return n - quotient * divisor.value; +} + +static inline size_t fxdiv_remainder_size_t(size_t n, const struct fxdiv_divisor_size_t divisor) { + const size_t quotient = fxdiv_quotient_size_t(n, divisor); + return n - quotient * divisor.value; +} + +static inline uint32_t fxdiv_round_down_uint32_t(uint32_t n, const struct fxdiv_divisor_uint32_t granularity) { + const uint32_t quotient = fxdiv_quotient_uint32_t(n, granularity); + return quotient * granularity.value; +} + +static inline uint64_t fxdiv_round_down_uint64_t(uint64_t n, const struct fxdiv_divisor_uint64_t granularity) { + const uint64_t quotient = fxdiv_quotient_uint64_t(n, granularity); + return quotient * granularity.value; +} + +static inline size_t fxdiv_round_down_size_t(size_t n, const struct fxdiv_divisor_size_t granularity) { + const size_t quotient = fxdiv_quotient_size_t(n, granularity); + return quotient * granularity.value; +} + +static inline struct fxdiv_result_uint32_t fxdiv_divide_uint32_t(uint32_t n, const struct fxdiv_divisor_uint32_t divisor) { + const uint32_t quotient = fxdiv_quotient_uint32_t(n, divisor); + const uint32_t remainder = n - quotient * divisor.value; + struct fxdiv_result_uint32_t result = { quotient, remainder }; + return result; +} + +static inline struct fxdiv_result_uint64_t fxdiv_divide_uint64_t(uint64_t n, const struct fxdiv_divisor_uint64_t divisor) { + const uint64_t quotient = fxdiv_quotient_uint64_t(n, divisor); + const uint64_t remainder = n - quotient * divisor.value; + struct fxdiv_result_uint64_t result = { quotient, remainder }; + return result; +} + +static inline struct fxdiv_result_size_t fxdiv_divide_size_t(size_t n, const struct fxdiv_divisor_size_t divisor) { + const size_t quotient = fxdiv_quotient_size_t(n, divisor); + const size_t remainder = n - quotient * divisor.value; + struct fxdiv_result_size_t result = { quotient, remainder }; + return result; +} + +#endif /* FXDIV_H */ diff --git a/phivenv/Lib/site-packages/torch/include/ittnotify-zca.h b/phivenv/Lib/site-packages/torch/include/ittnotify-zca.h new file mode 100644 index 0000000000000000000000000000000000000000..9da96394c4a6cca537f36252b44bb4e905501ee0 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ittnotify-zca.h @@ -0,0 +1,81 @@ +/* + Copyright (C) 2005-2019 Intel Corporation + + SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause +*/ + +/** + * Zero Cost Annotations (ZCA) + * + * Intel Compiler supports two intrinsics that could be used for code annotations + * without incurring significant run-time costs when the tools are not in use. + * Each annotation is more than a mere mark in the instruction stream. + * It can accept an expression argument like a call to a routine does. + * There are two forms of the intrinsic, with the following signatures: + * + * extern "C" void __notify_intrinsic( const char *annotation, const volatile void *tag); + * extern "C" void __notify_zc_intrinsic(const char *annotation, const volatile void *tag); + * + * The string annotation must be a compile-time constant. It specifies the type of the annotation. + * The pointer tag is computed at run time. It specifies the data associated with the annotation. + * Each intrinsic implies a compiler fence: the compiler must not move any memory + * operation across it. The reason for this restriction is that annotation might denote an + * event that must be precisely placed with respect to memory operations. + * + * The difference between the two intrinsics is that __notify_intrinsic must leave a + * probe-ready instruction sequence in the instruction stream where the instrinsic + * occurs. The __notify_zc_intrinsic does not leave such a sequence, and hence is closer to "zero cost". + **/ + +#pragma once +#include "ittnotify.h" + +#ifndef INTEL_NO_ITTNOTIFY_API +#if (defined(__INTEL_COMPILER) || defined(__INTEL_LLVM_COMPILER)) && (ITT_PLATFORM == ITT_PLATFORM_WIN || ITT_PLATFORM == ITT_PLATFORM_POSIX) +#define ITT_ENABLE_LOW_OVERHEAD_ANNOTATIONS +#else +#error Zero cost (low overhead) annotations are not supported on this platform +#endif +#endif + +/** + * Zero cost annotations for memory allocation and deallocation + **/ +#ifdef ITT_ENABLE_LOW_OVERHEAD_ANNOTATIONS +#pragma pack(push, 1) +typedef struct ___itt_zca_allocation_info { + size_t size; /*!< Size of allocated memory */ + void** ptr; /*!< Pointer to allocated memory pointer */ + int initialized; /*!< Is allocated memory initialized */ +} __itt_zca_allocation_info; +#pragma pack(pop) + +#define __itt_zca_mem_allocate_begin() __notify_intrinsic((char*)"mem_allocate_begin", 0) +#define __itt_zca_mem_allocate_end(ptr, size, init) { __itt_zca_allocation_info __itt_zca_alloc_info = { size, ptr, init }; __notify_intrinsic((char*)"mem_allocate_end", (void*)&__itt_zca_alloc_info); } +#define __itt_zca_mem_free_begin(ptr) __notify_intrinsic((char*)"mem_free_begin", (void*)ptr) +#define __itt_zca_mem_free_end() __notify_intrinsic((char*)"mem_free_end", 0) +#else +#define __itt_zca_mem_allocate_begin() +#define __itt_zca_mem_allocate_end(ptr, size, init) +#define __itt_zca_mem_free_begin(ptr) +#define __itt_zca_mem_free_end() +#endif + +/** + * Zero cost annotations for threading + **/ +#ifdef ITT_ENABLE_LOW_OVERHEAD_ANNOTATIONS +#define __itt_zca_suppress_push(id) __notify_zc_intrinsic((char*)"__itt_suppress_push", (void*)id); +#define __itt_zca_suppress_pop(id) __notify_zc_intrinsic((char*)"__itt_suppress_pop", (void*)id); +#define __itt_zca_sync_create(id) __notify_zc_intrinsic((char*)"__itt_sync_create", (void*)id) +#define __itt_zca_sync_acquired(id) __notify_zc_intrinsic((char*)"__itt_sync_acquired", (void*)id) +#define __itt_zca_sync_releasing(id) __notify_zc_intrinsic((char*)"__itt_sync_releasing", (void*)id) +#define __itt_zca_sync_destroy(id) __notify_zc_intrinsic((char*)"__itt_sync_destroy", (void*)id) +#else +#define __itt_zca_suppress_push(id) +#define __itt_zca_suppress_pop(id) +#define __itt_zca_sync_create(id) +#define __itt_zca_sync_acquired(id) +#define __itt_zca_sync_releasing(id) +#define __itt_zca_sync_destroy(id) +#endif diff --git a/phivenv/Lib/site-packages/torch/include/ittnotify.h b/phivenv/Lib/site-packages/torch/include/ittnotify.h new file mode 100644 index 0000000000000000000000000000000000000000..0418be309aedac94482631910a7e9cea9a92f22d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/ittnotify.h @@ -0,0 +1,4665 @@ +/* + Copyright (C) 2005-2019 Intel Corporation + + SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause +*/ +#ifndef _ITTNOTIFY_H_ +#define _ITTNOTIFY_H_ + +/** +@file +@brief Public User API functions and types +@mainpage + +The Instrumentation and Tracing Technology API (ITT API) is used to +annotate a user's program with additional information +that can be used by correctness and performance tools. The user inserts +calls in their program. Those calls generate information that is collected +at runtime, and used by Intel(R) Threading Tools. + +@section API Concepts +The following general concepts are used throughout the API. + +@subsection Unicode Support +Many API functions take character string arguments. On Windows, there +are two versions of each such function. The function name is suffixed +by W if Unicode support is enabled, and by A otherwise. Any API function +that takes a character string argument adheres to this convention. + +@subsection Conditional Compilation +Many users prefer having an option to modify ITT API code when linking it +inside their runtimes. ITT API header file provides a mechanism to replace +ITT API function names inside your code with empty strings. To do this, +define the macros INTEL_NO_ITTNOTIFY_API during compilation and remove the +static library from the linker script. + +@subsection Domains +[see domains] +Domains provide a way to separate notification for different modules or +libraries in a program. Domains are specified by dotted character strings, +e.g. TBB.Internal.Control. + +A mechanism (to be specified) is provided to enable and disable +domains. By default, all domains are enabled. +@subsection Named Entities and Instances +Named entities (frames, regions, tasks, and markers) communicate +information about the program to the analysis tools. A named entity often +refers to a section of program code, or to some set of logical concepts +that the programmer wants to group together. + +Named entities relate to the programmer's static view of the program. When +the program actually executes, many instances of a given named entity +may be created. + +The API annotations denote instances of named entities. The actual +named entities are displayed using the analysis tools. In other words, +the named entities come into existence when instances are created. + +Instances of named entities may have instance identifiers (IDs). Some +API calls use instance identifiers to create relationships between +different instances of named entities. Other API calls associate data +with instances of named entities. + +Some named entities must always have instance IDs. In particular, regions +and frames always have IDs. Task and markers need IDs only if the ID is +needed in another API call (such as adding a relation or metadata). + +The lifetime of instance IDs is distinct from the lifetime of +instances. This allows various relationships to be specified separate +from the actual execution of instances. This flexibility comes at the +expense of extra API calls. + +The same ID may not be reused for different instances, unless a previous +[ref] __itt_id_destroy call for that ID has been issued. +*/ + +/** @cond exclude_from_documentation */ +#ifndef ITT_OS_WIN +# define ITT_OS_WIN 1 +#endif /* ITT_OS_WIN */ + +#ifndef ITT_OS_LINUX +# define ITT_OS_LINUX 2 +#endif /* ITT_OS_LINUX */ + +#ifndef ITT_OS_MAC +# define ITT_OS_MAC 3 +#endif /* ITT_OS_MAC */ + +#ifndef ITT_OS_FREEBSD +# define ITT_OS_FREEBSD 4 +#endif /* ITT_OS_FREEBSD */ + +#ifndef ITT_OS_OPENBSD +# define ITT_OS_OPENBSD 5 +#endif /* ITT_OS_OPENBSD */ + +#ifndef ITT_OS +# if defined WIN32 || defined _WIN32 +# define ITT_OS ITT_OS_WIN +# elif defined( __APPLE__ ) && defined( __MACH__ ) +# define ITT_OS ITT_OS_MAC +# elif defined( __FreeBSD__ ) +# define ITT_OS ITT_OS_FREEBSD +# elif defined( __OpenBSD__) +# define ITT_OS ITT_OS_OPENBSD +# else +# define ITT_OS ITT_OS_LINUX +# endif +#endif /* ITT_OS */ + +#ifndef ITT_PLATFORM_WIN +# define ITT_PLATFORM_WIN 1 +#endif /* ITT_PLATFORM_WIN */ + +#ifndef ITT_PLATFORM_POSIX +# define ITT_PLATFORM_POSIX 2 +#endif /* ITT_PLATFORM_POSIX */ + +#ifndef ITT_PLATFORM_MAC +# define ITT_PLATFORM_MAC 3 +#endif /* ITT_PLATFORM_MAC */ + +#ifndef ITT_PLATFORM_FREEBSD +# define ITT_PLATFORM_FREEBSD 4 +#endif /* ITT_PLATFORM_FREEBSD */ + +#ifndef ITT_PLATFORM_OPENBSD +# define ITT_PLATFORM_OPENBSD 5 +#endif /* ITT_PLATFORM_OPENBSD */ + +#ifndef ITT_PLATFORM +# if ITT_OS==ITT_OS_WIN +# define ITT_PLATFORM ITT_PLATFORM_WIN +# elif ITT_OS==ITT_OS_MAC +# define ITT_PLATFORM ITT_PLATFORM_MAC +# elif ITT_OS==ITT_OS_FREEBSD +# define ITT_PLATFORM ITT_PLATFORM_FREEBSD +# elif ITT_OS==ITT_OS_OPENBSD +# define ITT_PLATFORM ITT_PLATFORM_OPENBSD +# else +# define ITT_PLATFORM ITT_PLATFORM_POSIX +# endif +#endif /* ITT_PLATFORM */ + +#if defined(_UNICODE) && !defined(UNICODE) +#define UNICODE +#endif + +#include +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#include +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#include +#if defined(UNICODE) || defined(_UNICODE) +#include +#endif /* UNICODE || _UNICODE */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef ITTAPI_CDECL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define ITTAPI_CDECL __cdecl +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define ITTAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define ITTAPI_CDECL /* actual only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* ITTAPI_CDECL */ + +#ifndef STDCALL +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# define STDCALL __stdcall +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# if defined _M_IX86 || defined __i386__ +# define STDCALL __attribute__ ((stdcall)) +# else /* _M_IX86 || __i386__ */ +# define STDCALL /* supported only on x86 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* STDCALL */ + +#define ITTAPI ITTAPI_CDECL +#define LIBITTAPI ITTAPI_CDECL + +/* TODO: Temporary for compatibility! */ +#define ITTAPI_CALL ITTAPI_CDECL +#define LIBITTAPI_CALL ITTAPI_CDECL + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +/* use __forceinline (VC++ specific) */ +#if defined(__MINGW32__) && !defined(__cplusplus) +#define ITT_INLINE static __inline__ __attribute__((__always_inline__,__gnu_inline__)) +#else +#define ITT_INLINE static __forceinline +#endif /* __MINGW32__ */ + +#define ITT_INLINE_ATTRIBUTE /* nothing */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/* + * Generally, functions are not inlined unless optimization is specified. + * For functions declared inline, this attribute inlines the function even + * if no optimization level was specified. + */ +#ifdef __STRICT_ANSI__ +#define ITT_INLINE static +#define ITT_INLINE_ATTRIBUTE __attribute__((unused)) +#else /* __STRICT_ANSI__ */ +#define ITT_INLINE static inline +#define ITT_INLINE_ATTRIBUTE __attribute__((always_inline, unused)) +#endif /* __STRICT_ANSI__ */ +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +/** @endcond */ + +#ifdef INTEL_ITTNOTIFY_ENABLE_LEGACY +# if ITT_PLATFORM==ITT_PLATFORM_WIN +# pragma message("WARNING!!! Deprecated API is used. Please undefine INTEL_ITTNOTIFY_ENABLE_LEGACY macro") +# else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# warning "Deprecated API is used. Please undefine INTEL_ITTNOTIFY_ENABLE_LEGACY macro" +# endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +# include "legacy/ittnotify.h" +#endif /* INTEL_ITTNOTIFY_ENABLE_LEGACY */ + +/** @cond exclude_from_documentation */ +/* Helper macro for joining tokens */ +#define ITT_JOIN_AUX(p,n) p##n +#define ITT_JOIN(p,n) ITT_JOIN_AUX(p,n) + +#ifdef ITT_MAJOR +#undef ITT_MAJOR +#endif +#ifdef ITT_MINOR +#undef ITT_MINOR +#endif +#define ITT_MAJOR 3 +#define ITT_MINOR 0 + +/* Standard versioning of a token with major and minor version numbers */ +#define ITT_VERSIONIZE(x) \ + ITT_JOIN(x, \ + ITT_JOIN(_, \ + ITT_JOIN(ITT_MAJOR, \ + ITT_JOIN(_, ITT_MINOR)))) + +#ifndef INTEL_ITTNOTIFY_PREFIX +# define INTEL_ITTNOTIFY_PREFIX __itt_ +#endif /* INTEL_ITTNOTIFY_PREFIX */ +#ifndef INTEL_ITTNOTIFY_POSTFIX +# define INTEL_ITTNOTIFY_POSTFIX _ptr_ +#endif /* INTEL_ITTNOTIFY_POSTFIX */ + +#define ITTNOTIFY_NAME_AUX(n) ITT_JOIN(INTEL_ITTNOTIFY_PREFIX,n) +#define ITTNOTIFY_NAME(n) ITT_VERSIONIZE(ITTNOTIFY_NAME_AUX(ITT_JOIN(n,INTEL_ITTNOTIFY_POSTFIX))) + +#define ITTNOTIFY_VOID(n) (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n) +#define ITTNOTIFY_DATA(n) (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n) + +#define ITTNOTIFY_VOID_D0(n,d) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d) +#define ITTNOTIFY_VOID_D1(n,d,x) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x) +#define ITTNOTIFY_VOID_D2(n,d,x,y) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y) +#define ITTNOTIFY_VOID_D3(n,d,x,y,z) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z) +#define ITTNOTIFY_VOID_D4(n,d,x,y,z,a) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) +#define ITTNOTIFY_VOID_D5(n,d,x,y,z,a,b) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) +#define ITTNOTIFY_VOID_D6(n,d,x,y,z,a,b,c) (d == NULL) ? (void)0 : (!(d)->flags) ? (void)0 : (!ITTNOTIFY_NAME(n)) ? (void)0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) +#define ITTNOTIFY_DATA_D0(n,d) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d) +#define ITTNOTIFY_DATA_D1(n,d,x) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x) +#define ITTNOTIFY_DATA_D2(n,d,x,y) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y) +#define ITTNOTIFY_DATA_D3(n,d,x,y,z) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z) +#define ITTNOTIFY_DATA_D4(n,d,x,y,z,a) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a) +#define ITTNOTIFY_DATA_D5(n,d,x,y,z,a,b) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b) +#define ITTNOTIFY_DATA_D6(n,d,x,y,z,a,b,c) (d == NULL) ? 0 : (!(d)->flags) ? 0 : (!ITTNOTIFY_NAME(n)) ? 0 : ITTNOTIFY_NAME(n)(d,x,y,z,a,b,c) + +#ifdef ITT_STUB +#undef ITT_STUB +#endif +#ifdef ITT_STUBV +#undef ITT_STUBV +#endif +#define ITT_STUBV(api,type,name,args) \ + typedef type (api* ITT_JOIN(ITTNOTIFY_NAME(name),_t)) args; \ + extern ITT_JOIN(ITTNOTIFY_NAME(name),_t) ITTNOTIFY_NAME(name); +#define ITT_STUB ITT_STUBV +/** @endcond */ + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** @cond exclude_from_gpa_documentation */ +/** + * @defgroup public Public API + * @{ + * @} + */ + +/** + * @defgroup control Collection Control + * @ingroup public + * General behavior: application continues to run, but no profiling information is being collected + * + * Pausing occurs not only for the current thread but for all process as well as spawned processes + * - Intel(R) Parallel Inspector and Intel(R) Inspector XE: + * - Does not analyze or report errors that involve memory access. + * - Other errors are reported as usual. Pausing data collection in + * Intel(R) Parallel Inspector and Intel(R) Inspector XE + * only pauses tracing and analyzing memory access. + * It does not pause tracing or analyzing threading APIs. + * . + * - Intel(R) VTune(TM) Profiler: + * - Does continue to record when new threads are started. + * . + * - Other effects: + * - Possible reduction of runtime overhead. + * . + * @{ + */ +/** @brief Pause collection */ +void ITTAPI __itt_pause(void); +/** @brief Resume collection */ +void ITTAPI __itt_resume(void); +/** @brief Detach collection */ +void ITTAPI __itt_detach(void); + +/** + * @enum __itt_collection_scope + * @brief Enumerator for collection scopes + */ +typedef enum { + __itt_collection_scope_host = 1 << 0, + __itt_collection_scope_offload = 1 << 1, + __itt_collection_scope_all = 0x7FFFFFFF +} __itt_collection_scope; + +/** @brief Pause scoped collection */ +void ITTAPI __itt_pause_scoped(__itt_collection_scope); +/** @brief Resume scoped collection */ +void ITTAPI __itt_resume_scoped(__itt_collection_scope); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, pause, (void)) +ITT_STUBV(ITTAPI, void, pause_scoped, (__itt_collection_scope)) +ITT_STUBV(ITTAPI, void, resume, (void)) +ITT_STUBV(ITTAPI, void, resume_scoped, (__itt_collection_scope)) +ITT_STUBV(ITTAPI, void, detach, (void)) +#define __itt_pause ITTNOTIFY_VOID(pause) +#define __itt_pause_ptr ITTNOTIFY_NAME(pause) +#define __itt_pause_scoped ITTNOTIFY_VOID(pause_scoped) +#define __itt_pause_scoped_ptr ITTNOTIFY_NAME(pause_scoped) +#define __itt_resume ITTNOTIFY_VOID(resume) +#define __itt_resume_ptr ITTNOTIFY_NAME(resume) +#define __itt_resume_scoped ITTNOTIFY_VOID(resume_scoped) +#define __itt_resume_scoped_ptr ITTNOTIFY_NAME(resume_scoped) +#define __itt_detach ITTNOTIFY_VOID(detach) +#define __itt_detach_ptr ITTNOTIFY_NAME(detach) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_pause() +#define __itt_pause_ptr 0 +#define __itt_pause_scoped(scope) +#define __itt_pause_scoped_ptr 0 +#define __itt_resume() +#define __itt_resume_ptr 0 +#define __itt_resume_scoped(scope) +#define __itt_resume_scoped_ptr 0 +#define __itt_detach() +#define __itt_detach_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_pause_ptr 0 +#define __itt_pause_scoped_ptr 0 +#define __itt_resume_ptr 0 +#define __itt_resume_scoped_ptr 0 +#define __itt_detach_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} control group */ +/** @endcond */ + +/** + * @defgroup Intel Processor Trace control + * API from this group provides control over collection and analysis of Intel Processor Trace (Intel PT) data + * Information about Intel Processor Trace technology can be found here (Volume 3 chapter 35): + * https://software.intel.com/sites/default/files/managed/39/c5/325462-sdm-vol-1-2abcd-3abcd.pdf + * Use this API to mark particular code regions for loading detailed performance statistics. + * This mode makes your analysis faster and more accurate. + * @{ +*/ +typedef unsigned char __itt_pt_region; + +/** + * @brief function saves a region name marked with Intel PT API and returns a region id. + * Only 7 names can be registered. Attempts to register more names will be ignored and a region id with auto names will be returned. + * For automatic naming of regions pass NULL as function parameter +*/ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_pt_region ITTAPI __itt_pt_region_createA(const char *name); +__itt_pt_region ITTAPI __itt_pt_region_createW(const wchar_t *name); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_pt_region_create __itt_pt_region_createW +#else /* UNICODE */ +# define __itt_pt_region_create __itt_pt_region_createA +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_pt_region ITTAPI __itt_pt_region_create(const char *name); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_pt_region, pt_region_createA, (const char *name)) +ITT_STUB(ITTAPI, __itt_pt_region, pt_region_createW, (const wchar_t *name)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_pt_region, pt_region_create, (const char *name)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_pt_region_createA ITTNOTIFY_DATA(pt_region_createA) +#define __itt_pt_region_createA_ptr ITTNOTIFY_NAME(pt_region_createA) +#define __itt_pt_region_createW ITTNOTIFY_DATA(pt_region_createW) +#define __itt_pt_region_createW_ptr ITTNOTIFY_NAME(pt_region_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_pt_region_create ITTNOTIFY_DATA(pt_region_create) +#define __itt_pt_region_create_ptr ITTNOTIFY_NAME(pt_region_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_pt_region_createA(name) (__itt_pt_region)0 +#define __itt_pt_region_createA_ptr 0 +#define __itt_pt_region_createW(name) (__itt_pt_region)0 +#define __itt_pt_region_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_pt_region_create(name) (__itt_pt_region)0 +#define __itt_pt_region_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_pt_region_createA_ptr 0 +#define __itt_pt_region_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_pt_region_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief function contains a special code pattern identified on the post-processing stage and + * marks the beginning of a code region targeted for Intel PT analysis + * @param[in] region - region id, 0 <= region < 8 +*/ +void __itt_mark_pt_region_begin(__itt_pt_region region); +/** + * @brief function contains a special code pattern identified on the post-processing stage and + * marks the end of a code region targeted for Intel PT analysis + * @param[in] region - region id, 0 <= region < 8 +*/ +void __itt_mark_pt_region_end(__itt_pt_region region); +/** @} Intel PT control group*/ + +/** + * @defgroup threads Threads + * @ingroup public + * Give names to threads + * @{ + */ +/** + * @brief Sets thread name of calling thread + * @param[in] name - name of thread + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_thread_set_nameA(const char *name); +void ITTAPI __itt_thread_set_nameW(const wchar_t *name); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_thread_set_name __itt_thread_set_nameW +# define __itt_thread_set_name_ptr __itt_thread_set_nameW_ptr +#else /* UNICODE */ +# define __itt_thread_set_name __itt_thread_set_nameA +# define __itt_thread_set_name_ptr __itt_thread_set_nameA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_thread_set_name(const char *name); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, thread_set_nameA, (const char *name)) +ITT_STUBV(ITTAPI, void, thread_set_nameW, (const wchar_t *name)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUBV(ITTAPI, void, thread_set_name, (const char *name)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_thread_set_nameA ITTNOTIFY_VOID(thread_set_nameA) +#define __itt_thread_set_nameA_ptr ITTNOTIFY_NAME(thread_set_nameA) +#define __itt_thread_set_nameW ITTNOTIFY_VOID(thread_set_nameW) +#define __itt_thread_set_nameW_ptr ITTNOTIFY_NAME(thread_set_nameW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_thread_set_name ITTNOTIFY_VOID(thread_set_name) +#define __itt_thread_set_name_ptr ITTNOTIFY_NAME(thread_set_name) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_thread_set_nameA(name) +#define __itt_thread_set_nameA_ptr 0 +#define __itt_thread_set_nameW(name) +#define __itt_thread_set_nameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_thread_set_name(name) +#define __itt_thread_set_name_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_thread_set_nameA_ptr 0 +#define __itt_thread_set_nameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_thread_set_name_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @cond exclude_from_gpa_documentation */ + +/** + * @brief Mark current thread as ignored from this point on, for the duration of its existence. + */ +void ITTAPI __itt_thread_ignore(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, thread_ignore, (void)) +#define __itt_thread_ignore ITTNOTIFY_VOID(thread_ignore) +#define __itt_thread_ignore_ptr ITTNOTIFY_NAME(thread_ignore) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_thread_ignore() +#define __itt_thread_ignore_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_thread_ignore_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} threads group */ + +/** + * @defgroup suppress Error suppression + * @ingroup public + * General behavior: application continues to run, but errors are suppressed + * + * @{ + */ + +/*****************************************************************//** + * @name group of functions used for error suppression in correctness tools + *********************************************************************/ +/** @{ */ +/** + * @hideinitializer + * @brief possible value for suppression mask + */ +#define __itt_suppress_all_errors 0x7fffffff + +/** + * @hideinitializer + * @brief possible value for suppression mask (suppresses errors from threading analysis) + */ +#define __itt_suppress_threading_errors 0x000000ff + +/** + * @hideinitializer + * @brief possible value for suppression mask (suppresses errors from memory analysis) + */ +#define __itt_suppress_memory_errors 0x0000ff00 + +/** + * @brief Start suppressing errors identified in mask on this thread + */ +void ITTAPI __itt_suppress_push(unsigned int mask); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, suppress_push, (unsigned int mask)) +#define __itt_suppress_push ITTNOTIFY_VOID(suppress_push) +#define __itt_suppress_push_ptr ITTNOTIFY_NAME(suppress_push) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_suppress_push(mask) +#define __itt_suppress_push_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_suppress_push_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Undo the effects of the matching call to __itt_suppress_push + */ +void ITTAPI __itt_suppress_pop(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, suppress_pop, (void)) +#define __itt_suppress_pop ITTNOTIFY_VOID(suppress_pop) +#define __itt_suppress_pop_ptr ITTNOTIFY_NAME(suppress_pop) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_suppress_pop() +#define __itt_suppress_pop_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_suppress_pop_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @enum __itt_suppress_mode + * @brief Enumerator for the suppressing modes + */ +typedef enum __itt_suppress_mode { + __itt_unsuppress_range, + __itt_suppress_range +} __itt_suppress_mode_t; + +/** + * @enum __itt_collection_state + * @brief Enumerator for collection state. + */ +typedef enum { + __itt_collection_uninitialized = 0, /* uninitialized */ + __itt_collection_init_fail = 1, /* failed to init */ + __itt_collection_collector_absent = 2, /* non work state collector is absent */ + __itt_collection_collector_exists = 3, /* work state collector exists */ + __itt_collection_init_successful = 4 /* success to init */ +} __itt_collection_state; + +/** + * @brief Mark a range of memory for error suppression or unsuppression for error types included in mask + */ +void ITTAPI __itt_suppress_mark_range(__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, suppress_mark_range, (__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size)) +#define __itt_suppress_mark_range ITTNOTIFY_VOID(suppress_mark_range) +#define __itt_suppress_mark_range_ptr ITTNOTIFY_NAME(suppress_mark_range) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_suppress_mark_range(mask) +#define __itt_suppress_mark_range_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_suppress_mark_range_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Undo the effect of a matching call to __itt_suppress_mark_range. If not matching + * call is found, nothing is changed. + */ +void ITTAPI __itt_suppress_clear_range(__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, suppress_clear_range, (__itt_suppress_mode_t mode, unsigned int mask, void * address, size_t size)) +#define __itt_suppress_clear_range ITTNOTIFY_VOID(suppress_clear_range) +#define __itt_suppress_clear_range_ptr ITTNOTIFY_NAME(suppress_clear_range) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_suppress_clear_range(mask) +#define __itt_suppress_clear_range_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_suppress_clear_range_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} */ +/** @} suppress group */ + +/** + * @defgroup sync Synchronization + * @ingroup public + * Indicate user-written synchronization code + * @{ + */ +/** + * @hideinitializer + * @brief possible value of attribute argument for sync object type + */ +#define __itt_attr_barrier 1 + +/** + * @hideinitializer + * @brief possible value of attribute argument for sync object type + */ +#define __itt_attr_mutex 2 + +/** +@brief Name a synchronization object +@param[in] addr Handle for the synchronization object. You should +use a real address to uniquely identify the synchronization object. +@param[in] objtype null-terminated object type string. If NULL is +passed, the name will be "User Synchronization". +@param[in] objname null-terminated object name string. If NULL, +no name will be assigned to the object. +@param[in] attribute one of [#__itt_attr_barrier, #__itt_attr_mutex] + */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_sync_createA(void *addr, const char *objtype, const char *objname, int attribute); +void ITTAPI __itt_sync_createW(void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_sync_create __itt_sync_createW +# define __itt_sync_create_ptr __itt_sync_createW_ptr +#else /* UNICODE */ +# define __itt_sync_create __itt_sync_createA +# define __itt_sync_create_ptr __itt_sync_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_sync_create (void *addr, const char *objtype, const char *objname, int attribute); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, sync_createA, (void *addr, const char *objtype, const char *objname, int attribute)) +ITT_STUBV(ITTAPI, void, sync_createW, (void *addr, const wchar_t *objtype, const wchar_t *objname, int attribute)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUBV(ITTAPI, void, sync_create, (void *addr, const char* objtype, const char* objname, int attribute)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_createA ITTNOTIFY_VOID(sync_createA) +#define __itt_sync_createA_ptr ITTNOTIFY_NAME(sync_createA) +#define __itt_sync_createW ITTNOTIFY_VOID(sync_createW) +#define __itt_sync_createW_ptr ITTNOTIFY_NAME(sync_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_create ITTNOTIFY_VOID(sync_create) +#define __itt_sync_create_ptr ITTNOTIFY_NAME(sync_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_createA(addr, objtype, objname, attribute) +#define __itt_sync_createA_ptr 0 +#define __itt_sync_createW(addr, objtype, objname, attribute) +#define __itt_sync_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_create(addr, objtype, objname, attribute) +#define __itt_sync_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_createA_ptr 0 +#define __itt_sync_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** +@brief Rename a synchronization object + +You can use the rename call to assign or reassign a name to a given +synchronization object. +@param[in] addr handle for the synchronization object. +@param[in] name null-terminated object name string. +*/ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_sync_renameA(void *addr, const char *name); +void ITTAPI __itt_sync_renameW(void *addr, const wchar_t *name); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_sync_rename __itt_sync_renameW +# define __itt_sync_rename_ptr __itt_sync_renameW_ptr +#else /* UNICODE */ +# define __itt_sync_rename __itt_sync_renameA +# define __itt_sync_rename_ptr __itt_sync_renameA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_sync_rename(void *addr, const char *name); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, sync_renameA, (void *addr, const char *name)) +ITT_STUBV(ITTAPI, void, sync_renameW, (void *addr, const wchar_t *name)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUBV(ITTAPI, void, sync_rename, (void *addr, const char *name)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_renameA ITTNOTIFY_VOID(sync_renameA) +#define __itt_sync_renameA_ptr ITTNOTIFY_NAME(sync_renameA) +#define __itt_sync_renameW ITTNOTIFY_VOID(sync_renameW) +#define __itt_sync_renameW_ptr ITTNOTIFY_NAME(sync_renameW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_rename ITTNOTIFY_VOID(sync_rename) +#define __itt_sync_rename_ptr ITTNOTIFY_NAME(sync_rename) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_renameA(addr, name) +#define __itt_sync_renameA_ptr 0 +#define __itt_sync_renameW(addr, name) +#define __itt_sync_renameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_rename(addr, name) +#define __itt_sync_rename_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_sync_renameA_ptr 0 +#define __itt_sync_renameW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_sync_rename_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + @brief Destroy a synchronization object. + @param addr Handle for the synchronization object. + */ +void ITTAPI __itt_sync_destroy(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, sync_destroy, (void *addr)) +#define __itt_sync_destroy ITTNOTIFY_VOID(sync_destroy) +#define __itt_sync_destroy_ptr ITTNOTIFY_NAME(sync_destroy) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_sync_destroy(addr) +#define __itt_sync_destroy_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_sync_destroy_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/*****************************************************************//** + * @name group of functions is used for performance measurement tools + *********************************************************************/ +/** @{ */ +/** + * @brief Enter spin loop on user-defined sync object + */ +void ITTAPI __itt_sync_prepare(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, sync_prepare, (void *addr)) +#define __itt_sync_prepare ITTNOTIFY_VOID(sync_prepare) +#define __itt_sync_prepare_ptr ITTNOTIFY_NAME(sync_prepare) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_sync_prepare(addr) +#define __itt_sync_prepare_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_sync_prepare_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Quit spin loop without acquiring spin object + */ +void ITTAPI __itt_sync_cancel(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, sync_cancel, (void *addr)) +#define __itt_sync_cancel ITTNOTIFY_VOID(sync_cancel) +#define __itt_sync_cancel_ptr ITTNOTIFY_NAME(sync_cancel) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_sync_cancel(addr) +#define __itt_sync_cancel_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_sync_cancel_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Successful spin loop completion (sync object acquired) + */ +void ITTAPI __itt_sync_acquired(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, sync_acquired, (void *addr)) +#define __itt_sync_acquired ITTNOTIFY_VOID(sync_acquired) +#define __itt_sync_acquired_ptr ITTNOTIFY_NAME(sync_acquired) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_sync_acquired(addr) +#define __itt_sync_acquired_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_sync_acquired_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Start sync object releasing code. Is called before the lock release call. + */ +void ITTAPI __itt_sync_releasing(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, sync_releasing, (void *addr)) +#define __itt_sync_releasing ITTNOTIFY_VOID(sync_releasing) +#define __itt_sync_releasing_ptr ITTNOTIFY_NAME(sync_releasing) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_sync_releasing(addr) +#define __itt_sync_releasing_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_sync_releasing_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} */ + +/** @} sync group */ + +/**************************************************************//** + * @name group of functions is used for correctness checking tools + ******************************************************************/ +/** @{ */ +/** + * @ingroup legacy + * @deprecated Legacy API + * @brief Fast synchronization which does no require spinning. + * - This special function is to be used by TBB and OpenMP libraries only when they know + * there is no spin but they need to suppress TC warnings about shared variable modifications. + * - It only has corresponding pointers in static library and does not have corresponding function + * in dynamic library. + * @see void __itt_sync_prepare(void* addr); + */ +void ITTAPI __itt_fsync_prepare(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, fsync_prepare, (void *addr)) +#define __itt_fsync_prepare ITTNOTIFY_VOID(fsync_prepare) +#define __itt_fsync_prepare_ptr ITTNOTIFY_NAME(fsync_prepare) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_fsync_prepare(addr) +#define __itt_fsync_prepare_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_fsync_prepare_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup legacy + * @deprecated Legacy API + * @brief Fast synchronization which does no require spinning. + * - This special function is to be used by TBB and OpenMP libraries only when they know + * there is no spin but they need to suppress TC warnings about shared variable modifications. + * - It only has corresponding pointers in static library and does not have corresponding function + * in dynamic library. + * @see void __itt_sync_cancel(void *addr); + */ +void ITTAPI __itt_fsync_cancel(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, fsync_cancel, (void *addr)) +#define __itt_fsync_cancel ITTNOTIFY_VOID(fsync_cancel) +#define __itt_fsync_cancel_ptr ITTNOTIFY_NAME(fsync_cancel) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_fsync_cancel(addr) +#define __itt_fsync_cancel_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_fsync_cancel_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup legacy + * @deprecated Legacy API + * @brief Fast synchronization which does no require spinning. + * - This special function is to be used by TBB and OpenMP libraries only when they know + * there is no spin but they need to suppress TC warnings about shared variable modifications. + * - It only has corresponding pointers in static library and does not have corresponding function + * in dynamic library. + * @see void __itt_sync_acquired(void *addr); + */ +void ITTAPI __itt_fsync_acquired(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, fsync_acquired, (void *addr)) +#define __itt_fsync_acquired ITTNOTIFY_VOID(fsync_acquired) +#define __itt_fsync_acquired_ptr ITTNOTIFY_NAME(fsync_acquired) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_fsync_acquired(addr) +#define __itt_fsync_acquired_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_fsync_acquired_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup legacy + * @deprecated Legacy API + * @brief Fast synchronization which does no require spinning. + * - This special function is to be used by TBB and OpenMP libraries only when they know + * there is no spin but they need to suppress TC warnings about shared variable modifications. + * - It only has corresponding pointers in static library and does not have corresponding function + * in dynamic library. + * @see void __itt_sync_releasing(void* addr); + */ +void ITTAPI __itt_fsync_releasing(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, fsync_releasing, (void *addr)) +#define __itt_fsync_releasing ITTNOTIFY_VOID(fsync_releasing) +#define __itt_fsync_releasing_ptr ITTNOTIFY_NAME(fsync_releasing) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_fsync_releasing(addr) +#define __itt_fsync_releasing_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_fsync_releasing_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} */ + +/** + * @defgroup model Modeling by Intel(R) Parallel Advisor + * @ingroup public + * This is the subset of itt used for modeling by Intel(R) Parallel Advisor. + * This API is called ONLY using annotate.h, by "Annotation" macros + * the user places in their sources during the parallelism modeling steps. + * + * site_begin/end and task_begin/end take the address of handle variables, + * which are writeable by the API. Handles must be 0 initialized prior + * to the first call to begin, or may cause a run-time failure. + * The handles are initialized in a multi-thread safe way by the API if + * the handle is 0. The commonly expected idiom is one static handle to + * identify a site or task. If a site or task of the same name has already + * been started during this collection, the same handle MAY be returned, + * but is not required to be - it is unspecified if data merging is done + * based on name. These routines also take an instance variable. Like + * the lexical instance, these must be 0 initialized. Unlike the lexical + * instance, this is used to track a single dynamic instance. + * + * API used by the Intel(R) Parallel Advisor to describe potential concurrency + * and related activities. User-added source annotations expand to calls + * to these procedures to enable modeling of a hypothetical concurrent + * execution serially. + * @{ + */ +#if !defined(_ADVISOR_ANNOTATE_H_) || defined(ANNOTATE_EXPAND_NULL) + +typedef void* __itt_model_site; /*!< @brief handle for lexical site */ +typedef void* __itt_model_site_instance; /*!< @brief handle for dynamic instance */ +typedef void* __itt_model_task; /*!< @brief handle for lexical site */ +typedef void* __itt_model_task_instance; /*!< @brief handle for dynamic instance */ + +/** + * @enum __itt_model_disable + * @brief Enumerator for the disable methods + */ +typedef enum { + __itt_model_disable_observation, + __itt_model_disable_collection +} __itt_model_disable; + +#endif /* !_ADVISOR_ANNOTATE_H_ || ANNOTATE_EXPAND_NULL */ + +/** + * @brief ANNOTATE_SITE_BEGIN/ANNOTATE_SITE_END support. + * + * site_begin/end model a potential concurrency site. + * site instances may be recursively nested with themselves. + * site_end exits the most recently started but unended site for the current + * thread. The handle passed to end may be used to validate structure. + * Instances of a site encountered on different threads concurrently + * are considered completely distinct. If the site name for two different + * lexical sites match, it is unspecified whether they are treated as the + * same or different for data presentation. + */ +void ITTAPI __itt_model_site_begin(__itt_model_site *site, __itt_model_site_instance *instance, const char *name); +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_model_site_beginW(const wchar_t *name); +#endif +void ITTAPI __itt_model_site_beginA(const char *name); +void ITTAPI __itt_model_site_beginAL(const char *name, size_t siteNameLen); +void ITTAPI __itt_model_site_end (__itt_model_site *site, __itt_model_site_instance *instance); +void ITTAPI __itt_model_site_end_2(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_site_begin, (__itt_model_site *site, __itt_model_site_instance *instance, const char *name)) +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, model_site_beginW, (const wchar_t *name)) +#endif +ITT_STUBV(ITTAPI, void, model_site_beginA, (const char *name)) +ITT_STUBV(ITTAPI, void, model_site_beginAL, (const char *name, size_t siteNameLen)) +ITT_STUBV(ITTAPI, void, model_site_end, (__itt_model_site *site, __itt_model_site_instance *instance)) +ITT_STUBV(ITTAPI, void, model_site_end_2, (void)) +#define __itt_model_site_begin ITTNOTIFY_VOID(model_site_begin) +#define __itt_model_site_begin_ptr ITTNOTIFY_NAME(model_site_begin) +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_model_site_beginW ITTNOTIFY_VOID(model_site_beginW) +#define __itt_model_site_beginW_ptr ITTNOTIFY_NAME(model_site_beginW) +#endif +#define __itt_model_site_beginA ITTNOTIFY_VOID(model_site_beginA) +#define __itt_model_site_beginA_ptr ITTNOTIFY_NAME(model_site_beginA) +#define __itt_model_site_beginAL ITTNOTIFY_VOID(model_site_beginAL) +#define __itt_model_site_beginAL_ptr ITTNOTIFY_NAME(model_site_beginAL) +#define __itt_model_site_end ITTNOTIFY_VOID(model_site_end) +#define __itt_model_site_end_ptr ITTNOTIFY_NAME(model_site_end) +#define __itt_model_site_end_2 ITTNOTIFY_VOID(model_site_end_2) +#define __itt_model_site_end_2_ptr ITTNOTIFY_NAME(model_site_end_2) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_site_begin(site, instance, name) +#define __itt_model_site_begin_ptr 0 +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_model_site_beginW(name) +#define __itt_model_site_beginW_ptr 0 +#endif +#define __itt_model_site_beginA(name) +#define __itt_model_site_beginA_ptr 0 +#define __itt_model_site_beginAL(name, siteNameLen) +#define __itt_model_site_beginAL_ptr 0 +#define __itt_model_site_end(site, instance) +#define __itt_model_site_end_ptr 0 +#define __itt_model_site_end_2() +#define __itt_model_site_end_2_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_site_begin_ptr 0 +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_model_site_beginW_ptr 0 +#endif +#define __itt_model_site_beginA_ptr 0 +#define __itt_model_site_beginAL_ptr 0 +#define __itt_model_site_end_ptr 0 +#define __itt_model_site_end_2_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_TASK_BEGIN/ANNOTATE_TASK_END support + * + * task_begin/end model a potential task, which is contained within the most + * closely enclosing dynamic site. task_end exits the most recently started + * but unended task. The handle passed to end may be used to validate + * structure. It is unspecified if bad dynamic nesting is detected. If it + * is, it should be encoded in the resulting data collection. The collector + * should not fail due to construct nesting issues, nor attempt to directly + * indicate the problem. + */ +void ITTAPI __itt_model_task_begin(__itt_model_task *task, __itt_model_task_instance *instance, const char *name); +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_model_task_beginW(const wchar_t *name); +void ITTAPI __itt_model_iteration_taskW(const wchar_t *name); +#endif +void ITTAPI __itt_model_task_beginA(const char *name); +void ITTAPI __itt_model_task_beginAL(const char *name, size_t taskNameLen); +void ITTAPI __itt_model_iteration_taskA(const char *name); +void ITTAPI __itt_model_iteration_taskAL(const char *name, size_t taskNameLen); +void ITTAPI __itt_model_task_end (__itt_model_task *task, __itt_model_task_instance *instance); +void ITTAPI __itt_model_task_end_2(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_task_begin, (__itt_model_task *task, __itt_model_task_instance *instance, const char *name)) +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, model_task_beginW, (const wchar_t *name)) +ITT_STUBV(ITTAPI, void, model_iteration_taskW, (const wchar_t *name)) +#endif +ITT_STUBV(ITTAPI, void, model_task_beginA, (const char *name)) +ITT_STUBV(ITTAPI, void, model_task_beginAL, (const char *name, size_t taskNameLen)) +ITT_STUBV(ITTAPI, void, model_iteration_taskA, (const char *name)) +ITT_STUBV(ITTAPI, void, model_iteration_taskAL, (const char *name, size_t taskNameLen)) +ITT_STUBV(ITTAPI, void, model_task_end, (__itt_model_task *task, __itt_model_task_instance *instance)) +ITT_STUBV(ITTAPI, void, model_task_end_2, (void)) +#define __itt_model_task_begin ITTNOTIFY_VOID(model_task_begin) +#define __itt_model_task_begin_ptr ITTNOTIFY_NAME(model_task_begin) +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_model_task_beginW ITTNOTIFY_VOID(model_task_beginW) +#define __itt_model_task_beginW_ptr ITTNOTIFY_NAME(model_task_beginW) +#define __itt_model_iteration_taskW ITTNOTIFY_VOID(model_iteration_taskW) +#define __itt_model_iteration_taskW_ptr ITTNOTIFY_NAME(model_iteration_taskW) +#endif +#define __itt_model_task_beginA ITTNOTIFY_VOID(model_task_beginA) +#define __itt_model_task_beginA_ptr ITTNOTIFY_NAME(model_task_beginA) +#define __itt_model_task_beginAL ITTNOTIFY_VOID(model_task_beginAL) +#define __itt_model_task_beginAL_ptr ITTNOTIFY_NAME(model_task_beginAL) +#define __itt_model_iteration_taskA ITTNOTIFY_VOID(model_iteration_taskA) +#define __itt_model_iteration_taskA_ptr ITTNOTIFY_NAME(model_iteration_taskA) +#define __itt_model_iteration_taskAL ITTNOTIFY_VOID(model_iteration_taskAL) +#define __itt_model_iteration_taskAL_ptr ITTNOTIFY_NAME(model_iteration_taskAL) +#define __itt_model_task_end ITTNOTIFY_VOID(model_task_end) +#define __itt_model_task_end_ptr ITTNOTIFY_NAME(model_task_end) +#define __itt_model_task_end_2 ITTNOTIFY_VOID(model_task_end_2) +#define __itt_model_task_end_2_ptr ITTNOTIFY_NAME(model_task_end_2) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_task_begin(task, instance, name) +#define __itt_model_task_begin_ptr 0 +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_model_task_beginW(name) +#define __itt_model_task_beginW_ptr 0 +#endif +#define __itt_model_task_beginA(name) +#define __itt_model_task_beginA_ptr 0 +#define __itt_model_task_beginAL(name, siteNameLen) +#define __itt_model_task_beginAL_ptr 0 +#define __itt_model_iteration_taskA(name) +#define __itt_model_iteration_taskA_ptr 0 +#define __itt_model_iteration_taskAL(name, siteNameLen) +#define __itt_model_iteration_taskAL_ptr 0 +#define __itt_model_task_end(task, instance) +#define __itt_model_task_end_ptr 0 +#define __itt_model_task_end_2() +#define __itt_model_task_end_2_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_task_begin_ptr 0 +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_model_task_beginW_ptr 0 +#endif +#define __itt_model_task_beginA_ptr 0 +#define __itt_model_task_beginAL_ptr 0 +#define __itt_model_iteration_taskA_ptr 0 +#define __itt_model_iteration_taskAL_ptr 0 +#define __itt_model_task_end_ptr 0 +#define __itt_model_task_end_2_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_LOCK_ACQUIRE/ANNOTATE_LOCK_RELEASE support + * + * lock_acquire/release model a potential lock for both lockset and + * performance modeling. Each unique address is modeled as a separate + * lock, with invalid addresses being valid lock IDs. Specifically: + * no storage is accessed by the API at the specified address - it is only + * used for lock identification. Lock acquires may be self-nested and are + * unlocked by a corresponding number of releases. + * (These closely correspond to __itt_sync_acquired/__itt_sync_releasing, + * but may not have identical semantics.) + */ +void ITTAPI __itt_model_lock_acquire(void *lock); +void ITTAPI __itt_model_lock_acquire_2(void *lock); +void ITTAPI __itt_model_lock_release(void *lock); +void ITTAPI __itt_model_lock_release_2(void *lock); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_lock_acquire, (void *lock)) +ITT_STUBV(ITTAPI, void, model_lock_acquire_2, (void *lock)) +ITT_STUBV(ITTAPI, void, model_lock_release, (void *lock)) +ITT_STUBV(ITTAPI, void, model_lock_release_2, (void *lock)) +#define __itt_model_lock_acquire ITTNOTIFY_VOID(model_lock_acquire) +#define __itt_model_lock_acquire_ptr ITTNOTIFY_NAME(model_lock_acquire) +#define __itt_model_lock_acquire_2 ITTNOTIFY_VOID(model_lock_acquire_2) +#define __itt_model_lock_acquire_2_ptr ITTNOTIFY_NAME(model_lock_acquire_2) +#define __itt_model_lock_release ITTNOTIFY_VOID(model_lock_release) +#define __itt_model_lock_release_ptr ITTNOTIFY_NAME(model_lock_release) +#define __itt_model_lock_release_2 ITTNOTIFY_VOID(model_lock_release_2) +#define __itt_model_lock_release_2_ptr ITTNOTIFY_NAME(model_lock_release_2) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_lock_acquire(lock) +#define __itt_model_lock_acquire_ptr 0 +#define __itt_model_lock_acquire_2(lock) +#define __itt_model_lock_acquire_2_ptr 0 +#define __itt_model_lock_release(lock) +#define __itt_model_lock_release_ptr 0 +#define __itt_model_lock_release_2(lock) +#define __itt_model_lock_release_2_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_lock_acquire_ptr 0 +#define __itt_model_lock_acquire_2_ptr 0 +#define __itt_model_lock_release_ptr 0 +#define __itt_model_lock_release_2_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_RECORD_ALLOCATION/ANNOTATE_RECORD_DEALLOCATION support + * + * record_allocation/deallocation describe user-defined memory allocator + * behavior, which may be required for correctness modeling to understand + * when storage is not expected to be actually reused across threads. + */ +void ITTAPI __itt_model_record_allocation (void *addr, size_t size); +void ITTAPI __itt_model_record_deallocation(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_record_allocation, (void *addr, size_t size)) +ITT_STUBV(ITTAPI, void, model_record_deallocation, (void *addr)) +#define __itt_model_record_allocation ITTNOTIFY_VOID(model_record_allocation) +#define __itt_model_record_allocation_ptr ITTNOTIFY_NAME(model_record_allocation) +#define __itt_model_record_deallocation ITTNOTIFY_VOID(model_record_deallocation) +#define __itt_model_record_deallocation_ptr ITTNOTIFY_NAME(model_record_deallocation) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_record_allocation(addr, size) +#define __itt_model_record_allocation_ptr 0 +#define __itt_model_record_deallocation(addr) +#define __itt_model_record_deallocation_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_record_allocation_ptr 0 +#define __itt_model_record_deallocation_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_INDUCTION_USES support + * + * Note particular storage is inductive through the end of the current site + */ +void ITTAPI __itt_model_induction_uses(void* addr, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_induction_uses, (void *addr, size_t size)) +#define __itt_model_induction_uses ITTNOTIFY_VOID(model_induction_uses) +#define __itt_model_induction_uses_ptr ITTNOTIFY_NAME(model_induction_uses) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_induction_uses(addr, size) +#define __itt_model_induction_uses_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_induction_uses_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_REDUCTION_USES support + * + * Note particular storage is used for reduction through the end + * of the current site + */ +void ITTAPI __itt_model_reduction_uses(void* addr, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_reduction_uses, (void *addr, size_t size)) +#define __itt_model_reduction_uses ITTNOTIFY_VOID(model_reduction_uses) +#define __itt_model_reduction_uses_ptr ITTNOTIFY_NAME(model_reduction_uses) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_reduction_uses(addr, size) +#define __itt_model_reduction_uses_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_reduction_uses_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_OBSERVE_USES support + * + * Have correctness modeling record observations about uses of storage + * through the end of the current site + */ +void ITTAPI __itt_model_observe_uses(void* addr, size_t size); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_observe_uses, (void *addr, size_t size)) +#define __itt_model_observe_uses ITTNOTIFY_VOID(model_observe_uses) +#define __itt_model_observe_uses_ptr ITTNOTIFY_NAME(model_observe_uses) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_observe_uses(addr, size) +#define __itt_model_observe_uses_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_observe_uses_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_CLEAR_USES support + * + * Clear the special handling of a piece of storage related to induction, + * reduction or observe_uses + */ +void ITTAPI __itt_model_clear_uses(void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_clear_uses, (void *addr)) +#define __itt_model_clear_uses ITTNOTIFY_VOID(model_clear_uses) +#define __itt_model_clear_uses_ptr ITTNOTIFY_NAME(model_clear_uses) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_clear_uses(addr) +#define __itt_model_clear_uses_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_clear_uses_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief ANNOTATE_DISABLE_*_PUSH/ANNOTATE_DISABLE_*_POP support + * + * disable_push/disable_pop push and pop disabling based on a parameter. + * Disabling observations stops processing of memory references during + * correctness modeling, and all annotations that occur in the disabled + * region. This allows description of code that is expected to be handled + * specially during conversion to parallelism or that is not recognized + * by tools (e.g. some kinds of synchronization operations.) + * This mechanism causes all annotations in the disabled region, other + * than disable_push and disable_pop, to be ignored. (For example, this + * might validly be used to disable an entire parallel site and the contained + * tasks and locking in it for data collection purposes.) + * The disable for collection is a more expensive operation, but reduces + * collector overhead significantly. This applies to BOTH correctness data + * collection and performance data collection. For example, a site + * containing a task might only enable data collection for the first 10 + * iterations. Both performance and correctness data should reflect this, + * and the program should run as close to full speed as possible when + * collection is disabled. + */ +void ITTAPI __itt_model_disable_push(__itt_model_disable x); +void ITTAPI __itt_model_disable_pop(void); +void ITTAPI __itt_model_aggregate_task(size_t x); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, model_disable_push, (__itt_model_disable x)) +ITT_STUBV(ITTAPI, void, model_disable_pop, (void)) +ITT_STUBV(ITTAPI, void, model_aggregate_task, (size_t x)) +#define __itt_model_disable_push ITTNOTIFY_VOID(model_disable_push) +#define __itt_model_disable_push_ptr ITTNOTIFY_NAME(model_disable_push) +#define __itt_model_disable_pop ITTNOTIFY_VOID(model_disable_pop) +#define __itt_model_disable_pop_ptr ITTNOTIFY_NAME(model_disable_pop) +#define __itt_model_aggregate_task ITTNOTIFY_VOID(model_aggregate_task) +#define __itt_model_aggregate_task_ptr ITTNOTIFY_NAME(model_aggregate_task) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_model_disable_push(x) +#define __itt_model_disable_push_ptr 0 +#define __itt_model_disable_pop() +#define __itt_model_disable_pop_ptr 0 +#define __itt_model_aggregate_task(x) +#define __itt_model_aggregate_task_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_model_disable_push_ptr 0 +#define __itt_model_disable_pop_ptr 0 +#define __itt_model_aggregate_task_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} model group */ + +/** + * @defgroup heap Heap + * @ingroup public + * Heap group + * @{ + */ + +typedef void* __itt_heap_function; + +/** + * @brief Create an identification for heap function + * @return non-zero identifier or NULL + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_heap_function ITTAPI __itt_heap_function_createA(const char* name, const char* domain); +__itt_heap_function ITTAPI __itt_heap_function_createW(const wchar_t* name, const wchar_t* domain); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_heap_function_create __itt_heap_function_createW +# define __itt_heap_function_create_ptr __itt_heap_function_createW_ptr +#else +# define __itt_heap_function_create __itt_heap_function_createA +# define __itt_heap_function_create_ptr __itt_heap_function_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_heap_function ITTAPI __itt_heap_function_create(const char* name, const char* domain); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_heap_function, heap_function_createA, (const char* name, const char* domain)) +ITT_STUB(ITTAPI, __itt_heap_function, heap_function_createW, (const wchar_t* name, const wchar_t* domain)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_heap_function, heap_function_create, (const char* name, const char* domain)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_heap_function_createA ITTNOTIFY_DATA(heap_function_createA) +#define __itt_heap_function_createA_ptr ITTNOTIFY_NAME(heap_function_createA) +#define __itt_heap_function_createW ITTNOTIFY_DATA(heap_function_createW) +#define __itt_heap_function_createW_ptr ITTNOTIFY_NAME(heap_function_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_heap_function_create ITTNOTIFY_DATA(heap_function_create) +#define __itt_heap_function_create_ptr ITTNOTIFY_NAME(heap_function_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_heap_function_createA(name, domain) (__itt_heap_function)0 +#define __itt_heap_function_createA_ptr 0 +#define __itt_heap_function_createW(name, domain) (__itt_heap_function)0 +#define __itt_heap_function_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_heap_function_create(name, domain) (__itt_heap_function)0 +#define __itt_heap_function_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_heap_function_createA_ptr 0 +#define __itt_heap_function_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_heap_function_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record an allocation begin occurrence. + */ +void ITTAPI __itt_heap_allocate_begin(__itt_heap_function h, size_t size, int initialized); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_allocate_begin, (__itt_heap_function h, size_t size, int initialized)) +#define __itt_heap_allocate_begin ITTNOTIFY_VOID(heap_allocate_begin) +#define __itt_heap_allocate_begin_ptr ITTNOTIFY_NAME(heap_allocate_begin) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_allocate_begin(h, size, initialized) +#define __itt_heap_allocate_begin_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_allocate_begin_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record an allocation end occurrence. + */ +void ITTAPI __itt_heap_allocate_end(__itt_heap_function h, void** addr, size_t size, int initialized); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_allocate_end, (__itt_heap_function h, void** addr, size_t size, int initialized)) +#define __itt_heap_allocate_end ITTNOTIFY_VOID(heap_allocate_end) +#define __itt_heap_allocate_end_ptr ITTNOTIFY_NAME(heap_allocate_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_allocate_end(h, addr, size, initialized) +#define __itt_heap_allocate_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_allocate_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record a free begin occurrence. + */ +void ITTAPI __itt_heap_free_begin(__itt_heap_function h, void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_free_begin, (__itt_heap_function h, void* addr)) +#define __itt_heap_free_begin ITTNOTIFY_VOID(heap_free_begin) +#define __itt_heap_free_begin_ptr ITTNOTIFY_NAME(heap_free_begin) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_free_begin(h, addr) +#define __itt_heap_free_begin_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_free_begin_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record a free end occurrence. + */ +void ITTAPI __itt_heap_free_end(__itt_heap_function h, void* addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_free_end, (__itt_heap_function h, void* addr)) +#define __itt_heap_free_end ITTNOTIFY_VOID(heap_free_end) +#define __itt_heap_free_end_ptr ITTNOTIFY_NAME(heap_free_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_free_end(h, addr) +#define __itt_heap_free_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_free_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record a reallocation begin occurrence. + */ +void ITTAPI __itt_heap_reallocate_begin(__itt_heap_function h, void* addr, size_t new_size, int initialized); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_reallocate_begin, (__itt_heap_function h, void* addr, size_t new_size, int initialized)) +#define __itt_heap_reallocate_begin ITTNOTIFY_VOID(heap_reallocate_begin) +#define __itt_heap_reallocate_begin_ptr ITTNOTIFY_NAME(heap_reallocate_begin) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_reallocate_begin(h, addr, new_size, initialized) +#define __itt_heap_reallocate_begin_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_reallocate_begin_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record a reallocation end occurrence. + */ +void ITTAPI __itt_heap_reallocate_end(__itt_heap_function h, void* addr, void** new_addr, size_t new_size, int initialized); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_reallocate_end, (__itt_heap_function h, void* addr, void** new_addr, size_t new_size, int initialized)) +#define __itt_heap_reallocate_end ITTNOTIFY_VOID(heap_reallocate_end) +#define __itt_heap_reallocate_end_ptr ITTNOTIFY_NAME(heap_reallocate_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_reallocate_end(h, addr, new_addr, new_size, initialized) +#define __itt_heap_reallocate_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_reallocate_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @brief internal access begin */ +void ITTAPI __itt_heap_internal_access_begin(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_internal_access_begin, (void)) +#define __itt_heap_internal_access_begin ITTNOTIFY_VOID(heap_internal_access_begin) +#define __itt_heap_internal_access_begin_ptr ITTNOTIFY_NAME(heap_internal_access_begin) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_internal_access_begin() +#define __itt_heap_internal_access_begin_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_internal_access_begin_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @brief internal access end */ +void ITTAPI __itt_heap_internal_access_end(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_internal_access_end, (void)) +#define __itt_heap_internal_access_end ITTNOTIFY_VOID(heap_internal_access_end) +#define __itt_heap_internal_access_end_ptr ITTNOTIFY_NAME(heap_internal_access_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_internal_access_end() +#define __itt_heap_internal_access_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_internal_access_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @brief record memory growth begin */ +void ITTAPI __itt_heap_record_memory_growth_begin(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_record_memory_growth_begin, (void)) +#define __itt_heap_record_memory_growth_begin ITTNOTIFY_VOID(heap_record_memory_growth_begin) +#define __itt_heap_record_memory_growth_begin_ptr ITTNOTIFY_NAME(heap_record_memory_growth_begin) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_record_memory_growth_begin() +#define __itt_heap_record_memory_growth_begin_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_record_memory_growth_begin_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @brief record memory growth end */ +void ITTAPI __itt_heap_record_memory_growth_end(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_record_memory_growth_end, (void)) +#define __itt_heap_record_memory_growth_end ITTNOTIFY_VOID(heap_record_memory_growth_end) +#define __itt_heap_record_memory_growth_end_ptr ITTNOTIFY_NAME(heap_record_memory_growth_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_record_memory_growth_end() +#define __itt_heap_record_memory_growth_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_record_memory_growth_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Specify the type of heap detection/reporting to modify. + */ +/** + * @hideinitializer + * @brief Report on memory leaks. + */ +#define __itt_heap_leaks 0x00000001 + +/** + * @hideinitializer + * @brief Report on memory growth. + */ +#define __itt_heap_growth 0x00000002 + + +/** @brief heap reset detection */ +void ITTAPI __itt_heap_reset_detection(unsigned int reset_mask); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_reset_detection, (unsigned int reset_mask)) +#define __itt_heap_reset_detection ITTNOTIFY_VOID(heap_reset_detection) +#define __itt_heap_reset_detection_ptr ITTNOTIFY_NAME(heap_reset_detection) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_reset_detection() +#define __itt_heap_reset_detection_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_reset_detection_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @brief report */ +void ITTAPI __itt_heap_record(unsigned int record_mask); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, heap_record, (unsigned int record_mask)) +#define __itt_heap_record ITTNOTIFY_VOID(heap_record) +#define __itt_heap_record_ptr ITTNOTIFY_NAME(heap_record) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_heap_record() +#define __itt_heap_record_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_heap_record_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @} heap group */ +/** @endcond */ +/* ========================================================================== */ + +/** + * @defgroup domains Domains + * @ingroup public + * Domains group + * @{ + */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_domain +{ + volatile int flags; /*!< Zero if disabled, non-zero if enabled. The meaning of different non-zero values is reserved to the runtime */ + const char* nameA; /*!< Copy of original name in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; /*!< Copy of original name in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_domain* next; +} __itt_domain; + +#pragma pack(pop) +/** @endcond */ + +/** + * @ingroup domains + * @brief Create a domain. + * Create domain using some domain name: the URI naming style is recommended. + * Because the set of domains is expected to be static over the application's + * execution time, there is no mechanism to destroy a domain. + * Any domain can be accessed by any thread in the process, regardless of + * which thread created the domain. This call is thread-safe. + * @param[in] name name of domain + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_domain* ITTAPI __itt_domain_createA(const char *name); +__itt_domain* ITTAPI __itt_domain_createW(const wchar_t *name); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_domain_create __itt_domain_createW +# define __itt_domain_create_ptr __itt_domain_createW_ptr +#else /* UNICODE */ +# define __itt_domain_create __itt_domain_createA +# define __itt_domain_create_ptr __itt_domain_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_domain* ITTAPI __itt_domain_create(const char *name); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_domain*, domain_createA, (const char *name)) +ITT_STUB(ITTAPI, __itt_domain*, domain_createW, (const wchar_t *name)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_domain*, domain_create, (const char *name)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_domain_createA ITTNOTIFY_DATA(domain_createA) +#define __itt_domain_createA_ptr ITTNOTIFY_NAME(domain_createA) +#define __itt_domain_createW ITTNOTIFY_DATA(domain_createW) +#define __itt_domain_createW_ptr ITTNOTIFY_NAME(domain_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_domain_create ITTNOTIFY_DATA(domain_create) +#define __itt_domain_create_ptr ITTNOTIFY_NAME(domain_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_domain_createA(name) (__itt_domain*)0 +#define __itt_domain_createA_ptr 0 +#define __itt_domain_createW(name) (__itt_domain*)0 +#define __itt_domain_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_domain_create(name) (__itt_domain*)0 +#define __itt_domain_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_domain_createA_ptr 0 +#define __itt_domain_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_domain_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} domains group */ + +/** + * @defgroup ids IDs + * @ingroup public + * IDs group + * @{ + */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_id +{ + unsigned long long d1, d2, d3; +} __itt_id; + +#pragma pack(pop) +/** @endcond */ + +static const __itt_id __itt_null = { 0, 0, 0 }; + +/** + * @ingroup ids + * @brief A convenience function is provided to create an ID without domain control. + * @brief This is a convenience function to initialize an __itt_id structure. This function + * does not affect the collector runtime in any way. After you make the ID with this + * function, you still must create it with the __itt_id_create function before using the ID + * to identify a named entity. + * @param[in] addr The address of object; high QWORD of the ID value. + * @param[in] extra The extra data to unique identify object; low QWORD of the ID value. + */ + +ITT_INLINE __itt_id ITTAPI __itt_id_make(void* addr, unsigned long long extra) ITT_INLINE_ATTRIBUTE; +ITT_INLINE __itt_id ITTAPI __itt_id_make(void* addr, unsigned long long extra) +{ + __itt_id id = __itt_null; + id.d1 = (unsigned long long)((uintptr_t)addr); + id.d2 = (unsigned long long)extra; + id.d3 = (unsigned long long)0; /* Reserved. Must be zero */ + return id; +} + +/** + * @ingroup ids + * @brief Create an instance of identifier. + * This establishes the beginning of the lifetime of an instance of + * the given ID in the trace. Once this lifetime starts, the ID + * can be used to tag named entity instances in calls such as + * __itt_task_begin, and to specify relationships among + * identified named entity instances, using the \ref relations APIs. + * Instance IDs are not domain specific! + * @param[in] domain The domain controlling the execution of this call. + * @param[in] id The ID to create. + */ +void ITTAPI __itt_id_create(const __itt_domain *domain, __itt_id id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, id_create, (const __itt_domain *domain, __itt_id id)) +#define __itt_id_create(d,x) ITTNOTIFY_VOID_D1(id_create,d,x) +#define __itt_id_create_ptr ITTNOTIFY_NAME(id_create) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_id_create(domain,id) +#define __itt_id_create_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_id_create_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup ids + * @brief Destroy an instance of identifier. + * This ends the lifetime of the current instance of the given ID value in the trace. + * Any relationships that are established after this lifetime ends are invalid. + * This call must be performed before the given ID value can be reused for a different + * named entity instance. + * @param[in] domain The domain controlling the execution of this call. + * @param[in] id The ID to destroy. + */ +void ITTAPI __itt_id_destroy(const __itt_domain *domain, __itt_id id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, id_destroy, (const __itt_domain *domain, __itt_id id)) +#define __itt_id_destroy(d,x) ITTNOTIFY_VOID_D1(id_destroy,d,x) +#define __itt_id_destroy_ptr ITTNOTIFY_NAME(id_destroy) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_id_destroy(domain,id) +#define __itt_id_destroy_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_id_destroy_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} ids group */ + +/** + * @defgroup handless String Handles + * @ingroup public + * String Handles group + * @{ + */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_string_handle +{ + const char* strA; /*!< Copy of original string in ASCII. */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* strW; /*!< Copy of original string in UNICODE. */ +#else /* UNICODE || _UNICODE */ + void* strW; +#endif /* UNICODE || _UNICODE */ + int extra1; /*!< Reserved. Must be zero */ + void* extra2; /*!< Reserved. Must be zero */ + struct ___itt_string_handle* next; +} __itt_string_handle; + +#pragma pack(pop) +/** @endcond */ + +/** + * @ingroup handles + * @brief Create a string handle. + * Create and return handle value that can be associated with a string. + * Consecutive calls to __itt_string_handle_create with the same name + * return the same value. Because the set of string handles is expected to remain + * static during the application's execution time, there is no mechanism to destroy a string handle. + * Any string handle can be accessed by any thread in the process, regardless of which thread created + * the string handle. This call is thread-safe. + * @param[in] name The input string + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_string_handle* ITTAPI __itt_string_handle_createA(const char *name); +__itt_string_handle* ITTAPI __itt_string_handle_createW(const wchar_t *name); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_string_handle_create __itt_string_handle_createW +# define __itt_string_handle_create_ptr __itt_string_handle_createW_ptr +#else /* UNICODE */ +# define __itt_string_handle_create __itt_string_handle_createA +# define __itt_string_handle_create_ptr __itt_string_handle_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_string_handle* ITTAPI __itt_string_handle_create(const char *name); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_string_handle*, string_handle_createA, (const char *name)) +ITT_STUB(ITTAPI, __itt_string_handle*, string_handle_createW, (const wchar_t *name)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_string_handle*, string_handle_create, (const char *name)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_string_handle_createA ITTNOTIFY_DATA(string_handle_createA) +#define __itt_string_handle_createA_ptr ITTNOTIFY_NAME(string_handle_createA) +#define __itt_string_handle_createW ITTNOTIFY_DATA(string_handle_createW) +#define __itt_string_handle_createW_ptr ITTNOTIFY_NAME(string_handle_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_string_handle_create ITTNOTIFY_DATA(string_handle_create) +#define __itt_string_handle_create_ptr ITTNOTIFY_NAME(string_handle_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_string_handle_createA(name) (__itt_string_handle*)0 +#define __itt_string_handle_createA_ptr 0 +#define __itt_string_handle_createW(name) (__itt_string_handle*)0 +#define __itt_string_handle_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_string_handle_create(name) (__itt_string_handle*)0 +#define __itt_string_handle_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_string_handle_createA_ptr 0 +#define __itt_string_handle_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_string_handle_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} handles group */ + +/** @cond exclude_from_documentation */ +typedef unsigned long long __itt_timestamp; +/** @endcond */ + +#define __itt_timestamp_none ((__itt_timestamp)-1LL) + +/** @cond exclude_from_gpa_documentation */ + +/** + * @ingroup timestamps + * @brief Return timestamp corresponding to the current moment. + * This returns the timestamp in the format that is the most relevant for the current + * host or platform (RDTSC, QPC, and others). You can use the "<" operator to + * compare __itt_timestamp values. + */ +__itt_timestamp ITTAPI __itt_get_timestamp(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_timestamp, get_timestamp, (void)) +#define __itt_get_timestamp ITTNOTIFY_DATA(get_timestamp) +#define __itt_get_timestamp_ptr ITTNOTIFY_NAME(get_timestamp) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_get_timestamp() +#define __itt_get_timestamp_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_get_timestamp_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} timestamps */ +/** @endcond */ + +/** @cond exclude_from_gpa_documentation */ + +/** + * @defgroup regions Regions + * @ingroup public + * Regions group + * @{ + */ +/** + * @ingroup regions + * @brief Begin of region instance. + * Successive calls to __itt_region_begin with the same ID are ignored + * until a call to __itt_region_end with the same ID + * @param[in] domain The domain for this region instance + * @param[in] id The instance ID for this region instance. Must not be __itt_null + * @param[in] parentid The instance ID for the parent of this region instance, or __itt_null + * @param[in] name The name of this region + */ +void ITTAPI __itt_region_begin(const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name); + +/** + * @ingroup regions + * @brief End of region instance. + * The first call to __itt_region_end with a given ID ends the + * region. Successive calls with the same ID are ignored, as are + * calls that do not have a matching __itt_region_begin call. + * @param[in] domain The domain for this region instance + * @param[in] id The instance ID for this region instance + */ +void ITTAPI __itt_region_end(const __itt_domain *domain, __itt_id id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, region_begin, (const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name)) +ITT_STUBV(ITTAPI, void, region_end, (const __itt_domain *domain, __itt_id id)) +#define __itt_region_begin(d,x,y,z) ITTNOTIFY_VOID_D3(region_begin,d,x,y,z) +#define __itt_region_begin_ptr ITTNOTIFY_NAME(region_begin) +#define __itt_region_end(d,x) ITTNOTIFY_VOID_D1(region_end,d,x) +#define __itt_region_end_ptr ITTNOTIFY_NAME(region_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_region_begin(d,x,y,z) +#define __itt_region_begin_ptr 0 +#define __itt_region_end(d,x) +#define __itt_region_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_region_begin_ptr 0 +#define __itt_region_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} regions group */ + +/** + * @defgroup frames Frames + * @ingroup public + * Frames are similar to regions, but are intended to be easier to use and to implement. + * In particular: + * - Frames always represent periods of elapsed time + * - By default, frames have no nesting relationships + * @{ + */ + +/** + * @ingroup frames + * @brief Begin a frame instance. + * Successive calls to __itt_frame_begin with the + * same ID are ignored until a call to __itt_frame_end with the same ID. + * @param[in] domain The domain for this frame instance + * @param[in] id The instance ID for this frame instance or NULL + */ +void ITTAPI __itt_frame_begin_v3(const __itt_domain *domain, __itt_id *id); + +/** + * @ingroup frames + * @brief End a frame instance. + * The first call to __itt_frame_end with a given ID + * ends the frame. Successive calls with the same ID are ignored, as are + * calls that do not have a matching __itt_frame_begin call. + * @param[in] domain The domain for this frame instance + * @param[in] id The instance ID for this frame instance or NULL for current + */ +void ITTAPI __itt_frame_end_v3(const __itt_domain *domain, __itt_id *id); + +/** + * @ingroup frames + * @brief Submits a frame instance. + * Successive calls to __itt_frame_begin or __itt_frame_submit with the + * same ID are ignored until a call to __itt_frame_end or __itt_frame_submit + * with the same ID. + * Passing special __itt_timestamp_none value as "end" argument means + * take the current timestamp as the end timestamp. + * @param[in] domain The domain for this frame instance + * @param[in] id The instance ID for this frame instance or NULL + * @param[in] begin Timestamp of the beginning of the frame + * @param[in] end Timestamp of the end of the frame + */ +void ITTAPI __itt_frame_submit_v3(const __itt_domain *domain, __itt_id *id, + __itt_timestamp begin, __itt_timestamp end); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, frame_begin_v3, (const __itt_domain *domain, __itt_id *id)) +ITT_STUBV(ITTAPI, void, frame_end_v3, (const __itt_domain *domain, __itt_id *id)) +ITT_STUBV(ITTAPI, void, frame_submit_v3, (const __itt_domain *domain, __itt_id *id, __itt_timestamp begin, __itt_timestamp end)) +#define __itt_frame_begin_v3(d,x) ITTNOTIFY_VOID_D1(frame_begin_v3,d,x) +#define __itt_frame_begin_v3_ptr ITTNOTIFY_NAME(frame_begin_v3) +#define __itt_frame_end_v3(d,x) ITTNOTIFY_VOID_D1(frame_end_v3,d,x) +#define __itt_frame_end_v3_ptr ITTNOTIFY_NAME(frame_end_v3) +#define __itt_frame_submit_v3(d,x,b,e) ITTNOTIFY_VOID_D3(frame_submit_v3,d,x,b,e) +#define __itt_frame_submit_v3_ptr ITTNOTIFY_NAME(frame_submit_v3) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_frame_begin_v3(domain,id) +#define __itt_frame_begin_v3_ptr 0 +#define __itt_frame_end_v3(domain,id) +#define __itt_frame_end_v3_ptr 0 +#define __itt_frame_submit_v3(domain,id,begin,end) +#define __itt_frame_submit_v3_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_frame_begin_v3_ptr 0 +#define __itt_frame_end_v3_ptr 0 +#define __itt_frame_submit_v3_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} frames group */ +/** @endcond */ + +/** + * @defgroup taskgroup Task Group + * @ingroup public + * Task Group + * @{ + */ +/** + * @ingroup task_groups + * @brief Denotes a task_group instance. + * Successive calls to __itt_task_group with the same ID are ignored. + * @param[in] domain The domain for this task_group instance + * @param[in] id The instance ID for this task_group instance. Must not be __itt_null. + * @param[in] parentid The instance ID for the parent of this task_group instance, or __itt_null. + * @param[in] name The name of this task_group + */ +void ITTAPI __itt_task_group(const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, task_group, (const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name)) +#define __itt_task_group(d,x,y,z) ITTNOTIFY_VOID_D3(task_group,d,x,y,z) +#define __itt_task_group_ptr ITTNOTIFY_NAME(task_group) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_task_group(d,x,y,z) +#define __itt_task_group_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_task_group_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} taskgroup group */ + +/** + * @defgroup tasks Tasks + * @ingroup public + * A task instance represents a piece of work performed by a particular + * thread for a period of time. A call to __itt_task_begin creates a + * task instance. This becomes the current instance for that task on that + * thread. A following call to __itt_task_end on the same thread ends the + * instance. There may be multiple simultaneous instances of tasks with the + * same name on different threads. If an ID is specified, the task instance + * receives that ID. Nested tasks are allowed. + * + * Note: The task is defined by the bracketing of __itt_task_begin and + * __itt_task_end on the same thread. If some scheduling mechanism causes + * task switching (the thread executes a different user task) or task + * switching (the user task switches to a different thread) then this breaks + * the notion of current instance. Additional API calls are required to + * deal with that possibility. + * @{ + */ + +/** + * @ingroup tasks + * @brief Begin a task instance. + * @param[in] domain The domain for this task + * @param[in] taskid The instance ID for this task instance, or __itt_null + * @param[in] parentid The parent instance to which this task instance belongs, or __itt_null + * @param[in] name The name of this task + */ +void ITTAPI __itt_task_begin(const __itt_domain *domain, __itt_id taskid, __itt_id parentid, __itt_string_handle *name); + +/** + * @ingroup tasks + * @brief Begin a task instance. + * @param[in] domain The domain for this task + * @param[in] taskid The identifier for this task instance (may be 0) + * @param[in] parentid The parent of this task (may be 0) + * @param[in] fn The pointer to the function you are tracing + */ +void ITTAPI __itt_task_begin_fn(const __itt_domain *domain, __itt_id taskid, __itt_id parentid, void* fn); + +/** + * @ingroup tasks + * @brief End the current task instance. + * @param[in] domain The domain for this task + */ +void ITTAPI __itt_task_end(const __itt_domain *domain); + +/** + * @ingroup tasks + * @brief Begin an overlapped task instance. + * @param[in] domain The domain for this task. + * @param[in] taskid The identifier for this task instance, *cannot* be __itt_null. + * @param[in] parentid The parent of this task, or __itt_null. + * @param[in] name The name of this task. + */ +void ITTAPI __itt_task_begin_overlapped(const __itt_domain* domain, __itt_id taskid, __itt_id parentid, __itt_string_handle* name); + +/** + * @ingroup tasks + * @brief End an overlapped task instance. + * @param[in] domain The domain for this task + * @param[in] taskid Explicit ID of finished task + */ +void ITTAPI __itt_task_end_overlapped(const __itt_domain *domain, __itt_id taskid); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, task_begin, (const __itt_domain *domain, __itt_id id, __itt_id parentid, __itt_string_handle *name)) +ITT_STUBV(ITTAPI, void, task_begin_fn, (const __itt_domain *domain, __itt_id id, __itt_id parentid, void* fn)) +ITT_STUBV(ITTAPI, void, task_end, (const __itt_domain *domain)) +ITT_STUBV(ITTAPI, void, task_begin_overlapped, (const __itt_domain *domain, __itt_id taskid, __itt_id parentid, __itt_string_handle *name)) +ITT_STUBV(ITTAPI, void, task_end_overlapped, (const __itt_domain *domain, __itt_id taskid)) +#define __itt_task_begin(d,x,y,z) ITTNOTIFY_VOID_D3(task_begin,d,x,y,z) +#define __itt_task_begin_ptr ITTNOTIFY_NAME(task_begin) +#define __itt_task_begin_fn(d,x,y,z) ITTNOTIFY_VOID_D3(task_begin_fn,d,x,y,z) +#define __itt_task_begin_fn_ptr ITTNOTIFY_NAME(task_begin_fn) +#define __itt_task_end(d) ITTNOTIFY_VOID_D0(task_end,d) +#define __itt_task_end_ptr ITTNOTIFY_NAME(task_end) +#define __itt_task_begin_overlapped(d,x,y,z) ITTNOTIFY_VOID_D3(task_begin_overlapped,d,x,y,z) +#define __itt_task_begin_overlapped_ptr ITTNOTIFY_NAME(task_begin_overlapped) +#define __itt_task_end_overlapped(d,x) ITTNOTIFY_VOID_D1(task_end_overlapped,d,x) +#define __itt_task_end_overlapped_ptr ITTNOTIFY_NAME(task_end_overlapped) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_task_begin(domain,id,parentid,name) +#define __itt_task_begin_ptr 0 +#define __itt_task_begin_fn(domain,id,parentid,fn) +#define __itt_task_begin_fn_ptr 0 +#define __itt_task_end(domain) +#define __itt_task_end_ptr 0 +#define __itt_task_begin_overlapped(domain,taskid,parentid,name) +#define __itt_task_begin_overlapped_ptr 0 +#define __itt_task_end_overlapped(domain,taskid) +#define __itt_task_end_overlapped_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_task_begin_ptr 0 +#define __itt_task_begin_fn_ptr 0 +#define __itt_task_end_ptr 0 +#define __itt_task_begin_overlapped_ptr 0 +#define __itt_task_end_overlapped_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} tasks group */ + + +/** + * @defgroup markers Markers + * Markers represent a single discreet event in time. Markers have a scope, + * described by an enumerated type __itt_scope. Markers are created by + * the API call __itt_marker. A marker instance can be given an ID for use in + * adding metadata. + * @{ + */ + +/** + * @brief Describes the scope of an event object in the trace. + */ +typedef enum +{ + __itt_scope_unknown = 0, + __itt_scope_global, + __itt_scope_track_group, + __itt_scope_track, + __itt_scope_task, + __itt_scope_marker +} __itt_scope; + +/** @cond exclude_from_documentation */ +#define __itt_marker_scope_unknown __itt_scope_unknown +#define __itt_marker_scope_global __itt_scope_global +#define __itt_marker_scope_process __itt_scope_track_group +#define __itt_marker_scope_thread __itt_scope_track +#define __itt_marker_scope_task __itt_scope_task +/** @endcond */ + +/** + * @ingroup markers + * @brief Create a marker instance + * @param[in] domain The domain for this marker + * @param[in] id The instance ID for this marker or __itt_null + * @param[in] name The name for this marker + * @param[in] scope The scope for this marker + */ +void ITTAPI __itt_marker(const __itt_domain *domain, __itt_id id, __itt_string_handle *name, __itt_scope scope); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, marker, (const __itt_domain *domain, __itt_id id, __itt_string_handle *name, __itt_scope scope)) +#define __itt_marker(d,x,y,z) ITTNOTIFY_VOID_D3(marker,d,x,y,z) +#define __itt_marker_ptr ITTNOTIFY_NAME(marker) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_marker(domain,id,name,scope) +#define __itt_marker_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_marker_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} markers group */ + +/** + * @defgroup metadata Metadata + * The metadata API is used to attach extra information to named + * entities. Metadata can be attached to an identified named entity by ID, + * or to the current entity (which is always a task). + * + * Conceptually metadata has a type (what kind of metadata), a key (the + * name of the metadata), and a value (the actual data). The encoding of + * the value depends on the type of the metadata. + * + * The type of metadata is specified by an enumerated type __itt_metdata_type. + * @{ + */ + +/** + * @ingroup parameters + * @brief describes the type of metadata + */ +typedef enum { + __itt_metadata_unknown = 0, + __itt_metadata_u64, /**< Unsigned 64-bit integer */ + __itt_metadata_s64, /**< Signed 64-bit integer */ + __itt_metadata_u32, /**< Unsigned 32-bit integer */ + __itt_metadata_s32, /**< Signed 32-bit integer */ + __itt_metadata_u16, /**< Unsigned 16-bit integer */ + __itt_metadata_s16, /**< Signed 16-bit integer */ + __itt_metadata_float, /**< Signed 32-bit floating-point */ + __itt_metadata_double /**< SIgned 64-bit floating-point */ +} __itt_metadata_type; + +/** + * @ingroup parameters + * @brief Add metadata to an instance of a named entity. + * @param[in] domain The domain controlling the call + * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task + * @param[in] key The name of the metadata + * @param[in] type The type of the metadata + * @param[in] count The number of elements of the given type. If count == 0, no metadata will be added. + * @param[in] data The metadata itself +*/ +void ITTAPI __itt_metadata_add(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, metadata_add, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data)) +#define __itt_metadata_add(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(metadata_add,d,x,y,z,a,b) +#define __itt_metadata_add_ptr ITTNOTIFY_NAME(metadata_add) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_metadata_add(d,x,y,z,a,b) +#define __itt_metadata_add_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_metadata_add_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup parameters + * @brief Add string metadata to an instance of a named entity. + * @param[in] domain The domain controlling the call + * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task + * @param[in] key The name of the metadata + * @param[in] data The metadata itself + * @param[in] length The number of characters in the string, or -1 if the length is unknown but the string is null-terminated +*/ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_metadata_str_addA(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length); +void ITTAPI __itt_metadata_str_addW(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const wchar_t *data, size_t length); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_metadata_str_add __itt_metadata_str_addW +# define __itt_metadata_str_add_ptr __itt_metadata_str_addW_ptr +#else /* UNICODE */ +# define __itt_metadata_str_add __itt_metadata_str_addA +# define __itt_metadata_str_add_ptr __itt_metadata_str_addA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_metadata_str_add(const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length); +#endif + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, metadata_str_addA, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length)) +ITT_STUBV(ITTAPI, void, metadata_str_addW, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const wchar_t *data, size_t length)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUBV(ITTAPI, void, metadata_str_add, (const __itt_domain *domain, __itt_id id, __itt_string_handle *key, const char *data, size_t length)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_metadata_str_addA(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_addA,d,x,y,z,a) +#define __itt_metadata_str_addA_ptr ITTNOTIFY_NAME(metadata_str_addA) +#define __itt_metadata_str_addW(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_addW,d,x,y,z,a) +#define __itt_metadata_str_addW_ptr ITTNOTIFY_NAME(metadata_str_addW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_metadata_str_add(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add,d,x,y,z,a) +#define __itt_metadata_str_add_ptr ITTNOTIFY_NAME(metadata_str_add) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_metadata_str_addA(d,x,y,z,a) +#define __itt_metadata_str_addA_ptr 0 +#define __itt_metadata_str_addW(d,x,y,z,a) +#define __itt_metadata_str_addW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_metadata_str_add(d,x,y,z,a) +#define __itt_metadata_str_add_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_metadata_str_addA_ptr 0 +#define __itt_metadata_str_addW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_metadata_str_add_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup parameters + * @brief Add metadata to an instance of a named entity. + * @param[in] domain The domain controlling the call + * @param[in] scope The scope of the instance to which the metadata is to be added + + * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task + + * @param[in] key The name of the metadata + * @param[in] type The type of the metadata + * @param[in] count The number of elements of the given type. If count == 0, no metadata will be added. + * @param[in] data The metadata itself +*/ +void ITTAPI __itt_metadata_add_with_scope(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, metadata_add_with_scope, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, __itt_metadata_type type, size_t count, void *data)) +#define __itt_metadata_add_with_scope(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(metadata_add_with_scope,d,x,y,z,a,b) +#define __itt_metadata_add_with_scope_ptr ITTNOTIFY_NAME(metadata_add_with_scope) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_metadata_add_with_scope(d,x,y,z,a,b) +#define __itt_metadata_add_with_scope_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_metadata_add_with_scope_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup parameters + * @brief Add string metadata to an instance of a named entity. + * @param[in] domain The domain controlling the call + * @param[in] scope The scope of the instance to which the metadata is to be added + + * @param[in] id The identifier of the instance to which the metadata is to be added, or __itt_null to add to the current task + + * @param[in] key The name of the metadata + * @param[in] data The metadata itself + * @param[in] length The number of characters in the string, or -1 if the length is unknown but the string is null-terminated +*/ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_metadata_str_add_with_scopeA(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length); +void ITTAPI __itt_metadata_str_add_with_scopeW(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const wchar_t *data, size_t length); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_metadata_str_add_with_scope __itt_metadata_str_add_with_scopeW +# define __itt_metadata_str_add_with_scope_ptr __itt_metadata_str_add_with_scopeW_ptr +#else /* UNICODE */ +# define __itt_metadata_str_add_with_scope __itt_metadata_str_add_with_scopeA +# define __itt_metadata_str_add_with_scope_ptr __itt_metadata_str_add_with_scopeA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_metadata_str_add_with_scope(const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length); +#endif + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUBV(ITTAPI, void, metadata_str_add_with_scopeA, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length)) +ITT_STUBV(ITTAPI, void, metadata_str_add_with_scopeW, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const wchar_t *data, size_t length)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUBV(ITTAPI, void, metadata_str_add_with_scope, (const __itt_domain *domain, __itt_scope scope, __itt_string_handle *key, const char *data, size_t length)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_metadata_str_add_with_scopeA(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add_with_scopeA,d,x,y,z,a) +#define __itt_metadata_str_add_with_scopeA_ptr ITTNOTIFY_NAME(metadata_str_add_with_scopeA) +#define __itt_metadata_str_add_with_scopeW(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add_with_scopeW,d,x,y,z,a) +#define __itt_metadata_str_add_with_scopeW_ptr ITTNOTIFY_NAME(metadata_str_add_with_scopeW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_metadata_str_add_with_scope(d,x,y,z,a) ITTNOTIFY_VOID_D4(metadata_str_add_with_scope,d,x,y,z,a) +#define __itt_metadata_str_add_with_scope_ptr ITTNOTIFY_NAME(metadata_str_add_with_scope) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_metadata_str_add_with_scopeA(d,x,y,z,a) +#define __itt_metadata_str_add_with_scopeA_ptr 0 +#define __itt_metadata_str_add_with_scopeW(d,x,y,z,a) +#define __itt_metadata_str_add_with_scopeW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_metadata_str_add_with_scope(d,x,y,z,a) +#define __itt_metadata_str_add_with_scope_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_metadata_str_add_with_scopeA_ptr 0 +#define __itt_metadata_str_add_with_scopeW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_metadata_str_add_with_scope_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @} metadata group */ + +/** + * @defgroup relations Relations + * Instances of named entities can be explicitly associated with other + * instances using instance IDs and the relationship API calls. + * + * @{ + */ + +/** + * @ingroup relations + * @brief The kind of relation between two instances is specified by the enumerated type __itt_relation. + * Relations between instances can be added with an API call. The relation + * API uses instance IDs. Relations can be added before or after the actual + * instances are created and persist independently of the instances. This + * is the motivation for having different lifetimes for instance IDs and + * the actual instances. + */ +typedef enum +{ + __itt_relation_is_unknown = 0, + __itt_relation_is_dependent_on, /**< "A is dependent on B" means that A cannot start until B completes */ + __itt_relation_is_sibling_of, /**< "A is sibling of B" means that A and B were created as a group */ + __itt_relation_is_parent_of, /**< "A is parent of B" means that A created B */ + __itt_relation_is_continuation_of, /**< "A is continuation of B" means that A assumes the dependencies of B */ + __itt_relation_is_child_of, /**< "A is child of B" means that A was created by B (inverse of is_parent_of) */ + __itt_relation_is_continued_by, /**< "A is continued by B" means that B assumes the dependencies of A (inverse of is_continuation_of) */ + __itt_relation_is_predecessor_to /**< "A is predecessor to B" means that B cannot start until A completes (inverse of is_dependent_on) */ +} __itt_relation; + +/** + * @ingroup relations + * @brief Add a relation to the current task instance. + * The current task instance is the head of the relation. + * @param[in] domain The domain controlling this call + * @param[in] relation The kind of relation + * @param[in] tail The ID for the tail of the relation + */ +void ITTAPI __itt_relation_add_to_current(const __itt_domain *domain, __itt_relation relation, __itt_id tail); + +/** + * @ingroup relations + * @brief Add a relation between two instance identifiers. + * @param[in] domain The domain controlling this call + * @param[in] head The ID for the head of the relation + * @param[in] relation The kind of relation + * @param[in] tail The ID for the tail of the relation + */ +void ITTAPI __itt_relation_add(const __itt_domain *domain, __itt_id head, __itt_relation relation, __itt_id tail); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, relation_add_to_current, (const __itt_domain *domain, __itt_relation relation, __itt_id tail)) +ITT_STUBV(ITTAPI, void, relation_add, (const __itt_domain *domain, __itt_id head, __itt_relation relation, __itt_id tail)) +#define __itt_relation_add_to_current(d,x,y) ITTNOTIFY_VOID_D2(relation_add_to_current,d,x,y) +#define __itt_relation_add_to_current_ptr ITTNOTIFY_NAME(relation_add_to_current) +#define __itt_relation_add(d,x,y,z) ITTNOTIFY_VOID_D3(relation_add,d,x,y,z) +#define __itt_relation_add_ptr ITTNOTIFY_NAME(relation_add) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_relation_add_to_current(d,x,y) +#define __itt_relation_add_to_current_ptr 0 +#define __itt_relation_add(d,x,y,z) +#define __itt_relation_add_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_relation_add_to_current_ptr 0 +#define __itt_relation_add_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} relations group */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_clock_info +{ + unsigned long long clock_freq; /*!< Clock domain frequency */ + unsigned long long clock_base; /*!< Clock domain base timestamp */ +} __itt_clock_info; + +#pragma pack(pop) +/** @endcond */ + +/** @cond exclude_from_documentation */ +typedef void (ITTAPI *__itt_get_clock_info_fn)(__itt_clock_info* clock_info, void* data); +/** @endcond */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_clock_domain +{ + __itt_clock_info info; /*!< Most recent clock domain info */ + __itt_get_clock_info_fn fn; /*!< Callback function pointer */ + void* fn_data; /*!< Input argument for the callback function */ + int extra1; /*!< Reserved. Must be zero */ + void* extra2; /*!< Reserved. Must be zero */ + struct ___itt_clock_domain* next; +} __itt_clock_domain; + +#pragma pack(pop) +/** @endcond */ + +/** + * @ingroup clockdomains + * @brief Create a clock domain. + * Certain applications require the capability to trace their application using + * a clock domain different than the CPU, for instance the instrumentation of events + * that occur on a GPU. + * Because the set of domains is expected to be static over the application's execution time, + * there is no mechanism to destroy a domain. + * Any domain can be accessed by any thread in the process, regardless of which thread created + * the domain. This call is thread-safe. + * @param[in] fn A pointer to a callback function which retrieves alternative CPU timestamps + * @param[in] fn_data Argument for a callback function; may be NULL + */ +__itt_clock_domain* ITTAPI __itt_clock_domain_create(__itt_get_clock_info_fn fn, void* fn_data); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_clock_domain*, clock_domain_create, (__itt_get_clock_info_fn fn, void* fn_data)) +#define __itt_clock_domain_create ITTNOTIFY_DATA(clock_domain_create) +#define __itt_clock_domain_create_ptr ITTNOTIFY_NAME(clock_domain_create) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_clock_domain_create(fn,fn_data) (__itt_clock_domain*)0 +#define __itt_clock_domain_create_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_clock_domain_create_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup clockdomains + * @brief Recalculate clock domains frequencies and clock base timestamps. + */ +void ITTAPI __itt_clock_domain_reset(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, clock_domain_reset, (void)) +#define __itt_clock_domain_reset ITTNOTIFY_VOID(clock_domain_reset) +#define __itt_clock_domain_reset_ptr ITTNOTIFY_NAME(clock_domain_reset) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_clock_domain_reset() +#define __itt_clock_domain_reset_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_clock_domain_reset_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup clockdomain + * @brief Create an instance of identifier. This establishes the beginning of the lifetime of + * an instance of the given ID in the trace. Once this lifetime starts, the ID can be used to + * tag named entity instances in calls such as __itt_task_begin, and to specify relationships among + * identified named entity instances, using the \ref relations APIs. + * @param[in] domain The domain controlling the execution of this call. + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] id The ID to create. + */ +void ITTAPI __itt_id_create_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id); + +/** + * @ingroup clockdomain + * @brief Destroy an instance of identifier. This ends the lifetime of the current instance of the + * given ID value in the trace. Any relationships that are established after this lifetime ends are + * invalid. This call must be performed before the given ID value can be reused for a different + * named entity instance. + * @param[in] domain The domain controlling the execution of this call. + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] id The ID to destroy. + */ +void ITTAPI __itt_id_destroy_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, id_create_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id)) +ITT_STUBV(ITTAPI, void, id_destroy_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id)) +#define __itt_id_create_ex(d,x,y,z) ITTNOTIFY_VOID_D3(id_create_ex,d,x,y,z) +#define __itt_id_create_ex_ptr ITTNOTIFY_NAME(id_create_ex) +#define __itt_id_destroy_ex(d,x,y,z) ITTNOTIFY_VOID_D3(id_destroy_ex,d,x,y,z) +#define __itt_id_destroy_ex_ptr ITTNOTIFY_NAME(id_destroy_ex) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_id_create_ex(domain,clock_domain,timestamp,id) +#define __itt_id_create_ex_ptr 0 +#define __itt_id_destroy_ex(domain,clock_domain,timestamp,id) +#define __itt_id_destroy_ex_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_id_create_ex_ptr 0 +#define __itt_id_destroy_ex_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup clockdomain + * @brief Begin a task instance. + * @param[in] domain The domain for this task + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] taskid The instance ID for this task instance, or __itt_null + * @param[in] parentid The parent instance to which this task instance belongs, or __itt_null + * @param[in] name The name of this task + */ +void ITTAPI __itt_task_begin_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, __itt_string_handle* name); + +/** + * @ingroup clockdomain + * @brief Begin a task instance. + * @param[in] domain The domain for this task + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] taskid The identifier for this task instance, or __itt_null + * @param[in] parentid The parent of this task, or __itt_null + * @param[in] fn The pointer to the function you are tracing + */ +void ITTAPI __itt_task_begin_fn_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, void* fn); + +/** + * @ingroup clockdomain + * @brief End the current task instance. + * @param[in] domain The domain for this task + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + */ +void ITTAPI __itt_task_end_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, task_begin_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_id parentid, __itt_string_handle *name)) +ITT_STUBV(ITTAPI, void, task_begin_fn_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_id parentid, void* fn)) +ITT_STUBV(ITTAPI, void, task_end_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp)) +#define __itt_task_begin_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(task_begin_ex,d,x,y,z,a,b) +#define __itt_task_begin_ex_ptr ITTNOTIFY_NAME(task_begin_ex) +#define __itt_task_begin_fn_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(task_begin_fn_ex,d,x,y,z,a,b) +#define __itt_task_begin_fn_ex_ptr ITTNOTIFY_NAME(task_begin_fn_ex) +#define __itt_task_end_ex(d,x,y) ITTNOTIFY_VOID_D2(task_end_ex,d,x,y) +#define __itt_task_end_ex_ptr ITTNOTIFY_NAME(task_end_ex) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_task_begin_ex(domain,clock_domain,timestamp,id,parentid,name) +#define __itt_task_begin_ex_ptr 0 +#define __itt_task_begin_fn_ex(domain,clock_domain,timestamp,id,parentid,fn) +#define __itt_task_begin_fn_ex_ptr 0 +#define __itt_task_end_ex(domain,clock_domain,timestamp) +#define __itt_task_end_ex_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_task_begin_ex_ptr 0 +#define __itt_task_begin_fn_ex_ptr 0 +#define __itt_task_end_ex_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @defgroup counters Counters + * @ingroup public + * Counters are user-defined objects with a monotonically increasing + * value. Counter values are 64-bit unsigned integers. + * Counters have names that can be displayed in + * the tools. + * @{ + */ + +/** + * @brief opaque structure for counter identification + */ +/** @cond exclude_from_documentation */ + +typedef struct ___itt_counter* __itt_counter; + +/** + * @brief Create an unsigned 64 bits integer counter with given name/domain + * + * After __itt_counter_create() is called, __itt_counter_inc(id), __itt_counter_inc_delta(id, delta), + * __itt_counter_set_value(id, value_ptr) or __itt_counter_set_value_ex(id, clock_domain, timestamp, value_ptr) + * can be used to change the value of the counter, where value_ptr is a pointer to an unsigned 64 bits integer + * + * The call is equal to __itt_counter_create_typed(name, domain, __itt_metadata_u64) + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_counter ITTAPI __itt_counter_createA(const char *name, const char *domain); +__itt_counter ITTAPI __itt_counter_createW(const wchar_t *name, const wchar_t *domain); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_counter_create __itt_counter_createW +# define __itt_counter_create_ptr __itt_counter_createW_ptr +#else /* UNICODE */ +# define __itt_counter_create __itt_counter_createA +# define __itt_counter_create_ptr __itt_counter_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_counter ITTAPI __itt_counter_create(const char *name, const char *domain); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_counter, counter_createA, (const char *name, const char *domain)) +ITT_STUB(ITTAPI, __itt_counter, counter_createW, (const wchar_t *name, const wchar_t *domain)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_counter, counter_create, (const char *name, const char *domain)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_createA ITTNOTIFY_DATA(counter_createA) +#define __itt_counter_createA_ptr ITTNOTIFY_NAME(counter_createA) +#define __itt_counter_createW ITTNOTIFY_DATA(counter_createW) +#define __itt_counter_createW_ptr ITTNOTIFY_NAME(counter_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create ITTNOTIFY_DATA(counter_create) +#define __itt_counter_create_ptr ITTNOTIFY_NAME(counter_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_createA(name, domain) +#define __itt_counter_createA_ptr 0 +#define __itt_counter_createW(name, domain) +#define __itt_counter_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create(name, domain) +#define __itt_counter_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_createA_ptr 0 +#define __itt_counter_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Increment the unsigned 64 bits integer counter value + * + * Calling this function to non-unsigned 64 bits integer counters has no effect + */ +void ITTAPI __itt_counter_inc(__itt_counter id); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_inc, (__itt_counter id)) +#define __itt_counter_inc ITTNOTIFY_VOID(counter_inc) +#define __itt_counter_inc_ptr ITTNOTIFY_NAME(counter_inc) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_inc(id) +#define __itt_counter_inc_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_inc_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** + * @brief Increment the unsigned 64 bits integer counter value with x + * + * Calling this function to non-unsigned 64 bits integer counters has no effect + */ +void ITTAPI __itt_counter_inc_delta(__itt_counter id, unsigned long long value); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_inc_delta, (__itt_counter id, unsigned long long value)) +#define __itt_counter_inc_delta ITTNOTIFY_VOID(counter_inc_delta) +#define __itt_counter_inc_delta_ptr ITTNOTIFY_NAME(counter_inc_delta) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_inc_delta(id, value) +#define __itt_counter_inc_delta_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_inc_delta_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Decrement the unsigned 64 bits integer counter value + * + * Calling this function to non-unsigned 64 bits integer counters has no effect + */ +void ITTAPI __itt_counter_dec(__itt_counter id); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_dec, (__itt_counter id)) +#define __itt_counter_dec ITTNOTIFY_VOID(counter_dec) +#define __itt_counter_dec_ptr ITTNOTIFY_NAME(counter_dec) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_dec(id) +#define __itt_counter_dec_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_dec_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** + * @brief Decrement the unsigned 64 bits integer counter value with x + * + * Calling this function to non-unsigned 64 bits integer counters has no effect + */ +void ITTAPI __itt_counter_dec_delta(__itt_counter id, unsigned long long value); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_dec_delta, (__itt_counter id, unsigned long long value)) +#define __itt_counter_dec_delta ITTNOTIFY_VOID(counter_dec_delta) +#define __itt_counter_dec_delta_ptr ITTNOTIFY_NAME(counter_dec_delta) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_dec_delta(id, value) +#define __itt_counter_dec_delta_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_dec_delta_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup counters + * @brief Increment a counter by one. + * The first call with a given name creates a counter by that name and sets its + * value to zero. Successive calls increment the counter value. + * @param[in] domain The domain controlling the call. Counter names are not domain specific. + * The domain argument is used only to enable or disable the API calls. + * @param[in] name The name of the counter + */ +void ITTAPI __itt_counter_inc_v3(const __itt_domain *domain, __itt_string_handle *name); + +/** + * @ingroup counters + * @brief Increment a counter by the value specified in delta. + * @param[in] domain The domain controlling the call. Counter names are not domain specific. + * The domain argument is used only to enable or disable the API calls. + * @param[in] name The name of the counter + * @param[in] delta The amount by which to increment the counter + */ +void ITTAPI __itt_counter_inc_delta_v3(const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_inc_v3, (const __itt_domain *domain, __itt_string_handle *name)) +ITT_STUBV(ITTAPI, void, counter_inc_delta_v3, (const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta)) +#define __itt_counter_inc_v3(d,x) ITTNOTIFY_VOID_D1(counter_inc_v3,d,x) +#define __itt_counter_inc_v3_ptr ITTNOTIFY_NAME(counter_inc_v3) +#define __itt_counter_inc_delta_v3(d,x,y) ITTNOTIFY_VOID_D2(counter_inc_delta_v3,d,x,y) +#define __itt_counter_inc_delta_v3_ptr ITTNOTIFY_NAME(counter_inc_delta_v3) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_inc_v3(domain,name) +#define __itt_counter_inc_v3_ptr 0 +#define __itt_counter_inc_delta_v3(domain,name,delta) +#define __itt_counter_inc_delta_v3_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_inc_v3_ptr 0 +#define __itt_counter_inc_delta_v3_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + + +/** + * @ingroup counters + * @brief Decrement a counter by one. + * The first call with a given name creates a counter by that name and sets its + * value to zero. Successive calls decrement the counter value. + * @param[in] domain The domain controlling the call. Counter names are not domain specific. + * The domain argument is used only to enable or disable the API calls. + * @param[in] name The name of the counter + */ +void ITTAPI __itt_counter_dec_v3(const __itt_domain *domain, __itt_string_handle *name); + +/** + * @ingroup counters + * @brief Decrement a counter by the value specified in delta. + * @param[in] domain The domain controlling the call. Counter names are not domain specific. + * The domain argument is used only to enable or disable the API calls. + * @param[in] name The name of the counter + * @param[in] delta The amount by which to decrement the counter + */ +void ITTAPI __itt_counter_dec_delta_v3(const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_dec_v3, (const __itt_domain *domain, __itt_string_handle *name)) +ITT_STUBV(ITTAPI, void, counter_dec_delta_v3, (const __itt_domain *domain, __itt_string_handle *name, unsigned long long delta)) +#define __itt_counter_dec_v3(d,x) ITTNOTIFY_VOID_D1(counter_dec_v3,d,x) +#define __itt_counter_dec_v3_ptr ITTNOTIFY_NAME(counter_dec_v3) +#define __itt_counter_dec_delta_v3(d,x,y) ITTNOTIFY_VOID_D2(counter_dec_delta_v3,d,x,y) +#define __itt_counter_dec_delta_v3_ptr ITTNOTIFY_NAME(counter_dec_delta_v3) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_dec_v3(domain,name) +#define __itt_counter_dec_v3_ptr 0 +#define __itt_counter_dec_delta_v3(domain,name,delta) +#define __itt_counter_dec_delta_v3_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_dec_v3_ptr 0 +#define __itt_counter_dec_delta_v3_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @} counters group */ + + +/** + * @brief Set the counter value + */ +void ITTAPI __itt_counter_set_value(__itt_counter id, void *value_ptr); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_set_value, (__itt_counter id, void *value_ptr)) +#define __itt_counter_set_value ITTNOTIFY_VOID(counter_set_value) +#define __itt_counter_set_value_ptr ITTNOTIFY_NAME(counter_set_value) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_set_value(id, value_ptr) +#define __itt_counter_set_value_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_set_value_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Set the counter value + */ +void ITTAPI __itt_counter_set_value_ex(__itt_counter id, __itt_clock_domain *clock_domain, unsigned long long timestamp, void *value_ptr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_set_value_ex, (__itt_counter id, __itt_clock_domain *clock_domain, unsigned long long timestamp, void *value_ptr)) +#define __itt_counter_set_value_ex ITTNOTIFY_VOID(counter_set_value_ex) +#define __itt_counter_set_value_ex_ptr ITTNOTIFY_NAME(counter_set_value_ex) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_set_value_ex(id, clock_domain, timestamp, value_ptr) +#define __itt_counter_set_value_ex_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_set_value_ex_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Create a typed counter with given name/domain + * + * After __itt_counter_create_typed() is called, __itt_counter_inc(id), __itt_counter_inc_delta(id, delta), + * __itt_counter_set_value(id, value_ptr) or __itt_counter_set_value_ex(id, clock_domain, timestamp, value_ptr) + * can be used to change the value of the counter + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_counter ITTAPI __itt_counter_create_typedA(const char *name, const char *domain, __itt_metadata_type type); +__itt_counter ITTAPI __itt_counter_create_typedW(const wchar_t *name, const wchar_t *domain, __itt_metadata_type type); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_counter_create_typed __itt_counter_create_typedW +# define __itt_counter_create_typed_ptr __itt_counter_create_typedW_ptr +#else /* UNICODE */ +# define __itt_counter_create_typed __itt_counter_create_typedA +# define __itt_counter_create_typed_ptr __itt_counter_create_typedA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_counter ITTAPI __itt_counter_create_typed(const char *name, const char *domain, __itt_metadata_type type); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_counter, counter_create_typedA, (const char *name, const char *domain, __itt_metadata_type type)) +ITT_STUB(ITTAPI, __itt_counter, counter_create_typedW, (const wchar_t *name, const wchar_t *domain, __itt_metadata_type type)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_counter, counter_create_typed, (const char *name, const char *domain, __itt_metadata_type type)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_create_typedA ITTNOTIFY_DATA(counter_create_typedA) +#define __itt_counter_create_typedA_ptr ITTNOTIFY_NAME(counter_create_typedA) +#define __itt_counter_create_typedW ITTNOTIFY_DATA(counter_create_typedW) +#define __itt_counter_create_typedW_ptr ITTNOTIFY_NAME(counter_create_typedW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_typed ITTNOTIFY_DATA(counter_create_typed) +#define __itt_counter_create_typed_ptr ITTNOTIFY_NAME(counter_create_typed) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_create_typedA(name, domain, type) +#define __itt_counter_create_typedA_ptr 0 +#define __itt_counter_create_typedW(name, domain, type) +#define __itt_counter_create_typedW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_typed(name, domain, type) +#define __itt_counter_create_typed_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_create_typedA_ptr 0 +#define __itt_counter_create_typedW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_typed_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Destroy the counter identified by the pointer previously returned by __itt_counter_create() or + * __itt_counter_create_typed() + */ +void ITTAPI __itt_counter_destroy(__itt_counter id); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_destroy, (__itt_counter id)) +#define __itt_counter_destroy ITTNOTIFY_VOID(counter_destroy) +#define __itt_counter_destroy_ptr ITTNOTIFY_NAME(counter_destroy) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_destroy(id) +#define __itt_counter_destroy_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_destroy_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} counters group */ + +/** + * @ingroup markers + * @brief Create a marker instance. + * @param[in] domain The domain for this marker + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] id The instance ID for this marker, or __itt_null + * @param[in] name The name for this marker + * @param[in] scope The scope for this marker + */ +void ITTAPI __itt_marker_ex(const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_string_handle *name, __itt_scope scope); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, marker_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id id, __itt_string_handle *name, __itt_scope scope)) +#define __itt_marker_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(marker_ex,d,x,y,z,a,b) +#define __itt_marker_ex_ptr ITTNOTIFY_NAME(marker_ex) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_marker_ex(domain,clock_domain,timestamp,id,name,scope) +#define __itt_marker_ex_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_marker_ex_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @ingroup clockdomain + * @brief Add a relation to the current task instance. + * The current task instance is the head of the relation. + * @param[in] domain The domain controlling this call + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] relation The kind of relation + * @param[in] tail The ID for the tail of the relation + */ +void ITTAPI __itt_relation_add_to_current_ex(const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_relation relation, __itt_id tail); + +/** + * @ingroup clockdomain + * @brief Add a relation between two instance identifiers. + * @param[in] domain The domain controlling this call + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] head The ID for the head of the relation + * @param[in] relation The kind of relation + * @param[in] tail The ID for the tail of the relation + */ +void ITTAPI __itt_relation_add_ex(const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id head, __itt_relation relation, __itt_id tail); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, relation_add_to_current_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_relation relation, __itt_id tail)) +ITT_STUBV(ITTAPI, void, relation_add_ex, (const __itt_domain *domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id head, __itt_relation relation, __itt_id tail)) +#define __itt_relation_add_to_current_ex(d,x,y,z,a) ITTNOTIFY_VOID_D4(relation_add_to_current_ex,d,x,y,z,a) +#define __itt_relation_add_to_current_ex_ptr ITTNOTIFY_NAME(relation_add_to_current_ex) +#define __itt_relation_add_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(relation_add_ex,d,x,y,z,a,b) +#define __itt_relation_add_ex_ptr ITTNOTIFY_NAME(relation_add_ex) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_relation_add_to_current_ex(domain,clock_domain,timestame,relation,tail) +#define __itt_relation_add_to_current_ex_ptr 0 +#define __itt_relation_add_ex(domain,clock_domain,timestamp,head,relation,tail) +#define __itt_relation_add_ex_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_relation_add_to_current_ex_ptr 0 +#define __itt_relation_add_ex_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @cond exclude_from_documentation */ +typedef enum ___itt_track_group_type +{ + __itt_track_group_type_normal = 0 +} __itt_track_group_type; +/** @endcond */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_track_group +{ + __itt_string_handle* name; /*!< Name of the track group */ + struct ___itt_track* track; /*!< List of child tracks */ + __itt_track_group_type tgtype; /*!< Type of the track group */ + int extra1; /*!< Reserved. Must be zero */ + void* extra2; /*!< Reserved. Must be zero */ + struct ___itt_track_group* next; +} __itt_track_group; + +#pragma pack(pop) +/** @endcond */ + +/** + * @brief Placeholder for custom track types. Currently, "normal" custom track + * is the only available track type. + */ +typedef enum ___itt_track_type +{ + __itt_track_type_normal = 0 +#ifdef INTEL_ITTNOTIFY_API_PRIVATE + , __itt_track_type_queue +#endif /* INTEL_ITTNOTIFY_API_PRIVATE */ +} __itt_track_type; + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_track +{ + __itt_string_handle* name; /*!< Name of the track group */ + __itt_track_group* group; /*!< Parent group to a track */ + __itt_track_type ttype; /*!< Type of the track */ + int extra1; /*!< Reserved. Must be zero */ + void* extra2; /*!< Reserved. Must be zero */ + struct ___itt_track* next; +} __itt_track; + +#pragma pack(pop) +/** @endcond */ + +/** + * @brief Create logical track group. + */ +__itt_track_group* ITTAPI __itt_track_group_create(__itt_string_handle* name, __itt_track_group_type track_group_type); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_track_group*, track_group_create, (__itt_string_handle* name, __itt_track_group_type track_group_type)) +#define __itt_track_group_create ITTNOTIFY_DATA(track_group_create) +#define __itt_track_group_create_ptr ITTNOTIFY_NAME(track_group_create) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_track_group_create(name) (__itt_track_group*)0 +#define __itt_track_group_create_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_track_group_create_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Create logical track. + */ +__itt_track* ITTAPI __itt_track_create(__itt_track_group* track_group, __itt_string_handle* name, __itt_track_type track_type); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_track*, track_create, (__itt_track_group* track_group,__itt_string_handle* name, __itt_track_type track_type)) +#define __itt_track_create ITTNOTIFY_DATA(track_create) +#define __itt_track_create_ptr ITTNOTIFY_NAME(track_create) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_track_create(track_group,name,track_type) (__itt_track*)0 +#define __itt_track_create_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_track_create_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Set the logical track. + */ +void ITTAPI __itt_set_track(__itt_track* track); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, set_track, (__itt_track *track)) +#define __itt_set_track ITTNOTIFY_VOID(set_track) +#define __itt_set_track_ptr ITTNOTIFY_NAME(set_track) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_set_track(track) +#define __itt_set_track_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_set_track_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/* ========================================================================== */ +/** @cond exclude_from_gpa_documentation */ +/** + * @defgroup events Events + * @ingroup public + * Events group + * @{ + */ +/** @brief user event type */ +typedef int __itt_event; + +/** + * @brief Create an event notification + * @note name or namelen being null/name and namelen not matching, user event feature not enabled + * @return non-zero event identifier upon success and __itt_err otherwise + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_event LIBITTAPI __itt_event_createA(const char *name, int namelen); +__itt_event LIBITTAPI __itt_event_createW(const wchar_t *name, int namelen); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_event_create __itt_event_createW +# define __itt_event_create_ptr __itt_event_createW_ptr +#else +# define __itt_event_create __itt_event_createA +# define __itt_event_create_ptr __itt_event_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_event LIBITTAPI __itt_event_create(const char *name, int namelen); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(LIBITTAPI, __itt_event, event_createA, (const char *name, int namelen)) +ITT_STUB(LIBITTAPI, __itt_event, event_createW, (const wchar_t *name, int namelen)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(LIBITTAPI, __itt_event, event_create, (const char *name, int namelen)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_event_createA ITTNOTIFY_DATA(event_createA) +#define __itt_event_createA_ptr ITTNOTIFY_NAME(event_createA) +#define __itt_event_createW ITTNOTIFY_DATA(event_createW) +#define __itt_event_createW_ptr ITTNOTIFY_NAME(event_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_event_create ITTNOTIFY_DATA(event_create) +#define __itt_event_create_ptr ITTNOTIFY_NAME(event_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_event_createA(name, namelen) (__itt_event)0 +#define __itt_event_createA_ptr 0 +#define __itt_event_createW(name, namelen) (__itt_event)0 +#define __itt_event_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_event_create(name, namelen) (__itt_event)0 +#define __itt_event_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_event_createA_ptr 0 +#define __itt_event_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_event_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record an event occurrence. + * @return __itt_err upon failure (invalid event id/user event feature not enabled) + */ +int LIBITTAPI __itt_event_start(__itt_event event); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(LIBITTAPI, int, event_start, (__itt_event event)) +#define __itt_event_start ITTNOTIFY_DATA(event_start) +#define __itt_event_start_ptr ITTNOTIFY_NAME(event_start) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_event_start(event) (int)0 +#define __itt_event_start_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_event_start_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Record an event end occurrence. + * @note It is optional if events do not have durations. + * @return __itt_err upon failure (invalid event id/user event feature not enabled) + */ +int LIBITTAPI __itt_event_end(__itt_event event); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(LIBITTAPI, int, event_end, (__itt_event event)) +#define __itt_event_end ITTNOTIFY_DATA(event_end) +#define __itt_event_end_ptr ITTNOTIFY_NAME(event_end) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_event_end(event) (int)0 +#define __itt_event_end_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_event_end_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} events group */ + + +/** + * @defgroup arrays Arrays Visualizer + * @ingroup public + * Visualize arrays + * @{ + */ + +/** + * @enum __itt_av_data_type + * @brief Defines types of arrays data (for C/C++ intrinsic types) + */ +typedef enum +{ + __itt_e_first = 0, + __itt_e_char = 0, /* 1-byte integer */ + __itt_e_uchar, /* 1-byte unsigned integer */ + __itt_e_int16, /* 2-byte integer */ + __itt_e_uint16, /* 2-byte unsigned integer */ + __itt_e_int32, /* 4-byte integer */ + __itt_e_uint32, /* 4-byte unsigned integer */ + __itt_e_int64, /* 8-byte integer */ + __itt_e_uint64, /* 8-byte unsigned integer */ + __itt_e_float, /* 4-byte floating */ + __itt_e_double, /* 8-byte floating */ + __itt_e_last = __itt_e_double +} __itt_av_data_type; + +/** + * @brief Save an array data to a file. + * Output format is defined by the file extension. The csv and bmp formats are supported (bmp - for 2-dimensional array only). + * @param[in] data - pointer to the array data + * @param[in] rank - the rank of the array + * @param[in] dimensions - pointer to an array of integers, which specifies the array dimensions. + * The size of dimensions must be equal to the rank + * @param[in] type - the type of the array, specified as one of the __itt_av_data_type values (for intrinsic types) + * @param[in] filePath - the file path; the output format is defined by the file extension + * @param[in] columnOrder - defines how the array is stored in the linear memory. + * It should be 1 for column-major order (e.g. in FORTRAN) or 0 - for row-major order (e.g. in C). + */ + +#if ITT_PLATFORM==ITT_PLATFORM_WIN +int ITTAPI __itt_av_saveA(void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder); +int ITTAPI __itt_av_saveW(void *data, int rank, const int *dimensions, int type, const wchar_t *filePath, int columnOrder); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_av_save __itt_av_saveW +# define __itt_av_save_ptr __itt_av_saveW_ptr +#else /* UNICODE */ +# define __itt_av_save __itt_av_saveA +# define __itt_av_save_ptr __itt_av_saveA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +int ITTAPI __itt_av_save(void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, int, av_saveA, (void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder)) +ITT_STUB(ITTAPI, int, av_saveW, (void *data, int rank, const int *dimensions, int type, const wchar_t *filePath, int columnOrder)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, int, av_save, (void *data, int rank, const int *dimensions, int type, const char *filePath, int columnOrder)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_av_saveA ITTNOTIFY_DATA(av_saveA) +#define __itt_av_saveA_ptr ITTNOTIFY_NAME(av_saveA) +#define __itt_av_saveW ITTNOTIFY_DATA(av_saveW) +#define __itt_av_saveW_ptr ITTNOTIFY_NAME(av_saveW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_av_save ITTNOTIFY_DATA(av_save) +#define __itt_av_save_ptr ITTNOTIFY_NAME(av_save) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_av_saveA(name) +#define __itt_av_saveA_ptr 0 +#define __itt_av_saveW(name) +#define __itt_av_saveW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_av_save(name) +#define __itt_av_save_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_av_saveA_ptr 0 +#define __itt_av_saveW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_av_save_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +void ITTAPI __itt_enable_attach(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, enable_attach, (void)) +#define __itt_enable_attach ITTNOTIFY_VOID(enable_attach) +#define __itt_enable_attach_ptr ITTNOTIFY_NAME(enable_attach) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_enable_attach() +#define __itt_enable_attach_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_enable_attach_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @cond exclude_from_gpa_documentation */ + +/** @} arrays group */ + +/** @endcond */ + +/** + * @brief Module load notification + * This API is used to report necessary information in case of bypassing default system loader. + * Notification should be done immidiatelly after this module is loaded to process memory. + * @param[in] start_addr - module start address + * @param[in] end_addr - module end address + * @param[in] path - file system full path to the module + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +void ITTAPI __itt_module_loadA(void *start_addr, void *end_addr, const char *path); +void ITTAPI __itt_module_loadW(void *start_addr, void *end_addr, const wchar_t *path); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_module_load __itt_module_loadW +# define __itt_module_load_ptr __itt_module_loadW_ptr +#else /* UNICODE */ +# define __itt_module_load __itt_module_loadA +# define __itt_module_load_ptr __itt_module_loadA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +void ITTAPI __itt_module_load(void *start_addr, void *end_addr, const char *path); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, void, module_loadA, (void *start_addr, void *end_addr, const char *path)) +ITT_STUB(ITTAPI, void, module_loadW, (void *start_addr, void *end_addr, const wchar_t *path)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, void, module_load, (void *start_addr, void *end_addr, const char *path)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_module_loadA ITTNOTIFY_VOID(module_loadA) +#define __itt_module_loadA_ptr ITTNOTIFY_NAME(module_loadA) +#define __itt_module_loadW ITTNOTIFY_VOID(module_loadW) +#define __itt_module_loadW_ptr ITTNOTIFY_NAME(module_loadW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_module_load ITTNOTIFY_VOID(module_load) +#define __itt_module_load_ptr ITTNOTIFY_NAME(module_load) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_module_loadA(start_addr, end_addr, path) +#define __itt_module_loadA_ptr 0 +#define __itt_module_loadW(start_addr, end_addr, path) +#define __itt_module_loadW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_module_load(start_addr, end_addr, path) +#define __itt_module_load_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_module_loadA_ptr 0 +#define __itt_module_loadW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_module_load_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Report module unload + * This API is used to report necessary information in case of bypassing default system loader. + * Notification should be done just before the module is unloaded from process memory. + * @param[in] addr - base address of loaded module + */ +void ITTAPI __itt_module_unload(void *addr); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, module_unload, (void *addr)) +#define __itt_module_unload ITTNOTIFY_VOID(module_unload) +#define __itt_module_unload_ptr ITTNOTIFY_NAME(module_unload) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_module_unload(addr) +#define __itt_module_unload_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_module_unload_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @cond exclude_from_documentation */ +typedef enum +{ + __itt_module_type_unknown = 0, + __itt_module_type_elf, + __itt_module_type_coff +} __itt_module_type; +/** @endcond */ + +/** @cond exclude_from_documentation */ +typedef enum +{ + itt_section_type_unknown, + itt_section_type_bss, /* notifies that the section contains uninitialized data. These are the relevant section types and the modules that contain them: + * ELF module: SHT_NOBITS section type + * COFF module: IMAGE_SCN_CNT_UNINITIALIZED_DATA section type + */ + itt_section_type_data, /* notifies that section contains initialized data. These are the relevant section types and the modules that contain them: + * ELF module: SHT_PROGBITS section type + * COFF module: IMAGE_SCN_CNT_INITIALIZED_DATA section type + */ + itt_section_type_text /* notifies that the section contains executable code. These are the relevant section types and the modules that contain them: + * ELF module: SHT_PROGBITS section type + * COFF module: IMAGE_SCN_CNT_CODE section type + */ +} __itt_section_type; +/** @endcond */ + +/** + * @hideinitializer + * @brief bit-mask, detects a section attribute that indicates whether a section can be executed as code: + * These are the relevant section attributes and the modules that contain them: + * ELF module: PF_X section attribute + * COFF module: IMAGE_SCN_MEM_EXECUTE attribute + */ +#define __itt_section_exec 0x20000000 + +/** + * @hideinitializer + * @brief bit-mask, detects a section attribute that indicates whether a section can be read. + * These are the relevant section attributes and the modules that contain them: + * ELF module: PF_R attribute + * COFF module: IMAGE_SCN_MEM_READ attribute + */ +#define __itt_section_read 0x40000000 + +/** + * @hideinitializer + * @brief bit-mask, detects a section attribute that indicates whether a section can be written to. + * These are the relevant section attributes and the modules that contain them: + * ELF module: PF_W attribute + * COFF module: IMAGE_SCN_MEM_WRITE attribute + */ +#define __itt_section_write 0x80000000 + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_section_info +{ + const char* name; /*!< Section name in UTF8 */ + __itt_section_type type; /*!< Section content and semantics description */ + size_t flags; /*!< Section bit flags that describe attributes using bit mask + * Zero if disabled, non-zero if enabled + */ + void* start_addr; /*!< Section load(relocated) start address */ + size_t size; /*!< Section file offset */ + size_t file_offset; /*!< Section size */ +} __itt_section_info; + +#pragma pack(pop) +/** @endcond */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_module_object +{ + unsigned int version; /*!< API version*/ + __itt_id module_id; /*!< Unique identifier. This is unchanged for sections that belong to the same module */ + __itt_module_type module_type; /*!< Binary module format */ + const char* module_name; /*!< Unique module name or path to module in UTF8 + * Contains module name when module_bufer and module_size exist + * Contains module path when module_bufer and module_size absent + * module_name remains the same for the certain module_id + */ + void* module_buffer; /*!< Module buffer content */ + size_t module_size; /*!< Module buffer size */ + /*!< If module_buffer and module_size exist, the binary module is dumped onto the system. + * If module_buffer and module_size do not exist, + * the binary module exists on the system already. + * The module_name parameter contains the path to the module. + */ + __itt_section_info* section_array; /*!< Reference to section information */ + size_t section_number; +} __itt_module_object; + +#pragma pack(pop) +/** @endcond */ + +/** + * @brief Load module content and its loaded(relocated) sections. + * This API is useful to save a module, or specify its location on the system and report information about loaded sections. + * The target module is saved on the system if module buffer content and size are available. + * If module buffer content and size are unavailable, the module name contains the path to the existing binary module. + * @param[in] module_obj - provides module and section information, along with unique module identifiers (name,module ID) + * which bind the binary module to particular sections. + */ +void ITTAPI __itt_module_load_with_sections(__itt_module_object* module_obj); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, module_load_with_sections, (__itt_module_object* module_obj)) +#define __itt_module_load_with_sections ITTNOTIFY_VOID(module_load_with_sections) +#define __itt_module_load_with_sections_ptr ITTNOTIFY_NAME(module_load_with_sections) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_module_load_with_sections(module_obj) +#define __itt_module_load_with_sections_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_module_load_with_sections_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Unload a module and its loaded(relocated) sections. + * This API notifies that the module and its sections were unloaded. + * @param[in] module_obj - provides module and sections information, along with unique module identifiers (name,module ID) + * which bind the binary module to particular sections. + */ +void ITTAPI __itt_module_unload_with_sections(__itt_module_object* module_obj); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, module_unload_with_sections, (__itt_module_object* module_obj)) +#define __itt_module_unload_with_sections ITTNOTIFY_VOID(module_unload_with_sections) +#define __itt_module_unload_with_sections_ptr ITTNOTIFY_NAME(module_unload_with_sections) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_module_unload_with_sections(module_obj) +#define __itt_module_unload_with_sections_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_module_unload_with_sections_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_histogram +{ + const __itt_domain* domain; /*!< Domain of the histogram*/ + const char* nameA; /*!< Name of the histogram */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* nameW; +#else /* UNICODE || _UNICODE */ + void* nameW; +#endif /* UNICODE || _UNICODE */ + __itt_metadata_type x_type; /*!< Type of the histogram X axis */ + __itt_metadata_type y_type; /*!< Type of the histogram Y axis */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_histogram* next; +} __itt_histogram; + +#pragma pack(pop) +/** @endcond */ + +/** + * @brief Create a typed histogram instance with given name/domain. + * @param[in] domain The domain controlling the call. + * @param[in] name The name of the histogram. + * @param[in] x_type The type of the X axis in histogram (may be 0 to calculate batch statistics). + * @param[in] y_type The type of the Y axis in histogram. +*/ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_histogram* ITTAPI __itt_histogram_createA(const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type); +__itt_histogram* ITTAPI __itt_histogram_createW(const __itt_domain* domain, const wchar_t* name, __itt_metadata_type x_type, __itt_metadata_type y_type); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_histogram_create __itt_histogram_createW +# define __itt_histogram_create_ptr __itt_histogram_createW_ptr +#else /* UNICODE */ +# define __itt_histogram_create __itt_histogram_createA +# define __itt_histogram_create_ptr __itt_histogram_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_histogram* ITTAPI __itt_histogram_create(const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_histogram*, histogram_createA, (const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type)) +ITT_STUB(ITTAPI, __itt_histogram*, histogram_createW, (const __itt_domain* domain, const wchar_t* name, __itt_metadata_type x_type, __itt_metadata_type y_type)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_histogram*, histogram_create, (const __itt_domain* domain, const char* name, __itt_metadata_type x_type, __itt_metadata_type y_type)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_histogram_createA ITTNOTIFY_DATA(histogram_createA) +#define __itt_histogram_createA_ptr ITTNOTIFY_NAME(histogram_createA) +#define __itt_histogram_createW ITTNOTIFY_DATA(histogram_createW) +#define __itt_histogram_createW_ptr ITTNOTIFY_NAME(histogram_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_histogram_create ITTNOTIFY_DATA(histogram_create) +#define __itt_histogram_create_ptr ITTNOTIFY_NAME(histogram_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_histogram_createA(domain, name, x_type, y_type) (__itt_histogram*)0 +#define __itt_histogram_createA_ptr 0 +#define __itt_histogram_createW(domain, name, x_type, y_type) (__itt_histogram*)0 +#define __itt_histogram_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_histogram_create(domain, name, x_type, y_type) (__itt_histogram*)0 +#define __itt_histogram_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_histogram_createA_ptr 0 +#define __itt_histogram_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_histogram_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Submit statistics for a histogram instance. + * @param[in] hist Pointer to the histogram instance to which the histogram statistic is to be dumped. + * @param[in] length The number of elements in dumped axis data array. + * @param[in] x_data The X axis dumped data itself (may be NULL to calculate batch statistics). + * @param[in] y_data The Y axis dumped data itself. +*/ +void ITTAPI __itt_histogram_submit(__itt_histogram* hist, size_t length, void* x_data, void* y_data); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, histogram_submit, (__itt_histogram* hist, size_t length, void* x_data, void* y_data)) +#define __itt_histogram_submit ITTNOTIFY_VOID(histogram_submit) +#define __itt_histogram_submit_ptr ITTNOTIFY_NAME(histogram_submit) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_histogram_submit(hist, length, x_data, y_data) +#define __itt_histogram_submit_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_histogram_submit_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ + +/** +* @brief function allows to obtain the current collection state at the moment +* @return collection state as a enum __itt_collection_state +*/ +__itt_collection_state __itt_get_collection_state(void); + +/** +* @brief function releases resources allocated by ITT API static part +* this API should be called from the library destructor +* @return void +*/ +void __itt_release_resources(void); +/** @endcond */ + +/** + * @brief Create a typed counter with given domain pointer, string name and counter type +*/ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_counter ITTAPI __itt_counter_createA_v3(const __itt_domain* domain, const char* name, __itt_metadata_type type); +__itt_counter ITTAPI __itt_counter_createW_v3(const __itt_domain* domain, const wchar_t* name, __itt_metadata_type type); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_counter_create_v3 __itt_counter_createW_v3 +# define __itt_counter_create_v3_ptr __itt_counter_createW_v3_ptr +#else /* UNICODE */ +# define __itt_counter_create_v3 __itt_counter_createA_v3 +# define __itt_counter_create_v3_ptr __itt_counter_createA_v3_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_counter ITTAPI __itt_counter_create_v3(const __itt_domain* domain, const char* name, __itt_metadata_type type); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_counter, counter_createA_v3, (const __itt_domain* domain, const char* name, __itt_metadata_type type)) +ITT_STUB(ITTAPI, __itt_counter, counter_createW_v3, (const __itt_domain* domain, const wchar_t* name, __itt_metadata_type type)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_counter, counter_create_v3, (const __itt_domain* domain, const char* name, __itt_metadata_type type)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_createA_v3 ITTNOTIFY_DATA(counter_createA_v3) +#define __itt_counter_createA_v3_ptr ITTNOTIFY_NAME(counter_createA_v3) +#define __itt_counter_createW_v3 ITTNOTIFY_DATA(counter_createW_v3) +#define __itt_counter_createW_v3_ptr ITTNOTIFY_NAME(counter_createW_v3) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_v3 ITTNOTIFY_DATA(counter_create_v3) +#define __itt_counter_create_v3_ptr ITTNOTIFY_NAME(counter_create_v3) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_createA_v3(domain, name, type) (__itt_counter)0 +#define __itt_counter_createA_v3_ptr 0 +#define __itt_counter_createW_v3(domain, name, type) (__itt_counter)0 +#define __itt_counter_create_typedW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_v3(domain, name, type) (__itt_counter)0 +#define __itt_counter_create_v3_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_counter_createA_v3_ptr 0 +#define __itt_counter_createW_v3_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_counter_create_v3_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Set the counter value api + */ +void ITTAPI __itt_counter_set_value_v3(__itt_counter counter, void *value_ptr); + +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, counter_set_value_v3, (__itt_counter counter, void *value_ptr)) +#define __itt_counter_set_value_v3 ITTNOTIFY_VOID(counter_set_value_v3) +#define __itt_counter_set_value_v3_ptr ITTNOTIFY_NAME(counter_set_value_v3) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_counter_set_value_v3(counter, value_ptr) +#define __itt_counter_set_value_v3_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_counter_set_value_v3_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief describes the type of context metadata +*/ +typedef enum { + __itt_context_unknown = 0, /*!< Undefined type */ + __itt_context_nameA, /*!< ASCII string char* type */ + __itt_context_nameW, /*!< Unicode string wchar_t* type */ + __itt_context_deviceA, /*!< ASCII string char* type */ + __itt_context_deviceW, /*!< Unicode string wchar_t* type */ + __itt_context_unitsA, /*!< ASCII string char* type */ + __itt_context_unitsW, /*!< Unicode string wchar_t* type */ + __itt_context_pci_addrA, /*!< ASCII string char* type */ + __itt_context_pci_addrW, /*!< Unicode string wchar_t* type */ + __itt_context_tid, /*!< Unsigned 64-bit integer type */ + __itt_context_max_val, /*!< Unsigned 64-bit integer type */ + __itt_context_bandwidth_flag, /*!< Unsigned 64-bit integer type */ + __itt_context_latency_flag, /*!< Unsigned 64-bit integer type */ + __itt_context_occupancy_flag, /*!< Unsigned 64-bit integer type */ + __itt_context_on_thread_flag, /*!< Unsigned 64-bit integer type */ + __itt_context_is_abs_val_flag, /*!< Unsigned 64-bit integer type */ + __itt_context_cpu_instructions_flag, /*!< Unsigned 64-bit integer type */ + __itt_context_cpu_cycles_flag /*!< Unsigned 64-bit integer type */ +} __itt_context_type; + +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_context_name __itt_context_nameW +# define __itt_context_device __itt_context_deviceW +# define __itt_context_units __itt_context_unitsW +# define __itt_context_pci_addr __itt_context_pci_addrW +#else /* UNICODE || _UNICODE */ +# define __itt_context_name __itt_context_nameA +# define __itt_context_device __itt_context_deviceA +# define __itt_context_units __itt_context_unitsA +# define __itt_context_pci_addr __itt_context_pci_addrA +#endif /* UNICODE || _UNICODE */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_context_metadata +{ + __itt_context_type type; /*!< Type of the context metadata value */ + void* value; /*!< Pointer to context metadata value itself */ +} __itt_context_metadata; + +#pragma pack(pop) +/** @endcond */ + +/** @cond exclude_from_documentation */ +#pragma pack(push, 8) + +typedef struct ___itt_counter_metadata +{ + __itt_counter counter; /*!< Associated context metadata counter */ + __itt_context_type type; /*!< Type of the context metadata value */ + const char* str_valueA; /*!< String context metadata value */ +#if defined(UNICODE) || defined(_UNICODE) + const wchar_t* str_valueW; +#else /* UNICODE || _UNICODE */ + void* str_valueW; +#endif /* UNICODE || _UNICODE */ + unsigned long long value; /*!< Numeric context metadata value */ + int extra1; /*!< Reserved to the runtime */ + void* extra2; /*!< Reserved to the runtime */ + struct ___itt_counter_metadata* next; +} __itt_counter_metadata; + +#pragma pack(pop) +/** @endcond */ + +/** + * @brief Bind context metadata to counter instance + * @param[in] counter Pointer to the counter instance to which the context metadata is to be associated. + * @param[in] length The number of elements in context metadata array. + * @param[in] metadata The context metadata itself. +*/ +void ITTAPI __itt_bind_context_metadata_to_counter(__itt_counter counter, size_t length, __itt_context_metadata* metadata); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, bind_context_metadata_to_counter, (__itt_counter counter, size_t length, __itt_context_metadata* metadata)) +#define __itt_bind_context_metadata_to_counter ITTNOTIFY_VOID(bind_context_metadata_to_counter) +#define __itt_bind_context_metadata_to_counter_ptr ITTNOTIFY_NAME(bind_context_metadata_to_counter) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_bind_context_metadata_to_counter(counter, length, metadata) +#define __itt_bind_context_metadata_to_counter_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_bind_context_metadata_to_counter_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* _ITTNOTIFY_H_ */ + +#ifdef INTEL_ITTNOTIFY_API_PRIVATE + +#ifndef _ITTNOTIFY_PRIVATE_ +#define _ITTNOTIFY_PRIVATE_ + +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +/** + * @ingroup clockdomain + * @brief Begin an overlapped task instance. + * @param[in] domain The domain for this task + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] taskid The identifier for this task instance, *cannot* be __itt_null. + * @param[in] parentid The parent of this task, or __itt_null. + * @param[in] name The name of this task. + */ +void ITTAPI __itt_task_begin_overlapped_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, __itt_string_handle* name); + +/** + * @ingroup clockdomain + * @brief End an overlapped task instance. + * @param[in] domain The domain for this task + * @param[in] clock_domain The clock domain controlling the execution of this call. + * @param[in] timestamp The user defined timestamp. + * @param[in] taskid Explicit ID of finished task + */ +void ITTAPI __itt_task_end_overlapped_ex(const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, task_begin_overlapped_ex, (const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid, __itt_id parentid, __itt_string_handle* name)) +ITT_STUBV(ITTAPI, void, task_end_overlapped_ex, (const __itt_domain* domain, __itt_clock_domain* clock_domain, unsigned long long timestamp, __itt_id taskid)) +#define __itt_task_begin_overlapped_ex(d,x,y,z,a,b) ITTNOTIFY_VOID_D5(task_begin_overlapped_ex,d,x,y,z,a,b) +#define __itt_task_begin_overlapped_ex_ptr ITTNOTIFY_NAME(task_begin_overlapped_ex) +#define __itt_task_end_overlapped_ex(d,x,y,z) ITTNOTIFY_VOID_D3(task_end_overlapped_ex,d,x,y,z) +#define __itt_task_end_overlapped_ex_ptr ITTNOTIFY_NAME(task_end_overlapped_ex) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_task_begin_overlapped_ex(domain,clock_domain,timestamp,taskid,parentid,name) +#define __itt_task_begin_overlapped_ex_ptr 0 +#define __itt_task_end_overlapped_ex(domain,clock_domain,timestamp,taskid) +#define __itt_task_end_overlapped_ex_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_task_begin_overlapped_ex_ptr 0 +#define __itt_task_end_overlapped_ptr 0 +#define __itt_task_end_overlapped_ex_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @defgroup makrs_internal Marks + * @ingroup internal + * Marks group + * @warning Internal API: + * - It is not shipped to outside of Intel + * - It is delivered to internal Intel teams using e-mail or SVN access only + * @{ + */ +/** @brief user mark type */ +typedef int __itt_mark_type; + +/** + * @brief Creates a user mark type with the specified name using char or Unicode string. + * @param[in] name - name of mark to create + * @return Returns a handle to the mark type + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +__itt_mark_type ITTAPI __itt_mark_createA(const char *name); +__itt_mark_type ITTAPI __itt_mark_createW(const wchar_t *name); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_mark_create __itt_mark_createW +# define __itt_mark_create_ptr __itt_mark_createW_ptr +#else /* UNICODE */ +# define __itt_mark_create __itt_mark_createA +# define __itt_mark_create_ptr __itt_mark_createA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +__itt_mark_type ITTAPI __itt_mark_create(const char *name); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, __itt_mark_type, mark_createA, (const char *name)) +ITT_STUB(ITTAPI, __itt_mark_type, mark_createW, (const wchar_t *name)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, __itt_mark_type, mark_create, (const char *name)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_mark_createA ITTNOTIFY_DATA(mark_createA) +#define __itt_mark_createA_ptr ITTNOTIFY_NAME(mark_createA) +#define __itt_mark_createW ITTNOTIFY_DATA(mark_createW) +#define __itt_mark_createW_ptr ITTNOTIFY_NAME(mark_createW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_create ITTNOTIFY_DATA(mark_create) +#define __itt_mark_create_ptr ITTNOTIFY_NAME(mark_create) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_mark_createA(name) (__itt_mark_type)0 +#define __itt_mark_createA_ptr 0 +#define __itt_mark_createW(name) (__itt_mark_type)0 +#define __itt_mark_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_create(name) (__itt_mark_type)0 +#define __itt_mark_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_mark_createA_ptr 0 +#define __itt_mark_createW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_create_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Creates a "discrete" user mark type of the specified type and an optional parameter using char or Unicode string. + * + * - The mark of "discrete" type is placed to collection results in case of success. It appears in overtime view(s) as a special tick sign. + * - The call is "synchronous" - function returns after mark is actually added to results. + * - This function is useful, for example, to mark different phases of application + * (beginning of the next mark automatically meand end of current region). + * - Can be used together with "continuous" marks (see below) at the same collection session + * @param[in] mt - mark, created by __itt_mark_create(const char* name) function + * @param[in] parameter - string parameter of mark + * @return Returns zero value in case of success, non-zero value otherwise. + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +int ITTAPI __itt_markA(__itt_mark_type mt, const char *parameter); +int ITTAPI __itt_markW(__itt_mark_type mt, const wchar_t *parameter); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_mark __itt_markW +# define __itt_mark_ptr __itt_markW_ptr +#else /* UNICODE */ +# define __itt_mark __itt_markA +# define __itt_mark_ptr __itt_markA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +int ITTAPI __itt_mark(__itt_mark_type mt, const char *parameter); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, int, markA, (__itt_mark_type mt, const char *parameter)) +ITT_STUB(ITTAPI, int, markW, (__itt_mark_type mt, const wchar_t *parameter)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, int, mark, (__itt_mark_type mt, const char *parameter)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_markA ITTNOTIFY_DATA(markA) +#define __itt_markA_ptr ITTNOTIFY_NAME(markA) +#define __itt_markW ITTNOTIFY_DATA(markW) +#define __itt_markW_ptr ITTNOTIFY_NAME(markW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark ITTNOTIFY_DATA(mark) +#define __itt_mark_ptr ITTNOTIFY_NAME(mark) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_markA(mt, parameter) (int)0 +#define __itt_markA_ptr 0 +#define __itt_markW(mt, parameter) (int)0 +#define __itt_markW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark(mt, parameter) (int)0 +#define __itt_mark_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_markA_ptr 0 +#define __itt_markW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Use this if necessary to create a "discrete" user event type (mark) for process + * rather then for one thread + * @see int __itt_mark(__itt_mark_type mt, const char* parameter); + */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +int ITTAPI __itt_mark_globalA(__itt_mark_type mt, const char *parameter); +int ITTAPI __itt_mark_globalW(__itt_mark_type mt, const wchar_t *parameter); +#if defined(UNICODE) || defined(_UNICODE) +# define __itt_mark_global __itt_mark_globalW +# define __itt_mark_global_ptr __itt_mark_globalW_ptr +#else /* UNICODE */ +# define __itt_mark_global __itt_mark_globalA +# define __itt_mark_global_ptr __itt_mark_globalA_ptr +#endif /* UNICODE */ +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +int ITTAPI __itt_mark_global(__itt_mark_type mt, const char *parameter); +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#if ITT_PLATFORM==ITT_PLATFORM_WIN +ITT_STUB(ITTAPI, int, mark_globalA, (__itt_mark_type mt, const char *parameter)) +ITT_STUB(ITTAPI, int, mark_globalW, (__itt_mark_type mt, const wchar_t *parameter)) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +ITT_STUB(ITTAPI, int, mark_global, (__itt_mark_type mt, const char *parameter)) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_mark_globalA ITTNOTIFY_DATA(mark_globalA) +#define __itt_mark_globalA_ptr ITTNOTIFY_NAME(mark_globalA) +#define __itt_mark_globalW ITTNOTIFY_DATA(mark_globalW) +#define __itt_mark_globalW_ptr ITTNOTIFY_NAME(mark_globalW) +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_global ITTNOTIFY_DATA(mark_global) +#define __itt_mark_global_ptr ITTNOTIFY_NAME(mark_global) +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#else /* INTEL_NO_ITTNOTIFY_API */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_mark_globalA(mt, parameter) (int)0 +#define __itt_mark_globalA_ptr 0 +#define __itt_mark_globalW(mt, parameter) (int)0 +#define __itt_mark_globalW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_global(mt, parameter) (int)0 +#define __itt_mark_global_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#if ITT_PLATFORM==ITT_PLATFORM_WIN +#define __itt_mark_globalA_ptr 0 +#define __itt_mark_globalW_ptr 0 +#else /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#define __itt_mark_global_ptr 0 +#endif /* ITT_PLATFORM==ITT_PLATFORM_WIN */ +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Creates an "end" point for "continuous" mark with specified name. + * + * - Returns zero value in case of success, non-zero value otherwise. + * Also returns non-zero value when preceding "begin" point for the + * mark with the same name failed to be created or not created. + * - The mark of "continuous" type is placed to collection results in + * case of success. It appears in overtime view(s) as a special tick + * sign (different from "discrete" mark) together with line from + * corresponding "begin" mark to "end" mark. + * @note Continuous marks can overlap and be nested inside each other. + * Discrete mark can be nested inside marked region + * @param[in] mt - mark, created by __itt_mark_create(const char* name) function + * @return Returns zero value in case of success, non-zero value otherwise. + */ +int ITTAPI __itt_mark_off(__itt_mark_type mt); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, int, mark_off, (__itt_mark_type mt)) +#define __itt_mark_off ITTNOTIFY_DATA(mark_off) +#define __itt_mark_off_ptr ITTNOTIFY_NAME(mark_off) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_mark_off(mt) (int)0 +#define __itt_mark_off_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_mark_off_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Use this if necessary to create an "end" point for mark of process + * @see int __itt_mark_off(__itt_mark_type mt); + */ +int ITTAPI __itt_mark_global_off(__itt_mark_type mt); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, int, mark_global_off, (__itt_mark_type mt)) +#define __itt_mark_global_off ITTNOTIFY_DATA(mark_global_off) +#define __itt_mark_global_off_ptr ITTNOTIFY_NAME(mark_global_off) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_mark_global_off(mt) (int)0 +#define __itt_mark_global_off_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_mark_global_off_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ +/** @} marks group */ + +/** + * @defgroup counters_internal Counters + * @ingroup internal + * Counters group + * @{ + */ + + +/** + * @defgroup stitch Stack Stitching + * @ingroup internal + * Stack Stitching group + * @{ + */ +/** + * @brief opaque structure for counter identification + */ +typedef struct ___itt_caller *__itt_caller; + +/** + * @brief Create the stitch point e.g. a point in call stack where other stacks should be stitched to. + * The function returns a unique identifier which is used to match the cut points with corresponding stitch points. + */ +__itt_caller ITTAPI __itt_stack_caller_create(void); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUB(ITTAPI, __itt_caller, stack_caller_create, (void)) +#define __itt_stack_caller_create ITTNOTIFY_DATA(stack_caller_create) +#define __itt_stack_caller_create_ptr ITTNOTIFY_NAME(stack_caller_create) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_stack_caller_create() (__itt_caller)0 +#define __itt_stack_caller_create_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_stack_caller_create_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Destroy the information about stitch point identified by the pointer previously returned by __itt_stack_caller_create() + */ +void ITTAPI __itt_stack_caller_destroy(__itt_caller id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, stack_caller_destroy, (__itt_caller id)) +#define __itt_stack_caller_destroy ITTNOTIFY_VOID(stack_caller_destroy) +#define __itt_stack_caller_destroy_ptr ITTNOTIFY_NAME(stack_caller_destroy) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_stack_caller_destroy(id) +#define __itt_stack_caller_destroy_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_stack_caller_destroy_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief Sets the cut point. Stack from each event which occurs after this call will be cut + * at the same stack level the function was called and stitched to the corresponding stitch point. + */ +void ITTAPI __itt_stack_callee_enter(__itt_caller id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, stack_callee_enter, (__itt_caller id)) +#define __itt_stack_callee_enter ITTNOTIFY_VOID(stack_callee_enter) +#define __itt_stack_callee_enter_ptr ITTNOTIFY_NAME(stack_callee_enter) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_stack_callee_enter(id) +#define __itt_stack_callee_enter_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_stack_callee_enter_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** + * @brief This function eliminates the cut point which was set by latest __itt_stack_callee_enter(). + */ +void ITTAPI __itt_stack_callee_leave(__itt_caller id); + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +ITT_STUBV(ITTAPI, void, stack_callee_leave, (__itt_caller id)) +#define __itt_stack_callee_leave ITTNOTIFY_VOID(stack_callee_leave) +#define __itt_stack_callee_leave_ptr ITTNOTIFY_NAME(stack_callee_leave) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_stack_callee_leave(id) +#define __itt_stack_callee_leave_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_stack_callee_leave_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +/** @} stitch group */ + +/* ***************************************************************************************************************************** */ + +#include + +/** @cond exclude_from_documentation */ +typedef enum __itt_error_code +{ + __itt_error_success = 0, /*!< no error */ + __itt_error_no_module = 1, /*!< module can't be loaded */ + /* %1$s -- library name; win: %2$d -- system error code; unx: %2$s -- system error message. */ + __itt_error_no_symbol = 2, /*!< symbol not found */ + /* %1$s -- library name, %2$s -- symbol name. */ + __itt_error_unknown_group = 3, /*!< unknown group specified */ + /* %1$s -- env var name, %2$s -- group name. */ + __itt_error_cant_read_env = 4, /*!< GetEnvironmentVariable() failed */ + /* %1$s -- env var name, %2$d -- system error. */ + __itt_error_env_too_long = 5, /*!< variable value too long */ + /* %1$s -- env var name, %2$d -- actual length of the var, %3$d -- max allowed length. */ + __itt_error_system = 6 /*!< pthread_mutexattr_init or pthread_mutex_init failed */ + /* %1$s -- function name, %2$d -- errno. */ +} __itt_error_code; + +typedef void (__itt_error_handler_t)(__itt_error_code code, va_list); +__itt_error_handler_t* __itt_set_error_handler(__itt_error_handler_t*); + +const char* ITTAPI __itt_api_version(void); +/** @endcond */ + +/** @cond exclude_from_documentation */ +#ifndef INTEL_NO_MACRO_BODY +#ifndef INTEL_NO_ITTNOTIFY_API +#define __itt_error_handler ITT_JOIN(INTEL_ITTNOTIFY_PREFIX, error_handler) +void __itt_error_handler(__itt_error_code code, va_list args); +extern const int ITTNOTIFY_NAME(err); +#define __itt_err ITTNOTIFY_NAME(err) +ITT_STUB(ITTAPI, const char*, api_version, (void)) +#define __itt_api_version ITTNOTIFY_DATA(api_version) +#define __itt_api_version_ptr ITTNOTIFY_NAME(api_version) +#else /* INTEL_NO_ITTNOTIFY_API */ +#define __itt_api_version() (const char*)0 +#define __itt_api_version_ptr 0 +#endif /* INTEL_NO_ITTNOTIFY_API */ +#else /* INTEL_NO_MACRO_BODY */ +#define __itt_api_version_ptr 0 +#endif /* INTEL_NO_MACRO_BODY */ +/** @endcond */ + +#ifdef __cplusplus +} +#endif /* __cplusplus */ + +#endif /* _ITTNOTIFY_PRIVATE_ */ + +#endif /* INTEL_ITTNOTIFY_API_PRIVATE */ diff --git a/phivenv/Lib/site-packages/torch/include/jitprofiling.h b/phivenv/Lib/site-packages/torch/include/jitprofiling.h new file mode 100644 index 0000000000000000000000000000000000000000..01dcd6863ec3701643dd1ebd63cecc29284abb41 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/jitprofiling.h @@ -0,0 +1,642 @@ +/* + Copyright (C) 2005-2019 Intel Corporation + + SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause +*/ + +#ifndef __JITPROFILING_H__ +#define __JITPROFILING_H__ + +/** + * @brief JIT Profiling APIs + * + * The JIT Profiling API is used to report information about just-in-time + * generated code that can be used by performance tools. The user inserts + * calls in the code generator to report information before JIT-compiled + * code goes to execution. This information is collected at runtime and used + * by tools like Intel(R) VTune(TM) Profiler to display performance metrics + * associated with JIT-compiled code. + * + * These APIs can be used to\n + * - **Profile trace-based and method-based JIT-compiled + * code**. Some examples of environments that you can profile with these APIs: + * dynamic JIT compilation of JavaScript code traces, JIT execution in OpenCL(TM) + * software technology, Java/.NET managed execution environments, and custom + * ISV JIT engines. + * @code + * #include + * + * if (iJIT_IsProfilingActive != iJIT_SAMPLING_ON) { + * return; + * } + * + * iJIT_Method_Load jmethod = {0}; + * jmethod.method_id = iJIT_GetNewMethodID(); + * jmethod.method_name = "method_name"; + * jmethod.class_file_name = "class_name"; + * jmethod.source_file_name = "source_file_name"; + * jmethod.method_load_address = code_addr; + * jmethod.method_size = code_size; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&jmethod); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_SHUTDOWN, NULL); + * @endcode + * + * * Expected behavior: + * * If any iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and its + * memory region is treated as unloaded. VTune Profiler displays the metrics + * collected by the method until it is overwritten. + * * If supplied line number information contains multiple source lines for + * the same assembly instruction (code location), then VTune Profiler picks up + * the first line number. + * * Dynamically generated code can be associated with a module name. + * Use the iJIT_Method_Load_V2 structure.\n + * Clarification of some cases: + * * If you register a function with the same method ID multiple times, + * specifying different module names, then the VTune Profiler picks up + * the module name registered first. If you want to distinguish the same + * function between different JIT engines, supply different method IDs for + * each function. Other symbolic information (for example, source file) + * can be identical. + * + * - **Analyze split functions** (multiple joint or disjoint code regions + * belonging to the same function) **including re-JIT** + * with potential overlapping of code regions in time, which is common in + * resource-limited environments. + * @code + * #include + * + * unsigned int method_id = iJIT_GetNewMethodID(); + * + * iJIT_Method_Load a = {0}; + * a.method_id = method_id; + * a.method_load_address = 0x100; + * a.method_size = 0x20; + * + * iJIT_Method_Load b = {0}; + * b.method_id = method_id; + * b.method_load_address = 0x200; + * b.method_size = 0x30; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&b); + * @endcode + * + * * Expected behaviour: + * * If a iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED event overwrites an + * already reported method, then such a method becomes invalid and + * its memory region is treated as unloaded. + * * All code regions reported with the same method ID are considered as + * belonging to the same method. Symbolic information (method name, + * source file name) will be taken from the first notification, and all + * subsequent notifications with the same method ID will be processed + * only for line number table information. So, the VTune Profiler will map + * samples to a source line using the line number table from the current + * notification while taking the source file name from the very first one.\n + * Clarification of some cases:\n + * * If you register a second code region with a different source file + * name and the same method ID, then this information will be saved and + * will not be considered as an extension of the first code region, but + * VTune Profiler will use the source file of the first code region and map + * performance metrics incorrectly. + * * If you register a second code region with the same source file as + * for the first region and the same method ID, then the source file will be + * discarded but VTune Profiler will map metrics to the source file correctly. + * * If you register a second code region with a null source file and + * the same method ID, then provided line number info will be associated + * with the source file of the first code region. + * + * - **Explore inline functions** including multi-level hierarchy of + * nested inline methods which shows how performance metrics are distributed through them. + * @code + * #include + * + * // method_id parent_id + * // [-- c --] 3000 2000 + * // [---- d -----] 2001 1000 + * // [---- b ----] 2000 1000 + * // [------------ a ----------------] 1000 n/a + * + * iJIT_Method_Load a = {0}; + * a.method_id = 1000; + * + * iJIT_Method_Inline_Load b = {0}; + * b.method_id = 2000; + * b.parent_method_id = 1000; + * + * iJIT_Method_Inline_Load c = {0}; + * c.method_id = 3000; + * c.parent_method_id = 2000; + * + * iJIT_Method_Inline_Load d = {0}; + * d.method_id = 2001; + * d.parent_method_id = 1000; + * + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED, (void*)&a); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&b); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&c); + * iJIT_NotifyEvent(iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, (void*)&d); + * @endcode + * + * * Requirements: + * * Each inline (iJIT_Method_Inline_Load) method should be associated + * with two method IDs: one for itself; one for its immediate parent. + * * Address regions of inline methods of the same parent method cannot + * overlap each other. + * * Execution of the parent method must not be started until it and all + * its inline methods are reported. + * * Expected behaviour: + * * In case of nested inline methods an order of + * iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED events is not important. + * * If any event overwrites either inline method or top parent method, + * then the parent, including inline methods, becomes invalid and its memory + * region is treated as unloaded. + * + * **Life time of allocated data**\n + * The client sends an event notification to the agent with event-specific + * data, which is a structure. The pointers in the structure refer to memory + * allocated by the client, which responsible for releasing it. The pointers are + * used by the iJIT_NotifyEvent method to copy client's data in a trace file, + * and they are not used after the iJIT_NotifyEvent method returns. + */ + +/** + * @defgroup jitapi JIT Profiling + * @ingroup internal + * @{ + */ + +/** + * @brief Enumerator for the types of notifications + */ +typedef enum iJIT_jvm_event +{ + iJVM_EVENT_TYPE_SHUTDOWN = 2, /**<\brief Send this to shutdown the agent. + * Use NULL for event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED = 13, /**<\brief Send when dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load as event + * data. */ +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UNLOAD_START, /**<\brief Send when compiled dynamic + * code is being unloaded from memory. + * Use iJIT_Method_Load as event data.*/ +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_UPDATE, /**<\brief Send to provide new content for + * a previously reported dynamic code. + * The previous content will be invalidated + * starting from the time of the notification. + * Use iJIT_Method_Load as event data but + * required fields are following: + * - method_id identify the code to update. + * - method_load_address specify start address + * within identified code range + * where update should be started. + * - method_size specify length of updated code + * range. */ + + + iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED, /**<\brief Send when an inline dynamic + * code is JIT compiled and loaded + * into memory by the JIT engine, + * but before the parent code region + * starts executing. + * Use iJIT_Method_Inline_Load as event data.*/ + +/** @cond exclude_from_documentation */ + iJVM_EVENT_TYPE_METHOD_UPDATE_V2, +/** @endcond */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 = 21, /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V2 as event data. */ + + iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 /**<\brief Send when a dynamic code is + * JIT compiled and loaded into + * memory by the JIT engine, but + * before the code is executed. + * Use iJIT_Method_Load_V3 as event data. */ +} iJIT_JVM_EVENT; + +/** + * @brief Enumerator for the agent's mode + */ +typedef enum _iJIT_IsProfilingActiveFlags +{ + iJIT_NOTHING_RUNNING = 0x0000, /**<\brief The agent is not running; + * iJIT_NotifyEvent calls will + * not be processed. */ + iJIT_SAMPLING_ON = 0x0001, /**<\brief The agent is running and + * ready to process notifications. */ +} iJIT_IsProfilingActiveFlags; + +/** + * @brief Description of a single entry in the line number information of a code region. + * @details A table of line number entries gives information about how the reported code region + * is mapped to source file. + * Intel(R) VTune(TM) Profiler uses line number information to attribute + * the samples (virtual address) to a line number. \n + * It is acceptable to report different code addresses for the same source line: + * @code + * Offset LineNumber + * 1 2 + * 12 4 + * 15 2 + * 18 1 + * 21 30 + * + * VTune Profiler constructs the following table using the client data + * + * Code subrange Line number + * 0-1 2 + * 1-12 4 + * 12-15 2 + * 15-18 1 + * 18-21 30 + * @endcode + */ +typedef struct _LineNumberInfo +{ + unsigned int Offset; /**<\brief Offset from the begining of the code region. */ + unsigned int LineNumber; /**<\brief Matching source line number offset (from beginning of source file). */ + +} *pLineNumberInfo, LineNumberInfo; + +/** + * @brief Enumerator for the code architecture. + */ +typedef enum _iJIT_CodeArchitecture +{ + iJIT_CA_NATIVE = 0, /**<\brief Native to the process architecture that is calling it. */ + + iJIT_CA_32, /**<\brief 32-bit machine code. */ + + iJIT_CA_64 /**<\brief 64-bit machine code. */ + +} iJIT_CodeArchitecture; + +#pragma pack(push, 8) + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, data provided with + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table.0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + unsigned int class_id; /**<\brief This field is obsolete. */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Load, iJIT_Method_Load; + +/** + * @brief Description of a JIT-compiled method + * @details When you use the iJIT_Method_Load_V2 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V2 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V2 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise different + * method IDs specify different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + The module name can be useful for distinguishing among + different JIT engines. VTune Profiler will display + reported methods grouped by specific module. */ + +} *piJIT_Method_Load_V2, iJIT_Method_Load_V2; + +/** + * @brief Description of a JIT-compiled method + * @details The iJIT_Method_Load_V3 structure is the same as iJIT_Method_Load_V2 + * with a newly introduced 'arch' field that specifies architecture of the code region. + * When you use the iJIT_Method_Load_V3 structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_LOAD_FINISHED_V3 + * as an event type to report it. + */ +typedef struct _iJIT_Method_Load_V3 +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or manage ID uniqueness + * and correct range by yourself.\n + * You must use the same method ID for all code + * regions of the same method, otherwise they are + * treated as regions of different methods. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Cannot be NULL. */ + + void* method_load_address; /**<\brief The start virtual address of the method code + * region. If NULL, then data provided with the + * event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array. */ + + char* class_file_name; /**<\brief Class name. Can be NULL. */ + + char* source_file_name; /**<\brief Source file name. Can be NULL. */ + + char* module_name; /**<\brief Module name. Can be NULL. + * The module name can be useful for distinguishing among + * different JIT engines. VTune Profiler will display + * reported methods grouped by specific module. */ + + iJIT_CodeArchitecture module_arch; /**<\brief Architecture of the method's code region. + * By default, it is the same as the process + * architecture that is calling it. + * For example, you can use it if your 32-bit JIT + * engine generates 64-bit code. + * + * If JIT engine reports both 32-bit and 64-bit types + * of methods then VTune Profiler splits the methods + * with the same module name but with different + * architectures in two different modules. VTune Profiler + * modifies the original name provided with a 64-bit method + * version by ending it with '(64)' */ + +} *piJIT_Method_Load_V3, iJIT_Method_Load_V3; + +/** + * @brief Description of an inline JIT-compiled method + * @details When you use the_iJIT_Method_Inline_Load structure to describe + * the JIT compiled method, use iJVM_EVENT_TYPE_METHOD_INLINE_LOAD_FINISHED + * as an event type to report it. + */ +typedef struct _iJIT_Method_Inline_Load +{ + unsigned int method_id; /**<\brief Unique method ID. Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + unsigned int parent_method_id; /**<\brief Unique immediate parent's method ID. + * Cannot be 0. + * You must either use the API function + * iJIT_GetNewMethodID to get a valid and unique + * method ID, or else manage ID uniqueness + * and correct range by yourself. */ + + char* method_name; /**<\brief The name of the method. It can be optionally + * prefixed with its class name and appended with + * its complete signature. Can't be NULL. */ + + void* method_load_address; /** <\brief The virtual address on which the method + * is inlined. If NULL, then data provided with + * the event are not accepted. */ + + unsigned int method_size; /**<\brief The code size of the method in memory. + * If 0, then data provided with the event are not + * accepted. */ + + unsigned int line_number_size; /**<\brief The number of entries in the line number + * table. 0 if none. */ + + pLineNumberInfo line_number_table; /**<\brief Pointer to the line numbers info + * array. Can be NULL if + * line_number_size is 0. See + * LineNumberInfo Structure for a + * description of a single entry in + * the line number info array */ + + char* class_file_name; /**<\brief Class name. Can be NULL.*/ + + char* source_file_name; /**<\brief Source file name. Can be NULL.*/ + +} *piJIT_Method_Inline_Load, iJIT_Method_Inline_Load; + +/** @cond exclude_from_documentation */ +/** + * @brief Description of a segment type + * @details Use the segment type to specify a type of data supplied + * with the iJVM_EVENT_TYPE_METHOD_UPDATE_V2 event to be applied to + * a certain code trace. + */ +typedef enum _iJIT_SegmentType +{ + iJIT_CT_UNKNOWN = 0, + + iJIT_CT_CODE, /**<\brief Executable code. */ + + iJIT_CT_DATA, /**<\brief Data (not executable code). + * VTune Profiler uses the format string + * (see iJIT_Method_Update) to represent + * this data in the VTune Profiler GUI */ + + iJIT_CT_KEEP, /**<\brief Use the previous markup for the trace. + * Can be used for the following + * iJVM_EVENT_TYPE_METHOD_UPDATE_V2 events, + * if the type of the previously reported segment + * type is the same. */ + iJIT_CT_EOF +} iJIT_SegmentType; + +/** + * @brief Description of a dynamic update of the content within JIT-compiled method + * @details The JIT engine may generate the methods that are updated at runtime + * partially by mixed (data + executable code) content. When you use the iJIT_Method_Update + * structure to describe the update of the content within a JIT-compiled method, + * use iJVM_EVENT_TYPE_METHOD_UPDATE_V2 as an event type to report it. + * + * On the first Update event, VTune Profiler copies the original code range reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event, then modifies it with the supplied bytes and + * adds the modified range to the original method. For next update events, VTune Profiler + * does the same but it uses the latest modified version of a code region for update. + * Eventually, VTune Profiler GUI displays multiple code ranges for the method reported by + * the iJVM_EVENT_TYPE_METHOD_LOAD event. + * Notes: + * - Multiple update events with different types for the same trace are allowed + * but they must be reported for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [code] Ignored + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + * - The types of previously reported events can be changed but they must be reported + * for the same code ranges. + * Example, + * @code + * [-- data---] Allowed + * [-- code --] Allowed + * [-- data---] Allowed + * [-- code --] Allowed + * [------------ trace ---------] + * @endcode + */ + +typedef struct _iJIT_Method_Update +{ + void* load_address; /**<\brief Start address of the update within a method */ + + unsigned int size; /**<\brief The update size */ + + iJIT_SegmentType type; /**<\brief Type of the update */ + + const char* data_format; /**<\brief C string that contains a format string + * that follows the same specifications as format in printf. + * The format string is used for iJIT_CT_CODE only + * and cannot be NULL. + * Format can be changed on the fly. */ +} *piJIT_Method_Update, iJIT_Method_Update; + +/** @endcond */ + +#pragma pack(pop) + +/** @cond exclude_from_documentation */ +#ifdef __cplusplus +extern "C" { +#endif /* __cplusplus */ + +#ifndef JITAPI_CDECL +# if defined WIN32 || defined _WIN32 +# define JITAPI_CDECL __cdecl +# else /* defined WIN32 || defined _WIN32 */ +# if defined _M_IX86 || defined __i386__ +# define JITAPI_CDECL __attribute__ ((cdecl)) +# else /* _M_IX86 || __i386__ */ +# define JITAPI_CDECL /* actual only on x86_64 platform */ +# endif /* _M_IX86 || __i386__ */ +# endif /* defined WIN32 || defined _WIN32 */ +#endif /* JITAPI_CDECL */ + +#define JITAPI JITAPI_CDECL +/** @endcond */ + +/** + * @brief Generates a new unique method ID. + * + * You must use this API to obtain unique and valid method IDs for methods or + * traces reported to the agent if you don't have your own mechanism to generate + * unique method IDs. + * + * @return a new unique method ID. When out of unique method IDs, this API + * returns 0, which is not an accepted value. + */ +unsigned int JITAPI iJIT_GetNewMethodID(void); + +/** + * @brief Returns the current mode of the agent. + * + * @return iJIT_SAMPLING_ON, indicating that agent is running, or + * iJIT_NOTHING_RUNNING if no agent is running. + */ +iJIT_IsProfilingActiveFlags JITAPI iJIT_IsProfilingActive(void); + +/** + * @brief Reports infomation about JIT-compiled code to the agent. + * + * The reported information is used to attribute samples obtained from any + * Intel(R) VTune(TM) Profiler collector. This API needs to be called + * after JIT compilation and before the first entry into the JIT-compiled + * code. + * + * @param[in] event_type - type of the data sent to the agent + * @param[in] EventSpecificData - pointer to event-specific data + * + * @returns 1 on success, otherwise 0. + */ +int JITAPI iJIT_NotifyEvent(iJIT_JVM_EVENT event_type, void *EventSpecificData); + +#ifdef __cplusplus +} +#endif /* __cplusplus */ +/** @endcond */ + +/** @} jitapi group */ + +#endif /* __JITPROFILING_H__ */ diff --git a/phivenv/Lib/site-packages/torch/include/libittnotify.h b/phivenv/Lib/site-packages/torch/include/libittnotify.h new file mode 100644 index 0000000000000000000000000000000000000000..c58ba26904c0cfa9b39ebaaf26968e2809767e90 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/libittnotify.h @@ -0,0 +1,19 @@ +/* + Copyright (C) 2005-2019 Intel Corporation + + SPDX-License-Identifier: GPL-2.0-only OR BSD-3-Clause +*/ + +#ifndef _LIBITTNOTIFY_H_ +#define _LIBITTNOTIFY_H_ + +#ifndef __ITT_INTERNAL_INCLUDE +# if defined WIN32 || defined _WIN32 +# pragma message("WARNING!!! Include file libittnotify.h is deprecated and should not be included anymore") +# else /* WIN32 */ +# warning "Include file libittnotify.h is deprecated and should not be included anymore" +# endif /* WIN32 */ +#endif /* __ITT_INTERNAL_INCLUDE */ +#include "legacy/ittnotify.h" + +#endif /* _LIBITTNOTIFY_H_ */ diff --git a/phivenv/Lib/site-packages/torch/include/libshm.h b/phivenv/Lib/site-packages/torch/include/libshm.h new file mode 100644 index 0000000000000000000000000000000000000000..bb916f32cce15081f3b11b40c7d7c0e128283acc --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/libshm.h @@ -0,0 +1,36 @@ +#pragma once + +#include + +#ifdef __cplusplus + +#ifdef SHM_EXPORTS +#define SHM_API __declspec(dllexport) +#else +#define SHM_API __declspec(dllimport) +#endif + +SHM_API void libshm_init(const char* manager_exec_path); + +class SHM_API THManagedMapAllocator : public at::RefcountedMapAllocator { + public: + THManagedMapAllocator( + const char* manager_handle, + const char* filename, + int flags, + size_t size) + : at::RefcountedMapAllocator(filename, flags, size) {} + + static at::DataPtr makeDataPtr( + const char* manager_handle, + const char* filename, + int flags, + size_t size); + static THManagedMapAllocator* fromDataPtr(const at::DataPtr&); + + const char* manager_handle() const { + return "no_manager"; + } +}; + +#endif diff --git a/phivenv/Lib/site-packages/torch/include/psimd.h b/phivenv/Lib/site-packages/torch/include/psimd.h new file mode 100644 index 0000000000000000000000000000000000000000..824d3ad546827eebcbc7bcd64ee7ff11ee5491b4 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/psimd.h @@ -0,0 +1,1384 @@ +#pragma once +#ifndef PSIMD_H +#define PSIMD_H + +#if defined(__CUDA_ARCH__) + /* CUDA compiler */ + #define PSIMD_INTRINSIC __forceinline__ __device__ +#elif defined(__OPENCL_VERSION__) + /* OpenCL compiler */ + #define PSIMD_INTRINSIC inline static +#elif defined(__INTEL_COMPILER) + /* Intel compiler, even on Windows */ + #define PSIMD_INTRINSIC inline static __attribute__((__always_inline__)) +#elif defined(__GNUC__) + /* GCC-compatible compiler (gcc/clang/icc) */ + #define PSIMD_INTRINSIC inline static __attribute__((__always_inline__)) +#elif defined(_MSC_VER) + /* MSVC-compatible compiler (cl/icl/clang-cl) */ + #define PSIMD_INTRINSIC __forceinline static +#elif defined(__cplusplus) + /* Generic C++ compiler */ + #define PSIMD_INTRINSIC inline static +#elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) + /* Generic C99 compiler */ + #define PSIMD_INTRINSIC inline static +#else + /* Generic C compiler */ + #define PSIMD_INTRINSIC static +#endif + +#if defined(__GNUC__) || defined(__clang__) + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + #include + #endif + + #if defined(__SSE2__) + #include + #endif + + #if defined(__SSE3__) + #include + #endif + + #if defined(__SSSE3__) + #include + #endif + + #if defined(__SSE4_1__) + #include + #endif + + #if defined(__SSE4_2__) + #include + #endif + + #if defined(__AVX__) + #include + #endif +#elif defined(_MSC_VER) + #include +#endif + +#if defined(__cplusplus) + #define PSIMD_CXX_SYNTAX +#elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201112L) + #define PSIMD_C11_SYNTAX +#elif defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 199901L) + #define PSIMD_C99_SYNTAX +#else + #define PSIMD_C89_SYNTAX +#endif + +#if defined(__cplusplus) && (__cplusplus >= 201103L) + #include + #include +#elif !defined(__OPENCL_VERSION__) + #include + #include +#endif + +#if defined(__GNUC__) || defined(__clang__) + #define PSIMD_HAVE_F64 0 + #define PSIMD_HAVE_F32 1 + #define PSIMD_HAVE_U8 1 + #define PSIMD_HAVE_S8 1 + #define PSIMD_HAVE_U16 1 + #define PSIMD_HAVE_S16 1 + #define PSIMD_HAVE_U32 1 + #define PSIMD_HAVE_S32 1 + #define PSIMD_HAVE_U64 0 + #define PSIMD_HAVE_S64 0 + + typedef int8_t psimd_s8 __attribute__((vector_size(16), aligned(1))); + typedef uint8_t psimd_u8 __attribute__((vector_size(16), aligned(1))); + typedef int16_t psimd_s16 __attribute__((vector_size(16), aligned(2))); + typedef uint16_t psimd_u16 __attribute__((vector_size(16), aligned(2))); + typedef int32_t psimd_s32 __attribute__((vector_size(16), aligned(4))); + typedef uint32_t psimd_u32 __attribute__((vector_size(16), aligned(4))); + typedef float psimd_f32 __attribute__((vector_size(16), aligned(4))); + + typedef struct { + psimd_s8 lo; + psimd_s8 hi; + } psimd_s8x2; + + typedef struct { + psimd_u8 lo; + psimd_u8 hi; + } psimd_u8x2; + + typedef struct { + psimd_s16 lo; + psimd_s16 hi; + } psimd_s16x2; + + typedef struct { + psimd_u16 lo; + psimd_u16 hi; + } psimd_u16x2; + + typedef struct { + psimd_s32 lo; + psimd_s32 hi; + } psimd_s32x2; + + typedef struct { + psimd_u32 lo; + psimd_u32 hi; + } psimd_u32x2; + + typedef struct { + psimd_f32 lo; + psimd_f32 hi; + } psimd_f32x2; + + /* Bit casts */ + PSIMD_INTRINSIC psimd_u32x2 psimd_cast_s32x2_u32x2(psimd_s32x2 v) { + return (psimd_u32x2) { .lo = (psimd_u32) v.lo, .hi = (psimd_u32) v.hi }; + } + + PSIMD_INTRINSIC psimd_f32x2 psimd_cast_s32x2_f32x2(psimd_s32x2 v) { + return (psimd_f32x2) { .lo = (psimd_f32) v.lo, .hi = (psimd_f32) v.hi }; + } + + PSIMD_INTRINSIC psimd_s32x2 psimd_cast_u32x2_s32x2(psimd_u32x2 v) { + return (psimd_s32x2) { .lo = (psimd_s32) v.lo, .hi = (psimd_s32) v.hi }; + } + + PSIMD_INTRINSIC psimd_f32x2 psimd_cast_u32x2_f32x2(psimd_u32x2 v) { + return (psimd_f32x2) { .lo = (psimd_f32) v.lo, .hi = (psimd_f32) v.hi }; + } + + PSIMD_INTRINSIC psimd_s32x2 psimd_cast_f32x2_s32x2(psimd_f32x2 v) { + return (psimd_s32x2) { .lo = (psimd_s32) v.lo, .hi = (psimd_s32) v.hi }; + } + + PSIMD_INTRINSIC psimd_u32x2 psimd_cast_f32x2_u32x2(psimd_f32x2 v) { + return (psimd_u32x2) { .lo = (psimd_u32) v.lo, .hi = (psimd_u32) v.hi }; + } + + /* Swap */ + PSIMD_INTRINSIC void psimd_swap_s8(psimd_s8 a[1], psimd_s8 b[1]) { + const psimd_s8 new_a = *b; + const psimd_s8 new_b = *a; + *a = new_a; + *b = new_b; + } + + PSIMD_INTRINSIC void psimd_swap_u8(psimd_u8 a[1], psimd_u8 b[1]) { + const psimd_u8 new_a = *b; + const psimd_u8 new_b = *a; + *a = new_a; + *b = new_b; + } + + PSIMD_INTRINSIC void psimd_swap_s16(psimd_s16 a[1], psimd_s16 b[1]) { + const psimd_s16 new_a = *b; + const psimd_s16 new_b = *a; + *a = new_a; + *b = new_b; + } + + PSIMD_INTRINSIC void psimd_swap_u16(psimd_u16 a[1], psimd_u16 b[1]) { + const psimd_u16 new_a = *b; + const psimd_u16 new_b = *a; + *a = new_a; + *b = new_b; + } + + PSIMD_INTRINSIC void psimd_swap_s32(psimd_s32 a[1], psimd_s32 b[1]) { + const psimd_s32 new_a = *b; + const psimd_s32 new_b = *a; + *a = new_a; + *b = new_b; + } + + PSIMD_INTRINSIC void psimd_swap_u32(psimd_u32 a[1], psimd_u32 b[1]) { + const psimd_u32 new_a = *b; + const psimd_u32 new_b = *a; + *a = new_a; + *b = new_b; + } + + PSIMD_INTRINSIC void psimd_swap_f32(psimd_f32 a[1], psimd_f32 b[1]) { + const psimd_f32 new_a = *b; + const psimd_f32 new_b = *a; + *a = new_a; + *b = new_b; + } + + /* Zero-initialization */ + PSIMD_INTRINSIC psimd_s8 psimd_zero_s8(void) { + return (psimd_s8) { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_u8 psimd_zero_u8(void) { + return (psimd_u8) { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_s16 psimd_zero_s16(void) { + return (psimd_s16) { 0, 0, 0, 0, 0, 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_u16 psimd_zero_u16(void) { + return (psimd_u16) { 0, 0, 0, 0, 0, 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_s32 psimd_zero_s32(void) { + return (psimd_s32) { 0, 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_u32 psimd_zero_u32(void) { + return (psimd_u32) { 0, 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_zero_f32(void) { + return (psimd_f32) { 0.0f, 0.0f, 0.0f, 0.0f }; + } + + /* Initialization to the same constant */ + PSIMD_INTRINSIC psimd_s8 psimd_splat_s8(int8_t c) { + return (psimd_s8) { c, c, c, c, c, c, c, c, c, c, c, c, c, c, c, c }; + } + + PSIMD_INTRINSIC psimd_u8 psimd_splat_u8(uint8_t c) { + return (psimd_u8) { c, c, c, c, c, c, c, c, c, c, c, c, c, c, c, c }; + } + + PSIMD_INTRINSIC psimd_s16 psimd_splat_s16(int16_t c) { + return (psimd_s16) { c, c, c, c, c, c, c, c }; + } + + PSIMD_INTRINSIC psimd_u16 psimd_splat_u16(uint16_t c) { + return (psimd_u16) { c, c, c, c, c, c, c, c }; + } + + PSIMD_INTRINSIC psimd_s32 psimd_splat_s32(int32_t c) { + return (psimd_s32) { c, c, c, c }; + } + + PSIMD_INTRINSIC psimd_u32 psimd_splat_u32(uint32_t c) { + return (psimd_u32) { c, c, c, c }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat_f32(float c) { + return (psimd_f32) { c, c, c, c }; + } + + /* Load vector */ + PSIMD_INTRINSIC psimd_s8 psimd_load_s8(const void* address) { + return *((const psimd_s8*) address); + } + + PSIMD_INTRINSIC psimd_u8 psimd_load_u8(const void* address) { + return *((const psimd_u8*) address); + } + + PSIMD_INTRINSIC psimd_s16 psimd_load_s16(const void* address) { + return *((const psimd_s16*) address); + } + + PSIMD_INTRINSIC psimd_u16 psimd_load_u16(const void* address) { + return *((const psimd_u16*) address); + } + + PSIMD_INTRINSIC psimd_s32 psimd_load_s32(const void* address) { + return *((const psimd_s32*) address); + } + + PSIMD_INTRINSIC psimd_u32 psimd_load_u32(const void* address) { + return *((const psimd_u32*) address); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load_f32(const void* address) { + return *((const psimd_f32*) address); + } + + PSIMD_INTRINSIC psimd_s8 psimd_load_splat_s8(const void* address) { + return psimd_splat_s8(*((const int8_t*) address)); + } + + PSIMD_INTRINSIC psimd_u8 psimd_load_splat_u8(const void* address) { + return psimd_splat_u8(*((const uint8_t*) address)); + } + + PSIMD_INTRINSIC psimd_s16 psimd_load_splat_s16(const void* address) { + return psimd_splat_s16(*((const int16_t*) address)); + } + + PSIMD_INTRINSIC psimd_u16 psimd_load_splat_u16(const void* address) { + return psimd_splat_u16(*((const uint16_t*) address)); + } + + PSIMD_INTRINSIC psimd_s32 psimd_load_splat_s32(const void* address) { + return psimd_splat_s32(*((const int32_t*) address)); + } + + PSIMD_INTRINSIC psimd_u32 psimd_load_splat_u32(const void* address) { + return psimd_splat_u32(*((const uint32_t*) address)); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load_splat_f32(const void* address) { + return psimd_splat_f32(*((const float*) address)); + } + + PSIMD_INTRINSIC psimd_s32 psimd_load1_s32(const void* address) { + return (psimd_s32) { *((const int32_t*) address), 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_u32 psimd_load1_u32(const void* address) { + return (psimd_u32) { *((const uint32_t*) address), 0, 0, 0 }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load1_f32(const void* address) { + return (psimd_f32) { *((const float*) address), 0.0f, 0.0f, 0.0f }; + } + + PSIMD_INTRINSIC psimd_s32 psimd_load2_s32(const void* address) { + const int32_t* address_s32 = (const int32_t*) address; + return (psimd_s32) { address_s32[0], address_s32[1], 0, 0 }; + } + + PSIMD_INTRINSIC psimd_u32 psimd_load2_u32(const void* address) { + const uint32_t* address_u32 = (const uint32_t*) address; + return (psimd_u32) { address_u32[0], address_u32[1], 0, 0 }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load2_f32(const void* address) { + const float* address_f32 = (const float*) address; + return (psimd_f32) { address_f32[0], address_f32[1], 0.0f, 0.0f }; + } + + PSIMD_INTRINSIC psimd_s32 psimd_load3_s32(const void* address) { + const int32_t* address_s32 = (const int32_t*) address; + return (psimd_s32) { address_s32[0], address_s32[1], address_s32[2], 0 }; + } + + PSIMD_INTRINSIC psimd_u32 psimd_load3_u32(const void* address) { + const uint32_t* address_u32 = (const uint32_t*) address; + return (psimd_u32) { address_u32[0], address_u32[1], address_u32[2], 0 }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load3_f32(const void* address) { + const float* address_f32 = (const float*) address; + return (psimd_f32) { address_f32[0], address_f32[1], address_f32[2], 0.0f }; + } + + PSIMD_INTRINSIC psimd_s32 psimd_load4_s32(const void* address) { + return psimd_load_s32(address); + } + + PSIMD_INTRINSIC psimd_u32 psimd_load4_u32(const void* address) { + return psimd_load_u32(address); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load4_f32(const void* address) { + return psimd_load_f32(address); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load_stride2_f32(const void* address) { + const psimd_f32 v0x1x = psimd_load_f32(address); + const psimd_f32 vx2x3 = psimd_load_f32((const float*) address + 3); + #if defined(__clang__) + return __builtin_shufflevector(v0x1x, vx2x3, 0, 2, 5, 7); + #else + return __builtin_shuffle(v0x1x, vx2x3, (psimd_s32) { 0, 2, 5, 7 }); + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_load1_stride2_f32(const void* address) { + return psimd_load_f32(address); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load2_stride2_f32(const void* address) { + const float* address_f32 = (const float*) address; + return (psimd_f32) { address_f32[0], address_f32[2], 0.0f, 0.0f }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load3_stride2_f32(const void* address) { + const psimd_f32 v0x1x = psimd_load_f32(address); + const psimd_f32 v2zzz = psimd_load1_f32((const float*) address + 2); + #if defined(__clang__) + return __builtin_shufflevector(v0x1x, v2zzz, 0, 2, 4, 6); + #else + return __builtin_shuffle(v0x1x, v2zzz, (psimd_s32) { 0, 2, 4, 6 }); + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_load4_stride2_f32(const void* address) { + return psimd_load_stride2_f32(address); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load_stride_f32(const void* address, size_t stride) { + const float* address0_f32 = (const float*) address; + const float* address1_f32 = address0_f32 + stride; + const float* address2_f32 = address1_f32 + stride; + const float* address3_f32 = address2_f32 + stride; + return (psimd_f32) { *address0_f32, *address1_f32, *address2_f32, *address3_f32 }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load1_stride_f32(const void* address, size_t stride) { + return psimd_load1_f32(address); + } + + PSIMD_INTRINSIC psimd_f32 psimd_load2_stride_f32(const void* address, size_t stride) { + const float* address_f32 = (const float*) address; + return (psimd_f32) { address_f32[0], address_f32[stride], 0.0f, 0.0f }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load3_stride_f32(const void* address, size_t stride) { + const float* address0_f32 = (const float*) address; + const float* address1_f32 = address0_f32 + stride; + const float* address2_f32 = address1_f32 + stride; + return (psimd_f32) { *address0_f32, *address1_f32, *address2_f32, 0.0f }; + } + + PSIMD_INTRINSIC psimd_f32 psimd_load4_stride_f32(const void* address, size_t stride) { + return psimd_load_stride_f32(address, stride); + } + + /* Store vector */ + PSIMD_INTRINSIC void psimd_store_s8(void* address, psimd_s8 value) { + *((psimd_s8*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store_u8(void* address, psimd_u8 value) { + *((psimd_u8*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store_s16(void* address, psimd_s16 value) { + *((psimd_s16*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store_u16(void* address, psimd_u16 value) { + *((psimd_u16*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store_s32(void* address, psimd_s32 value) { + *((psimd_s32*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store_u32(void* address, psimd_u32 value) { + *((psimd_u32*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store_f32(void* address, psimd_f32 value) { + *((psimd_f32*) address) = value; + } + + PSIMD_INTRINSIC void psimd_store1_s32(void* address, psimd_s32 value) { + *((int32_t*) address) = value[0]; + } + + PSIMD_INTRINSIC void psimd_store1_u32(void* address, psimd_u32 value) { + *((uint32_t*) address) = value[0]; + } + + PSIMD_INTRINSIC void psimd_store1_f32(void* address, psimd_f32 value) { + *((float*) address) = value[0]; + } + + PSIMD_INTRINSIC void psimd_store2_s32(void* address, psimd_s32 value) { + int32_t* address_s32 = (int32_t*) address; + address_s32[0] = value[0]; + address_s32[1] = value[1]; + } + + PSIMD_INTRINSIC void psimd_store2_u32(void* address, psimd_u32 value) { + uint32_t* address_u32 = (uint32_t*) address; + address_u32[0] = value[0]; + address_u32[1] = value[1]; + } + + PSIMD_INTRINSIC void psimd_store2_f32(void* address, psimd_f32 value) { + float* address_f32 = (float*) address; + address_f32[0] = value[0]; + address_f32[1] = value[1]; + } + + PSIMD_INTRINSIC void psimd_store3_s32(void* address, psimd_s32 value) { + int32_t* address_s32 = (int32_t*) address; + address_s32[0] = value[0]; + address_s32[1] = value[1]; + address_s32[2] = value[2]; + } + + PSIMD_INTRINSIC void psimd_store3_u32(void* address, psimd_u32 value) { + uint32_t* address_u32 = (uint32_t*) address; + address_u32[0] = value[0]; + address_u32[1] = value[1]; + address_u32[2] = value[2]; + } + + PSIMD_INTRINSIC void psimd_store3_f32(void* address, psimd_f32 value) { + float* address_f32 = (float*) address; + address_f32[0] = value[0]; + address_f32[1] = value[1]; + address_f32[2] = value[2]; + } + + PSIMD_INTRINSIC void psimd_store4_s32(void* address, psimd_s32 value) { + psimd_store_s32(address, value); + } + + PSIMD_INTRINSIC void psimd_store4_u32(void* address, psimd_u32 value) { + psimd_store_u32(address, value); + } + + PSIMD_INTRINSIC void psimd_store4_f32(void* address, psimd_f32 value) { + psimd_store_f32(address, value); + } + + PSIMD_INTRINSIC void psimd_store_stride_f32(void* address, size_t stride, psimd_f32 value) { + float* address0_f32 = (float*) address; + float* address1_f32 = address0_f32 + stride; + float* address2_f32 = address1_f32 + stride; + float* address3_f32 = address2_f32 + stride; + *address0_f32 = value[0]; + *address1_f32 = value[1]; + *address2_f32 = value[2]; + *address3_f32 = value[3]; + } + + PSIMD_INTRINSIC void psimd_store1_stride_f32(void* address, size_t stride, psimd_f32 value) { + psimd_store1_f32(address, value); + } + + PSIMD_INTRINSIC void psimd_store2_stride_f32(void* address, size_t stride, psimd_f32 value) { + float* address_f32 = (float*) address; + address_f32[0] = value[0]; + address_f32[stride] = value[1]; + } + + PSIMD_INTRINSIC void psimd_store3_stride_f32(void* address, size_t stride, psimd_f32 value) { + float* address0_f32 = (float*) address; + float* address1_f32 = address0_f32 + stride; + float* address2_f32 = address1_f32 + stride; + *address0_f32 = value[0]; + *address1_f32 = value[1]; + *address2_f32 = value[2]; + } + + /* Vector addition */ + PSIMD_INTRINSIC psimd_s8 psimd_add_s8(psimd_s8 a, psimd_s8 b) { + return a + b; + } + + PSIMD_INTRINSIC psimd_u8 psimd_add_u8(psimd_u8 a, psimd_u8 b) { + return a + b; + } + + PSIMD_INTRINSIC psimd_s16 psimd_add_s16(psimd_s16 a, psimd_s16 b) { + return a + b; + } + + PSIMD_INTRINSIC psimd_u16 psimd_add_u16(psimd_u16 a, psimd_u16 b) { + return a + b; + } + + PSIMD_INTRINSIC psimd_s32 psimd_add_s32(psimd_s32 a, psimd_s32 b) { + return a + b; + } + + PSIMD_INTRINSIC psimd_u32 psimd_add_u32(psimd_u32 a, psimd_u32 b) { + return a + b; + } + + PSIMD_INTRINSIC psimd_f32 psimd_add_f32(psimd_f32 a, psimd_f32 b) { + #if defined(__ARM_ARCH_7A__) && defined(__ARM_NEON__) && !defined(__FAST_MATH__) + return (psimd_f32) vaddq_f32((float32x4_t) a, (float32x4_t) b); + #else + return a + b; + #endif + } + + /* Vector subtraction */ + PSIMD_INTRINSIC psimd_s8 psimd_sub_s8(psimd_s8 a, psimd_s8 b) { + return a - b; + } + + PSIMD_INTRINSIC psimd_u8 psimd_sub_u8(psimd_u8 a, psimd_u8 b) { + return a - b; + } + + PSIMD_INTRINSIC psimd_s16 psimd_sub_s16(psimd_s16 a, psimd_s16 b) { + return a - b; + } + + PSIMD_INTRINSIC psimd_u16 psimd_sub_u16(psimd_u16 a, psimd_u16 b) { + return a - b; + } + + PSIMD_INTRINSIC psimd_s32 psimd_sub_s32(psimd_s32 a, psimd_s32 b) { + return a - b; + } + + PSIMD_INTRINSIC psimd_u32 psimd_sub_u32(psimd_u32 a, psimd_u32 b) { + return a - b; + } + + PSIMD_INTRINSIC psimd_f32 psimd_sub_f32(psimd_f32 a, psimd_f32 b) { + #if defined(__ARM_ARCH_7A__) && defined(__ARM_NEON__) && !defined(__FAST_MATH__) + return (psimd_f32) vsubq_f32((float32x4_t) a, (float32x4_t) b); + #else + return a - b; + #endif + } + + /* Vector multiplication */ + PSIMD_INTRINSIC psimd_s8 psimd_mul_s8(psimd_s8 a, psimd_s8 b) { + return a * b; + } + + PSIMD_INTRINSIC psimd_u8 psimd_mul_u8(psimd_u8 a, psimd_u8 b) { + return a * b; + } + + PSIMD_INTRINSIC psimd_s16 psimd_mul_s16(psimd_s16 a, psimd_s16 b) { + return a * b; + } + + PSIMD_INTRINSIC psimd_u16 psimd_mul_u16(psimd_u16 a, psimd_u16 b) { + return a * b; + } + + PSIMD_INTRINSIC psimd_s32 psimd_mul_s32(psimd_s32 a, psimd_s32 b) { + return a * b; + } + + PSIMD_INTRINSIC psimd_u32 psimd_mul_u32(psimd_u32 a, psimd_u32 b) { + return a * b; + } + + PSIMD_INTRINSIC psimd_f32 psimd_mul_f32(psimd_f32 a, psimd_f32 b) { + #if defined(__ARM_ARCH_7A__) && defined(__ARM_NEON__) && !defined(__FAST_MATH__) + return (psimd_f32) vmulq_f32((float32x4_t) a, (float32x4_t) b); + #else + return a * b; + #endif + } + + /* Quasi-Fused Multiply-Add */ + PSIMD_INTRINSIC psimd_f32 psimd_qfma_f32(psimd_f32 a, psimd_f32 b, psimd_f32 c) { + #if defined(__aarch64__) || defined(__ARM_NEON__) && defined(__ARM_FEATURE_FMA) + return (psimd_f32) vfmaq_f32((float32x4_t) a, (float32x4_t) b, (float32x4_t) c); + #elif (defined(__x86_64__) || defined(__i386__) || defined(__i686__)) && defined(__FMA__) + return (psimd_f32) _mm_fmadd_ps((__m128) b, (__m128) c, (__m128) a); + #elif (defined(__x86_64__) || defined(__i386__) || defined(__i686__)) && defined(__FMA4__) + return (psimd_f32) _mm_macc_ps((__m128) b, (__m128) c, (__m128) a); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) && PSIMD_ENABLE_WASM_QFMA + return (psimd_f32) __builtin_wasm_qfma_f32x4(a, b, c); + #else + return a + b * c; + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_div_f32(psimd_f32 a, psimd_f32 b) { + return a / b; + } + + /* Vector and */ + PSIMD_INTRINSIC psimd_f32 psimd_andmask_f32(psimd_s32 mask, psimd_f32 v) { + return (psimd_f32) (mask & (psimd_s32) v); + } + + /* Vector and-not */ + PSIMD_INTRINSIC psimd_f32 psimd_andnotmask_f32(psimd_s32 mask, psimd_f32 v) { + return (psimd_f32) (~mask & (psimd_s32) v); + } + + /* Vector blend */ + PSIMD_INTRINSIC psimd_s8 psimd_blend_s8(psimd_s8 mask, psimd_s8 a, psimd_s8 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s8) vbslq_s8((uint8x16_t) mask, (int8x16_t) a, (int8x16_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_s8) __builtin_wasm_bitselect(a, b, mask); + #else + return (mask & a) | (~mask & b); + #endif + } + + PSIMD_INTRINSIC psimd_u8 psimd_blend_u8(psimd_s8 mask, psimd_u8 a, psimd_u8 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u8) vbslq_u8((uint8x16_t) mask, (uint8x16_t) a, (uint8x16_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_u8) __builtin_wasm_bitselect(a, b, mask); + #else + return (psimd_u8) ((mask & (psimd_s8) a) | (~mask & (psimd_s8) b)); + #endif + } + + PSIMD_INTRINSIC psimd_s16 psimd_blend_s16(psimd_s16 mask, psimd_s16 a, psimd_s16 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s16) vbslq_s16((uint16x8_t) mask, (int16x8_t) a, (int16x8_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_s16) __builtin_wasm_bitselect(a, b, mask); + #else + return (mask & a) | (~mask & b); + #endif + } + + PSIMD_INTRINSIC psimd_u16 psimd_blend_u16(psimd_s16 mask, psimd_u16 a, psimd_u16 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u16) vbslq_u16((uint16x8_t) mask, (uint16x8_t) a, (uint16x8_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_u16) __builtin_wasm_bitselect(a, b, mask); + #else + return (psimd_u16) ((mask & (psimd_s16) a) | (~mask & (psimd_s16) b)); + #endif + } + + PSIMD_INTRINSIC psimd_s32 psimd_blend_s32(psimd_s32 mask, psimd_s32 a, psimd_s32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s32) vbslq_s32((uint32x4_t) mask, (int32x4_t) a, (int32x4_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_s32) __builtin_wasm_bitselect(a, b, mask); + #else + return (mask & a) | (~mask & b); + #endif + } + + PSIMD_INTRINSIC psimd_u32 psimd_blend_u32(psimd_s32 mask, psimd_u32 a, psimd_u32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u32) vbslq_u32((uint32x4_t) mask, (uint32x4_t) a, (uint32x4_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_u32) __builtin_wasm_bitselect(a, b, mask); + #else + return (psimd_u32) ((mask & (psimd_s32) a) | (~mask & (psimd_s32) b)); + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_blend_f32(psimd_s32 mask, psimd_f32 a, psimd_f32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_f32) vbslq_f32((uint32x4_t) mask, (float32x4_t) a, (float32x4_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return (psimd_f32) __builtin_wasm_bitselect(a, b, mask); + #else + return (psimd_f32) ((mask & (psimd_s32) a) | (~mask & (psimd_s32) b)); + #endif + } + + /* Vector blend on sign */ + PSIMD_INTRINSIC psimd_s8 psimd_signblend_s8(psimd_s8 x, psimd_s8 a, psimd_s8 b) { + return psimd_blend_s8(x >> psimd_splat_s8(7), a, b); + } + + PSIMD_INTRINSIC psimd_u8 psimd_signblend_u8(psimd_s8 x, psimd_u8 a, psimd_u8 b) { + return psimd_blend_u8((x >> psimd_splat_s8(7)), a, b); + } + + PSIMD_INTRINSIC psimd_s16 psimd_signblend_s16(psimd_s16 x, psimd_s16 a, psimd_s16 b) { + return psimd_blend_s16(x >> psimd_splat_s16(15), a, b); + } + + PSIMD_INTRINSIC psimd_u16 psimd_signblend_u16(psimd_s16 x, psimd_u16 a, psimd_u16 b) { + return psimd_blend_u16((x >> psimd_splat_s16(15)), a, b); + } + + PSIMD_INTRINSIC psimd_s32 psimd_signblend_s32(psimd_s32 x, psimd_s32 a, psimd_s32 b) { + return psimd_blend_s32(x >> psimd_splat_s32(31), a, b); + } + + PSIMD_INTRINSIC psimd_u32 psimd_signblend_u32(psimd_s32 x, psimd_u32 a, psimd_u32 b) { + return psimd_blend_u32((x >> psimd_splat_s32(31)), a, b); + } + + PSIMD_INTRINSIC psimd_f32 psimd_signblend_f32(psimd_f32 x, psimd_f32 a, psimd_f32 b) { + const psimd_s32 mask = (psimd_s32) x >> psimd_splat_s32(31); + return psimd_blend_f32(mask, a, b); + } + + /* Vector absolute value */ + PSIMD_INTRINSIC psimd_f32 psimd_abs_f32(psimd_f32 v) { + const psimd_s32 mask = (psimd_s32) psimd_splat_f32(-0.0f); + return (psimd_f32) ((psimd_s32) v & ~mask); + } + + /* Vector negation */ + PSIMD_INTRINSIC psimd_f32 psimd_neg_f32(psimd_f32 v) { + const psimd_s32 mask = (psimd_s32) psimd_splat_f32(-0.0f); + return (psimd_f32) ((psimd_s32) v ^ mask); + } + + /* Vector maximum */ + PSIMD_INTRINSIC psimd_s8 psimd_max_s8(psimd_s8 a, psimd_s8 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s8) vmaxq_s8((int8x16_t) a, (int8x16_t) b); + #else + return psimd_blend_s8(a > b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_u8 psimd_max_u8(psimd_u8 a, psimd_u8 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u8) vmaxq_u8((uint8x16_t) a, (uint8x16_t) b); + #else + return psimd_blend_u8(a > b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_s16 psimd_max_s16(psimd_s16 a, psimd_s16 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s16) vmaxq_s16((int16x8_t) a, (int16x8_t) b); + #else + return psimd_blend_s16(a > b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_u16 psimd_max_u16(psimd_u16 a, psimd_u16 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u16) vmaxq_u16((uint16x8_t) a, (uint16x8_t) b); + #else + return psimd_blend_u16(a > b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_s32 psimd_max_s32(psimd_s32 a, psimd_s32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s32) vmaxq_s32((int32x4_t) a, (int32x4_t) b); + #else + return psimd_blend_s32(a > b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_u32 psimd_max_u32(psimd_u32 a, psimd_u32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u32) vmaxq_u32((uint32x4_t) a, (uint32x4_t) b); + #else + return psimd_blend_u32(a > b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_max_f32(psimd_f32 a, psimd_f32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_f32) vmaxq_f32((float32x4_t) a, (float32x4_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return __builtin_wasm_max_f32x4(a, b); + #else + return psimd_blend_f32(a > b, a, b); + #endif + } + + /* Vector minimum */ + PSIMD_INTRINSIC psimd_s8 psimd_min_s8(psimd_s8 a, psimd_s8 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s8) vminq_s8((int8x16_t) a, (int8x16_t) b); + #else + return psimd_blend_s8(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_u8 psimd_min_u8(psimd_u8 a, psimd_u8 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u8) vminq_u8((uint8x16_t) a, (uint8x16_t) b); + #else + return psimd_blend_u8(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_s16 psimd_min_s16(psimd_s16 a, psimd_s16 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s16) vminq_s16((int16x8_t) a, (int16x8_t) b); + #else + return psimd_blend_s16(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_u16 psimd_min_u16(psimd_u16 a, psimd_u16 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u16) vminq_u16((uint16x8_t) a, (uint16x8_t) b); + #else + return psimd_blend_u16(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_s32 psimd_min_s32(psimd_s32 a, psimd_s32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_s32) vminq_s32((int32x4_t) a, (int32x4_t) b); + #else + return psimd_blend_s32(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_u32 psimd_min_u32(psimd_u32 a, psimd_u32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_u32) vminq_u32((uint32x4_t) a, (uint32x4_t) b); + #else + return psimd_blend_u32(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_min_f32(psimd_f32 a, psimd_f32 b) { + #if defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_f32) vminq_f32((float32x4_t) a, (float32x4_t) b); + #elif defined(__wasm__) && defined(__wasm_simd128__) && defined(__clang__) + return __builtin_wasm_min_f32x4(a, b); + #else + return psimd_blend_f32(a < b, a, b); + #endif + } + + PSIMD_INTRINSIC psimd_f32 psimd_cvt_s32_f32(psimd_s32 v) { + #if defined(__clang__) + return __builtin_convertvector(v, psimd_f32); + #elif defined(__ARM_NEON__) || defined(__ARM_NEON) + return (psimd_f32) vcvtq_f32_s32((int32x4_t) v); + #elif defined(__SSE2__) + return (psimd_f32) _mm_cvtepi32_ps((__m128i) v); + #else + return (psimd_f32) { (float) v[0], (float) v[1], (float) v[2], (float) v[3] }; + #endif + } + + /* Broadcast vector element */ + #if defined(__clang__) + PSIMD_INTRINSIC psimd_f32 psimd_splat0_f32(psimd_f32 v) { + return __builtin_shufflevector(v, v, 0, 0, 0, 0); + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat1_f32(psimd_f32 v) { + return __builtin_shufflevector(v, v, 1, 1, 1, 1); + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat2_f32(psimd_f32 v) { + return __builtin_shufflevector(v, v, 2, 2, 2, 2); + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat3_f32(psimd_f32 v) { + return __builtin_shufflevector(v, v, 3, 3, 3, 3); + } + #else + PSIMD_INTRINSIC psimd_f32 psimd_splat0_f32(psimd_f32 v) { + return __builtin_shuffle(v, (psimd_s32) { 0, 0, 0, 0 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat1_f32(psimd_f32 v) { + return __builtin_shuffle(v, (psimd_s32) { 1, 1, 1, 1 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat2_f32(psimd_f32 v) { + return __builtin_shuffle(v, (psimd_s32) { 2, 2, 2, 2 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_splat3_f32(psimd_f32 v) { + return __builtin_shuffle(v, (psimd_s32) { 3, 3, 3, 3 }); + } + #endif + + /* Reversal of vector elements */ + #if defined(__clang__) + PSIMD_INTRINSIC psimd_s8 psimd_reverse_s8(psimd_s8 v) { + return __builtin_shufflevector(v, v, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + } + + PSIMD_INTRINSIC psimd_u8 psimd_reverse_u8(psimd_u8 v) { + return __builtin_shufflevector(v, v, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0); + } + + PSIMD_INTRINSIC psimd_s16 psimd_reverse_s16(psimd_s16 v) { + return __builtin_shufflevector(v, v, 7, 6, 5, 4, 3, 2, 1, 0); + } + + PSIMD_INTRINSIC psimd_u16 psimd_reverse_u16(psimd_u16 v) { + return __builtin_shufflevector(v, v, 7, 6, 5, 4, 3, 2, 1, 0); + } + + PSIMD_INTRINSIC psimd_s32 psimd_reverse_s32(psimd_s32 v) { + return __builtin_shufflevector(v, v, 3, 2, 1, 0); + } + + PSIMD_INTRINSIC psimd_u32 psimd_reverse_u32(psimd_u32 v) { + return __builtin_shufflevector(v, v, 3, 2, 1, 0); + } + + PSIMD_INTRINSIC psimd_f32 psimd_reverse_f32(psimd_f32 v) { + return __builtin_shufflevector(v, v, 3, 2, 1, 0); + } + #else + PSIMD_INTRINSIC psimd_s8 psimd_reverse_s8(psimd_s8 v) { + return __builtin_shuffle(v, (psimd_s8) { 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 }); + } + + PSIMD_INTRINSIC psimd_u8 psimd_reverse_u8(psimd_u8 v) { + return __builtin_shuffle(v, (psimd_s8) { 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0 }); + } + + PSIMD_INTRINSIC psimd_s16 psimd_reverse_s16(psimd_s16 v) { + return __builtin_shuffle(v, (psimd_s16) { 7, 6, 5, 4, 3, 2, 1, 0 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_reverse_u16(psimd_u16 v) { + return __builtin_shuffle(v, (psimd_s16) { 7, 6, 5, 4, 3, 2, 1, 0 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_reverse_s32(psimd_s32 v) { + return __builtin_shuffle(v, (psimd_s32) { 3, 2, 1, 0 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_reverse_u32(psimd_u32 v) { + return __builtin_shuffle(v, (psimd_s32) { 3, 2, 1, 0 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_reverse_f32(psimd_f32 v) { + return __builtin_shuffle(v, (psimd_s32) { 3, 2, 1, 0 }); + } + #endif + + /* Interleaving of vector elements */ + #if defined(__clang__) + PSIMD_INTRINSIC psimd_s16 psimd_interleave_lo_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shufflevector(a, b, 0, 8+0, 1, 8+1, 2, 8+2, 3, 8+3); + } + + PSIMD_INTRINSIC psimd_s16 psimd_interleave_hi_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shufflevector(a, b, 4, 8+4, 5, 8+5, 6, 8+6, 7, 8+7); + } + + PSIMD_INTRINSIC psimd_u16 psimd_interleave_lo_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shufflevector(a, b, 0, 8+0, 1, 8+1, 2, 8+2, 3, 8+3); + } + + PSIMD_INTRINSIC psimd_u16 psimd_interleave_hi_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shufflevector(a, b, 4, 8+4, 5, 8+5, 6, 8+6, 7, 8+7); + } + + PSIMD_INTRINSIC psimd_s32 psimd_interleave_lo_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shufflevector(a, b, 0, 4+0, 1, 4+1); + } + + PSIMD_INTRINSIC psimd_s32 psimd_interleave_hi_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shufflevector(a, b, 2, 4+2, 3, 4+3); + } + + PSIMD_INTRINSIC psimd_u32 psimd_interleave_lo_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shufflevector(a, b, 0, 4+0, 1, 4+1); + } + + PSIMD_INTRINSIC psimd_u32 psimd_interleave_hi_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shufflevector(a, b, 2, 4+2, 3, 4+3); + } + + PSIMD_INTRINSIC psimd_f32 psimd_interleave_lo_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shufflevector(a, b, 0, 4+0, 1, 4+1); + } + + PSIMD_INTRINSIC psimd_f32 psimd_interleave_hi_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shufflevector(a, b, 2, 4+2, 3, 4+3); + } + #else + PSIMD_INTRINSIC psimd_s16 psimd_interleave_lo_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 0, 8+0, 1, 8+1, 2, 8+2, 3, 8+3 }); + } + + PSIMD_INTRINSIC psimd_s16 psimd_interleave_hi_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 4, 8+4, 5, 8+5, 6, 8+6, 7, 8+7 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_interleave_lo_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 0, 8+0, 1, 8+1, 2, 8+2, 3, 8+3 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_interleave_hi_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 4, 8+4, 5, 8+5, 6, 8+6, 7, 8+7 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_interleave_lo_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 4+0, 1, 4+1 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_interleave_hi_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 2, 4+2, 3, 4+3 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_interleave_lo_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 4+0, 1, 4+1 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_interleave_hi_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 2, 4+2, 3, 4+3 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_interleave_lo_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 4+0, 1, 4+1 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_interleave_hi_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 2, 4+2, 3, 4+3 }); + } + #endif + + /* Concatenation of low/high vector elements */ + #if defined(__clang__) + PSIMD_INTRINSIC psimd_s16 psimd_concat_lo_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shufflevector(a, b, 0, 1, 2, 3, 8+0, 8+1, 8+2, 8+3); + } + + PSIMD_INTRINSIC psimd_s16 psimd_concat_hi_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shufflevector(a, b, 4, 5, 6, 7, 8+4, 8+5, 8+6, 8+7); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_lo_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shufflevector(a, b, 0, 1, 2, 3, 8+0, 8+1, 8+2, 8+3); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_hi_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shufflevector(a, b, 4, 5, 6, 7, 8+4, 8+5, 8+6, 8+7); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_lo_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shufflevector(a, b, 0, 1, 4+0, 4+1); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_hi_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shufflevector(a, b, 2, 3, 4+2, 4+3); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_lo_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shufflevector(a, b, 0, 1, 4+0, 4+1); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_hi_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shufflevector(a, b, 2, 3, 4+2, 4+3); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_lo_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shufflevector(a, b, 0, 1, 4+0, 4+1); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_hi_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shufflevector(a, b, 2, 3, 4+2, 4+3); + } + #else + PSIMD_INTRINSIC psimd_s16 psimd_concat_lo_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 0, 1, 2, 3, 8+0, 8+1, 8+2, 8+3 }); + } + + PSIMD_INTRINSIC psimd_s16 psimd_concat_hi_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 4, 5, 6, 7, 8+4, 8+5, 8+6, 8+7 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_lo_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 0, 1, 2, 3, 8+0, 8+1, 8+2, 8+3 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_hi_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 4, 5, 6, 7, 8+4, 8+5, 8+6, 8+7 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_lo_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 1, 4+0, 4+1 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_hi_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 2, 3, 4+2, 4+3 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_lo_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 1, 4+0, 4+1 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_hi_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 2, 3, 4+2, 4+3 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_lo_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 1, 4+0, 4+1 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_hi_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 2, 3, 4+2, 4+3 }); + } + #endif + + /* Concatenation of even/odd vector elements */ + #if defined(__clang__) + PSIMD_INTRINSIC psimd_s8 psimd_concat_even_s8(psimd_s8 a, psimd_s8 b) { + return __builtin_shufflevector(a, b, + 0, 2, 4, 6, 8, 10, 12, 14, 16+0, 16+2, 16+4, 16+6, 16+8, 16+10, 16+12, 16+14); + } + + PSIMD_INTRINSIC psimd_s8 psimd_concat_odd_s8(psimd_s8 a, psimd_s8 b) { + return __builtin_shufflevector(a, b, + 1, 3, 5, 7, 9, 11, 13, 15, 16+1, 16+3, 16+5, 16+7, 16+9, 16+11, 16+13, 16+15); + } + + PSIMD_INTRINSIC psimd_u8 psimd_concat_even_u8(psimd_u8 a, psimd_u8 b) { + return __builtin_shufflevector(a, b, + 0, 2, 4, 6, 8, 10, 12, 14, 16+0, 16+2, 16+4, 16+6, 16+8, 16+10, 16+12, 16+14); + } + + PSIMD_INTRINSIC psimd_u8 psimd_concat_odd_u8(psimd_u8 a, psimd_u8 b) { + return __builtin_shufflevector(a, b, + 1, 3, 5, 7, 9, 11, 13, 15, 16+1, 16+3, 16+5, 16+7, 16+9, 16+11, 16+13, 16+15); + } + + PSIMD_INTRINSIC psimd_s16 psimd_concat_even_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shufflevector(a, b, 0, 2, 4, 6, 8+0, 8+2, 8+4, 8+6); + } + + PSIMD_INTRINSIC psimd_s16 psimd_concat_odd_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shufflevector(a, b, 1, 3, 5, 7, 8+1, 8+3, 8+5, 8+7); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_even_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shufflevector(a, b, 0, 2, 4, 6, 8+0, 8+2, 8+4, 8+6); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_odd_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shufflevector(a, b, 1, 3, 5, 7, 8+1, 8+3, 8+5, 8+7); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_even_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shufflevector(a, b, 0, 2, 4+0, 4+2); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_odd_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shufflevector(a, b, 1, 3, 4+1, 4+3); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_even_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shufflevector(a, b, 0, 2, 4+0, 4+2); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_odd_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shufflevector(a, b, 1, 3, 4+1, 4+3); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_even_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shufflevector(a, b, 0, 2, 4+0, 4+2); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_odd_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shufflevector(a, b, 1, 3, 4+1, 4+3); + } + #else + PSIMD_INTRINSIC psimd_s8 psimd_concat_even_s8(psimd_s8 a, psimd_s8 b) { + return __builtin_shuffle(a, b, + (psimd_s8) { 0, 2, 4, 6, 8, 10, 12, 14, 16+0, 16+2, 16+4, 16+6, 16+8, 16+10, 16+12, 16+14 }); + } + + PSIMD_INTRINSIC psimd_s8 psimd_concat_odd_s8(psimd_s8 a, psimd_s8 b) { + return __builtin_shuffle(a, b, + (psimd_s8) { 1, 3, 5, 7, 9, 11, 13, 15, 16+1, 16+3, 16+5, 16+7, 16+9, 16+11, 16+13, 16+15 }); + } + + PSIMD_INTRINSIC psimd_u8 psimd_concat_even_u8(psimd_u8 a, psimd_u8 b) { + return __builtin_shuffle(a, b, + (psimd_s8) { 0, 2, 4, 6, 8, 10, 12, 14, 16+0, 16+2, 16+4, 16+6, 16+8, 16+10, 16+12, 16+14 }); + } + + PSIMD_INTRINSIC psimd_u8 psimd_concat_odd_u8(psimd_u8 a, psimd_u8 b) { + return __builtin_shuffle(a, b, + (psimd_s8) { 1, 3, 5, 7, 9, 11, 13, 15, 16+1, 16+3, 16+5, 16+7, 16+9, 16+11, 16+13, 16+15 }); + } + + PSIMD_INTRINSIC psimd_s16 psimd_concat_even_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 0, 2, 4, 6, 8+0, 8+2, 8+4, 8+6 }); + } + + PSIMD_INTRINSIC psimd_s16 psimd_concat_odd_s16(psimd_s16 a, psimd_s16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 1, 3, 5, 7, 8+1, 8+3, 8+5, 8+7 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_even_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 0, 2, 4, 6, 8+0, 8+2, 8+4, 8+6 }); + } + + PSIMD_INTRINSIC psimd_u16 psimd_concat_odd_u16(psimd_u16 a, psimd_u16 b) { + return __builtin_shuffle(a, b, (psimd_s16) { 1, 3, 5, 7, 8+1, 8+3, 8+5, 8+7 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_even_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 2, 4+0, 4+2 }); + } + + PSIMD_INTRINSIC psimd_s32 psimd_concat_odd_s32(psimd_s32 a, psimd_s32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 1, 3, 4+1, 4+3 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_even_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 2, 4+0, 4+2 }); + } + + PSIMD_INTRINSIC psimd_u32 psimd_concat_odd_u32(psimd_u32 a, psimd_u32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 1, 3, 4+1, 4+3 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_even_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 0, 2, 4+0, 4+2 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_concat_odd_f32(psimd_f32 a, psimd_f32 b) { + return __builtin_shuffle(a, b, (psimd_s32) { 1, 3, 4+1, 4+3 }); + } + #endif + + /* Vector reduce */ + #if defined(__clang__) + PSIMD_INTRINSIC psimd_f32 psimd_allreduce_sum_f32(psimd_f32 v) { + const psimd_f32 temp = v + __builtin_shufflevector(v, v, 2, 3, 0, 1); + return temp + __builtin_shufflevector(temp, temp, 1, 0, 3, 2); + } + + PSIMD_INTRINSIC psimd_f32 psimd_allreduce_max_f32(psimd_f32 v) { + const psimd_f32 temp = psimd_max_f32(v, __builtin_shufflevector(v, v, 2, 3, 0, 1)); + return psimd_max_f32(temp, __builtin_shufflevector(temp, temp, 1, 0, 3, 2)); + } + + PSIMD_INTRINSIC psimd_f32 psimd_allreduce_min_f32(psimd_f32 v) { + const psimd_f32 temp = psimd_min_f32(v, __builtin_shufflevector(v, v, 2, 3, 0, 1)); + return psimd_min_f32(temp, __builtin_shufflevector(temp, temp, 1, 0, 3, 2)); + } + + PSIMD_INTRINSIC float psimd_reduce_sum_f32(psimd_f32 v) { + const psimd_f32 temp = v + __builtin_shufflevector(v, v, 2, 3, -1, -1); + const psimd_f32 result = temp + __builtin_shufflevector(temp, temp, 1, -1, -1, -1); + return result[0]; + } + + PSIMD_INTRINSIC float psimd_reduce_max_f32(psimd_f32 v) { + const psimd_f32 temp = psimd_max_f32(v, __builtin_shufflevector(v, v, 2, 3, -1, -1)); + const psimd_f32 result = psimd_max_f32(temp, __builtin_shufflevector(temp, temp, 1, -1, -1, -1)); + return result[0]; + } + + PSIMD_INTRINSIC float psimd_reduce_min_f32(psimd_f32 v) { + const psimd_f32 temp = psimd_min_f32(v, __builtin_shufflevector(v, v, 2, 3, -1, -1)); + const psimd_f32 result = psimd_min_f32(temp, __builtin_shufflevector(temp, temp, 1, -1, -1, -1)); + return result[0]; + } + #else + PSIMD_INTRINSIC psimd_f32 psimd_allreduce_sum_f32(psimd_f32 v) { + const psimd_f32 temp = v + __builtin_shuffle(v, (psimd_s32) { 2, 3, 0, 1 }); + return temp + __builtin_shuffle(temp, (psimd_s32) { 1, 0, 3, 2 }); + } + + PSIMD_INTRINSIC psimd_f32 psimd_allreduce_max_f32(psimd_f32 v) { + const psimd_f32 temp = psimd_max_f32(v, __builtin_shuffle(v, (psimd_s32) { 2, 3, 0, 1 })); + return psimd_max_f32(temp, __builtin_shuffle(temp, (psimd_s32) { 1, 0, 3, 2 })); + } + + PSIMD_INTRINSIC psimd_f32 psimd_allreduce_min_f32(psimd_f32 v) { + const psimd_f32 temp = psimd_min_f32(v, __builtin_shuffle(v, (psimd_s32) { 2, 3, 0, 1 })); + return psimd_min_f32(temp, __builtin_shuffle(temp, (psimd_s32) { 1, 0, 3, 2 })); + } + + PSIMD_INTRINSIC float psimd_reduce_sum_f32(psimd_f32 v) { + const psimd_f32 result = psimd_allreduce_sum_f32(v); + return result[0]; + } + + PSIMD_INTRINSIC float psimd_reduce_max_f32(psimd_f32 v) { + const psimd_f32 result = psimd_allreduce_max_f32(v); + return result[0]; + } + + PSIMD_INTRINSIC float psimd_reduce_min_f32(psimd_f32 v) { + const psimd_f32 result = psimd_allreduce_min_f32(v); + return result[0]; + } + #endif +#endif + +#endif /* PSIMD_H */ diff --git a/phivenv/Lib/site-packages/torch/include/pthreadpool.h b/phivenv/Lib/site-packages/torch/include/pthreadpool.h new file mode 100644 index 0000000000000000000000000000000000000000..3b779adcb96bd845944856387195c98b40567afe --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/pthreadpool.h @@ -0,0 +1,2555 @@ +#ifndef PTHREADPOOL_H_ +#define PTHREADPOOL_H_ + +#include +#include + +typedef struct pthreadpool* pthreadpool_t; + +typedef void (*pthreadpool_task_1d_t)(void*, size_t); +typedef void (*pthreadpool_task_1d_with_thread_t)(void*, size_t, size_t); +typedef void (*pthreadpool_task_1d_tile_1d_t)(void*, size_t, size_t); +typedef void (*pthreadpool_task_2d_t)(void*, size_t, size_t); +typedef void (*pthreadpool_task_2d_with_thread_t)(void*, size_t, size_t, size_t); +typedef void (*pthreadpool_task_2d_tile_1d_t)(void*, size_t, size_t, size_t); +typedef void (*pthreadpool_task_2d_tile_2d_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_t)(void*, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_tile_1d_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_tile_1d_with_thread_t)(void*, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_tile_2d_t)(void*, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_4d_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_4d_tile_1d_t)(void*, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_4d_tile_2d_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_5d_t)(void*, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_5d_tile_1d_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_5d_tile_2d_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_6d_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_6d_tile_1d_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_6d_tile_2d_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t, size_t, size_t); + +typedef void (*pthreadpool_task_1d_with_id_t)(void*, uint32_t, size_t); +typedef void (*pthreadpool_task_2d_tile_1d_with_id_t)(void*, uint32_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_2d_tile_2d_with_id_t)(void*, uint32_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_tile_1d_with_id_t)(void*, uint32_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_tile_2d_with_id_t)(void*, uint32_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_4d_tile_2d_with_id_t)(void*, uint32_t, size_t, size_t, size_t, size_t, size_t, size_t); + +typedef void (*pthreadpool_task_2d_tile_1d_with_id_with_thread_t)(void*, uint32_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_task_3d_tile_1d_with_id_with_thread_t)(void*, uint32_t, size_t, size_t, size_t, size_t, size_t); + + +/** + * Disable support for denormalized numbers to the maximum extent possible for + * the duration of the computation. + * + * Handling denormalized floating-point numbers is often implemented in + * microcode, and incurs significant performance degradation. This hint + * instructs the thread pool to disable support for denormalized numbers before + * running the computation by manipulating architecture-specific control + * registers, and restore the initial value of control registers after the + * computation is complete. The thread pool temporary disables denormalized + * numbers on all threads involved in the computation (i.e. the caller threads, + * and potentially worker threads). + * + * Disabling denormalized numbers may have a small negative effect on results' + * accuracy. As various architectures differ in capabilities to control + * processing of denormalized numbers, using this flag may also hurt results' + * reproducibility across different instruction set architectures. + */ +#define PTHREADPOOL_FLAG_DISABLE_DENORMALS 0x00000001 + +/** + * Yield worker threads to the system scheduler after the operation is finished. + * + * Force workers to use kernel wait (instead of active spin-wait by default) for + * new commands after this command is processed. This flag affects only the + * immediate next operation on this thread pool. To make the thread pool always + * use kernel wait, pass this flag to all parallelization functions. + */ +#define PTHREADPOOL_FLAG_YIELD_WORKERS 0x00000002 + +#ifdef __cplusplus +extern "C" { +#endif + +/** + * Create a thread pool with the specified number of threads. + * + * @param threads_count the number of threads in the thread pool. + * A value of 0 has special interpretation: it creates a thread pool with as + * many threads as there are logical processors in the system. + * + * @returns A pointer to an opaque thread pool object if the call is + * successful, or NULL pointer if the call failed. + */ +pthreadpool_t pthreadpool_create(size_t threads_count); + +/** + * Query the number of threads in a thread pool. + * + * @param threadpool the thread pool to query. + * + * @returns The number of threads in the thread pool. + */ +size_t pthreadpool_get_threads_count(pthreadpool_t threadpool); + +/** + * Process items on a 1D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range; i++) + * function(context, i); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each item. + * @param context the first argument passed to the specified function. + * @param range the number of items on the 1D grid to process. The + * specified function will be called once for each item. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_1d( + pthreadpool_t threadpool, + pthreadpool_task_1d_t function, + void* context, + size_t range, + uint32_t flags); + +/** + * Process items on a 1D grid passing along the current thread id. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range; i++) + * function(context, thread_index, i); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each item. + * @param context the first argument passed to the specified function. + * @param range the number of items on the 1D grid to process. The + * specified function will be called once for each item. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_1d_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_1d_with_thread_t function, + void* context, + size_t range, + uint32_t flags); + +/** + * Process items on a 1D grid using a microarchitecture-aware task function. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range; i++) + * function(context, uarch_index, i); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If + * threadpool is NULL, all items are processed serially on the calling + * thread. + * @param function the function to call for each item. + * @param context the first argument passed to the specified + * function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range the number of items on the 1D grid to process. + * The specified function will be called once for each item. + * @param flags a bitwise combination of zero or more optional + * flags (PTHREADPOOL_FLAG_DISABLE_DENORMALS or + * PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_1d_with_uarch( + pthreadpool_t threadpool, + pthreadpool_task_1d_with_id_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range, + uint32_t flags); + +/** + * Process items on a 1D grid with specified maximum tile size. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range; i += tile) + * function(context, i, min(range - i, tile)); + * + * When the call returns, all items have been processed and the thread pool is + * ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, + * the calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range the number of items on the 1D grid to process. + * @param tile the maximum number of items on the 1D grid to process in + * one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_1d_tile_1d( + pthreadpool_t threadpool, + pthreadpool_task_1d_tile_1d_t function, + void* context, + size_t range, + size_t tile, + uint32_t flags); + +/** + * Process items on a 2D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * function(context, i, j); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each item. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d( + pthreadpool_t threadpool, + pthreadpool_task_2d_t function, + void* context, + size_t range_i, + size_t range_j, + uint32_t flags); + +/** + * Process items on a 2D grid passing along the current thread id. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * function(context, thread_index, i, j); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each item. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_2d_with_thread_t function, + void* context, + size_t range_i, + size_t range_j, + uint32_t flags); + +/** + * Process items on a 2D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * function(context, i, j, min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param tile_j the maximum number of items along the second dimension of + * the 2D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d_tile_1d( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_1d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t tile_j, + uint32_t flags); + +/** + * Process items on a 2D grid with the specified maximum tile size along the + * last grid dimension using a microarchitecture-aware task function. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * function(context, uarch_index, i, j, min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param tile_j the maximum number of items along the second dimension of + * the 2D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d_tile_1d_with_uarch( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_1d_with_id_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t tile_j, + uint32_t flags); + +/** + * Process items on a 2D grid with the specified maximum tile size along the + * last grid dimension using a microarchitecture-aware task function and passing + * along the current thread id. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * function(context, uarch_index, thread_index, i, j, min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param tile_j the maximum number of items along the second dimension of + * the 2D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d_tile_1d_with_uarch_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_1d_with_id_with_thread_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t tile_j, + uint32_t flags); + +/** + * Process items on a 2D grid with the specified maximum tile size along each + * grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i += tile_i) + * for (size_t j = 0; j < range_j; j += tile_j) + * function(context, i, j, + * min(range_i - i, tile_i), min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param tile_j the maximum number of items along the first dimension of + * the 2D grid to process in one function call. + * @param tile_j the maximum number of items along the second dimension of + * the 2D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d_tile_2d( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_2d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j, + uint32_t flags); + +/** + * Process items on a 2D grid with the specified maximum tile size along each + * grid dimension using a microarchitecture-aware task function. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i += tile_i) + * for (size_t j = 0; j < range_j; j += tile_j) + * function(context, uarch_index, i, j, + * min(range_i - i, tile_i), min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If + * threadpool is NULL, all items are processed serially on the calling + * thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified + * function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, + * cpuinfo initialization failed, or index returned + * by cpuinfo_get_current_uarch_index() exceeds + * the max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected + * by the specified function. If the index returned + * by cpuinfo_get_current_uarch_index() exceeds this + * value, default_uarch_index will be used instead. + * default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first + * dimension of the 2D grid. + * @param range_j the number of items to process along the second + * dimension of the 2D grid. + * @param tile_j the maximum number of items along the first + * dimension of the 2D grid to process in one function call. + * @param tile_j the maximum number of items along the second + * dimension of the 2D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional + * flags (PTHREADPOOL_FLAG_DISABLE_DENORMALS or + * PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_2d_tile_2d_with_uarch( + pthreadpool_t threadpool, + pthreadpool_task_2d_tile_2d_with_id_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j, + uint32_t flags); + +/** + * Process items on a 3D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * function(context, i, j, k); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d( + pthreadpool_t threadpool, + pthreadpool_task_3d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + uint32_t flags); + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * function(context, i, j, k, min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param tile_k the maximum number of items along the third dimension of + * the 3D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d_tile_1d( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_k, + uint32_t flags); + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last grid dimension and passing along the current thread id. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * function(context, thread_index, i, j, k, min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param tile_k the maximum number of items along the third dimension of + * the 3D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d_tile_1d_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_thread_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_k, + uint32_t flags); + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last grid dimension using a microarchitecture-aware task function. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * function(context, uarch_index, i, j, k, min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If + * threadpool is NULL, all items are processed serially on the calling + * thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified + * function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first + * dimension of the 3D grid. + * @param range_j the number of items to process along the second + * dimension of the 3D grid. + * @param range_k the number of items to process along the third + * dimension of the 3D grid. + * @param tile_k the maximum number of items along the third + * dimension of the 3D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional + * flags (PTHREADPOOL_FLAG_DISABLE_DENORMALS or + * PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d_tile_1d_with_uarch( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_id_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_k, + uint32_t flags); + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last grid dimension using a microarchitecture-aware task function and passing + * along the current thread id. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * function(context, uarch_index, thread_index, i, j, k, min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If + * threadpool is NULL, all items are processed serially on the calling + * thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified + * function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first + * dimension of the 3D grid. + * @param range_j the number of items to process along the second + * dimension of the 3D grid. + * @param range_k the number of items to process along the third + * dimension of the 3D grid. + * @param tile_k the maximum number of items along the third + * dimension of the 3D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional + * flags (PTHREADPOOL_FLAG_DISABLE_DENORMALS or + * PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d_tile_1d_with_uarch_with_thread( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_1d_with_id_with_thread_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_k, + uint32_t flags); + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * for (size_t k = 0; k < range_k; k += tile_k) + * function(context, i, j, k, + * min(range_j - j, tile_j), min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param tile_j the maximum number of items along the second dimension of + * the 3D grid to process in one function call. + * @param tile_k the maximum number of items along the third dimension of + * the 3D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d_tile_2d( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_2d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_j, + size_t tile_k, + uint32_t flags); + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last two grid dimensions using a microarchitecture-aware task function. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * for (size_t k = 0; k < range_k; k += tile_k) + * function(context, uarch_index, i, j, k, + * min(range_j - j, tile_j), min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If + * threadpool is NULL, all items are processed serially on the calling + * thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified + * function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first + * dimension of the 3D grid. + * @param range_j the number of items to process along the second + * dimension of the 3D grid. + * @param range_k the number of items to process along the third + * dimension of the 3D grid. + * @param tile_j the maximum number of items along the second + * dimension of the 3D grid to process in one function call. + * @param tile_k the maximum number of items along the third + * dimension of the 3D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional + * flags (PTHREADPOOL_FLAG_DISABLE_DENORMALS or + * PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_3d_tile_2d_with_uarch( + pthreadpool_t threadpool, + pthreadpool_task_3d_tile_2d_with_id_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_j, + size_t tile_k, + uint32_t flags); + +/** + * Process items on a 4D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * function(context, i, j, k, l); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 4D grid. + * @param range_j the number of items to process along the second dimension + * of the 4D grid. + * @param range_k the number of items to process along the third dimension + * of the 4D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 4D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_4d( + pthreadpool_t threadpool, + pthreadpool_task_4d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + uint32_t flags); + +/** + * Process items on a 4D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l += tile_l) + * function(context, i, j, k, l, min(range_l - l, tile_l)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 4D grid. + * @param range_j the number of items to process along the second dimension + * of the 4D grid. + * @param range_k the number of items to process along the third dimension + * of the 4D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 4D grid. + * @param tile_l the maximum number of items along the fourth dimension of + * the 4D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_4d_tile_1d( + pthreadpool_t threadpool, + pthreadpool_task_4d_tile_1d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_l, + uint32_t flags); + +/** + * Process items on a 4D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * for (size_t l = 0; l < range_l; l += tile_l) + * function(context, i, j, k, l, + * min(range_k - k, tile_k), min(range_l - l, tile_l)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 4D grid. + * @param range_j the number of items to process along the second dimension + * of the 4D grid. + * @param range_k the number of items to process along the third dimension + * of the 4D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 4D grid. + * @param tile_k the maximum number of items along the third dimension of + * the 4D grid to process in one function call. + * @param tile_l the maximum number of items along the fourth dimension of + * the 4D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_4d_tile_2d( + pthreadpool_t threadpool, + pthreadpool_task_4d_tile_2d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_k, + size_t tile_l, + uint32_t flags); + +/** + * Process items on a 4D grid with the specified maximum tile size along the + * last two grid dimensions using a microarchitecture-aware task function. + * + * The function implements a parallel version of the following snippet: + * + * uint32_t uarch_index = cpuinfo_initialize() ? + * cpuinfo_get_current_uarch_index() : default_uarch_index; + * if (uarch_index > max_uarch_index) uarch_index = default_uarch_index; + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * for (size_t l = 0; l < range_l; l += tile_l) + * function(context, uarch_index, i, j, k, l, + * min(range_k - k, tile_k), min(range_l - l, tile_l)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If + * threadpool is NULL, all items are processed serially on the calling + * thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified + * function. + * @param default_uarch_index the microarchitecture index to use when + * pthreadpool is configured without cpuinfo, cpuinfo initialization failed, + * or index returned by cpuinfo_get_current_uarch_index() exceeds the + * max_uarch_index value. + * @param max_uarch_index the maximum microarchitecture index expected by + * the specified function. If the index returned by + * cpuinfo_get_current_uarch_index() exceeds this value, default_uarch_index + * will be used instead. default_uarch_index can exceed max_uarch_index. + * @param range_i the number of items to process along the first + * dimension of the 4D grid. + * @param range_j the number of items to process along the second + * dimension of the 4D grid. + * @param range_k the number of items to process along the third + * dimension of the 4D grid. + * @param range_l the number of items to process along the fourth + * dimension of the 4D grid. + * @param tile_k the maximum number of items along the third + * dimension of the 4D grid to process in one function call. + * @param tile_l the maximum number of items along the fourth + * dimension of the 4D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional + * flags (PTHREADPOOL_FLAG_DISABLE_DENORMALS or + * PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_4d_tile_2d_with_uarch( + pthreadpool_t threadpool, + pthreadpool_task_4d_tile_2d_with_id_t function, + void* context, + uint32_t default_uarch_index, + uint32_t max_uarch_index, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_k, + size_t tile_l, + uint32_t flags); + +/** + * Process items on a 5D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m++) + * function(context, i, j, k, l, m); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 5D grid. + * @param range_j the number of items to process along the second dimension + * of the 5D grid. + * @param range_k the number of items to process along the third dimension + * of the 5D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 5D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 5D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_5d( + pthreadpool_t threadpool, + pthreadpool_task_5d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + uint32_t flags); + +/** + * Process items on a 5D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m += tile_m) + * function(context, i, j, k, l, m, min(range_m - m, tile_m)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 5D grid. + * @param range_j the number of items to process along the second dimension + * of the 5D grid. + * @param range_k the number of items to process along the third dimension + * of the 5D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 5D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 5D grid. + * @param tile_m the maximum number of items along the fifth dimension of + * the 5D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_5d_tile_1d( + pthreadpool_t threadpool, + pthreadpool_task_5d_tile_1d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t tile_m, + uint32_t flags); + +/** + * Process items on a 5D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l += tile_l) + * for (size_t m = 0; m < range_m; m += tile_m) + * function(context, i, j, k, l, m, + * min(range_l - l, tile_l), min(range_m - m, tile_m)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 5D grid. + * @param range_j the number of items to process along the second dimension + * of the 5D grid. + * @param range_k the number of items to process along the third dimension + * of the 5D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 5D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 5D grid. + * @param tile_l the maximum number of items along the fourth dimension of + * the 5D grid to process in one function call. + * @param tile_m the maximum number of items along the fifth dimension of + * the 5D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_5d_tile_2d( + pthreadpool_t threadpool, + pthreadpool_task_5d_tile_2d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t tile_l, + size_t tile_m, + uint32_t flags); + +/** + * Process items on a 6D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m++) + * for (size_t n = 0; n < range_n; n++) + * function(context, i, j, k, l, m, n); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 6D grid. + * @param range_j the number of items to process along the second dimension + * of the 6D grid. + * @param range_k the number of items to process along the third dimension + * of the 6D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 6D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 6D grid. + * @param range_n the number of items to process along the sixth dimension + * of the 6D grid. + * @param tile_n the maximum number of items along the sixth dimension of + * the 6D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_6d( + pthreadpool_t threadpool, + pthreadpool_task_6d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t range_n, + uint32_t flags); + +/** + * Process items on a 6D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m++) + * for (size_t n = 0; n < range_n; n += tile_n) + * function(context, i, j, k, l, m, n, min(range_n - n, tile_n)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 6D grid. + * @param range_j the number of items to process along the second dimension + * of the 6D grid. + * @param range_k the number of items to process along the third dimension + * of the 6D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 6D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 6D grid. + * @param range_n the number of items to process along the sixth dimension + * of the 6D grid. + * @param tile_n the maximum number of items along the sixth dimension of + * the 6D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_6d_tile_1d( + pthreadpool_t threadpool, + pthreadpool_task_6d_tile_1d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t range_n, + size_t tile_n, + uint32_t flags); + +/** + * Process items on a 6D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m += tile_m) + * for (size_t n = 0; n < range_n; n += tile_n) + * function(context, i, j, k, l, m, n, + * min(range_m - m, tile_m), min(range_n - n, tile_n)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param function the function to call for each tile. + * @param context the first argument passed to the specified function. + * @param range_i the number of items to process along the first dimension + * of the 6D grid. + * @param range_j the number of items to process along the second dimension + * of the 6D grid. + * @param range_k the number of items to process along the third dimension + * of the 6D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 6D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 6D grid. + * @param range_n the number of items to process along the sixth dimension + * of the 6D grid. + * @param tile_m the maximum number of items along the fifth dimension of + * the 6D grid to process in one function call. + * @param tile_n the maximum number of items along the sixth dimension of + * the 6D grid to process in one function call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +void pthreadpool_parallelize_6d_tile_2d( + pthreadpool_t threadpool, + pthreadpool_task_6d_tile_2d_t function, + void* context, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t range_n, + size_t tile_m, + size_t tile_n, + uint32_t flags); + +/** + * Terminates threads in the thread pool and releases associated resources. + * + * @warning Accessing the thread pool after a call to this function constitutes + * undefined behaviour and may cause data corruption. + * + * @param[in,out] threadpool The thread pool to destroy. + */ +void pthreadpool_destroy(pthreadpool_t threadpool); + +#ifndef PTHREADPOOL_NO_DEPRECATED_API + +/* Legacy API for compatibility with pre-existing users (e.g. NNPACK) */ +#if defined(__GNUC__) + #define PTHREADPOOL_DEPRECATED __attribute__((__deprecated__)) +#else + #define PTHREADPOOL_DEPRECATED +#endif + +typedef void (*pthreadpool_function_1d_t)(void*, size_t); +typedef void (*pthreadpool_function_1d_tiled_t)(void*, size_t, size_t); +typedef void (*pthreadpool_function_2d_t)(void*, size_t, size_t); +typedef void (*pthreadpool_function_2d_tiled_t)(void*, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_function_3d_tiled_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t); +typedef void (*pthreadpool_function_4d_tiled_t)(void*, size_t, size_t, size_t, size_t, size_t, size_t, size_t, size_t); + +void pthreadpool_compute_1d( + pthreadpool_t threadpool, + pthreadpool_function_1d_t function, + void* argument, + size_t range) PTHREADPOOL_DEPRECATED; + +void pthreadpool_compute_1d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_1d_tiled_t function, + void* argument, + size_t range, + size_t tile) PTHREADPOOL_DEPRECATED; + +void pthreadpool_compute_2d( + pthreadpool_t threadpool, + pthreadpool_function_2d_t function, + void* argument, + size_t range_i, + size_t range_j) PTHREADPOOL_DEPRECATED; + +void pthreadpool_compute_2d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_2d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j) PTHREADPOOL_DEPRECATED; + +void pthreadpool_compute_3d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_3d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_i, + size_t tile_j, + size_t tile_k) PTHREADPOOL_DEPRECATED; + +void pthreadpool_compute_4d_tiled( + pthreadpool_t threadpool, + pthreadpool_function_4d_tiled_t function, + void* argument, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_i, + size_t tile_j, + size_t tile_k, + size_t tile_l) PTHREADPOOL_DEPRECATED; + +#endif /* PTHREADPOOL_NO_DEPRECATED_API */ + +#ifdef __cplusplus +} /* extern "C" */ +#endif + +#ifdef __cplusplus + +namespace libpthreadpool { +namespace detail { +namespace { + +template +void call_wrapper_1d(void* arg, size_t i) { + (*static_cast(arg))(i); +} + +template +void call_wrapper_1d_tile_1d(void* arg, size_t range_i, size_t tile_i) { + (*static_cast(arg))(range_i, tile_i); +} + +template +void call_wrapper_2d(void* functor, size_t i, size_t j) { + (*static_cast(functor))(i, j); +} + +template +void call_wrapper_2d_tile_1d(void* functor, + size_t i, size_t range_j, size_t tile_j) +{ + (*static_cast(functor))(i, range_j, tile_j); +} + +template +void call_wrapper_2d_tile_2d(void* functor, + size_t range_i, size_t range_j, + size_t tile_i, size_t tile_j) +{ + (*static_cast(functor))(range_i, range_j, tile_i, tile_j); +} + +template +void call_wrapper_3d(void* functor, size_t i, size_t j, size_t k) { + (*static_cast(functor))(i, j, k); +} + +template +void call_wrapper_3d_tile_1d(void* functor, + size_t i, size_t j, size_t range_k, + size_t tile_k) +{ + (*static_cast(functor))(i, j, range_k, tile_k); +} + +template +void call_wrapper_3d_tile_2d(void* functor, + size_t i, size_t range_j, size_t range_k, + size_t tile_j, size_t tile_k) +{ + (*static_cast(functor))(i, range_j, range_k, tile_j, tile_k); +} + +template +void call_wrapper_4d(void* functor, size_t i, size_t j, size_t k, size_t l) { + (*static_cast(functor))(i, j, k, l); +} + +template +void call_wrapper_4d_tile_1d(void* functor, + size_t i, size_t j, size_t k, size_t range_l, + size_t tile_l) +{ + (*static_cast(functor))(i, j, k, range_l, tile_l); +} + +template +void call_wrapper_4d_tile_2d(void* functor, + size_t i, size_t j, size_t range_k, size_t range_l, + size_t tile_k, size_t tile_l) +{ + (*static_cast(functor))(i, j, range_k, range_l, tile_k, tile_l); +} + +template +void call_wrapper_5d(void* functor, size_t i, size_t j, size_t k, size_t l, size_t m) { + (*static_cast(functor))(i, j, k, l, m); +} + +template +void call_wrapper_5d_tile_1d(void* functor, + size_t i, size_t j, size_t k, size_t l, size_t range_m, + size_t tile_m) +{ + (*static_cast(functor))(i, j, k, l, range_m, tile_m); +} + +template +void call_wrapper_5d_tile_2d(void* functor, + size_t i, size_t j, size_t k, size_t range_l, size_t range_m, + size_t tile_l, size_t tile_m) +{ + (*static_cast(functor))(i, j, k, range_l, range_m, tile_l, tile_m); +} + +template +void call_wrapper_6d(void* functor, size_t i, size_t j, size_t k, size_t l, size_t m, size_t n) { + (*static_cast(functor))(i, j, k, l, m, n); +} + +template +void call_wrapper_6d_tile_1d(void* functor, + size_t i, size_t j, size_t k, size_t l, size_t m, size_t range_n, + size_t tile_n) +{ + (*static_cast(functor))(i, j, k, l, m, range_n, tile_n); +} + +template +void call_wrapper_6d_tile_2d(void* functor, + size_t i, size_t j, size_t k, size_t l, size_t range_m, size_t range_n, + size_t tile_m, size_t tile_n) +{ + (*static_cast(functor))(i, j, k, l, range_m, range_n, tile_m, tile_n); +} + +} /* namespace */ +} /* namespace detail */ +} /* namespace libpthreadpool */ + +/** + * Process items on a 1D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range; i++) + * functor(i); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each item. + * @param range the number of items on the 1D grid to process. The + * specified functor will be called once for each item. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range, + uint32_t flags = 0) +{ + pthreadpool_parallelize_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_1d, + const_cast(static_cast(&functor)), + range, + flags); +} + +/** + * Process items on a 1D grid with specified maximum tile size. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range; i += tile) + * functor(i, min(range - i, tile)); + * + * When the call returns, all items have been processed and the thread pool is + * ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, + * the calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range the number of items on the 1D grid to process. + * @param tile the maximum number of items on the 1D grid to process in + * one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_1d_tile_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range, + size_t tile, + uint32_t flags = 0) +{ + pthreadpool_parallelize_1d_tile_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_1d_tile_1d, + const_cast(static_cast(&functor)), + range, + tile, + flags); +} + +/** + * Process items on a 2D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * functor(i, j); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each item. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_2d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + uint32_t flags = 0) +{ + pthreadpool_parallelize_2d( + threadpool, + &libpthreadpool::detail::call_wrapper_2d, + const_cast(static_cast(&functor)), + range_i, + range_j, + flags); +} + +/** + * Process items on a 2D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * functor(i, j, min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param tile_j the maximum number of items along the second dimension of + * the 2D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_2d_tile_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t tile_j, + uint32_t flags = 0) +{ + pthreadpool_parallelize_2d_tile_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_2d_tile_1d, + const_cast(static_cast(&functor)), + range_i, + range_j, + tile_j, + flags); +} + +/** + * Process items on a 2D grid with the specified maximum tile size along each + * grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i += tile_i) + * for (size_t j = 0; j < range_j; j += tile_j) + * functor(i, j, + * min(range_i - i, tile_i), min(range_j - j, tile_j)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 2D grid. + * @param range_j the number of items to process along the second dimension + * of the 2D grid. + * @param tile_j the maximum number of items along the first dimension of + * the 2D grid to process in one functor call. + * @param tile_j the maximum number of items along the second dimension of + * the 2D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_2d_tile_2d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t tile_i, + size_t tile_j, + uint32_t flags = 0) +{ + pthreadpool_parallelize_2d_tile_2d( + threadpool, + &libpthreadpool::detail::call_wrapper_2d_tile_2d, + const_cast(static_cast(&functor)), + range_i, + range_j, + tile_i, + tile_j, + flags); +} + +/** + * Process items on a 3D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * functor(i, j, k); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_3d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + uint32_t flags = 0) +{ + pthreadpool_parallelize_3d( + threadpool, + &libpthreadpool::detail::call_wrapper_3d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + flags); +} + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * functor(i, j, k, min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param tile_k the maximum number of items along the third dimension of + * the 3D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_3d_tile_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_k, + uint32_t flags = 0) +{ + pthreadpool_parallelize_3d_tile_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_3d_tile_1d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + tile_k, + flags); +} + +/** + * Process items on a 3D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j += tile_j) + * for (size_t k = 0; k < range_k; k += tile_k) + * functor(i, j, k, + * min(range_j - j, tile_j), min(range_k - k, tile_k)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 3D grid. + * @param range_j the number of items to process along the second dimension + * of the 3D grid. + * @param range_k the number of items to process along the third dimension + * of the 3D grid. + * @param tile_j the maximum number of items along the second dimension of + * the 3D grid to process in one functor call. + * @param tile_k the maximum number of items along the third dimension of + * the 3D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_3d_tile_2d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t tile_j, + size_t tile_k, + uint32_t flags = 0) +{ + pthreadpool_parallelize_3d_tile_2d( + threadpool, + &libpthreadpool::detail::call_wrapper_3d_tile_2d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + tile_j, + tile_k, + flags); +} + +/** + * Process items on a 4D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * functor(i, j, k, l); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 4D grid. + * @param range_j the number of items to process along the second dimension + * of the 4D grid. + * @param range_k the number of items to process along the third dimension + * of the 4D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 4D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_4d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + uint32_t flags = 0) +{ + pthreadpool_parallelize_4d( + threadpool, + &libpthreadpool::detail::call_wrapper_4d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + flags); +} + +/** + * Process items on a 4D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l += tile_l) + * functor(i, j, k, l, min(range_l - l, tile_l)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 4D grid. + * @param range_j the number of items to process along the second dimension + * of the 4D grid. + * @param range_k the number of items to process along the third dimension + * of the 4D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 4D grid. + * @param tile_l the maximum number of items along the fourth dimension of + * the 4D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_4d_tile_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_l, + uint32_t flags = 0) +{ + pthreadpool_parallelize_4d_tile_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_4d_tile_1d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + tile_l, + flags); +} + +/** + * Process items on a 4D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k += tile_k) + * for (size_t l = 0; l < range_l; l += tile_l) + * functor(i, j, k, l, + * min(range_k - k, tile_k), min(range_l - l, tile_l)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 4D grid. + * @param range_j the number of items to process along the second dimension + * of the 4D grid. + * @param range_k the number of items to process along the third dimension + * of the 4D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 4D grid. + * @param tile_k the maximum number of items along the third dimension of + * the 4D grid to process in one functor call. + * @param tile_l the maximum number of items along the fourth dimension of + * the 4D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_4d_tile_2d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t tile_k, + size_t tile_l, + uint32_t flags = 0) +{ + pthreadpool_parallelize_4d_tile_2d( + threadpool, + &libpthreadpool::detail::call_wrapper_4d_tile_2d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + tile_k, + tile_l, + flags); +} + +/** + * Process items on a 5D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m++) + * functor(i, j, k, l, m); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 5D grid. + * @param range_j the number of items to process along the second dimension + * of the 5D grid. + * @param range_k the number of items to process along the third dimension + * of the 5D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 5D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 5D grid. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_5d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + uint32_t flags = 0) +{ + pthreadpool_parallelize_5d( + threadpool, + &libpthreadpool::detail::call_wrapper_5d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + range_m, + flags); +} + +/** + * Process items on a 5D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m += tile_m) + * functor(i, j, k, l, m, min(range_m - m, tile_m)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 5D grid. + * @param range_j the number of items to process along the second dimension + * of the 5D grid. + * @param range_k the number of items to process along the third dimension + * of the 5D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 5D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 5D grid. + * @param tile_m the maximum number of items along the fifth dimension of + * the 5D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_5d_tile_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t tile_m, + uint32_t flags = 0) +{ + pthreadpool_parallelize_5d_tile_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_5d_tile_1d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + range_m, + tile_m, + flags); +} + +/** + * Process items on a 5D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l += tile_l) + * for (size_t m = 0; m < range_m; m += tile_m) + * functor(i, j, k, l, m, + * min(range_l - l, tile_l), min(range_m - m, tile_m)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 5D grid. + * @param range_j the number of items to process along the second dimension + * of the 5D grid. + * @param range_k the number of items to process along the third dimension + * of the 5D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 5D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 5D grid. + * @param tile_l the maximum number of items along the fourth dimension of + * the 5D grid to process in one functor call. + * @param tile_m the maximum number of items along the fifth dimension of + * the 5D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_5d_tile_2d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t tile_l, + size_t tile_m, + uint32_t flags = 0) +{ + pthreadpool_parallelize_5d_tile_2d( + threadpool, + &libpthreadpool::detail::call_wrapper_5d_tile_2d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + range_m, + tile_l, + tile_m, + flags); +} + +/** + * Process items on a 6D grid. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m++) + * for (size_t n = 0; n < range_n; n++) + * functor(i, j, k, l, m, n); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 6D grid. + * @param range_j the number of items to process along the second dimension + * of the 6D grid. + * @param range_k the number of items to process along the third dimension + * of the 6D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 6D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 6D grid. + * @param range_n the number of items to process along the sixth dimension + * of the 6D grid. + * @param tile_n the maximum number of items along the sixth dimension of + * the 6D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_6d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t range_n, + uint32_t flags = 0) +{ + pthreadpool_parallelize_6d( + threadpool, + &libpthreadpool::detail::call_wrapper_6d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + range_m, + range_n, + flags); +} + +/** + * Process items on a 6D grid with the specified maximum tile size along the + * last grid dimension. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m++) + * for (size_t n = 0; n < range_n; n += tile_n) + * functor(i, j, k, l, m, n, min(range_n - n, tile_n)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 6D grid. + * @param range_j the number of items to process along the second dimension + * of the 6D grid. + * @param range_k the number of items to process along the third dimension + * of the 6D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 6D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 6D grid. + * @param range_n the number of items to process along the sixth dimension + * of the 6D grid. + * @param tile_n the maximum number of items along the sixth dimension of + * the 6D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_6d_tile_1d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t range_n, + size_t tile_n, + uint32_t flags = 0) +{ + pthreadpool_parallelize_6d_tile_1d( + threadpool, + &libpthreadpool::detail::call_wrapper_6d_tile_1d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + range_m, + range_n, + tile_n, + flags); +} + +/** + * Process items on a 6D grid with the specified maximum tile size along the + * last two grid dimensions. + * + * The function implements a parallel version of the following snippet: + * + * for (size_t i = 0; i < range_i; i++) + * for (size_t j = 0; j < range_j; j++) + * for (size_t k = 0; k < range_k; k++) + * for (size_t l = 0; l < range_l; l++) + * for (size_t m = 0; m < range_m; m += tile_m) + * for (size_t n = 0; n < range_n; n += tile_n) + * functor(i, j, k, l, m, n, + * min(range_m - m, tile_m), min(range_n - n, tile_n)); + * + * When the function returns, all items have been processed and the thread pool + * is ready for a new task. + * + * @note If multiple threads call this function with the same thread pool, the + * calls are serialized. + * + * @param threadpool the thread pool to use for parallelisation. If threadpool + * is NULL, all items are processed serially on the calling thread. + * @param functor the functor to call for each tile. + * @param range_i the number of items to process along the first dimension + * of the 6D grid. + * @param range_j the number of items to process along the second dimension + * of the 6D grid. + * @param range_k the number of items to process along the third dimension + * of the 6D grid. + * @param range_l the number of items to process along the fourth dimension + * of the 6D grid. + * @param range_m the number of items to process along the fifth dimension + * of the 6D grid. + * @param range_n the number of items to process along the sixth dimension + * of the 6D grid. + * @param tile_m the maximum number of items along the fifth dimension of + * the 6D grid to process in one functor call. + * @param tile_n the maximum number of items along the sixth dimension of + * the 6D grid to process in one functor call. + * @param flags a bitwise combination of zero or more optional flags + * (PTHREADPOOL_FLAG_DISABLE_DENORMALS or PTHREADPOOL_FLAG_YIELD_WORKERS) + */ +template +inline void pthreadpool_parallelize_6d_tile_2d( + pthreadpool_t threadpool, + const T& functor, + size_t range_i, + size_t range_j, + size_t range_k, + size_t range_l, + size_t range_m, + size_t range_n, + size_t tile_m, + size_t tile_n, + uint32_t flags = 0) +{ + pthreadpool_parallelize_6d_tile_2d( + threadpool, + &libpthreadpool::detail::call_wrapper_6d_tile_2d, + const_cast(static_cast(&functor)), + range_i, + range_j, + range_k, + range_l, + range_m, + range_n, + tile_m, + tile_n, + flags); +} + +#endif /* __cplusplus */ + +#endif /* PTHREADPOOL_H_ */ diff --git a/phivenv/Lib/site-packages/torch/include/sleef.h b/phivenv/Lib/site-packages/torch/include/sleef.h new file mode 100644 index 0000000000000000000000000000000000000000..2321a33cf13a390bd5171767ba46752e1c26143d --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/sleef.h @@ -0,0 +1,4216 @@ +// Copyright Naoki Shibata and contributors 2010 - 2024. +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#ifndef __SLEEF_H__ +#define __SLEEF_H__ + +#define SLEEF_VERSION_MAJOR 3 +#define SLEEF_VERSION_MINOR 8 +#define SLEEF_VERSION_PATCHLEVEL 0 + +#include +#include + +#if defined (__GNUC__) || defined (__clang__) || defined(__INTEL_COMPILER) +#define SLEEF_CONST __attribute__((const)) +#define SLEEF_INLINE __attribute__((always_inline)) +#elif defined(_MSC_VER) +#define SLEEF_CONST +#define SLEEF_INLINE __forceinline +#endif + +#if defined(__AVX2__) || defined(__aarch64__) || defined(__arm__) || defined(__powerpc64__) || defined(__zarch__) +#ifndef FP_FAST_FMA +#define FP_FAST_FMA +#endif +#ifndef FP_FAST_FMAF +#define FP_FAST_FMAF +#endif +#endif + +#if defined(_MSC_VER) && !defined(__STDC__) +#define __STDC__ 1 +#endif + +#if (defined(__MINGW32__) || defined(__MINGW64__) || defined(__CYGWIN__) || defined(_MSC_VER)) && !defined(SLEEF_STATIC_LIBS) +#ifdef SLEEF_IMPORT_IS_EXPORT +#define SLEEF_IMPORT __declspec(dllexport) +#else // #ifdef SLEEF_IMPORT_IS_EXPORT +#define SLEEF_IMPORT __declspec(dllimport) +#if (defined(_MSC_VER)) +#pragma comment(lib,"sleef.lib") +#endif // #if (defined(_MSC_VER)) +#endif // #ifdef SLEEF_IMPORT_IS_EXPORT +#else // #if (defined(__MINGW32__) || defined(__MINGW64__) || defined(__CYGWIN__) || defined(_MSC_VER)) && !defined(SLEEF_STATIC_LIBS) +#define SLEEF_IMPORT +#endif // #if (defined(__MINGW32__) || defined(__MINGW64__) || defined(__CYGWIN__) || defined(_MSC_VER)) && !defined(SLEEF_STATIC_LIBS) + +#if (defined(__GNUC__) || defined(__CLANG__)) && (defined(__i386__) || defined(__x86_64__)) +#include +#endif + +#if (defined(_MSC_VER)) +#include +#endif + +#if defined(__ARM_NEON__) || defined(__ARM_NEON) +#include +#endif + +#if defined(__ARM_FEATURE_SVE) +#include +#endif + +#if defined(__VSX__) && defined(__PPC64__) && defined(__LITTLE_ENDIAN__) +#include +typedef __vector double SLEEF_VECTOR_DOUBLE; +typedef __vector float SLEEF_VECTOR_FLOAT; +typedef __vector int SLEEF_VECTOR_INT; +typedef __vector unsigned int SLEEF_VECTOR_UINT; +typedef __vector long long SLEEF_VECTOR_LONGLONG; +typedef __vector unsigned long long SLEEF_VECTOR_ULONGLONG; +#endif + +#if defined(__VX__) && defined(__VEC__) +#ifndef SLEEF_VECINTRIN_H_INCLUDED +#include +#define SLEEF_VECINTRIN_H_INCLUDED +#endif +typedef __vector double SLEEF_VECTOR_DOUBLE; +typedef __vector float SLEEF_VECTOR_FLOAT; +typedef __vector int SLEEF_VECTOR_INT; +typedef __vector unsigned int SLEEF_VECTOR_UINT; +typedef __vector long long SLEEF_VECTOR_LONGLONG; +typedef __vector unsigned long long SLEEF_VECTOR_ULONGLONG; +#endif + +// + +#if defined(SLEEF_ENABLE_OMP_SIMD) && (defined(__GNUC__) || defined(__CLANG__)) && !defined(__INTEL_COMPILER) +#if defined(__aarch64__) +//#define SLEEF_PRAGMA_OMP_SIMD_DP _Pragma ("omp declare simd simdlen(2) notinbranch") +//#define SLEEF_PRAGMA_OMP_SIMD_SP _Pragma ("omp declare simd simdlen(4) notinbranch") +//#elif defined(__x86_64__) && defined(__AVX512F__) +//#define SLEEF_PRAGMA_OMP_SIMD_DP _Pragma ("omp declare simd simdlen(8) notinbranch") +//#define SLEEF_PRAGMA_OMP_SIMD_SP _Pragma ("omp declare simd simdlen(16) notinbranch") +#elif defined(__x86_64__) && defined(__AVX__) +#define SLEEF_PRAGMA_OMP_SIMD_DP _Pragma ("omp declare simd simdlen(4) notinbranch") +#define SLEEF_PRAGMA_OMP_SIMD_SP _Pragma ("omp declare simd simdlen(8) notinbranch") +#elif defined(__x86_64__) && defined(__SSE2__) +#define SLEEF_PRAGMA_OMP_SIMD_DP _Pragma ("omp declare simd simdlen(2) notinbranch") +#define SLEEF_PRAGMA_OMP_SIMD_SP _Pragma ("omp declare simd simdlen(4) notinbranch") +#endif +#endif + +#ifndef SLEEF_PRAGMA_OMP_SIMD_DP +#define SLEEF_PRAGMA_OMP_SIMD_DP +#define SLEEF_PRAGMA_OMP_SIMD_SP +#endif + +// + +#ifndef SLEEF_FP_ILOGB0 +#define SLEEF_FP_ILOGB0 ((int)0x80000000) +#endif + +#ifndef SLEEF_FP_ILOGBNAN +#define SLEEF_FP_ILOGBNAN ((int)2147483647) +#endif + +// + +SLEEF_IMPORT void *Sleef_malloc(size_t z); +SLEEF_IMPORT void Sleef_free(void *ptr); +SLEEF_IMPORT uint64_t Sleef_currentTimeMicros(); + +#if defined(__i386__) || defined(__x86_64__) || defined(_MSC_VER) +SLEEF_IMPORT void Sleef_x86CpuID(int32_t out[4], uint32_t eax, uint32_t ecx); +#endif + +// + +#if defined(__riscv_v) +#include +typedef vfloat64m2_t Sleef_vfloat64m1_t_2; +typedef vfloat32m2_t Sleef_vfloat32m1_t_2; +typedef vfloat64m4_t Sleef_vfloat64m2_t_2; +typedef vfloat32m4_t Sleef_vfloat32m2_t_2; +#define Sleef_vfloat64m1_t_2_DEFINED +#define Sleef_vfloat32m1_t_2_DEFINED +#define Sleef_vfloat64m2_t_2_DEFINED +#define Sleef_vfloat32m2_t_2_DEFINED +#endif + +#ifndef Sleef_double2_DEFINED +#define Sleef_double2_DEFINED +typedef struct { + double x, y; +} Sleef_double2; +#endif + +#ifndef Sleef_float2_DEFINED +#define Sleef_float2_DEFINED +typedef struct { + float x, y; +} Sleef_float2; +#endif + +#ifndef Sleef_longdouble2_DEFINED +#define Sleef_longdouble2_DEFINED +typedef struct { + long double x, y; +} Sleef_longdouble2; +#endif + +#if (defined(__SIZEOF_FLOAT128__) && __SIZEOF_FLOAT128__ == 16) || (defined(__linux__) && defined(__GNUC__) && (defined(__i386__) || defined(__x86_64__))) || (defined(__PPC64__) && defined(__GNUC__) && !defined(__clang__) && __GNUC__ >= 8) +#define SLEEF_FLOAT128_IS_IEEEQP +#endif + +#if !defined(SLEEF_FLOAT128_IS_IEEEQP) && defined(__SIZEOF_LONG_DOUBLE__) && __SIZEOF_LONG_DOUBLE__ == 16 && (defined(__aarch64__) || defined(__zarch__)) +#define SLEEF_LONGDOUBLE_IS_IEEEQP +#endif + +#if !defined(Sleef_quad_DEFINED) +#define Sleef_quad_DEFINED +typedef struct { uint64_t x, y; } Sleef_uint64_2t; +#if defined(SLEEF_FLOAT128_IS_IEEEQP) || defined(ENABLEFLOAT128) +typedef __float128 Sleef_quad; +#define SLEEF_QUAD_C(x) (x ## Q) +#elif defined(SLEEF_LONGDOUBLE_IS_IEEEQP) +typedef long double Sleef_quad; +#define SLEEF_QUAD_C(x) (x ## L) +#else +typedef Sleef_uint64_2t Sleef_quad; +#endif +#endif +#if !defined(Sleef_quad2_DEFINED) +#define Sleef_quad2_DEFINED +typedef union { + struct { + Sleef_quad x, y; + }; + Sleef_quad s[2]; +} Sleef_quad2; +#endif + +#ifdef __cplusplus +extern "C" +{ +#endif + +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sin_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cos_u35(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double2 Sleef_sincos_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tan_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_asin_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_acos_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atan_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atan2_u35(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cbrt_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sin_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cos_u10(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double2 Sleef_sincos_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tan_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_asin_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_acos_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atan_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atan2_u10(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cbrt_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_pow_u10(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sinh_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cosh_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tanh_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sinh_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cosh_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tanh_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_asinh_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_acosh_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atanh_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp2_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp10_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp2_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp10_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_expm1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log10_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log2_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log2_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log1p_u10(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double2 Sleef_sincospi_u05(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double2 Sleef_sincospi_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sinpi_u05(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cospi_u05(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_ldexp(double, int); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST int Sleef_ilogb(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fma(double, double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sqrt(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sqrt_u05(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sqrt_u35(double); + +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_hypot_u05(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_hypot_u35(double, double); + +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fabs(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_copysign(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fmax(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fmin(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fdim(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_trunc(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_floor(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_ceil(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_round(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_rint(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_nextafter(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_frfrexp(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST int Sleef_expfrexp(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fmod(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_remainder(double, double); +SLEEF_IMPORT SLEEF_CONST Sleef_double2 Sleef_modf(double); + +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_lgamma_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tgamma_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_erf_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_erfc_u15(double); + +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cosf_u35(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float2 Sleef_sincosf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_asinf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_acosf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atanf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f_u35(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_logf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cosf_u10(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float2 Sleef_sincosf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fastsinf_u3500(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fastcosf_u3500(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_asinf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_acosf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atanf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f_u10(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_logf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_expf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_powf_u10(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fastpowf_u3500(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_coshf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_coshf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_asinhf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_acoshf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atanhf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_expm1f_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log10f_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log2f_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log2f_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log1pf_u10(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float2 Sleef_sincospif_u05(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float2 Sleef_sincospif_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinpif_u05(float d); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cospif_u05(float d); +SLEEF_IMPORT SLEEF_CONST float Sleef_ldexpf(float, int); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST int Sleef_ilogbf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fmaf(float, float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf_u05(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf_u35(float); + +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf_u05(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf_u35(float, float); + +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fabsf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_copysignf(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fmaxf(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fminf(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fdimf(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_truncf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_floorf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_ceilf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_roundf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_rintf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_nextafterf(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_frfrexpf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST int Sleef_expfrexpf(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fmodf(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_remainderf(float, float); +SLEEF_IMPORT SLEEF_CONST Sleef_float2 Sleef_modff(float); + +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_lgammaf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tgammaf_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_erff_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_erfcf_u15(float); + +SLEEF_IMPORT SLEEF_CONST Sleef_longdouble2 Sleef_sincospil_u05(long double); +SLEEF_IMPORT SLEEF_CONST Sleef_longdouble2 Sleef_sincospil_u35(long double); + +#if defined(Sleef_quad2_DEFINED) +SLEEF_IMPORT SLEEF_CONST Sleef_quad2 Sleef_sincospiq_u05(Sleef_quad); +SLEEF_IMPORT SLEEF_CONST Sleef_quad2 Sleef_sincospiq_u35(Sleef_quad); +#endif +#ifdef __SSE2__ + +#ifndef Sleef___m128d_2_DEFINED +typedef struct { + __m128d x, y; +} Sleef___m128d_2; +#define Sleef___m128d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u35(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u10(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_powd2_u10(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastsind2_u3500(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastcosd2_u3500(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastpowd2_u3500(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asinhd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acoshd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atanhd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expm1d2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log10d2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log1pd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u05(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinpid2_u05(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cospid2_u05(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ldexpd2(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmad2(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u05(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u35(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u05(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u35(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fabsd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_copysignd2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmaxd2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmind2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fdimd2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_truncd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_floord2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ceild2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_roundd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_rintd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_nextafterd2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_frfrexpd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmodd2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_remainderd2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_modfd2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_lgammad2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tgammad2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfd2_u10(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfcd2_u15(__m128d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd2(int); + +#ifndef Sleef___m128_2_DEFINED +typedef struct { + __m128 x, y; +} Sleef___m128_2; +#define Sleef___m128_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u35(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u10(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_powf4_u10(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastsinf4_u3500(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastcosf4_u3500(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastpowf4_u3500(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinhf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acoshf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanhf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expm1f4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log10f4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log1pf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u05(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinpif4_u05(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cospif4_u05(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaf4(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u05(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u35(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u05(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u35(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fabsf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_copysignf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaxf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fminf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fdimf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_truncf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_floorf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_ceilf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_roundf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_rintf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_nextafterf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_frfrexpf4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmodf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_remainderf4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_modff4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_lgammaf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tgammaf4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erff4_u10(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erfcf4_u15(__m128); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf4(int); +#endif +#ifdef __SSE2__ + +#ifndef Sleef___m128d_2_DEFINED +typedef struct { + __m128d x, y; +} Sleef___m128d_2; +#define Sleef___m128d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sind2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cosd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincosd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tand2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_asind2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_acosd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atand2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u35sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atan2d2_u35sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_logd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cbrtd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sind2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cosd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincosd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tand2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_asind2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_acosd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atand2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u10sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atan2d2_u10sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_logd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cbrtd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_expd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_powd2_u10sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_powd2_u10sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sinhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_coshd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tanhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sinhd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_coshd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tanhd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastsind2_u3500sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fastsind2_u3500sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastcosd2_u3500sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fastcosd2_u3500sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastpowd2_u3500sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fastpowd2_u3500sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asinhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_asinhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acoshd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_acoshd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atanhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atanhd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp2d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp2d2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp10d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp10d2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expm1d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_expm1d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log10d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log10d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log2d2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log2d2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log1pd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log1pd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincospid2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincospid2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinpid2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sinpid2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cospid2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cospid2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ldexpd2_sse2(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_ldexpd2_sse2(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_cinz_ilogbd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmad2_sse2(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmad2_sse2(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sqrtd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sqrtd2_u05sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sqrtd2_u35sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u05sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_hypotd2_u05sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u35sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_hypotd2_u35sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fabsd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fabsd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_copysignd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_copysignd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmaxd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmaxd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmind2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmind2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fdimd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fdimd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_truncd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_truncd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_floord2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_floord2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ceild2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_ceild2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_roundd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_roundd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_rintd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_rintd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_nextafterd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_nextafterd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_frfrexpd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_frfrexpd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_cinz_expfrexpd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmodd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmodd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_remainderd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_remainderd2_sse2(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_modfd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_modfd2_sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_lgammad2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_lgammad2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tgammad2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tgammad2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_erfd2_u10sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfcd2_u15sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_erfcd2_u15sse2(__m128d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd2_sse2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd2_sse2(int); + +#ifndef Sleef___m128_2_DEFINED +typedef struct { + __m128 x, y; +} Sleef___m128_2; +#define Sleef___m128_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cosf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincosf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_asinf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_acosf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atanf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u35sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atan2f4_u35sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_logf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cbrtf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cosf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincosf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_asinf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_acosf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atanf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u10sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atan2f4_u10sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_logf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cbrtf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_expf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_powf4_u10sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_powf4_u10sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_coshf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinhf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_coshf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanhf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastsinf4_u3500sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fastsinf4_u3500sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastcosf4_u3500sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fastcosf4_u3500sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastpowf4_u3500sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fastpowf4_u3500sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_asinhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acoshf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_acoshf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atanhf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp2f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp2f4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp10f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp10f4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expm1f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_expm1f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log10f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log10f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log2f4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log2f4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log1pf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log1pf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincospif4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincospif4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinpif4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinpif4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cospif4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cospif4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaf4_sse2(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fmaf4_sse2(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sqrtf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sqrtf4_u05sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sqrtf4_u35sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u05sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_hypotf4_u05sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u35sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_hypotf4_u35sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fabsf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fabsf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_copysignf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_copysignf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaxf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fmaxf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fminf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fminf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fdimf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fdimf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_truncf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_truncf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_floorf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_floorf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_ceilf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_ceilf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_roundf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_roundf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_rintf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_rintf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_nextafterf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_nextafterf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_frfrexpf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_frfrexpf4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmodf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fmodf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_remainderf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_remainderf4_sse2(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_modff4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_modff4_sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_lgammaf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_lgammaf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tgammaf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tgammaf4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erff4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_erff4_u10sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erfcf4_u15sse2(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_erfcf4_u15sse2(__m128); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf4_sse2(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_cinz_getIntf4_sse2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf4_sse2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_cinz_getPtrf4_sse2(int); +#endif +#ifdef __SSE2__ + +#ifndef Sleef___m128d_2_DEFINED +typedef struct { + __m128d x, y; +} Sleef___m128d_2; +#define Sleef___m128d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sind2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cosd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincosd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tand2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_asind2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_acosd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atand2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u35sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atan2d2_u35sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_logd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cbrtd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sind2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cosd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincosd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tand2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_asind2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_acosd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atand2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u10sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atan2d2_u10sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_logd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cbrtd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_expd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_powd2_u10sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_powd2_u10sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sinhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_coshd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tanhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sinhd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_coshd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tanhd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastsind2_u3500sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fastsind2_u3500sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastcosd2_u3500sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fastcosd2_u3500sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastpowd2_u3500sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fastpowd2_u3500sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asinhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_asinhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acoshd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_acoshd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atanhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_atanhd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp2d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp2d2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp10d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_exp10d2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expm1d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_expm1d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log10d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log10d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log2d2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log2d2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log1pd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_log1pd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincospid2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_sincospid2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinpid2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sinpid2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cospid2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_cospid2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ldexpd2_sse4(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_ldexpd2_sse4(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_cinz_ilogbd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmad2_sse4(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmad2_sse4(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sqrtd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sqrtd2_u05sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_sqrtd2_u35sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u05sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_hypotd2_u05sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u35sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_hypotd2_u35sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fabsd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fabsd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_copysignd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_copysignd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmaxd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmaxd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmind2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmind2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fdimd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fdimd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_truncd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_truncd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_floord2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_floord2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ceild2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_ceild2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_roundd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_roundd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_rintd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_rintd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_nextafterd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_nextafterd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_frfrexpd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_frfrexpd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_cinz_expfrexpd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmodd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_fmodd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_remainderd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_remainderd2_sse4(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_modfd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_cinz_modfd2_sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_lgammad2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_lgammad2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tgammad2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_tgammad2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_erfd2_u10sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfcd2_u15sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cinz_erfcd2_u15sse4(__m128d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd2_sse4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd2_sse4(int); + +#ifndef Sleef___m128_2_DEFINED +typedef struct { + __m128 x, y; +} Sleef___m128_2; +#define Sleef___m128_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cosf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincosf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_asinf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_acosf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atanf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u35sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atan2f4_u35sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_logf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cbrtf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cosf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincosf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_asinf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_acosf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atanf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u10sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atan2f4_u10sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_logf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cbrtf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_expf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_powf4_u10sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_powf4_u10sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_coshf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinhf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_coshf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tanhf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastsinf4_u3500sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fastsinf4_u3500sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastcosf4_u3500sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fastcosf4_u3500sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastpowf4_u3500sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fastpowf4_u3500sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_asinhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acoshf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_acoshf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_atanhf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp2f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp2f4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp10f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_exp10f4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expm1f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_expm1f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log10f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log10f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log2f4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log2f4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log1pf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_log1pf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincospif4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_sincospif4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinpif4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sinpif4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cospif4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_cospif4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaf4_sse4(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fmaf4_sse4(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sqrtf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sqrtf4_u05sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_sqrtf4_u35sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u05sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_hypotf4_u05sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u35sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_hypotf4_u35sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fabsf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fabsf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_copysignf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_copysignf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaxf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fmaxf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fminf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fminf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fdimf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fdimf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_truncf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_truncf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_floorf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_floorf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_ceilf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_ceilf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_roundf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_roundf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_rintf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_rintf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_nextafterf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_nextafterf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_frfrexpf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_frfrexpf4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmodf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_fmodf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_remainderf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_remainderf4_sse4(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_modff4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_cinz_modff4_sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_lgammaf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_lgammaf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tgammaf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_tgammaf4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erff4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_erff4_u10sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erfcf4_u15sse4(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cinz_erfcf4_u15sse4(__m128); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf4_sse4(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_cinz_getIntf4_sse4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf4_sse4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_cinz_getPtrf4_sse4(int); +#endif +#ifdef __AVX__ + +#ifndef Sleef___m256d_2_DEFINED +typedef struct { + __m256d x, y; +} Sleef___m256d_2; +#define Sleef___m256d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u35(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u10(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_powd4_u10(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastsind4_u3500(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastcosd4_u3500(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastpowd4_u3500(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asinhd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acoshd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atanhd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expm1d4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log10d4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log1pd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u05(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinpid4_u05(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cospid4_u05(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ldexpd4(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmad4(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u05(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u35(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u05(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u35(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fabsd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_copysignd4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmaxd4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmind4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fdimd4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_truncd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_floord4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ceild4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_roundd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_rintd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_nextafterd4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_frfrexpd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmodd4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_remainderd4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_modfd4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_lgammad4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tgammad4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfd4_u10(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfcd4_u15(__m256d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd4(int); + +#ifndef Sleef___m256_2_DEFINED +typedef struct { + __m256 x, y; +} Sleef___m256_2; +#define Sleef___m256_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u35(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u10(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_powf8_u10(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastsinf8_u3500(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastcosf8_u3500(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastpowf8_u3500(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinhf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acoshf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanhf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expm1f8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log10f8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log1pf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u05(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinpif8_u05(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cospif8_u05(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaf8(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u05(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u35(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u05(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u35(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fabsf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_copysignf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaxf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fminf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fdimf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_truncf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_floorf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_ceilf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_roundf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_rintf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_nextafterf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_frfrexpf8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmodf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_remainderf8(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_modff8(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_lgammaf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tgammaf8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erff8_u10(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erfcf8_u15(__m256); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf8(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf8(int); +#endif +#ifdef __AVX__ + +#ifndef Sleef___m256d_2_DEFINED +typedef struct { + __m256d x, y; +} Sleef___m256d_2; +#define Sleef___m256d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sind4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_cosd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_cinz_sincosd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_tand4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_asind4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_acosd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_atand4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u35avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_atan2d4_u35avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_logd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_cbrtd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sind4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_cosd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_cinz_sincosd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_tand4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_asind4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_acosd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_atand4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u10avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_atan2d4_u10avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_logd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_cbrtd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_expd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_powd4_u10avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_powd4_u10avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sinhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_coshd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_tanhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sinhd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_coshd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_tanhd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastsind4_u3500avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fastsind4_u3500avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastcosd4_u3500avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fastcosd4_u3500avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastpowd4_u3500avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fastpowd4_u3500avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asinhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_asinhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acoshd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_acoshd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atanhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_atanhd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_exp2d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_exp2d4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_exp10d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_exp10d4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expm1d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_expm1d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log10d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_log10d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_log2d4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_log2d4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log1pd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_log1pd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_cinz_sincospid4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_cinz_sincospid4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinpid4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sinpid4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cospid4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_cospid4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ldexpd4_avx(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_ldexpd4_avx(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_cinz_ilogbd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmad4_avx(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fmad4_avx(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sqrtd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sqrtd4_u05avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_sqrtd4_u35avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u05avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_hypotd4_u05avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u35avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_hypotd4_u35avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fabsd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fabsd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_copysignd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_copysignd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmaxd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fmaxd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmind4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fmind4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fdimd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fdimd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_truncd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_truncd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_floord4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_floord4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ceild4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_ceild4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_roundd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_roundd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_rintd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_rintd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_nextafterd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_nextafterd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_frfrexpd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_frfrexpd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_cinz_expfrexpd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmodd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_fmodd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_remainderd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_remainderd4_avx(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_modfd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_cinz_modfd4_avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_lgammad4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_lgammad4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tgammad4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_tgammad4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_erfd4_u10avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfcd4_u15avx(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cinz_erfcd4_u15avx(__m256d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd4_avx(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd4_avx(int); + +#ifndef Sleef___m256_2_DEFINED +typedef struct { + __m256 x, y; +} Sleef___m256_2; +#define Sleef___m256_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sinf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_cosf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_cinz_sincosf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_tanf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_asinf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_acosf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_atanf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u35avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_atan2f8_u35avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_logf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_cbrtf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sinf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_cosf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_cinz_sincosf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_tanf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_asinf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_acosf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_atanf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u10avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_atan2f8_u10avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_logf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_cbrtf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_expf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_powf8_u10avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_powf8_u10avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sinhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_coshf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_tanhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sinhf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_coshf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_tanhf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastsinf8_u3500avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fastsinf8_u3500avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastcosf8_u3500avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fastcosf8_u3500avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastpowf8_u3500avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fastpowf8_u3500avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_asinhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acoshf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_acoshf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_atanhf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_exp2f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_exp2f8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_exp10f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_exp10f8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expm1f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_expm1f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log10f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_log10f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_log2f8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_log2f8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log1pf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_log1pf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_cinz_sincospif8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_cinz_sincospif8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinpif8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sinpif8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cospif8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_cospif8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaf8_avx(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fmaf8_avx(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sqrtf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sqrtf8_u05avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_sqrtf8_u35avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u05avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_hypotf8_u05avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u35avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_hypotf8_u35avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fabsf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fabsf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_copysignf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_copysignf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaxf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fmaxf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fminf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fminf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fdimf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fdimf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_truncf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_truncf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_floorf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_floorf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_ceilf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_ceilf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_roundf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_roundf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_rintf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_rintf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_nextafterf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_nextafterf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_frfrexpf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_frfrexpf8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmodf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_fmodf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_remainderf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_remainderf8_avx(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_modff8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_cinz_modff8_avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_lgammaf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_lgammaf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tgammaf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_tgammaf8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erff8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_erff8_u10avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erfcf8_u15avx(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cinz_erfcf8_u15avx(__m256); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf8_avx(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_cinz_getIntf8_avx(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf8_avx(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_cinz_getPtrf8_avx(int); +#endif +#ifdef __AVX__ + +#ifndef Sleef___m256d_2_DEFINED +typedef struct { + __m256d x, y; +} Sleef___m256d_2; +#define Sleef___m256d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sind4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cosd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincosd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tand4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_asind4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_acosd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atand4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u35fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atan2d4_u35fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_logd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cbrtd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sind4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cosd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincosd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tand4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_asind4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_acosd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atand4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u10fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atan2d4_u10fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_logd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cbrtd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_expd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_powd4_u10fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_powd4_u10fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sinhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_coshd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tanhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sinhd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_coshd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tanhd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastsind4_u3500fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fastsind4_u3500fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastcosd4_u3500fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fastcosd4_u3500fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastpowd4_u3500fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fastpowd4_u3500fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asinhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_asinhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acoshd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_acoshd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atanhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atanhd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp2d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp2d4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp10d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp10d4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expm1d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_expm1d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log10d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log10d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log2d4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log2d4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log1pd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log1pd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincospid4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincospid4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinpid4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sinpid4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cospid4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cospid4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ldexpd4_fma4(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_ldexpd4_fma4(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_finz_ilogbd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmad4_fma4(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmad4_fma4(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sqrtd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sqrtd4_u05fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sqrtd4_u35fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u05fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_hypotd4_u05fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u35fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_hypotd4_u35fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fabsd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fabsd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_copysignd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_copysignd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmaxd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmaxd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmind4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmind4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fdimd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fdimd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_truncd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_truncd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_floord4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_floord4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ceild4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_ceild4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_roundd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_roundd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_rintd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_rintd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_nextafterd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_nextafterd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_frfrexpd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_frfrexpd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_finz_expfrexpd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmodd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmodd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_remainderd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_remainderd4_fma4(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_modfd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_modfd4_fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_lgammad4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_lgammad4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tgammad4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tgammad4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_erfd4_u10fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfcd4_u15fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_erfcd4_u15fma4(__m256d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd4_fma4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd4_fma4(int); + +#ifndef Sleef___m256_2_DEFINED +typedef struct { + __m256 x, y; +} Sleef___m256_2; +#define Sleef___m256_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cosf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincosf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_asinf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_acosf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atanf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u35fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atan2f8_u35fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_logf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cbrtf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cosf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincosf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_asinf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_acosf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atanf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u10fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atan2f8_u10fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_logf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cbrtf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_expf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_powf8_u10fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_powf8_u10fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_coshf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinhf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_coshf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanhf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastsinf8_u3500fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fastsinf8_u3500fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastcosf8_u3500fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fastcosf8_u3500fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastpowf8_u3500fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fastpowf8_u3500fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_asinhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acoshf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_acoshf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atanhf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp2f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp2f8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp10f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp10f8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expm1f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_expm1f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log10f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log10f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log2f8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log2f8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log1pf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log1pf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincospif8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincospif8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinpif8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinpif8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cospif8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cospif8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaf8_fma4(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fmaf8_fma4(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sqrtf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sqrtf8_u05fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sqrtf8_u35fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u05fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_hypotf8_u05fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u35fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_hypotf8_u35fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fabsf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fabsf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_copysignf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_copysignf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaxf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fmaxf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fminf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fminf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fdimf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fdimf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_truncf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_truncf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_floorf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_floorf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_ceilf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_ceilf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_roundf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_roundf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_rintf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_rintf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_nextafterf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_nextafterf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_frfrexpf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_frfrexpf8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmodf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fmodf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_remainderf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_remainderf8_fma4(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_modff8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_modff8_fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_lgammaf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_lgammaf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tgammaf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tgammaf8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erff8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_erff8_u10fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erfcf8_u15fma4(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_erfcf8_u15fma4(__m256); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf8_fma4(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_finz_getIntf8_fma4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf8_fma4(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_finz_getPtrf8_fma4(int); +#endif +#ifdef __AVX__ + +#ifndef Sleef___m256d_2_DEFINED +typedef struct { + __m256d x, y; +} Sleef___m256d_2; +#define Sleef___m256d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sind4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cosd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincosd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tand4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_asind4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_acosd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atand4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u35avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atan2d4_u35avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_logd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cbrtd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sind4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sind4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cosd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cosd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincosd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincosd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tand4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tand4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asind4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_asind4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acosd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_acosd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atand4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atand4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atan2d4_u10avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atan2d4_u10avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_logd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_logd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cbrtd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cbrtd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_expd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_powd4_u10avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_powd4_u10avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sinhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_coshd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tanhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinhd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sinhd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_coshd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_coshd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tanhd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tanhd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastsind4_u3500avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fastsind4_u3500avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastcosd4_u3500avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fastcosd4_u3500avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fastpowd4_u3500avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fastpowd4_u3500avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_asinhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_asinhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_acoshd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_acoshd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_atanhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_atanhd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp2d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp2d4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp2d4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp10d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_exp10d4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_exp10d4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_expm1d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_expm1d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log10d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log10d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log2d4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log2d4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log2d4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_log1pd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_log1pd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincospid4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_sincospid4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_sincospid4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sinpid4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sinpid4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_cospid4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_cospid4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ldexpd4_avx2(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_ldexpd4_avx2(__m256d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_finz_ilogbd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmad4_avx2(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmad4_avx2(__m256d, __m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sqrtd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sqrtd4_u05avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_sqrtd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_sqrtd4_u35avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u05avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_hypotd4_u05avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_hypotd4_u35avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_hypotd4_u35avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fabsd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fabsd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_copysignd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_copysignd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmaxd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmaxd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmind4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmind4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fdimd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fdimd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_truncd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_truncd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_floord4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_floord4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_ceild4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_ceild4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_roundd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_roundd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_rintd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_rintd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_nextafterd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_nextafterd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_frfrexpd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_frfrexpd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_finz_expfrexpd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_fmodd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_fmodd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_remainderd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_remainderd4_avx2(__m256d, __m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_modfd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST Sleef___m256d_2 Sleef_finz_modfd4_avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_lgammad4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_lgammad4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_tgammad4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_tgammad4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_erfd4_u10avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_erfcd4_u15avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST __m256d Sleef_finz_erfcd4_u15avx2(__m256d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd4_avx2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd4_avx2(int); + +#ifndef Sleef___m256_2_DEFINED +typedef struct { + __m256 x, y; +} Sleef___m256_2; +#define Sleef___m256_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cosf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincosf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_asinf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_acosf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atanf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u35avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atan2f8_u35avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_logf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cbrtf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cosf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cosf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincosf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincosf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_asinf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acosf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_acosf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atanf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atan2f8_u10avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atan2f8_u10avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_logf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_logf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cbrtf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cbrtf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_expf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_powf8_u10avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_powf8_u10avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_coshf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinhf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinhf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_coshf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_coshf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tanhf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tanhf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastsinf8_u3500avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fastsinf8_u3500avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastcosf8_u3500avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fastcosf8_u3500avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fastpowf8_u3500avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fastpowf8_u3500avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_asinhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_asinhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_acoshf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_acoshf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_atanhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_atanhf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp2f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp2f8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp2f8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp10f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_exp10f8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_exp10f8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_expm1f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_expm1f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log10f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log10f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log2f8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log2f8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log2f8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_log1pf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_log1pf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincospif8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_sincospif8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_sincospif8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sinpif8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sinpif8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_cospif8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_cospif8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaf8_avx2(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fmaf8_avx2(__m256, __m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sqrtf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sqrtf8_u05avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_sqrtf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_sqrtf8_u35avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u05avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_hypotf8_u05avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_hypotf8_u35avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_hypotf8_u35avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fabsf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fabsf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_copysignf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_copysignf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmaxf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fmaxf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fminf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fminf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fdimf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fdimf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_truncf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_truncf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_floorf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_floorf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_ceilf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_ceilf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_roundf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_roundf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_rintf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_rintf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_nextafterf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_nextafterf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_frfrexpf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_frfrexpf8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_fmodf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_fmodf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_remainderf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_remainderf8_avx2(__m256, __m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_modff8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST Sleef___m256_2 Sleef_finz_modff8_avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_lgammaf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_lgammaf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_tgammaf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_tgammaf8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erff8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_erff8_u10avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_erfcf8_u15avx2(__m256); +SLEEF_IMPORT SLEEF_CONST __m256 Sleef_finz_erfcf8_u15avx2(__m256); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf8_avx2(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_finz_getIntf8_avx2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf8_avx2(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_finz_getPtrf8_avx2(int); +#endif +#ifdef __SSE2__ + +#ifndef Sleef___m128d_2_DEFINED +typedef struct { + __m128d x, y; +} Sleef___m128d_2; +#define Sleef___m128d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sind2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_cosd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_finz_sincosd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_tand2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_asind2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_acosd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_atand2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u35avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_atan2d2_u35avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_logd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_cbrtd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sind2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sind2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cosd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_cosd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincosd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_finz_sincosd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tand2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_tand2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asind2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_asind2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acosd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_acosd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atand2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_atand2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atan2d2_u10avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_atan2d2_u10avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_logd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_logd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cbrtd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_cbrtd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_expd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_powd2_u10avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_powd2_u10avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sinhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_coshd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_tanhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinhd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sinhd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_coshd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_coshd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tanhd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_tanhd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastsind2_u3500avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fastsind2_u3500avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastcosd2_u3500avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fastcosd2_u3500avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fastpowd2_u3500avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fastpowd2_u3500avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_asinhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_asinhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_acoshd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_acoshd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_atanhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_atanhd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_exp2d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp2d2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_exp2d2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_exp10d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_exp10d2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_exp10d2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_expm1d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_expm1d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log10d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_log10d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_log2d2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log2d2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_log2d2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_log1pd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_log1pd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_finz_sincospid2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_sincospid2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_finz_sincospid2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sinpid2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sinpid2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_cospid2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_cospid2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ldexpd2_avx2128(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_ldexpd2_avx2128(__m128d, __m128i); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_ilogbd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_finz_ilogbd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmad2_avx2128(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fmad2_avx2128(__m128d, __m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sqrtd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sqrtd2_u05avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_sqrtd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_sqrtd2_u35avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u05avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_hypotd2_u05avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_hypotd2_u35avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_hypotd2_u35avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fabsd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fabsd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_copysignd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_copysignd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmaxd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fmaxd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmind2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fmind2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fdimd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fdimd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_truncd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_truncd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_floord2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_floord2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_ceild2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_ceild2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_roundd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_roundd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_rintd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_rintd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_nextafterd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_nextafterd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_frfrexpd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_frfrexpd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_expfrexpd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128i Sleef_finz_expfrexpd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_fmodd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_fmodd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_remainderd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_remainderd2_avx2128(__m128d, __m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_modfd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST Sleef___m128d_2 Sleef_finz_modfd2_avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_lgammad2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_lgammad2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_tgammad2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_tgammad2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_erfd2_u10avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_erfcd2_u15avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST __m128d Sleef_finz_erfcd2_u15avx2128(__m128d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd2_avx2128(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd2_avx2128(int); + +#ifndef Sleef___m128_2_DEFINED +typedef struct { + __m128 x, y; +} Sleef___m128_2; +#define Sleef___m128_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sinf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_cosf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_finz_sincosf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_tanf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_asinf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_acosf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_atanf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u35avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_atan2f4_u35avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_logf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_cbrtf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sinf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cosf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_cosf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincosf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_finz_sincosf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_tanf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_asinf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acosf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_acosf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_atanf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atan2f4_u10avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_atan2f4_u10avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_logf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_logf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cbrtf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_cbrtf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_expf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_powf4_u10avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_powf4_u10avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sinhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_coshf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_tanhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinhf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sinhf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_coshf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_coshf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tanhf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_tanhf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastsinf4_u3500avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fastsinf4_u3500avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastcosf4_u3500avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fastcosf4_u3500avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fastpowf4_u3500avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fastpowf4_u3500avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_asinhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_asinhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_acoshf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_acoshf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_atanhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_atanhf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_exp2f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp2f4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_exp2f4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_exp10f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_exp10f4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_exp10f4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_expm1f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_expm1f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log10f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_log10f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_log2f4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log2f4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_log2f4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_log1pf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_log1pf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_finz_sincospif4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_sincospif4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_finz_sincospif4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sinpif4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sinpif4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_cospif4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_cospif4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaf4_avx2128(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fmaf4_avx2128(__m128, __m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sqrtf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sqrtf4_u05avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_sqrtf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_sqrtf4_u35avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u05avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_hypotf4_u05avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_hypotf4_u35avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_hypotf4_u35avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fabsf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fabsf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_copysignf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_copysignf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmaxf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fmaxf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fminf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fminf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fdimf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fdimf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_truncf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_truncf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_floorf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_floorf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_ceilf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_ceilf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_roundf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_roundf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_rintf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_rintf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_nextafterf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_nextafterf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_frfrexpf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_frfrexpf4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_fmodf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_fmodf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_remainderf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_remainderf4_avx2128(__m128, __m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_modff4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST Sleef___m128_2 Sleef_finz_modff4_avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_lgammaf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_lgammaf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_tgammaf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_tgammaf4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erff4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_erff4_u10avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_erfcf4_u15avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST __m128 Sleef_finz_erfcf4_u15avx2128(__m128); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf4_avx2128(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_finz_getIntf4_avx2128(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf4_avx2128(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_finz_getPtrf4_avx2128(int); +#endif +#ifdef __AVX512F__ + +#ifndef Sleef___m512d_2_DEFINED +typedef struct { + __m512d x, y; +} Sleef___m512d_2; +#define Sleef___m512d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sind8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cosd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincosd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tand8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asind8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acosd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atand8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atan2d8_u35(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_logd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cbrtd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sind8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cosd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincosd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tand8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asind8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acosd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atand8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atan2d8_u10(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_logd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cbrtd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_expd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_powd8_u10(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinhd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_coshd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tanhd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinhd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_coshd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tanhd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastsind8_u3500(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastcosd8_u3500(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastpowd8_u3500(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asinhd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acoshd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atanhd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp2d8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp2d8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp10d8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp10d8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_expm1d8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log10d8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log2d8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log2d8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log1pd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincospid8_u05(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincospid8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinpid8_u05(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cospid8_u05(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_ldexpd8(__m512d, __m256i); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_ilogbd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmad8(__m512d, __m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_u05(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_u35(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_hypotd8_u05(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_hypotd8_u35(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fabsd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_copysignd8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmaxd8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmind8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fdimd8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_truncd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_floord8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_ceild8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_roundd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_rintd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_nextafterd8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_frfrexpd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_expfrexpd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmodd8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_remainderd8(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_modfd8(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_lgammad8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tgammad8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_erfd8_u10(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_erfcd8_u15(__m512d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd8(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd8(int); + +#ifndef Sleef___m512_2_DEFINED +typedef struct { + __m512 x, y; +} Sleef___m512_2; +#define Sleef___m512_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cosf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincosf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acosf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atan2f16_u35(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_logf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cbrtf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cosf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincosf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acosf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atan2f16_u10(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_logf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cbrtf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_expf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_powf16_u10(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinhf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_coshf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanhf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinhf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_coshf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanhf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastsinf16_u3500(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastcosf16_u3500(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastpowf16_u3500(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinhf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acoshf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanhf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp2f16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp2f16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp10f16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp10f16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_expm1f16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log10f16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log2f16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log2f16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log1pf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincospif16_u05(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincospif16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinpif16_u05(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cospif16_u05(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmaf16(__m512, __m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_u05(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_u35(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_hypotf16_u05(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_hypotf16_u35(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fabsf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_copysignf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmaxf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fminf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fdimf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_truncf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_floorf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_ceilf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_roundf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_rintf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_nextafterf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_frfrexpf16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmodf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_remainderf16(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_modff16(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_lgammaf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tgammaf16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_erff16_u10(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_erfcf16_u15(__m512); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf16(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf16(int); +#endif +#ifdef __AVX512F__ + +#ifndef Sleef___m512d_2_DEFINED +typedef struct { + __m512d x, y; +} Sleef___m512d_2; +#define Sleef___m512d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sind8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sind8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cosd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_cosd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincosd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_finz_sincosd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tand8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_tand8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asind8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_asind8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acosd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_acosd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atand8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_atand8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atan2d8_u35avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_atan2d8_u35avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_logd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_logd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cbrtd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_cbrtd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sind8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sind8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cosd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_cosd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincosd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_finz_sincosd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tand8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_tand8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asind8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_asind8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acosd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_acosd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atand8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_atand8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atan2d8_u10avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_atan2d8_u10avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_logd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_logd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cbrtd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_cbrtd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_expd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_expd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_powd8_u10avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_powd8_u10avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sinhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_coshd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_coshd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tanhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_tanhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinhd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sinhd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_coshd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_coshd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tanhd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_tanhd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastsind8_u3500avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fastsind8_u3500avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastcosd8_u3500avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fastcosd8_u3500avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastpowd8_u3500avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fastpowd8_u3500avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asinhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_asinhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acoshd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_acoshd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atanhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_atanhd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp2d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_exp2d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp2d8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_exp2d8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp10d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_exp10d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp10d8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_exp10d8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_expm1d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_expm1d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log10d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_log10d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log2d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_log2d8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log2d8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_log2d8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log1pd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_log1pd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincospid8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_finz_sincospid8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincospid8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_finz_sincospid8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinpid8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sinpid8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cospid8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_cospid8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_ldexpd8_avx512f(__m512d, __m256i); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_ldexpd8_avx512f(__m512d, __m256i); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_ilogbd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_finz_ilogbd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmad8_avx512f(__m512d, __m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fmad8_avx512f(__m512d, __m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sqrtd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sqrtd8_u05avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_sqrtd8_u35avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_hypotd8_u05avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_hypotd8_u05avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_hypotd8_u35avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_hypotd8_u35avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fabsd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fabsd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_copysignd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_copysignd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmaxd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fmaxd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmind8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fmind8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fdimd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fdimd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_truncd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_truncd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_floord8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_floord8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_ceild8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_ceild8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_roundd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_roundd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_rintd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_rintd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_nextafterd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_nextafterd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_frfrexpd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_frfrexpd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_expfrexpd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_finz_expfrexpd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmodd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_fmodd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_remainderd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_remainderd8_avx512f(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_modfd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_finz_modfd8_avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_lgammad8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_lgammad8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tgammad8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_tgammad8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_erfd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_erfd8_u10avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_erfcd8_u15avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_finz_erfcd8_u15avx512f(__m512d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd8_avx512f(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd8_avx512f(int); + +#ifndef Sleef___m512_2_DEFINED +typedef struct { + __m512 x, y; +} Sleef___m512_2; +#define Sleef___m512_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sinf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cosf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_cosf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincosf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_finz_sincosf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_tanf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_asinf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acosf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_acosf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_atanf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atan2f16_u35avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_atan2f16_u35avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_logf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_logf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cbrtf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_cbrtf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sinf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cosf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_cosf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincosf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_finz_sincosf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_tanf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_asinf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acosf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_acosf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_atanf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atan2f16_u10avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_atan2f16_u10avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_logf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_logf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cbrtf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_cbrtf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_expf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_expf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_powf16_u10avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_powf16_u10avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sinhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_coshf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_coshf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_tanhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinhf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sinhf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_coshf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_coshf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanhf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_tanhf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastsinf16_u3500avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fastsinf16_u3500avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastcosf16_u3500avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fastcosf16_u3500avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastpowf16_u3500avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fastpowf16_u3500avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_asinhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acoshf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_acoshf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_atanhf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp2f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_exp2f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp2f16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_exp2f16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp10f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_exp10f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp10f16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_exp10f16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_expm1f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_expm1f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log10f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_log10f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log2f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_log2f16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log2f16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_log2f16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log1pf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_log1pf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincospif16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_finz_sincospif16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincospif16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_finz_sincospif16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinpif16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sinpif16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cospif16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_cospif16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmaf16_avx512f(__m512, __m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fmaf16_avx512f(__m512, __m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sqrtf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sqrtf16_u05avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_sqrtf16_u35avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_hypotf16_u05avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_hypotf16_u05avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_hypotf16_u35avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_hypotf16_u35avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fabsf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fabsf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_copysignf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_copysignf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmaxf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fmaxf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fminf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fminf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fdimf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fdimf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_truncf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_truncf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_floorf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_floorf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_ceilf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_ceilf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_roundf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_roundf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_rintf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_rintf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_nextafterf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_nextafterf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_frfrexpf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_frfrexpf16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmodf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_fmodf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_remainderf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_remainderf16_avx512f(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_modff16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_finz_modff16_avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_lgammaf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_lgammaf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tgammaf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_tgammaf16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_erff16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_erff16_u10avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_erfcf16_u15avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_finz_erfcf16_u15avx512f(__m512); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf16_avx512f(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_finz_getIntf16_avx512f(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf16_avx512f(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_finz_getPtrf16_avx512f(int); +#endif +#ifdef __AVX512F__ + +#ifndef Sleef___m512d_2_DEFINED +typedef struct { + __m512d x, y; +} Sleef___m512d_2; +#define Sleef___m512d_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sind8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sind8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cosd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_cosd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincosd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_cinz_sincosd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tand8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_tand8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asind8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_asind8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acosd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_acosd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atand8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_atand8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atan2d8_u35avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_atan2d8_u35avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_logd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_logd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cbrtd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_cbrtd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sind8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sind8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cosd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_cosd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincosd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_cinz_sincosd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tand8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_tand8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asind8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_asind8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acosd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_acosd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atand8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_atand8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atan2d8_u10avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_atan2d8_u10avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_logd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_logd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cbrtd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_cbrtd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_expd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_expd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_powd8_u10avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_powd8_u10avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sinhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_coshd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_coshd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tanhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_tanhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinhd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sinhd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_coshd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_coshd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tanhd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_tanhd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastsind8_u3500avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fastsind8_u3500avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastcosd8_u3500avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fastcosd8_u3500avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fastpowd8_u3500avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fastpowd8_u3500avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_asinhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_asinhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_acoshd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_acoshd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_atanhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_atanhd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp2d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_exp2d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp2d8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_exp2d8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp10d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_exp10d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_exp10d8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_exp10d8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_expm1d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_expm1d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log10d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_log10d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log2d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_log2d8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log2d8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_log2d8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_log1pd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_log1pd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincospid8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_cinz_sincospid8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_sincospid8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_cinz_sincospid8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sinpid8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sinpid8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cospid8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_cospid8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_ldexpd8_avx512fnofma(__m512d, __m256i); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_ldexpd8_avx512fnofma(__m512d, __m256i); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_ilogbd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_cinz_ilogbd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmad8_avx512fnofma(__m512d, __m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fmad8_avx512fnofma(__m512d, __m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sqrtd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sqrtd8_u05avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_sqrtd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_sqrtd8_u35avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_hypotd8_u05avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_hypotd8_u05avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_hypotd8_u35avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_hypotd8_u35avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fabsd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fabsd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_copysignd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_copysignd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmaxd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fmaxd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmind8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fmind8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fdimd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fdimd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_truncd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_truncd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_floord8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_floord8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_ceild8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_ceild8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_roundd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_roundd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_rintd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_rintd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_nextafterd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_nextafterd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_frfrexpd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_frfrexpd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_expfrexpd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m256i Sleef_cinz_expfrexpd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_fmodd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_fmodd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_remainderd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_remainderd8_avx512fnofma(__m512d, __m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_modfd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST Sleef___m512d_2 Sleef_cinz_modfd8_avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_lgammad8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_lgammad8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_tgammad8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_tgammad8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_erfd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_erfd8_u10avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_erfcd8_u15avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST __m512d Sleef_cinz_erfcd8_u15avx512fnofma(__m512d); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd8_avx512fnofma(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd8_avx512fnofma(int); + +#ifndef Sleef___m512_2_DEFINED +typedef struct { + __m512 x, y; +} Sleef___m512_2; +#define Sleef___m512_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sinf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cosf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_cosf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincosf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_cinz_sincosf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_tanf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_asinf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acosf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_acosf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_atanf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atan2f16_u35avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_atan2f16_u35avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_logf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_logf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cbrtf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_cbrtf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sinf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cosf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_cosf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincosf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_cinz_sincosf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_tanf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_asinf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acosf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_acosf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_atanf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atan2f16_u10avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_atan2f16_u10avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_logf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_logf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cbrtf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_cbrtf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_expf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_expf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_powf16_u10avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_powf16_u10avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sinhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_coshf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_coshf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_tanhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinhf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sinhf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_coshf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_coshf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tanhf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_tanhf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastsinf16_u3500avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fastsinf16_u3500avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastcosf16_u3500avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fastcosf16_u3500avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fastpowf16_u3500avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fastpowf16_u3500avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_asinhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_asinhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_acoshf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_acoshf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_atanhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_atanhf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp2f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_exp2f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp2f16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_exp2f16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp10f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_exp10f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_exp10f16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_exp10f16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_expm1f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_expm1f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log10f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_log10f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log2f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_log2f16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log2f16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_log2f16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_log1pf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_log1pf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincospif16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_cinz_sincospif16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_sincospif16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_cinz_sincospif16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sinpif16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sinpif16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cospif16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_cospif16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmaf16_avx512fnofma(__m512, __m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fmaf16_avx512fnofma(__m512, __m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sqrtf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sqrtf16_u05avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_sqrtf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_sqrtf16_u35avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_hypotf16_u05avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_hypotf16_u05avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_hypotf16_u35avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_hypotf16_u35avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fabsf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fabsf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_copysignf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_copysignf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmaxf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fmaxf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fminf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fminf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fdimf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fdimf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_truncf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_truncf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_floorf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_floorf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_ceilf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_ceilf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_roundf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_roundf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_rintf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_rintf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_nextafterf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_nextafterf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_frfrexpf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_frfrexpf16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_fmodf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_fmodf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_remainderf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_remainderf16_avx512fnofma(__m512, __m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_modff16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST Sleef___m512_2 Sleef_cinz_modff16_avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_lgammaf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_lgammaf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_tgammaf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_tgammaf16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_erff16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_erff16_u10avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_erfcf16_u15avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST __m512 Sleef_cinz_erfcf16_u15avx512fnofma(__m512); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf16_avx512fnofma(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_cinz_getIntf16_avx512fnofma(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf16_avx512fnofma(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_cinz_getPtrf16_avx512fnofma(int); +#endif +#ifdef __STDC__ + +#ifndef Sleef_double_2_DEFINED +typedef Sleef_double2 Sleef_double_2; +#define Sleef_double_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST double Sleef_sind1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sind1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cosd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_cosd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincosd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_cinz_sincosd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tand1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_tand1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_asind1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_asind1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_acosd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_acosd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atand1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_atand1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atan2d1_u35purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_atan2d1_u35purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_logd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_logd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cbrtd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_cbrtd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sind1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sind1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cosd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_cosd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincosd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_cinz_sincosd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tand1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_tand1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_asind1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_asind1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_acosd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_acosd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atand1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_atand1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atan2d1_u10purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_atan2d1_u10purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_logd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_logd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cbrtd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_cbrtd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_expd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_expd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_powd1_u10purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_powd1_u10purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sinhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sinhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_coshd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_coshd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tanhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_tanhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sinhd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sinhd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_coshd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_coshd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tanhd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_tanhd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fastsind1_u3500purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fastsind1_u3500purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fastcosd1_u3500purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fastcosd1_u3500purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fastpowd1_u3500purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fastpowd1_u3500purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_asinhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_asinhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_acoshd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_acoshd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atanhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_atanhd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp2d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_exp2d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp2d1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_exp2d1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp10d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_exp10d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp10d1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_exp10d1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_expm1d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_expm1d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log10d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_log10d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log2d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_log2d1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log2d1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_log2d1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log1pd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_log1pd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincospid1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_cinz_sincospid1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincospid1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_cinz_sincospid1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sinpid1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sinpid1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cospid1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_cospid1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_ldexpd1_purec(double, int32_t); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_ldexpd1_purec(double, int32_t); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_ilogbd1_purec(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_cinz_ilogbd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmad1_purec(double, double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fmad1_purec(double, double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sqrtd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sqrtd1_u05purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_sqrtd1_u35purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_hypotd1_u05purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_hypotd1_u05purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_hypotd1_u35purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_hypotd1_u35purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fabsd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fabsd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_copysignd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_copysignd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmaxd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fmaxd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmind1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fmind1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fdimd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fdimd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_truncd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_truncd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_floord1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_floord1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_ceild1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_ceild1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_roundd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_roundd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_rintd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_rintd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_nextafterd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_nextafterd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_frfrexpd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_frfrexpd1_purec(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_expfrexpd1_purec(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_cinz_expfrexpd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmodd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_fmodd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_remainderd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_remainderd1_purec(double, double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_modfd1_purec(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_cinz_modfd1_purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_lgammad1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_lgammad1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tgammad1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_tgammad1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_erfd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_erfd1_u10purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_erfcd1_u15purec(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cinz_erfcd1_u15purec(double); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd1_purec(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd1_purec(int); + +#ifndef Sleef_float_2_DEFINED +typedef Sleef_float2 Sleef_float_2; +#define Sleef_float_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST float Sleef_sinf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sinf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cosf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_cosf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincosf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_cinz_sincosf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_tanf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_asinf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_asinf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_acosf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_acosf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atanf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_atanf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f1_u35purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_atan2f1_u35purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_logf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_logf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_cbrtf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sinf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cosf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_cosf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincosf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_cinz_sincosf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_tanf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_asinf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_asinf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_acosf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_acosf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atanf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_atanf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f1_u10purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_atan2f1_u10purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_logf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_logf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_cbrtf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_expf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_expf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_powf1_u10purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_powf1_u10purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sinhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_coshf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_coshf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_tanhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sinhf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_coshf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_coshf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_tanhf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fastsinf1_u3500purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fastsinf1_u3500purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fastcosf1_u3500purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fastcosf1_u3500purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fastpowf1_u3500purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fastpowf1_u3500purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_asinhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_asinhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_acoshf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_acoshf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atanhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_atanhf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_exp2f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_exp2f1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_exp10f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_exp10f1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_expm1f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_expm1f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log10f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_log10f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log2f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_log2f1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log2f1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_log2f1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log1pf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_log1pf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincospif1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_cinz_sincospif1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincospif1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_cinz_sincospif1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinpif1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sinpif1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cospif1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_cospif1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fmaf1_purec(float, float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fmaf1_purec(float, float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sqrtf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sqrtf1_u05purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_sqrtf1_u35purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf1_u05purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_hypotf1_u05purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf1_u35purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_hypotf1_u35purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fabsf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fabsf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_copysignf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_copysignf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fmaxf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fmaxf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fminf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fminf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fdimf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fdimf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_truncf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_truncf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_floorf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_floorf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_ceilf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_ceilf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_roundf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_roundf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_rintf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_rintf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_nextafterf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_nextafterf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_frfrexpf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_frfrexpf1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fmodf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_fmodf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_remainderf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_remainderf1_purec(float, float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_modff1_purec(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_cinz_modff1_purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_lgammaf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_lgammaf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tgammaf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_tgammaf1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_erff1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_erff1_u10purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_erfcf1_u15purec(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cinz_erfcf1_u15purec(float); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf1_purec(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_cinz_getIntf1_purec(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf1_purec(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_cinz_getPtrf1_purec(int); +#endif +#ifdef __STDC__ + +#ifndef Sleef_double_2_DEFINED +typedef Sleef_double2 Sleef_double_2; +#define Sleef_double_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST double Sleef_sind1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sind1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cosd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_cosd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincosd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_finz_sincosd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tand1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_tand1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_asind1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_asind1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_acosd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_acosd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atand1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_atand1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atan2d1_u35purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_atan2d1_u35purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_logd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_logd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cbrtd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_cbrtd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sind1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sind1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cosd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_cosd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincosd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_finz_sincosd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tand1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_tand1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_asind1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_asind1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_acosd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_acosd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atand1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_atand1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atan2d1_u10purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_atan2d1_u10purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_logd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_logd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cbrtd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_cbrtd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_expd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_expd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_powd1_u10purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_powd1_u10purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sinhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sinhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_coshd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_coshd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tanhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_tanhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sinhd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sinhd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_coshd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_coshd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tanhd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_tanhd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fastsind1_u3500purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fastsind1_u3500purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fastcosd1_u3500purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fastcosd1_u3500purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fastpowd1_u3500purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fastpowd1_u3500purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_asinhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_asinhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_acoshd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_acoshd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_atanhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_atanhd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp2d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_exp2d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp2d1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_exp2d1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp10d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_exp10d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_exp10d1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_exp10d1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_expm1d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_expm1d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log10d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_log10d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log2d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_log2d1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log2d1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_log2d1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_log1pd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_log1pd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincospid1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_finz_sincospid1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincospid1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_finz_sincospid1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sinpid1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sinpid1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_cospid1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_cospid1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_ldexpd1_purecfma(double, int32_t); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_ldexpd1_purecfma(double, int32_t); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_ilogbd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_finz_ilogbd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmad1_purecfma(double, double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fmad1_purecfma(double, double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sqrtd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sqrtd1_u05purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_sqrtd1_u35purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_hypotd1_u05purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_hypotd1_u05purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_hypotd1_u35purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_hypotd1_u35purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fabsd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fabsd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_copysignd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_copysignd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmaxd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fmaxd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmind1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fmind1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fdimd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fdimd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_truncd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_truncd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_floord1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_floord1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_ceild1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_ceild1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_roundd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_roundd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_rintd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_rintd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_nextafterd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_nextafterd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_frfrexpd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_frfrexpd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_expfrexpd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_finz_expfrexpd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmodd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_fmodd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_remainderd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_remainderd1_purecfma(double, double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_modfd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_finz_modfd1_purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_lgammad1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_lgammad1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_tgammad1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_tgammad1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_erfd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_erfd1_u10purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_erfcd1_u15purecfma(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_finz_erfcd1_u15purecfma(double); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd1_purecfma(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd1_purecfma(int); + +#ifndef Sleef_float_2_DEFINED +typedef Sleef_float2 Sleef_float_2; +#define Sleef_float_2_DEFINED +#endif + +SLEEF_IMPORT SLEEF_CONST float Sleef_sinf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sinf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cosf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_cosf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincosf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_finz_sincosf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_tanf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_asinf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_asinf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_acosf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_acosf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atanf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_atanf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f1_u35purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_atan2f1_u35purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_logf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_logf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_cbrtf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sinf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cosf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_cosf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincosf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_finz_sincosf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_tanf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_asinf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_asinf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_acosf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_acosf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atanf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_atanf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f1_u10purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_atan2f1_u10purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_logf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_logf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_cbrtf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_expf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_expf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_powf1_u10purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_powf1_u10purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sinhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_coshf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_coshf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_tanhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sinhf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_coshf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_coshf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_tanhf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fastsinf1_u3500purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fastsinf1_u3500purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fastcosf1_u3500purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fastcosf1_u3500purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fastpowf1_u3500purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fastpowf1_u3500purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_asinhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_asinhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_acoshf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_acoshf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_atanhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_atanhf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_exp2f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_exp2f1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_exp10f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_exp10f1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_expm1f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_expm1f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log10f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_log10f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log2f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_log2f1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log2f1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_log2f1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_log1pf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_log1pf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincospif1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_finz_sincospif1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincospif1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_finz_sincospif1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sinpif1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sinpif1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_cospif1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_cospif1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fmaf1_purecfma(float, float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fmaf1_purecfma(float, float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sqrtf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sqrtf1_u05purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_sqrtf1_u35purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf1_u05purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_hypotf1_u05purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf1_u35purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_hypotf1_u35purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fabsf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fabsf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_copysignf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_copysignf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fmaxf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fmaxf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fminf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fminf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fdimf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fdimf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_truncf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_truncf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_floorf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_floorf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_ceilf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_ceilf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_roundf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_roundf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_rintf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_rintf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_nextafterf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_nextafterf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_frfrexpf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_frfrexpf1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_fmodf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_fmodf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_remainderf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_remainderf1_purecfma(float, float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_modff1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_finz_modff1_purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_lgammaf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_lgammaf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_tgammaf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_tgammaf1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_erff1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_erff1_u10purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_erfcf1_u15purecfma(float); +SLEEF_IMPORT SLEEF_CONST float Sleef_finz_erfcf1_u15purecfma(float); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf1_purecfma(int); +SLEEF_IMPORT SLEEF_CONST int Sleef_finz_getIntf1_purecfma(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf1_purecfma(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_finz_getPtrf1_purecfma(int); +#endif +#ifdef __STDC__ + +#ifndef Sleef_double_2_DEFINED +typedef Sleef_double2 Sleef_double_2; +#define Sleef_double_2_DEFINED +#endif + +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sind1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cosd1_u35(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincosd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tand1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_asind1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_acosd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atand1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atan2d1_u35(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_logd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cbrtd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sind1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cosd1_u10(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincosd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tand1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_asind1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_acosd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atand1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atan2d1_u10(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_logd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cbrtd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_expd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_powd1_u10(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sinhd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_coshd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tanhd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sinhd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_coshd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tanhd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fastsind1_u3500(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fastcosd1_u3500(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fastpowd1_u3500(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_asinhd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_acoshd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_atanhd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp2d1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp2d1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp10d1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_exp10d1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_expm1d1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log10d1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log2d1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log2d1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_log1pd1_u10(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincospid1_u05(double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_sincospid1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sinpid1_u05(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_cospid1_u05(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_ldexpd1(double, int32_t); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_ilogbd1(double); +SLEEF_IMPORT SLEEF_CONST double Sleef_fmad1(double, double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_u05(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_sqrtd1_u35(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_hypotd1_u05(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_hypotd1_u35(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fabsd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_copysignd1(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fmaxd1(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fmind1(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fdimd1(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_truncd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_floord1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_ceild1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_roundd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_rintd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_nextafterd1(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_frfrexpd1(double); +SLEEF_IMPORT SLEEF_CONST int32_t Sleef_expfrexpd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_fmodd1(double, double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_remainderd1(double, double); +SLEEF_IMPORT SLEEF_CONST Sleef_double_2 Sleef_modfd1(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_lgammad1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_tgammad1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_erfd1_u10(double); +SLEEF_PRAGMA_OMP_SIMD_DP SLEEF_IMPORT SLEEF_CONST double Sleef_erfcd1_u15(double); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntd1(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrd1(int); + +#ifndef Sleef_float_2_DEFINED +typedef Sleef_float2 Sleef_float_2; +#define Sleef_float_2_DEFINED +#endif + +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cosf1_u35(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincosf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_asinf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_acosf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atanf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f1_u35(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_logf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cosf1_u10(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincosf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_asinf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_acosf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atanf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atan2f1_u10(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_logf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cbrtf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_expf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_powf1_u10(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_coshf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinhf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_coshf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tanhf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fastsinf1_u3500(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fastcosf1_u3500(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fastpowf1_u3500(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_asinhf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_acoshf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_atanhf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp2f1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_exp10f1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_expm1f1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log10f1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log2f1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log2f1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_log1pf1_u10(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincospif1_u05(float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_sincospif1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sinpif1_u05(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_cospif1_u05(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fmaf1(float, float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_u05(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_sqrtf1_u35(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf1_u05(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_hypotf1_u35(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fabsf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_copysignf1(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fmaxf1(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fminf1(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fdimf1(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_truncf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_floorf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_ceilf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_roundf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_rintf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_nextafterf1(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_frfrexpf1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_fmodf1(float, float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_remainderf1(float, float); +SLEEF_IMPORT SLEEF_CONST Sleef_float_2 Sleef_modff1(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_lgammaf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_tgammaf1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_erff1_u10(float); +SLEEF_PRAGMA_OMP_SIMD_SP SLEEF_IMPORT SLEEF_CONST float Sleef_erfcf1_u15(float); +SLEEF_IMPORT SLEEF_CONST int Sleef_getIntf1(int); +SLEEF_IMPORT SLEEF_CONST void *Sleef_getPtrf1(int); +#endif + +// + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // #ifndef __SLEEF_H__ diff --git a/phivenv/Lib/site-packages/torch/include/xnnpack.h b/phivenv/Lib/site-packages/torch/include/xnnpack.h new file mode 100644 index 0000000000000000000000000000000000000000..aa5bf4f9ed8caf9aa2216d49566368aed7ad7e58 --- /dev/null +++ b/phivenv/Lib/site-packages/torch/include/xnnpack.h @@ -0,0 +1,4855 @@ +// Copyright (c) Facebook, Inc. and its affiliates. +// All rights reserved. +// +// Copyright 2019 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include "pthreadpool.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/// The number of bytes XNNPACK may read beyond array bounds. +/// The caller must allocate at least this many extra bytes after the tensor data passed to XNNPACK. +/// +/// Note: XNNPACK reads, but never writes beyond array bounds. +#if XNN_ARCH_HEXAGON +#define XNN_EXTRA_BYTES 128 +#else +#define XNN_EXTRA_BYTES 16 +#endif // XNN_ARCH_HEXAGON + +/// Maximum number of dimensions in tensor shape. +#define XNN_MAX_TENSOR_DIMS 6 + +/// A value ID that cannot be valid. +#define XNN_INVALID_VALUE_ID UINT32_MAX + +/// Allow sparse inference in a Runtime. +/// +/// Note: this flag is a hint to XNNPACK that it should consider sparse inference, but does not guarantee it. +#define XNN_FLAG_HINT_SPARSE_INFERENCE 0x00000001 + +/// Allow IEEE FP16 inference in a Runtime. +/// +/// Note: this flag hints XNNPACK to consider IEEE FP16 inference, but does not guarantee it. +#define XNN_FLAG_HINT_FP16_INFERENCE 0x00000002 + +/// Force IEEE FP16 inference in a Runtime, and fail if FP16 inference is not possible. +/// +/// Note: this flag guarantees that XNNPACK will use IEEE FP16 inference, or fail to create the Runtime object. +/// Warning: on x86 systems FP16 computations will be emulated at a substantial performance cost. +#define XNN_FLAG_FORCE_FP16_INFERENCE 0x00000004 + +/// Enable timing of each operator's runtime. +#define XNN_FLAG_BASIC_PROFILING 0x00000008 + +/// Enable the just-in-time compiler. +#define XNN_FLAG_JIT 0x00000010 + +/// The convolution operator represents a depthwise convolution, and use HWGo layout for filters. +#define XNN_FLAG_DEPTHWISE_CONVOLUTION 0x00000001 + +/// Assume transposed weights in a fully connected operator. +#define XNN_FLAG_TRANSPOSE_WEIGHTS 0x00000001 + +/// The operator assumes NHWC layout for the input, regardless of the output layout. +#define XNN_FLAG_INPUT_NHWC 0x00000002 + +/// Match "SAME" padding in TensorFlow. Exact padding values are computed dynamically depending on input size. +#define XNN_FLAG_TENSORFLOW_SAME_PADDING 0x00000004 + +/// Assume transposed weights in a batch matrix multiply operator. +#define XNN_FLAG_TRANSPOSE_B XNN_FLAG_TRANSPOSE_WEIGHTS + +/// Assume transposed input in a batch matrix multiply operator. +#define XNN_FLAG_TRANSPOSE_A 0x00000002 + +/// Implicitly flatten and reshape input of a Fully Connected operator into a 2D tensor. +#define XNN_FLAG_TENSORFLOW_RESHAPE_2D 0x00000004 + +/// Match behaviour of TensorFlow 1.x. +#define XNN_FLAG_TENSORFLOW_LEGACY_MODE 0x00000004 + +/// Static weights of the FP16 operator are in FP32 format. +#define XNN_FLAG_FP32_STATIC_WEIGHTS 0x00000008 + +/// Static biases of the FP16 operator are in FP32 format. +#define XNN_FLAG_FP32_STATIC_BIASES 0x00000080 + +/// Align corners of input and output images in resize operations. +#define XNN_FLAG_ALIGN_CORNERS 0x00000008 + +/// Yield worker threads of the thread pool to the system scheduler after the inference. +#define XNN_FLAG_YIELD_WORKERS 0x00000010 + +/// Use transient indirection buffer to reduce memory footprint +#define XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER 0x00000020 + +/// Retain reduced dimensions with length 1. +#define XNN_FLAG_KEEP_DIMS 0x00000040 + +// Next unused flag value: 0x00000100. + +/// The number of entries in an array of xnn_quantization_params that XNNPACK may read beyond array bounds. +/// The caller must allocate at least this many extra xnn_quantization_params before passing the array to XNNPACK. +/// +/// Note: XNNPACK reads, but never writes beyond array bounds. +#define XNN_EXTRA_QUANTIZATION_PARAMS 15 + +/// The minimum blocksize for blockwise quantized operators. +#define XNN_MIN_BLOCKSIZE 32 + +#ifdef __GNUC__ +#define XNN_DEPRECATED __attribute__((deprecated)) +#else +#define XNN_DEPRECATED +#endif + +struct xnn_quantization_params { + int32_t zero_point; + float scale; +}; + +/// Status code for any XNNPACK function call. +enum xnn_status { + /// The call succeeded, and all output arguments now contain valid data. + xnn_status_success = 0, + xnn_status_uninitialized = 1, + xnn_status_invalid_parameter = 2, + xnn_status_invalid_state = 3, + xnn_status_unsupported_parameter = 4, + xnn_status_unsupported_hardware = 5, + xnn_status_out_of_memory = 6, + xnn_status_reallocation_required = 7, + xnn_status_deprecated = 8, +}; + +struct xnn_allocator { + /// User-specified pointer that will be passed as-is to all functions in this structure. + void* context; + /// Pointer to a function to be called for general memory allocation. + /// + /// @param context - The user-specified pointer from xnn_allocator structure. + /// @param size - The size of the memory block to allocate, in bytes. + /// + /// @returns Pointer to the allocated memory block of at least @ref size bytes. + /// If allocation fails, the function must return NULL. + void* (*allocate)(void* context, size_t size); + /// Pointer to a function to be called for general memory re-allocation, i.e. to increase or shrink a previously + /// allocated memory block. The content of the old memory block is copied to the new memory block. + /// + /// @param context - The user-specified pointer from xnn_allocator structure. + /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL. + /// If the pointer is NULL, the @ref reallocate call is equivalent to an @ref allocate call. + /// @param size - The new size of the memory block to allocate, in bytes. + /// + /// @returns Pointer to the newly allocated memory block of at least @ref size bytes with the content of the previous + /// memory block. + /// If allocation fails, the function must return NULL, but must not release the previous memory block. + void* (*reallocate)(void* context, void* pointer, size_t size); + /// Pointer to a function to be called for general memory de-allocation. + /// + /// @param context - The user-specified pointer from xnn_allocator structure. + /// @param pointer - Pointer to a memory block allocated by @ref allocate or @ref reallocate functions. Can be NULL. + /// If the pointer is NULL, the @ref deallocate call is a no-op. + void (*deallocate)(void* context, void* pointer); + /// Pointer to a function to be called for aligned memory allocation. + /// + /// @param context - The user-specified pointer from xnn_allocator structure. + /// @param alignment - The alignment of the memory block to allocate, in bytes. Alignment is always a power-of-2. + /// @param size - The size of the memory block to allocate, in bytes. + /// + /// @returns Pointer to the allocated memory block of at least @ref size bytes. + /// If allocation fails, the function must return NULL. + void* (*aligned_allocate)(void* context, size_t alignment, size_t size); + /// Pointer to a function to be called for aligned memory deallocation. + /// + /// @param context - The user-specified pointer from xnn_allocator structure. + /// @param pointer - Pointer to a memory block allocated by @ref aligned_allocate function. Can be NULL. + /// If the pointer is NULL, the @ref aligned_deallocate call is a no-op. + void (*aligned_deallocate)(void* context, void* pointer); +}; + +/// Initialize XNNPACK library. +/// +/// XNNPACK must be successfully initialized before use. During initialization, XNNPACK populates internal structures +/// depending on the host processor. Initialization can be time-consuming. +/// +/// @param[in] allocator - structure with function pointers to be use for memory allocation and de-allocation. +/// If this argument is NULL, system-provided memory management functions (e.g. malloc/free) +/// will be used. +/// +/// @retval xnn_status_success - XNNPACK is successfully initialized and ready to use. +/// @retval xnn_status_out_of_memory - initialization failed due to out-of-memory condition. +/// @retval xnn_status_unsupported_hardware - initialization failed because the host processor does not satisfy the +/// minimum hardware requirements for XNNPACK. E.g. this may happen on x86 +/// processors without SSE2 extension, or on 32-bit ARM processors without +/// the NEON SIMD extension. +enum xnn_status xnn_initialize(const struct xnn_allocator* allocator); + +/// Deinitialize XNNPACK library. +/// +/// To avoid memory and resource leaks, users must call xnn_deinitialize once for each successful xnn_initialize call. +/// +/// @retval xnn_status_success - deinitialization call succeeded. +enum xnn_status xnn_deinitialize(void); + +/// Get the microkernel implementation build identifier's data. +/// +/// That identifier will be unique for the current set of microkernels implementations. +/// +/// @returns A pointer to the current identifier's data. +const void* xnn_experimental_get_build_identifier_data(); + +/// Get the microkernel implementation build identifier's data size. +/// +/// @returns The size in bytes of the identifier's data. +size_t xnn_experimental_get_build_identifier_size(); + +/// Check whether the given data matches this version's identifier. +/// +/// @returns The size in bytes of the identifier's data. +bool xnn_experimental_check_build_identifier(const void* data, size_t size); + +/// Subgraph is an abstract representation of a neural network model. +/// Subgraph objects are used to define Values (tensors) and Nodes (operators) comprising the model. +typedef struct xnn_subgraph* xnn_subgraph_t; + +/// Create a empty Subgraph object. +/// +/// @param external_value_ids - number of Value IDs to reserve for communication with external graph representation. +/// The Subgraph object would avoid creating internal Value IDs in the +/// [0, reserved_value_ids-1] range. +/// @param flags - binary features of the subgraph. No supported flags are currently defined. +/// @param subgraph_out - pointer to the variable that will be initialized with a handle to the Subgraph object upon +/// successful return. +enum xnn_status xnn_create_subgraph( + uint32_t external_value_ids, + uint32_t flags, + xnn_subgraph_t* subgraph_out); + +/// Destroy a Subgraph object, as well as Values, and Nodes associated with the subgraph. +/// +/// @param subgraph - the Subgraph object to destroy. +enum xnn_status xnn_delete_subgraph( + xnn_subgraph_t subgraph); + +#define XNN_VALUE_FLAG_EXTERNAL_INPUT 0x00000001 +#define XNN_VALUE_FLAG_EXTERNAL_OUTPUT 0x00000002 +#define XNN_VALUE_FLAG_PERSISTENT 0x00000004 + +#define XNN_INVALID_VALUE_ID UINT32_MAX + +/// Type of elements in a Value object. +enum xnn_datatype { + /// Invalid data type. Valid Values never have this datatype. + xnn_datatype_invalid = 0, + /// IEEE754 single-precision floating-point. + xnn_datatype_fp32 = 1, + /// IEEE754 half-precision floating-point. + xnn_datatype_fp16 = 2, + /// Quantized 8-bit signed integer with shared per-Value quantization + /// parameters. + xnn_datatype_qint8 = 3, + /// Quantized 8-bit unsigned integer with shared per-Value quantization + /// parameters. + xnn_datatype_quint8 = 4, + /// Quantized 32-bit signed integer with shared per-Value quantization + /// parameters. + xnn_datatype_qint32 = 5, + /// Quantized 8-bit signed integer with shared per-channel quantization + /// parameters. + xnn_datatype_qcint8 = 6, + /// Quantized 32-bit signed integer with shared per-channel quantization + /// parameters. + xnn_datatype_qcint32 = 7, + /// Quantized 4-bit signed integer with shared per-channel quantization + /// parameters. + xnn_datatype_qcint4 = 8, + /// Dynamically quantized 8-bit signed integer with per-batch quantization + /// parameters. + xnn_datatype_qdint8 = 9, + /// Dynamically quantized 8-bit signed integers packed with their per-row + /// quantization parameters. + xnn_datatype_qpint8 = 10, + /// 32-bit signed integers. + xnn_datatype_int32 = 11, + /// Quantized 4-bit signed integer with shared per-channel-block quantization + /// parameters. + xnn_datatype_qbint4 = 12, + /// IEEE754 single-precision packed floating-point. + xnn_datatype_pfp32 = 13, + /// BFloat16, i.e. the upper 16 bits of a float32. + xnn_datatype_bf16 = 14, + /// Dynamically quantized 8-bit unsigned integer with per-batch quantization + /// parameters. + xnn_datatype_qduint8 = 15, +}; + +/// Define a tensor-type Value and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Value. +/// @param datatype - type of the tensor elements. +/// @param num_dims - number of dimensions in the shape. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +/// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized, +/// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time +/// of the Subgraph object, and of any Runtime objects created from the Subgraph. +/// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on +/// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be +/// created for the Value. +/// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT +/// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT. +/// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a +/// valid @a external_id was provided, the variable will be initialized with the @a external_id value. +enum xnn_status xnn_define_tensor_value( + xnn_subgraph_t subgraph, + enum xnn_datatype datatype, + size_t num_dims, + const size_t* dims, + const void* data, + uint32_t external_id, + uint32_t flags, + uint32_t* id_out); + +/// Define a quantized tensor-type Value and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Value. +/// @param datatype - type of the tensor elements. +/// @param zero_point - offset from zero to subtract from the quantized elements in the Value. +/// @param scale - multiplication factor to convert quantized elements to real representation. +/// @param num_dims - number of dimensions in the shape. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +/// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized, +/// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time +/// of the Subgraph object, and of any Runtime objects created from the Subgraph. +/// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on +/// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be +/// created for the Value. +/// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT +/// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT. +/// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a +/// valid @a external_id was provided, the variable will be initialized with the @a external_id value. +enum xnn_status xnn_define_quantized_tensor_value( + xnn_subgraph_t subgraph, + enum xnn_datatype datatype, + int32_t zero_point, + float scale, + size_t num_dims, + const size_t* dims, + const void* data, + uint32_t external_id, + uint32_t flags, + uint32_t* id_out); + +enum xnn_status xnn_define_channelwise_quantized_tensor_value( + xnn_subgraph_t subgraph, + enum xnn_datatype datatype, + const float* scale, + size_t num_dims, + size_t channel_dim, + const size_t* dims, + const void* data, + uint32_t external_id, + uint32_t flags, + uint32_t* id_out); + +/// Validate the dimensions, channel_dim, zero point, datatype, and scale of a quantized tensor-type. +/// +/// @param datatype - type of the tensor elements. +/// @param zero_point - offset from zero to subtract from the quantized elements in the Value. +/// @param scale - multiplication factor to convert quantized elements to real representation. +/// @param num_dims - number of dimensions in the shape. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +enum xnn_status xnn_validate_quantized_tensor( + enum xnn_datatype datatype, + int32_t zero_point, + float scale, + size_t num_dims, + const size_t* dims); + +/// Validate the dimensions, channel_dim, zero point, datatype, and scales of a channelwise quantized tensor-type. +/// +/// @param datatype - type of the tensor elements. +/// @param zero_point - offset from zero to subtract from the quantized elements in the Value. +/// @param scale - per-channel multiplication factors to convert quantized elements to real representation. +/// @param num_dims - number of dimensions in the shape. +/// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters. +/// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution, +/// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in +/// the Depthwise Convolution operators. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +enum xnn_status xnn_validate_channelwise_quantized_tensor( + enum xnn_datatype datatype, + int32_t zero_point, + const float* scale, + size_t num_dims, + size_t channel_dim, + const size_t* dims); + +/// Define a channelwise quantized tensor-type Value and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Value. +/// @param datatype - type of the tensor elements. +/// @param zero_point - offset from zero to subtract from the quantized elements in the Value. +/// @param scale - per-channel multiplication factors to convert quantized elements to real representation. +/// @param num_dims - number of dimensions in the shape. +/// @param channel_dim - index of the channel dimension in the tensor with per-channel quantization parameters. +/// Typically this is the first dimension (dimension #0) of the filter tensors in the Convolution, +/// Deconvolution, and Fully Connected operators and the last dimension of the filter tensors in +/// the Depthwise Convolution operators. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +/// @param data - pointer to static data used for tensor initialization. If the tensor is not statically initialized, +/// this pointer must be is NULL. If non-NULL, the life-time of the static data must exceed the life-time +/// of the Subgraph object, and of any Runtime objects created from the Subgraph. +/// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on +/// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be +/// created for the Value. +/// @param flags - binary features of the Value. Supported values are any combination of XNN_VALUE_FLAG_EXTERNAL_INPUT +/// and XNN_VALUE_FLAG_EXTERNAL_OUTPUT. +/// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a +/// valid @a external_id was provided, the variable will be initialized with the @a external_id value. +enum xnn_status xnn_define_channelwise_quantized_tensor_value_v2( + xnn_subgraph_t subgraph, + enum xnn_datatype datatype, + int32_t zero_point, + const float* scale, + size_t num_dims, + size_t channel_dim, + const size_t* dims, + const void* data, + uint32_t external_id, + uint32_t flags, + uint32_t* id_out); + +/// Define a blockwise quantized tensor-type Value and add it to a Subgraph. +/// @param block_size - size of a block in the tensor with blockwise quantization parameters. Block is defined as +/// number of input channel element per output channel. +/// For Fully connected operators with 2d filters of size [output_channels, input_channels], +/// expecting number of scale values to be = output_channels * (input_channels / block_size). +enum xnn_status xnn_define_blockwise_quantized_tensor_value( + xnn_subgraph_t subgraph, + enum xnn_datatype datatype, + int32_t zero_point, + const uint16_t* scale, + size_t num_dims, + size_t channel_dim, + size_t block_size, + const size_t* dims, + const void* data, + uint32_t external_id, + uint32_t flags, + uint32_t* id_out); + +/// Define a dynamically quantized tensor-type Value and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Value. +/// @param datatype - type of the tensor elements. +/// @param num_dims - number of dimensions in the shape. +/// @param num_non_batch_dims - number of non-batch dimensions in the shape. The leading (num_dims - num_non_batch_dims) +/// dimensions will be flattened and treated as batch size. A set of quantization parameters +/// will be calculated for each batch element. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +/// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on +/// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be +/// created for the Value. +/// @param flags - binary features of the Value. No supported flags are currently defined. +/// @param id_out - pointer to the variable that will be initialized with the Value ID upon successful return. If a +/// valid @a external_id was provided, the variable will be initialized with the @a external_id value. +enum xnn_status xnn_define_dynamically_quantized_tensor_value( + xnn_subgraph_t subgraph, + enum xnn_datatype datatype, + size_t num_dims, + size_t num_nonbatch_dims, + const size_t* dims, + uint32_t external_id, + uint32_t flags, + uint32_t* id_out); + +/// Type of unary operation +enum xnn_unary_operator { + xnn_unary_invalid = -1, + xnn_unary_convert, + xnn_unary_clamp, + xnn_unary_abs, + xnn_unary_bankers_rounding, + xnn_unary_ceiling, + xnn_unary_elu, + xnn_unary_exp, + xnn_unary_floor, + xnn_unary_gelu, + xnn_unary_hardswish, + xnn_unary_leaky_relu, + xnn_unary_log, + xnn_unary_negate, + xnn_unary_sigmoid, + xnn_unary_square, + xnn_unary_square_root, + xnn_unary_reciprocal_square_root, + xnn_unary_tanh, + // The following operators are experimental and may be removed. + xnn_unary_cube_root, + xnn_unary_cosine, + xnn_unary_sine, + xnn_unary_count_leading_zeros, + xnn_unary_bitwise_not, + xnn_unary_popcount, + xnn_unary_sign, +}; + +/// Parameters for xnn_define_unary +union xnn_unary_params { + struct { + /// lower bound for clipping output values. + float min; + /// upper bound for clipping output values. + float max; + } clamp; + struct { + /// scale factor for negative input elements. + float alpha; + } elu; + struct { + /// scale factor for negative input elements. + float negative_slope; + } leaky_relu; +}; + +/// Define a unary operator Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param operator - type of unary operator to define. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param params - parameters to be interpreted by the specific operator type. +/// @param flags - binary features of the Node. No supported flags are currently defined. +enum xnn_status xnn_define_unary( + xnn_subgraph_t subgraph, + enum xnn_unary_operator type, + const union xnn_unary_params* params, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Convert Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Convert Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_convert( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Convolution Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING +/// flag is specified. +/// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param kernel_height - kernel (filter) height. +/// @param kernel_width - kernel (filter) width. +/// @param subsampling_height - height of subsampling region for convolution output (convolution height stride). +/// @param subsampling_width - width of subsampling region for convolution output (convolution width stride). +/// @param dilation_height - dilation of kernel elements along the height dimension. +/// @param dilation_width - dilation of kernel elements along the width dimension. +/// @param groups - number of convolution groups. +/// @param group_input_channels - number of input channels per group. +/// @param group_output_channels - number of output channels per group. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, groups * group_input_channels] dimensions +/// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph +/// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels] +/// dimensions. +/// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If +/// present, the bias tensor must be a 1D tensor defined in the @a subgraph with [groups * +/// group_output_channels] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, OH, OW, groups * group_output_channels] dimensions. +/// @param flags - binary features of the 2D Convolution Node. The only currently supported values is +/// XNN_FLAG_TENSORFLOW_SAME_PADDING. +enum xnn_status xnn_define_convolution_2d( + xnn_subgraph_t subgraph, + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + float output_min, + float output_max, + uint32_t input_id, + uint32_t filter_id, + uint32_t bias_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Deconvolution (Transposed Convolution) Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param padding_top - implicit padding above 2D output data. +/// @param padding_right - implicit padding to the right of 2D output data. +/// @param padding_bottom - implicit padding below 2D output data. +/// @param padding_left - implicit padding to the left of 2D output data. +/// @param adjustment_height - additional elements in the bottom of the 2D output data. +/// @param adjustment_width - additional elements to the right of the 2D output data. +/// @param kernel_height - kernel (filter) height. +/// @param kernel_width - kernel (filter) width. +/// @param upsampling_height - height of upsampling region for deconvolution input (deconvolution height stride). +/// @param upsampling_width - width of upsampling region for deconvolution input (deconvolution width stride). +/// @param dilation_height - dilation of kernel elements along the height dimension. +/// @param dilation_width - dilation of kernel elements along the width dimension. +/// @param groups - number of convolution groups. +/// @param group_input_channels - number of input channels per group. +/// @param group_output_channels - number of output channels per group. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, groups * group_input_channels] dimensions +/// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph +/// with [groups * group_output_channels, kernel_height, kernel_width, group_input_channels] +/// dimensions. +/// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Convolution Node without a bias. If +/// present, the bias tensor must be a 1D tensor defined in the @a subgraph with +/// [groups * group_output_channels] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, OH, OW, groups * group_output_channels] dimensions. +/// @param flags - binary features of the 2D Deconvolution Node. No supported flags are currently defined. +enum xnn_status xnn_define_deconvolution_2d( + xnn_subgraph_t subgraph, + uint32_t padding_top, + uint32_t padding_right, + uint32_t padding_bottom, + uint32_t padding_left, + uint32_t adjustment_height, + uint32_t adjustment_width, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t upsampling_height, + uint32_t upsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + float output_min, + float output_max, + uint32_t input_id, + uint32_t filter_id, + uint32_t bias_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Depthwise Convolution Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING +/// flag is specified. +/// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param kernel_height - kernel (filter) height. +/// @param kernel_width - kernel (filter) width. +/// @param subsampling_height - height of subsampling region for convolution output (convolution height stride). +/// @param subsampling_width - width of subsampling region for convolution output (convolution width stride). +/// @param dilation_height - dilation of kernel elements along the height dimension. +/// @param dilation_width - dilation of kernel elements along the width dimension. +/// @param depth_multiplier - ratio of output channels to input channels. +/// @param input_channels - number of input channels. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, input_channels] dimensions +/// @param filter_id - Value ID for the filter tensor. The filter tensor must ge a 4D tensor defined in the @a subgraph +/// with [1, kernel_height, kernel_width, input_channels * depth_multiplier] dimensions. +/// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a 2D Depthwise Convolution Node without +/// a bias. If present, the bias tensor must be a 1D tensor defined in the @a subgraph with +/// [input_channels * depth_multiplier] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, OH, OW, input_channels * depth_multiplier] dimensions. +/// @param flags - binary features of the 2D Depthwise Convolution Node. The only currently supported values is +/// XNN_FLAG_TENSORFLOW_SAME_PADDING. +enum xnn_status xnn_define_depthwise_convolution_2d( + xnn_subgraph_t subgraph, + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t depth_multiplier, + size_t input_channels, + float output_min, + float output_max, + uint32_t input_id, + uint32_t filter_id, + uint32_t bias_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Depth To Space Node 2D and add it to a Subgraph. +/// +/// The Depth To Space 2D Node rearranges data from depth into blocks of spatial data (a reverse transform to +/// Space To Depth). For a given input pixel, an output square of pixels with side @a block_size is formed from values +/// in the corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times +/// smaller than that of the input. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param block_size - the size of the spatial block. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, OC * block_size * block_size] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH * block_size, IW * block_size, OC] dimensions. +/// @param flags - binary features of the input_channels Node. No supported flags are currently defined. +enum xnn_status xnn_define_depth_to_space_2d( + xnn_subgraph_t subgraph, + uint32_t block_size, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +enum xnn_status xnn_define_depth_to_space( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t block_size, + uint32_t flags); + +/// Define a 1D Global Average Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions +/// defined in the @a subgraph. Averaging is performed across the second-innermost dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more +/// dimensions defined in the @a subgraph. +/// @param flags - binary features of the 1D Global Average Pooling Node. The only currently supported value is +/// XNN_FLAG_KEEP_DIMS. +XNN_DEPRECATED enum xnn_status xnn_define_global_average_pooling_1d( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Global Average Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions +/// defined in the @a subgraph. Averaging is performed across the second- and third-innermost +/// dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more +/// dimensions defined in the @a subgraph. +/// @param flags - binary features of the 2D Global Average Pooling Node. The only currently supported value is +/// XNN_FLAG_KEEP_DIMS. +XNN_DEPRECATED enum xnn_status xnn_define_global_average_pooling_2d( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 1D Global Sum Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 2 or more dimensions +/// defined in the @a subgraph. Averaging is performed across the second-innermost dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 2 or more +/// dimensions defined in the @a subgraph. +/// @param flags - binary features of the 1D Global Sum Pooling Node. The only currently supported value is +/// XNN_FLAG_KEEP_DIMS. +XNN_DEPRECATED enum xnn_status xnn_define_global_sum_pooling_1d( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Global Sum Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with 3 or more dimensions +/// defined in the @a subgraph. Averaging is performed across the second- and third-innermost +/// dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor with 3 or more +/// dimensions defined in the @a subgraph. +/// @param flags - binary features of the 2D Global Sum Pooling Node. The only currently supported value is +/// XNN_FLAG_KEEP_DIMS. +XNN_DEPRECATED enum xnn_status xnn_define_global_sum_pooling_2d( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Average Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING +/// flag is specified. +/// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param pooling_height - pooling (kernel) height. +/// @param pooling_width - pooling (kernel) width. +/// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding +/// to vertically adjacent output pixels. +/// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding +/// to horizontally adjacent output pixels. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, channels] dimensions +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, OH, OW, channels] dimensions. +/// @param flags - binary features of the 2D Average Pooling Node. The only currently supported values is +/// XNN_FLAG_TENSORFLOW_SAME_PADDING. +enum xnn_status xnn_define_average_pooling_2d( + xnn_subgraph_t subgraph, + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Fully Connected Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the +/// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least +/// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if +/// input is a 2D tensor, it must have [batch_size, input_channels] dimensions. +/// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be +/// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of +/// [num_input_elements] dimensions, then reshaped into a 2D tensor of +/// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the +/// total number of elements in the input tensor. +/// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph. +/// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have +/// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is +/// specified, the filter tensor must have [input_channels, output_channels] dimensions. +/// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias. +/// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels] +/// dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph. +/// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same +/// dimensionality as the input tensor, all its dimensions but the last one must match the +/// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must +/// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output +/// must be a 2D tensor of [batch_size, output_channels] dimensions. +/// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of +/// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the +/// total number of elements in the input tensor. +/// @param flags - binary features of the Fully Connected Node. The only currently supported values are +/// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS. +enum xnn_status xnn_define_fully_connected( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t filter_id, + uint32_t bias_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Sparse Fully Connected Node and add it to a Subgraph. +/// +/// This operator is experimental, and will be removed in the future. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the +/// @a subgraph. If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the input tensor must be at least +/// 1D and its last dimension must match the last dimension of the filter tensor. In particular, if +/// input is a 2D tensor, it must have [batch_size, input_channels] dimensions. +/// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, the number of elements in the input tensor must be +/// divisible by the input_channels. The tensor will be first flattened into a 1D tensor of +/// [num_input_elements] dimensions, then reshaped into a 2D tensor of +/// [num_input_elements / input_channels, input_channels] dimensions where num_input_elements is the +/// total number of elements in the input tensor. +/// @param filter_id - Value ID for the filter tensor. The filter tensor must a 2D tensor defined in the @a subgraph. +/// If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is not specified, the filter tensor must have +/// [output_channels, input_channels] dimensions. If the XNN_FLAG_TRANSPOSE_WEIGHTS flag is +/// specified, the filter tensor must have [input_channels, output_channels] dimensions. +/// @param bias_id - Value ID for the bias tensor, or XNN_INVALID_VALUE_ID for a Fully Connected Node without a bias. +/// If present, the bias tensor must be a 1D tensor defined in the @a subgraph with [output_channels] +/// dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph. +/// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is not specified, the output tensor must have the same +/// dimensionality as the input tensor, all its dimensions but the last one must match the +/// corresponding dimensions of the input tensor, and the last dimensions of the output tensor must +/// match the first dimension of the filter tensor. In particular, if input is a 2D tensor, output +/// must be a 2D tensor of [batch_size, output_channels] dimensions. +/// If XNN_FLAG_TENSORFLOW_RESHAPE_2D is specified, output must be a 2D tensor of +/// [num_input_elements / input_channels, output_channels] dimensions where num_input_elements is the +/// total number of elements in the input tensor. +/// @param flags - binary features of the Fully Connected Node. The only currently supported values are +/// XNN_FLAG_TENSORFLOW_RESHAPE_2D and XNN_FLAG_TRANSPOSE_WEIGHTS. +enum xnn_status xnn_define_fully_connected_sparse( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t filter_id, + uint32_t bias_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Max Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_padding_top - implicit zero-padding above 2D input data. Must be 0 if XNN_FLAG_TENSORFLOW_SAME_PADDING +/// flag is specified. +/// @param input_padding_right - implicit zero-padding to the right of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_bottom - implicit zero-padding below 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param input_padding_left - implicit zero-padding to the left of 2D input data. Must be 0 if +/// XNN_FLAG_TENSORFLOW_SAME_PADDING flag is specified. +/// @param pooling_height - pooling (kernel) height. +/// @param pooling_width - pooling (kernel) width. +/// @param stride_height - displacing of the pooling window in the vertical dimension of the input pixels corresponding +/// to vertically adjacent output pixels. +/// @param stride_width - displacing of the pooling window in the horizontal dimension of the input pixels corresponding +/// to horizontally adjacent output pixels. +/// @param dilation_height - dilation of pooling elements along the height dimension. +/// @param dilation_width - dilation of pooling elements along the width dimension. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, channels] dimensions +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, OH, OW, channels] dimensions. +/// @param flags - binary features of the 2D Max Pooling Node. The only currently supported values is +/// XNN_FLAG_TENSORFLOW_SAME_PADDING. +enum xnn_status xnn_define_max_pooling_2d( + xnn_subgraph_t subgraph, + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D ArgMax Pooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_padding_top - implicit zero-padding above 2D input data. +/// @param input_padding_right - implicit zero-padding to the right of 2D input data. +/// @param input_padding_bottom - implicit zero-padding below 2D input data. +/// @param input_padding_left - implicit zero-padding to the left of 2D input data. +/// @param pooling_height - pooling (kernel) height. Vertical stride between pooling regions match this value. +/// @param pooling_width - pooling (kernel) width. Horizontal stride between pooling regions match this value. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, channels] dimensions +/// @param output_value_id - Value ID for the output tensor with the maximum values in the pools. The output tensor must +/// be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels] dimensions. +/// @param output_index_id - Value ID for the output tensor with the indexes of the maximum values in the pools. The +/// output tensor must be a 4D tensor defined in the @a subgraph with [N, OH, OW, channels] +/// dimensions. +/// @param flags - binary features of the 2D ArgMax Pooling Node. No supported flags are currently defined. +enum xnn_status xnn_define_argmax_pooling_2d( + xnn_subgraph_t subgraph, + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t input_id, + uint32_t output_value_id, + uint32_t output_index_id, + uint32_t flags); + +/// Define a 2D UnPooling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param padding_top - implicit padding above 2D output data. +/// @param padding_right - implicit padding to the right of 2D output data. +/// @param padding_bottom - implicit padding below 2D output data. +/// @param padding_left - implicit padding to the left of 2D output data. +/// @param pooling_height - height of the pooling window. +/// @param pooling_width - width of the pooling window. +/// @param input_value_id - Value ID for the input tensor with the max-pooling values to invert. The input value tensor +/// must be a 4D tensor defined in the @a subgraph with [N, IH, IW, channels] dimensions. +/// @param input_index_id - Value ID for the input tensor with the indices of the per-pool maximum values produced by +/// a 2D UnPooling Node. The input tensor must be a 4D tensor defined in the @a subgraph with +/// [N, IH, IW, channels] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, OH, OW, channels] dimensions. +/// @param flags - binary features of the 2D UnPooling Node. No supported flags are currently defined. +enum xnn_status xnn_define_unpooling_2d( + xnn_subgraph_t subgraph, + uint32_t padding_top, + uint32_t padding_right, + uint32_t padding_bottom, + uint32_t padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t input_value_id, + uint32_t input_index_id, + uint32_t output_id, + uint32_t flags); + +enum xnn_binary_operator { + xnn_binary_invalid = -1, + xnn_binary_add, + xnn_binary_subtract, + xnn_binary_multiply, + xnn_binary_divide, + xnn_binary_maximum, + xnn_binary_minimum, + xnn_binary_copysign, + xnn_binary_squared_difference, + xnn_binary_prelu, + // The following operators are experimental and may be removed. + xnn_binary_modulus, + xnn_binary_atan2, + xnn_binary_pow, + xnn_binary_bitwise_and, + xnn_binary_bitwise_or, + xnn_binary_bitwise_xor, + xnn_binary_shift_left, + xnn_binary_shift_right_logical, + xnn_binary_shift_right_arithmetic, +}; + +struct xnn_binary_params { + /// lower bound for clipping output values. + double output_min; + /// upper bound for clipping output values. + double output_max; +}; + +/// Define a 2-Input binary operator Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param type - Type of operator to apply to the two inputs. +/// @param params - Optional parameters for the operator. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Node. No supported flags are currently defined. +enum xnn_status xnn_define_binary( + xnn_subgraph_t subgraph, + enum xnn_binary_operator type, + const struct xnn_binary_params* params, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2-Input Add Node and add it to a Subgraph. +/// +/// The 2-Input Add Node computes elementwise addition of two tensor inputs with numpy broadcasting rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Add Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_add2( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2-Input Multiply Node and add it to a Subgraph. +/// +/// The 2-Input Multiply Node computes elementwise multiplication of two tensor inputs with numpy broadcasting rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Multiply Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_multiply2( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +// Cap operations applied to logits (Q * K) of attention operator. +enum xnn_attention_logits_cap_type { + // No capping. + xnn_attention_logits_cap_type_none = 0, + // Cap the absolute values of logits by tanh: tanh(logits / cap) * cap + xnn_attention_logits_cap_type_tanh +}; + +// Params when the cap type is xnn_attention_logits_cap_type_tanh. +struct xnn_attention_logits_cap_tanh_params { + float cap; +}; + +/// Define a Scaled Dot-Product Attention Node and add it to a Subgraph. +/// +/// This operator is experimental. +/// +/// The Scaled Dot-Product Attention Node computes a multi-head or multi-query scaled dot attention on the query, key, +/// and value tensors. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param cap_type - type of cap to be applied to the logits. +/// @param cap_params - parameters for the cap. Must be a pointer to xnn_attention_logits_cap_tanh_params if cap_type +/// is xnn_attention_logits_cap_type_tanh. +/// @param query_id - Value ID for the query tensor. The query tensor must be a 3+-dimensional tensor defined in the +/// @a subgraph with the dimensions as [*, H, T, C], where H/T/C are the heads/tokens/channels, and * +/// is the 0 or more dimensions treated as batch size. +/// @param key_id - Value ID for the key tensor. The key tensor must be a 2+--dimensional tensor defined in the +/// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as +/// [*, H, U, C] (multi-head), or have 1 less dimension than the query, with the dimensions as +/// as [*, U, C] (multi-query, number of heads omitted implies single head), where H/U/C are the +/// heads/key_value_tokens/channels, and * is the 0 or more dimensions treated as batch size. These +/// batch size dimensions must be the same as query. +/// @param value_id - Value ID for the value tensor. The value tensor must be a 2+--dimensional tensor defined in the +/// @a subgraph. It can have the same number of dimensions as the query, with the dimensions as +/// [*, H, U, D] (multi-head), or have 1 less dimension than the query, with the dimensions as +/// as [*, U, D] (multi-query, number of heads omitted implies single head), where H/U/D are the +/// heads/key_value_tokens/value_channels, and * is the 0 or more dimensions treated as batch size. +/// These batch size dimensions must be the same as query and key. +/// @param scale_id - Value ID for the scale tensor. The scale tensor must be a 1D tensor defined in the @a subgraph +/// with [C] dimensions. The query tensor is multiplied with this scale tensor before the dot product +/// with the key tensor. +/// @param mask_id - Value ID for the mask tensor. The mask tensor must be a 2D tensor defined in the @a subgraph with +/// [T, U] dimensions. The mask tensor is added to the logits (query dot value). +/// @param output_id - Value ID for the output tensor. The output tensor must be a 3+-dimensional tensor defined in the +/// @a subgraph with the dimensions as [*, H, T, D], where H/T/D are the heads/tokens/value_channels, +/// and * is the 0 or more dimensions treated as batch size. These batch size dimensions must be the +/// same as query, key, and value. +/// @param flags - binary features of the Scaled Dot Product Attention Node. No supported flags are currently defined. +enum xnn_status xnn_define_scaled_dot_product_attention( + xnn_subgraph_t subgraph, + enum xnn_attention_logits_cap_type cap_type, + const void* cap_params, + uint32_t query_id, + uint32_t key_id, + uint32_t value_id, + uint32_t scale_id, + uint32_t mask_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Subtract Node and add it to a Subgraph. +/// +/// The Subtract Node computes elementwise subtraction of two tensor inputs with numpy broadcasting rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Subtract Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_subtract( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Divide Node and add it to a Subgraph. +/// +/// The Divide Node computes elementwise division of two tensor inputs with numpy broadcasting rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Divide Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_divide( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2-Input Maximum Node and add it to a Subgraph. +/// +/// The 2-Input Maximum Node computes elementwise maximum of two tensor inputs with numpy broadcasting rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Maximum Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_maximum2( + xnn_subgraph_t subgraph, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2-Input Minimum Node and add it to a Subgraph. +/// +/// The 2-Input Minimum Node computes elementwise minimum of two tensor inputs with numpy broadcasting rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Minimum Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_minimum2( + xnn_subgraph_t subgraph, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Squared Difference Node and add it to a Subgraph. +/// +/// The Squared Difference Node computes elementwise squared difference of two tensor inputs with numpy broadcasting +/// rules. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the second +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an M-dimensional tensor defined in +/// the @a subgraph with each dimension either equal to the corresponding dimension of the first +/// input, or equal to 1. In the latter case, the elements of the input tensor are broadcasted along +/// that dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be a max(N,M)-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the maximum between the corresponding dimension +/// of the two inputs. +/// @param flags - binary features of the Squared Difference Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_squared_difference( + xnn_subgraph_t subgraph, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Constant Pad Node with static padding specification and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param pre_paddings - number of padding elements to insert before input elements for every dimension. This array +/// must have as many elements as the number of dimensions in the input tensor. +/// @param post_paddings - number of padding elements to insert after input elements for every dimension. This array +/// must have as many elements as the number of dimensions in the input tensor. +/// @param padding_value - constant value used to initialize padding elements. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor with padding. +/// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined. +enum xnn_status xnn_define_static_constant_pad( + xnn_subgraph_t subgraph, + const size_t* pre_paddings, + const size_t* post_paddings, + float padding_value, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Expand Dims Node with and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param num_new_axes - number of new axes of size 1 to be inserted. +/// @param new_axes - The axis positions of the new axes in the expanded dimensions. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor with padding. +/// @param flags - binary features of the Constant Pad Node. No supported flags are currently defined. +enum xnn_status xnn_define_static_expand_dims( + xnn_subgraph_t subgraph, + size_t num_new_axes, + const size_t* new_axes, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Mean Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param num_reduction_axes - number of axes along which mean is computed. +/// @param reduction_axes - axes along which mean is computed. +/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least +/// @a num_reduction_axes dimensions defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the +/// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if +/// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at +/// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified). +/// @param flags - binary features of the Mean Node. The only currently supported value is XNN_FLAG_KEEP_DIMS +XNN_DEPRECATED enum xnn_status xnn_define_static_mean( + xnn_subgraph_t subgraph, + size_t num_reduction_axes, + const size_t* reduction_axes, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +enum xnn_reduce_operator { + xnn_reduce_invalid = -1, + xnn_reduce_sum, + xnn_reduce_mean, +}; + +/// Define a Reduce Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param num_reduction_axes - number of axes along which reduce is computed. +/// @param reduction_axes - axes along which reduce is computed. +/// @param input_id - Value ID for the input tensor. The input tensor must be a dense tensor with at least +/// @a num_reduction_axes dimensions defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be a dense tensor defined in the +/// @a subgraph with @a num_reduction_axes fewer dimensions than the input tensor (if +/// XNN_FLAG_KEEP_DIMS is not specified), or has same dimension rank but the dimension at +/// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is specified). +/// @param flags - binary features of the Reduce Node. The only currently supported value is XNN_FLAG_KEEP_DIMS +enum xnn_status xnn_define_static_reduce( + xnn_subgraph_t subgraph, + enum xnn_reduce_operator reduce_operator_type, + size_t num_reduction_axes, + const size_t* reduction_axes, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Reduce Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param num_reduction_axes - number of axes along which reduce is computed. +/// @param reduction_axes - axes along which reduce is computed. Negative values +/// are interpreted as offsets from @a +/// num_reduction_axes. +/// @param input_id - Value ID for the input tensor. The input tensor must be a +/// dense tensor with at least @a num_reduction_axes +/// dimensions defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be +/// a dense tensor defined in the @a subgraph with @a +/// num_reduction_axes fewer dimensions than the input tensor +/// (if XNN_FLAG_KEEP_DIMS is not specified), or has same +/// dimension rank but the dimension at +/// @a reduction_axes reduced to 1 (if XNN_FLAG_KEEP_DIMS is +/// specified). +/// @param flags - binary features of the Reduce Node. The only currently +/// supported value is XNN_FLAG_KEEP_DIMS +enum xnn_status xnn_define_static_reduce_v2( // + xnn_subgraph_t subgraph, // + enum xnn_reduce_operator reduce_operator_type, // + size_t num_reduction_axes, // + const int64_t* reduction_axes, // + uint32_t input_id, // + uint32_t output_id, // + uint32_t flags); + +/// Define a 2-Input Concatenate Node and add it to a Subgraph. +/// +/// The 2-Input Concatenate Node concatenates two tensors along a specified axis. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// second input. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// first input. +/// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the dimension of both inputs, except the axis +/// dimension, where it is the sum of the corresponding dimensions of both inputs. +/// @param flags - binary features of the Concatenate Node. No supported flags are currently defined. +enum xnn_status xnn_define_concatenate2( + xnn_subgraph_t subgraph, + int32_t axis, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 3-Input Concatenate Node and add it to a Subgraph. +/// +/// The 3-Input Concatenate Node concatenates three tensors along a specified axis. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis +/// dimension, where it is the sum of the corresponding dimensions of all inputs. +/// @param flags - binary features of the Concatenate Node. No supported flags are currently defined. +enum xnn_status xnn_define_concatenate3( + xnn_subgraph_t subgraph, + int32_t axis, + uint32_t input1_id, + uint32_t input2_id, + uint32_t input3_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 4-Input Concatenate Node and add it to a Subgraph. +/// +/// The 4-Input Concatenate Node concatenates four tensors along a specified axis. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis +/// dimension, where it is the sum of the corresponding dimensions of all inputs. +/// @param flags - binary features of the Concatenate Node. No supported flags are currently defined. +enum xnn_status xnn_define_concatenate4( + xnn_subgraph_t subgraph, + int32_t axis, + uint32_t input1_id, + uint32_t input2_id, + uint32_t input3_id, + uint32_t input4_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 5-Input Concatenate Node and add it to a Subgraph. +/// +/// The 5-Input Concatenate Node concatenates four tensors along a specified axis. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param axis - the axis to concatenate the two input tensors along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input3_id - Value ID for the third input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input4_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param input5_id - Value ID for the fourth input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph with each dimension, except the axis, equal to the corresponding dimension of the +/// other inputs. +/// @param output_id - Value ID for the output tensor. The output tensor must be a N-dimensional tensor defined +/// in the @a subgraph with each dimension equal to the dimension of all inputs, except the axis +/// dimension, where it is the sum of the corresponding dimensions of all inputs. +enum xnn_status xnn_define_concatenate5( + xnn_subgraph_t subgraph, + int32_t axis, + uint32_t input1_id, + uint32_t input2_id, + uint32_t input3_id, + uint32_t input4_id, + uint32_t input5_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Copy Sign Node and add it to a Subgraph. +/// +/// The Copy Sign Node copies the sign of the second input to the first input. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. +/// @param flags - binary features of the Copy Sign Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_copysign( + xnn_subgraph_t subgraph, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Copy Node and add it to a Subgraph. +/// +/// The Copy Node copies an input tensor to an output tensor. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the first input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Copy Node. No supported flags are currently defined. +enum xnn_status xnn_define_copy( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2-Output Split Node and add it to a Subgraph. +/// +/// The 2-Output Split Node splits an input tensor into two output tensors along a specified axis evenly. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a +/// subgraph. +/// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined +/// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension +/// of the second output. The split_dim dimension is half of the input's split_dim. +/// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor +/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding +/// dimension of the first output. The split_dim dimension is half of the input's split_dim. +/// @param flags - binary features of the Split Node. No supported flags are currently defined. +enum xnn_status xnn_define_even_split2( + xnn_subgraph_t subgraph, + int32_t split_dim, + uint32_t input_id, + uint32_t output1_id, + uint32_t output2_id, + uint32_t flags); + +/// Define a 3-Output Split Node and add it to a Subgraph. +/// +/// The 3-Output Split Node splits an input tensor into three output tensors along a specified axis evenly. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a +/// subgraph. +/// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined +/// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension +/// of the second and third output. The split_dim dimension is one third of the input's split_dim. +/// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor +/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding +/// dimension of the first and third output. The split_dim dimension is one third of the input's +/// split_dim. +/// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor +/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding +/// dimension of the second and third output. The split_dim dimension is one third of the input's +/// split_dim. +/// @param flags - binary features of the Split Node. No supported flags are currently defined. +enum xnn_status xnn_define_even_split3( + xnn_subgraph_t subgraph, + int32_t split_dim, + uint32_t input_id, + uint32_t output1_id, + uint32_t output2_id, + uint32_t output3_id, + uint32_t flags); + +/// Define a 4-Output Split Node and add it to a Subgraph. +/// +/// The 4-Output Split Node splits an input tensor into four output tensors along a specified axis evenly. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param split_dim - the dimension to split the input tensor along. If this is less than zero, the number of +/// dimensions is added to it. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in the @a +/// subgraph. +/// @param output1_id - Value ID for the first output tensor. The output tensor must be an N-dimensional tensor defined +/// in the @a subgraph with each dimension, except the axis, equal to the corresponding dimension +/// of the other output tensors. The split_dim dimension is one fourth of the input's split_dim. +/// @param output2_id - Value ID for the second output tensor. The output tensor must be an N-dimensional tensor +/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding +/// dimension of the other output tensors. The split_dim dimension is one fourth of the input's +/// split_dim. +/// @param output3_id - Value ID for the third output tensor. The output tensor must be an N-dimensional tensor +/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding +/// dimension of the other output tensors. The split_dim dimension is one fourth of the input's +/// split_dim. +/// @param output4_id - Value ID for the fourth output tensor. The output tensor must be an N-dimensional tensor +/// defined in the @a subgraph with each dimension, except the axis, equal to the corresponding +/// dimension of the other output tensors. The split_dim dimension is one fourth of the input's +/// split_dim. +/// @param flags - binary features of the Split Node. No supported flags are currently defined. +enum xnn_status xnn_define_even_split4( + xnn_subgraph_t subgraph, + int32_t split_dim, + uint32_t input_id, + uint32_t output1_id, + uint32_t output2_id, + uint32_t output3_id, + uint32_t output4_id, + uint32_t flags); + +/// Define a Reshape Node with static shape specification and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param num_dims - number of shape dimensions in the output tensor. +/// @param new_shape - shape dimensions of the output tensor. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor with padding. +/// @param flags - binary features of the Reshape Node. No supported flags are currently defined. +enum xnn_status xnn_define_static_reshape( + xnn_subgraph_t subgraph, + size_t num_dims, + const size_t* new_shape, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a 2D Resize Bilinear Node with static output height & width specification and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param new_height - height dimension of the output tensor. +/// @param new_width - width dimension of the output tensor. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, H, W, C] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, new_height, new_width, C] dimensions. +/// @param flags - binary features of the 2D Resize Bilinear Node. The only currently supported values are +/// XNN_FLAG_TENSORFLOW_LEGACY_MODE and XNN_FLAG_ALIGN_CORNERS, which are mutually exclusive. +enum xnn_status xnn_define_static_resize_bilinear_2d( + xnn_subgraph_t subgraph, + size_t new_height, + size_t new_width, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a PReLU (Parametric ReLU) Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, H, W, channels] dimensions. +/// @param slope_id - Value ID for the slope tensor. The slope tensor must be a 1D tensor defined in the @a subgraph with +/// either [1] or [channels] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, H, W, channels] dimensions. +/// @param flags - binary features of the PReLU Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_prelu( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t slope_id, + uint32_t output_id, + uint32_t flags); + +/// Define a RoPE (Rotary Positional Embeddings) Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param max_tokens - deprecated. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [batch, tokens, heads, channels] dimensions. +/// @param weights_id - Value ID for the weights tensor. The weights tensor must be a 2D tensor defined in the +/// @a subgraph with [max_tokens, channels] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [batch, tokens, heads, channels] dimensions. +/// @param flags - binary features of the RoPE Node. No supported flags are currently defined. +enum xnn_status xnn_define_rope( + xnn_subgraph_t subgraph, + size_t max_sequence_size, + uint32_t input_id, + uint32_t weights_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Abs Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Abs Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_abs( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Bankers' Rounding Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Bankers' Rounding Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_bankers_rounding( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Batch Matrix Multiply Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input1_id - Value ID for the first input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the second input +/// tensor. The last 2 dimensions are [M, K]. If XNN_FLAG_TRANSPOSE_B is not specified, the last +/// dimension must match the second last dimension of the second input tensor. If +/// XNN_FLAG_TRANSPOSE_B is specified, the last dimension must match the last dimension of the +/// second input tensor. +/// @param input2_id - Value ID for the second input tensor. The input tensor must be an N-dimensional tensor defined +/// in the @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first input +/// tensor. If XNN_FLAG_TRANSPOSE_B is not specified, the last 2 dimensions are [K, N], and the +/// second last dimension must match the last dimension of the first input tensor. If +/// XNN_FLAG_TRANSPOSE_B is specified, the last 2 dimensions are [N, K], and the last dimension must +/// match the last dimension of the first input tensor. +/// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined in the +/// @a subgraph. It must be at least 3D. The first N-2 dimensions must match the first and second +/// input tensors . The last 2 dimensions must be [M, N]. +/// @param flags - binary features of the Batch Matrix Multiply Node. The only currently supported value is +/// XNN_FLAG_TRANSPOSE_B. +enum xnn_status xnn_define_batch_matrix_multiply( + xnn_subgraph_t subgraph, + uint32_t input1_id, + uint32_t input2_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Ceiling Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Ceiling Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_ceiling( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Clamp Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param output_min - lower bound for clipping output values. +/// @param output_max - upper bound for clipping output values. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Clamp Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_clamp( + xnn_subgraph_t subgraph, + float output_min, + float output_max, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define an ELU (Exponential Linear Unit) Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param alpha - scale factor for negative output elements. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the ELU Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_elu( + xnn_subgraph_t subgraph, + float alpha, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Exp Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Exp Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_exp( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Floor Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Floor Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_floor( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define an GELU (Gaussian Error Linear Unit) Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the GELU Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_gelu( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a HardSwish Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the HardSwish Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_hardswish( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Leaky ReLU Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param negative_slope - scale factor for negative input elements. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Leaky ReLU Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_leaky_relu( + xnn_subgraph_t subgraph, + float negative_slope, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Log Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Log Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_log( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Negate Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Negate Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_negate( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Sigmoid Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Sigmoid Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_sigmoid( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a SoftMax Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph, and have at +/// least one dimension. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the SoftMax Node. No supported flags are currently defined. +enum xnn_status xnn_define_softmax( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Space To Depth 2D Node and add it to a Subgraph. +/// +/// The Space To Depth 2D Node rearranges blocks of spatial data into blocks (a reverse transform to Depth To Space 2D). +/// For a given input pixel, an output square of pixels with side @a block_size is formed from values in the +/// corresponding number of its channels. The output depth is therefore @a block_size x @a block_size times greater +/// than that of the input. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param block_size - the size of the spatial block. +/// @param input_id - Value ID for the input tensor. The input tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH * block_size, IW * block_size, OC] dimensions. +/// @param output_id - Value ID for the output tensor. The output tensor must be a 4D tensor defined in the @a subgraph +/// with [N, IH, IW, OC * block_size * block_size] dimensions. +/// @param flags - binary features of the input_channels Node. No supported flags are currently defined. +enum xnn_status xnn_define_space_to_depth_2d( + xnn_subgraph_t subgraph, + uint32_t block_size, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Square Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Square Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_square( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Square Root Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Square Root Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_square_root( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Reciprocal Square Root Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be +/// defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be +/// defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Square Root Node. No supported flags +/// are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_reciprocal_square_root( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +enum xnn_status xnn_define_static_slice( + xnn_subgraph_t subgraph, + size_t num_dims, + const size_t* offsets, + const size_t* sizes, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); +/// Define a Static Slice Node add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param num_dims - number of shape dimensions in the input and output tensor. +/// @param offsets - offsets in each dimension of the input tensor. This array must have @a num_dims elements. Can be +/// negative meaning that the offset is relative to the end of the dimension. +/// @param sizes - size of each dimension in output tensor. This array must have @a num_dims elements. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// dimensions must match @a sizes. +/// @param flags - binary features of the Static Slice Node. No supported flags are currently defined. +enum xnn_status xnn_define_static_slice_v2( // + xnn_subgraph_t subgraph, // + size_t num_dims, // + const int64_t* offsets, // + const size_t* sizes, // + uint32_t input_id, // + uint32_t output_id, // + uint32_t flags); + +/// Define a Static Transpose Node and add it to a Subgraph. +/// +/// The Static Transpose Node applies a generalized transpose to the input tensor using the permuation in perm. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be an N-dimensional tensor defined in +/// the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be an N-dimensional tensor defined +/// in the @a subgraph with each dimension equal to its corresponding permuted input dimension. +/// @param num_dims - the number of permutation dimensions. This must be equal to the number of input dimensions. +/// @param perm - The permutation of the axis of the input tensor. The perm array must must contain 0 to N-1 in the +/// permuted order. +/// @param flags - binary features of the Static Transpose Node. No supported flags are currently defined. +enum xnn_status xnn_define_static_transpose( + xnn_subgraph_t subgraph, + size_t num_dims, + const size_t* perm, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Define a Tanh Node and add it to a Subgraph. +/// +/// @param subgraph - a Subgraph object that will own the created Node. +/// @param input_id - Value ID for the input tensor. The input tensor must be defined in the @a subgraph. +/// @param output_id - Value ID for the output tensor. The output tensor must be defined in the @a subgraph, and its +/// shape must match the shape of the input tensor. +/// @param flags - binary features of the Tanh Node. No supported flags are currently defined. +XNN_DEPRECATED enum xnn_status xnn_define_tanh( + xnn_subgraph_t subgraph, + uint32_t input_id, + uint32_t output_id, + uint32_t flags); + +/// Code cache is a cache for JIT generated code. +typedef struct xnn_code_cache* xnn_code_cache_t; + +/// Weights cache can be finalized in these ways: +enum xnn_weights_cache_finalization_kind { + /// Weights cache is finalized, no insert operations into the weights cache is allowed, even if the "inserted" + /// weights already exist in thee cache. Weights cache memory will also be trimmed to page boundary and set to + /// read-only (to prevent writes). + xnn_weights_cache_finalization_kind_hard, + /// Weights cache will be finalized with some extra space at the end, this allows for "inserting" into the cache only + /// if the weights are already in the cache, and errors on inserting uncached weights. There is memory overhead. + xnn_weights_cache_finalization_kind_soft, +}; + +/// A combination of multiple factors to uniquely locate the weights cache. +struct xnn_weights_cache_look_up_key { + /// The unique seed for each ukernel. It is guaranteed that each ukernel provides + /// a consistent and identical seed. + uint32_t seed; + /// Pointer to the original kernel. + const void* kernel; + /// Pointer to the original bias, could be NULL. + const void* bias; +}; + +/// A group of function pointers to manage weights cache. All functions may be +/// called on multi threads. +struct xnn_weights_cache_provider { + /// User-specified pointer that will be passed as-is to all functions in this + /// structure. + void* context; + + /// Looks up the tuple of {cache_key, kernel, bias} in the cache. If it is found, + /// returns the offset to the found entry for reuse. Otherwise, returns SIZE_MAX. + /// @param context - The user-specified pointer from xnn_weights_cache_provider structure. + /// @param cache_key - The key used to locate the weights cache entry. + size_t (*look_up)(void* context, const struct xnn_weights_cache_look_up_key* cache_key); + + /// Ensures that cache has enough space for `n` bytes. Returns the address to + /// store weight cache. Returns NULL if fails to reserve space. + /// @param context - The user-specified pointer from xnn_weights_cache_provider structure. + /// @param n - size to be reserved. + void* (*reserve_space)(void* context, size_t n); + + /// Looks up packed weights at `ptr` in the cache. If it is found, reuse it. + /// Otherwise, it is added to the cache. Returns the offset to the cache. + /// @param context - The user-specified pointer from xnn_weights_cache_provider structure. + /// @param cache_key - The key used to locate the weights cache entry. + /// @param ptr - pointer pointing to the packed weight. + /// @param size - size of the packed weight. + size_t (*look_up_or_insert)(void* context, const struct xnn_weights_cache_look_up_key* cache_key, void* ptr, size_t size); + + /// Returns whether the cache is finalized. + /// @param context - The user-specified pointer from xnn_weights_cache_provider structure. + bool (*is_finalized)(void* context); + + /// Returns the absolute pointer corresponding to `offset`, where the offset is returned from + /// `look_up` or `get_or_insert`. This function must be called after finalize. + /// @param context - The user-specified pointer from xnn_weights_cache_provider structure. + /// @param offset - offset to the start of internal buffer + void* (*offset_to_addr)(void* context, size_t offset); + + /// Destroy a weights cache object, as well as memory used for the cache. + /// @param context - The user-specified pointer from xnn_weights_cache_provider structure. + enum xnn_status (*delete_cache)(void* context); +}; + +/// Weights cache is a cache for packed weights. It can be reused between runtimes. +typedef struct xnn_weights_cache_provider* xnn_weights_cache_t; + +/// Create a weights cache object specifying the initial size of weights cache (in bytes). +/// +/// @param[in] size - initial capacity of the weights cache (in bytes), i.e. it can hold size bytes without growing. +/// @param weights_cache_out - pointer to the variable that will be initialized to a handle to the weights cache provider +/// upon successful return. Once created, the weights cache provider can be shared between +/// different Runtime objects. +enum xnn_status xnn_create_weights_cache_with_size(size_t size, xnn_weights_cache_t* weights_cache_out); + +enum xnn_status xnn_create_weights_cache(xnn_weights_cache_t* weights_cache_out); + +/// Finalizes the weights cache. The kind of finalization is specified by `finalization_kind`. +/// @param weights_cache - the weights cache object to finalize. +/// @param finalization_kind - the kind of finalization. +enum xnn_status xnn_finalize_weights_cache( + xnn_weights_cache_t weights_cache, + enum xnn_weights_cache_finalization_kind finalization_kind); + +// Wrapper function of the function pointers in `xnn_weights_cache_t`. +bool xnn_weights_cache_is_finalized(xnn_weights_cache_t cache); + +/// Destroy a weights cache object, as well as memory used for the cache. +/// @param weights_cache - the weights cache object to destroy. +enum xnn_status xnn_delete_weights_cache(xnn_weights_cache_t weights_cache); + +typedef struct xnn_workspace* xnn_workspace_t; + +/// Create a workspace object. +/// @param workspace_out - pointer to the variable that will be initialized to a handle to the workspace object upon +/// successful return. Once created, the workspace can be shared between different Runtime +/// objects. +enum xnn_status xnn_create_workspace(xnn_workspace_t* workspace_out); +/// Destroy a workspace object, as well as memory used by the workspace. Object destruction can be deferred until all +/// Runtime objects created with this workspace are destroyed. +/// @param workspace - the workspace object to destroy. +enum xnn_status xnn_release_workspace(xnn_workspace_t workspace); + +/// Runtime is a combination of an execution plan for subgraph Nodes and a memory manager for subgraph Values. +typedef struct xnn_runtime* xnn_runtime_t; + +enum xnn_profile_info { + /// Returns a size_t containing the number of operators. + xnn_profile_info_num_operators, + /// Returns a char[] containing the null character separated names of all operators. + xnn_profile_info_operator_name, + /// Returns a uint64_t[] with the runtimes of all operators in the same order as xnn_profile_info_operator_name. + xnn_profile_info_operator_timing, +}; + +/// Return profile information for all operators. +/// +/// @param runtime - a Runtime object created with @ref xnn_create_runtime, @ref xnn_create_runtime_v2 or +/// @ref xnn_create_runtime_v3. +/// @param param_name - type of profile information required. +/// @param param_value_size - the size in bytes of memory pointed to by param_value. If this is not sufficient then +/// param_value_size_ret will be set to the required size and xnn_status_out_of_memory will be +/// returned. +/// @param param_value - a pointer to memory location where appropriate values for a given param_value will be written. +/// @param param_value_size_ret - returns number of bytes required to write the result if param_value_size is not +/// sufficient. +enum xnn_status xnn_get_runtime_profiling_info(xnn_runtime_t runtime, + enum xnn_profile_info param_name, + size_t param_value_size, + void* param_value, + size_t* param_value_size_ret); + +/// Create a Runtime object from a subgraph. +/// +/// @param subgraph - a Subgraph object with all Values and Nodes that would be handled by the runtime. No Values or +/// Nodes can be added to the runtime once it is constructed. +/// @param weights_cache - a cache for packed weights. The runtime will look up and reuse packed weights in this cache, +/// this will reduce memory allocated for packed weights. +/// @param workspace - a workspace to hold internal tensors. The runtime will allocate space used for internal tensors +/// and track them using workspace. Workspace can be shared and reused across different runtimes. If +/// workspace is NULL, there will be no sharing: each runtime has its own workspace. +/// @param threadpool - the thread pool to be used for parallelisation of computations in the runtime. If the thread +/// pool is NULL, the computation would run on the caller thread without parallelization. +/// @param flags - binary features of the runtime. The only currently supported values are +/// XNN_FLAG_HINT_SPARSE_INFERENCE, XNN_FLAG_HINT_FP16_INFERENCE, XNN_FLAG_FORCE_FP16_INFERENCE, +/// XNN_FLAG_YIELD_WORKERS, and XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER. If XNN_FLAG_YIELD_WORKERS is +/// specified, worker threads would be yielded to the system scheduler after processing the last operator +/// in the Runtime. If XNN_FLAG_TRANSIENT_INDIRECTION_BUFFER is specified, convolution operators will +/// initialize indirection buffers on each inference run using temporary memory in the workspace, instead +/// of initializing persistent indirection buffers once. +/// @param runtime_out - pointer to the variable that will be initialized with a handle to the Runtime object upon +/// successful return. Once constructed, the Runtime object is independent of the Subgraph object +/// used to create it. +enum xnn_status xnn_create_runtime_v4( + xnn_subgraph_t subgraph, + xnn_weights_cache_t weights_cache, + xnn_workspace_t workspace, + pthreadpool_t threadpool, + uint32_t flags, + xnn_runtime_t* runtime_out); + +enum xnn_status xnn_create_runtime_v3( + xnn_subgraph_t subgraph, + xnn_weights_cache_t weights_cache, + pthreadpool_t threadpool, + uint32_t flags, + xnn_runtime_t* runtime_out); + +enum xnn_status xnn_create_runtime_v2( + xnn_subgraph_t subgraph, + pthreadpool_t threadpool, + uint32_t flags, + xnn_runtime_t* runtime_out); + +enum xnn_status xnn_create_runtime( + xnn_subgraph_t subgraph, + xnn_runtime_t* runtime_out); + +struct xnn_external_value { + uint32_t id; + void* data; +}; + +/// Reshape an external value. +/// +/// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on +/// the Subgraph creation. If the external ID is XNN_INVALID_VALUE_ID, an internal ID will be +/// created for the Value. +/// @param num_dims - number of dimensions in the shape. +/// @param dims - pointer to an array of @a num_dims shape dimensions. If num_dims is 0, this pointer can be NULL. +/// XNNPACK does not keep any pointers to this array after the function returns. +enum xnn_status xnn_reshape_external_value( + xnn_runtime_t runtime, + uint32_t external_id, + size_t num_dims, + const size_t* dims); + +/// Get the external value shape. +/// +/// @param external_id - external ID for the Value. The ID must be within the range of reversed Value IDs specified on +/// the Subgraph creation. The external ID can not be XNN_INVALID_VALUE_ID. +/// @param num_dims - A valid pointer into which the number of dimensions in the shape will be written. It can not be larger than XNN_MAX_TENSOR_DIMS. +/// @param dims - pointer to an array of @a num_dims shape dimensions. This pointer can't be NULL. It must be large enough to hold +/// at least @a num_dims elements. XNNPACK does not keep any pointers to this array after the function returns. +enum xnn_status xnn_get_external_value_shape( + xnn_runtime_t runtime, + uint32_t external_id, + size_t* num_dims, + size_t* dims); + +/// Reshape the XNNPACK runtime. +/// +/// Propagates the shapes of input tensors through the graph to determine the shapes of intermediate and output tensors. +/// Memory is allocated if required. Output tensor shapes are returned by xnn_get_external_value_shape. +/// +/// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2. +enum xnn_status xnn_reshape_runtime( + xnn_runtime_t runtime); + +/// Deprecated. Use xnn_reshape_runtime and xnn_setup_runtime_v2. +/// +/// Setup data pointers for external inputs and outputs in a Runtime object and +/// allocate memory. +/// +/// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2. +/// @param num_external_values - the number of external inputs and outputs specified in this call. This number must +/// match the number of external inputs and outputs in the runtime, i.e. all external +/// inputs and outputs in the runtime must be specified in one call. +/// @param external_values - array with location information for all external inputs and outputs in the runtime. +enum xnn_status xnn_setup_runtime( + xnn_runtime_t runtime, + size_t num_external_values, + const struct xnn_external_value* external_values); + +/// Setup data pointers for external inputs and outputs in a Runtime object. +/// Should be called after xnn_reshape_runtime. +/// +/// @param runtime - a Runtime object created with @ref xnn_create_runtime or @ref xnn_create_runtime_v2. +/// @param num_external_values - the number of external inputs and outputs specified in this call. This number must +/// match the number of external inputs and outputs in the runtime, i.e. all external +/// inputs and outputs in the runtime must be specified in one call. +/// @param external_values - array with location information for all external inputs and outputs in the runtime. +enum xnn_status xnn_setup_runtime_v2( + xnn_runtime_t runtime, + size_t num_external_values, + const struct xnn_external_value* external_values); + +/// Execute forward pass for all operators in the runtime. +/// +/// @param runtime - the Runtime object with the execution plan to invoke. +enum xnn_status xnn_invoke_runtime( + xnn_runtime_t runtime); + +/// Destroy a Runtime object, as well as operators and memory associated with it. +/// +/// @param runtime - the Runtime object to destroy. +enum xnn_status xnn_delete_runtime( + xnn_runtime_t runtime); + +typedef struct xnn_operator* xnn_operator_t; + +enum xnn_status xnn_run_operator( + xnn_operator_t op, + pthreadpool_t threadpool); + +enum xnn_status xnn_delete_operator( + xnn_operator_t op); + +/// Operator API: +/// - create operator will create and populate a xnn_operator_t +/// - reshape operator will update fields in xnn_operator_t with shape/dimensions and parallelization information +/// - setup operator will update pointers to input and outputs +/// Each supported operator must have a create, reshape, and setup function. (Optionally a run function.) +/// Operators listed below are in alphabetical order by operator name; within each operator, we sort alphabetically by +/// data layout and type. We also group create, reshape, setup (and optionally run) functions of each operator together. + +enum xnn_status xnn_create_binary_elementwise_nd( + enum xnn_binary_operator type, + enum xnn_datatype datatype, + const struct xnn_quantization_params* input1_quantization, + const struct xnn_quantization_params* input2_quantization, + const struct xnn_quantization_params* output_quantization, + uint32_t flags, + xnn_operator_t* binary_op_out); + +enum xnn_status xnn_reshape_binary_elementwise_nd( + xnn_operator_t binary_op, + size_t num_input1_dims, + const size_t* input1_shape, + size_t num_input2_dims, + const size_t* input2_shape, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_binary_elementwise_nd( + xnn_operator_t binary_op, + const void* input1, + const void* input2, + void* output); + +enum xnn_status xnn_run_binary_elementwise_nd( + enum xnn_binary_operator type, + enum xnn_datatype datatype, + const struct xnn_quantization_params* input1_quantization, + const struct xnn_quantization_params* input2_quantization, + const struct xnn_quantization_params* output_quantization, + uint32_t flags, + size_t num_input1_dims, + const size_t* input1_shape, + size_t num_input2_dims, + const size_t* input2_shape, + const void* input1, + const void* input2, + void* output, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_unary_elementwise_nc( + enum xnn_unary_operator op_type, + enum xnn_datatype input_datatype, + enum xnn_datatype output_datatype, + const union xnn_unary_params* params, + const struct xnn_quantization_params* input_quantization, + const struct xnn_quantization_params* output_quantization, + uint32_t flags, + xnn_operator_t* op_out); + +enum xnn_status xnn_reshape_unary_elementwise_nc( + xnn_operator_t op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_unary_elementwise_nc( + xnn_operator_t op, + const void* input, + void* output); + +enum xnn_status xnn_run_unary_elementwise_nc( + // create parameters + enum xnn_unary_operator op_type, + enum xnn_datatype input_datatype, + enum xnn_datatype output_datatype, + const union xnn_unary_params* params, + const struct xnn_quantization_params* input_quantization, + const struct xnn_quantization_params* output_quantization, + uint32_t flags, + // reshape parameters + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool, + // setup parameters + const void* input, + void* output); + +enum xnn_status xnn_create_argmax_pooling2d_nhwc_f32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t flags, + xnn_operator_t* argmax_pooling_op_out); + +enum xnn_status xnn_reshape_argmax_pooling2d_nhwc_f32( + xnn_operator_t argmax_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_argmax_pooling2d_nhwc_f32( + xnn_operator_t argmax_pooling_op, + void* workspace, + const float* input, + float* output, + uint32_t* index); + +enum xnn_status xnn_create_average_pooling2d_nhwc_f16( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + float output_min, + float output_max, + uint32_t flags, + xnn_operator_t* average_pooling_op_out); + +enum xnn_status xnn_reshape_average_pooling2d_nhwc_f16( + xnn_operator_t average_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_average_pooling2d_nhwc_f16( + xnn_operator_t average_pooling_op, + void* workspace, + const void* input, + void* output); + +enum xnn_status xnn_create_average_pooling2d_nhwc_f32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + float output_min, + float output_max, + uint32_t flags, + xnn_operator_t* average_pooling_op_out); + +enum xnn_status xnn_reshape_average_pooling2d_nhwc_f32( + xnn_operator_t average_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_average_pooling2d_nhwc_f32( + xnn_operator_t average_pooling_op, + void* workspace, + const float* input, + float* output); + +enum xnn_status xnn_create_average_pooling2d_nhwc_qu8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint8_t input_zero_point, + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + xnn_operator_t* average_pooling_op_out); + +enum xnn_status xnn_reshape_average_pooling2d_nhwc_qu8( + xnn_operator_t average_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_average_pooling2d_nhwc_qu8( + xnn_operator_t average_pooling_op, + void* workspace, + const uint8_t* input, + uint8_t* output); + +enum xnn_status xnn_create_batch_matrix_multiply_nc_f16( + uint32_t flags, + xnn_operator_t* batch_matrix_multiply_op); + +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f16( + xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims, + const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k, + size_t n, size_t* workspace_size, size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_f16( + xnn_operator_t batch_matrix_multiply_op, void* workspace, + const void* input_a, const void* input_b, void* output); + +enum xnn_status xnn_create_batch_matrix_multiply_nc_f32( + uint32_t flags, xnn_operator_t* batch_matrix_multiply_op); + +enum xnn_status xnn_create_batch_matrix_multiply_nc_f32_const_weights( + size_t batch_size_b, size_t k, size_t n, const float* data_b, + uint32_t flags, xnn_operator_t* batch_matrix_multiply_op); + +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f32( + xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims, + const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k, + size_t n, size_t* workspace_size, size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_f32( + xnn_operator_t batch_matrix_multiply_op, void* workspace, + const float* input_a, const float* input_b, float* output); + +enum xnn_status xnn_create_batch_matrix_multiply_nc_qd8_f32_qc8w( + size_t batch_size_b, size_t k, size_t n, const int8_t* data_b, + const float* scale_b, uint32_t flags, + xnn_operator_t* batch_matrix_multiply_op); + +enum xnn_status xnn_reshape_batch_matrix_multiply_nc_qd8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, size_t num_batch_dims, + const size_t* batch_dims_a, const size_t* batch_dims_b, size_t m, size_t k, + size_t n, pthreadpool_t threadpool); + +enum xnn_status xnn_setup_batch_matrix_multiply_nc_qd8_f32_qc8w( + xnn_operator_t batch_matrix_multiply_op, const int8_t* input_a, + const struct xnn_quantization_params* quantization_params, + float* output); + +enum xnn_status xnn_create_channel_shuffle_nc_x8( + size_t groups, + size_t group_channels, + size_t input_stride, + size_t output_stride, + uint32_t flags, + xnn_operator_t* channel_shuffle_op_out); + +enum xnn_status xnn_reshape_channel_shuffle_nc_x8( + xnn_operator_t channel_shuffle_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_channel_shuffle_nc_x8( + xnn_operator_t channel_shuffle_op, + const void* input, + void* output); + +enum xnn_status xnn_create_channel_shuffle_nc_x32( + size_t groups, + size_t group_channels, + size_t input_stride, + size_t output_stride, + uint32_t flags, + xnn_operator_t* channel_shuffle_op_out); + +enum xnn_status xnn_reshape_channel_shuffle_nc_x32( + xnn_operator_t channel_shuffle_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_channel_shuffle_nc_x32( + xnn_operator_t channel_shuffle_op, + const void* input, + void* output); + +enum xnn_status xnn_create_constant_pad_nd_x8( + const void* padding_value, + uint32_t flags, + xnn_operator_t* constant_pad_op_out); + +enum xnn_status xnn_reshape_constant_pad_nd_x8( + xnn_operator_t constant_pad_op, + size_t num_dims, + const size_t* input_shape, + const size_t* pre_padding, + const size_t* post_padding, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_constant_pad_nd_x8( + xnn_operator_t constant_pad_op, + const void* input, + void* output); + +enum xnn_status xnn_run_constant_pad_nd_x8( + uint32_t flags, + size_t num_dims, + const size_t* input_shape, + const size_t* pre_paddings, + const size_t* post_paddings, + const void* input, + void* output, + const void* padding_value, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_constant_pad_nd_x16( + const void* padding_value, + uint32_t flags, + xnn_operator_t* constant_pad_op_out); + +enum xnn_status xnn_reshape_constant_pad_nd_x16( + xnn_operator_t constant_pad_op, + size_t num_dims, + const size_t* input_shape, + const size_t* pre_padding, + const size_t* post_padding, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_constant_pad_nd_x16( + xnn_operator_t constant_pad_op, + const void* input, + void* output); + +enum xnn_status xnn_run_constant_pad_nd_x16( + uint32_t flags, + size_t num_dims, + const size_t* input_shape, + const size_t* pre_paddings, + const size_t* post_paddings, + const void* input, + void* output, + const void* padding_value, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_constant_pad_nd_x32( + const void* padding_value, + uint32_t flags, + xnn_operator_t* constant_pad_op_out); + +enum xnn_status xnn_reshape_constant_pad_nd_x32( + xnn_operator_t constant_pad_op, + size_t num_dims, + const size_t* input_shape, + const size_t* pre_padding, + const size_t* post_padding, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_constant_pad_nd_x32( + xnn_operator_t constant_pad_op, + const void* input, + void* output); + +enum xnn_status xnn_run_constant_pad_nd_x32( + uint32_t flags, + size_t num_dims, + const size_t* input_shape, + const size_t* pre_paddings, + const size_t* post_paddings, + const void* input, + void* output, + const void* padding_value, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_convert_nc_f16_qd8( + uint32_t flags, + xnn_operator_t* convert_op_out); + +enum xnn_status xnn_reshape_convert_nc_f16_qd8( + xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool); + +// quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries. +enum xnn_status xnn_setup_convert_nc_f16_qd8( + xnn_operator_t convert_op, + const void* input, + int8_t* output, + struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_convert_nc_f32_qd8( + uint32_t flags, + xnn_operator_t* convert_op_out); + +enum xnn_status xnn_reshape_convert_nc_f32_qd8( + xnn_operator_t convert_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool); + +// quantization_params must be padded with at least XNN_EXTRA_QUANTIZATION_PARAMS entries. +enum xnn_status xnn_setup_convert_nc_f32_qd8( + xnn_operator_t convert_op, + const float* input, + int8_t* output, + struct xnn_quantization_params* quantization_params); + +XNN_DEPRECATED enum xnn_status xnn_run_convert_nc_f32_f16( + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + const float* input, + void* output, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_convolution2d_nchw_f16( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nchw_f16( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nchw_f16( + xnn_operator_t convolution_op, + const void* input, + void* output); + +enum xnn_status xnn_create_convolution2d_nchw_f32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nchw_f32( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nchw_f32( + xnn_operator_t convolution_op, + const float* input, + float* output); + +enum xnn_status xnn_create_convolution2d_nhwc_f16( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_f16( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_f16( + xnn_operator_t convolution_op, + void* workspace, + const void* input, + void* output); + +enum xnn_status xnn_create_convolution2d_nhwc_f32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_create_convolution2d_nhwc_f32_f16( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +// Forward declare. +struct xnn_post_operation; + +/// Deprecated +enum xnn_status xnn_create_fused_convolution2d_nhwc_f32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + const float* kernel, + const float* bias, + size_t num_post_operations, + struct xnn_post_operation* post_operations, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_f32( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_f32( + xnn_operator_t convolution_op, + void* workspace, + const float* input, + float* output); + +enum xnn_status xnn_create_convolution2d_nhwc_qd8_f16_qc8w( + uint32_t input_padding_top, uint32_t input_padding_right, + uint32_t input_padding_bottom, uint32_t input_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height, + uint32_t subsampling_width, uint32_t dilation_height, + uint32_t dilation_width, uint32_t groups, size_t group_input_channels, + size_t group_output_channels, size_t input_channel_stride, + size_t output_channel_stride, const float* kernel_scale, + const int8_t* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_create_convolution2d_nhwc_qd8_f32_qc8w( + uint32_t input_padding_top, uint32_t input_padding_right, + uint32_t input_padding_bottom, uint32_t input_padding_left, + uint32_t kernel_height, uint32_t kernel_width, uint32_t subsampling_height, + uint32_t subsampling_width, uint32_t dilation_height, + uint32_t dilation_width, uint32_t groups, size_t group_input_channels, + size_t group_output_channels, size_t input_channel_stride, + size_t output_channel_stride, const float* kernel_scale, + const int8_t* kernel, const float* bias, float output_min, float output_max, + uint32_t flags, xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_create_convolution2d_nhwc_qs8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + int8_t input_zero_point, + float input_scale, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f16_qc8w( + xnn_operator_t convolution_op, size_t batch_size, size_t input_height, + size_t input_width, size_t* workspace_size, size_t* workspace_alignment, + size_t* output_height_out, size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qd8_f32_qc8w( + xnn_operator_t convolution_op, size_t batch_size, size_t input_height, + size_t input_width, size_t* workspace_size, size_t* workspace_alignment, + size_t* output_height_out, size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qs8( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f16_qc8w( + xnn_operator_t convolution_op, void* workspace, const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_setup_convolution2d_nhwc_qd8_f32_qc8w( + xnn_operator_t convolution_op, void* workspace, const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_setup_convolution2d_nhwc_qs8( + xnn_operator_t convolution_op, + void* workspace, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_convolution2d_nhwc_qs8_qc8w( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + int8_t input_zero_point, + float input_scale, + const float* kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qs8_qc8w( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_qs8_qc8w( + xnn_operator_t convolution_op, + void* workspace, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_convolution2d_nhwc_qu8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t subsampling_height, + uint32_t subsampling_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_channel_stride, + size_t output_channel_stride, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* convolution_op_out); + +enum xnn_status xnn_reshape_convolution2d_nhwc_qu8( + xnn_operator_t convolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* workspace_size, + size_t* workspace_alignment, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_convolution2d_nhwc_qu8( + xnn_operator_t convolution_op, + void* workspace, + const uint8_t* input, + uint8_t* output); + +enum xnn_status xnn_create_copy_nc_x8( + uint32_t flags, + xnn_operator_t* copy_op_out); + +enum xnn_status xnn_reshape_copy_nc_x8( + xnn_operator_t copy_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_copy_nc_x8( + xnn_operator_t copy_op, + const void* input, + void* output); + +enum xnn_status xnn_create_copy_nc_x16( + uint32_t flags, + xnn_operator_t* copy_op_out); + +enum xnn_status xnn_reshape_copy_nc_x16( + xnn_operator_t copy_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_copy_nc_x16( + xnn_operator_t copy_op, + const void* input, + void* output); + +enum xnn_status xnn_create_copy_nc_x32( + uint32_t flags, + xnn_operator_t* copy_op_out); + +enum xnn_status xnn_reshape_copy_nc_x32( + xnn_operator_t copy_op, + size_t batch_size, + size_t channels, + size_t input_stride, + size_t output_stride, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_copy_nc_x32( + xnn_operator_t copy_op, + const void* input, + void* output); + +enum xnn_status xnn_run_copy_nc_x32( + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + const uint32_t* input, + uint32_t* output, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_deconvolution2d_nhwc_f16( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_f16( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_f16( + xnn_operator_t deconvolution_op, + const void* input, + void* output); + +enum xnn_status xnn_create_deconvolution2d_nhwc_f32( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + const float* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_create_deconvolution2d_nhwc_f32_f16( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_f32( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_f32( + xnn_operator_t deconvolution_op, + const float* input, + float* output); + +enum xnn_status xnn_create_deconvolution2d_nhwc_qd8_f32_qc8w( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qd8_f32_qc8w( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_qd8_f32_qc8w( + xnn_operator_t deconvolution_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_deconvolution2d_nhwc_qs8( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + int8_t input_zero_point, + float input_scale, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8( + xnn_operator_t deconvolution_op, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_deconvolution2d_nhwc_qs8_qc8w( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + int8_t input_zero_point, + float input_scale, + const float* kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qs8_qc8w( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_qs8_qc8w( + xnn_operator_t deconvolution_op, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_deconvolution2d_nhwc_qu8( + uint32_t output_padding_top, + uint32_t output_padding_right, + uint32_t output_padding_bottom, + uint32_t output_padding_left, + uint32_t kernel_height, + uint32_t kernel_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint32_t groups, + size_t group_input_channels, + size_t group_output_channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* deconvolution_op_out); + +enum xnn_status xnn_reshape_deconvolution2d_nhwc_qu8( + xnn_operator_t deconvolution_op, + size_t batch_size, + size_t input_height, + size_t input_width, + uint32_t adjustment_height, + uint32_t adjustment_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_deconvolution2d_nhwc_qu8( + xnn_operator_t deconvolution_op, + const uint8_t* input, + uint8_t* output); + +enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x16( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* depth_to_space_op_out); + +enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x16( + xnn_operator_t depth_to_space_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x16( + xnn_operator_t depth_to_space_op, + const void* input, + void* output); + +enum xnn_status xnn_create_depth_to_space_nchw2nhwc_x32( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* depth_to_space_op_out); + +enum xnn_status xnn_reshape_depth_to_space_nchw2nhwc_x32( + xnn_operator_t depth_to_space_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_depth_to_space_nchw2nhwc_x32( + xnn_operator_t depth_to_space_op, + const void* input, + void* output); + +enum xnn_status xnn_create_depth_to_space_nhwc_x8( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* depth_to_space_op_out); + +enum xnn_status xnn_reshape_depth_to_space_nhwc_x8( + xnn_operator_t depth_to_space_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_depth_to_space_nhwc_x8( + xnn_operator_t depth_to_space_op, + const void* input, + void* output); + +enum xnn_status xnn_create_depth_to_space_nhwc_x16( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* depth_to_space_op_out); + +enum xnn_status xnn_reshape_depth_to_space_nhwc_x16( + xnn_operator_t depth_to_space_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_depth_to_space_nhwc_x16( + xnn_operator_t depth_to_space_op, + const void* input, + void* output); + +enum xnn_status xnn_create_depth_to_space_nhwc_x32( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* depth_to_space_op_out); + +enum xnn_status xnn_reshape_depth_to_space_nhwc_x32( + xnn_operator_t depth_to_space_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_depth_to_space_nhwc_x32( + xnn_operator_t depth_to_space_op, + const void* input, + void* output); + +enum xnn_status xnn_create_dynamic_fully_connected_nc_f16( + float output_min, + float output_max, + uint32_t flags, + xnn_operator_t* dynamic_fully_connected_op_out); + +enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f16( + xnn_operator_t dynamic_fully_connected_op, + size_t batch_size, + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_dynamic_fully_connected_nc_f16( + xnn_operator_t dynamic_fully_connected_op, + void* workspace, + const void* input, + const void* kernel, + const void* bias, + void* output); + +enum xnn_status xnn_create_dynamic_fully_connected_nc_f32( + float output_min, + float output_max, + uint32_t flags, + xnn_operator_t* dynamic_fully_connected_op_out); + +enum xnn_status xnn_reshape_dynamic_fully_connected_nc_f32( + xnn_operator_t dynamic_fully_connected_op, + size_t batch_size, + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_dynamic_fully_connected_nc_f32( + xnn_operator_t dynamic_fully_connected_op, + void* workspace, + const float* input, + const float* kernel, + const float* bias, + float* output); + +enum xnn_status xnn_create_fully_connected_nc_f16( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_f16( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_f16( + xnn_operator_t fully_connected_op, + const void* input, + void* output); + +enum xnn_status xnn_create_fully_connected_nc_f32_f16( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const void* kernel, + const void* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_create_fully_connected_nc_f32( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_f32_f16( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_reshape_fully_connected_nc_f32( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_f32_f16( + xnn_operator_t fully_connected_op, + const float* input, + float* output); + +enum xnn_status xnn_setup_fully_connected_nc_f32( + xnn_operator_t fully_connected_op, + const float* input, + float* output); + +enum xnn_status xnn_create_fully_connected_nc_f32_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const uint8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_f32_qc4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_f32_qc4w( + xnn_operator_t fully_connected_op, + const float* input, + float* output); + +enum xnn_status xnn_create_fully_connected_nc_f32_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_f32_qc8w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_f32_qc8w( + xnn_operator_t fully_connected_op, + const float* input, + float* output); + +enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc4w( + xnn_operator_t fully_connected_op, + const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qb4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t block_size, + uint8_t kernel_zero_point, + const uint16_t* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qb4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qb4w( + xnn_operator_t fully_connected_op, + const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t kernel_zero_point, + const float* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc4w( + xnn_operator_t fully_connected_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qb4w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + size_t block_size, + uint8_t kernel_zero_point, + const uint16_t* kernel_scale, + const void* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qb4w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qb4w( + xnn_operator_t fully_connected_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_create_fully_connected_nc_qd8_f16_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_setup_fully_connected_nc_qd8_f16_qc8w( + xnn_operator_t fully_connected_op, + const int8_t* input, + void* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_reshape_fully_connected_nc_qd8_f16_qc8w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_fully_connected_nc_qd8_f32_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + const float* kernel_scale, + const int8_t* kernel, + const float* bias, + float output_min, + float output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_setup_fully_connected_nc_qd8_f32_qc8w( + xnn_operator_t fully_connected_op, + const int8_t* input, + float* output, + const struct xnn_quantization_params* quantization_params); + +enum xnn_status xnn_reshape_fully_connected_nc_qd8_f32_qc8w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_fully_connected_nc_qs8( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + int8_t input_zero_point, + float input_scale, + float kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qs8( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qs8( + xnn_operator_t fully_connected_op, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_fully_connected_nc_qs8_qc8w( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + int8_t input_zero_point, + float input_scale, + const float* kernel_scale, + const int8_t* kernel, + const int32_t* bias, + int8_t output_zero_point, + float output_scale, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qs8_qc8w( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qs8_qc8w( + xnn_operator_t fully_connected_op, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_fully_connected_nc_qu8( + size_t input_channels, + size_t output_channels, + size_t input_stride, + size_t output_stride, + uint8_t input_zero_point, + float input_scale, + uint8_t kernel_zero_point, + float kernel_scale, + const uint8_t* kernel, + const int32_t* bias, + uint8_t output_zero_point, + float output_scale, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + xnn_code_cache_t code_cache, + xnn_weights_cache_t weights_cache, + xnn_operator_t* fully_connected_op_out); + +enum xnn_status xnn_reshape_fully_connected_nc_qu8( + xnn_operator_t fully_connected_op, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_fully_connected_nc_qu8( + xnn_operator_t fully_connected_op, + const uint8_t* input, + uint8_t* output); + + +enum xnn_status xnn_create_max_pooling2d_nhwc_f16( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + float output_min, + float output_max, + uint32_t flags, + xnn_operator_t* max_pooling_op_out); + +enum xnn_status xnn_reshape_max_pooling2d_nhwc_f16( + xnn_operator_t max_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_max_pooling2d_nhwc_f16( + xnn_operator_t max_pooling_op, + const void* input, + void* output); + +enum xnn_status xnn_create_max_pooling2d_nhwc_f32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + float output_min, + float output_max, + uint32_t flags, + xnn_operator_t* max_pooling_op_out); + +enum xnn_status xnn_reshape_max_pooling2d_nhwc_f32( + xnn_operator_t max_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_max_pooling2d_nhwc_f32( + xnn_operator_t max_pooling_op, + const float* input, + float* output); + +enum xnn_status xnn_create_max_pooling2d_nhwc_s8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + int8_t output_min, + int8_t output_max, + uint32_t flags, + xnn_operator_t* max_pooling_op_out); + +enum xnn_status xnn_reshape_max_pooling2d_nhwc_s8( + xnn_operator_t max_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_max_pooling2d_nhwc_s8( + xnn_operator_t max_pooling_op, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_max_pooling2d_nhwc_u8( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + uint32_t stride_height, + uint32_t stride_width, + uint32_t dilation_height, + uint32_t dilation_width, + uint8_t output_min, + uint8_t output_max, + uint32_t flags, + xnn_operator_t* max_pooling_op_out); + +enum xnn_status xnn_reshape_max_pooling2d_nhwc_u8( + xnn_operator_t max_pooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_max_pooling2d_nhwc_u8( + xnn_operator_t max_pooling_op, + const uint8_t* input, + uint8_t* output); + +enum xnn_status xnn_create_reduce_nd( + enum xnn_reduce_operator reduce_operator_type, + enum xnn_datatype datatype, + const struct xnn_quantization_params* input_quantization, + const struct xnn_quantization_params* output_quantization, + uint32_t flags, + xnn_operator_t* reduce_op_out); + +enum xnn_status xnn_reshape_reduce_nd( // + xnn_operator_t reduce_op, // + size_t num_reduction_axes, // + const int64_t* reduction_axes, // + size_t num_input_dims, // + const size_t* input_shape, // + size_t* workspace_size, // + size_t* workspace_alignment, // + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_reduce_nd( + xnn_operator_t reduce_op, + void* workspace, + const void* input, + void* output); + +enum xnn_status xnn_create_resize_bilinear2d_nchw_f32( + size_t output_height, + size_t output_width, + uint32_t flags, + xnn_operator_t* resize_op_out); + +enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f32( + xnn_operator_t resize_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_resize_bilinear2d_nchw_f32( + xnn_operator_t resize_op, + const float* input, + float* output); + +enum xnn_status xnn_create_resize_bilinear2d_nchw_f16( + size_t output_height, + size_t output_width, + uint32_t flags, + xnn_operator_t* resize_op_out); + +enum xnn_status xnn_reshape_resize_bilinear2d_nchw_f16( + xnn_operator_t resize_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_resize_bilinear2d_nchw_f16( + xnn_operator_t resize_op, + const void* input, + void* output); + +enum xnn_status xnn_create_resize_bilinear2d_nhwc_f16( + size_t output_height, + size_t output_width, + uint32_t flags, + xnn_operator_t* resize_op_out); + +enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f16( + xnn_operator_t resize_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f16( + xnn_operator_t resize_op, + void* workspace, + const void* input, + void* output); + +enum xnn_status xnn_create_resize_bilinear2d_nhwc_f32( + size_t output_height, + size_t output_width, + uint32_t flags, + xnn_operator_t* resize_op_out); + +enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_f32( + xnn_operator_t resize_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_resize_bilinear2d_nhwc_f32( + xnn_operator_t resize_op, + void* workspace, + const float* input, + float* output); + +enum xnn_status xnn_create_resize_bilinear2d_nhwc_s8( + size_t output_height, + size_t output_width, + uint32_t flags, + xnn_operator_t* resize_op_out); + +enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_s8( + xnn_operator_t resize_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_resize_bilinear2d_nhwc_s8( + xnn_operator_t resize_op, + void* workspace, + const int8_t* input, + int8_t* output); + +enum xnn_status xnn_create_resize_bilinear2d_nhwc_u8( + size_t output_height, + size_t output_width, + uint32_t flags, + xnn_operator_t* resize_op_out); + +enum xnn_status xnn_reshape_resize_bilinear2d_nhwc_u8( + xnn_operator_t resize_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_resize_bilinear2d_nhwc_u8( + xnn_operator_t resize_op, + void* workspace, + const uint8_t* input, + uint8_t* output); + +enum xnn_status xnn_create_rope_nthc_f16( + uint32_t flags, + xnn_operator_t* rope_op_out); + +enum xnn_status xnn_reshape_rope_nthc_f16( + xnn_operator_t rope_op, + size_t batch_size, + size_t tokens, + size_t heads, + size_t channels, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_rope_nthc_f16( + xnn_operator_t rope_op, + const void* input, + const void* weights, + void* output); + +enum xnn_status xnn_create_rope_nthc_f32( + uint32_t flags, + xnn_operator_t* rope_op_out); + +enum xnn_status xnn_reshape_rope_nthc_f32( + xnn_operator_t rope_op, + size_t batch_size, + size_t tokens, + size_t heads, + size_t channels, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_rope_nthc_f32( + xnn_operator_t rope_op, + const float* input, + const float* weights, + float* output); + +// N: batch size +// H: number of heads +// T: tokens (sequence length) +// C: channels (head dimension) +enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f16( + enum xnn_attention_logits_cap_type cap_type, + const void* cap_params, + uint32_t flags, + xnn_operator_t* attention_op_out); + +enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f16( + xnn_operator_t attention_op, + size_t batch_size, + size_t query_heads, + // Number of tokens in query. + size_t query_tokens, + size_t key_value_heads, + // Number of tokens in key/value. For self-attention, this is same as tokens. + size_t key_value_tokens, + size_t query_key_channels, + size_t value_channels, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +// Query is of dimension [batch_size, query_heads, query_tokens, channels]. +// Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, channels]. +// Scale is of dimension [channels]. +// Mask is of dimension [query_tokens, key_value_tokens]. +enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f16( + xnn_operator_t attention_op, + void* workspace, + const void* query, + const void* key, + const void* value, + const void* scale, + const void* mask, + void* output); + +// N: batch size +// H: number of heads +// T: tokens (sequence length) +// C: channels (head dimension) +enum xnn_status xnn_create_scaled_dot_product_attention_nhtc_f32( + enum xnn_attention_logits_cap_type cap_type, + const void* cap_params, + uint32_t flags, + xnn_operator_t* attention_op_out); + +enum xnn_status xnn_reshape_scaled_dot_product_attention_nhtc_f32( + xnn_operator_t attention_op, + size_t batch_size, + size_t query_heads, + // Number of tokens in query. + size_t query_tokens, + size_t key_value_heads, + // Number of tokens in key/value. For self-attention, this is same as tokens. + size_t key_value_tokens, + size_t query_key_channels, + size_t value_channels, + size_t* workspace_size, + size_t* workspace_alignment, + pthreadpool_t threadpool); + +// Query is of dimension [batch_size, query_heads, query_tokens, query_key_channels]. +// Key and value are of dimension [batch_size, key_value_heads, key_value_tokens, query_key_channels]. +// Scale is of dimension [query_key_channels]. +// Mask is of dimension [query_tokens, key_value_tokens]. +// Output is of dimension [batch_size, query_heads, query_tokens, value_channels]. +enum xnn_status xnn_setup_scaled_dot_product_attention_nhtc_f32( + xnn_operator_t attention_op, + void* workspace, + const float* query, + const float* key, + const float* value, + const float* scale, + const float* mask, + float* output); + + +enum xnn_status xnn_create_slice_nd_x16( + uint32_t flags, + xnn_operator_t* slice_op_out); + +enum xnn_status xnn_reshape_slice_nd_x16( + xnn_operator_t slice_op, + size_t num_dims, + const size_t* input_shape, + const size_t* offsets, + const size_t* sizes, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_slice_nd_x16( + xnn_operator_t slice_op, + const void* input, + void* output); + +enum xnn_status xnn_create_slice_nd_x32( + uint32_t flags, + xnn_operator_t* slice_op_out); + +enum xnn_status xnn_reshape_slice_nd_x32( + xnn_operator_t slice_op, + size_t num_dims, + const size_t* input_shape, + const size_t* offsets, + const size_t* sizes, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_slice_nd_x32( + xnn_operator_t slice_op, + const void* input, + void* output); + +enum xnn_status xnn_run_slice_nd_x32( + size_t num_dims, + const size_t* input_shape, + const size_t* offsets, + const size_t* sizes, + const void* input, + void* output, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_softmax_nc_f16( + uint32_t flags, + xnn_operator_t* softmax_op_out); + +enum xnn_status xnn_reshape_softmax_nc_f16( + xnn_operator_t softmax_op, + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_softmax_nc_f16( + xnn_operator_t softmax_op, + const void* input, + void* output); + +enum xnn_status xnn_create_softmax_nc_f32( + uint32_t flags, + xnn_operator_t* softmax_op_out); + +enum xnn_status xnn_reshape_softmax_nc_f32( + xnn_operator_t softmax_op, + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_softmax_nc_f32( + xnn_operator_t softmax_op, + const float* input, + float* output); + +enum xnn_status xnn_create_softmax_nc_qu8( + float input_scale, + uint8_t output_zero_point, + float output_scale, + uint32_t flags, + xnn_operator_t* softmax_op_out); + +enum xnn_status xnn_reshape_softmax_nc_qu8( + xnn_operator_t softmax_op, + size_t channels, + size_t input_stride, + size_t output_stride, + size_t batch_size, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_softmax_nc_qu8( + xnn_operator_t softmax_op, + const uint8_t* input, + uint8_t* output); + +enum xnn_status xnn_create_space_to_depth_nhwc_x16( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* space_to_depth_op_out); + +enum xnn_status xnn_reshape_space_to_depth_nhwc_x16( + xnn_operator_t space_to_depth_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_space_to_depth_nhwc_x16( + xnn_operator_t space_to_depth_op, + const void* input, + void* output); + +enum xnn_status xnn_create_space_to_depth_nhwc_x32( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* space_to_depth_op_out); + +enum xnn_status xnn_reshape_space_to_depth_nhwc_x32( + xnn_operator_t space_to_depth_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_space_to_depth_nhwc_x32( + xnn_operator_t space_to_depth_op, + const void* input, + void* output); + +enum xnn_status xnn_create_transpose_nd_x8( + uint32_t flags, + xnn_operator_t* transpose_op_out); + +enum xnn_status xnn_reshape_transpose_nd_x8( + xnn_operator_t transpose_op, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_transpose_nd_x8( + xnn_operator_t transpose_op, + const void* input, + void* output); + +enum xnn_status xnn_run_transpose_nd_x8( + const void* input, + void* output, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_transpose_nd_x16( + uint32_t flags, + xnn_operator_t* transpose_op_out); + +enum xnn_status xnn_reshape_transpose_nd_x16( + xnn_operator_t transpose_op, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_transpose_nd_x16( + xnn_operator_t transpose_op, + const void* input, + void* output); + +enum xnn_status xnn_run_transpose_nd_x16( + const void* input, + void* output, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_transpose_nd_x32( + uint32_t flags, + xnn_operator_t* transpose_op_out); + +enum xnn_status xnn_reshape_transpose_nd_x32( + xnn_operator_t transpose_op, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_transpose_nd_x32( + xnn_operator_t transpose_op, + const void* input, + void* output); + +enum xnn_status xnn_run_transpose_nd_x32( + const void* input, + void* output, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_transpose_nd_x64( + uint32_t flags, + xnn_operator_t* transpose_op_out); + +enum xnn_status xnn_reshape_transpose_nd_x64( + xnn_operator_t transpose_op, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_transpose_nd_x64( + xnn_operator_t transpose_op, + const void* input, + void* output); + +enum xnn_status xnn_run_transpose_nd_x64( + const void* input, + void* output, + size_t num_dims, + const size_t* input_shape, + const size_t* output_perm, + uint32_t flags, + pthreadpool_t threadpool); + +enum xnn_status xnn_create_unpooling2d_nhwc_x32( + uint32_t input_padding_top, + uint32_t input_padding_right, + uint32_t input_padding_bottom, + uint32_t input_padding_left, + uint32_t pooling_height, + uint32_t pooling_width, + size_t channels, + size_t input_pixel_stride, + size_t output_pixel_stride, + uint32_t flags, + xnn_operator_t* unpooling_op_out); + +enum xnn_status xnn_reshape_unpooling2d_nhwc_x32( + xnn_operator_t unpooling_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t* output_height_out, + size_t* output_width_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_unpooling2d_nhwc_x32( + xnn_operator_t unpooling_op, + const void* input, + const uint32_t* index, + void* output); + +enum xnn_status xnn_create_slice_nd_x8( + uint32_t flags, + xnn_operator_t* slice_op_out); + +enum xnn_status xnn_reshape_slice_nd_x8( + xnn_operator_t slice_op, + size_t num_dims, + const size_t* input_shape, + const size_t* offsets, + const size_t* sizes, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_slice_nd_x8( + xnn_operator_t slice_op, + const void* input, + void* output); + +enum xnn_status xnn_create_space_to_depth_nhwc_x8( + uint32_t block_size, + uint32_t flags, + xnn_operator_t* space_to_depth_op_out); + +enum xnn_status xnn_reshape_space_to_depth_nhwc_x8( + xnn_operator_t space_to_depth_op, + size_t batch_size, + size_t input_height, + size_t input_width, + size_t input_channels, + size_t* output_height_out, + size_t* output_width_out, + size_t* output_channels_out, + pthreadpool_t threadpool); + +enum xnn_status xnn_setup_space_to_depth_nhwc_x8( + xnn_operator_t space_to_depth_op, + const void* input, + void* output); + +#ifdef __cplusplus +} // extern "C" +#endif